inline void BinaryFlat ( DistMatrix<T,CIRC,CIRC>& A, Int height, Int width, const std::string filename ) { DEBUG_ONLY(CallStackEntry cse("read::Binary")) std::ifstream file( filename.c_str(), std::ios::binary ); if( !file.is_open() ) RuntimeError("Could not open ",filename); const Int numBytes = FileSize( file ); const Int numBytesExp = height*width*sizeof(T); if( numBytes != numBytesExp ) RuntimeError ("Expected file to be ",numBytesExp," bytes but found ",numBytes); A.Resize( height, width ); if( A.CrossRank() == A.Root() ) { if( A.Height() == A.LDim() ) file.read( (char*)A.Buffer(), height*width*sizeof(T) ); else for( Int j=0; j<width; ++j ) file.read( (char*)A.Buffer(0,j), height*sizeof(T) ); } }
inline void BinaryFlat ( DistMatrix<T,U,V>& A, Int height, Int width, const std::string filename ) { DEBUG_ONLY(CallStackEntry cse("read::BinaryFlat")) std::ifstream file( filename.c_str(), std::ios::binary ); if( !file.is_open() ) RuntimeError("Could not open ",filename); const Int numBytes = FileSize( file ); const Int numBytesExp = height*width*sizeof(T); if( numBytes != numBytesExp ) RuntimeError ("Expected file to be ",numBytesExp," bytes but found ",numBytes); A.Resize( height, width ); if( U == A.UGath && V == A.VGath ) { if( A.CrossRank() == A.Root() ) { if( A.Height() == A.LDim() ) file.read( (char*)A.Buffer(), height*width*sizeof(T) ); else for( Int j=0; j<width; ++j ) file.read( (char*)A.Buffer(0,j), height*sizeof(T) ); } } else if( U == A.UGath ) { const Int localWidth = A.LocalWidth(); for( Int jLoc=0; jLoc<localWidth; ++jLoc ) { const Int j = A.GlobalCol(jLoc); const Int localIndex = j*height; const std::streamoff pos = localIndex*sizeof(T); file.seekg( pos ); file.read( (char*)A.Buffer(0,jLoc), height*sizeof(T) ); } } else { const Int localHeight = A.LocalHeight(); const Int localWidth = A.LocalWidth(); for( Int jLoc=0; jLoc<localWidth; ++jLoc ) { const Int j = A.GlobalCol(jLoc); for( Int iLoc=0; iLoc<localHeight; ++iLoc ) { const Int i = A.GlobalRow(iLoc); const Int localIndex = i+j*height; const std::streamoff pos = localIndex*sizeof(T); file.seekg( pos ); file.read( (char*)A.Buffer(iLoc,jLoc), sizeof(T) ); } } } }
void Gather ( const BlockMatrix<T>& A, DistMatrix<T,CIRC,CIRC,BLOCK>& B ) { DEBUG_ONLY(CSE cse("copy::Gather")) AssertSameGrids( A, B ); if( A.DistSize() == 1 && A.CrossSize() == 1 ) { B.Resize( A.Height(), A.Width() ); if( B.CrossRank() == B.Root() ) Copy( A.LockedMatrix(), B.Matrix() ); return; } const Int height = A.Height(); const Int width = A.Width(); B.SetGrid( A.Grid() ); B.Resize( height, width ); // Gather the colShifts and rowShifts // ================================== Int myShifts[2]; myShifts[0] = A.ColShift(); myShifts[1] = A.RowShift(); vector<Int> shifts; const Int crossSize = B.CrossSize(); if( B.CrossRank() == B.Root() ) shifts.resize( 2*crossSize ); mpi::Gather( myShifts, 2, shifts.data(), 2, B.Root(), B.CrossComm() ); // Gather the payload data // ======================= const bool irrelevant = ( A.RedundantRank()!=0 || A.CrossRank()!=A.Root() ); int totalSend = ( irrelevant ? 0 : A.LocalHeight()*A.LocalWidth() ); vector<int> recvCounts, recvOffsets; if( B.CrossRank() == B.Root() ) recvCounts.resize( crossSize ); mpi::Gather( &totalSend, 1, recvCounts.data(), 1, B.Root(), B.CrossComm() ); int totalRecv = Scan( recvCounts, recvOffsets ); //vector<T> sendBuf(totalSend), recvBuf(totalRecv); vector<T> sendBuf, recvBuf; sendBuf.reserve( totalSend ); recvBuf.reserve( totalRecv ); if( !irrelevant ) copy::util::InterleaveMatrix ( A.LocalHeight(), A.LocalWidth(), A.LockedBuffer(), 1, A.LDim(), sendBuf.data(), 1, A.LocalHeight() ); mpi::Gather ( sendBuf.data(), totalSend, recvBuf.data(), recvCounts.data(), recvOffsets.data(), B.Root(), B.CrossComm() ); // Unpack // ====== const Int mb = A.BlockHeight(); const Int nb = A.BlockWidth(); const Int colCut = A.ColCut(); const Int rowCut = A.RowCut(); if( B.Root() == B.CrossRank() ) { for( Int q=0; q<crossSize; ++q ) { if( recvCounts[q] == 0 ) continue; const Int colShift = shifts[2*q+0]; const Int rowShift = shifts[2*q+1]; const Int colStride = A.ColStride(); const Int rowStride = A.RowStride(); const Int localHeight = BlockedLength( height, colShift, mb, colCut, colStride ); const Int localWidth = BlockedLength( width, rowShift, nb, rowCut, rowStride ); const T* data = &recvBuf[recvOffsets[q]]; for( Int jLoc=0; jLoc<localWidth; ++jLoc ) { const Int jBefore = rowShift*nb - rowCut; const Int jLocAdj = ( rowShift==0 ? jLoc+rowCut : jLoc ); const Int numFilledLocalBlocks = jLocAdj / nb; const Int jMid = numFilledLocalBlocks*nb*rowStride; const Int jPost = jLocAdj-numFilledLocalBlocks*nb; const Int j = jBefore + jMid + jPost; const T* sourceCol = &data[jLoc*localHeight]; for( Int iLoc=0; iLoc<localHeight; ++iLoc ) { const Int iBefore = colShift*mb - colCut; const Int iLocAdj = (colShift==0 ? iLoc+colCut : iLoc); const Int numFilledLocalBlocks = iLocAdj / mb; const Int iMid = numFilledLocalBlocks*mb*colStride; const Int iPost = iLocAdj-numFilledLocalBlocks*mb; const Int i = iBefore + iMid + iPost; B.SetLocal(i,j,sourceCol[iLoc]); } } } } }
void Gather ( const ElementalMatrix<T>& A, DistMatrix<T,CIRC,CIRC>& B ) { DEBUG_ONLY(CSE cse("copy::Gather")) AssertSameGrids( A, B ); if( A.DistSize() == 1 && A.CrossSize() == 1 ) { B.Resize( A.Height(), A.Width() ); if( B.CrossRank() == B.Root() ) Copy( A.LockedMatrix(), B.Matrix() ); return; } const Int height = A.Height(); const Int width = A.Width(); B.SetGrid( A.Grid() ); B.Resize( height, width ); // Gather the colShifts and rowShifts // ================================== Int myShifts[2]; myShifts[0] = A.ColShift(); myShifts[1] = A.RowShift(); vector<Int> shifts; const Int crossSize = B.CrossSize(); if( B.CrossRank() == B.Root() ) shifts.resize( 2*crossSize ); mpi::Gather( myShifts, 2, shifts.data(), 2, B.Root(), B.CrossComm() ); // Gather the payload data // ======================= const bool irrelevant = ( A.RedundantRank()!=0 || A.CrossRank()!=A.Root() ); int totalSend = ( irrelevant ? 0 : A.LocalHeight()*A.LocalWidth() ); vector<int> recvCounts, recvOffsets; if( B.CrossRank() == B.Root() ) recvCounts.resize( crossSize ); mpi::Gather( &totalSend, 1, recvCounts.data(), 1, B.Root(), B.CrossComm() ); int totalRecv = Scan( recvCounts, recvOffsets ); //vector<T> sendBuf(totalSend), recvBuf(totalRecv); vector<T> sendBuf, recvBuf; sendBuf.reserve( totalSend ); recvBuf.reserve( totalRecv ); if( !irrelevant ) copy::util::InterleaveMatrix ( A.LocalHeight(), A.LocalWidth(), A.LockedBuffer(), 1, A.LDim(), sendBuf.data(), 1, A.LocalHeight() ); mpi::Gather ( sendBuf.data(), totalSend, recvBuf.data(), recvCounts.data(), recvOffsets.data(), B.Root(), B.CrossComm() ); // Unpack // ====== if( B.Root() == B.CrossRank() ) { for( Int q=0; q<crossSize; ++q ) { if( recvCounts[q] == 0 ) continue; const Int colShift = shifts[2*q+0]; const Int rowShift = shifts[2*q+1]; const Int colStride = A.ColStride(); const Int rowStride = A.RowStride(); const Int localHeight = Length( height, colShift, colStride ); const Int localWidth = Length( width, rowShift, rowStride ); copy::util::InterleaveMatrix ( localHeight, localWidth, &recvBuf[recvOffsets[q]], 1, localHeight, B.Buffer(colShift,rowShift), colStride, rowStride*B.LDim() ); } } }
void AllGather ( const DistMatrix<T, U, V >& A, DistMatrix<T,Collect<U>(),Collect<V>()>& B ) { DEBUG_ONLY(CSE cse("copy::AllGather")) AssertSameGrids( A, B ); const Int height = A.Height(); const Int width = A.Width(); B.SetGrid( A.Grid() ); B.Resize( height, width ); if( A.Participating() ) { const Int colStride = A.ColStride(); const Int rowStride = A.RowStride(); const Int distStride = colStride*rowStride; const Int maxLocalHeight = MaxLength(height,colStride); const Int maxLocalWidth = MaxLength(width,rowStride); const Int portionSize = mpi::Pad( maxLocalHeight*maxLocalWidth ); vector<T> buf( (distStride+1)*portionSize ); T* sendBuf = &buf[0]; T* recvBuf = &buf[portionSize]; // Pack util::InterleaveMatrix ( A.LocalHeight(), A.LocalWidth(), A.LockedBuffer(), 1, A.LDim(), sendBuf, 1, A.LocalHeight() ); // Communicate mpi::AllGather ( sendBuf, portionSize, recvBuf, portionSize, A.DistComm() ); // Unpack util::StridedUnpack ( height, width, A.ColAlign(), colStride, A.RowAlign(), rowStride, recvBuf, portionSize, B.Buffer(), B.LDim() ); } if( A.Grid().InGrid() && A.CrossComm() != mpi::COMM_SELF ) { // Pack from the root const Int BLocalHeight = B.LocalHeight(); const Int BLocalWidth = B.LocalWidth(); vector<T> buf(BLocalHeight*BLocalWidth); if( A.CrossRank() == A.Root() ) util::InterleaveMatrix ( BLocalHeight, BLocalWidth, B.LockedBuffer(), 1, B.LDim(), buf.data(), 1, BLocalHeight ); // Broadcast from the root mpi::Broadcast ( buf.data(), BLocalHeight*BLocalWidth, A.Root(), A.CrossComm() ); // Unpack if not the root if( A.CrossRank() != A.Root() ) util::InterleaveMatrix ( BLocalHeight, BLocalWidth, buf.data(), 1, BLocalHeight, B.Buffer(), 1, B.LDim() ); } }
void Scatter ( const DistMatrix<T,CIRC,CIRC>& A, ElementalMatrix<T>& B ) { DEBUG_CSE AssertSameGrids( A, B ); const Int m = A.Height(); const Int n = A.Width(); const Int colStride = B.ColStride(); const Int rowStride = B.RowStride(); B.Resize( m, n ); if( B.CrossSize() != 1 || B.RedundantSize() != 1 ) { // TODO: // Broadcast over the redundant communicator and use mpi::Translate // rank to determine whether a process is the root of the broadcast. GeneralPurpose( A, B ); return; } const Int pkgSize = mpi::Pad(MaxLength(m,colStride)*MaxLength(n,rowStride)); const Int recvSize = pkgSize; const Int sendSize = B.DistSize()*pkgSize; // Translate the root of A into the DistComm of B (if possible) const Int root = A.Root(); const Int target = mpi::Translate( A.CrossComm(), root, B.DistComm() ); if( target == mpi::UNDEFINED ) return; if( B.DistSize() == 1 ) { Copy( A.LockedMatrix(), B.Matrix() ); return; } vector<T> buffer; T* recvBuf=0; // some compilers (falsely) warn otherwise if( A.CrossRank() == root ) { FastResize( buffer, sendSize+recvSize ); T* sendBuf = &buffer[0]; recvBuf = &buffer[sendSize]; // Pack the send buffer copy::util::StridedPack ( m, n, B.ColAlign(), colStride, B.RowAlign(), rowStride, A.LockedBuffer(), A.LDim(), sendBuf, pkgSize ); // Scatter from the root mpi::Scatter ( sendBuf, pkgSize, recvBuf, pkgSize, target, B.DistComm() ); } else { FastResize( buffer, recvSize ); recvBuf = &buffer[0]; // Perform the receiving portion of the scatter from the non-root mpi::Scatter ( static_cast<T*>(0), pkgSize, recvBuf, pkgSize, target, B.DistComm() ); } // Unpack copy::util::InterleaveMatrix ( B.LocalHeight(), B.LocalWidth(), recvBuf, 1, B.LocalHeight(), B.Buffer(), 1, B.LDim() ); }