예제 #1
0
int ompi_coll_base_alltoall_intra_pairwise(const void *sbuf, int scount,
                                            struct ompi_datatype_t *sdtype,
                                            void* rbuf, int rcount,
                                            struct ompi_datatype_t *rdtype,
                                            struct ompi_communicator_t *comm,
                                            mca_coll_base_module_t *module)
{
    int line = -1, err = 0, rank, size, step, sendto, recvfrom;
    void * tmpsend, *tmprecv;
    ptrdiff_t lb, sext, rext;

    if (MPI_IN_PLACE == sbuf) {
        return mca_coll_base_alltoall_intra_basic_inplace (rbuf, rcount, rdtype,
                                                            comm, module);
    }

    size = ompi_comm_size(comm);
    rank = ompi_comm_rank(comm);

    OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
                 "coll:base:alltoall_intra_pairwise rank %d", rank));

    err = ompi_datatype_get_extent (sdtype, &lb, &sext);
    if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }
    err = ompi_datatype_get_extent (rdtype, &lb, &rext);
    if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }


    /* Perform pairwise exchange - starting from 1 so the local copy is last */
    for (step = 1; step < size + 1; step++) {

        /* Determine sender and receiver for this step. */
        sendto  = (rank + step) % size;
        recvfrom = (rank + size - step) % size;

        /* Determine sending and receiving locations */
        tmpsend = (char*)sbuf + (ptrdiff_t)sendto * sext * (ptrdiff_t)scount;
        tmprecv = (char*)rbuf + (ptrdiff_t)recvfrom * rext * (ptrdiff_t)rcount;

        /* send and receive */
        err = ompi_coll_base_sendrecv( tmpsend, scount, sdtype, sendto,
                                        MCA_COLL_BASE_TAG_ALLTOALL,
                                        tmprecv, rcount, rdtype, recvfrom,
                                        MCA_COLL_BASE_TAG_ALLTOALL,
                                        comm, MPI_STATUS_IGNORE, rank);
        if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl;  }
    }

    return MPI_SUCCESS;

 err_hndl:
    OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
                 "%s:%4d\tError occurred %d, rank %2d", __FILE__, line,
                 err, rank));
    (void)line;  // silence compiler warning
    return err;
}
예제 #2
0
int ompi_coll_base_alltoall_intra_two_procs(const void *sbuf, int scount,
                                             struct ompi_datatype_t *sdtype,
                                             void* rbuf, int rcount,
                                             struct ompi_datatype_t *rdtype,
                                             struct ompi_communicator_t *comm,
                                             mca_coll_base_module_t *module)
{
    int line = -1, err = 0, rank, remote;
    void * tmpsend, *tmprecv;
    ptrdiff_t sext, rext, lb;

    if (MPI_IN_PLACE == sbuf) {
        return mca_coll_base_alltoall_intra_basic_inplace (rbuf, rcount, rdtype,
                                                            comm, module);
    }

    rank = ompi_comm_rank(comm);

    OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
                 "ompi_coll_base_alltoall_intra_two_procs rank %d", rank));

    err = ompi_datatype_get_extent (sdtype, &lb, &sext);
    if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }

    err = ompi_datatype_get_extent (rdtype, &lb, &rext);
    if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }

    /* exchange data */
    remote  = rank ^ 1;

    tmpsend = (char*)sbuf + (ptrdiff_t)remote * sext * (ptrdiff_t)scount;
    tmprecv = (char*)rbuf + (ptrdiff_t)remote * rext * (ptrdiff_t)rcount;

    /* send and receive */
    err = ompi_coll_base_sendrecv ( tmpsend, scount, sdtype, remote,
                                     MCA_COLL_BASE_TAG_ALLTOALL,
                                     tmprecv, rcount, rdtype, remote,
                                     MCA_COLL_BASE_TAG_ALLTOALL,
                                     comm, MPI_STATUS_IGNORE, rank );
    if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl;  }

    /* ddt sendrecv your own data */
    err = ompi_datatype_sndrcv((char*) sbuf + (ptrdiff_t)rank * sext * (ptrdiff_t)scount,
                               (int32_t) scount, sdtype,
                               (char*) rbuf + (ptrdiff_t)rank * rext * (ptrdiff_t)rcount,
                               (int32_t) rcount, rdtype);
    if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl;  }

    /* done */
    return MPI_SUCCESS;

 err_hndl:
    OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
                 "%s:%4d\tError occurred %d, rank %2d", __FILE__, line, err,
                 rank));
    return err;
}
int
ompi_coll_base_alltoallv_intra_pairwise(const void *sbuf, const int *scounts, const int *sdisps,
                                         struct ompi_datatype_t *sdtype,
                                         void* rbuf, const int *rcounts, const int *rdisps,
                                         struct ompi_datatype_t *rdtype,
                                         struct ompi_communicator_t *comm,
                                         mca_coll_base_module_t *module)
{
    int line = -1, err = 0, rank, size, step = 0, sendto, recvfrom;
    void *psnd, *prcv;
    ptrdiff_t sext, rext;

    if (MPI_IN_PLACE == sbuf) {
        return mca_coll_base_alltoallv_intra_basic_inplace (rbuf, rcounts, rdisps,
                                                             rdtype, comm, module);
    }

    size = ompi_comm_size(comm);
    rank = ompi_comm_rank(comm);

    OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
                 "coll:base:alltoallv_intra_pairwise rank %d", rank));

    ompi_datatype_type_extent(sdtype, &sext);
    ompi_datatype_type_extent(rdtype, &rext);

   /* Perform pairwise exchange starting from 1 since local exhange is done */
    for (step = 0; step < size; step++) {

        /* Determine sender and receiver for this step. */
        sendto  = (rank + step) % size;
        recvfrom = (rank + size - step) % size;

        /* Determine sending and receiving locations */
        psnd = (char*)sbuf + (ptrdiff_t)sdisps[sendto] * sext;
        prcv = (char*)rbuf + (ptrdiff_t)rdisps[recvfrom] * rext;

        /* send and receive */
        err = ompi_coll_base_sendrecv( psnd, scounts[sendto], sdtype, sendto,
                                        MCA_COLL_BASE_TAG_ALLTOALLV,
                                        prcv, rcounts[recvfrom], rdtype, recvfrom,
                                        MCA_COLL_BASE_TAG_ALLTOALLV,
                                        comm, MPI_STATUS_IGNORE, rank);
        if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl;  }
    }

    return MPI_SUCCESS;

 err_hndl:
    OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
                 "%s:%4d\tError occurred %d, rank %2d at step %d", __FILE__, line,
                 err, rank, step));
    (void)line;  // silence compiler warning
    return err;
}
예제 #4
0
int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, int scount,
                                         struct ompi_datatype_t *sdtype,
                                         void* rbuf, int rcount,
                                         struct ompi_datatype_t *rdtype,
                                         struct ompi_communicator_t *comm,
                                         mca_coll_base_module_t *module)
{
    int i, k, line = -1, rank, size, err = 0;
    int sendto, recvfrom, distance, *displs = NULL, *blen = NULL;
    char *tmpbuf = NULL, *tmpbuf_free = NULL;
    OPAL_PTRDIFF_TYPE sext, rext, span, gap;
    struct ompi_datatype_t *new_ddt;

    if (MPI_IN_PLACE == sbuf) {
        return mca_coll_base_alltoall_intra_basic_inplace (rbuf, rcount, rdtype,
                                                            comm, module);
    }

    size = ompi_comm_size(comm);
    rank = ompi_comm_rank(comm);

    OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
                 "coll:base:alltoall_intra_bruck rank %d", rank));

    err = ompi_datatype_type_extent (sdtype, &sext);
    if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }

    err = ompi_datatype_type_extent (rdtype, &rext);
    if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }

    span = opal_datatype_span(&sdtype->super, (int64_t)size * scount, &gap);

    displs = (int *) malloc(size * sizeof(int));
    if (displs == NULL) { line = __LINE__; err = -1; goto err_hndl; }
    blen = (int *) malloc(size * sizeof(int));
    if (blen == NULL) { line = __LINE__; err = -1; goto err_hndl; }

    /* tmp buffer allocation for message data */
    tmpbuf_free = (char *)malloc(span);
    if (tmpbuf_free == NULL) { line = __LINE__; err = -1; goto err_hndl; }
    tmpbuf = tmpbuf_free - gap;

    /* Step 1 - local rotation - shift up by rank */
    err = ompi_datatype_copy_content_same_ddt (sdtype,
                                               (int32_t) ((ptrdiff_t)(size - rank) * (ptrdiff_t)scount),
                                               tmpbuf,
                                               ((char*) sbuf) + (ptrdiff_t)rank * (ptrdiff_t)scount * sext);
    if (err<0) {
        line = __LINE__; err = -1; goto err_hndl;
    }

    if (rank != 0) {
        err = ompi_datatype_copy_content_same_ddt (sdtype, (ptrdiff_t)rank * (ptrdiff_t)scount,
                                                   tmpbuf + (ptrdiff_t)(size - rank) * (ptrdiff_t)scount* sext,
                                                   (char*) sbuf);
        if (err<0) {
            line = __LINE__; err = -1; goto err_hndl;
        }
    }

    /* perform communication step */
    for (distance = 1; distance < size; distance<<=1) {

        sendto = (rank + distance) % size;
        recvfrom = (rank - distance + size) % size;
        k = 0;

        /* create indexed datatype */
        for (i = 1; i < size; i++) {
            if (( i & distance) == distance) {
                displs[k] = (ptrdiff_t)i * (ptrdiff_t)scount;
                blen[k] = scount;
                k++;
            }
        }
        /* Set indexes and displacements */
        err = ompi_datatype_create_indexed(k, blen, displs, sdtype, &new_ddt);
        if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl;  }
        /* Commit the new datatype */
        err = ompi_datatype_commit(&new_ddt);
        if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl;  }

        /* Sendreceive */
        err = ompi_coll_base_sendrecv ( tmpbuf, 1, new_ddt, sendto,
                                         MCA_COLL_BASE_TAG_ALLTOALL,
                                         rbuf, 1, new_ddt, recvfrom,
                                         MCA_COLL_BASE_TAG_ALLTOALL,
                                         comm, MPI_STATUS_IGNORE, rank );
        if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }

        /* Copy back new data from recvbuf to tmpbuf */
        err = ompi_datatype_copy_content_same_ddt(new_ddt, 1,tmpbuf, (char *) rbuf);
        if (err < 0) { line = __LINE__; err = -1; goto err_hndl;  }

        /* free ddt */
        err = ompi_datatype_destroy(&new_ddt);
        if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl;  }
    } /* end of for (distance = 1... */

    /* Step 3 - local rotation - */
    for (i = 0; i < size; i++) {

        err = ompi_datatype_copy_content_same_ddt (rdtype, (int32_t) rcount,
                                                   ((char*)rbuf) + ((ptrdiff_t)((rank - i + size) % size) * (ptrdiff_t)rcount * rext),
                                                   tmpbuf + (ptrdiff_t)i * (ptrdiff_t)rcount * rext);
        if (err < 0) { line = __LINE__; err = -1; goto err_hndl;  }
    }

    /* Step 4 - clean up */
    if (tmpbuf != NULL) free(tmpbuf_free);
    if (displs != NULL) free(displs);
    if (blen != NULL) free(blen);
    return OMPI_SUCCESS;

 err_hndl:
    OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
                 "%s:%4d\tError occurred %d, rank %2d", __FILE__, line, err,
                 rank));
    (void)line;  // silence compiler warning
    if (tmpbuf != NULL) free(tmpbuf_free);
    if (displs != NULL) free(displs);
    if (blen != NULL) free(blen);
    return err;
}