ssize_t sock_conn_send_src_addr(struct sock_ep_attr *ep_attr, struct sock_tx_ctx *tx_ctx, struct sock_conn *conn) { int ret; uint64_t total_len; struct sock_op tx_op = { 0 }; tx_op.op = SOCK_OP_CONN_MSG; SOCK_LOG_DBG("New conn msg on TX: %p using conn: %p\n", tx_ctx, conn); total_len = 0; tx_op.src_iov_len = sizeof(struct sockaddr_in); total_len = tx_op.src_iov_len + sizeof(struct sock_op_send); sock_tx_ctx_start(tx_ctx); if (ofi_rbavail(&tx_ctx->rb) < total_len) { ret = -FI_EAGAIN; goto err; } sock_tx_ctx_write_op_send(tx_ctx, &tx_op, 0, (uintptr_t) NULL, 0, 0, ep_attr, conn); sock_tx_ctx_write(tx_ctx, ep_attr->src_addr, sizeof(struct sockaddr_in)); sock_tx_ctx_commit(tx_ctx); conn->address_published = 1; return 0; err: sock_tx_ctx_abort(tx_ctx); return ret; }
ssize_t sock_ep_tx_atomic(struct fid_ep *ep, const struct fi_msg_atomic *msg, const struct fi_ioc *comparev, void **compare_desc, size_t compare_count, struct fi_ioc *resultv, void **result_desc, size_t result_count, uint64_t flags) { int i, ret; size_t datatype_sz; struct sock_op tx_op; union sock_iov tx_iov; struct sock_conn *conn; struct sock_tx_ctx *tx_ctx; uint64_t total_len, src_len, dst_len; struct sock_ep *sock_ep; switch (ep->fid.fclass) { case FI_CLASS_EP: sock_ep = container_of(ep, struct sock_ep, ep); tx_ctx = sock_ep->tx_ctx; break; case FI_CLASS_TX_CTX: tx_ctx = container_of(ep, struct sock_tx_ctx, fid.ctx); sock_ep = tx_ctx->ep; break; default: SOCK_LOG_ERROR("Invalid EP type\n"); return -FI_EINVAL; } if (msg->iov_count > SOCK_EP_MAX_IOV_LIMIT || msg->rma_iov_count > SOCK_EP_MAX_IOV_LIMIT) return -FI_EINVAL; if (!tx_ctx->enabled) return -FI_EOPBADSTATE; if (sock_ep->connected) { conn = sock_ep_lookup_conn(sock_ep); } else { conn = sock_av_lookup_addr(sock_ep, tx_ctx->av, msg->addr); if (!conn) { SOCK_LOG_ERROR("Address lookup failed\n"); return -errno; } } if (!conn) return -FI_EAGAIN; SOCK_EP_SET_TX_OP_FLAGS(flags); if (flags & SOCK_USE_OP_FLAGS) flags |= tx_ctx->attr.op_flags; if (msg->op == FI_ATOMIC_READ) { flags &= ~FI_INJECT; } if (sock_ep_is_send_cq_low(&tx_ctx->comp, flags)) { SOCK_LOG_ERROR("CQ size low\n"); return -FI_EAGAIN; } if (flags & FI_TRIGGER) { ret = sock_queue_atomic_op(ep, msg, comparev, compare_count, resultv, result_count, flags, SOCK_OP_ATOMIC); if (ret != 1) return ret; } src_len = 0; datatype_sz = fi_datatype_size(msg->datatype); if (flags & FI_INJECT) { for (i = 0; i < msg->iov_count; i++) src_len += (msg->msg_iov[i].count * datatype_sz); if (src_len > SOCK_EP_MAX_INJECT_SZ) return -FI_EINVAL; total_len = src_len; } else { total_len = msg->iov_count * sizeof(union sock_iov); } total_len += (sizeof(struct sock_op_send) + (msg->rma_iov_count * sizeof(union sock_iov)) + (result_count * sizeof(union sock_iov))); sock_tx_ctx_start(tx_ctx); if (rbfdavail(&tx_ctx->rbfd) < total_len) { ret = -FI_EAGAIN; goto err; } memset(&tx_op, 0, sizeof(tx_op)); tx_op.op = SOCK_OP_ATOMIC; tx_op.dest_iov_len = msg->rma_iov_count; tx_op.atomic.op = msg->op; tx_op.atomic.datatype = msg->datatype; tx_op.atomic.res_iov_len = result_count; tx_op.atomic.cmp_iov_len = compare_count; if (flags & FI_INJECT) tx_op.src_iov_len = src_len; else tx_op.src_iov_len = msg->iov_count; sock_tx_ctx_write_op_send(tx_ctx, &tx_op, flags, (uintptr_t) msg->context, msg->addr, (uintptr_t) msg->msg_iov[0].addr, sock_ep, conn); if (flags & FI_REMOTE_CQ_DATA) sock_tx_ctx_write(tx_ctx, &msg->data, sizeof(uint64_t)); src_len = 0; if (flags & FI_INJECT) { for (i = 0; i < msg->iov_count; i++) { sock_tx_ctx_write(tx_ctx, msg->msg_iov[i].addr, msg->msg_iov[i].count * datatype_sz); src_len += (msg->msg_iov[i].count * datatype_sz); } } else { for (i = 0; i < msg->iov_count; i++) { tx_iov.ioc.addr = (uintptr_t) msg->msg_iov[i].addr; tx_iov.ioc.count = msg->msg_iov[i].count; sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov)); src_len += (tx_iov.ioc.count * datatype_sz); } } #ifdef ENABLE_DEBUG if (src_len > SOCK_EP_MAX_ATOMIC_SZ) { ret = -FI_EINVAL; goto err; } #endif dst_len = 0; for (i = 0; i < msg->rma_iov_count; i++) { tx_iov.ioc.addr = msg->rma_iov[i].addr; tx_iov.ioc.key = msg->rma_iov[i].key; tx_iov.ioc.count = msg->rma_iov[i].count; sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov)); dst_len += (tx_iov.ioc.count * datatype_sz); } if (msg->iov_count && dst_len != src_len) { SOCK_LOG_ERROR("Buffer length mismatch\n"); ret = -FI_EINVAL; goto err; } else { src_len = dst_len; } dst_len = 0; for (i = 0; i < result_count; i++) { tx_iov.ioc.addr = (uintptr_t) resultv[i].addr; tx_iov.ioc.count = resultv[i].count; sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov)); dst_len += (tx_iov.ioc.count * datatype_sz); } #ifdef ENABLE_DEBUG if (result_count && (dst_len != src_len)) { SOCK_LOG_ERROR("Buffer length mismatch\n"); ret = -FI_EINVAL; goto err; } #endif dst_len = 0; for (i = 0; i < compare_count; i++) { tx_iov.ioc.addr = (uintptr_t) comparev[i].addr; tx_iov.ioc.count = comparev[i].count; sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov)); dst_len += (tx_iov.ioc.count * datatype_sz); } #ifdef ENABLE_DEBUG if (compare_count && (dst_len != src_len)) { SOCK_LOG_ERROR("Buffer length mismatch\n"); ret = -FI_EINVAL; goto err; } #endif sock_tx_ctx_commit(tx_ctx); return 0; err: sock_tx_ctx_abort(tx_ctx); return ret; }
static ssize_t sock_ep_tx_atomic(struct fid_ep *ep, const struct fi_msg_atomic *msg, const struct fi_ioc *comparev, void **compare_desc, size_t compare_count, struct fi_ioc *resultv, void **result_desc, size_t result_count, uint64_t flags) { int i, ret; size_t datatype_sz; struct sock_op tx_op; union sock_iov tx_iov; struct sock_conn *conn; struct sock_tx_ctx *tx_ctx; uint64_t total_len, src_len, dst_len; struct sock_ep *sock_ep; switch (ep->fid.fclass) { case FI_CLASS_EP: sock_ep = container_of(ep, struct sock_ep, ep); tx_ctx = sock_ep->tx_ctx; break; case FI_CLASS_TX_CTX: tx_ctx = container_of(ep, struct sock_tx_ctx, fid.ctx); sock_ep = tx_ctx->ep; break; default: SOCK_LOG_ERROR("Invalid EP type\n"); return -FI_EINVAL; } assert(tx_ctx->enabled && msg->iov_count <= SOCK_EP_MAX_IOV_LIMIT && msg->rma_iov_count <= SOCK_EP_MAX_IOV_LIMIT); if (sock_ep->connected) { conn = sock_ep_lookup_conn(sock_ep); } else { conn = sock_av_lookup_addr(sock_ep, tx_ctx->av, msg->addr); } if (!conn) return -FI_EAGAIN; src_len = 0; datatype_sz = fi_datatype_size(msg->datatype); if (flags & FI_INJECT) { for (i=0; i< msg->iov_count; i++) { src_len += (msg->msg_iov[i].count * datatype_sz); } assert(src_len <= SOCK_EP_MAX_INJECT_SZ); total_len = src_len; } else { total_len = msg->iov_count * sizeof(union sock_iov); } total_len += (sizeof(tx_op) + (msg->rma_iov_count * sizeof(union sock_iov)) + (result_count * sizeof (union sock_iov))); sock_tx_ctx_start(tx_ctx); if (rbfdavail(&tx_ctx->rbfd) < total_len) { ret = -FI_EAGAIN; goto err; } flags |= tx_ctx->attr.op_flags; memset(&tx_op, 0, sizeof(tx_op)); tx_op.op = SOCK_OP_ATOMIC; tx_op.dest_iov_len = msg->rma_iov_count; tx_op.atomic.op = msg->op; tx_op.atomic.datatype = msg->datatype; tx_op.atomic.res_iov_len = result_count; tx_op.atomic.cmp_iov_len = compare_count; if (flags & FI_INJECT) tx_op.src_iov_len = src_len; else tx_op.src_iov_len = msg->iov_count; sock_tx_ctx_write_op_send(tx_ctx, &tx_op, flags, (uintptr_t) msg->context, msg->addr, (uintptr_t) msg->msg_iov[0].addr, sock_ep, conn); if (flags & FI_REMOTE_CQ_DATA) { sock_tx_ctx_write(tx_ctx, &msg->data, sizeof(uint64_t)); } src_len = 0; if (flags & FI_INJECT) { for (i=0; i< msg->iov_count; i++) { sock_tx_ctx_write(tx_ctx, msg->msg_iov[i].addr, msg->msg_iov[i].count * datatype_sz); src_len += (msg->msg_iov[i].count * datatype_sz); } } else { for (i = 0; i< msg->iov_count; i++) { tx_iov.ioc.addr = (uintptr_t) msg->msg_iov[i].addr; tx_iov.ioc.count = msg->msg_iov[i].count; tx_iov.ioc.key = (uintptr_t) msg->desc[i]; sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov)); src_len += (tx_iov.ioc.count * datatype_sz); } } assert(src_len <= SOCK_EP_MAX_ATOMIC_SZ); dst_len = 0; for (i = 0; i< msg->rma_iov_count; i++) { tx_iov.ioc.addr = msg->rma_iov[i].addr; tx_iov.ioc.key = msg->rma_iov[i].key; tx_iov.ioc.count = msg->rma_iov[i].count; sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov)); dst_len += (tx_iov.ioc.count * datatype_sz); } if (dst_len != src_len) { SOCK_LOG_ERROR("Buffer length mismatch\n"); ret = -FI_EINVAL; goto err; } dst_len = 0; for (i = 0; i< result_count; i++) { tx_iov.ioc.addr = (uintptr_t) resultv[i].addr; tx_iov.ioc.count = resultv[i].count; sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov)); dst_len += (tx_iov.ioc.count * datatype_sz); } if (result_count && (dst_len != src_len)) { SOCK_LOG_ERROR("Buffer length mismatch\n"); ret = -FI_EINVAL; goto err; } dst_len = 0; for (i = 0; i< compare_count; i++) { tx_iov.ioc.addr = (uintptr_t) comparev[i].addr; tx_iov.ioc.count = comparev[i].count; sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov)); dst_len += (tx_iov.ioc.count * datatype_sz); } if (compare_count && (dst_len != src_len)) { SOCK_LOG_ERROR("Buffer length mismatch\n"); ret = -FI_EINVAL; goto err; } sock_tx_ctx_commit(tx_ctx); return 0; err: SOCK_LOG_INFO("Not enough space for TX entry, try again\n"); sock_tx_ctx_abort(tx_ctx); return ret; }
static ssize_t sock_ep_recv(struct fid_ep *ep, void *buf, size_t len, void *desc, fi_addr_t src_addr, void *context) { struct iovec msg_iov = { .iov_base = buf, .iov_len = len, }; struct fi_msg msg = { .msg_iov = &msg_iov, .desc = &desc, .iov_count = 1, .addr = src_addr, .context = context, .data = 0, }; return sock_ep_recvmsg(ep, &msg, SOCK_USE_OP_FLAGS); } static ssize_t sock_ep_recvv(struct fid_ep *ep, const struct iovec *iov, void **desc, size_t count, fi_addr_t src_addr, void *context) { struct fi_msg msg = { .msg_iov = iov, .desc = desc, .iov_count = count, .addr = src_addr, .context = context, .data = 0, }; return sock_ep_recvmsg(ep, &msg, SOCK_USE_OP_FLAGS); } ssize_t sock_ep_sendmsg(struct fid_ep *ep, const struct fi_msg *msg, uint64_t flags) { int ret; size_t i; uint64_t total_len, op_flags; struct sock_op tx_op; union sock_iov tx_iov; struct sock_conn *conn; struct sock_tx_ctx *tx_ctx; struct sock_ep *sock_ep; struct sock_ep_attr *ep_attr; switch (ep->fid.fclass) { case FI_CLASS_EP: sock_ep = container_of(ep, struct sock_ep, ep); ep_attr = sock_ep->attr; tx_ctx = sock_ep->attr->tx_ctx->use_shared ? sock_ep->attr->tx_ctx->stx_ctx : sock_ep->attr->tx_ctx; op_flags = sock_ep->tx_attr.op_flags; break; case FI_CLASS_TX_CTX: tx_ctx = container_of(ep, struct sock_tx_ctx, fid.ctx); ep_attr = tx_ctx->ep_attr; op_flags = tx_ctx->attr.op_flags; break; default: SOCK_LOG_ERROR("Invalid EP type\n"); return -FI_EINVAL; } #if ENABLE_DEBUG if (msg->iov_count > SOCK_EP_MAX_IOV_LIMIT) return -FI_EINVAL; #endif if (!tx_ctx->enabled) return -FI_EOPBADSTATE; if (sock_drop_packet(ep_attr)) return 0; ret = sock_ep_get_conn(ep_attr, tx_ctx, msg->addr, &conn); if (ret) return ret; SOCK_LOG_DBG("New sendmsg on TX: %p using conn: %p\n", tx_ctx, conn); SOCK_EP_SET_TX_OP_FLAGS(flags); if (flags & SOCK_USE_OP_FLAGS) flags |= op_flags; if (flags & FI_TRIGGER) { ret = sock_queue_msg_op(ep, msg, flags, FI_OP_SEND); if (ret != 1) return ret; } memset(&tx_op, 0, sizeof(struct sock_op)); tx_op.op = SOCK_OP_SEND; total_len = 0; if (flags & FI_INJECT) { for (i = 0; i < msg->iov_count; i++) total_len += msg->msg_iov[i].iov_len; if (total_len > SOCK_EP_MAX_INJECT_SZ) return -FI_EINVAL; tx_op.src_iov_len = total_len; } else { tx_op.src_iov_len = msg->iov_count; total_len = msg->iov_count * sizeof(union sock_iov); } total_len += sizeof(struct sock_op_send); if (flags & FI_REMOTE_CQ_DATA) total_len += sizeof(uint64_t); sock_tx_ctx_start(tx_ctx); if (ofi_rbavail(&tx_ctx->rb) < total_len) { ret = -FI_EAGAIN; goto err; } sock_tx_ctx_write_op_send(tx_ctx, &tx_op, flags, (uintptr_t) msg->context, msg->addr, (uintptr_t) ((msg->iov_count > 0) ? msg->msg_iov[0].iov_base : NULL), ep_attr, conn); if (flags & FI_REMOTE_CQ_DATA) sock_tx_ctx_write(tx_ctx, &msg->data, sizeof(msg->data)); if (flags & FI_INJECT) { for (i = 0; i < msg->iov_count; i++) { sock_tx_ctx_write(tx_ctx, msg->msg_iov[i].iov_base, msg->msg_iov[i].iov_len); } } else { for (i = 0; i < msg->iov_count; i++) { tx_iov.iov.addr = (uintptr_t) msg->msg_iov[i].iov_base; tx_iov.iov.len = msg->msg_iov[i].iov_len; sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov)); } } sock_tx_ctx_commit(tx_ctx); return 0; err: sock_tx_ctx_abort(tx_ctx); return ret; } static ssize_t sock_ep_send(struct fid_ep *ep, const void *buf, size_t len, void *desc, fi_addr_t dest_addr, void *context) { struct iovec msg_iov = { .iov_base = (void *)buf, .iov_len = len, }; struct fi_msg msg = { .msg_iov = &msg_iov, .desc = &desc, .iov_count = 1, .addr = dest_addr, .context = context, .data = 0, }; return sock_ep_sendmsg(ep, &msg, SOCK_USE_OP_FLAGS); } static ssize_t sock_ep_sendv(struct fid_ep *ep, const struct iovec *iov, void **desc, size_t count, fi_addr_t dest_addr, void *context) { struct fi_msg msg = { .msg_iov = iov, .desc = desc, .iov_count = count, .addr = dest_addr, .context = context, .data = 0, }; return sock_ep_sendmsg(ep, &msg, SOCK_USE_OP_FLAGS); } static ssize_t sock_ep_senddata(struct fid_ep *ep, const void *buf, size_t len, void *desc, uint64_t data, fi_addr_t dest_addr, void *context) { struct iovec msg_iov = { .iov_base = (void *)buf, .iov_len = len, }; struct fi_msg msg = { .msg_iov = &msg_iov, .desc = desc, .iov_count = 1, .addr = dest_addr, .context = context, .data = data, }; return sock_ep_sendmsg(ep, &msg, FI_REMOTE_CQ_DATA | SOCK_USE_OP_FLAGS); } static ssize_t sock_ep_inject(struct fid_ep *ep, const void *buf, size_t len, fi_addr_t dest_addr) { struct iovec msg_iov = { .iov_base = (void *)buf, .iov_len = len, }; struct fi_msg msg = { .msg_iov = &msg_iov, .desc = NULL, .iov_count = 1, .addr = dest_addr, .context = NULL, .data = 0, }; return sock_ep_sendmsg(ep, &msg, FI_INJECT | SOCK_NO_COMPLETION | SOCK_USE_OP_FLAGS); } static ssize_t sock_ep_injectdata(struct fid_ep *ep, const void *buf, size_t len, uint64_t data, fi_addr_t dest_addr) { struct iovec msg_iov = { .iov_base = (void *)buf, .iov_len = len, }; struct fi_msg msg = { .msg_iov = &msg_iov, .desc = NULL, .iov_count = 1, .addr = dest_addr, .context = NULL, .data = data, }; return sock_ep_sendmsg(ep, &msg, FI_REMOTE_CQ_DATA | FI_INJECT | SOCK_NO_COMPLETION | SOCK_USE_OP_FLAGS); } struct fi_ops_msg sock_ep_msg_ops = { .size = sizeof(struct fi_ops_msg), .recv = sock_ep_recv, .recvv = sock_ep_recvv, .recvmsg = sock_ep_recvmsg, .send = sock_ep_send, .sendv = sock_ep_sendv, .sendmsg = sock_ep_sendmsg, .inject = sock_ep_inject, .senddata = sock_ep_senddata, .injectdata = sock_ep_injectdata }; ssize_t sock_ep_trecvmsg(struct fid_ep *ep, const struct fi_msg_tagged *msg, uint64_t flags) { int ret; size_t i; struct sock_rx_ctx *rx_ctx; struct sock_rx_entry *rx_entry; struct sock_ep *sock_ep; uint64_t op_flags; switch (ep->fid.fclass) { case FI_CLASS_EP: sock_ep = container_of(ep, struct sock_ep, ep); rx_ctx = sock_ep->attr->rx_ctx; op_flags = sock_ep->rx_attr.op_flags; break; case FI_CLASS_RX_CTX: case FI_CLASS_SRX_CTX: rx_ctx = container_of(ep, struct sock_rx_ctx, ctx); op_flags = rx_ctx->attr.op_flags; break; default: SOCK_LOG_ERROR("Invalid ep type\n"); return -FI_EINVAL; } #if ENABLE_DEBUG if (msg->iov_count > SOCK_EP_MAX_IOV_LIMIT) return -FI_EINVAL; #endif if (!rx_ctx->enabled) return -FI_EOPBADSTATE; if (flags & SOCK_USE_OP_FLAGS) flags |= op_flags; flags &= ~FI_MULTI_RECV; if (flags & FI_TRIGGER) { ret = sock_queue_tmsg_op(ep, msg, flags, FI_OP_TRECV); if (ret != 1) return ret; } if (flags & FI_PEEK) { return sock_rx_peek_recv(rx_ctx, msg->addr, msg->tag, msg->ignore, msg->context, flags, 1); } else if (flags & FI_CLAIM) { return sock_rx_claim_recv(rx_ctx, msg->context, flags, msg->tag, msg->ignore, 1, msg->msg_iov, msg->iov_count); } fastlock_acquire(&rx_ctx->lock); rx_entry = sock_rx_new_entry(rx_ctx); fastlock_release(&rx_ctx->lock); if (!rx_entry) return -FI_ENOMEM; rx_entry->rx_op.op = SOCK_OP_TRECV; rx_entry->rx_op.dest_iov_len = msg->iov_count; rx_entry->flags = flags; rx_entry->context = (uintptr_t) msg->context; rx_entry->addr = (rx_ctx->attr.caps & FI_DIRECTED_RECV) ? msg->addr : FI_ADDR_UNSPEC; rx_entry->data = msg->data; rx_entry->tag = msg->tag; rx_entry->ignore = msg->ignore; rx_entry->is_tagged = 1; for (i = 0; i < msg->iov_count; i++) { rx_entry->iov[i].iov.addr = (uintptr_t) msg->msg_iov[i].iov_base; rx_entry->iov[i].iov.len = msg->msg_iov[i].iov_len; rx_entry->total_len += rx_entry->iov[i].iov.len; } fastlock_acquire(&rx_ctx->lock); SOCK_LOG_DBG("New rx_entry: %p (ctx: %p)\n", rx_entry, rx_ctx); dlist_insert_tail(&rx_entry->entry, &rx_ctx->rx_entry_list); fastlock_release(&rx_ctx->lock); return 0; } static ssize_t sock_ep_trecv(struct fid_ep *ep, void *buf, size_t len, void *desc, fi_addr_t src_addr, uint64_t tag, uint64_t ignore, void *context) { struct iovec msg_iov = { .iov_base = buf, .iov_len = len, }; struct fi_msg_tagged msg = { .msg_iov = &msg_iov, .desc = &desc, .iov_count = 1, .addr = src_addr, .context = context, .tag = tag, .ignore = ignore, .data = 0, }; return sock_ep_trecvmsg(ep, &msg, SOCK_USE_OP_FLAGS); } static ssize_t sock_ep_trecvv(struct fid_ep *ep, const struct iovec *iov, void **desc, size_t count, fi_addr_t src_addr, uint64_t tag, uint64_t ignore, void *context) { struct fi_msg_tagged msg = { .msg_iov = iov, .desc = desc, .iov_count = count, .addr = src_addr, .context = context, .tag = tag, .ignore = ignore, .data = 0, }; return sock_ep_trecvmsg(ep, &msg, SOCK_USE_OP_FLAGS); } ssize_t sock_ep_tsendmsg(struct fid_ep *ep, const struct fi_msg_tagged *msg, uint64_t flags) { int ret; size_t i; uint64_t total_len, op_flags; struct sock_op tx_op; union sock_iov tx_iov; struct sock_conn *conn; struct sock_tx_ctx *tx_ctx; struct sock_ep *sock_ep; struct sock_ep_attr *ep_attr; switch (ep->fid.fclass) { case FI_CLASS_EP: sock_ep = container_of(ep, struct sock_ep, ep); tx_ctx = sock_ep->attr->tx_ctx->use_shared ? sock_ep->attr->tx_ctx->stx_ctx : sock_ep->attr->tx_ctx; ep_attr = sock_ep->attr; op_flags = sock_ep->tx_attr.op_flags; break; case FI_CLASS_TX_CTX: tx_ctx = container_of(ep, struct sock_tx_ctx, fid.ctx); ep_attr = tx_ctx->ep_attr; op_flags = tx_ctx->attr.op_flags; break; default: SOCK_LOG_ERROR("Invalid EP type\n"); return -FI_EINVAL; } #if ENABLE_DEBUG if (msg->iov_count > SOCK_EP_MAX_IOV_LIMIT) return -FI_EINVAL; #endif if (!tx_ctx->enabled) return -FI_EOPBADSTATE; if (sock_drop_packet(ep_attr)) return 0; ret = sock_ep_get_conn(ep_attr, tx_ctx, msg->addr, &conn); if (ret) return ret; SOCK_EP_SET_TX_OP_FLAGS(flags); if (flags & SOCK_USE_OP_FLAGS) flags |= op_flags; if (flags & FI_TRIGGER) { ret = sock_queue_tmsg_op(ep, msg, flags, FI_OP_TSEND); if (ret != 1) return ret; } memset(&tx_op, 0, sizeof(tx_op)); tx_op.op = SOCK_OP_TSEND; total_len = 0; if (flags & FI_INJECT) { for (i = 0; i < msg->iov_count; i++) total_len += msg->msg_iov[i].iov_len; tx_op.src_iov_len = total_len; if (total_len > SOCK_EP_MAX_INJECT_SZ) return -FI_EINVAL; } else { total_len = msg->iov_count * sizeof(union sock_iov); tx_op.src_iov_len = msg->iov_count; } total_len += sizeof(struct sock_op_tsend); if (flags & FI_REMOTE_CQ_DATA) total_len += sizeof(uint64_t); sock_tx_ctx_start(tx_ctx); if (ofi_rbavail(&tx_ctx->rb) < total_len) { ret = -FI_EAGAIN; goto err; } sock_tx_ctx_write_op_tsend(tx_ctx, &tx_op, flags, (uintptr_t) msg->context, msg->addr, (uintptr_t) ((msg->iov_count > 0) ? msg->msg_iov[0].iov_base : NULL), ep_attr, conn, msg->tag); if (flags & FI_REMOTE_CQ_DATA) sock_tx_ctx_write(tx_ctx, &msg->data, sizeof(msg->data)); if (flags & FI_INJECT) { for (i = 0; i < msg->iov_count; i++) { sock_tx_ctx_write(tx_ctx, msg->msg_iov[i].iov_base, msg->msg_iov[i].iov_len); } } else { for (i = 0; i < msg->iov_count; i++) { tx_iov.iov.addr = (uintptr_t) msg->msg_iov[i].iov_base; tx_iov.iov.len = msg->msg_iov[i].iov_len; sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov)); } } sock_tx_ctx_commit(tx_ctx); return 0; err: sock_tx_ctx_abort(tx_ctx); return ret; } static ssize_t sock_ep_tsend(struct fid_ep *ep, const void *buf, size_t len, void *desc, fi_addr_t dest_addr, uint64_t tag, void *context) { struct iovec msg_iov = { .iov_base = (void *)buf, .iov_len = len, }; struct fi_msg_tagged msg = { .msg_iov = &msg_iov, .desc = &desc, .iov_count = 1, .addr = dest_addr, .tag = tag, .ignore = 0, .context = context, .data = 0, }; return sock_ep_tsendmsg(ep, &msg, SOCK_USE_OP_FLAGS); } static ssize_t sock_ep_tsendv(struct fid_ep *ep, const struct iovec *iov, void **desc, size_t count, fi_addr_t dest_addr, uint64_t tag, void *context) { struct fi_msg_tagged msg = { .msg_iov = iov, .desc = desc, .iov_count = count, .addr = dest_addr, .tag = tag, .ignore = 0, .context = context, .data = 0, }; return sock_ep_tsendmsg(ep, &msg, SOCK_USE_OP_FLAGS); } static ssize_t sock_ep_tsenddata(struct fid_ep *ep, const void *buf, size_t len, void *desc, uint64_t data, fi_addr_t dest_addr, uint64_t tag, void *context) { struct iovec msg_iov = { .iov_base = (void *)buf, .iov_len = len, }; struct fi_msg_tagged msg = { .msg_iov = &msg_iov, .desc = desc, .iov_count = 1, .addr = dest_addr, .tag = tag, .ignore = 0, .context = context, .data = data, }; return sock_ep_tsendmsg(ep, &msg, FI_REMOTE_CQ_DATA | SOCK_USE_OP_FLAGS); } static ssize_t sock_ep_tinject(struct fid_ep *ep, const void *buf, size_t len, fi_addr_t dest_addr, uint64_t tag) { struct iovec msg_iov = { .iov_base = (void *)buf, .iov_len = len, }; struct fi_msg_tagged msg = { .msg_iov = &msg_iov, .desc = NULL, .iov_count = 1, .addr = dest_addr, .tag = tag, .ignore = 0, .context = NULL, .data = 0, }; return sock_ep_tsendmsg(ep, &msg, FI_INJECT | SOCK_NO_COMPLETION | SOCK_USE_OP_FLAGS); } static ssize_t sock_ep_tinjectdata(struct fid_ep *ep, const void *buf, size_t len, uint64_t data, fi_addr_t dest_addr, uint64_t tag) { struct iovec msg_iov = { .iov_base = (void *)buf, .iov_len = len, }; struct fi_msg_tagged msg = { .msg_iov = &msg_iov, .desc = NULL, .iov_count = 1, .addr = dest_addr, .tag = tag, .ignore = 0, .context = NULL, .data = data, }; return sock_ep_tsendmsg(ep, &msg, FI_REMOTE_CQ_DATA | FI_INJECT | SOCK_NO_COMPLETION | SOCK_USE_OP_FLAGS); } struct fi_ops_tagged sock_ep_tagged = { .size = sizeof(struct fi_ops_tagged), .recv = sock_ep_trecv, .recvv = sock_ep_trecvv, .recvmsg = sock_ep_trecvmsg, .send = sock_ep_tsend, .sendv = sock_ep_tsendv, .sendmsg = sock_ep_tsendmsg, .inject = sock_ep_tinject, .senddata = sock_ep_tsenddata, .injectdata = sock_ep_tinjectdata, };
ssize_t sock_ep_rma_readmsg(struct fid_ep *ep, const struct fi_msg_rma *msg, uint64_t flags) { int ret, i; struct sock_op tx_op; union sock_iov tx_iov; struct sock_conn *conn; struct sock_tx_ctx *tx_ctx; uint64_t total_len, src_len, dst_len; struct sock_ep *sock_ep; switch (ep->fid.fclass) { case FI_CLASS_EP: sock_ep = container_of(ep, struct sock_ep, ep); tx_ctx = sock_ep->tx_ctx; break; case FI_CLASS_TX_CTX: tx_ctx = container_of(ep, struct sock_tx_ctx, fid.ctx); sock_ep = tx_ctx->ep; break; default: SOCK_LOG_ERROR("Invalid EP type\n"); return -FI_EINVAL; } #if ENABLE_DEBUG if (msg->iov_count > SOCK_EP_MAX_IOV_LIMIT || msg->rma_iov_count > SOCK_EP_MAX_IOV_LIMIT) return -FI_EINVAL; #endif if (!tx_ctx->enabled) return -FI_EOPBADSTATE; if (sock_ep->connected) { conn = sock_ep_lookup_conn(sock_ep); } else { conn = sock_av_lookup_addr(sock_ep, tx_ctx->av, msg->addr); if (!conn) { SOCK_LOG_ERROR("Address lookup failed\n"); return -errno; } } if (!conn) return -FI_EAGAIN; SOCK_EP_SET_TX_OP_FLAGS(flags); if (flags & SOCK_USE_OP_FLAGS) flags |= tx_ctx->attr.op_flags; if (sock_ep_is_send_cq_low(&tx_ctx->comp, flags)) { SOCK_LOG_ERROR("CQ size low\n"); return -FI_EAGAIN; } if (flags & FI_TRIGGER) { ret = sock_queue_rma_op(ep, msg, flags, SOCK_OP_READ); if (ret != 1) return ret; } total_len = sizeof(struct sock_op_send) + (msg->iov_count * sizeof(union sock_iov)) + (msg->rma_iov_count * sizeof(union sock_iov)); sock_tx_ctx_start(tx_ctx); if (rbfdavail(&tx_ctx->rbfd) < total_len) { ret = -FI_EAGAIN; goto err; } memset(&tx_op, 0, sizeof(struct sock_op)); tx_op.op = SOCK_OP_READ; tx_op.src_iov_len = msg->rma_iov_count; tx_op.dest_iov_len = msg->iov_count; sock_tx_ctx_write_op_send(tx_ctx, &tx_op, flags, (uintptr_t) msg->context, msg->addr, (uintptr_t) msg->msg_iov[0].iov_base, sock_ep, conn); src_len = 0; for (i = 0; i < msg->rma_iov_count; i++) { tx_iov.iov.addr = msg->rma_iov[i].addr; tx_iov.iov.key = msg->rma_iov[i].key; tx_iov.iov.len = msg->rma_iov[i].len; sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov)); src_len += tx_iov.iov.len; } dst_len = 0; for (i = 0; i < msg->iov_count; i++) { tx_iov.iov.addr = (uintptr_t) msg->msg_iov[i].iov_base; tx_iov.iov.len = msg->msg_iov[i].iov_len; sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov)); dst_len += tx_iov.iov.len; } #if ENABLE_DEBUG if (dst_len != src_len) { SOCK_LOG_ERROR("Buffer length mismatch\n"); ret = -FI_EINVAL; goto err; } #endif sock_tx_ctx_commit(tx_ctx); return 0; err: sock_tx_ctx_abort(tx_ctx); return ret; }
ssize_t sock_ep_tsendmsg(struct fid_ep *ep, const struct fi_msg_tagged *msg, uint64_t flags) { int ret, i; uint64_t total_len; struct sock_op tx_op; union sock_iov tx_iov; struct sock_conn *conn; struct sock_tx_ctx *tx_ctx; struct sock_ep *sock_ep; switch (ep->fid.fclass) { case FI_CLASS_EP: sock_ep = container_of(ep, struct sock_ep, ep); tx_ctx = sock_ep->tx_ctx; break; case FI_CLASS_TX_CTX: tx_ctx = container_of(ep, struct sock_tx_ctx, fid.ctx); sock_ep = tx_ctx->ep; break; default: SOCK_LOG_ERROR("Invalid EP type\n"); return -FI_EINVAL; } #if ENABLE_DEBUG if (msg->iov_count > SOCK_EP_MAX_IOV_LIMIT) return -FI_EINVAL; #endif if (!tx_ctx->enabled) return -FI_EOPBADSTATE; if (sock_drop_packet(sock_ep)) return 0; ret = sock_ep_get_conn(sock_ep, tx_ctx, msg->addr, &conn); if (ret) return ret; SOCK_EP_SET_TX_OP_FLAGS(flags); if (flags & SOCK_USE_OP_FLAGS) flags |= tx_ctx->attr.op_flags; if (sock_ep_is_send_cq_low(&tx_ctx->comp, flags)) { SOCK_LOG_ERROR("CQ size low\n"); return -FI_EAGAIN; } if (flags & FI_TRIGGER) { ret = sock_queue_tmsg_op(ep, msg, flags, SOCK_OP_TSEND); if (ret != 1) return ret; } memset(&tx_op, 0, sizeof(tx_op)); tx_op.op = SOCK_OP_TSEND; total_len = 0; if (flags & FI_INJECT) { for (i = 0; i < msg->iov_count; i++) total_len += msg->msg_iov[i].iov_len; tx_op.src_iov_len = total_len; if (total_len > SOCK_EP_MAX_INJECT_SZ) { ret = -FI_EINVAL; goto err; } } else { total_len = msg->iov_count * sizeof(union sock_iov); tx_op.src_iov_len = msg->iov_count; } total_len += sizeof(struct sock_op_tsend); if (flags & FI_REMOTE_CQ_DATA) total_len += sizeof(uint64_t); sock_tx_ctx_start(tx_ctx); if (rbavail(&tx_ctx->rb) < total_len) { ret = -FI_EAGAIN; goto err; } sock_tx_ctx_write_op_tsend(tx_ctx, &tx_op, flags, (uintptr_t) msg->context, msg->addr, (uintptr_t) msg->msg_iov[0].iov_base, sock_ep, conn, msg->tag); if (flags & FI_REMOTE_CQ_DATA) sock_tx_ctx_write(tx_ctx, &msg->data, sizeof(msg->data)); if (flags & FI_INJECT) { for (i = 0; i < msg->iov_count; i++) { sock_tx_ctx_write(tx_ctx, msg->msg_iov[i].iov_base, msg->msg_iov[i].iov_len); } } else { for (i = 0; i < msg->iov_count; i++) { tx_iov.iov.addr = (uintptr_t) msg->msg_iov[i].iov_base; tx_iov.iov.len = msg->msg_iov[i].iov_len; sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov)); } } sock_tx_ctx_commit(tx_ctx); return 0; err: sock_tx_ctx_abort(tx_ctx); return ret; }