void L1DistanceMatrix(direction_t dirA, direction_t dirB, T alpha, const El::ElementalMatrix<T> &APre, const El::ElementalMatrix<T> &BPre, T beta, El::ElementalMatrix<T> &CPre) { if (dirA == base::COLUMNS && dirB == base::COLUMNS) { // Use a SUMMA-like routine, with C as stationary // Basically an adaptation of Elementals TN case for stationary C. const El::Int m = CPre.Height(); const El::Int n = CPre.Width(); const El::Int sumDim = BPre.Height(); const El::Int bsize = El::Blocksize(); const El::Grid& g = APre.Grid(); El::DistMatrixReadProxy<T, T, El::MC, El::MR> AProx(APre); El::DistMatrixReadProxy<T, T, El::MC, El::MR> BProx(BPre); El::DistMatrixReadWriteProxy<T, T, El::MC, El::MR> CProx(CPre); auto& A = AProx.GetLocked(); auto& B = BProx.GetLocked(); auto& C = CProx.Get(); // Temporary distributions El::DistMatrix<T, El::STAR, El::MC> A1_STAR_MC(g); El::DistMatrix<T, El::STAR, El::MR> B1_STAR_MR(g); A1_STAR_MC.AlignWith(C); B1_STAR_MR.AlignWith(C); El::Scale(beta, C); for(El::Int k = 0; k < sumDim; k += bsize) { const El::Int nb = std::min(bsize,sumDim-k); auto A1 = A(El::IR(k,k+nb), El::IR(0,m)); auto B1 = B(El::IR(k,k+nb), El::IR(0,n)); A1_STAR_MC = A1; B1_STAR_MR = B1; L1DistanceMatrix(base::COLUMNS, base::COLUMNS, alpha, A1_STAR_MC.LockedMatrix(), B1_STAR_MR.LockedMatrix(), T(1.0), C.Matrix()); } } // TODO the rest of the cases. }
inline void internal::GemmTNC ( Orientation orientationOfA, T alpha, const DistMatrix<T,MC,MR>& A, const DistMatrix<T,MC,MR>& B, T beta, DistMatrix<T,MC,MR>& C ) { #ifndef RELEASE PushCallStack("internal::GemmTNC"); if( A.Grid() != B.Grid() || B.Grid() != C.Grid() ) throw std::logic_error ("{A,B,C} must be distributed over the same grid"); if( orientationOfA == NORMAL ) throw std::logic_error("GemmTNC assumes A is (Conjugate)Transposed"); if( A.Width() != C.Height() || B.Width() != C.Width() || A.Height() != B.Height() ) { std::ostringstream msg; msg << "Nonconformal GemmTNC: \n" << " A ~ " << A.Height() << " x " << A.Width() << "\n" << " B ~ " << B.Height() << " x " << B.Width() << "\n" << " C ~ " << C.Height() << " x " << C.Width() << "\n"; throw std::logic_error( msg.str().c_str() ); } #endif const Grid& g = A.Grid(); // Matrix views DistMatrix<T,MC,MR> AT(g), A0(g), AB(g), A1(g), A2(g); DistMatrix<T,MC,MR> BT(g), B0(g), BB(g), B1(g), B2(g); // Temporary distributions DistMatrix<T,STAR,MC> A1_STAR_MC(g); DistMatrix<T,STAR,MR> B1_STAR_MR(g); // Start the algorithm Scal( beta, C ); LockedPartitionDown ( A, AT, AB, 0 ); LockedPartitionDown ( B, BT, BB, 0 ); while( AB.Height() > 0 ) { LockedRepartitionDown ( AT, A0, /**/ /**/ A1, AB, A2 ); LockedRepartitionDown ( BT, B0, /**/ /**/ B1, BB, B2 ); A1_STAR_MC.AlignWith( C ); B1_STAR_MR.AlignWith( C ); //--------------------------------------------------------------------// A1_STAR_MC = A1; // A1[*,MC] <- A1[MC,MR] B1_STAR_MR = B1; // B1[*,MR] <- B1[MC,MR] // C[MC,MR] += alpha (A1[*,MC])^T B1[*,MR] // = alpha (A1^T)[MC,*] B1[*,MR] internal::LocalGemm ( orientationOfA, NORMAL, alpha, A1_STAR_MC, B1_STAR_MR, (T)1, C ); //--------------------------------------------------------------------// A1_STAR_MC.FreeAlignments(); B1_STAR_MR.FreeAlignments(); SlideLockedPartitionDown ( AT, A0, A1, /**/ /**/ AB, A2 ); SlideLockedPartitionDown ( BT, B0, B1, /**/ /**/ BB, B2 ); } #ifndef RELEASE PopCallStack(); #endif }
inline void RowEchelon( DistMatrix<F>& A, DistMatrix<F>& B ) { #ifndef RELEASE CallStackEntry entry("RowEchelon"); if( A.Grid() != B.Grid() ) LogicError("{A,B} must be distributed over the same grid"); if( A.Height() != B.Height() ) LogicError("A and B must be the same height"); #endif const Grid& g = A.Grid(); // Matrix views DistMatrix<F> ATL(g), ATR(g), A00(g), A01(g), A02(g), APan(g), ABL(g), ABR(g), A10(g), A11(g), A12(g), A20(g), A21(g), A22(g); DistMatrix<F> BT(g), B0(g), BB(g), B1(g), B2(g); // Temporary distributions DistMatrix<F,STAR,STAR> A11_STAR_STAR(g); DistMatrix<F,STAR,VR > A12_STAR_VR(g); DistMatrix<F,STAR,MR > A12_STAR_MR(g); DistMatrix<F,MC, STAR> A21_MC_STAR(g); DistMatrix<F,STAR,VR > B1_STAR_VR(g); DistMatrix<F,STAR,MR > B1_STAR_MR(g); DistMatrix<Int,STAR,STAR> p1_STAR_STAR(g); // In case B's columns are not aligned with A's const bool BAligned = ( B.ColShift() == A.ColShift() ); DistMatrix<F,MC,STAR> A21_MC_STAR_B(g); // Pivot composition std::vector<Int> image, preimage; // Start the algorithm PartitionDownDiagonal ( A, ATL, ATR, ABL, ABR, 0 ); PartitionDown ( B, BT, BB, 0 ); while( ATL.Height() < A.Height() && ATL.Width() < A.Width() ) { RepartitionDownDiagonal ( ATL, /**/ ATR, A00, /**/ A01, A02, /*************/ /******************/ /**/ A10, /**/ A11, A12, ABL, /**/ ABR, A20, /**/ A21, A22 ); RepartitionDown ( BT, B0, /**/ /**/ B1, BB, B2 ); View2x1 ( APan, A12, A22 ); A12_STAR_VR.AlignWith( A22 ); A12_STAR_MR.AlignWith( A22 ); A21_MC_STAR.AlignWith( A22 ); B1_STAR_VR.AlignWith( B1 ); B1_STAR_MR.AlignWith( B1 ); if( ! BAligned ) A21_MC_STAR_B.AlignWith( B2 ); p1_STAR_STAR.ResizeTo( A11.Height(), 1 ); //--------------------------------------------------------------------// A11_STAR_STAR = A11; A21_MC_STAR = A21; lu::Panel( A11_STAR_STAR, A21_MC_STAR, p1_STAR_STAR, A00.Height() ); ComposePivots( p1_STAR_STAR, A00.Height(), image, preimage ); ApplyRowPivots( APan, image, preimage ); ApplyRowPivots( BB, image, preimage ); A12_STAR_VR = A12; B1_STAR_VR = B1; LocalTrsm ( LEFT, LOWER, NORMAL, UNIT, F(1), A11_STAR_STAR, A12_STAR_VR ); LocalTrsm( LEFT, LOWER, NORMAL, UNIT, F(1), A11_STAR_STAR, B1_STAR_VR ); A12_STAR_MR = A12_STAR_VR; B1_STAR_MR = B1_STAR_VR; LocalGemm( NORMAL, NORMAL, F(-1), A21_MC_STAR, A12_STAR_MR, F(1), A22 ); if( BAligned ) { LocalGemm ( NORMAL, NORMAL, F(-1), A21_MC_STAR, B1_STAR_MR, F(1), B2 ); } else { A21_MC_STAR_B = A21_MC_STAR; LocalGemm ( NORMAL, NORMAL, F(-1), A21_MC_STAR_B, B1_STAR_MR, F(1), B2 ); } A11 = A11_STAR_STAR; A12 = A12_STAR_MR; B1 = B1_STAR_MR; //--------------------------------------------------------------------// SlidePartitionDownDiagonal ( ATL, /**/ ATR, A00, A01, /**/ A02, /**/ A10, A11, /**/ A12, /*************/ /******************/ ABL, /**/ ABR, A20, A21, /**/ A22 ); SlidePartitionDown ( BT, B0, B1, /**/ /**/ BB, B2 ); } }