void Scatter ( T alpha, const ElementalMatrix<T>& A, ElementalMatrix<T>& B ) { DEBUG_ONLY(CSE cse("axpy_contract::Scatter")) AssertSameGrids( A, B ); if( A.Height() != B.Height() || A.Width() != B.Width() ) LogicError("Sizes of A and B must match"); if( !B.Participating() ) return; const Int colStride = B.ColStride(); const Int rowStride = B.RowStride(); const Int colAlign = B.ColAlign(); const Int rowAlign = B.RowAlign(); const Int height = B.Height(); const Int width = B.Width(); const Int localHeight = B.LocalHeight(); const Int localWidth = B.LocalWidth(); const Int maxLocalHeight = MaxLength(height,colStride); const Int maxLocalWidth = MaxLength(width,rowStride); const Int recvSize = mpi::Pad( maxLocalHeight*maxLocalWidth ); const Int sendSize = colStride*rowStride*recvSize; //vector<T> buffer( sendSize ); vector<T> buffer; buffer.reserve( sendSize ); // Pack copy::util::StridedPack ( height, width, colAlign, colStride, rowAlign, rowStride, A.LockedBuffer(), A.LDim(), buffer.data(), recvSize ); // Communicate mpi::ReduceScatter( buffer.data(), recvSize, B.DistComm() ); // Unpack our received data axpy::util::InterleaveMatrixUpdate ( alpha, localHeight, localWidth, buffer.data(), 1, localHeight, B.Buffer(), 1, B.LDim() ); }
void Scatter ( const DistMatrix<T,CIRC,CIRC>& A, ElementalMatrix<T>& B ) { DEBUG_CSE AssertSameGrids( A, B ); const Int m = A.Height(); const Int n = A.Width(); const Int colStride = B.ColStride(); const Int rowStride = B.RowStride(); B.Resize( m, n ); if( B.CrossSize() != 1 || B.RedundantSize() != 1 ) { // TODO: // Broadcast over the redundant communicator and use mpi::Translate // rank to determine whether a process is the root of the broadcast. GeneralPurpose( A, B ); return; } const Int pkgSize = mpi::Pad(MaxLength(m,colStride)*MaxLength(n,rowStride)); const Int recvSize = pkgSize; const Int sendSize = B.DistSize()*pkgSize; // Translate the root of A into the DistComm of B (if possible) const Int root = A.Root(); const Int target = mpi::Translate( A.CrossComm(), root, B.DistComm() ); if( target == mpi::UNDEFINED ) return; if( B.DistSize() == 1 ) { Copy( A.LockedMatrix(), B.Matrix() ); return; } vector<T> buffer; T* recvBuf=0; // some compilers (falsely) warn otherwise if( A.CrossRank() == root ) { FastResize( buffer, sendSize+recvSize ); T* sendBuf = &buffer[0]; recvBuf = &buffer[sendSize]; // Pack the send buffer copy::util::StridedPack ( m, n, B.ColAlign(), colStride, B.RowAlign(), rowStride, A.LockedBuffer(), A.LDim(), sendBuf, pkgSize ); // Scatter from the root mpi::Scatter ( sendBuf, pkgSize, recvBuf, pkgSize, target, B.DistComm() ); } else { FastResize( buffer, recvSize ); recvBuf = &buffer[0]; // Perform the receiving portion of the scatter from the non-root mpi::Scatter ( static_cast<T*>(0), pkgSize, recvBuf, pkgSize, target, B.DistComm() ); } // Unpack copy::util::InterleaveMatrix ( B.LocalHeight(), B.LocalWidth(), recvBuf, 1, B.LocalHeight(), B.Buffer(), 1, B.LDim() ); }