static int psmx_atomic_readwritevalid(struct fid_ep *ep, enum fi_datatype datatype, enum fi_op op, size_t *count) { int chunk_size; if (datatype < 0 || datatype >= FI_DATATYPE_LAST) return -FI_EOPNOTSUPP; switch (op) { case FI_MIN: case FI_MAX: case FI_SUM: case FI_PROD: case FI_LOR: case FI_LAND: case FI_BOR: case FI_BAND: case FI_LXOR: case FI_BXOR: case FI_ATOMIC_READ: case FI_ATOMIC_WRITE: break; default: return -FI_EOPNOTSUPP; } if (count) { chunk_size = MIN(PSMX_AM_CHUNK_SIZE, psmx_am_param.max_request_short); *count = chunk_size / fi_datatype_size(datatype); } return 0; }
static int sock_ep_atomic_valid(struct fid_ep *ep, enum fi_datatype datatype, enum fi_op op, size_t *count) { size_t datatype_sz; switch (datatype) { case FI_FLOAT: case FI_DOUBLE: case FI_LONG_DOUBLE: if (op == FI_BOR || op == FI_BAND || op == FI_BXOR || op == FI_MSWAP) return -FI_ENOENT; break; case FI_FLOAT_COMPLEX: case FI_DOUBLE_COMPLEX: case FI_LONG_DOUBLE_COMPLEX: if (op == FI_BOR || op == FI_BAND || op == FI_BXOR || op == FI_MSWAP || op == FI_MIN || op == FI_MAX || op == FI_CSWAP_LE || op == FI_CSWAP_LT || op == FI_CSWAP_GE || op == FI_CSWAP_GT) return -FI_ENOENT; break; default: break; } datatype_sz = fi_datatype_size(datatype); if (datatype_sz == 0) return -FI_ENOENT; *count = (SOCK_EP_MAX_ATOMIC_SZ/datatype_sz); return 0; }
static int psmx_atomic_compwritevalid(struct fid_ep *ep, enum fi_datatype datatype, enum fi_op op, size_t *count) { int chunk_size; if (datatype < 0 || datatype >= FI_DATATYPE_LAST) return -FI_EOPNOTSUPP; switch (op) { case FI_CSWAP: case FI_CSWAP_NE: break; case FI_CSWAP_LE: case FI_CSWAP_LT: case FI_CSWAP_GE: case FI_CSWAP_GT: if (datatype == FI_FLOAT_COMPLEX || datatype == FI_DOUBLE_COMPLEX || datatype == FI_LONG_DOUBLE_COMPLEX) return -FI_EOPNOTSUPP; break; case FI_MSWAP: if (datatype == FI_FLOAT_COMPLEX || datatype == FI_DOUBLE_COMPLEX || datatype == FI_LONG_DOUBLE_COMPLEX || datatype == FI_FLOAT || datatype == FI_DOUBLE || datatype == FI_LONG_DOUBLE) return -FI_EOPNOTSUPP; break; default: return -FI_EOPNOTSUPP; } if (count) { chunk_size = MIN(PSMX_AM_CHUNK_SIZE, psmx_am_param.max_request_short); *count = chunk_size / (2 * fi_datatype_size(datatype)); } return 0; }
static int __gnix_amo_txd_complete(void *arg, gni_return_t tx_status) { struct gnix_tx_descriptor *txd = (struct gnix_tx_descriptor *)arg; struct gnix_fab_req *req = txd->req; int rc = FI_SUCCESS; uint32_t read_data32; uint64_t read_data64; if (tx_status != GNI_RC_SUCCESS) { return __gnix_amo_post_err(txd); } /* FI_ATOMIC_READ data is delivered to operand buffer in addition to * the results buffer. */ if (req->amo.op == FI_ATOMIC_READ) { switch(fi_datatype_size(req->amo.datatype)) { case sizeof(uint32_t): read_data32 = *(uint32_t *)req->amo.loc_addr; *(uint32_t *)req->amo.read_buf = read_data32; break; case sizeof(uint64_t): read_data64 = *(uint64_t *)req->amo.loc_addr; *(uint64_t *)req->amo.read_buf = read_data64; break; default: GNIX_WARN(FI_LOG_EP_DATA, "Invalid datatype: %d\n", req->amo.datatype); assert(0); break; } } /* complete request */ rc = __gnix_amo_send_completion(req->vc->ep, req); if (rc != FI_SUCCESS) GNIX_WARN(FI_LOG_EP_DATA, "__gnix_amo_send_completion() failed: %d\n", rc); __gnix_amo_fr_complete(req, txd); return FI_SUCCESS; }
ssize_t sock_ep_tx_atomic(struct fid_ep *ep, const struct fi_msg_atomic *msg, const struct fi_ioc *comparev, void **compare_desc, size_t compare_count, struct fi_ioc *resultv, void **result_desc, size_t result_count, uint64_t flags) { int i, ret; size_t datatype_sz; struct sock_op tx_op; union sock_iov tx_iov; struct sock_conn *conn; struct sock_tx_ctx *tx_ctx; uint64_t total_len, src_len, dst_len; struct sock_ep *sock_ep; switch (ep->fid.fclass) { case FI_CLASS_EP: sock_ep = container_of(ep, struct sock_ep, ep); tx_ctx = sock_ep->tx_ctx; break; case FI_CLASS_TX_CTX: tx_ctx = container_of(ep, struct sock_tx_ctx, fid.ctx); sock_ep = tx_ctx->ep; break; default: SOCK_LOG_ERROR("Invalid EP type\n"); return -FI_EINVAL; } if (msg->iov_count > SOCK_EP_MAX_IOV_LIMIT || msg->rma_iov_count > SOCK_EP_MAX_IOV_LIMIT) return -FI_EINVAL; if (!tx_ctx->enabled) return -FI_EOPBADSTATE; if (sock_ep->connected) { conn = sock_ep_lookup_conn(sock_ep); } else { conn = sock_av_lookup_addr(sock_ep, tx_ctx->av, msg->addr); if (!conn) { SOCK_LOG_ERROR("Address lookup failed\n"); return -errno; } } if (!conn) return -FI_EAGAIN; SOCK_EP_SET_TX_OP_FLAGS(flags); if (flags & SOCK_USE_OP_FLAGS) flags |= tx_ctx->attr.op_flags; if (msg->op == FI_ATOMIC_READ) { flags &= ~FI_INJECT; } if (sock_ep_is_send_cq_low(&tx_ctx->comp, flags)) { SOCK_LOG_ERROR("CQ size low\n"); return -FI_EAGAIN; } if (flags & FI_TRIGGER) { ret = sock_queue_atomic_op(ep, msg, comparev, compare_count, resultv, result_count, flags, SOCK_OP_ATOMIC); if (ret != 1) return ret; } src_len = 0; datatype_sz = fi_datatype_size(msg->datatype); if (flags & FI_INJECT) { for (i = 0; i < msg->iov_count; i++) src_len += (msg->msg_iov[i].count * datatype_sz); if (src_len > SOCK_EP_MAX_INJECT_SZ) return -FI_EINVAL; total_len = src_len; } else { total_len = msg->iov_count * sizeof(union sock_iov); } total_len += (sizeof(struct sock_op_send) + (msg->rma_iov_count * sizeof(union sock_iov)) + (result_count * sizeof(union sock_iov))); sock_tx_ctx_start(tx_ctx); if (rbfdavail(&tx_ctx->rbfd) < total_len) { ret = -FI_EAGAIN; goto err; } memset(&tx_op, 0, sizeof(tx_op)); tx_op.op = SOCK_OP_ATOMIC; tx_op.dest_iov_len = msg->rma_iov_count; tx_op.atomic.op = msg->op; tx_op.atomic.datatype = msg->datatype; tx_op.atomic.res_iov_len = result_count; tx_op.atomic.cmp_iov_len = compare_count; if (flags & FI_INJECT) tx_op.src_iov_len = src_len; else tx_op.src_iov_len = msg->iov_count; sock_tx_ctx_write_op_send(tx_ctx, &tx_op, flags, (uintptr_t) msg->context, msg->addr, (uintptr_t) msg->msg_iov[0].addr, sock_ep, conn); if (flags & FI_REMOTE_CQ_DATA) sock_tx_ctx_write(tx_ctx, &msg->data, sizeof(uint64_t)); src_len = 0; if (flags & FI_INJECT) { for (i = 0; i < msg->iov_count; i++) { sock_tx_ctx_write(tx_ctx, msg->msg_iov[i].addr, msg->msg_iov[i].count * datatype_sz); src_len += (msg->msg_iov[i].count * datatype_sz); } } else { for (i = 0; i < msg->iov_count; i++) { tx_iov.ioc.addr = (uintptr_t) msg->msg_iov[i].addr; tx_iov.ioc.count = msg->msg_iov[i].count; sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov)); src_len += (tx_iov.ioc.count * datatype_sz); } } #ifdef ENABLE_DEBUG if (src_len > SOCK_EP_MAX_ATOMIC_SZ) { ret = -FI_EINVAL; goto err; } #endif dst_len = 0; for (i = 0; i < msg->rma_iov_count; i++) { tx_iov.ioc.addr = msg->rma_iov[i].addr; tx_iov.ioc.key = msg->rma_iov[i].key; tx_iov.ioc.count = msg->rma_iov[i].count; sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov)); dst_len += (tx_iov.ioc.count * datatype_sz); } if (msg->iov_count && dst_len != src_len) { SOCK_LOG_ERROR("Buffer length mismatch\n"); ret = -FI_EINVAL; goto err; } else { src_len = dst_len; } dst_len = 0; for (i = 0; i < result_count; i++) { tx_iov.ioc.addr = (uintptr_t) resultv[i].addr; tx_iov.ioc.count = resultv[i].count; sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov)); dst_len += (tx_iov.ioc.count * datatype_sz); } #ifdef ENABLE_DEBUG if (result_count && (dst_len != src_len)) { SOCK_LOG_ERROR("Buffer length mismatch\n"); ret = -FI_EINVAL; goto err; } #endif dst_len = 0; for (i = 0; i < compare_count; i++) { tx_iov.ioc.addr = (uintptr_t) comparev[i].addr; tx_iov.ioc.count = comparev[i].count; sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov)); dst_len += (tx_iov.ioc.count * datatype_sz); } #ifdef ENABLE_DEBUG if (compare_count && (dst_len != src_len)) { SOCK_LOG_ERROR("Buffer length mismatch\n"); ret = -FI_EINVAL; goto err; } #endif sock_tx_ctx_commit(tx_ctx); return 0; err: sock_tx_ctx_abort(tx_ctx); return ret; }
ssize_t _gnix_atomic(struct gnix_fid_ep *ep, enum gnix_fab_req_type fr_type, const struct fi_msg_atomic *msg, const struct fi_ioc *comparev, void **compare_desc, size_t compare_count, struct fi_ioc *resultv, void **result_desc, size_t result_count, uint64_t flags) { struct gnix_vc *vc; struct gnix_fab_req *req; struct gnix_fid_mem_desc *md = NULL; int rc, len; struct fid_mr *auto_mr = NULL; void *mdesc = NULL; uint64_t compare_operand = 0; void *loc_addr = NULL; int dt_len, dt_align; if (!ep || !msg || !msg->msg_iov || !msg->msg_iov[0].addr || msg->msg_iov[0].count != 1 || msg->iov_count != 1 || !msg->rma_iov || !msg->rma_iov[0].addr) return -FI_EINVAL; if (fr_type == GNIX_FAB_RQ_CAMO) { if (!comparev || !comparev[0].addr || compare_count != 1) return -FI_EINVAL; compare_operand = *(uint64_t *)comparev[0].addr; } dt_len = fi_datatype_size(msg->datatype); dt_align = dt_len - 1; len = dt_len * msg->msg_iov->count; if (msg->rma_iov->addr & dt_align) { GNIX_INFO(FI_LOG_EP_DATA, "Invalid target alignment: %d (mask 0x%x)\n", msg->rma_iov->addr, dt_align); return -FI_EINVAL; } /* need a memory descriptor for all fetching and comparison AMOs */ if (fr_type == GNIX_FAB_RQ_FAMO || fr_type == GNIX_FAB_RQ_CAMO) { if (!resultv || !resultv[0].addr || result_count != 1) return -FI_EINVAL; loc_addr = resultv[0].addr; if ((uint64_t)loc_addr & dt_align) { GNIX_INFO(FI_LOG_EP_DATA, "Invalid source alignment: %d (mask 0x%x)\n", loc_addr, dt_align); return -FI_EINVAL; } if (!result_desc || !result_desc[0]) { rc = gnix_mr_reg(&ep->domain->domain_fid.fid, loc_addr, len, FI_READ | FI_WRITE, 0, 0, 0, &auto_mr, NULL); if (rc != FI_SUCCESS) { GNIX_INFO(FI_LOG_EP_DATA, "Failed to auto-register local buffer: %d\n", rc); return rc; } flags |= FI_LOCAL_MR; mdesc = (void *)auto_mr; GNIX_INFO(FI_LOG_EP_DATA, "auto-reg MR: %p\n", auto_mr); } else { mdesc = result_desc[0]; } } /* find VC for target */ rc = _gnix_ep_get_vc(ep, msg->addr, &vc); if (rc) { GNIX_INFO(FI_LOG_EP_DATA, "_gnix_ep_get_vc() failed, addr: %lx, rc:\n", msg->addr, rc); goto err_get_vc; } /* setup fabric request */ req = _gnix_fr_alloc(ep); if (!req) { GNIX_INFO(FI_LOG_EP_DATA, "_gnix_fr_alloc() failed\n"); rc = -FI_ENOSPC; goto err_fr_alloc; } req->type = fr_type; req->gnix_ep = ep; req->vc = vc; req->user_context = msg->context; req->work_fn = _gnix_amo_post_req; if (mdesc) { md = container_of(mdesc, struct gnix_fid_mem_desc, mr_fid); } req->amo.loc_md = (void *)md; req->amo.loc_addr = (uint64_t)loc_addr; if (msg->op == FI_ATOMIC_READ) { /* Atomic reads are the only AMO which write to the operand * buffer. It's assumed that this is in addition to writing * fetched data to the result buffer. Make the NIC write to * the result buffer, like all other AMOS, and copy read data * to the operand buffer after the completion is received. */ req->amo.first_operand = 0xFFFFFFFFFFFFFFFF; /* operand to FAND */ req->amo.read_buf = msg->msg_iov[0].addr; } else if (msg->op == FI_CSWAP) { req->amo.first_operand = compare_operand; req->amo.second_operand = *(uint64_t *)msg->msg_iov[0].addr; } else if (msg->op == FI_MSWAP) { req->amo.first_operand = ~compare_operand; req->amo.second_operand = *(uint64_t *)msg->msg_iov[0].addr; req->amo.second_operand &= compare_operand; } else { req->amo.first_operand = *(uint64_t *)msg->msg_iov[0].addr; } req->amo.rem_addr = msg->rma_iov->addr; req->amo.rem_mr_key = msg->rma_iov->key; req->amo.len = len; req->amo.imm = msg->data; req->amo.datatype = msg->datatype; req->amo.op = msg->op; req->flags = flags; /* Inject interfaces always suppress completions. If * SELECTIVE_COMPLETION is set, honor any setting. Otherwise, always * deliver a completion. */ if ((flags & GNIX_SUPPRESS_COMPLETION) || (ep->send_selective_completion && !(flags & FI_COMPLETION))) { req->flags &= ~FI_COMPLETION; } else { req->flags |= FI_COMPLETION; } return _gnix_vc_queue_tx_req(req); err_fr_alloc: err_get_vc: if (auto_mr) { fi_close(&auto_mr->fid); } return rc; }
ssize_t _psmx_atomic_readwrite(struct fid_ep *ep, const void *buf, size_t count, void *desc, void *result, void *result_desc, fi_addr_t dest_addr, uint64_t addr, uint64_t key, enum fi_datatype datatype, enum fi_op op, void *context, uint64_t flags) { struct psmx_fid_ep *ep_priv; struct psmx_fid_av *av; struct psmx_epaddr_context *epaddr_context; struct psmx_am_request *req; psm_amarg_t args[8]; int am_flags = PSM_AM_FLAG_ASYNC; int chunk_size, len; size_t idx; ep_priv = container_of(ep, struct psmx_fid_ep, ep); if (flags & FI_TRIGGER) { struct psmx_trigger *trigger; struct fi_triggered_context *ctxt = context; trigger = calloc(1, sizeof(*trigger)); if (!trigger) return -FI_ENOMEM; trigger->op = PSMX_TRIGGERED_ATOMIC_READWRITE; trigger->cntr = container_of(ctxt->trigger.threshold.cntr, struct psmx_fid_cntr, cntr); trigger->threshold = ctxt->trigger.threshold.threshold; trigger->atomic_readwrite.ep = ep; trigger->atomic_readwrite.buf = buf; trigger->atomic_readwrite.count = count; trigger->atomic_readwrite.desc = desc; trigger->atomic_readwrite.result = result; trigger->atomic_readwrite.result_desc = result_desc; trigger->atomic_readwrite.dest_addr = dest_addr; trigger->atomic_readwrite.addr = addr; trigger->atomic_readwrite.key = key; trigger->atomic_readwrite.datatype = datatype; trigger->atomic_readwrite.atomic_op = op; trigger->atomic_readwrite.context = context; trigger->atomic_readwrite.flags = flags & ~FI_TRIGGER; psmx_cntr_add_trigger(trigger->cntr, trigger); return 0; } if (!buf && op != FI_ATOMIC_READ) return -FI_EINVAL; if (datatype < 0 || datatype >= FI_DATATYPE_LAST) return -FI_EINVAL; if (op < 0 || op >= FI_ATOMIC_OP_LAST) return -FI_EINVAL; av = ep_priv->av; if (av && av->type == FI_AV_TABLE) { idx = dest_addr; if (idx >= av->last) return -FI_EINVAL; dest_addr = (fi_addr_t) av->psm_epaddrs[idx]; } else if (!dest_addr) { return -FI_EINVAL; } epaddr_context = psm_epaddr_getctxt((void *)dest_addr); if (epaddr_context->epid == ep_priv->domain->psm_epid) return psmx_atomic_self(PSMX_AM_REQ_ATOMIC_READWRITE, ep_priv, buf, count, desc, NULL, NULL, result, result_desc, addr, key, datatype, op, context, flags); chunk_size = MIN(PSMX_AM_CHUNK_SIZE, psmx_am_param.max_request_short); len = fi_datatype_size(datatype) * count; if (len > chunk_size) return -FI_EMSGSIZE; if ((flags & FI_INJECT) && op != FI_ATOMIC_READ) { req = malloc(sizeof(*req) + len); if (!req) return -FI_ENOMEM; memset((void *)req, 0, sizeof(*req)); memcpy((void *)req+sizeof(*req), (void *)buf, len); buf = (void *)req + sizeof(*req); } else { req = calloc(1, sizeof(*req)); if (!req) return -FI_ENOMEM; } req->no_event = (flags & PSMX_NO_COMPLETION) || (ep_priv->send_selective_completion && !(flags & FI_COMPLETION)); req->op = PSMX_AM_REQ_ATOMIC_READWRITE; req->atomic.buf = (void *)buf; req->atomic.len = len; req->atomic.addr = addr; req->atomic.key = key; req->atomic.context = context; req->atomic.result = result; req->ep = ep_priv; if (op == FI_ATOMIC_READ) req->cq_flags = FI_READ | FI_ATOMIC; else req->cq_flags = FI_WRITE | FI_ATOMIC; args[0].u32w0 = PSMX_AM_REQ_ATOMIC_READWRITE; args[0].u32w1 = count; args[1].u64 = (uint64_t)(uintptr_t)req; args[2].u64 = addr; args[3].u64 = key; args[4].u32w0 = datatype; args[4].u32w1 = op; psm_am_request_short((psm_epaddr_t) dest_addr, PSMX_AM_ATOMIC_HANDLER, args, 5, (void *)buf, (buf?len:0), am_flags, NULL, NULL); return 0; }
static int psmx_atomic_self(int am_cmd, struct psmx_fid_ep *ep, const void *buf, size_t count, void *desc, const void *compare, void *compare_desc, void *result, void *result_desc, uint64_t addr, uint64_t key, enum fi_datatype datatype, enum fi_op op, void *context, uint64_t flags) { struct psmx_fid_mr *mr; struct psmx_cq_event *event; struct psmx_fid_ep *target_ep; struct psmx_fid_cntr *cntr = NULL; struct psmx_fid_cntr *mr_cntr = NULL; void *tmp_buf; size_t len; int no_event; int err = 0; int op_error; int access; uint64_t cq_flags = 0; if (am_cmd == PSMX_AM_REQ_ATOMIC_WRITE) access = FI_REMOTE_WRITE; else access = FI_REMOTE_READ | FI_REMOTE_WRITE; len = fi_datatype_size(datatype) * count; mr = psmx_mr_get(psmx_active_fabric->active_domain, key); op_error = mr ? psmx_mr_validate(mr, addr, len, access) : -FI_EINVAL; if (op_error) goto gen_local_event; addr += mr->offset; switch (am_cmd) { case PSMX_AM_REQ_ATOMIC_WRITE: err = psmx_atomic_do_write((void *)addr, (void *)buf, (int)datatype, (int)op, (int)count); cq_flags = FI_WRITE | FI_ATOMIC; break; case PSMX_AM_REQ_ATOMIC_READWRITE: if (result != buf) { err = psmx_atomic_do_readwrite((void *)addr, (void *)buf, (void *)result, (int)datatype, (int)op, (int)count); } else { tmp_buf = malloc(len); if (tmp_buf) { memcpy(tmp_buf, result, len); err = psmx_atomic_do_readwrite((void *)addr, (void *)buf, tmp_buf, (int)datatype, (int)op, (int)count); memcpy(result, tmp_buf, len); free(tmp_buf); } else { err = -FI_ENOMEM; } } if (op == FI_ATOMIC_READ) cq_flags = FI_READ | FI_ATOMIC; else cq_flags = FI_WRITE | FI_ATOMIC; break; case PSMX_AM_REQ_ATOMIC_COMPWRITE: if (result != buf && result != compare) { err = psmx_atomic_do_compwrite((void *)addr, (void *)buf, (void *)compare, (void *)result, (int)datatype, (int)op, (int)count); } else { tmp_buf = malloc(len); if (tmp_buf) { memcpy(tmp_buf, result, len); err = psmx_atomic_do_compwrite((void *)addr, (void *)buf, (void *)compare, tmp_buf, (int)datatype, (int)op, (int)count); memcpy(result, tmp_buf, len); free(tmp_buf); } else { err = -FI_ENOMEM; } } cq_flags = FI_WRITE | FI_ATOMIC; break; } target_ep = mr->domain->atomics_ep; if (op == FI_ATOMIC_READ) { cntr = target_ep->remote_read_cntr; } else { cntr = target_ep->remote_write_cntr; mr_cntr = mr->cntr; } if (cntr) psmx_cntr_inc(cntr); if (mr_cntr && mr_cntr != cntr) psmx_cntr_inc(mr_cntr); gen_local_event: no_event = ((flags & PSMX_NO_COMPLETION) || (ep->send_selective_completion && !(flags & FI_COMPLETION))); if (ep->send_cq && !no_event) { event = psmx_cq_create_event( ep->send_cq, context, (void *)buf, cq_flags, len, 0, /* data */ 0, /* tag */ 0, /* olen */ op_error); if (event) psmx_cq_enqueue_event(ep->send_cq, event); else err = -FI_ENOMEM; } switch (am_cmd) { case PSMX_AM_REQ_ATOMIC_WRITE: if (ep->write_cntr) psmx_cntr_inc(ep->write_cntr); break; case PSMX_AM_REQ_ATOMIC_READWRITE: case PSMX_AM_REQ_ATOMIC_COMPWRITE: if (ep->read_cntr) psmx_cntr_inc(ep->read_cntr); break; } return err; }
int psmx_am_atomic_handler(psm_am_token_t token, psm_epaddr_t epaddr, psm_amarg_t *args, int nargs, void *src, uint32_t len) #endif { psm_amarg_t rep_args[8]; int count; void *addr; uint64_t key; int datatype, op; int err = 0; int op_error = 0; struct psmx_am_request *req; struct psmx_cq_event *event; struct psmx_fid_mr *mr; struct psmx_fid_ep *target_ep; struct psmx_fid_cntr *cntr = NULL; struct psmx_fid_cntr *mr_cntr = NULL; void *tmp_buf; #if (PSM_VERNO_MAJOR >= 2) psm_epaddr_t epaddr; psm_am_get_source(token, &epaddr); #endif switch (args[0].u32w0 & PSMX_AM_OP_MASK) { case PSMX_AM_REQ_ATOMIC_WRITE: count = args[0].u32w1; addr = (void *)(uintptr_t)args[2].u64; key = args[3].u64; datatype = args[4].u32w0; op = args[4].u32w1; assert(len == fi_datatype_size(datatype) * count); mr = psmx_mr_get(psmx_active_fabric->active_domain, key); op_error = mr ? psmx_mr_validate(mr, (uint64_t)addr, len, FI_REMOTE_WRITE) : -FI_EINVAL; if (!op_error) { addr += mr->offset; psmx_atomic_do_write(addr, src, datatype, op, count); target_ep = mr->domain->atomics_ep; cntr = target_ep->remote_write_cntr; mr_cntr = mr->cntr; if (cntr) psmx_cntr_inc(cntr); if (mr_cntr && mr_cntr != cntr) psmx_cntr_inc(mr_cntr); } rep_args[0].u32w0 = PSMX_AM_REP_ATOMIC_WRITE; rep_args[0].u32w1 = op_error; rep_args[1].u64 = args[1].u64; err = psm_am_reply_short(token, PSMX_AM_ATOMIC_HANDLER, rep_args, 2, NULL, 0, 0, NULL, NULL ); break; case PSMX_AM_REQ_ATOMIC_READWRITE: count = args[0].u32w1; addr = (void *)(uintptr_t)args[2].u64; key = args[3].u64; datatype = args[4].u32w0; op = args[4].u32w1; if (op == FI_ATOMIC_READ) len = fi_datatype_size(datatype) * count; assert(len == fi_datatype_size(datatype) * count); mr = psmx_mr_get(psmx_active_fabric->active_domain, key); op_error = mr ? psmx_mr_validate(mr, (uint64_t)addr, len, FI_REMOTE_READ|FI_REMOTE_WRITE) : -FI_EINVAL; if (!op_error) { addr += mr->offset; tmp_buf = malloc(len); if (tmp_buf) psmx_atomic_do_readwrite(addr, src, tmp_buf, datatype, op, count); else op_error = -FI_ENOMEM; target_ep = mr->domain->atomics_ep; if (op == FI_ATOMIC_READ) { cntr = target_ep->remote_read_cntr; } else { cntr = target_ep->remote_write_cntr; mr_cntr = mr->cntr; } if (cntr) psmx_cntr_inc(cntr); if (mr_cntr && mr_cntr != cntr) psmx_cntr_inc(mr_cntr); } else { tmp_buf = NULL; } rep_args[0].u32w0 = PSMX_AM_REP_ATOMIC_READWRITE; rep_args[0].u32w1 = op_error; rep_args[1].u64 = args[1].u64; err = psm_am_reply_short(token, PSMX_AM_ATOMIC_HANDLER, rep_args, 2, tmp_buf, (tmp_buf?len:0), 0, psmx_am_atomic_completion, tmp_buf ); break; case PSMX_AM_REQ_ATOMIC_COMPWRITE: count = args[0].u32w1; addr = (void *)(uintptr_t)args[2].u64; key = args[3].u64; datatype = args[4].u32w0; op = args[4].u32w1; len /= 2; assert(len == fi_datatype_size(datatype) * count); mr = psmx_mr_get(psmx_active_fabric->active_domain, key); op_error = mr ? psmx_mr_validate(mr, (uint64_t)addr, len, FI_REMOTE_READ|FI_REMOTE_WRITE) : -FI_EINVAL; if (!op_error) { addr += mr->offset; tmp_buf = malloc(len); if (tmp_buf) psmx_atomic_do_compwrite(addr, src, src + len, tmp_buf, datatype, op, count); else op_error = -FI_ENOMEM; target_ep = mr->domain->atomics_ep; cntr = target_ep->remote_write_cntr; mr_cntr = mr->cntr; if (cntr) psmx_cntr_inc(cntr); if (mr_cntr && mr_cntr != cntr) psmx_cntr_inc(mr_cntr); } else { tmp_buf = NULL; } rep_args[0].u32w0 = PSMX_AM_REP_ATOMIC_READWRITE; rep_args[0].u32w1 = op_error; rep_args[1].u64 = args[1].u64; err = psm_am_reply_short(token, PSMX_AM_ATOMIC_HANDLER, rep_args, 2, tmp_buf, (tmp_buf?len:0), 0, psmx_am_atomic_completion, tmp_buf ); break; case PSMX_AM_REP_ATOMIC_WRITE: req = (struct psmx_am_request *)(uintptr_t)args[1].u64; op_error = (int)args[0].u32w1; assert(req->op == PSMX_AM_REQ_ATOMIC_WRITE); if (req->ep->send_cq && !req->no_event) { event = psmx_cq_create_event( req->ep->send_cq, req->atomic.context, req->atomic.buf, req->cq_flags, req->atomic.len, 0, /* data */ 0, /* tag */ 0, /* olen */ op_error); if (event) psmx_cq_enqueue_event(req->ep->send_cq, event); else err = -FI_ENOMEM; } if (req->ep->write_cntr) psmx_cntr_inc(req->ep->write_cntr); free(req); break; case PSMX_AM_REP_ATOMIC_READWRITE: case PSMX_AM_REP_ATOMIC_COMPWRITE: req = (struct psmx_am_request *)(uintptr_t)args[1].u64; op_error = (int)args[0].u32w1; assert(op_error || req->atomic.len == len); if (!op_error) memcpy(req->atomic.result, src, len); if (req->ep->send_cq && !req->no_event) { event = psmx_cq_create_event( req->ep->send_cq, req->atomic.context, req->atomic.buf, req->cq_flags, req->atomic.len, 0, /* data */ 0, /* tag */ 0, /* olen */ op_error); if (event) psmx_cq_enqueue_event(req->ep->send_cq, event); else err = -FI_ENOMEM; } if (req->ep->read_cntr) psmx_cntr_inc(req->ep->read_cntr); free(req); break; default: err = -FI_EINVAL; } return err; }
static int psmx_atomic_self(int am_cmd, struct psmx_fid_ep *ep, const void *buf, size_t count, void *desc, const void *compare, void *compare_desc, void *result, void *result_desc, uint64_t addr, uint64_t key, enum fi_datatype datatype, enum fi_op op, void *context, uint64_t flags) { struct psmx_fid_mr *mr; struct psmx_cq_event *event; struct psmx_fid_ep *target_ep; size_t len; int no_event; int err = 0; int op_error; int access; if (am_cmd == PSMX_AM_REQ_ATOMIC_WRITE) access = FI_REMOTE_WRITE; else access = FI_REMOTE_READ | FI_REMOTE_WRITE; len = fi_datatype_size(datatype) * count; mr = psmx_mr_hash_get(key); op_error = mr ? psmx_mr_validate(mr, addr, len, access) : -EINVAL; if (op_error) goto gen_local_event; addr += mr->offset; switch (am_cmd) { case PSMX_AM_REQ_ATOMIC_WRITE: err = psmx_atomic_do_write((void *)addr, (void *)buf, (int)datatype, (int)op, (int)count); break; case PSMX_AM_REQ_ATOMIC_READWRITE: err = psmx_atomic_do_readwrite((void *)addr, (void *)buf, (void *)result, (int)datatype, (int)op, (int)count); break; case PSMX_AM_REQ_ATOMIC_COMPWRITE: err = psmx_atomic_do_compwrite((void *)addr, (void *)buf, (void *)compare, (void *)result, (int)datatype, (int)op, (int)count); break; } if (op != FI_ATOMIC_READ) { if (mr->cq) { event = psmx_cq_create_event( mr->cq, 0, /* context */ (void *)addr, 0, /* flags */ len, 0, /* data */ 0, /* tag */ 0, /* olen */ 0 /* err */); if (event) psmx_cq_enqueue_event(mr->cq, event); else err = -ENOMEM; } if (mr->cntr) psmx_cntr_inc(mr->cntr); } target_ep = mr->domain->atomics_ep; if (op == FI_ATOMIC_WRITE) { if (target_ep->remote_write_cntr) psmx_cntr_inc(target_ep->remote_write_cntr); } else if (op == FI_ATOMIC_READ) { if (target_ep->remote_read_cntr) psmx_cntr_inc(target_ep->remote_read_cntr); } else { if (target_ep->remote_write_cntr) psmx_cntr_inc(target_ep->remote_write_cntr); if (am_cmd != PSMX_AM_REQ_ATOMIC_WRITE && target_ep->remote_read_cntr && target_ep->remote_read_cntr != target_ep->remote_write_cntr) psmx_cntr_inc(target_ep->remote_read_cntr); } gen_local_event: no_event = ((flags & FI_INJECT) || (ep->send_cq_event_flag && !(flags & FI_EVENT))); if (ep->send_cq && !no_event) { event = psmx_cq_create_event( ep->send_cq, context, (void *)buf, 0, /* flags */ len, 0, /* data */ 0, /* tag */ 0, /* olen */ op_error); if (event) psmx_cq_enqueue_event(ep->send_cq, event); else err = -ENOMEM; } switch (am_cmd) { case PSMX_AM_REQ_ATOMIC_WRITE: if (ep->write_cntr) psmx_cntr_inc(ep->write_cntr); break; case PSMX_AM_REQ_ATOMIC_READWRITE: case PSMX_AM_REQ_ATOMIC_COMPWRITE: if (ep->read_cntr) psmx_cntr_inc(ep->read_cntr); break; } return err; }
static ssize_t sock_ep_tx_atomic(struct fid_ep *ep, const struct fi_msg_atomic *msg, const struct fi_ioc *comparev, void **compare_desc, size_t compare_count, struct fi_ioc *resultv, void **result_desc, size_t result_count, uint64_t flags) { int i, ret; size_t datatype_sz; struct sock_op tx_op; union sock_iov tx_iov; struct sock_conn *conn; struct sock_tx_ctx *tx_ctx; uint64_t total_len, src_len, dst_len; struct sock_ep *sock_ep; switch (ep->fid.fclass) { case FI_CLASS_EP: sock_ep = container_of(ep, struct sock_ep, ep); tx_ctx = sock_ep->tx_ctx; break; case FI_CLASS_TX_CTX: tx_ctx = container_of(ep, struct sock_tx_ctx, fid.ctx); sock_ep = tx_ctx->ep; break; default: SOCK_LOG_ERROR("Invalid EP type\n"); return -FI_EINVAL; } assert(tx_ctx->enabled && msg->iov_count <= SOCK_EP_MAX_IOV_LIMIT && msg->rma_iov_count <= SOCK_EP_MAX_IOV_LIMIT); if (sock_ep->connected) { conn = sock_ep_lookup_conn(sock_ep); } else { conn = sock_av_lookup_addr(sock_ep, tx_ctx->av, msg->addr); } if (!conn) return -FI_EAGAIN; src_len = 0; datatype_sz = fi_datatype_size(msg->datatype); if (flags & FI_INJECT) { for (i=0; i< msg->iov_count; i++) { src_len += (msg->msg_iov[i].count * datatype_sz); } assert(src_len <= SOCK_EP_MAX_INJECT_SZ); total_len = src_len; } else { total_len = msg->iov_count * sizeof(union sock_iov); } total_len += (sizeof(tx_op) + (msg->rma_iov_count * sizeof(union sock_iov)) + (result_count * sizeof (union sock_iov))); sock_tx_ctx_start(tx_ctx); if (rbfdavail(&tx_ctx->rbfd) < total_len) { ret = -FI_EAGAIN; goto err; } flags |= tx_ctx->attr.op_flags; memset(&tx_op, 0, sizeof(tx_op)); tx_op.op = SOCK_OP_ATOMIC; tx_op.dest_iov_len = msg->rma_iov_count; tx_op.atomic.op = msg->op; tx_op.atomic.datatype = msg->datatype; tx_op.atomic.res_iov_len = result_count; tx_op.atomic.cmp_iov_len = compare_count; if (flags & FI_INJECT) tx_op.src_iov_len = src_len; else tx_op.src_iov_len = msg->iov_count; sock_tx_ctx_write_op_send(tx_ctx, &tx_op, flags, (uintptr_t) msg->context, msg->addr, (uintptr_t) msg->msg_iov[0].addr, sock_ep, conn); if (flags & FI_REMOTE_CQ_DATA) { sock_tx_ctx_write(tx_ctx, &msg->data, sizeof(uint64_t)); } src_len = 0; if (flags & FI_INJECT) { for (i=0; i< msg->iov_count; i++) { sock_tx_ctx_write(tx_ctx, msg->msg_iov[i].addr, msg->msg_iov[i].count * datatype_sz); src_len += (msg->msg_iov[i].count * datatype_sz); } } else { for (i = 0; i< msg->iov_count; i++) { tx_iov.ioc.addr = (uintptr_t) msg->msg_iov[i].addr; tx_iov.ioc.count = msg->msg_iov[i].count; tx_iov.ioc.key = (uintptr_t) msg->desc[i]; sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov)); src_len += (tx_iov.ioc.count * datatype_sz); } } assert(src_len <= SOCK_EP_MAX_ATOMIC_SZ); dst_len = 0; for (i = 0; i< msg->rma_iov_count; i++) { tx_iov.ioc.addr = msg->rma_iov[i].addr; tx_iov.ioc.key = msg->rma_iov[i].key; tx_iov.ioc.count = msg->rma_iov[i].count; sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov)); dst_len += (tx_iov.ioc.count * datatype_sz); } if (dst_len != src_len) { SOCK_LOG_ERROR("Buffer length mismatch\n"); ret = -FI_EINVAL; goto err; } dst_len = 0; for (i = 0; i< result_count; i++) { tx_iov.ioc.addr = (uintptr_t) resultv[i].addr; tx_iov.ioc.count = resultv[i].count; sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov)); dst_len += (tx_iov.ioc.count * datatype_sz); } if (result_count && (dst_len != src_len)) { SOCK_LOG_ERROR("Buffer length mismatch\n"); ret = -FI_EINVAL; goto err; } dst_len = 0; for (i = 0; i< compare_count; i++) { tx_iov.ioc.addr = (uintptr_t) comparev[i].addr; tx_iov.ioc.count = comparev[i].count; sock_tx_ctx_write(tx_ctx, &tx_iov, sizeof(tx_iov)); dst_len += (tx_iov.ioc.count * datatype_sz); } if (compare_count && (dst_len != src_len)) { SOCK_LOG_ERROR("Buffer length mismatch\n"); ret = -FI_EINVAL; goto err; } sock_tx_ctx_commit(tx_ctx); return 0; err: SOCK_LOG_INFO("Not enough space for TX entry, try again\n"); sock_tx_ctx_abort(tx_ctx); return ret; }
int fi_ibv_query_atomic(struct fid_domain *domain_fid, enum fi_datatype datatype, enum fi_op op, struct fi_atomic_attr *attr, uint64_t flags) { struct fi_ibv_domain *domain = container_of(domain_fid, struct fi_ibv_domain, domain_fid); char *log_str_fetch = "fi_fetch_atomic with FI_SUM op"; char *log_str_comp = "fi_compare_atomic"; char *log_str; if (flags & FI_TAGGED) return -FI_ENOSYS; if ((flags & FI_FETCH_ATOMIC) && (flags & FI_COMPARE_ATOMIC)) return -FI_EBADFLAGS; if (!flags) { switch (op) { case FI_ATOMIC_WRITE: break; default: return -FI_ENOSYS; } } else { if (flags & FI_FETCH_ATOMIC) { switch (op) { case FI_ATOMIC_READ: goto check_datatype; case FI_SUM: log_str = log_str_fetch; break; default: return -FI_ENOSYS; } } else if (flags & FI_COMPARE_ATOMIC) { if (op != FI_CSWAP) return -FI_ENOSYS; log_str = log_str_comp; } else { return -FI_EBADFLAGS; } if (domain->info->tx_attr->op_flags & FI_INJECT) { VERBS_INFO(FI_LOG_EP_DATA, "FI_INJECT not supported for %s\n", log_str); return -FI_EINVAL; } } check_datatype: switch (datatype) { case FI_INT64: case FI_UINT64: #if __BITS_PER_LONG == 64 case FI_DOUBLE: case FI_FLOAT: #endif break; default: return -FI_EINVAL; } attr->size = fi_datatype_size(datatype); if (attr->size == 0) return -FI_EINVAL; attr->count = 1; return 0; }