int psmx_am_atomic_handler(psm_am_token_t token, psm_epaddr_t epaddr, psm_amarg_t *args, int nargs, void *src, uint32_t len) #endif { psm_amarg_t rep_args[8]; int count; void *addr; uint64_t key; int datatype, op; int err = 0; int op_error = 0; struct psmx_am_request *req; struct psmx_cq_event *event; struct psmx_fid_mr *mr; struct psmx_fid_ep *target_ep; struct psmx_fid_cntr *cntr = NULL; struct psmx_fid_cntr *mr_cntr = NULL; void *tmp_buf; #if (PSM_VERNO_MAJOR >= 2) psm_epaddr_t epaddr; psm_am_get_source(token, &epaddr); #endif switch (args[0].u32w0 & PSMX_AM_OP_MASK) { case PSMX_AM_REQ_ATOMIC_WRITE: count = args[0].u32w1; addr = (void *)(uintptr_t)args[2].u64; key = args[3].u64; datatype = args[4].u32w0; op = args[4].u32w1; assert(len == fi_datatype_size(datatype) * count); mr = psmx_mr_get(psmx_active_fabric->active_domain, key); op_error = mr ? psmx_mr_validate(mr, (uint64_t)addr, len, FI_REMOTE_WRITE) : -FI_EINVAL; if (!op_error) { addr += mr->offset; psmx_atomic_do_write(addr, src, datatype, op, count); target_ep = mr->domain->atomics_ep; cntr = target_ep->remote_write_cntr; mr_cntr = mr->cntr; if (cntr) psmx_cntr_inc(cntr); if (mr_cntr && mr_cntr != cntr) psmx_cntr_inc(mr_cntr); } rep_args[0].u32w0 = PSMX_AM_REP_ATOMIC_WRITE; rep_args[0].u32w1 = op_error; rep_args[1].u64 = args[1].u64; err = psm_am_reply_short(token, PSMX_AM_ATOMIC_HANDLER, rep_args, 2, NULL, 0, 0, NULL, NULL ); break; case PSMX_AM_REQ_ATOMIC_READWRITE: count = args[0].u32w1; addr = (void *)(uintptr_t)args[2].u64; key = args[3].u64; datatype = args[4].u32w0; op = args[4].u32w1; if (op == FI_ATOMIC_READ) len = fi_datatype_size(datatype) * count; assert(len == fi_datatype_size(datatype) * count); mr = psmx_mr_get(psmx_active_fabric->active_domain, key); op_error = mr ? psmx_mr_validate(mr, (uint64_t)addr, len, FI_REMOTE_READ|FI_REMOTE_WRITE) : -FI_EINVAL; if (!op_error) { addr += mr->offset; tmp_buf = malloc(len); if (tmp_buf) psmx_atomic_do_readwrite(addr, src, tmp_buf, datatype, op, count); else op_error = -FI_ENOMEM; target_ep = mr->domain->atomics_ep; if (op == FI_ATOMIC_READ) { cntr = target_ep->remote_read_cntr; } else { cntr = target_ep->remote_write_cntr; mr_cntr = mr->cntr; } if (cntr) psmx_cntr_inc(cntr); if (mr_cntr && mr_cntr != cntr) psmx_cntr_inc(mr_cntr); } else { tmp_buf = NULL; } rep_args[0].u32w0 = PSMX_AM_REP_ATOMIC_READWRITE; rep_args[0].u32w1 = op_error; rep_args[1].u64 = args[1].u64; err = psm_am_reply_short(token, PSMX_AM_ATOMIC_HANDLER, rep_args, 2, tmp_buf, (tmp_buf?len:0), 0, psmx_am_atomic_completion, tmp_buf ); break; case PSMX_AM_REQ_ATOMIC_COMPWRITE: count = args[0].u32w1; addr = (void *)(uintptr_t)args[2].u64; key = args[3].u64; datatype = args[4].u32w0; op = args[4].u32w1; len /= 2; assert(len == fi_datatype_size(datatype) * count); mr = psmx_mr_get(psmx_active_fabric->active_domain, key); op_error = mr ? psmx_mr_validate(mr, (uint64_t)addr, len, FI_REMOTE_READ|FI_REMOTE_WRITE) : -FI_EINVAL; if (!op_error) { addr += mr->offset; tmp_buf = malloc(len); if (tmp_buf) psmx_atomic_do_compwrite(addr, src, src + len, tmp_buf, datatype, op, count); else op_error = -FI_ENOMEM; target_ep = mr->domain->atomics_ep; cntr = target_ep->remote_write_cntr; mr_cntr = mr->cntr; if (cntr) psmx_cntr_inc(cntr); if (mr_cntr && mr_cntr != cntr) psmx_cntr_inc(mr_cntr); } else { tmp_buf = NULL; } rep_args[0].u32w0 = PSMX_AM_REP_ATOMIC_READWRITE; rep_args[0].u32w1 = op_error; rep_args[1].u64 = args[1].u64; err = psm_am_reply_short(token, PSMX_AM_ATOMIC_HANDLER, rep_args, 2, tmp_buf, (tmp_buf?len:0), 0, psmx_am_atomic_completion, tmp_buf ); break; case PSMX_AM_REP_ATOMIC_WRITE: req = (struct psmx_am_request *)(uintptr_t)args[1].u64; op_error = (int)args[0].u32w1; assert(req->op == PSMX_AM_REQ_ATOMIC_WRITE); if (req->ep->send_cq && !req->no_event) { event = psmx_cq_create_event( req->ep->send_cq, req->atomic.context, req->atomic.buf, req->cq_flags, req->atomic.len, 0, /* data */ 0, /* tag */ 0, /* olen */ op_error); if (event) psmx_cq_enqueue_event(req->ep->send_cq, event); else err = -FI_ENOMEM; } if (req->ep->write_cntr) psmx_cntr_inc(req->ep->write_cntr); free(req); break; case PSMX_AM_REP_ATOMIC_READWRITE: case PSMX_AM_REP_ATOMIC_COMPWRITE: req = (struct psmx_am_request *)(uintptr_t)args[1].u64; op_error = (int)args[0].u32w1; assert(op_error || req->atomic.len == len); if (!op_error) memcpy(req->atomic.result, src, len); if (req->ep->send_cq && !req->no_event) { event = psmx_cq_create_event( req->ep->send_cq, req->atomic.context, req->atomic.buf, req->cq_flags, req->atomic.len, 0, /* data */ 0, /* tag */ 0, /* olen */ op_error); if (event) psmx_cq_enqueue_event(req->ep->send_cq, event); else err = -FI_ENOMEM; } if (req->ep->read_cntr) psmx_cntr_inc(req->ep->read_cntr); free(req); break; default: err = -FI_EINVAL; } return err; }
static int psmx_atomic_self(int am_cmd, struct psmx_fid_ep *ep, const void *buf, size_t count, void *desc, const void *compare, void *compare_desc, void *result, void *result_desc, uint64_t addr, uint64_t key, enum fi_datatype datatype, enum fi_op op, void *context, uint64_t flags) { struct psmx_fid_mr *mr; struct psmx_cq_event *event; struct psmx_fid_ep *target_ep; struct psmx_fid_cntr *cntr = NULL; struct psmx_fid_cntr *mr_cntr = NULL; void *tmp_buf; size_t len; int no_event; int err = 0; int op_error; int access; uint64_t cq_flags = 0; if (am_cmd == PSMX_AM_REQ_ATOMIC_WRITE) access = FI_REMOTE_WRITE; else access = FI_REMOTE_READ | FI_REMOTE_WRITE; len = fi_datatype_size(datatype) * count; mr = psmx_mr_get(psmx_active_fabric->active_domain, key); op_error = mr ? psmx_mr_validate(mr, addr, len, access) : -FI_EINVAL; if (op_error) goto gen_local_event; addr += mr->offset; switch (am_cmd) { case PSMX_AM_REQ_ATOMIC_WRITE: err = psmx_atomic_do_write((void *)addr, (void *)buf, (int)datatype, (int)op, (int)count); cq_flags = FI_WRITE | FI_ATOMIC; break; case PSMX_AM_REQ_ATOMIC_READWRITE: if (result != buf) { err = psmx_atomic_do_readwrite((void *)addr, (void *)buf, (void *)result, (int)datatype, (int)op, (int)count); } else { tmp_buf = malloc(len); if (tmp_buf) { memcpy(tmp_buf, result, len); err = psmx_atomic_do_readwrite((void *)addr, (void *)buf, tmp_buf, (int)datatype, (int)op, (int)count); memcpy(result, tmp_buf, len); free(tmp_buf); } else { err = -FI_ENOMEM; } } if (op == FI_ATOMIC_READ) cq_flags = FI_READ | FI_ATOMIC; else cq_flags = FI_WRITE | FI_ATOMIC; break; case PSMX_AM_REQ_ATOMIC_COMPWRITE: if (result != buf && result != compare) { err = psmx_atomic_do_compwrite((void *)addr, (void *)buf, (void *)compare, (void *)result, (int)datatype, (int)op, (int)count); } else { tmp_buf = malloc(len); if (tmp_buf) { memcpy(tmp_buf, result, len); err = psmx_atomic_do_compwrite((void *)addr, (void *)buf, (void *)compare, tmp_buf, (int)datatype, (int)op, (int)count); memcpy(result, tmp_buf, len); free(tmp_buf); } else { err = -FI_ENOMEM; } } cq_flags = FI_WRITE | FI_ATOMIC; break; } target_ep = mr->domain->atomics_ep; if (op == FI_ATOMIC_READ) { cntr = target_ep->remote_read_cntr; } else { cntr = target_ep->remote_write_cntr; mr_cntr = mr->cntr; } if (cntr) psmx_cntr_inc(cntr); if (mr_cntr && mr_cntr != cntr) psmx_cntr_inc(mr_cntr); gen_local_event: no_event = ((flags & PSMX_NO_COMPLETION) || (ep->send_selective_completion && !(flags & FI_COMPLETION))); if (ep->send_cq && !no_event) { event = psmx_cq_create_event( ep->send_cq, context, (void *)buf, cq_flags, len, 0, /* data */ 0, /* tag */ 0, /* olen */ op_error); if (event) psmx_cq_enqueue_event(ep->send_cq, event); else err = -FI_ENOMEM; } switch (am_cmd) { case PSMX_AM_REQ_ATOMIC_WRITE: if (ep->write_cntr) psmx_cntr_inc(ep->write_cntr); break; case PSMX_AM_REQ_ATOMIC_READWRITE: case PSMX_AM_REQ_ATOMIC_COMPWRITE: if (ep->read_cntr) psmx_cntr_inc(ep->read_cntr); break; } return err; }
static int psmx_atomic_self(int am_cmd, struct psmx_fid_ep *ep, const void *buf, size_t count, void *desc, const void *compare, void *compare_desc, void *result, void *result_desc, uint64_t addr, uint64_t key, enum fi_datatype datatype, enum fi_op op, void *context, uint64_t flags) { struct psmx_fid_mr *mr; struct psmx_cq_event *event; struct psmx_fid_ep *target_ep; size_t len; int no_event; int err = 0; int op_error; int access; if (am_cmd == PSMX_AM_REQ_ATOMIC_WRITE) access = FI_REMOTE_WRITE; else access = FI_REMOTE_READ | FI_REMOTE_WRITE; len = fi_datatype_size(datatype) * count; mr = psmx_mr_hash_get(key); op_error = mr ? psmx_mr_validate(mr, addr, len, access) : -EINVAL; if (op_error) goto gen_local_event; addr += mr->offset; switch (am_cmd) { case PSMX_AM_REQ_ATOMIC_WRITE: err = psmx_atomic_do_write((void *)addr, (void *)buf, (int)datatype, (int)op, (int)count); break; case PSMX_AM_REQ_ATOMIC_READWRITE: err = psmx_atomic_do_readwrite((void *)addr, (void *)buf, (void *)result, (int)datatype, (int)op, (int)count); break; case PSMX_AM_REQ_ATOMIC_COMPWRITE: err = psmx_atomic_do_compwrite((void *)addr, (void *)buf, (void *)compare, (void *)result, (int)datatype, (int)op, (int)count); break; } if (op != FI_ATOMIC_READ) { if (mr->cq) { event = psmx_cq_create_event( mr->cq, 0, /* context */ (void *)addr, 0, /* flags */ len, 0, /* data */ 0, /* tag */ 0, /* olen */ 0 /* err */); if (event) psmx_cq_enqueue_event(mr->cq, event); else err = -ENOMEM; } if (mr->cntr) psmx_cntr_inc(mr->cntr); } target_ep = mr->domain->atomics_ep; if (op == FI_ATOMIC_WRITE) { if (target_ep->remote_write_cntr) psmx_cntr_inc(target_ep->remote_write_cntr); } else if (op == FI_ATOMIC_READ) { if (target_ep->remote_read_cntr) psmx_cntr_inc(target_ep->remote_read_cntr); } else { if (target_ep->remote_write_cntr) psmx_cntr_inc(target_ep->remote_write_cntr); if (am_cmd != PSMX_AM_REQ_ATOMIC_WRITE && target_ep->remote_read_cntr && target_ep->remote_read_cntr != target_ep->remote_write_cntr) psmx_cntr_inc(target_ep->remote_read_cntr); } gen_local_event: no_event = ((flags & FI_INJECT) || (ep->send_cq_event_flag && !(flags & FI_EVENT))); if (ep->send_cq && !no_event) { event = psmx_cq_create_event( ep->send_cq, context, (void *)buf, 0, /* flags */ len, 0, /* data */ 0, /* tag */ 0, /* olen */ op_error); if (event) psmx_cq_enqueue_event(ep->send_cq, event); else err = -ENOMEM; } switch (am_cmd) { case PSMX_AM_REQ_ATOMIC_WRITE: if (ep->write_cntr) psmx_cntr_inc(ep->write_cntr); break; case PSMX_AM_REQ_ATOMIC_READWRITE: case PSMX_AM_REQ_ATOMIC_COMPWRITE: if (ep->read_cntr) psmx_cntr_inc(ep->read_cntr); break; } return err; }