int psmx_am_rma_handler(psm_am_token_t token, psm_epaddr_t epaddr, psm_amarg_t *args, int nargs, void *src, uint32_t len) { psm_amarg_t rep_args[8]; void *rma_addr; ssize_t rma_len; uint64_t key; int err = 0; int op_error = 0; int cmd, eom, has_data; struct psmx_am_request *req; struct psmx_cq_event *event; int chunk_size; uint64_t offset; struct psmx_fid_mr *mr; cmd = args[0].u32w0 & PSMX_AM_OP_MASK; eom = args[0].u32w0 & PSMX_AM_EOM; has_data = args[0].u32w0 & PSMX_AM_DATA; switch (cmd) { case PSMX_AM_REQ_WRITE: rma_len = args[0].u32w1; rma_addr = (void *)(uintptr_t)args[2].u64; key = args[3].u64; mr = psmx_mr_hash_get(key); op_error = mr ? psmx_mr_validate(mr, (uint64_t)rma_addr, len, FI_REMOTE_WRITE) : -EINVAL; if (!op_error) { rma_addr += mr->offset; memcpy(rma_addr, src, len); if (eom) { if (mr->cq) { /* TODO: report the addr/len of the whole write */ event = psmx_cq_create_event( mr->cq, 0, /* context */ rma_addr, 0, /* flags */ rma_len, has_data ? args[4].u64 : 0, 0, /* tag */ 0, /* olen */ 0); if (event) psmx_cq_enqueue_event(mr->cq, event); else err = -ENOMEM; } if (mr->cntr) psmx_cntr_inc(mr->cntr); if (mr->domain->rma_ep->remote_write_cntr) psmx_cntr_inc(mr->domain->rma_ep->remote_write_cntr); } } if (eom || op_error) { rep_args[0].u32w0 = PSMX_AM_REP_WRITE | eom; rep_args[0].u32w1 = op_error; rep_args[1].u64 = args[1].u64; err = psm_am_reply_short(token, PSMX_AM_RMA_HANDLER, rep_args, 2, NULL, 0, 0, NULL, NULL ); } break; case PSMX_AM_REQ_WRITE_LONG: rma_len = args[0].u32w1; rma_addr = (void *)(uintptr_t)args[2].u64; key = args[3].u64; mr = psmx_mr_hash_get(key); op_error = mr ? psmx_mr_validate(mr, (uint64_t)rma_addr, len, FI_REMOTE_WRITE) : -EINVAL; if (op_error) { rep_args[0].u32w0 = PSMX_AM_REP_WRITE | eom; rep_args[0].u32w1 = op_error; rep_args[1].u64 = args[1].u64; err = psm_am_reply_short(token, PSMX_AM_RMA_HANDLER, rep_args, 2, NULL, 0, 0, NULL, NULL ); break; } rma_addr += mr->offset; req = calloc(1, sizeof(*req)); if (!req) { err = -ENOMEM; } else { req->op = args[0].u32w0; req->write.addr = (uint64_t)rma_addr; req->write.len = rma_len; req->write.key = key; req->write.context = (void *)args[4].u64; req->write.data = has_data ? args[5].u64 : 0; PSMX_CTXT_TYPE(&req->fi_context) = PSMX_REMOTE_WRITE_CONTEXT; PSMX_CTXT_USER(&req->fi_context) = mr; psmx_am_enqueue_rma(mr->domain, req); } break; case PSMX_AM_REQ_READ: rma_len = args[0].u32w1; rma_addr = (void *)(uintptr_t)args[2].u64; key = args[3].u64; offset = args[4].u64; mr = psmx_mr_hash_get(key); op_error = mr ? psmx_mr_validate(mr, (uint64_t)rma_addr, rma_len, FI_REMOTE_READ) : -EINVAL; if (!op_error) { rma_addr += mr->offset; } else { rma_addr = NULL; rma_len = 0; } chunk_size = MIN(PSMX_AM_CHUNK_SIZE, psmx_am_param.max_reply_short); assert(rma_len <= chunk_size); rep_args[0].u32w0 = PSMX_AM_REP_READ | eom; rep_args[0].u32w1 = op_error; rep_args[1].u64 = args[1].u64; rep_args[2].u64 = offset; err = psm_am_reply_short(token, PSMX_AM_RMA_HANDLER, rep_args, 3, rma_addr, rma_len, 0, NULL, NULL ); if (eom && !op_error) { if (mr->domain->rma_ep->remote_read_cntr) psmx_cntr_inc(mr->domain->rma_ep->remote_read_cntr); } break; case PSMX_AM_REQ_READ_LONG: rma_len = args[0].u32w1; rma_addr = (void *)(uintptr_t)args[2].u64; key = args[3].u64; mr = psmx_mr_hash_get(key); op_error = mr ? psmx_mr_validate(mr, (uint64_t)rma_addr, len, FI_REMOTE_WRITE) : -EINVAL; if (op_error) { rep_args[0].u32w0 = PSMX_AM_REP_READ | eom; rep_args[0].u32w1 = op_error; rep_args[1].u64 = args[1].u64; rep_args[2].u64 = 0; err = psm_am_reply_short(token, PSMX_AM_RMA_HANDLER, rep_args, 3, NULL, 0, 0, NULL, NULL ); break; } rma_addr += mr->offset; req = calloc(1, sizeof(*req)); if (!req) { err = -ENOMEM; } else { req->op = args[0].u32w0; req->read.addr = (uint64_t)rma_addr; req->read.len = rma_len; req->read.key = key; req->read.context = (void *)args[4].u64; req->read.peer_addr = (void *)epaddr; PSMX_CTXT_TYPE(&req->fi_context) = PSMX_REMOTE_READ_CONTEXT; PSMX_CTXT_USER(&req->fi_context) = mr; psmx_am_enqueue_rma(mr->domain, req); } break; case PSMX_AM_REP_WRITE: req = (struct psmx_am_request *)(uintptr_t)args[1].u64; assert(req->op == PSMX_AM_REQ_WRITE); op_error = (int)args[0].u32w1; if (!req->error) req->error = op_error; if (eom) { if (req->ep->send_cq && !req->no_event) { event = psmx_cq_create_event( req->ep->send_cq, req->write.context, req->write.buf, 0, /* flags */ req->write.len, 0, /* data */ 0, /* tag */ 0, /* olen */ req->error); if (event) psmx_cq_enqueue_event(req->ep->send_cq, event); else err = -ENOMEM; } if (req->ep->write_cntr) psmx_cntr_inc(req->ep->write_cntr); free(req); } break; case PSMX_AM_REP_READ: req = (struct psmx_am_request *)(uintptr_t)args[1].u64; assert(req->op == PSMX_AM_REQ_READ); op_error = (int)args[0].u32w1; offset = args[2].u64; if (!req->error) req->error = op_error; if (!op_error) { memcpy(req->read.buf + offset, src, len); req->read.len_read += len; } if (eom) { if (req->ep->send_cq && !req->no_event) { event = psmx_cq_create_event( req->ep->send_cq, req->read.context, req->read.buf, 0, /* flags */ req->read.len_read, 0, /* data */ 0, /* tag */ req->read.len - req->read.len_read, req->error); if (event) psmx_cq_enqueue_event(req->ep->send_cq, event); else err = -ENOMEM; } if (req->ep->read_cntr) psmx_cntr_inc(req->ep->read_cntr); free(req); } break; default: err = -EINVAL; } return err; }
int psmx_am_msg_handler(psm_am_token_t token, psm_epaddr_t epaddr, psm_amarg_t *args, int nargs, void *src, uint32_t len) #endif { psm_amarg_t rep_args[8]; struct psmx_am_request *req; struct psmx_cq_event *event; struct psmx_epaddr_context *epaddr_context; struct psmx_fid_domain *domain; int copy_len; uint64_t offset; int cmd, eom; int err = 0; int op_error = 0; struct psmx_unexp *unexp; #if (PSM_VERNO_MAJOR >= 2) psm_epaddr_t epaddr; psm_am_get_source(token, &epaddr); #endif epaddr_context = psm_epaddr_getctxt(epaddr); if (!epaddr_context) { FI_WARN(&psmx_prov, FI_LOG_EP_DATA, "NULL context for epaddr %p\n", epaddr); return -FI_EIO; } domain = epaddr_context->domain; cmd = args[0].u32w0 & PSMX_AM_OP_MASK; eom = args[0].u32w0 & PSMX_AM_EOM; switch (cmd) { case PSMX_AM_REQ_SEND: assert(len == args[0].u32w1); offset = args[3].u64; if (offset == 0) { /* this is the first packet */ req = psmx_am_search_and_dequeue_recv(domain, (const void *)epaddr); if (req) { copy_len = MIN(len, req->recv.len); memcpy(req->recv.buf, src, len); req->recv.len_received += copy_len; } else { unexp = malloc(sizeof(*unexp) + len); if (!unexp) { op_error = -FI_ENOSPC; } else { memcpy(unexp->buf, src, len); unexp->sender_addr = epaddr; unexp->sender_context = args[1].u64; unexp->len_received = len; unexp->done = !!eom; unexp->list_entry.next = NULL; psmx_am_enqueue_unexp(domain, unexp); if (!eom) { /* stop here. will reply when recv is posted */ break; } } } if (!op_error && !eom) { /* reply w/ recv req to be used for following packets */ rep_args[0].u32w0 = PSMX_AM_REP_SEND; rep_args[0].u32w1 = 0; rep_args[1].u64 = args[1].u64; rep_args[2].u64 = (uint64_t)(uintptr_t)req; err = psm_am_reply_short(token, PSMX_AM_MSG_HANDLER, rep_args, 3, NULL, 0, 0, NULL, NULL ); } } else { req = (struct psmx_am_request *)(uintptr_t)args[2].u64; if (req) { copy_len = MIN(req->recv.len + offset, len); memcpy(req->recv.buf + offset, src, copy_len); req->recv.len_received += copy_len; } else { FI_WARN(&psmx_prov, FI_LOG_EP_DATA, "NULL recv_req in follow-up packets.\n"); op_error = -FI_ENOMSG; } } if (eom && req) { if (req->ep->recv_cq && !req->no_event) { event = psmx_cq_create_event( req->ep->recv_cq, req->recv.context, req->recv.buf, req->cq_flags, req->recv.len_received, 0, /* data */ 0, /* tag */ req->recv.len - req->recv.len_received, 0 /* err */); if (event) psmx_cq_enqueue_event(req->ep->recv_cq, event); else err = -FI_ENOMEM; } if (req->ep->recv_cntr) psmx_cntr_inc(req->ep->recv_cntr); free(req); } if (eom || op_error) { rep_args[0].u32w0 = PSMX_AM_REP_SEND; rep_args[0].u32w1 = op_error; rep_args[1].u64 = args[1].u64; rep_args[2].u64 = 0; /* done */ err = psm_am_reply_short(token, PSMX_AM_MSG_HANDLER, rep_args, 3, NULL, 0, 0, NULL, NULL ); } break; case PSMX_AM_REP_SEND: req = (struct psmx_am_request *)(uintptr_t)args[1].u64; op_error = (int)args[0].u32w1; assert(req->op == PSMX_AM_REQ_SEND); if (args[2].u64) { /* more to send */ req->send.peer_context = (void *)(uintptr_t)args[2].u64; /* psm_am_request_short() can't be called inside the handler. * put the request into a queue and process it later. */ psmx_am_enqueue_send(req->ep->domain, req); } else { /* done */ if (req->ep->send_cq && !req->no_event) { event = psmx_cq_create_event( req->ep->send_cq, req->send.context, req->send.buf, req->cq_flags, req->send.len, 0, /* data */ 0, /* tag */ 0, /* olen */ op_error); if (event) psmx_cq_enqueue_event(req->ep->send_cq, event); else err = -FI_ENOMEM; } if (req->ep->send_cntr) psmx_cntr_inc(req->ep->send_cntr); free(req); } break; default: err = -FI_EINVAL; } return err; }
int psmx_am_atomic_handler(psm_am_token_t token, psm_epaddr_t epaddr, psm_amarg_t *args, int nargs, void *src, uint32_t len) #endif { psm_amarg_t rep_args[8]; int count; void *addr; uint64_t key; int datatype, op; int err = 0; int op_error = 0; struct psmx_am_request *req; struct psmx_cq_event *event; struct psmx_fid_mr *mr; struct psmx_fid_ep *target_ep; struct psmx_fid_cntr *cntr = NULL; struct psmx_fid_cntr *mr_cntr = NULL; void *tmp_buf; #if (PSM_VERNO_MAJOR >= 2) psm_epaddr_t epaddr; psm_am_get_source(token, &epaddr); #endif switch (args[0].u32w0 & PSMX_AM_OP_MASK) { case PSMX_AM_REQ_ATOMIC_WRITE: count = args[0].u32w1; addr = (void *)(uintptr_t)args[2].u64; key = args[3].u64; datatype = args[4].u32w0; op = args[4].u32w1; assert(len == fi_datatype_size(datatype) * count); mr = psmx_mr_get(psmx_active_fabric->active_domain, key); op_error = mr ? psmx_mr_validate(mr, (uint64_t)addr, len, FI_REMOTE_WRITE) : -FI_EINVAL; if (!op_error) { addr += mr->offset; psmx_atomic_do_write(addr, src, datatype, op, count); target_ep = mr->domain->atomics_ep; cntr = target_ep->remote_write_cntr; mr_cntr = mr->cntr; if (cntr) psmx_cntr_inc(cntr); if (mr_cntr && mr_cntr != cntr) psmx_cntr_inc(mr_cntr); } rep_args[0].u32w0 = PSMX_AM_REP_ATOMIC_WRITE; rep_args[0].u32w1 = op_error; rep_args[1].u64 = args[1].u64; err = psm_am_reply_short(token, PSMX_AM_ATOMIC_HANDLER, rep_args, 2, NULL, 0, 0, NULL, NULL ); break; case PSMX_AM_REQ_ATOMIC_READWRITE: count = args[0].u32w1; addr = (void *)(uintptr_t)args[2].u64; key = args[3].u64; datatype = args[4].u32w0; op = args[4].u32w1; if (op == FI_ATOMIC_READ) len = fi_datatype_size(datatype) * count; assert(len == fi_datatype_size(datatype) * count); mr = psmx_mr_get(psmx_active_fabric->active_domain, key); op_error = mr ? psmx_mr_validate(mr, (uint64_t)addr, len, FI_REMOTE_READ|FI_REMOTE_WRITE) : -FI_EINVAL; if (!op_error) { addr += mr->offset; tmp_buf = malloc(len); if (tmp_buf) psmx_atomic_do_readwrite(addr, src, tmp_buf, datatype, op, count); else op_error = -FI_ENOMEM; target_ep = mr->domain->atomics_ep; if (op == FI_ATOMIC_READ) { cntr = target_ep->remote_read_cntr; } else { cntr = target_ep->remote_write_cntr; mr_cntr = mr->cntr; } if (cntr) psmx_cntr_inc(cntr); if (mr_cntr && mr_cntr != cntr) psmx_cntr_inc(mr_cntr); } else { tmp_buf = NULL; } rep_args[0].u32w0 = PSMX_AM_REP_ATOMIC_READWRITE; rep_args[0].u32w1 = op_error; rep_args[1].u64 = args[1].u64; err = psm_am_reply_short(token, PSMX_AM_ATOMIC_HANDLER, rep_args, 2, tmp_buf, (tmp_buf?len:0), 0, psmx_am_atomic_completion, tmp_buf ); break; case PSMX_AM_REQ_ATOMIC_COMPWRITE: count = args[0].u32w1; addr = (void *)(uintptr_t)args[2].u64; key = args[3].u64; datatype = args[4].u32w0; op = args[4].u32w1; len /= 2; assert(len == fi_datatype_size(datatype) * count); mr = psmx_mr_get(psmx_active_fabric->active_domain, key); op_error = mr ? psmx_mr_validate(mr, (uint64_t)addr, len, FI_REMOTE_READ|FI_REMOTE_WRITE) : -FI_EINVAL; if (!op_error) { addr += mr->offset; tmp_buf = malloc(len); if (tmp_buf) psmx_atomic_do_compwrite(addr, src, src + len, tmp_buf, datatype, op, count); else op_error = -FI_ENOMEM; target_ep = mr->domain->atomics_ep; cntr = target_ep->remote_write_cntr; mr_cntr = mr->cntr; if (cntr) psmx_cntr_inc(cntr); if (mr_cntr && mr_cntr != cntr) psmx_cntr_inc(mr_cntr); } else { tmp_buf = NULL; } rep_args[0].u32w0 = PSMX_AM_REP_ATOMIC_READWRITE; rep_args[0].u32w1 = op_error; rep_args[1].u64 = args[1].u64; err = psm_am_reply_short(token, PSMX_AM_ATOMIC_HANDLER, rep_args, 2, tmp_buf, (tmp_buf?len:0), 0, psmx_am_atomic_completion, tmp_buf ); break; case PSMX_AM_REP_ATOMIC_WRITE: req = (struct psmx_am_request *)(uintptr_t)args[1].u64; op_error = (int)args[0].u32w1; assert(req->op == PSMX_AM_REQ_ATOMIC_WRITE); if (req->ep->send_cq && !req->no_event) { event = psmx_cq_create_event( req->ep->send_cq, req->atomic.context, req->atomic.buf, req->cq_flags, req->atomic.len, 0, /* data */ 0, /* tag */ 0, /* olen */ op_error); if (event) psmx_cq_enqueue_event(req->ep->send_cq, event); else err = -FI_ENOMEM; } if (req->ep->write_cntr) psmx_cntr_inc(req->ep->write_cntr); free(req); break; case PSMX_AM_REP_ATOMIC_READWRITE: case PSMX_AM_REP_ATOMIC_COMPWRITE: req = (struct psmx_am_request *)(uintptr_t)args[1].u64; op_error = (int)args[0].u32w1; assert(op_error || req->atomic.len == len); if (!op_error) memcpy(req->atomic.result, src, len); if (req->ep->send_cq && !req->no_event) { event = psmx_cq_create_event( req->ep->send_cq, req->atomic.context, req->atomic.buf, req->cq_flags, req->atomic.len, 0, /* data */ 0, /* tag */ 0, /* olen */ op_error); if (event) psmx_cq_enqueue_event(req->ep->send_cq, event); else err = -FI_ENOMEM; } if (req->ep->read_cntr) psmx_cntr_inc(req->ep->read_cntr); free(req); break; default: err = -FI_EINVAL; } return err; }
int psmx_am_msg_handler(psm_am_token_t token, psm_epaddr_t epaddr, psm_amarg_t *args, int nargs, void *src, uint32_t len) { psm_amarg_t rep_args[8]; struct psmx_am_request *req; struct psmx_cq_event *event; struct psmx_epaddr_context *epaddr_context; struct psmx_fid_domain *domain; int msg_len; int copy_len; uint64_t offset; int cmd, eom; int err = 0; int op_error = 0; struct psmx_unexp *unexp; epaddr_context = psm_epaddr_getctxt(epaddr); if (!epaddr_context) { fprintf(stderr, "%s: NULL context for epaddr %p\n", __func__, epaddr); return -EIO; } domain = epaddr_context->domain; cmd = args[0].u32w0 & PSMX_AM_OP_MASK; eom = args[0].u32w0 & PSMX_AM_EOM; switch (cmd) { case PSMX_AM_REQ_SEND: msg_len = args[0].u32w1; offset = args[3].u64; assert(len == msg_len); if (offset == 0) { /* this is the first packet */ req = psmx_am_search_and_dequeue_recv(domain, (const void *)epaddr); if (req) { copy_len = MIN(len, req->recv.len); memcpy(req->recv.buf, src, len); req->recv.len_received += copy_len; } else { unexp = malloc(sizeof(*unexp) + len); if (!unexp) { op_error = -ENOBUFS; } else { memcpy(unexp->buf, src, len); unexp->sender_addr = epaddr; unexp->sender_context = args[1].u64; unexp->len_received = len; unexp->done = !!eom; unexp->next = NULL; psmx_unexp_enqueue(unexp); if (!eom) { /* stop here. will reply when recv is posted */ break; } } } if (!op_error && !eom) { /* reply w/ recv req to be used for following packets */ rep_args[0].u32w0 = PSMX_AM_REP_SEND; rep_args[0].u32w1 = 0; rep_args[1].u64 = args[1].u64; rep_args[2].u64 = (uint64_t)(uintptr_t)req; err = psm_am_reply_short(token, PSMX_AM_MSG_HANDLER, rep_args, 3, NULL, 0, 0, NULL, NULL ); } } else { req = (struct psmx_am_request *)(uintptr_t)args[2].u64; if (req) { copy_len = MIN(req->recv.len + offset, len); memcpy(req->recv.buf + offset, src, copy_len); req->recv.len_received += copy_len; } else { fprintf(stderr, "%s: NULL recv_req in follow-up packets.\n", __func__); op_error = -EBADMSG; } } if (eom && req) { if (req->ep->recv_cq && !req->no_event) { event = psmx_cq_create_event( req->ep->recv_cq, req->recv.context, req->recv.buf, 0, /* flags */ req->recv.len_received, 0, /* data */ 0, /* tag */ req->recv.len - req->recv.len_received, 0 /* err */); if (event) psmx_cq_enqueue_event(req->ep->recv_cq, event); else err = -ENOMEM; } if (req->ep->recv_cntr) psmx_cntr_inc(req->ep->recv_cntr); free(req); } if (eom || op_error) { rep_args[0].u32w0 = PSMX_AM_REP_SEND; rep_args[0].u32w1 = op_error; rep_args[1].u64 = args[1].u64; rep_args[2].u64 = 0; /* done */ err = psm_am_reply_short(token, PSMX_AM_MSG_HANDLER, rep_args, 3, NULL, 0, 0, NULL, NULL ); } break; case PSMX_AM_REP_SEND: req = (struct psmx_am_request *)(uintptr_t)args[1].u64; op_error = (int)args[0].u32w1; assert(req->op == PSMX_AM_REQ_SEND); if (args[2].u64) { /* more to send */ req->send.peer_context = (void *)(uintptr_t)args[2].u64; #if PSMX_AM_USE_SEND_QUEUE /* psm_am_request_short() can't be called inside the handler. * put the request into a queue and process it later. */ psmx_am_enqueue_send(req->ep->domain, req); if (req->ep->domain->progress_thread) pthread_cond_signal(&req->ep->domain->progress_cond); #else req->send.peer_ready = 1; #endif } else { /* done */ if (req->ep->send_cq && !req->no_event) { event = psmx_cq_create_event( req->ep->send_cq, req->send.context, req->send.buf, 0, /* flags */ req->send.len, 0, /* data */ 0, /* tag */ 0, /* olen */ op_error); if (event) psmx_cq_enqueue_event(req->ep->send_cq, event); else err = -ENOMEM; } if (req->ep->send_cntr) psmx_cntr_inc(req->ep->send_cntr); if (req->state == PSMX_AM_STATE_QUEUED) req->state = PSMX_AM_STATE_DONE; else free(req); } break; default: err = -EINVAL; } return err; }