Ejemplo n.º 1
0
void PartialColScatter
( T alpha,
  const ElementalMatrix<T>& A,
        ElementalMatrix<T>& B )
{
    DEBUG_ONLY(CSE cse("axpy_contract::PartialColScatter"))
    AssertSameGrids( A, B );
    if( A.Height() != B.Height() || A.Width() != B.Width() )
        LogicError("A and B must be the same size");

#ifdef EL_CACHE_WARNINGS
    if( A.Width() != 1 && A.Grid().Rank() == 0 )
    {
        cerr <<
          "axpy_contract::PartialColScatterUpdate potentially causes a large "
          "amount of cache-thrashing. If possible, avoid it by forming the "
          "(conjugate-)transpose of the [UGath,* ] matrix instead."
          << endl;
    }
#endif
    if( B.ColAlign() % A.ColStride() == A.ColAlign() )
    {
        const Int colStride = B.ColStride();
        const Int colStridePart = B.PartialColStride();
        const Int colStrideUnion = B.PartialUnionColStride();
        const Int colRankPart = B.PartialColRank();
        const Int colAlign = B.ColAlign();

        const Int height = B.Height();
        const Int width = B.Width();
        const Int localHeight = B.LocalHeight();
        const Int maxLocalHeight = MaxLength( height, colStride );
        const Int recvSize = mpi::Pad( maxLocalHeight*width );
        const Int sendSize = colStrideUnion*recvSize;

        //vector<T> buffer( sendSize );
        vector<T> buffer;
        buffer.reserve( sendSize );

        // Pack
        copy::util::PartialColStridedPack
        ( height, width,
          colAlign, colStride,
          colStrideUnion, colStridePart, colRankPart,
          A.ColShift(),
          A.LockedBuffer(), A.LDim(),
          buffer.data(),    recvSize );

        // Communicate
        mpi::ReduceScatter( buffer.data(), recvSize, B.PartialUnionColComm() );

        // Unpack our received data
        axpy::util::InterleaveMatrixUpdate
        ( alpha, localHeight, width,
          buffer.data(), 1, localHeight,
          B.Buffer(),    1, B.LDim() );
    }
    else
        LogicError("Unaligned PartialColScatter not implemented");
}
Ejemplo n.º 2
0
void Gather
( const ElementalMatrix<T>& A,
        DistMatrix<T,CIRC,CIRC>& B )
{
    DEBUG_ONLY(CSE cse("copy::Gather"))
    AssertSameGrids( A, B );
    if( A.DistSize() == 1 && A.CrossSize() == 1 )
    {
        B.Resize( A.Height(), A.Width() );
        if( B.CrossRank() == B.Root() )
            Copy( A.LockedMatrix(), B.Matrix() );
        return;
    }

    const Int height = A.Height();
    const Int width = A.Width();
    B.SetGrid( A.Grid() );
    B.Resize( height, width );

    // Gather the colShifts and rowShifts
    // ==================================
    Int myShifts[2];
    myShifts[0] = A.ColShift();
    myShifts[1] = A.RowShift();
    vector<Int> shifts;
    const Int crossSize = B.CrossSize();
    if( B.CrossRank() == B.Root() )
        shifts.resize( 2*crossSize );
    mpi::Gather( myShifts, 2, shifts.data(), 2, B.Root(), B.CrossComm() );

    // Gather the payload data
    // =======================
    const bool irrelevant = ( A.RedundantRank()!=0 || A.CrossRank()!=A.Root() );
    int totalSend = ( irrelevant ? 0 : A.LocalHeight()*A.LocalWidth() );
    vector<int> recvCounts, recvOffsets;
    if( B.CrossRank() == B.Root() )
        recvCounts.resize( crossSize );
    mpi::Gather( &totalSend, 1, recvCounts.data(), 1, B.Root(), B.CrossComm() );
    int totalRecv = Scan( recvCounts, recvOffsets );
    //vector<T> sendBuf(totalSend), recvBuf(totalRecv);
    vector<T> sendBuf, recvBuf;
    sendBuf.reserve( totalSend );
    recvBuf.reserve( totalRecv );
    if( !irrelevant )
        copy::util::InterleaveMatrix
        ( A.LocalHeight(), A.LocalWidth(),
          A.LockedBuffer(), 1, A.LDim(),
          sendBuf.data(),   1, A.LocalHeight() );
    mpi::Gather
    ( sendBuf.data(), totalSend,
      recvBuf.data(), recvCounts.data(), recvOffsets.data(), 
      B.Root(), B.CrossComm() );

    // Unpack
    // ======
    if( B.Root() == B.CrossRank() )
    {
        for( Int q=0; q<crossSize; ++q )
        {
            if( recvCounts[q] == 0 )
                continue;
            const Int colShift = shifts[2*q+0];
            const Int rowShift = shifts[2*q+1];
            const Int colStride = A.ColStride();
            const Int rowStride = A.RowStride();
            const Int localHeight = Length( height, colShift, colStride );
            const Int localWidth = Length( width, rowShift, rowStride );
            copy::util::InterleaveMatrix
            ( localHeight, localWidth,
              &recvBuf[recvOffsets[q]],    1,         localHeight,
              B.Buffer(colShift,rowShift), colStride, rowStride*B.LDim() );
        }
    }
}