/* ****************************************************************** */ int MPIDO_Allgather_alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Aint send_true_lb, MPI_Aint recv_true_lb, size_t send_size, size_t recv_size, MPID_Comm * comm_ptr, int *mpierrno) { int i, rc; void *a2a_sendbuf = NULL; char *destbuf=NULL; char *startbuf=NULL; const int size = comm_ptr->local_size; const int rank = comm_ptr->rank; int a2a_sendcounts[size]; int a2a_senddispls[size]; int a2a_recvcounts[size]; int a2a_recvdispls[size]; for (i = 0; i < size; ++i) { a2a_sendcounts[i] = send_size; a2a_senddispls[i] = 0; a2a_recvcounts[i] = recvcount; a2a_recvdispls[i] = recvcount * i; } if (sendbuf != MPI_IN_PLACE) { a2a_sendbuf = (char *)sendbuf + send_true_lb; } else { startbuf = (char *) recvbuf + recv_true_lb; destbuf = startbuf + rank * send_size; a2a_sendbuf = destbuf; a2a_sendcounts[rank] = 0; a2a_recvcounts[rank] = 0; } /* Switch to comm->coll_fns->fn() */ rc = MPIDO_Alltoallv((const void *)a2a_sendbuf, a2a_sendcounts, a2a_senddispls, MPI_CHAR, recvbuf, a2a_recvcounts, a2a_recvdispls, recvtype, comm_ptr, mpierrno); return rc; }
/* this guy requires quite a few buffers. maybe * we should somehow "steal" the comm_ptr alltoall ones? */ int MPIDO_Scatterv_alltoallv(void * sendbuf, int * sendcounts, int * displs, MPI_Datatype sendtype, void * recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPID_Comm * comm_ptr, int *mpierrno) { const int rank = comm_ptr->rank; const int size = comm_ptr->local_size; int *sdispls, *scounts; int *rdispls, *rcounts; char *sbuf, *rbuf; int rbytes; int rc; MPIDI_Datatype_get_data_size(recvcount, recvtype, rbytes); rbuf = MPIU_Malloc(size * rbytes * sizeof(char)); if(!rbuf) { return MPIR_Err_create_code(MPI_SUCCESS, MPIR_ERR_RECOVERABLE, __FUNCTION__, __LINE__, MPI_ERR_OTHER, "**nomem", 0); } if(rank == root) { sdispls = displs; scounts = sendcounts; sbuf = sendbuf; } else { sdispls = MPIU_Malloc(size * sizeof(int)); scounts = MPIU_Malloc(size * sizeof(int)); sbuf = MPIU_Malloc(rbytes * sizeof(char)); if(!sdispls || !scounts || !sbuf) { if(sdispls) MPIU_Free(sdispls); if(scounts) MPIU_Free(scounts); return MPIR_Err_create_code(MPI_SUCCESS, MPIR_ERR_RECOVERABLE, __FUNCTION__, __LINE__, MPI_ERR_OTHER, "**nomem", 0); } memset(sdispls, 0, size*sizeof(int)); memset(scounts, 0, size*sizeof(int)); } rdispls = MPIU_Malloc(size * sizeof(int)); rcounts = MPIU_Malloc(size * sizeof(int)); if(!rdispls || !rcounts) { if(rdispls) MPIU_Free(rdispls); return MPIR_Err_create_code(MPI_SUCCESS, MPIR_ERR_RECOVERABLE, __FUNCTION__, __LINE__, MPI_ERR_OTHER, "**nomem", 0); } memset(rdispls, 0, size*sizeof(unsigned)); memset(rcounts, 0, size*sizeof(unsigned)); rcounts[root] = rbytes; /* Switch to comm->coll_fns->fn() */ rc = MPIDO_Alltoallv(sbuf, scounts, sdispls, sendtype, rbuf, rcounts, rdispls, MPI_CHAR, comm_ptr, mpierrno); if(rank == root && recvbuf == MPI_IN_PLACE) { MPIU_Free(rbuf); MPIU_Free(rdispls); MPIU_Free(rcounts); return rc; } else { memcpy(recvbuf, rbuf, rbytes); MPIU_Free(rbuf); MPIU_Free(rdispls); MPIU_Free(rcounts); if(rank != root) { MPIU_Free(sbuf); MPIU_Free(sdispls); MPIU_Free(scounts); } } return rc; }
int MPIDO_Allgatherv_alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int buffer_sum, const int *displs, MPI_Datatype recvtype, MPI_Aint send_true_lb, MPI_Aint recv_true_lb, size_t send_size, size_t recv_size, MPID_Comm * comm_ptr, int *mpierrno) { TRACE_ERR("Entering MPIDO_Allgatherv_alltoallv\n"); size_t total_send_size; char *startbuf; char *destbuf; int i, rc; int my_recvcounts = -1; void *a2a_sendbuf = NULL; const int size = comm_ptr->local_size; int a2a_sendcounts[size]; int a2a_senddispls[size]; const int rank = comm_ptr->rank; total_send_size = recvcounts[rank] * recv_size; for (i = 0; i < size; ++i) { a2a_sendcounts[i] = total_send_size; a2a_senddispls[i] = 0; } if (sendbuf != MPI_IN_PLACE) { a2a_sendbuf = (char *)sendbuf + send_true_lb; } else { startbuf = (char *) recvbuf + recv_true_lb; destbuf = startbuf + displs[rank] * recv_size; a2a_sendbuf = destbuf; a2a_sendcounts[rank] = 0; my_recvcounts = recvcounts[rank]; recvcounts[rank] = 0; } TRACE_ERR("Calling alltoallv in MPIDO_Allgatherv_alltoallv\n"); /* Switch to comm->coll_fns->fn() */ rc = MPIDO_Alltoallv(a2a_sendbuf, a2a_sendcounts, a2a_senddispls, MPI_CHAR, recvbuf, recvcounts, displs, recvtype, comm_ptr, mpierrno); if (sendbuf == MPI_IN_PLACE) recvcounts[rank] = my_recvcounts; TRACE_ERR("Leaving MPIDO_Allgatherv_alltoallv\n"); return rc; }