void Contract ( const BlockMatrix<T>& A, BlockMatrix<T>& B ) { DEBUG_ONLY(CSE cse("Contract")) AssertSameGrids( A, B ); const Dist U = B.ColDist(); const Dist V = B.RowDist(); // TODO: Shorten this implementation? if( A.ColDist() == U && A.RowDist() == V ) { Copy( A, B ); } else if( A.ColDist() == U && A.RowDist() == Partial(V) ) { B.AlignAndResize ( A.BlockHeight(), A.BlockWidth(), A.ColAlign(), A.RowAlign(), A.ColCut(), A.RowCut(), A.Height(), A.Width(), false, false ); Zeros( B.Matrix(), B.LocalHeight(), B.LocalWidth() ); AxpyContract( T(1), A, B ); } else if( A.ColDist() == Partial(U) && A.RowDist() == V ) { B.AlignAndResize ( A.BlockHeight(), A.BlockWidth(), A.ColAlign(), A.RowAlign(), A.ColCut(), A.RowCut(), A.Height(), A.Width(), false, false ); Zeros( B.Matrix(), B.LocalHeight(), B.LocalWidth() ); AxpyContract( T(1), A, B ); } else if( A.ColDist() == U && A.RowDist() == Collect(V) ) { B.AlignColsAndResize ( A.BlockHeight(), A.ColAlign(), A.ColCut(), A.Height(), A.Width(), false, false ); Zeros( B.Matrix(), B.LocalHeight(), B.LocalWidth() ); AxpyContract( T(1), A, B ); } else if( A.ColDist() == Collect(U) && A.RowDist() == V ) { B.AlignRowsAndResize ( A.BlockWidth(), A.RowAlign(), A.RowCut(), A.Height(), A.Width(), false, false ); Zeros( B.Matrix(), B.LocalHeight(), B.LocalWidth() ); AxpyContract( T(1), A, B ); } else if( A.ColDist() == Collect(U) && A.RowDist() == Collect(V) ) { Zeros( B, A.Height(), A.Width() ); AxpyContract( T(1), A, B ); } else LogicError("Incompatible distributions"); }
void LT_Dot ( T alpha, const AbstractDistMatrix<T>& APre, AbstractDistMatrix<T>& CPre, const bool conjugate, Int blockSize=2000 ) { EL_DEBUG_CSE const Int n = CPre.Height(); const Grid& g = APre.Grid(); const Orientation orient = ( conjugate ? ADJOINT : TRANSPOSE ); DistMatrixReadProxy<T,T,VC,STAR> AProx( APre ); auto& A = AProx.GetLocked(); DistMatrixReadWriteProxy<T,T,MC,MR> CProx( CPre ); auto& C = CProx.Get(); DistMatrix<T,STAR,STAR> Z( blockSize, blockSize, g ); Zero( Z ); for( Int kOuter=0; kOuter<n; kOuter+=blockSize ) { const Int nbOuter = Min(blockSize,n-kOuter); const Range<Int> indOuter( kOuter, kOuter+nbOuter ); auto A1 = A( ALL, indOuter ); auto C11 = C( indOuter, indOuter ); Z.Resize( nbOuter, nbOuter ); Syrk( LOWER, TRANSPOSE, alpha, A1.Matrix(), Z.Matrix(), conjugate ); AxpyContract( T(1), Z, C11 ); for( Int kInner=kOuter+nbOuter; kInner<n; kInner+=blockSize ) { const Int nbInner = Min(blockSize,n-kInner); const Range<Int> indInner( kInner, kInner+nbInner ); auto A2 = A( ALL, indInner ); auto C21 = C( indInner, indOuter ); LocalGemm( orient, NORMAL, alpha, A1, A2, Z ); AxpyContract( T(1), Z, C21 ); } } }
void SUMMA_NTDot ( Orientation orientB, T alpha, const AbstractDistMatrix<T>& APre, const AbstractDistMatrix<T>& BPre, AbstractDistMatrix<T>& CPre, Int blockSize=2000 ) { EL_DEBUG_CSE const Int m = CPre.Height(); const Int n = CPre.Width(); const Grid& g = APre.Grid(); DistMatrixReadProxy<T,T,STAR,VC> AProx( APre ); auto& A = AProx.GetLocked(); ElementalProxyCtrl BCtrl; BCtrl.rowConstrain = true; BCtrl.rowAlign = A.RowAlign(); DistMatrixReadProxy<T,T,STAR,VC> BProx( BPre, BCtrl ); auto& B = BProx.GetLocked(); DistMatrixReadWriteProxy<T,T,MC,MR> CProx( CPre ); auto& C = CProx.Get(); DistMatrix<T,STAR,STAR> C11_STAR_STAR(g); for( Int kOuter=0; kOuter<m; kOuter+=blockSize ) { const Int nbOuter = Min(blockSize,m-kOuter); const Range<Int> indOuter( kOuter, kOuter+nbOuter ); auto A1 = A( indOuter, ALL ); for( Int kInner=0; kInner<n; kInner+=blockSize ) { const Int nbInner = Min(blockSize,n-kInner); const Range<Int> indInner( kInner, kInner+nbInner ); auto B1 = B( indInner, ALL ); auto C11 = C( indOuter, indInner ); LocalGemm( NORMAL, orientB, alpha, A1, B1, C11_STAR_STAR ); AxpyContract( T(1), C11_STAR_STAR, C11 ); } } }
void SUMMA_NTA ( Orientation orientB, T alpha, const AbstractDistMatrix<T>& APre, const AbstractDistMatrix<T>& BPre, AbstractDistMatrix<T>& CPre ) { EL_DEBUG_CSE const Int n = CPre.Width(); const Int bsize = Blocksize(); const Grid& g = APre.Grid(); const bool conjugate = ( orientB == ADJOINT ); DistMatrixReadProxy<T,T,MC,MR> AProx( APre ); DistMatrixReadProxy<T,T,MC,MR> BProx( BPre ); DistMatrixReadWriteProxy<T,T,MC,MR> CProx( CPre ); auto& A = AProx.GetLocked(); auto& B = BProx.GetLocked(); auto& C = CProx.Get(); // Temporary distributions DistMatrix<T,MR,STAR> B1Trans_MR_STAR(g); DistMatrix<T,MC,STAR> D1_MC_STAR(g); B1Trans_MR_STAR.AlignWith( A ); D1_MC_STAR.AlignWith( A ); for( Int k=0; k<n; k+=bsize ) { const Int nb = Min(bsize,n-k); auto B1 = B( IR(k,k+nb), ALL ); auto C1 = C( ALL, IR(k,k+nb) ); // C1[MC,*] := alpha A[MC,MR] (B1^[T/H])[MR,*] Transpose( B1, B1Trans_MR_STAR, conjugate ); LocalGemm( NORMAL, NORMAL, alpha, A, B1Trans_MR_STAR, D1_MC_STAR ); // C1[MC,MR] += scattered result of D1[MC,*] summed over grid rows AxpyContract( T(1), D1_MC_STAR, C1 ); } }