void PartialColScatter ( T alpha, const ElementalMatrix<T>& A, ElementalMatrix<T>& B ) { DEBUG_ONLY(CSE cse("axpy_contract::PartialColScatter")) AssertSameGrids( A, B ); if( A.Height() != B.Height() || A.Width() != B.Width() ) LogicError("A and B must be the same size"); #ifdef EL_CACHE_WARNINGS if( A.Width() != 1 && A.Grid().Rank() == 0 ) { cerr << "axpy_contract::PartialColScatterUpdate potentially causes a large " "amount of cache-thrashing. If possible, avoid it by forming the " "(conjugate-)transpose of the [UGath,* ] matrix instead." << endl; } #endif if( B.ColAlign() % A.ColStride() == A.ColAlign() ) { const Int colStride = B.ColStride(); const Int colStridePart = B.PartialColStride(); const Int colStrideUnion = B.PartialUnionColStride(); const Int colRankPart = B.PartialColRank(); const Int colAlign = B.ColAlign(); const Int height = B.Height(); const Int width = B.Width(); const Int localHeight = B.LocalHeight(); const Int maxLocalHeight = MaxLength( height, colStride ); const Int recvSize = mpi::Pad( maxLocalHeight*width ); const Int sendSize = colStrideUnion*recvSize; //vector<T> buffer( sendSize ); vector<T> buffer; buffer.reserve( sendSize ); // Pack copy::util::PartialColStridedPack ( height, width, colAlign, colStride, colStrideUnion, colStridePart, colRankPart, A.ColShift(), A.LockedBuffer(), A.LDim(), buffer.data(), recvSize ); // Communicate mpi::ReduceScatter( buffer.data(), recvSize, B.PartialUnionColComm() ); // Unpack our received data axpy::util::InterleaveMatrixUpdate ( alpha, localHeight, width, buffer.data(), 1, localHeight, B.Buffer(), 1, B.LDim() ); } else LogicError("Unaligned PartialColScatter not implemented"); }
void Gather ( const ElementalMatrix<T>& A, DistMatrix<T,CIRC,CIRC>& B ) { DEBUG_ONLY(CSE cse("copy::Gather")) AssertSameGrids( A, B ); if( A.DistSize() == 1 && A.CrossSize() == 1 ) { B.Resize( A.Height(), A.Width() ); if( B.CrossRank() == B.Root() ) Copy( A.LockedMatrix(), B.Matrix() ); return; } const Int height = A.Height(); const Int width = A.Width(); B.SetGrid( A.Grid() ); B.Resize( height, width ); // Gather the colShifts and rowShifts // ================================== Int myShifts[2]; myShifts[0] = A.ColShift(); myShifts[1] = A.RowShift(); vector<Int> shifts; const Int crossSize = B.CrossSize(); if( B.CrossRank() == B.Root() ) shifts.resize( 2*crossSize ); mpi::Gather( myShifts, 2, shifts.data(), 2, B.Root(), B.CrossComm() ); // Gather the payload data // ======================= const bool irrelevant = ( A.RedundantRank()!=0 || A.CrossRank()!=A.Root() ); int totalSend = ( irrelevant ? 0 : A.LocalHeight()*A.LocalWidth() ); vector<int> recvCounts, recvOffsets; if( B.CrossRank() == B.Root() ) recvCounts.resize( crossSize ); mpi::Gather( &totalSend, 1, recvCounts.data(), 1, B.Root(), B.CrossComm() ); int totalRecv = Scan( recvCounts, recvOffsets ); //vector<T> sendBuf(totalSend), recvBuf(totalRecv); vector<T> sendBuf, recvBuf; sendBuf.reserve( totalSend ); recvBuf.reserve( totalRecv ); if( !irrelevant ) copy::util::InterleaveMatrix ( A.LocalHeight(), A.LocalWidth(), A.LockedBuffer(), 1, A.LDim(), sendBuf.data(), 1, A.LocalHeight() ); mpi::Gather ( sendBuf.data(), totalSend, recvBuf.data(), recvCounts.data(), recvOffsets.data(), B.Root(), B.CrossComm() ); // Unpack // ====== if( B.Root() == B.CrossRank() ) { for( Int q=0; q<crossSize; ++q ) { if( recvCounts[q] == 0 ) continue; const Int colShift = shifts[2*q+0]; const Int rowShift = shifts[2*q+1]; const Int colStride = A.ColStride(); const Int rowStride = A.RowStride(); const Int localHeight = Length( height, colShift, colStride ); const Int localWidth = Length( width, rowShift, rowStride ); copy::util::InterleaveMatrix ( localHeight, localWidth, &recvBuf[recvOffsets[q]], 1, localHeight, B.Buffer(colShift,rowShift), colStride, rowStride*B.LDim() ); } } }