static int vsock_stream_setsockopt(struct socket *sock, int level, int optname, char __user *optval, unsigned int optlen) { int err; struct sock *sk; struct vsock_sock *vsk; u64 val; if (level != AF_VSOCK) return -ENOPROTOOPT; #define COPY_IN(_v) \ do { \ if (optlen < sizeof(_v)) { \ err = -EINVAL; \ goto exit; \ } \ if (copy_from_user(&_v, optval, sizeof(_v)) != 0) { \ err = -EFAULT; \ goto exit; \ } \ } while (0) err = 0; sk = sock->sk; vsk = vsock_sk(sk); lock_sock(sk); switch (optname) { case SO_VM_SOCKETS_BUFFER_SIZE: COPY_IN(val); transport->set_buffer_size(vsk, val); break; case SO_VM_SOCKETS_BUFFER_MAX_SIZE: COPY_IN(val); transport->set_max_buffer_size(vsk, val); break; case SO_VM_SOCKETS_BUFFER_MIN_SIZE: COPY_IN(val); transport->set_min_buffer_size(vsk, val); break; case SO_VM_SOCKETS_CONNECT_TIMEOUT: { struct timeval tv; COPY_IN(tv); if (tv.tv_sec >= 0 && tv.tv_usec < USEC_PER_SEC && tv.tv_sec < (MAX_SCHEDULE_TIMEOUT / HZ - 1)) { vsk->connect_timeout = tv.tv_sec * HZ + DIV_ROUND_UP(tv.tv_usec, (1000000 / HZ)); if (vsk->connect_timeout == 0) vsk->connect_timeout = VSOCK_DEFAULT_CONNECT_TIMEOUT; } else { err = -ERANGE; } break; } default: err = -ENOPROTOOPT; break; } #undef COPY_IN exit: release_sock(sk); return err; }
static int vsock_accept(struct socket *sock, struct socket *newsock, int flags) { struct sock *listener; int err; struct sock *connected; struct vsock_sock *vconnected; long timeout; DEFINE_WAIT(wait); err = 0; listener = sock->sk; lock_sock(listener); if (sock->type != SOCK_STREAM) { err = -EOPNOTSUPP; goto out; } if (listener->sk_state != SS_LISTEN) { err = -EINVAL; goto out; } /* Wait for children sockets to appear; these are the new sockets * created upon connection establishment. */ timeout = sock_sndtimeo(listener, flags & O_NONBLOCK); prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE); while ((connected = vsock_dequeue_accept(listener)) == NULL && listener->sk_err == 0) { release_sock(listener); timeout = schedule_timeout(timeout); lock_sock(listener); if (signal_pending(current)) { err = sock_intr_errno(timeout); goto out_wait; } else if (timeout == 0) { err = -EAGAIN; goto out_wait; } prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE); } if (listener->sk_err) err = -listener->sk_err; if (connected) { listener->sk_ack_backlog--; lock_sock(connected); vconnected = vsock_sk(connected); /* If the listener socket has received an error, then we should * reject this socket and return. Note that we simply mark the * socket rejected, drop our reference, and let the cleanup * function handle the cleanup; the fact that we found it in * the listener's accept queue guarantees that the cleanup * function hasn't run yet. */ if (err) { vconnected->rejected = true; release_sock(connected); sock_put(connected); goto out_wait; } newsock->state = SS_CONNECTED; sock_graft(connected, newsock); release_sock(connected); sock_put(connected); } out_wait: finish_wait(sk_sleep(listener), &wait); out: release_sock(listener); return err; }
static int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, int flags) { return transport->dgram_dequeue(vsock_sk(sock->sk), msg, len, flags); }
static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, int addr_len, int flags) { int err; struct sock *sk; struct vsock_sock *vsk; struct sockaddr_vm *remote_addr; long timeout; DEFINE_WAIT(wait); err = 0; sk = sock->sk; vsk = vsock_sk(sk); lock_sock(sk); /* XXX AF_UNSPEC should make us disconnect like AF_INET. */ switch (sock->state) { case SS_CONNECTED: err = -EISCONN; goto out; case SS_DISCONNECTING: err = -EINVAL; goto out; case SS_CONNECTING: /* This continues on so we can move sock into the SS_CONNECTED * state once the connection has completed (at which point err * will be set to zero also). Otherwise, we will either wait * for the connection or return -EALREADY should this be a * non-blocking call. */ err = -EALREADY; break; default: if ((sk->sk_state == SS_LISTEN) || vsock_addr_cast(addr, addr_len, &remote_addr) != 0) { err = -EINVAL; goto out; } /* The hypervisor and well-known contexts do not have socket * endpoints. */ if (!transport->stream_allow(remote_addr->svm_cid, remote_addr->svm_port)) { err = -ENETUNREACH; goto out; } /* Set the remote address that we are connecting to. */ memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr)); err = vsock_auto_bind(vsk); if (err) goto out; sk->sk_state = SS_CONNECTING; err = transport->connect(vsk); if (err < 0) goto out; /* Mark sock as connecting and set the error code to in * progress in case this is a non-blocking connect. */ sock->state = SS_CONNECTING; err = -EINPROGRESS; } /* The receive path will handle all communication until we are able to * enter the connected state. Here we wait for the connection to be * completed or a notification of an error. */ timeout = vsk->connect_timeout; prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); while (sk->sk_state != SS_CONNECTED && sk->sk_err == 0) { if (flags & O_NONBLOCK) { /* If we're not going to block, we schedule a timeout * function to generate a timeout on the connection * attempt, in case the peer doesn't respond in a * timely manner. We hold on to the socket until the * timeout fires. */ sock_hold(sk); INIT_DELAYED_WORK(&vsk->dwork, vsock_connect_timeout); schedule_delayed_work(&vsk->dwork, timeout); /* Skip ahead to preserve error code set above. */ goto out_wait; } release_sock(sk); timeout = schedule_timeout(timeout); lock_sock(sk); if (signal_pending(current)) { err = sock_intr_errno(timeout); goto out_wait_error; } else if (timeout == 0) { err = -ETIMEDOUT; goto out_wait_error; } prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); } if (sk->sk_err) { err = -sk->sk_err; goto out_wait_error; } else err = 0; out_wait: finish_wait(sk_sleep(sk), &wait); out: release_sock(sk); return err; out_wait_error: sk->sk_state = SS_UNCONNECTED; sock->state = SS_UNCONNECTED; goto out_wait; }
int virtio_transport_shutdown(struct vsock_sock *vsk, int mode) { struct virtio_vsock_pkt_info info = { .op = VIRTIO_VSOCK_OP_SHUTDOWN, .type = VIRTIO_VSOCK_TYPE_STREAM, .flags = (mode & RCV_SHUTDOWN ? VIRTIO_VSOCK_SHUTDOWN_RCV : 0) | (mode & SEND_SHUTDOWN ? VIRTIO_VSOCK_SHUTDOWN_SEND : 0), }; return virtio_transport_send_pkt_info(vsk, &info); } EXPORT_SYMBOL_GPL(virtio_transport_shutdown); int virtio_transport_dgram_enqueue(struct vsock_sock *vsk, struct sockaddr_vm *remote_addr, struct msghdr *msg, size_t dgram_len) { return -EOPNOTSUPP; } EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue); ssize_t virtio_transport_stream_enqueue(struct vsock_sock *vsk, struct msghdr *msg, size_t len) { struct virtio_vsock_pkt_info info = { .op = VIRTIO_VSOCK_OP_RW, .type = VIRTIO_VSOCK_TYPE_STREAM, .msg = msg, .pkt_len = len, }; return virtio_transport_send_pkt_info(vsk, &info); } EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue); void virtio_transport_destruct(struct vsock_sock *vsk) { struct virtio_vsock_sock *vvs = vsk->trans; kfree(vvs); } EXPORT_SYMBOL_GPL(virtio_transport_destruct); static int virtio_transport_reset(struct vsock_sock *vsk, struct virtio_vsock_pkt *pkt) { struct virtio_vsock_pkt_info info = { .op = VIRTIO_VSOCK_OP_RST, .type = VIRTIO_VSOCK_TYPE_STREAM, .reply = !!pkt, }; /* Send RST only if the original pkt is not a RST pkt */ if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) return 0; return virtio_transport_send_pkt_info(vsk, &info); } /* Normally packets are associated with a socket. There may be no socket if an * attempt was made to connect to a socket that does not exist. */ static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt) { struct virtio_vsock_pkt_info info = { .op = VIRTIO_VSOCK_OP_RST, .type = le16_to_cpu(pkt->hdr.type), .reply = true, }; /* Send RST only if the original pkt is not a RST pkt */ if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) return 0; pkt = virtio_transport_alloc_pkt(&info, 0, le64_to_cpu(pkt->hdr.dst_cid), le32_to_cpu(pkt->hdr.dst_port), le64_to_cpu(pkt->hdr.src_cid), le32_to_cpu(pkt->hdr.src_port)); if (!pkt) return -ENOMEM; return virtio_transport_get_ops()->send_pkt(pkt); } static void virtio_transport_wait_close(struct sock *sk, long timeout) { if (timeout) { DEFINE_WAIT_FUNC(wait, woken_wake_function); add_wait_queue(sk_sleep(sk), &wait); do { if (sk_wait_event(sk, &timeout, sock_flag(sk, SOCK_DONE), &wait)) break; } while (!signal_pending(current) && timeout); remove_wait_queue(sk_sleep(sk), &wait); } } static void virtio_transport_do_close(struct vsock_sock *vsk, bool cancel_timeout) { struct sock *sk = sk_vsock(vsk); sock_set_flag(sk, SOCK_DONE); vsk->peer_shutdown = SHUTDOWN_MASK; if (vsock_stream_has_data(vsk) <= 0) sk->sk_state = SS_DISCONNECTING; sk->sk_state_change(sk); if (vsk->close_work_scheduled && (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) { vsk->close_work_scheduled = false; vsock_remove_sock(vsk); /* Release refcnt obtained when we scheduled the timeout */ sock_put(sk); } } static void virtio_transport_close_timeout(struct work_struct *work) { struct vsock_sock *vsk = container_of(work, struct vsock_sock, close_work.work); struct sock *sk = sk_vsock(vsk); sock_hold(sk); lock_sock(sk); if (!sock_flag(sk, SOCK_DONE)) { (void)virtio_transport_reset(vsk, NULL); virtio_transport_do_close(vsk, false); } vsk->close_work_scheduled = false; release_sock(sk); sock_put(sk); } /* User context, vsk->sk is locked */ static bool virtio_transport_close(struct vsock_sock *vsk) { struct sock *sk = &vsk->sk; if (!(sk->sk_state == SS_CONNECTED || sk->sk_state == SS_DISCONNECTING)) return true; /* Already received SHUTDOWN from peer, reply with RST */ if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) { (void)virtio_transport_reset(vsk, NULL); return true; } if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK) (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK); if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING)) virtio_transport_wait_close(sk, sk->sk_lingertime); if (sock_flag(sk, SOCK_DONE)) { return true; } sock_hold(sk); INIT_DELAYED_WORK(&vsk->close_work, virtio_transport_close_timeout); vsk->close_work_scheduled = true; schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT); return false; } void virtio_transport_release(struct vsock_sock *vsk) { struct sock *sk = &vsk->sk; bool remove_sock = true; lock_sock(sk); if (sk->sk_type == SOCK_STREAM) remove_sock = virtio_transport_close(vsk); release_sock(sk); if (remove_sock) vsock_remove_sock(vsk); } EXPORT_SYMBOL_GPL(virtio_transport_release); static int virtio_transport_recv_connecting(struct sock *sk, struct virtio_vsock_pkt *pkt) { struct vsock_sock *vsk = vsock_sk(sk); int err; int skerr; switch (le16_to_cpu(pkt->hdr.op)) { case VIRTIO_VSOCK_OP_RESPONSE: sk->sk_state = SS_CONNECTED; sk->sk_socket->state = SS_CONNECTED; vsock_insert_connected(vsk); sk->sk_state_change(sk); break; case VIRTIO_VSOCK_OP_INVALID: break; case VIRTIO_VSOCK_OP_RST: skerr = ECONNRESET; err = 0; goto destroy; default: skerr = EPROTO; err = -EINVAL; goto destroy; } return 0; destroy: virtio_transport_reset(vsk, pkt); sk->sk_state = SS_UNCONNECTED; sk->sk_err = skerr; sk->sk_error_report(sk); return err; } static int virtio_transport_recv_connected(struct sock *sk, struct virtio_vsock_pkt *pkt) { struct vsock_sock *vsk = vsock_sk(sk); struct virtio_vsock_sock *vvs = vsk->trans; int err = 0; switch (le16_to_cpu(pkt->hdr.op)) { case VIRTIO_VSOCK_OP_RW: pkt->len = le32_to_cpu(pkt->hdr.len); pkt->off = 0; spin_lock_bh(&vvs->rx_lock); virtio_transport_inc_rx_pkt(vvs, pkt); list_add_tail(&pkt->list, &vvs->rx_queue); spin_unlock_bh(&vvs->rx_lock); sk->sk_data_ready(sk); return err; case VIRTIO_VSOCK_OP_CREDIT_UPDATE: sk->sk_write_space(sk); break; case VIRTIO_VSOCK_OP_SHUTDOWN: if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV) vsk->peer_shutdown |= RCV_SHUTDOWN; if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND) vsk->peer_shutdown |= SEND_SHUTDOWN; if (vsk->peer_shutdown == SHUTDOWN_MASK && vsock_stream_has_data(vsk) <= 0) sk->sk_state = SS_DISCONNECTING; if (le32_to_cpu(pkt->hdr.flags)) sk->sk_state_change(sk); break; case VIRTIO_VSOCK_OP_RST: virtio_transport_do_close(vsk, true); break; default: err = -EINVAL; break; } virtio_transport_free_pkt(pkt); return err; } static void virtio_transport_recv_disconnecting(struct sock *sk, struct virtio_vsock_pkt *pkt) { struct vsock_sock *vsk = vsock_sk(sk); if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) virtio_transport_do_close(vsk, true); } static int virtio_transport_send_response(struct vsock_sock *vsk, struct virtio_vsock_pkt *pkt) { struct virtio_vsock_pkt_info info = { .op = VIRTIO_VSOCK_OP_RESPONSE, .type = VIRTIO_VSOCK_TYPE_STREAM, .remote_cid = le64_to_cpu(pkt->hdr.src_cid), .remote_port = le32_to_cpu(pkt->hdr.src_port), .reply = true, }; return virtio_transport_send_pkt_info(vsk, &info); } /* Handle server socket */ static int virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) { struct vsock_sock *vsk = vsock_sk(sk); struct vsock_sock *vchild; struct sock *child; if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) { virtio_transport_reset(vsk, pkt); return -EINVAL; } if (sk_acceptq_is_full(sk)) { virtio_transport_reset(vsk, pkt); return -ENOMEM; } child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL, sk->sk_type, 0); if (!child) { virtio_transport_reset(vsk, pkt); return -ENOMEM; } sk->sk_ack_backlog++; lock_sock_nested(child, SINGLE_DEPTH_NESTING); child->sk_state = SS_CONNECTED; vchild = vsock_sk(child); vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid), le32_to_cpu(pkt->hdr.dst_port)); vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid), le32_to_cpu(pkt->hdr.src_port)); vsock_insert_connected(vchild); vsock_enqueue_accept(sk, child); virtio_transport_send_response(vchild, pkt); release_sock(child); sk->sk_data_ready(sk); return 0; } static bool virtio_transport_space_update(struct sock *sk, struct virtio_vsock_pkt *pkt) { struct vsock_sock *vsk = vsock_sk(sk); struct virtio_vsock_sock *vvs = vsk->trans; bool space_available; /* buf_alloc and fwd_cnt is always included in the hdr */ spin_lock_bh(&vvs->tx_lock); vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc); vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt); space_available = virtio_transport_has_space(vsk); spin_unlock_bh(&vvs->tx_lock); return space_available; } /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex * lock. */ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) { struct sockaddr_vm src, dst; struct vsock_sock *vsk; struct sock *sk; bool space_available; vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid), le32_to_cpu(pkt->hdr.src_port)); vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid), le32_to_cpu(pkt->hdr.dst_port)); trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port, dst.svm_cid, dst.svm_port, le32_to_cpu(pkt->hdr.len), le16_to_cpu(pkt->hdr.type), le16_to_cpu(pkt->hdr.op), le32_to_cpu(pkt->hdr.flags), le32_to_cpu(pkt->hdr.buf_alloc), le32_to_cpu(pkt->hdr.fwd_cnt)); if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) { (void)virtio_transport_reset_no_sock(pkt); goto free_pkt; } /* The socket must be in connected or bound table * otherwise send reset back */ sk = vsock_find_connected_socket(&src, &dst); if (!sk) { sk = vsock_find_bound_socket(&dst); if (!sk) { (void)virtio_transport_reset_no_sock(pkt); goto free_pkt; } } vsk = vsock_sk(sk); space_available = virtio_transport_space_update(sk, pkt); lock_sock(sk); /* Update CID in case it has changed after a transport reset event */ vsk->local_addr.svm_cid = dst.svm_cid; if (space_available) sk->sk_write_space(sk); switch (sk->sk_state) { case VSOCK_SS_LISTEN: virtio_transport_recv_listen(sk, pkt); virtio_transport_free_pkt(pkt); break; case SS_CONNECTING: virtio_transport_recv_connecting(sk, pkt); virtio_transport_free_pkt(pkt); break; case SS_CONNECTED: virtio_transport_recv_connected(sk, pkt); break; case SS_DISCONNECTING: virtio_transport_recv_disconnecting(sk, pkt); virtio_transport_free_pkt(pkt); break; default: virtio_transport_free_pkt(pkt); break; } release_sock(sk); /* Release refcnt obtained when we fetched this socket out of the * bound or connected list. */ sock_put(sk); return; free_pkt: virtio_transport_free_pkt(pkt); } EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt); void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt) { kfree(pkt->buf); kfree(pkt); } EXPORT_SYMBOL_GPL(virtio_transport_free_pkt); MODULE_LICENSE("GPL v2"); MODULE_AUTHOR("Asias He"); MODULE_DESCRIPTION("common code for virtio vsock");