inline void Symv ( UpperOrLower uplo, T alpha, const DistMatrix<T>& A, const DistMatrix<T>& x, T beta, DistMatrix<T>& y, bool conjugate=false ) { #ifndef RELEASE CallStackEntry entry("Symv"); if( A.Grid() != x.Grid() || x.Grid() != y.Grid() ) throw std::logic_error ("{A,x,y} must be distributed over the same grid"); if( A.Height() != A.Width() ) throw std::logic_error("A must be square"); if( ( x.Width() != 1 && x.Height() != 1 ) || ( y.Width() != 1 && y.Height() != 1 ) ) throw std::logic_error("x and y are assumed to be vectors"); const int xLength = ( x.Width()==1 ? x.Height() : x.Width() ); const int yLength = ( y.Width()==1 ? y.Height() : y.Width() ); if( A.Height() != xLength || A.Height() != yLength ) { std::ostringstream msg; msg << "Nonconformal Symv: \n" << " A ~ " << A.Height() << " x " << A.Width() << "\n" << " x ~ " << x.Height() << " x " << x.Width() << "\n" << " y ~ " << y.Height() << " x " << y.Width() << "\n"; throw std::logic_error( msg.str() ); } #endif const Grid& g = A.Grid(); if( x.Width() == 1 && y.Width() == 1 ) { // Temporary distributions DistMatrix<T,MC,STAR> x_MC_STAR(g), z_MC_STAR(g); DistMatrix<T,MR,STAR> x_MR_STAR(g), z_MR_STAR(g); DistMatrix<T,MR,MC > z_MR_MC(g); DistMatrix<T> z(g); // Begin the algoritm Scale( beta, y ); x_MC_STAR.AlignWith( A ); x_MR_STAR.AlignWith( A ); z_MC_STAR.AlignWith( A ); z_MR_STAR.AlignWith( A ); z.AlignWith( y ); Zeros( z_MC_STAR, y.Height(), 1 ); Zeros( z_MR_STAR, y.Height(), 1 ); //--------------------------------------------------------------------// x_MC_STAR = x; x_MR_STAR = x_MC_STAR; if( uplo == LOWER ) { internal::LocalSymvColAccumulateL ( alpha, A, x_MC_STAR, x_MR_STAR, z_MC_STAR, z_MR_STAR, conjugate ); } else { internal::LocalSymvColAccumulateU ( alpha, A, x_MC_STAR, x_MR_STAR, z_MC_STAR, z_MR_STAR, conjugate ); } z_MR_MC.SumScatterFrom( z_MR_STAR ); z = z_MR_MC; z.SumScatterUpdate( T(1), z_MC_STAR ); Axpy( T(1), z, y ); //--------------------------------------------------------------------// x_MC_STAR.FreeAlignments(); x_MR_STAR.FreeAlignments(); z_MC_STAR.FreeAlignments(); z_MR_STAR.FreeAlignments(); z.FreeAlignments(); } else if( x.Width() == 1 ) { // Temporary distributions DistMatrix<T,MC,STAR> x_MC_STAR(g), z_MC_STAR(g); DistMatrix<T,MR,STAR> x_MR_STAR(g), z_MR_STAR(g); DistMatrix<T,MR,MC > z_MR_MC(g); DistMatrix<T> z(g), zTrans(g); // Begin the algoritm Scale( beta, y ); x_MC_STAR.AlignWith( A ); x_MR_STAR.AlignWith( A ); z_MC_STAR.AlignWith( A ); z_MR_STAR.AlignWith( A ); z.AlignWith( y ); z_MR_MC.AlignWith( y ); Zeros( z_MC_STAR, y.Width(), 1 ); Zeros( z_MR_STAR, y.Width(), 1 ); //--------------------------------------------------------------------// x_MC_STAR = x; x_MR_STAR = x_MC_STAR; if( uplo == LOWER ) { internal::LocalSymvColAccumulateL ( alpha, A, x_MC_STAR, x_MR_STAR, z_MC_STAR, z_MR_STAR, conjugate ); } else { internal::LocalSymvColAccumulateU ( alpha, A, x_MC_STAR, x_MR_STAR, z_MC_STAR, z_MR_STAR, conjugate ); } z.SumScatterFrom( z_MC_STAR ); z_MR_MC = z; z_MR_MC.SumScatterUpdate( T(1), z_MR_STAR ); Transpose( z_MR_MC, zTrans ); Axpy( T(1), zTrans, y ); //--------------------------------------------------------------------// x_MC_STAR.FreeAlignments(); x_MR_STAR.FreeAlignments(); z_MC_STAR.FreeAlignments(); z_MR_STAR.FreeAlignments(); z.FreeAlignments(); z_MR_MC.FreeAlignments(); } else if( y.Width() == 1 ) { // Temporary distributions DistMatrix<T,STAR,MC> x_STAR_MC(g), z_STAR_MC(g); DistMatrix<T,STAR,MR> x_STAR_MR(g), z_STAR_MR(g); DistMatrix<T,MR, MC> z_MR_MC(g); DistMatrix<T> z(g), zTrans(g); // Begin the algoritm Scale( beta, y ); x_STAR_MC.AlignWith( A ); x_STAR_MR.AlignWith( A ); z_STAR_MC.AlignWith( A ); z_STAR_MR.AlignWith( A ); z.AlignWith( y ); z_MR_MC.AlignWith( y ); Zeros( z_STAR_MC, 1, y.Height() ); Zeros( z_STAR_MR, 1, y.Height() ); //--------------------------------------------------------------------// x_STAR_MR = x; x_STAR_MC = x_STAR_MR; if( uplo == LOWER ) { internal::LocalSymvRowAccumulateL ( alpha, A, x_STAR_MC, x_STAR_MR, z_STAR_MC, z_STAR_MR, conjugate ); } else { internal::LocalSymvRowAccumulateU ( alpha, A, x_STAR_MC, x_STAR_MR, z_STAR_MC, z_STAR_MR, conjugate ); } z.SumScatterFrom( z_STAR_MR ); z_MR_MC = z; z_MR_MC.SumScatterUpdate( T(1), z_STAR_MC ); Transpose( z_MR_MC, zTrans ); Axpy( T(1), zTrans, y ); //--------------------------------------------------------------------// x_STAR_MC.FreeAlignments(); x_STAR_MR.FreeAlignments(); z_STAR_MC.FreeAlignments(); z_STAR_MR.FreeAlignments(); z.FreeAlignments(); z_MR_MC.FreeAlignments(); } else { // Temporary distributions DistMatrix<T,STAR,MC> x_STAR_MC(g), z_STAR_MC(g); DistMatrix<T,STAR,MR> x_STAR_MR(g), z_STAR_MR(g); DistMatrix<T,MR, MC> z_MR_MC(g); DistMatrix<T> z(g); // Begin the algoritm Scale( beta, y ); x_STAR_MC.AlignWith( A ); x_STAR_MR.AlignWith( A ); z_STAR_MC.AlignWith( A ); z_STAR_MR.AlignWith( A ); z.AlignWith( y ); z_MR_MC.AlignWith( y ); Zeros( z_STAR_MC, 1, y.Width() ); Zeros( z_STAR_MR, 1, y.Width() ); //--------------------------------------------------------------------// x_STAR_MR = x; x_STAR_MC = x_STAR_MR; if( uplo == LOWER ) { internal::LocalSymvRowAccumulateL ( alpha, A, x_STAR_MC, x_STAR_MR, z_STAR_MC, z_STAR_MR, conjugate ); } else { internal::LocalSymvRowAccumulateU ( alpha, A, x_STAR_MC, x_STAR_MR, z_STAR_MC, z_STAR_MR, conjugate ); } z_MR_MC.SumScatterFrom( z_STAR_MC ); z = z_MR_MC; z.SumScatterUpdate( T(1), z_STAR_MR ); Axpy( T(1), z, y ); //--------------------------------------------------------------------// x_STAR_MC.FreeAlignments(); x_STAR_MR.FreeAlignments(); z_STAR_MC.FreeAlignments(); z_STAR_MR.FreeAlignments(); z.FreeAlignments(); z_MR_MC.FreeAlignments(); } }
inline void TrsvLT ( Orientation orientation, UnitOrNonUnit diag, const DistMatrix<F>& L, DistMatrix<F>& x ) { #ifndef RELEASE PushCallStack("internal::TrsvLT"); if( L.Grid() != x.Grid() ) throw std::logic_error("{L,x} must be distributed over the same grid"); if( orientation == NORMAL ) throw std::logic_error("TrsvLT expects a (conjugate-)transpose option"); if( L.Height() != L.Width() ) throw std::logic_error("L must be square"); if( x.Width() != 1 && x.Height() != 1 ) throw std::logic_error("x must be a vector"); const int xLength = ( x.Width() == 1 ? x.Height() : x.Width() ); if( L.Width() != xLength ) throw std::logic_error("Nonconformal TrsvLT"); #endif const Grid& g = L.Grid(); if( x.Width() == 1 ) { // Matrix views DistMatrix<F> L10(g), L11(g); DistMatrix<F> xT(g), x0(g), xB(g), x1(g), x2(g); // Temporary distributions DistMatrix<F,STAR,STAR> L11_STAR_STAR(g); DistMatrix<F,STAR,STAR> x1_STAR_STAR(g); DistMatrix<F,MC, STAR> x1_MC_STAR(g); DistMatrix<F,MC, MR > z1(g); DistMatrix<F,MR, MC > z1_MR_MC(g); DistMatrix<F,MR, STAR> z_MR_STAR(g); // Views of z[MR,* ] DistMatrix<F,MR,STAR> z0_MR_STAR(g), z1_MR_STAR(g); z_MR_STAR.AlignWith( L ); Zeros( x.Height(), 1, z_MR_STAR ); // Start the algorithm PartitionUp ( x, xT, xB, 0 ); while( xT.Height() > 0 ) { RepartitionUp ( xT, x0, x1, /**/ /**/ xB, x2 ); const int n0 = x0.Height(); const int n1 = x1.Height(); LockedView( L10, L, n0, 0, n1, n0 ); LockedView( L11, L, n0, n0, n1, n1 ); View( z0_MR_STAR, z_MR_STAR, 0, 0, n0, 1 ); View( z1_MR_STAR, z_MR_STAR, n0, 0, n1, 1 ); x1_MC_STAR.AlignWith( L10 ); z1.AlignWith( x1 ); //----------------------------------------------------------------// if( x2.Height() != 0 ) { z1_MR_MC.SumScatterFrom( z1_MR_STAR ); z1 = z1_MR_MC; Axpy( F(1), z1, x1 ); } x1_STAR_STAR = x1; L11_STAR_STAR = L11; Trsv ( LOWER, orientation, diag, L11_STAR_STAR.LockedMatrix(), x1_STAR_STAR.Matrix() ); x1 = x1_STAR_STAR; x1_MC_STAR = x1_STAR_STAR; Gemv ( orientation, F(-1), L10.LockedMatrix(), x1_MC_STAR.LockedMatrix(), F(1), z0_MR_STAR.Matrix() ); //----------------------------------------------------------------// x1_MC_STAR.FreeAlignments(); z1.FreeAlignments(); SlidePartitionUp ( xT, x0, /**/ /**/ x1, xB, x2 ); } } else { // Matrix views DistMatrix<F> L10(g), L11(g); DistMatrix<F> xL(g), xR(g), x0(g), x1(g), x2(g); // Temporary distributions DistMatrix<F,STAR,STAR> L11_STAR_STAR(g); DistMatrix<F,STAR,STAR> x1_STAR_STAR(g); DistMatrix<F,STAR,MC > x1_STAR_MC(g); DistMatrix<F,STAR,MR > z_STAR_MR(g); // Views of z[* ,MR], which will store updates to x DistMatrix<F,STAR,MR> z0_STAR_MR(g), z1_STAR_MR(g); z_STAR_MR.AlignWith( L ); Zeros( 1, x.Width(), z_STAR_MR ); // Start the algorithm PartitionLeft( x, xL, xR, 0 ); while( xL.Width() > 0 ) { RepartitionLeft ( xL, /**/ xR, x0, x1, /**/ x2 ); const int n0 = x0.Width(); const int n1 = x1.Width(); LockedView( L10, L, n0, 0, n1, n0 ); LockedView( L11, L, n0, n0, n1, n1 ); View( z0_STAR_MR, z_STAR_MR, 0, 0, 1, n0 ); View( z1_STAR_MR, z_STAR_MR, 0, n0, 1, n1 ); x1_STAR_MC.AlignWith( L10 ); //----------------------------------------------------------------// if( x2.Width() != 0 ) x1.SumScatterUpdate( F(1), z1_STAR_MR ); x1_STAR_STAR = x1; L11_STAR_STAR = L11; Trsv ( LOWER, orientation, diag, L11_STAR_STAR.LockedMatrix(), x1_STAR_STAR.Matrix() ); x1 = x1_STAR_STAR; x1_STAR_MC = x1_STAR_STAR; Gemv ( orientation, F(-1), L10.LockedMatrix(), x1_STAR_MC.LockedMatrix(), F(1), z0_STAR_MR.Matrix() ); //----------------------------------------------------------------// x1_STAR_MC.FreeAlignments(); SlidePartitionLeft ( xL, /**/ xR, x0, /**/ x1, x2 ); } } #ifndef RELEASE PopCallStack(); #endif }