inline void HemmLUA ( T alpha, const DistMatrix<T>& A, const DistMatrix<T>& B, T beta, DistMatrix<T>& C ) { #ifndef RELEASE PushCallStack("internal::HemmLUA"); if( A.Grid() != B.Grid() || B.Grid() != C.Grid() ) throw std::logic_error ("{A,B,C} must be distributed over the same grid"); #endif const Grid& g = A.Grid(); DistMatrix<T> BL(g), BR(g), B0(g), B1(g), B2(g); DistMatrix<T> CL(g), CR(g), C0(g), C1(g), C2(g); DistMatrix<T,MC,STAR> B1_MC_STAR(g); DistMatrix<T,VR,STAR> B1_VR_STAR(g); DistMatrix<T,STAR,MR> B1Adj_STAR_MR(g); DistMatrix<T,MC,STAR> Z1_MC_STAR(g); DistMatrix<T,MR,STAR> Z1_MR_STAR(g); DistMatrix<T,MR,MC > Z1_MR_MC(g); DistMatrix<T> Z1(g); B1_MC_STAR.AlignWith( A ); B1_VR_STAR.AlignWith( A ); B1Adj_STAR_MR.AlignWith( A ); Z1_MC_STAR.AlignWith( A ); Z1_MR_STAR.AlignWith( A ); Scale( beta, C ); LockedPartitionRight ( B, BL, BR, 0 ); PartitionRight ( C, CL, CR, 0 ); while( CL.Width() < C.Width() ) { LockedRepartitionRight ( BL, /**/ BR, B0, /**/ B1, B2 ); RepartitionRight ( CL, /**/ CR, C0, /**/ C1, C2 ); Z1.AlignWith( C1 ); Zeros( C1.Height(), C1.Width(), Z1_MC_STAR ); Zeros( C1.Height(), C1.Width(), Z1_MR_STAR ); //--------------------------------------------------------------------// B1_MC_STAR = B1; B1_VR_STAR = B1_MC_STAR; B1Adj_STAR_MR.AdjointFrom( B1_VR_STAR ); LocalSymmetricAccumulateLU ( ADJOINT, alpha, A, B1_MC_STAR, B1Adj_STAR_MR, Z1_MC_STAR, Z1_MR_STAR ); Z1_MR_MC.SumScatterFrom( Z1_MR_STAR ); Z1 = Z1_MR_MC; Z1.SumScatterUpdate( T(1), Z1_MC_STAR ); Axpy( T(1), Z1, C1 ); //--------------------------------------------------------------------// Z1.FreeAlignments(); SlideLockedPartitionRight ( BL, /**/ BR, B0, B1, /**/ B2 ); SlidePartitionRight ( CL, /**/ CR, C0, C1, /**/ C2 ); } #ifndef RELEASE PopCallStack(); #endif }
inline void TwoSidedTrsmUVar1 ( UnitOrNonUnit diag, DistMatrix<F>& A, const DistMatrix<F>& U ) { #ifndef RELEASE CallStackEntry entry("internal::TwoSidedTrsmUVar1"); if( A.Height() != A.Width() ) LogicError("A must be square"); if( U.Height() != U.Width() ) LogicError("Triangular matrices must be square"); if( A.Height() != U.Height() ) LogicError("A and U must be the same size"); #endif const Grid& g = A.Grid(); // Matrix views DistMatrix<F> ATL(g), ATR(g), A00(g), A01(g), A02(g), ABL(g), ABR(g), A10(g), A11(g), A12(g), A20(g), A21(g), A22(g); DistMatrix<F> UTL(g), UTR(g), U00(g), U01(g), U02(g), UBL(g), UBR(g), U10(g), U11(g), U12(g), U20(g), U21(g), U22(g); // Temporary distributions DistMatrix<F,STAR,STAR> A11_STAR_STAR(g); DistMatrix<F,VC, STAR> A01_VC_STAR(g); DistMatrix<F,STAR,STAR> U11_STAR_STAR(g); DistMatrix<F,MC, STAR> U01_MC_STAR(g); DistMatrix<F,VC, STAR> U01_VC_STAR(g); DistMatrix<F,VR, STAR> U01_VR_STAR(g); DistMatrix<F,STAR,MR > U01Adj_STAR_MR(g); DistMatrix<F,STAR,STAR> X11_STAR_STAR(g); DistMatrix<F,MR, MC > Z01_MR_MC(g); DistMatrix<F,MC, STAR> Z01_MC_STAR(g); DistMatrix<F,MR, STAR> Z01_MR_STAR(g); DistMatrix<F> Y01(g); PartitionDownDiagonal ( A, ATL, ATR, ABL, ABR, 0 ); LockedPartitionDownDiagonal ( U, UTL, UTR, UBL, UBR, 0 ); while( ATL.Height() < A.Height() ) { RepartitionDownDiagonal ( ATL, /**/ ATR, A00, /**/ A01, A02, /*************/ /******************/ /**/ A10, /**/ A11, A12, ABL, /**/ ABR, A20, /**/ A21, A22 ); LockedRepartitionDownDiagonal ( UTL, /**/ UTR, U00, /**/ U01, U02, /*************/ /******************/ /**/ U10, /**/ U11, U12, UBL, /**/ UBR, U20, /**/ U21, U22 ); A01_VC_STAR.AlignWith( A01 ); U01_MC_STAR.AlignWith( A00 ); U01_VR_STAR.AlignWith( A00 ); U01_VC_STAR.AlignWith( A00 ); U01Adj_STAR_MR.AlignWith( A00 ); Y01.AlignWith( A01 ); Z01_MR_MC.AlignWith( A01 ); Z01_MC_STAR.AlignWith( A00 ); Z01_MR_STAR.AlignWith( A00 ); //--------------------------------------------------------------------// // Y01 := A00 U01 U01_MC_STAR = U01; U01_VR_STAR = U01_MC_STAR; U01Adj_STAR_MR.AdjointFrom( U01_VR_STAR ); Zeros( Z01_MC_STAR, A01.Height(), A01.Width() ); Zeros( Z01_MR_STAR, A01.Height(), A01.Width() ); LocalSymmetricAccumulateLU ( ADJOINT, F(1), A00, U01_MC_STAR, U01Adj_STAR_MR, Z01_MC_STAR, Z01_MR_STAR ); Z01_MR_MC.SumScatterFrom( Z01_MR_STAR ); Y01 = Z01_MR_MC; Y01.SumScatterUpdate( F(1), Z01_MC_STAR ); // A01 := inv(U00)' A01 // // This is the bottleneck because A01 only has blocksize columns Trsm( LEFT, UPPER, ADJOINT, diag, F(1), U00, A01 ); // A01 := A01 - 1/2 Y01 Axpy( F(-1)/F(2), Y01, A01 ); // A11 := A11 - (U01' A01 + A01' U01) A01_VC_STAR = A01; U01_VC_STAR = U01_MC_STAR; Zeros( X11_STAR_STAR, A11.Height(), A11.Width() ); Her2k ( UPPER, ADJOINT, F(-1), A01_VC_STAR.Matrix(), U01_VC_STAR.Matrix(), F(0), X11_STAR_STAR.Matrix() ); A11.SumScatterUpdate( F(1), X11_STAR_STAR ); // A11 := inv(U11)' A11 inv(U11) A11_STAR_STAR = A11; U11_STAR_STAR = U11; LocalTwoSidedTrsm( UPPER, diag, A11_STAR_STAR, U11_STAR_STAR ); A11 = A11_STAR_STAR; // A01 := A01 - 1/2 Y01 Axpy( F(-1)/F(2), Y01, A01 ); // A01 := A01 inv(U11) A01_VC_STAR = A01; LocalTrsm ( RIGHT, UPPER, NORMAL, diag, F(1), U11_STAR_STAR, A01_VC_STAR ); A01 = A01_VC_STAR; //--------------------------------------------------------------------// SlidePartitionDownDiagonal ( ATL, /**/ ATR, A00, A01, /**/ A02, /**/ A10, A11, /**/ A12, /*************/ /******************/ ABL, /**/ ABR, A20, A21, /**/ A22 ); SlideLockedPartitionDownDiagonal ( UTL, /**/ UTR, U00, U01, /**/ U02, /**/ U10, U11, /**/ U12, /*************/ /******************/ UBL, /**/ UBR, U20, U21, /**/ U22 ); } }
inline void TwoSidedTrsmUVar2 ( UnitOrNonUnit diag, DistMatrix<F>& A, const DistMatrix<F>& U ) { #ifndef RELEASE CallStackEntry entry("internal::TwoSidedTrsmUVar2"); if( A.Height() != A.Width() ) LogicError("A must be square"); if( U.Height() != U.Width() ) LogicError("Triangular matrices must be square"); if( A.Height() != U.Height() ) LogicError("A and U must be the same size"); #endif const Grid& g = A.Grid(); // Matrix views DistMatrix<F> ATL(g), ATR(g), A00(g), A01(g), A02(g), ABL(g), ABR(g), A10(g), A11(g), A12(g), A20(g), A21(g), A22(g); DistMatrix<F> UTL(g), UTR(g), U00(g), U01(g), U02(g), UBL(g), UBR(g), U10(g), U11(g), U12(g), U20(g), U21(g), U22(g); // Temporary distributions DistMatrix<F,MC, STAR> A01_MC_STAR(g); DistMatrix<F,VC, STAR> A01_VC_STAR(g); DistMatrix<F,STAR,STAR> A11_STAR_STAR(g); DistMatrix<F,STAR,VR > A12_STAR_VR(g); DistMatrix<F,MC, STAR> F01_MC_STAR(g); DistMatrix<F,MC, STAR> U01_MC_STAR(g); DistMatrix<F,VR, STAR> U01_VR_STAR(g); DistMatrix<F,STAR,MR > U01Adj_STAR_MR(g); DistMatrix<F,STAR,STAR> U11_STAR_STAR(g); DistMatrix<F,STAR,MR > X11_STAR_MR(g); DistMatrix<F,MR, STAR> X12Adj_MR_STAR(g); DistMatrix<F,MR, MC > X12Adj_MR_MC(g); DistMatrix<F,MR, MC > Y01_MR_MC(g); DistMatrix<F,MR, STAR> Y01_MR_STAR(g); DistMatrix<F> X11(g); DistMatrix<F> Y01(g); Matrix<F> X12Local; PartitionDownDiagonal ( A, ATL, ATR, ABL, ABR, 0 ); LockedPartitionDownDiagonal ( U, UTL, UTR, UBL, UBR, 0 ); while( ATL.Height() < A.Height() ) { RepartitionDownDiagonal ( ATL, /**/ ATR, A00, /**/ A01, A02, /*************/ /******************/ /**/ A10, /**/ A11, A12, ABL, /**/ ABR, A20, /**/ A21, A22 ); LockedRepartitionDownDiagonal ( UTL, /**/ UTR, U00, /**/ U01, U02, /*************/ /******************/ /**/ U10, /**/ U11, U12, UBL, /**/ UBR, U20, /**/ U21, U22 ); A01_MC_STAR.AlignWith( U01 ); Y01.AlignWith( A01 ); Y01_MR_STAR.AlignWith( A00 ); U01_MC_STAR.AlignWith( A00 ); U01_VR_STAR.AlignWith( A00 ); U01Adj_STAR_MR.AlignWith( A00 ); X11_STAR_MR.AlignWith( U01 ); X11.AlignWith( A11 ); X12Adj_MR_STAR.AlignWith( A02 ); X12Adj_MR_MC.AlignWith( A12 ); F01_MC_STAR.AlignWith( A00 ); //--------------------------------------------------------------------// // Y01 := A00 U01 U01_MC_STAR = U01; U01_VR_STAR = U01_MC_STAR; U01Adj_STAR_MR.AdjointFrom( U01_VR_STAR ); Zeros( Y01_MR_STAR, A01.Height(), A01.Width() ); Zeros( F01_MC_STAR, A01.Height(), A01.Width() ); LocalSymmetricAccumulateLU ( ADJOINT, F(1), A00, U01_MC_STAR, U01Adj_STAR_MR, F01_MC_STAR, Y01_MR_STAR ); Y01_MR_MC.SumScatterFrom( Y01_MR_STAR ); Y01 = Y01_MR_MC; Y01.SumScatterUpdate( F(1), F01_MC_STAR ); // X11 := U01' A01 LocalGemm( ADJOINT, NORMAL, F(1), U01_MC_STAR, A01, X11_STAR_MR ); // A01 := A01 - Y01 Axpy( F(-1), Y01, A01 ); A01_MC_STAR = A01; // A11 := A11 - triu(X11 + A01' U01) = A11 - (U01 A01 + A01' U01) LocalGemm( ADJOINT, NORMAL, F(1), A01_MC_STAR, U01, F(1), X11_STAR_MR ); X11.SumScatterFrom( X11_STAR_MR ); MakeTriangular( UPPER, X11 ); Axpy( F(-1), X11, A11 ); // A01 := A01 inv(U11) U11_STAR_STAR = U11; A01_VC_STAR = A01_MC_STAR; LocalTrsm ( RIGHT, UPPER, NORMAL, diag, F(1), U11_STAR_STAR, A01_VC_STAR ); A01 = A01_VC_STAR; // A11 := inv(U11)' A11 inv(U11) A11_STAR_STAR = A11; LocalTwoSidedTrsm( UPPER, diag, A11_STAR_STAR, U11_STAR_STAR ); A11 = A11_STAR_STAR; // A12 := A12 - A02' U01 LocalGemm( ADJOINT, NORMAL, F(1), A02, U01_MC_STAR, X12Adj_MR_STAR ); X12Adj_MR_MC.SumScatterFrom( X12Adj_MR_STAR ); Adjoint( X12Adj_MR_MC.LockedMatrix(), X12Local ); Axpy( F(-1), X12Local, A12.Matrix() ); // A12 := inv(U11)' A12 A12_STAR_VR = A12; LocalTrsm ( LEFT, UPPER, ADJOINT, diag, F(1), U11_STAR_STAR, A12_STAR_VR ); A12 = A12_STAR_VR; //--------------------------------------------------------------------// SlidePartitionDownDiagonal ( ATL, /**/ ATR, A00, A01, /**/ A02, /**/ A10, A11, /**/ A12, /*************/ /******************/ ABL, /**/ ABR, A20, A21, /**/ A22 ); SlideLockedPartitionDownDiagonal ( UTL, /**/ UTR, U00, U01, /**/ U02, /**/ U10, U11, /**/ U12, /*************/ /******************/ UBL, /**/ UBR, U20, U21, /**/ U22 ); } }