static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd) { struct socket *sock, *oldsock; struct vhost_virtqueue *vq; struct vhost_ubuf_ref *ubufs, *oldubufs = NULL; int r; mutex_lock(&n->dev.mutex); r = vhost_dev_check_owner(&n->dev); if (r) goto err; if (index >= VHOST_NET_VQ_MAX) { r = -ENOBUFS; goto err; } vq = n->vqs + index; mutex_lock(&vq->mutex); /* Verify that ring has been setup correctly. */ if (!vhost_vq_access_ok(vq)) { r = -EFAULT; goto err_vq; } sock = get_socket(fd); if (IS_ERR(sock)) { r = PTR_ERR(sock); goto err_vq; } /* start polling new socket */ oldsock = rcu_dereference_protected(vq->private_data, lockdep_is_held(&vq->mutex)); if (sock != oldsock) { ubufs = vhost_ubuf_alloc(vq, sock && vhost_sock_zcopy(sock)); if (IS_ERR(ubufs)) { r = PTR_ERR(ubufs); goto err_ubufs; } oldubufs = vq->ubufs; vq->ubufs = ubufs; vhost_net_disable_vq(n, vq); rcu_assign_pointer(vq->private_data, sock); vhost_net_enable_vq(n, vq); r = vhost_init_used(vq); if (r) goto err_vq; } mutex_unlock(&vq->mutex); if (oldubufs) { vhost_ubuf_put_and_wait(oldubufs); mutex_lock(&vq->mutex); vhost_zerocopy_signal_used(vq); mutex_unlock(&vq->mutex); } if (oldsock) { vhost_net_flush_vq(n, index); fput(oldsock->file); } mutex_unlock(&n->dev.mutex); return 0; err_ubufs: fput(sock->file); err_vq: mutex_unlock(&vq->mutex); err: mutex_unlock(&n->dev.mutex); return r; }
static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd) { struct socket *sock, *oldsock; struct vhost_virtqueue *vq; struct vhost_net_virtqueue *nvq; struct vhost_net_ubuf_ref *ubufs, *oldubufs = NULL; int r; mutex_lock(&n->dev.mutex); r = vhost_dev_check_owner(&n->dev); if (r) goto err; if (index >= VHOST_NET_VQ_MAX) { r = -ENOBUFS; goto err; } vq = &n->vqs[index].vq; nvq = &n->vqs[index]; mutex_lock(&vq->mutex); /* Verify that ring has been setup correctly. */ if (!vhost_vq_access_ok(vq)) { r = -EFAULT; goto err_vq; } sock = get_socket(fd); if (IS_ERR(sock)) { r = PTR_ERR(sock); goto err_vq; } /* start polling new socket */ oldsock = vq->private_data; if (sock != oldsock) { ubufs = vhost_net_ubuf_alloc(vq, sock && vhost_sock_zcopy(sock)); if (IS_ERR(ubufs)) { r = PTR_ERR(ubufs); goto err_ubufs; } vhost_net_disable_vq(n, vq); vq->private_data = sock; r = vhost_init_used(vq); if (r) goto err_used; r = vhost_net_enable_vq(n, vq); if (r) goto err_used; oldubufs = nvq->ubufs; nvq->ubufs = ubufs; n->tx_packets = 0; n->tx_zcopy_err = 0; n->tx_flush = false; } mutex_unlock(&vq->mutex); if (oldubufs) { vhost_net_ubuf_put_wait_and_free(oldubufs); mutex_lock(&vq->mutex); vhost_zerocopy_signal_used(n, vq); mutex_unlock(&vq->mutex); } if (oldsock) { vhost_net_flush_vq(n, index); sockfd_put(oldsock); } mutex_unlock(&n->dev.mutex); return 0; err_used: vq->private_data = oldsock; vhost_net_enable_vq(n, vq); if (ubufs) vhost_net_ubuf_put_wait_and_free(ubufs); err_ubufs: sockfd_put(sock); err_vq: mutex_unlock(&vq->mutex); err: mutex_unlock(&n->dev.mutex); return r; }
/* Expects to be always run from workqueue - which acts as * read-size critical section for our kind of RCU. */ static void handle_tx(struct vhost_net *net) { struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_TX]; unsigned out, in, s; int head; struct msghdr msg = { .msg_name = NULL, .msg_namelen = 0, .msg_control = NULL, .msg_controllen = 0, .msg_iov = vq->iov, .msg_flags = MSG_DONTWAIT, }; size_t len, total_len = 0; int err, wmem; size_t hdr_size; struct socket *sock; struct vhost_ubuf_ref *uninitialized_var(ubufs); bool zcopy; /* TODO: check that we are running from vhost_worker? */ sock = rcu_dereference_check(vq->private_data, 1); if (!sock) return; wmem = atomic_read(&sock->sk->sk_wmem_alloc); if (wmem >= sock->sk->sk_sndbuf) { mutex_lock(&vq->mutex); tx_poll_start(net, sock); mutex_unlock(&vq->mutex); return; } mutex_lock(&vq->mutex); vhost_disable_notify(&net->dev, vq); if (wmem < sock->sk->sk_sndbuf / 2) tx_poll_stop(net); hdr_size = vq->vhost_hlen; zcopy = vhost_sock_zcopy(sock); for (;;) { /* Release DMAs done buffers first */ if (zcopy) vhost_zerocopy_signal_used(vq); head = vhost_get_vq_desc(&net->dev, vq, vq->iov, ARRAY_SIZE(vq->iov), &out, &in, NULL, NULL); /* On error, stop handling until the next kick. */ if (unlikely(head < 0)) break; /* Nothing new? Wait for eventfd to tell us they refilled. */ if (head == vq->num) { int num_pends; wmem = atomic_read(&sock->sk->sk_wmem_alloc); if (wmem >= sock->sk->sk_sndbuf * 3 / 4) { tx_poll_start(net, sock); set_bit(SOCK_ASYNC_NOSPACE, &sock->flags); break; } /* If more outstanding DMAs, queue the work. * Handle upend_idx wrap around */ num_pends = likely(vq->upend_idx >= vq->done_idx) ? (vq->upend_idx - vq->done_idx) : (vq->upend_idx + UIO_MAXIOV - vq->done_idx); if (unlikely(num_pends > VHOST_MAX_PEND)) { tx_poll_start(net, sock); set_bit(SOCK_ASYNC_NOSPACE, &sock->flags); break; } if (unlikely(vhost_enable_notify(&net->dev, vq))) { vhost_disable_notify(&net->dev, vq); continue; } break; } if (in) { vq_err(vq, "Unexpected descriptor format for TX: " "out %d, int %d\n", out, in); break; } /* Skip header. TODO: support TSO. */ s = move_iovec_hdr(vq->iov, vq->hdr, hdr_size, out); msg.msg_iovlen = out; len = iov_length(vq->iov, out); /* Sanity check */ if (!len) { vq_err(vq, "Unexpected header len for TX: " "%zd expected %zd\n", iov_length(vq->hdr, s), hdr_size); break; } /* use msg_control to pass vhost zerocopy ubuf info to skb */ if (zcopy) { vq->heads[vq->upend_idx].id = head; if (len < VHOST_GOODCOPY_LEN) { /* copy don't need to wait for DMA done */ vq->heads[vq->upend_idx].len = VHOST_DMA_DONE_LEN; msg.msg_control = NULL; msg.msg_controllen = 0; ubufs = NULL; } else { struct ubuf_info *ubuf = &vq->ubuf_info[head]; vq->heads[vq->upend_idx].len = len; ubuf->callback = vhost_zerocopy_callback; ubuf->ctx = vq->ubufs; ubuf->desc = vq->upend_idx; msg.msg_control = ubuf; msg.msg_controllen = sizeof(ubuf); ubufs = vq->ubufs; kref_get(&ubufs->kref); } vq->upend_idx = (vq->upend_idx + 1) % UIO_MAXIOV; } /* TODO: Check specific error and bomb out unless ENOBUFS? */ err = sock->ops->sendmsg(NULL, sock, &msg, len); if (unlikely(err < 0)) { if (zcopy) { if (ubufs) vhost_ubuf_put(ubufs); vq->upend_idx = ((unsigned)vq->upend_idx - 1) % UIO_MAXIOV; } vhost_discard_vq_desc(vq, 1); tx_poll_start(net, sock); break; } if (err != len) pr_debug("Truncated TX packet: " " len %d != %zd\n", err, len); if (!zcopy) vhost_add_used_and_signal(&net->dev, vq, head, 0); total_len += len; if (unlikely(total_len >= VHOST_NET_WEIGHT)) { vhost_poll_queue(&vq->poll); break; } } mutex_unlock(&vq->mutex); } static int peek_head_len(struct sock *sk) { struct sk_buff *head; int len = 0; unsigned long flags; spin_lock_irqsave(&sk->sk_receive_queue.lock, flags); head = skb_peek(&sk->sk_receive_queue); if (likely(head)) { len = head->len; if (vlan_tx_tag_present(head)) len += VLAN_HLEN; } spin_unlock_irqrestore(&sk->sk_receive_queue.lock, flags); return len; } /* This is a multi-buffer version of vhost_get_desc, that works if * vq has read descriptors only. * @vq - the relevant virtqueue * @datalen - data length we'll be reading * @iovcount - returned count of io vectors we fill * @log - vhost log * @log_num - log offset * @quota - headcount quota, 1 for big buffer * returns number of buffer heads allocated, negative on error */ static int get_rx_bufs(struct vhost_virtqueue *vq, struct vring_used_elem *heads, int datalen, unsigned *iovcount, struct vhost_log *log, unsigned *log_num, unsigned int quota) { unsigned int out, in; int seg = 0; int headcount = 0; unsigned d; int r, nlogs = 0; while (datalen > 0 && headcount < quota) { if (unlikely(seg >= UIO_MAXIOV)) { r = -ENOBUFS; goto err; } d = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg, ARRAY_SIZE(vq->iov) - seg, &out, &in, log, log_num); if (d == vq->num) { r = 0; goto err; } if (unlikely(out || in <= 0)) { vq_err(vq, "unexpected descriptor format for RX: " "out %d, in %d\n", out, in); r = -EINVAL; goto err; } if (unlikely(log)) { nlogs += *log_num; log += *log_num; } heads[headcount].id = d; heads[headcount].len = iov_length(vq->iov + seg, in); datalen -= heads[headcount].len; ++headcount; seg += in; } heads[headcount - 1].len += datalen; *iovcount = seg; if (unlikely(log)) *log_num = nlogs; return headcount; err: vhost_discard_vq_desc(vq, headcount); return r; }