void SUMMA_NTB ( Orientation orientB, T alpha, const AbstractDistMatrix<T>& APre, const AbstractDistMatrix<T>& BPre, AbstractDistMatrix<T>& CPre ) { EL_DEBUG_CSE const Int m = CPre.Height(); const Int bsize = Blocksize(); const Grid& g = APre.Grid(); DistMatrixReadProxy<T,T,MC,MR> AProx( APre ); DistMatrixReadProxy<T,T,MC,MR> BProx( BPre ); DistMatrixReadWriteProxy<T,T,MC,MR> CProx( CPre ); auto& A = AProx.GetLocked(); auto& B = BProx.GetLocked(); auto& C = CProx.Get(); // Temporary distributions DistMatrix<T,MR,STAR> A1Trans_MR_STAR(g); DistMatrix<T,STAR,MC> D1_STAR_MC(g); DistMatrix<T,MR,MC> D1_MR_MC(g); A1Trans_MR_STAR.AlignWith( B ); D1_STAR_MC.AlignWith( B ); for( Int k=0; k<m; k+=bsize ) { const Int nb = Min(bsize,m-k); auto A1 = A( IR(k,k+nb), ALL ); auto C1 = C( IR(k,k+nb), ALL ); // D1[*,MC] := alpha A1[*,MR] (B[MC,MR])^T // = alpha (A1^T)[MR,*] (B^T)[MR,MC] Transpose( A1, A1Trans_MR_STAR ); LocalGemm( TRANSPOSE, orientB, alpha, A1Trans_MR_STAR, B, D1_STAR_MC ); // C1[MC,MR] += scattered & transposed D1[*,MC] summed over grid rows Contract( D1_STAR_MC, D1_MR_MC ); Axpy( T(1), D1_MR_MC, C1 ); } }
inline void GemmTTB ( Orientation orientationOfA, Orientation orientationOfB, T alpha, const DistMatrix<T>& A, const DistMatrix<T>& B, T beta, DistMatrix<T>& C ) { #ifndef RELEASE PushCallStack("internal::GemmTTB"); if( A.Grid() != B.Grid() || B.Grid() != C.Grid() ) throw std::logic_error ("{A,B,C} must be distributed over the same grid"); if( orientationOfA == NORMAL || orientationOfB == NORMAL ) throw std::logic_error ("GemmTTB expects A and B to be (Conjugate)Transposed"); if( A.Width() != C.Height() || B.Height() != C.Width() || A.Height() != B.Width() ) { std::ostringstream msg; msg << "Nonconformal GemmTTB: \n" << " A ~ " << A.Height() << " x " << A.Width() << "\n" << " B ~ " << B.Height() << " x " << B.Width() << "\n" << " C ~ " << C.Height() << " x " << C.Width() << "\n"; throw std::logic_error( msg.str().c_str() ); } #endif const Grid& g = A.Grid(); // Matrix views DistMatrix<T> AL(g), AR(g), A0(g), A1(g), A2(g); DistMatrix<T> CT(g), C0(g), CB(g), C1(g), C2(g); // Temporary distributions DistMatrix<T,VR, STAR> A1_VR_STAR(g); DistMatrix<T,STAR,MR > A1AdjOrTrans_STAR_MR(g); DistMatrix<T,STAR,MC > D1_STAR_MC(g); DistMatrix<T,MR, MC > D1_MR_MC(g); DistMatrix<T> D1(g); A1_VR_STAR.AlignWith( B ); A1AdjOrTrans_STAR_MR.AlignWith( B ); D1_STAR_MC.AlignWith( B ); // Start the algorithm Scale( beta, C ); LockedPartitionRight( A, AL, AR, 0 ); PartitionDown ( C, CT, CB, 0 ); while( AR.Width() > 0 ) { LockedRepartitionRight ( AL, /**/ AR, A0, /**/ A1, A2 ); RepartitionDown ( CT, C0, /**/ /**/ C1, CB, C2 ); D1.AlignWith( C1 ); Zeros( C1.Height(), C1.Width(), D1_STAR_MC ); //--------------------------------------------------------------------// A1_VR_STAR = A1; if( orientationOfA == ADJOINT ) A1AdjOrTrans_STAR_MR.AdjointFrom( A1_VR_STAR ); else A1AdjOrTrans_STAR_MR.TransposeFrom( A1_VR_STAR ); // D1[*,MC] := alpha (A1[MR,*])^[T/H] (B[MC,MR])^[T/H] // = alpha (A1^[T/H])[*,MR] (B^[T/H])[MR,MC] LocalGemm ( NORMAL, orientationOfB, alpha, A1AdjOrTrans_STAR_MR, B, T(0), D1_STAR_MC ); // C1[MC,MR] += scattered & transposed D1[*,MC] summed over grid rows D1_MR_MC.SumScatterFrom( D1_STAR_MC ); D1 = D1_MR_MC; Axpy( T(1), D1, C1 ); //--------------------------------------------------------------------// D1.FreeAlignments(); SlideLockedPartitionRight ( AL, /**/ AR, A0, A1, /**/ A2 ); SlidePartitionDown ( CT, C0, C1, /**/ /**/ CB, C2 ); } #ifndef RELEASE PopCallStack(); #endif }