void SUMMA_TNA ( Orientation orientA, T alpha, const AbstractDistMatrix<T>& APre, const AbstractDistMatrix<T>& BPre, AbstractDistMatrix<T>& CPre ) { DEBUG_CSE const Int n = CPre.Width(); const Int bsize = Blocksize(); const Grid& g = APre.Grid(); DistMatrixReadProxy<T,T,MC,MR> AProx( APre ); DistMatrixReadProxy<T,T,MC,MR> BProx( BPre ); DistMatrixReadWriteProxy<T,T,MC,MR> CProx( CPre ); auto& A = AProx.GetLocked(); auto& B = BProx.GetLocked(); auto& C = CProx.Get(); // Temporary distributions DistMatrix<T,MC,STAR> B1_MC_STAR(g); DistMatrix<T,MR,STAR> D1_MR_STAR(g); DistMatrix<T,MR,MC > D1_MR_MC(g); B1_MC_STAR.AlignWith( A ); D1_MR_STAR.AlignWith( A ); for( Int k=0; k<n; k+=bsize ) { const Int nb = Min(bsize,n-k); auto B1 = B( ALL, IR(k,k+nb) ); auto C1 = C( ALL, IR(k,k+nb) ); // D1[MR,*] := alpha (A1[MC,MR])^T B1[MC,*] // = alpha (A1^T)[MR,MC] B1[MC,*] B1_MC_STAR = B1; LocalGemm( orientA, NORMAL, alpha, A, B1_MC_STAR, D1_MR_STAR ); // C1[MC,MR] += scattered & transposed D1[MR,*] summed over grid cols Contract( D1_MR_STAR, D1_MR_MC ); Axpy( T(1), D1_MR_MC, C1 ); } }
void SUMMA_NTB ( Orientation orientB, T alpha, const AbstractDistMatrix<T>& APre, const AbstractDistMatrix<T>& BPre, AbstractDistMatrix<T>& CPre ) { EL_DEBUG_CSE const Int m = CPre.Height(); const Int bsize = Blocksize(); const Grid& g = APre.Grid(); DistMatrixReadProxy<T,T,MC,MR> AProx( APre ); DistMatrixReadProxy<T,T,MC,MR> BProx( BPre ); DistMatrixReadWriteProxy<T,T,MC,MR> CProx( CPre ); auto& A = AProx.GetLocked(); auto& B = BProx.GetLocked(); auto& C = CProx.Get(); // Temporary distributions DistMatrix<T,MR,STAR> A1Trans_MR_STAR(g); DistMatrix<T,STAR,MC> D1_STAR_MC(g); DistMatrix<T,MR,MC> D1_MR_MC(g); A1Trans_MR_STAR.AlignWith( B ); D1_STAR_MC.AlignWith( B ); for( Int k=0; k<m; k+=bsize ) { const Int nb = Min(bsize,m-k); auto A1 = A( IR(k,k+nb), ALL ); auto C1 = C( IR(k,k+nb), ALL ); // D1[*,MC] := alpha A1[*,MR] (B[MC,MR])^T // = alpha (A1^T)[MR,*] (B^T)[MR,MC] Transpose( A1, A1Trans_MR_STAR ); LocalGemm( TRANSPOSE, orientB, alpha, A1Trans_MR_STAR, B, D1_STAR_MC ); // C1[MC,MR] += scattered & transposed D1[*,MC] summed over grid rows Contract( D1_STAR_MC, D1_MR_MC ); Axpy( T(1), D1_MR_MC, C1 ); } }
inline void GemmTTA ( Orientation orientationOfA, Orientation orientationOfB, T alpha, const DistMatrix<T>& A, const DistMatrix<T>& B, T beta, DistMatrix<T>& C ) { #ifndef RELEASE PushCallStack("internal::GemmTTA"); if( A.Grid() != B.Grid() || B.Grid() != C.Grid() ) throw std::logic_error ("{A,B,C} must be distributed over the same grid"); if( orientationOfA == NORMAL || orientationOfB == NORMAL ) throw std::logic_error ("GemmTTA expects A and B to be (Conjugate)Transposed"); if( A.Width() != C.Height() || B.Height() != C.Width() || A.Height() != B.Width() ) { std::ostringstream msg; msg << "Nonconformal GemmTTA: \n" << " A ~ " << A.Height() << " x " << A.Width() << "\n" << " B ~ " << B.Height() << " x " << B.Width() << "\n" << " C ~ " << C.Height() << " x " << C.Width() << "\n"; throw std::logic_error( msg.str().c_str() ); } #endif const Grid& g = A.Grid(); // Matrix views DistMatrix<T> BT(g), B0(g), BB(g), B1(g), B2(g); DistMatrix<T> CL(g), CR(g), C0(g), C1(g), C2(g); // Temporary distributions DistMatrix<T,STAR,MC > B1_STAR_MC(g); DistMatrix<T,MR, STAR> D1_MR_STAR(g); DistMatrix<T,MR, MC > D1_MR_MC(g); DistMatrix<T> D1(g); B1_STAR_MC.AlignWith( A ); D1_MR_STAR.AlignWith( A ); // Start the algorithm Scale( beta, C ); LockedPartitionDown ( B, BT, BB, 0 ); PartitionRight( C, CL, CR, 0 ); while( BB.Height() > 0 ) { LockedRepartitionDown ( BT, B0, /**/ /**/ B1, BB, B2 ); RepartitionRight ( CL, /**/ CR, C0, /**/ C1, C2 ); D1.AlignWith( C1 ); Zeros( C1.Height(), C1.Width(), D1_MR_STAR ); //--------------------------------------------------------------------// B1_STAR_MC = B1; // B1[*,MC] <- B1[MC,MR] // D1[MR,*] := alpha (A[MC,MR])^T (B1[*,MC])^T // = alpha (A^T)[MR,MC] (B1^T)[MC,*] LocalGemm ( orientationOfA, orientationOfB, alpha, A, B1_STAR_MC, T(0), D1_MR_STAR ); // C1[MC,MR] += scattered & transposed D1[MR,*] summed over grid cols D1_MR_MC.SumScatterFrom( D1_MR_STAR ); D1 = D1_MR_MC; Axpy( T(1), D1, C1 ); //--------------------------------------------------------------------// D1.FreeAlignments(); SlideLockedPartitionDown ( BT, B0, B1, /**/ /**/ BB, B2 ); SlidePartitionRight ( CL, /**/ CR, C0, C1, /**/ C2 ); } #ifndef RELEASE PopCallStack(); #endif }
inline void GemmTTB ( Orientation orientationOfA, Orientation orientationOfB, T alpha, const DistMatrix<T>& A, const DistMatrix<T>& B, T beta, DistMatrix<T>& C ) { #ifndef RELEASE PushCallStack("internal::GemmTTB"); if( A.Grid() != B.Grid() || B.Grid() != C.Grid() ) throw std::logic_error ("{A,B,C} must be distributed over the same grid"); if( orientationOfA == NORMAL || orientationOfB == NORMAL ) throw std::logic_error ("GemmTTB expects A and B to be (Conjugate)Transposed"); if( A.Width() != C.Height() || B.Height() != C.Width() || A.Height() != B.Width() ) { std::ostringstream msg; msg << "Nonconformal GemmTTB: \n" << " A ~ " << A.Height() << " x " << A.Width() << "\n" << " B ~ " << B.Height() << " x " << B.Width() << "\n" << " C ~ " << C.Height() << " x " << C.Width() << "\n"; throw std::logic_error( msg.str().c_str() ); } #endif const Grid& g = A.Grid(); // Matrix views DistMatrix<T> AL(g), AR(g), A0(g), A1(g), A2(g); DistMatrix<T> CT(g), C0(g), CB(g), C1(g), C2(g); // Temporary distributions DistMatrix<T,VR, STAR> A1_VR_STAR(g); DistMatrix<T,STAR,MR > A1AdjOrTrans_STAR_MR(g); DistMatrix<T,STAR,MC > D1_STAR_MC(g); DistMatrix<T,MR, MC > D1_MR_MC(g); DistMatrix<T> D1(g); A1_VR_STAR.AlignWith( B ); A1AdjOrTrans_STAR_MR.AlignWith( B ); D1_STAR_MC.AlignWith( B ); // Start the algorithm Scale( beta, C ); LockedPartitionRight( A, AL, AR, 0 ); PartitionDown ( C, CT, CB, 0 ); while( AR.Width() > 0 ) { LockedRepartitionRight ( AL, /**/ AR, A0, /**/ A1, A2 ); RepartitionDown ( CT, C0, /**/ /**/ C1, CB, C2 ); D1.AlignWith( C1 ); Zeros( C1.Height(), C1.Width(), D1_STAR_MC ); //--------------------------------------------------------------------// A1_VR_STAR = A1; if( orientationOfA == ADJOINT ) A1AdjOrTrans_STAR_MR.AdjointFrom( A1_VR_STAR ); else A1AdjOrTrans_STAR_MR.TransposeFrom( A1_VR_STAR ); // D1[*,MC] := alpha (A1[MR,*])^[T/H] (B[MC,MR])^[T/H] // = alpha (A1^[T/H])[*,MR] (B^[T/H])[MR,MC] LocalGemm ( NORMAL, orientationOfB, alpha, A1AdjOrTrans_STAR_MR, B, T(0), D1_STAR_MC ); // C1[MC,MR] += scattered & transposed D1[*,MC] summed over grid rows D1_MR_MC.SumScatterFrom( D1_STAR_MC ); D1 = D1_MR_MC; Axpy( T(1), D1, C1 ); //--------------------------------------------------------------------// D1.FreeAlignments(); SlideLockedPartitionRight ( AL, /**/ AR, A0, A1, /**/ A2 ); SlidePartitionDown ( CT, C0, C1, /**/ /**/ CB, C2 ); } #ifndef RELEASE PopCallStack(); #endif }
inline void internal::GemmTNA ( Orientation orientationOfA, T alpha, const DistMatrix<T,MC,MR>& A, const DistMatrix<T,MC,MR>& B, T beta, DistMatrix<T,MC,MR>& C ) { #ifndef RELEASE PushCallStack("internal::GemmTNA"); if( A.Grid() != B.Grid() || B.Grid() != C.Grid() ) throw std::logic_error ("{A,B,C} must be distributed over the same grid"); if( orientationOfA == NORMAL ) throw std::logic_error("GemmTNA assumes A is (Conjugate)Transposed"); if( A.Width() != C.Height() || B.Width() != C.Width() || A.Height() != B.Height() ) { std::ostringstream msg; msg << "Nonconformal GemmTNA: \n" << " A ~ " << A.Height() << " x " << A.Width() << "\n" << " B ~ " << B.Height() << " x " << B.Width() << "\n" << " C ~ " << C.Height() << " x " << C.Width() << "\n"; throw std::logic_error( msg.str().c_str() ); } #endif const Grid& g = A.Grid(); // Matrix views DistMatrix<T,MC,MR> BL(g), BR(g), B0(g), B1(g), B2(g); DistMatrix<T,MC,MR> CL(g), CR(g), C0(g), C1(g), C2(g); // Temporary distributions DistMatrix<T,MC,STAR> B1_MC_STAR(g); DistMatrix<T,MR,STAR> D1_MR_STAR(g); DistMatrix<T,MR,MC > D1_MR_MC(g); DistMatrix<T,MC,MR > D1(g); // Start the algorithm Scal( beta, C ); LockedPartitionRight( B, BL, BR, 0 ); PartitionRight( C, CL, CR, 0 ); while( BR.Width() > 0 ) { LockedRepartitionRight ( BL, /**/ BR, B0, /**/ B1, B2 ); RepartitionRight ( CL, /**/ CR, C0, /**/ C1, C2 ); B1_MC_STAR.AlignWith( A ); D1_MR_STAR.AlignWith( A ); D1_MR_STAR.ResizeTo( C1.Height(), C1.Width() ); D1.AlignWith( C1 ); //--------------------------------------------------------------------// B1_MC_STAR = B1; // B1[MC,*] <- B1[MC,MR] // D1[MR,*] := alpha (A1[MC,MR])^T B1[MC,*] // = alpha (A1^T)[MR,MC] B1[MC,*] internal::LocalGemm ( orientationOfA, NORMAL, alpha, A, B1_MC_STAR, (T)0, D1_MR_STAR ); // C1[MC,MR] += scattered & transposed D1[MR,*] summed over grid cols D1_MR_MC.SumScatterFrom( D1_MR_STAR ); D1 = D1_MR_MC; Axpy( (T)1, D1, C1 ); //--------------------------------------------------------------------// B1_MC_STAR.FreeAlignments(); D1_MR_STAR.FreeAlignments(); D1.FreeAlignments(); SlideLockedPartitionRight ( BL, /**/ BR, B0, B1, /**/ B2 ); SlidePartitionRight ( CL, /**/ CR, C0, C1, /**/ C2 ); } #ifndef RELEASE PopCallStack(); #endif }