int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, int dst, int tag, mca_pml_base_send_mode_t mode, struct ompi_communicator_t* comm) { ompi_request_t *req; ucp_ep_h ep; PML_UCX_TRACE_SEND("%s", buf, count, datatype, dst, tag, mode, comm, "send"); /* TODO special care to sync/buffered send */ ep = mca_pml_ucx_get_ep(comm, dst); if (OPAL_UNLIKELY(NULL == ep)) { PML_UCX_ERROR("Failed to get ep for rank %d", dst); return OMPI_ERROR; } req = (ompi_request_t*)ucp_tag_send_nb(ep, buf, count, mca_pml_ucx_get_datatype(datatype), PML_UCX_MAKE_SEND_TAG(tag, comm), mca_pml_ucx_send_completion); if (OPAL_LIKELY(req == NULL)) { return OMPI_SUCCESS; } else if (!UCS_PTR_IS_ERR(req)) { PML_UCX_VERBOSE(8, "got request %p", (void*)req); ucp_worker_progress(ompi_pml_ucx.ucp_worker); ompi_request_wait(&req, MPI_STATUS_IGNORE); return OMPI_SUCCESS; } else { PML_UCX_ERROR("ucx send failed: %s", ucs_status_string(UCS_PTR_STATUS(req))); return OMPI_ERROR; } }
int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src, int tag, struct ompi_communicator_t* comm, ompi_status_public_t* mpi_status) { ucp_tag_t ucp_tag, ucp_tag_mask; ompi_request_t *req; PML_UCX_TRACE_RECV("%s", buf, count, datatype, src, tag, comm, "recv"); PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm); req = (ompi_request_t*)ucp_tag_recv_nb(ompi_pml_ucx.ucp_worker, buf, count, mca_pml_ucx_get_datatype(datatype), ucp_tag, ucp_tag_mask, mca_pml_ucx_blocking_recv_completion); if (UCS_PTR_IS_ERR(req)) { PML_UCX_ERROR("ucx recv failed: %s", ucs_status_string(UCS_PTR_STATUS(req))); return OMPI_ERROR; } ucp_worker_progress(ompi_pml_ucx.ucp_worker); while ( !REQUEST_COMPLETE(req) ) { opal_progress(); } if (mpi_status != MPI_STATUS_IGNORE) { *mpi_status = req->req_status; } req->req_complete = REQUEST_PENDING; ucp_request_release(req); return OMPI_SUCCESS; }
int mca_spml_ucx_del_procs(ompi_proc_t** procs, size_t nprocs) { int my_rank = oshmem_my_proc_id(); size_t num_reqs, max_reqs; void *dreq, **dreqs; ucp_ep_h ep; size_t i, n; oshmem_shmem_barrier(); if (!mca_spml_ucx.ucp_peers) { return OSHMEM_SUCCESS; } max_reqs = mca_spml_ucx.num_disconnect; if (max_reqs > nprocs) { max_reqs = nprocs; } dreqs = malloc(sizeof(*dreqs) * max_reqs); if (dreqs == NULL) { return OMPI_ERR_OUT_OF_RESOURCE; } num_reqs = 0; for (i = 0; i < nprocs; ++i) { n = (i + my_rank) % nprocs; ep = mca_spml_ucx.ucp_peers[n].ucp_conn; if (ep == NULL) { continue; } SPML_VERBOSE(10, "disconnecting from peer %d", n); dreq = ucp_disconnect_nb(ep); if (dreq != NULL) { if (UCS_PTR_IS_ERR(dreq)) { SPML_ERROR("ucp_disconnect_nb(%d) failed: %s", n, ucs_status_string(UCS_PTR_STATUS(dreq))); } else { dreqs[num_reqs++] = dreq; } } mca_spml_ucx.ucp_peers[n].ucp_conn = NULL; if ((int)num_reqs >= mca_spml_ucx.num_disconnect) { mca_spml_ucx_waitall(dreqs, &num_reqs); } } mca_spml_ucx_waitall(dreqs, &num_reqs); free(dreqs); opal_pmix.fence(NULL, 0); free(mca_spml_ucx.ucp_peers); return OSHMEM_SUCCESS; }
struct ucx_context *launch_send(int msg_len) { ucp_tag_recv_info_t info_tag; ucp_tag_message_h msg_tag; ucs_status_t status; struct ucx_context *request; static char *msg = NULL; static int cur_len = 0; int len; static char fill = 'a'; if( same_buf ) { if( cur_len < msg_len ) { if( NULL == msg ) { msg = malloc(msg_len); } else { free(msg); msg = malloc(msg_len); } cur_len = msg_len; } } else { msg = malloc(msg_len); } if( mem_set ) { memset(msg, fill, msg_len); fill++; if( 'z' < fill ){ fill = 'a'; } } request = ucp_tag_send_nb(rem_ep, msg, msg_len, ucp_dt_make_contig(1), tag, send_handle); if (UCS_PTR_IS_ERR(request)) { fprintf(stderr, "unable to send UCX address message\n"); free(msg); abort(); } else if (UCS_PTR_STATUS(request) != UCS_OK) { request->buf = msg; } else { request = NULL; if( !same_buf ) { free(msg); } } return request; }
static int connect_client() { ucp_tag_recv_info_t info_tag; ucp_tag_message_h msg_tag; ucs_status_t status; ucp_ep_params_t ep_params; struct msg *msg = 0; struct ucx_context *request = 0; size_t msg_len = 0; int ret = -1; int i; /* Send client UCX address to server */ ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; ep_params.address = peer_addr; status = ucp_ep_create(ucp_worker, &ep_params, &rem_ep); if (status != UCS_OK) { abort(); } msg_len = sizeof(*msg) + local_addr_len; msg = calloc(1, msg_len); if (!msg) { abort(); } msg->data_len = local_addr_len; memcpy(msg->data, local_addr, local_addr_len); request = ucp_tag_send_nb(rem_ep, msg, msg_len, ucp_dt_make_contig(1), tag, send_handle); if (UCS_PTR_IS_ERR(request)) { fprintf(stderr, "unable to send UCX address message\n"); free(msg); abort(); } else if (UCS_PTR_STATUS(request) != UCS_OK) { fprintf(stderr, "UCX address message was scheduled for send\n"); wait(ucp_worker, request); request->completed = 0; /* Reset request state before recycling it */ ucp_request_release(request); } free (msg); ret = 0; err: return ret; }
struct ucx_context * launch_recv() { ucp_tag_recv_info_t info_tag; ucp_tag_message_h msg_tag; ucs_status_t status; struct ucx_context *request; static char *msg = NULL; static int cur_len = 0; msg_tag = ucp_tag_probe_nb(ucp_worker, tag, tag_mask, 1, &info_tag); if(msg_tag == NULL) { return NULL; } if( same_buf ) { if( cur_len < info_tag.length ){ if( NULL == msg ) { msg = malloc(info_tag.length); } else { free(msg); msg = malloc(info_tag.length); } cur_len = info_tag.length; } } else { msg = malloc(info_tag.length); } if (!msg) { fprintf(stderr, "unable to allocate memory\n"); abort(); } request = ucp_tag_msg_recv_nb(ucp_worker, msg, info_tag.length, ucp_dt_make_contig(1), msg_tag, recv_handle); if (UCS_PTR_IS_ERR(request)) { fprintf(stderr, "unable to receive UCX data message (%u)\n", UCS_PTR_STATUS(request)); free(msg); abort(); } request->buf = msg; return request; }
void ucp_ep_destroy(ucp_ep_h ep) { ucp_worker_h worker = ep->worker; ucs_status_ptr_t *request; ucs_status_t status; request = ucp_disconnect_nb(ep); if (request == NULL) { return; } else if (UCS_PTR_IS_ERR(request)) { ucs_warn("disconnect failed: %s", ucs_status_string(UCS_PTR_STATUS(request))); return; } else { do { ucp_worker_progress(worker); status = ucp_request_test(request, NULL); } while (status == UCS_INPROGRESS); ucp_request_release(request); } }
int mca_pml_ucx_mrecv(void *buf, size_t count, ompi_datatype_t *datatype, struct ompi_message_t **message, ompi_status_public_t* status) { ompi_request_t *req; PML_UCX_TRACE_MRECV("mrecv", buf, count, datatype, message); req = (ompi_request_t*)ucp_tag_msg_recv_nb(ompi_pml_ucx.ucp_worker, buf, count, mca_pml_ucx_get_datatype(datatype), (*message)->req_ptr, mca_pml_ucx_recv_completion); if (UCS_PTR_IS_ERR(req)) { PML_UCX_ERROR("ucx msg recv failed: %s", ucs_status_string(UCS_PTR_STATUS(req))); return OMPI_ERROR; } PML_UCX_MESSAGE_RELEASE(message); ompi_request_wait(&req, status); return OMPI_SUCCESS; }
int mca_pml_ucx_improbe(int src, int tag, struct ompi_communicator_t* comm, int *matched, struct ompi_message_t **message, ompi_status_public_t* mpi_status) { ucp_tag_t ucp_tag, ucp_tag_mask; ucp_tag_recv_info_t info; ucp_tag_message_h ucp_msg; PML_UCX_TRACE_PROBE("improbe", src, tag, comm); PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm); ucp_msg = ucp_tag_probe_nb(ompi_pml_ucx.ucp_worker, ucp_tag, ucp_tag_mask, 1, &info); if (ucp_msg != NULL) { PML_UCX_MESSAGE_NEW(comm, ucp_msg, &info, message); PML_UCX_VERBOSE(8, "got message %p (%p)", (void*)*message, (void*)ucp_msg); *matched = 1; mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info); } else if (UCS_PTR_STATUS(ucp_msg) == UCS_ERR_NO_MESSAGE) { *matched = 0; } return OMPI_SUCCESS; }
int mca_pml_ucx_irecv(void *buf, size_t count, ompi_datatype_t *datatype, int src, int tag, struct ompi_communicator_t* comm, struct ompi_request_t **request) { ucp_tag_t ucp_tag, ucp_tag_mask; ompi_request_t *req; PML_UCX_TRACE_RECV("irecv request *%p", buf, count, datatype, src, tag, comm, (void*)request); PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm); req = (ompi_request_t*)ucp_tag_recv_nb(ompi_pml_ucx.ucp_worker, buf, count, mca_pml_ucx_get_datatype(datatype), ucp_tag, ucp_tag_mask, mca_pml_ucx_recv_completion); if (UCS_PTR_IS_ERR(req)) { PML_UCX_ERROR("ucx recv failed: %s", ucs_status_string(UCS_PTR_STATUS(req))); return OMPI_ERROR; } PML_UCX_VERBOSE(8, "got request %p", (void*)req); *request = req; return OMPI_SUCCESS; }
int mca_pml_ucx_start(size_t count, ompi_request_t** requests) { mca_pml_ucx_persistent_request_t *preq; ompi_request_t *tmp_req; size_t i; for (i = 0; i < count; ++i) { preq = (mca_pml_ucx_persistent_request_t *)requests[i]; if ((preq == NULL) || (OMPI_REQUEST_PML != preq->ompi.req_type)) { /* Skip irrelevant requests */ continue; } PML_UCX_ASSERT(preq->ompi.req_state != OMPI_REQUEST_INVALID); preq->ompi.req_state = OMPI_REQUEST_ACTIVE; mca_pml_ucx_request_reset(&preq->ompi); if (preq->flags & MCA_PML_UCX_REQUEST_FLAG_SEND) { /* TODO special care to sync/buffered send */ PML_UCX_VERBOSE(8, "start send request %p", (void*)preq); tmp_req = (ompi_request_t*)ucp_tag_send_nb(preq->send.ep, preq->buffer, preq->count, preq->datatype, preq->tag, mca_pml_ucx_psend_completion); } else { PML_UCX_VERBOSE(8, "start recv request %p", (void*)preq); tmp_req = (ompi_request_t*)ucp_tag_recv_nb(ompi_pml_ucx.ucp_worker, preq->buffer, preq->count, preq->datatype, preq->tag, preq->recv.tag_mask, mca_pml_ucx_precv_completion); } if (tmp_req == NULL) { /* Only send can complete immediately */ PML_UCX_ASSERT(preq->flags & MCA_PML_UCX_REQUEST_FLAG_SEND); PML_UCX_VERBOSE(8, "send completed immediately, completing persistent request %p", (void*)preq); mca_pml_ucx_set_send_status(&preq->ompi.req_status, UCS_OK); ompi_request_complete(&preq->ompi, true); } else if (!UCS_PTR_IS_ERR(tmp_req)) { if (REQUEST_COMPLETE(tmp_req)) { /* tmp_req is already completed */ PML_UCX_VERBOSE(8, "completing persistent request %p", (void*)preq); mca_pml_ucx_persistent_request_complete(preq, tmp_req); } else { /* tmp_req would be completed by callback and trigger completion * of preq */ PML_UCX_VERBOSE(8, "temporary request %p will complete persistent request %p", (void*)tmp_req, (void*)preq); tmp_req->req_complete_cb_data = preq; preq->tmp_req = tmp_req; } } else { PML_UCX_ERROR("ucx %s failed: %s", (preq->flags & MCA_PML_UCX_REQUEST_FLAG_SEND) ? "send" : "recv", ucs_status_string(UCS_PTR_STATUS(tmp_req))); return OMPI_ERROR; } } return OMPI_SUCCESS; }
static int run_ucx_server(ucp_worker_h ucp_worker) { ucp_tag_recv_info_t info_tag; ucp_tag_message_h msg_tag; ucs_status_t status; ucp_ep_h client_ep; struct msg *msg = 0; struct ucx_context *request = 0; size_t msg_len = 0; int ret = -1; /* Receive client UCX address */ do { /* Following blocked methods used to polling internal file descriptor * to make CPU idle and don't spin loop */ if (ucp_test_mode == TEST_MODE_WAIT) { status = ucp_worker_wait(ucp_worker); if (status != UCS_OK) { goto err; } } else if (ucp_test_mode == TEST_MODE_EVENTFD) { status = test_poll_wait(ucp_worker); if (status != UCS_OK) { goto err; } } /* Progressing before probe to update the state */ ucp_worker_progress(ucp_worker); /* Probing incoming events in non-block mode */ msg_tag = ucp_tag_probe_nb(ucp_worker, tag, tag_mask, 1, &info_tag); } while (msg_tag == NULL); msg = malloc(info_tag.length); if (!msg) { fprintf(stderr, "unable to allocate memory\n"); goto err; } request = ucp_tag_msg_recv_nb(ucp_worker, msg, info_tag.length, ucp_dt_make_contig(1), msg_tag, recv_handle); if (UCS_PTR_IS_ERR(request)) { fprintf(stderr, "unable to receive UCX address message (%s)\n", ucs_status_string(UCS_PTR_STATUS(request))); free(msg); goto err; } else { wait(ucp_worker, request); ucp_request_release(request); printf("UCX address message was received\n"); } peer_addr = malloc(msg->data_len); if (!peer_addr) { fprintf(stderr, "unable to allocate memory for peer address\n"); free(msg); goto err; } peer_addr_len = msg->data_len; memcpy(peer_addr, msg->data, peer_addr_len); free(msg); /* Send test string to client */ status = ucp_ep_create(ucp_worker, peer_addr, &client_ep); if (status != UCS_OK) { goto err; } msg_len = sizeof(*msg) + strlen(test_str) + 1; msg = calloc(1, msg_len); if (!msg) { printf("unable to allocate memory\n"); goto err_ep; } msg->data_len = msg_len - sizeof(*msg); snprintf((char *)msg->data, msg->data_len, "%s", test_str); request = ucp_tag_send_nb(client_ep, msg, msg_len, ucp_dt_make_contig(1), tag, send_handle); if (UCS_PTR_IS_ERR(request)) { fprintf(stderr, "unable to send UCX data message\n"); free(msg); goto err_ep; } else if (UCS_PTR_STATUS(request) != UCS_OK) { printf("UCX data message was scheduled for send\n"); wait(ucp_worker, request); ucp_request_release(request); } ret = 0; free(msg); err_ep: ucp_ep_destroy(client_ep); err: return ret; }
static int connect_server() { ucp_tag_recv_info_t info_tag; ucp_tag_message_h msg_tag; ucs_status_t status; ucp_ep_params_t ep_params; struct msg *msg = 0; struct ucx_context *request = 0; /* Receive client UCX address */ do { /* Following blocked methods used to polling internal file descriptor * to make CPU idle and don't spin loop */ /* Progressing before probe to update the state */ ucp_worker_progress(ucp_worker); /* Probing incoming events in non-block mode */ msg_tag = ucp_tag_probe_nb(ucp_worker, tag, tag_mask, 1, &info_tag); } while (msg_tag == NULL); msg = malloc(info_tag.length); if (!msg) { fprintf(stderr, "unable to allocate memory\n"); abort(); } request = ucp_tag_msg_recv_nb(ucp_worker, msg, info_tag.length, ucp_dt_make_contig(1), msg_tag, recv_handle); if (UCS_PTR_IS_ERR(request)) { fprintf(stderr, "unable to receive UCX address message (%s)\n", ucs_status_string(UCS_PTR_STATUS(request))); free(msg); abort(); } else { wait(ucp_worker, request); request->completed = 0; ucp_request_release(request); printf("UCX address message was received\n"); } peer_addr = malloc(msg->data_len); if (!peer_addr) { fprintf(stderr, "unable to allocate memory for peer address\n"); free(msg); abort(); } peer_addr_len = msg->data_len; memcpy(peer_addr, msg->data, peer_addr_len); free(msg); /* Send test string to client */ ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; ep_params.address = peer_addr; status = ucp_ep_create(ucp_worker, &ep_params, &rem_ep); if (status != UCS_OK) { abort(); } }