int psmx2_ep_open(struct fid_domain *domain, struct fi_info *info, struct fid_ep **ep, void *context) { struct psmx2_fid_domain *domain_priv; struct psmx2_fid_ep *ep_priv; struct psmx2_ep_name ep_name; struct psmx2_ep_name *src_addr; struct psmx2_trx_ctxt *trx_ctxt = NULL; int err = -FI_EINVAL; int alloc_trx_ctxt = 1; domain_priv = container_of(domain, struct psmx2_fid_domain, util_domain.domain_fid.fid); if (!domain_priv) goto errout; if (info && info->ep_attr && info->ep_attr->rx_ctx_cnt == FI_SHARED_CONTEXT) return -FI_ENOSYS; if (info && info->ep_attr && info->ep_attr->tx_ctx_cnt == FI_SHARED_CONTEXT && !ofi_recv_allowed(info->caps) && !ofi_rma_target_allowed(info->caps)) { alloc_trx_ctxt = 0; FI_INFO(&psmx2_prov, FI_LOG_EP_CTRL, "Tx only endpoint with STX context.\n"); } src_addr = NULL; if (info && info->src_addr) { if (info->addr_format == FI_ADDR_STR) src_addr = psmx2_string_to_ep_name(info->src_addr); else src_addr = info->src_addr; } if (alloc_trx_ctxt) { trx_ctxt = psmx2_trx_ctxt_alloc(domain_priv, src_addr, 0); if (!trx_ctxt) goto errout; } err = psmx2_ep_open_internal(domain_priv, info, &ep_priv, context, trx_ctxt); if (err) goto errout_free_ctxt; ep_priv->type = PSMX2_EP_REGULAR; ep_priv->service = PSMX2_ANY_SERVICE; if (src_addr) { ep_priv->service = src_addr->service; if (info->addr_format == FI_ADDR_STR) free(src_addr); } if (ep_priv->service == PSMX2_ANY_SERVICE) ep_priv->service = ((getpid() & 0x7FFF) << 16) + ((uintptr_t)ep_priv & 0xFFFF); if (alloc_trx_ctxt) { trx_ctxt->ep = ep_priv; psmx2_lock(&domain_priv->trx_ctxt_lock, 1); dlist_insert_before(&trx_ctxt->entry, &domain_priv->trx_ctxt_list); psmx2_unlock(&domain_priv->trx_ctxt_lock, 1); ep_name.epid = trx_ctxt->psm2_epid; ep_name.type = ep_priv->type; ofi_ns_add_local_name(&domain_priv->fabric->name_server, &ep_priv->service, &ep_name); } *ep = &ep_priv->ep; return 0; errout_free_ctxt: psmx2_trx_ctxt_free(trx_ctxt); errout: return err; }
int psmx2_sep_open(struct fid_domain *domain, struct fi_info *info, struct fid_ep **sep, void *context) { struct psmx2_fid_domain *domain_priv; struct psmx2_fid_ep *ep_priv; struct psmx2_fid_sep *sep_priv; struct psmx2_ep_name ep_name; struct psmx2_ep_name *src_addr; struct psmx2_trx_ctxt *trx_ctxt; size_t ctxt_cnt = 1; size_t ctxt_size; int err = -FI_EINVAL; int i; domain_priv = container_of(domain, struct psmx2_fid_domain, util_domain.domain_fid.fid); if (!domain_priv) goto errout; if (info && info->ep_attr) { if (info->ep_attr->tx_ctx_cnt > psmx2_env.sep_trx_ctxt) { FI_WARN(&psmx2_prov, FI_LOG_EP_CTRL, "tx_ctx_cnt %"PRIu64" exceed limit %d.\n", info->ep_attr->tx_ctx_cnt, psmx2_env.sep_trx_ctxt); goto errout; } if (info->ep_attr->rx_ctx_cnt > psmx2_env.sep_trx_ctxt) { FI_WARN(&psmx2_prov, FI_LOG_EP_CTRL, "rx_ctx_cnt %"PRIu64" exceed limit %d.\n", info->ep_attr->rx_ctx_cnt, psmx2_env.sep_trx_ctxt); goto errout; } ctxt_cnt = info->ep_attr->tx_ctx_cnt; if (ctxt_cnt < info->ep_attr->rx_ctx_cnt) ctxt_cnt = info->ep_attr->rx_ctx_cnt; if (ctxt_cnt == 0) { FI_INFO(&psmx2_prov, FI_LOG_EP_CTRL, "tx_ctx_cnt and rx_ctx_cnt are 0, use 1.\n"); ctxt_cnt = 1; } } ctxt_size = ctxt_cnt * sizeof(struct psmx2_sep_ctxt); sep_priv = (struct psmx2_fid_sep *) calloc(1, sizeof(*sep_priv) + ctxt_size); if (!sep_priv) { err = -FI_ENOMEM; goto errout; } sep_priv->ep.fid.fclass = FI_CLASS_SEP; sep_priv->ep.fid.context = context; sep_priv->ep.fid.ops = &psmx2_fi_ops_sep; sep_priv->ep.ops = &psmx2_sep_ops; sep_priv->ep.cm = &psmx2_cm_ops; sep_priv->domain = domain_priv; sep_priv->ctxt_cnt = ctxt_cnt; ofi_atomic_initialize32(&sep_priv->ref, 0); src_addr = NULL; if (info && info->src_addr) { if (info->addr_format == FI_ADDR_STR) src_addr = psmx2_string_to_ep_name(info->src_addr); else src_addr = info->src_addr; } for (i = 0; i < ctxt_cnt; i++) { trx_ctxt = psmx2_trx_ctxt_alloc(domain_priv, src_addr, i); if (!trx_ctxt) { err = -FI_ENOMEM; goto errout_free_ctxt; } sep_priv->ctxts[i].trx_ctxt = trx_ctxt; err = psmx2_ep_open_internal(domain_priv, info, &ep_priv, context, trx_ctxt); if (err) goto errout_free_ctxt; /* override the ops so the fid can't be closed individually */ ep_priv->ep.fid.ops = &psmx2_fi_ops_sep_ctxt; trx_ctxt->ep = ep_priv; sep_priv->ctxts[i].ep = ep_priv; } sep_priv->type = PSMX2_EP_SCALABLE; sep_priv->service = PSMX2_ANY_SERVICE; if (src_addr) { sep_priv->service = src_addr->service; if (info->addr_format == FI_ADDR_STR) free(src_addr); } if (sep_priv->service == PSMX2_ANY_SERVICE) sep_priv->service = ((getpid() & 0x7FFF) << 16) + ((uintptr_t)sep_priv & 0xFFFF); sep_priv->id = ofi_atomic_inc32(&domain_priv->sep_cnt); psmx2_lock(&domain_priv->sep_lock, 1); dlist_insert_before(&sep_priv->entry, &domain_priv->sep_list); psmx2_unlock(&domain_priv->sep_lock, 1); psmx2_lock(&domain_priv->trx_ctxt_lock, 1); for (i = 0; i< ctxt_cnt; i++) { dlist_insert_before(&sep_priv->ctxts[i].trx_ctxt->entry, &domain_priv->trx_ctxt_list); } psmx2_unlock(&domain_priv->trx_ctxt_lock, 1); ep_name.epid = sep_priv->ctxts[0].trx_ctxt->psm2_epid; ep_name.sep_id = sep_priv->id; ep_name.type = sep_priv->type; ofi_ns_add_local_name(&domain_priv->fabric->name_server, &sep_priv->service, &ep_name); psmx2_domain_acquire(domain_priv); *sep = &sep_priv->ep; /* Make sure the AM handler is installed to answer SEP query */ psmx2_am_init(sep_priv->ctxts[0].trx_ctxt); return 0; errout_free_ctxt: while (i) { if (sep_priv->ctxts[i].ep) psmx2_ep_close_internal(sep_priv->ctxts[i].ep); if (sep_priv->ctxts[i].trx_ctxt) psmx2_trx_ctxt_free(sep_priv->ctxts[i].trx_ctxt); i--; } free(sep_priv); errout: return err; }
int psmx_ep_open(struct fid_domain *domain, struct fi_info *info, struct fid_ep **ep, void *context) { struct psmx_fid_domain *domain_priv; struct psmx_fid_ep *ep_priv; int err; uint64_t ep_cap; if (info) ep_cap = info->caps; else ep_cap = FI_TAGGED; domain_priv = container_of(domain, struct psmx_fid_domain, util_domain.domain_fid.fid); if (!domain_priv) return -FI_EINVAL; if (info && info->ep_attr && info->ep_attr->auth_key) { if (info->ep_attr->auth_key_size != sizeof(psm_uuid_t)) { FI_WARN(&psmx_prov, FI_LOG_EP_CTRL, "Invalid auth_key_len %d, should be %d.\n", info->ep_attr->auth_key_size, sizeof(psm_uuid_t)); return -FI_EINVAL; } if (memcmp(domain_priv->fabric->uuid, info->ep_attr->auth_key, sizeof(psm_uuid_t))) { FI_WARN(&psmx_prov, FI_LOG_EP_CTRL, "Invalid auth_key: %s\n", psmx_uuid_to_string((void *)info->ep_attr->auth_key)); return -FI_EINVAL; } } err = psmx_domain_check_features(domain_priv, ep_cap); if (err) return err; ep_priv = (struct psmx_fid_ep *) calloc(1, sizeof *ep_priv); if (!ep_priv) return -FI_ENOMEM; ep_priv->ep.fid.fclass = FI_CLASS_EP; ep_priv->ep.fid.context = context; ep_priv->ep.fid.ops = &psmx_fi_ops; ep_priv->ep.ops = &psmx_ep_ops; ep_priv->ep.cm = &psmx_cm_ops; ep_priv->domain = domain_priv; ofi_atomic_initialize32(&ep_priv->ref, 0); PSMX_CTXT_TYPE(&ep_priv->nocomp_send_context) = PSMX_NOCOMP_SEND_CONTEXT; PSMX_CTXT_EP(&ep_priv->nocomp_send_context) = ep_priv; PSMX_CTXT_TYPE(&ep_priv->nocomp_recv_context) = PSMX_NOCOMP_RECV_CONTEXT; PSMX_CTXT_EP(&ep_priv->nocomp_recv_context) = ep_priv; if (ep_cap & FI_TAGGED) ep_priv->ep.tagged = &psmx_tagged_ops; if (ep_cap & FI_MSG) ep_priv->ep.msg = &psmx_msg_ops; if ((ep_cap & FI_MSG) && psmx_env.am_msg) ep_priv->ep.msg = &psmx_msg2_ops; if (ep_cap & FI_RMA) ep_priv->ep.rma = &psmx_rma_ops; if (ep_cap & FI_ATOMICS) ep_priv->ep.atomic = &psmx_atomic_ops; ep_priv->caps = ep_cap; err = psmx_domain_enable_ep(domain_priv, ep_priv); if (err) { free(ep_priv); return err; } psmx_domain_acquire(domain_priv); if (info) { if (info->tx_attr) ep_priv->tx_flags = info->tx_attr->op_flags; if (info->rx_attr) ep_priv->rx_flags = info->rx_attr->op_flags; } psmx_ep_optimize_ops(ep_priv); ep_priv->service = PSMX_ANY_SERVICE; if (info && info->src_addr) ep_priv->service = ((struct psmx_src_name *)info->src_addr)->service; if (ep_priv->service == PSMX_ANY_SERVICE) ep_priv->service = ((getpid() & 0x7FFF) << 16) + ((uintptr_t)ep_priv & 0xFFFF); ofi_ns_add_local_name(&ep_priv->domain->fabric->name_server, &ep_priv->service, &ep_priv->domain->psm_epid); *ep = &ep_priv->ep; return 0; }