Esempio n. 1
0
void Scatter
( T alpha,
  const ElementalMatrix<T>& A,
        ElementalMatrix<T>& B )
{
    DEBUG_ONLY(CSE cse("axpy_contract::Scatter"))
    AssertSameGrids( A, B );
    if( A.Height() != B.Height() || A.Width() != B.Width() )
        LogicError("Sizes of A and B must match");
    if( !B.Participating() )
        return;

    const Int colStride = B.ColStride();
    const Int rowStride = B.RowStride();
    const Int colAlign = B.ColAlign();
    const Int rowAlign = B.RowAlign();

    const Int height = B.Height();
    const Int width = B.Width();
    const Int localHeight = B.LocalHeight();
    const Int localWidth = B.LocalWidth();
    const Int maxLocalHeight = MaxLength(height,colStride);
    const Int maxLocalWidth = MaxLength(width,rowStride);

    const Int recvSize = mpi::Pad( maxLocalHeight*maxLocalWidth );
    const Int sendSize = colStride*rowStride*recvSize;

    //vector<T> buffer( sendSize );
    vector<T> buffer;
    buffer.reserve( sendSize );

    // Pack 
    copy::util::StridedPack
    ( height, width,
      colAlign, colStride,
      rowAlign, rowStride,
      A.LockedBuffer(), A.LDim(),
      buffer.data(),    recvSize );

    // Communicate
    mpi::ReduceScatter( buffer.data(), recvSize, B.DistComm() );

    // Unpack our received data
    axpy::util::InterleaveMatrixUpdate
    ( alpha, localHeight, localWidth,
      buffer.data(), 1, localHeight,
      B.Buffer(),    1, B.LDim() );
}
Esempio n. 2
0
void Scatter
( const DistMatrix<T,CIRC,CIRC>& A,
        ElementalMatrix<T>& B )
{
    DEBUG_CSE
    AssertSameGrids( A, B );

    const Int m = A.Height();
    const Int n = A.Width();
    const Int colStride = B.ColStride();
    const Int rowStride = B.RowStride();
    B.Resize( m, n );
    if( B.CrossSize() != 1 || B.RedundantSize() != 1 )
    {
        // TODO:
        // Broadcast over the redundant communicator and use mpi::Translate
        // rank to determine whether a process is the root of the broadcast.
        GeneralPurpose( A, B ); 
        return;
    }

    const Int pkgSize = mpi::Pad(MaxLength(m,colStride)*MaxLength(n,rowStride));
    const Int recvSize = pkgSize;
    const Int sendSize = B.DistSize()*pkgSize;

    // Translate the root of A into the DistComm of B (if possible)
    const Int root = A.Root();
    const Int target = mpi::Translate( A.CrossComm(), root, B.DistComm() ); 
    if( target == mpi::UNDEFINED )
        return;

    if( B.DistSize() == 1 )
    {
        Copy( A.LockedMatrix(), B.Matrix() );
        return;
    }

    vector<T> buffer;
    T* recvBuf=0; // some compilers (falsely) warn otherwise
    if( A.CrossRank() == root )
    {
        FastResize( buffer, sendSize+recvSize );
        T* sendBuf = &buffer[0];
        recvBuf    = &buffer[sendSize];

        // Pack the send buffer
        copy::util::StridedPack
        ( m, n,
          B.ColAlign(), colStride,
          B.RowAlign(), rowStride,
          A.LockedBuffer(), A.LDim(),
          sendBuf,          pkgSize );

        // Scatter from the root
        mpi::Scatter
        ( sendBuf, pkgSize, recvBuf, pkgSize, target, B.DistComm() );
    }
    else
    {
        FastResize( buffer, recvSize );
        recvBuf = &buffer[0];

        // Perform the receiving portion of the scatter from the non-root
        mpi::Scatter
        ( static_cast<T*>(0), pkgSize,
          recvBuf,            pkgSize, target, B.DistComm() );
    }

    // Unpack
    copy::util::InterleaveMatrix
    ( B.LocalHeight(), B.LocalWidth(),
      recvBuf,    1, B.LocalHeight(),
      B.Buffer(), 1, B.LDim() );
}