コード例 #1
0
ファイル: HPSDCholesky.hpp プロジェクト: ahmadia/elemental
void MakeExplicitlyHermitian( UpperOrLower uplo, DistMatrix<F,MC,MR>& A )
{
    const Grid& g = A.Grid();
    DistMatrix<F,MC,MR> ATL(g), ATR(g),  A00(g), A01(g), A02(g),
                        ABL(g), ABR(g),  A10(g), A11(g), A12(g),
                                         A20(g), A21(g), A22(g);
    DistMatrix<F,MC,MR> A11Adj(g);
    DistMatrix<F,MR,MC> A11_MR_MC(g);
    DistMatrix<F,MR,MC> A21_MR_MC(g);
    DistMatrix<F,MR,MC> A12_MR_MC(g);

    PartitionDownDiagonal
    ( A, ATL, ATR,
         ABL, ABR, 0 );
    while( ATL.Height() < A.Height() )
    {
        RepartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, /**/ A01, A02,
         /*************/ /******************/
               /**/       A10, /**/ A11, A12,
          ABL, /**/ ABR,  A20, /**/ A21, A22 );

        A11Adj.AlignWith( A11 );
        A11_MR_MC.AlignWith( A11 );
        A12_MR_MC.AlignWith( A21 );
        A21_MR_MC.AlignWith( A12 );
        //--------------------------------------------------------------------//
        A11_MR_MC = A11;
        A11Adj.ResizeTo( A11.Height(), A11.Width() );
        Adjoint( A11_MR_MC.LocalMatrix(), A11Adj.LocalMatrix() );

        if( uplo == LOWER )
        {
            MakeTrapezoidal( LEFT, UPPER, 1, A11Adj );
            Axpy( (F)1, A11Adj, A11 );

            A21_MR_MC = A21;
            Adjoint( A21_MR_MC.LocalMatrix(), A12.LocalMatrix() ); 
        }
        else
        {
            MakeTrapezoidal( LEFT, LOWER, -1, A11Adj );
            Axpy( (F)1, A11Adj, A11 );

            A12_MR_MC = A12;
            Adjoint( A12_MR_MC.LocalMatrix(), A21.LocalMatrix() );
        }
        //--------------------------------------------------------------------//
        A21_MR_MC.FreeAlignments();
        A12_MR_MC.FreeAlignments();
        A11_MR_MC.FreeAlignments();
        A11Adj.FreeAlignments();

        SlidePartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, A01, /**/ A02,
               /**/       A10, A11, /**/ A12,
         /*************/ /******************/
          ABL, /**/ ABR,  A20, A21, /**/ A22 );
    }
}
コード例 #2
0
void Transform2x2
( const Matrix<T>& G,
        AbstractDistMatrix<T>& a1,
        AbstractDistMatrix<T>& a2 )
{
    DEBUG_CSE
    typedef unique_ptr<AbstractDistMatrix<T>> ADMPtr;

    // TODO: Optimize by attempting SendRecv when possible

    ADMPtr a1_like_a2( a2.Construct( a2.Grid(), a2.Root() ) );
    a1_like_a2->AlignWith( DistData(a2) );
    Copy( a1, *a1_like_a2 );

    ADMPtr a2_like_a1( a1.Construct( a1.Grid(), a1.Root() ) );
    a2_like_a1->AlignWith( DistData(a1) );
    Copy( a2, *a2_like_a1 );

    // TODO: Generalized axpy?
    Scale( G(0,0), a1 );
    Axpy( G(0,1), *a2_like_a1, a1 );

    // TODO: Generalized axpy?
    Scale( G(1,1), a2 );
    Axpy( G(1,0), *a1_like_a2, a2 );
}
コード例 #3
0
ファイル: IpVector.cpp プロジェクト: Gjacquenot/simbody
 /* Prototype implementation for specialized functions */
 void Vector::AddTwoVectorsImpl(Number a, const Vector& v1,
                                Number b, const Vector& v2, Number c)
 {
   if (c==0.) {
     if (a==1.) {
       Copy(v1);
       if (b!=0.) {
         Axpy(b, v2);
       }
     }
     else if (a==0.) {
       if (b==0.) {
         Set(0.);
       }
       else {
         Copy(v2);
         if (b!=1.) {
           Scal(b);
         }
       }
     }
     else {
       if (b==1.) {
         Copy(v2);
         Axpy(a, v1);
       }
       else if (b==0.) {
         Copy(v1);
         Scal(a);
       }
       else {
         Copy(v1);
         Scal(a);
         Axpy(b, v2);
       }
     }
   }
   else { /* c==0. */
     if (c!=1.) {
       Scal(c);
     }
     if (a!=0.) {
       Axpy(a, v1);
     }
     if (b!=0.) {
       Axpy(b, v2);
     }
   }
 }
コード例 #4
0
ファイル: Sign.hpp プロジェクト: khalid-hasanov/Elemental
inline void
NewtonStep
( const DistMatrix<F>& X, DistMatrix<F>& XNew, Scaling scaling=FROB_NORM )
{
#ifndef RELEASE
    CallStackEntry entry("sign::NewtonStep");
#endif
    typedef BASE(F) Real;

    // Calculate mu while forming B := inv(X)
    Real mu;
    DistMatrix<Int,VC,STAR> p( X.Grid() );
    XNew = X;
    LU( XNew, p );
    if( scaling == DETERMINANT )
    {
        SafeProduct<F> det = determinant::AfterLUPartialPiv( XNew, p );
        mu = Real(1)/Exp(det.kappa);
    }
    inverse::AfterLUPartialPiv( XNew, p );
    if( scaling == FROB_NORM )
        mu = Sqrt( FrobeniusNorm(XNew)/FrobeniusNorm(X) );
    else if( scaling == NONE )
        mu = 1;
    else
        LogicError("Scaling case not handled");

    // Overwrite XNew with the new iterate
    const Real halfMu = mu/Real(2);
    const Real halfMuInv = Real(1)/(2*mu); 
    Scale( halfMuInv, XNew );
    Axpy( halfMu, X, XNew );
}
コード例 #5
0
ファイル: Sign.cpp プロジェクト: elemental/Elemental
void
NewtonStep
( const DistMatrix<Field>& X,
        DistMatrix<Field>& XNew,
  SignScaling scaling=SIGN_SCALE_FROB )
{
    EL_DEBUG_CSE
    typedef Base<Field> Real;

    // Calculate mu while forming B := inv(X)
    Real mu=1;
    DistPermutation P( X.Grid() );
    XNew = X;
    LU( XNew, P );
    if( scaling == SIGN_SCALE_DET )
    {
        SafeProduct<Field> det = det::AfterLUPartialPiv( XNew, P );
        mu = Real(1)/Exp(det.kappa);
    }
    inverse::AfterLUPartialPiv( XNew, P );
    if( scaling == SIGN_SCALE_FROB )
        mu = Sqrt( FrobeniusNorm(XNew)/FrobeniusNorm(X) );

    // Overwrite XNew with the new iterate
    const Real halfMu = mu/Real(2);
    const Real halfMuInv = Real(1)/(2*mu);
    XNew *= halfMuInv;
    Axpy( halfMu, X, XNew );
}
コード例 #6
0
ファイル: BlockCyclic.cpp プロジェクト: bluehope/Elemental
const BlockCyclicMatrix<T>&
BlockCyclicMatrix<T>::operator-=( const BlockCyclicMatrix<T>& A )
{
    DEBUG_ONLY(CSE cse("BCM::operator-="))
    Axpy( T(-1), A, *this );
    return *this;
}
コード例 #7
0
ファイル: MakeSymmetric.hpp プロジェクト: jimgoo/Elemental
inline void
MakeSymmetric( UpperOrLower uplo, DistMatrix<T>& A )
{
#ifndef RELEASE
    PushCallStack("MakeSymmetric");
#endif
    if( A.Height() != A.Width() )
        throw std::logic_error("Cannot make non-square matrix symmetric");

    const Grid& g = A.Grid();
    DistMatrix<T,MD,STAR> d(g);
    A.GetDiagonal( d );

    if( uplo == LOWER )
        MakeTrapezoidal( LEFT, LOWER, -1, A );
    else
        MakeTrapezoidal( LEFT, UPPER, +1, A );
    DistMatrix<T> ATrans(g);
    Transpose( A, ATrans );
    Axpy( T(1), ATrans, A );

    A.SetDiagonal( d );
#ifndef RELEASE
    PopCallStack();
#endif
}
コード例 #8
0
ファイル: Local.hpp プロジェクト: certik/Elemental
inline void HermitianTridiagU( Matrix<R>& A )
{
#ifndef RELEASE
    PushCallStack("HermitianTridiagU");
    if( A.Height() != A.Width() )
        throw std::logic_error( "A must be square." );
#endif
    // Matrix views 
    Matrix<R>
        ATL, ATR,  A00, a01,     A02,  a01T,
        ABL, ABR,  a10, alpha11, a12,  alpha01B,
                   A20, a21,     A22;

    // Temporary matrices
    Matrix<R> w01;

    PushBlocksizeStack( 1 );
    PartitionUpDiagonal
    ( A, ATL, ATR,
         ABL, ABR, 0 );
    while( ABR.Height()+1 < A.Height() )
    {
        RepartitionUpDiagonal
        ( ATL, /**/ ATR,  A00, a01,     /**/ A02,
               /**/       a10, alpha11, /**/ a12,
         /*************/ /**********************/
          ABL, /**/ ABR,  A20, a21,     /**/ A22 );

        PartitionUp
        ( a01, a01T,
               alpha01B, 1 );

        w01.ResizeTo( a01.Height(), 1 );
        //--------------------------------------------------------------------//
        const R tau = Reflector( alpha01B, a01T );
        const R epsilon1 = alpha01B.Get(0,0);
        alpha01B.Set(0,0,R(1));

        Symv( UPPER, tau, A00, a01, R(0), w01 );
        const R alpha = -tau*Dot( w01, a01 )/R(2);
        Axpy( alpha, a01, w01 );
        Syr2( UPPER, R(-1), a01, w01, A00 );
        alpha01B.Set(0,0,epsilon1);
        //--------------------------------------------------------------------//

        SlidePartitionUpDiagonal
        ( ATL, /**/ ATR,  A00, /**/ a01,     A02,
         /*************/ /**********************/
               /**/       a10, /**/ alpha11, a12,
          ABL, /**/ ABR,  A20, /**/ a21,     A22 );
    }
    PopBlocksizeStack();
#ifndef RELEASE
    PopCallStack();
#endif
}
コード例 #9
0
ファイル: L.hpp プロジェクト: khalid-hasanov/Elemental
void L( Matrix<F>& A, Matrix<F>& t )
{
#ifndef RELEASE
    CallStackEntry entry("hermitian_tridiag::L");
    if( A.Height() != A.Width() )
        LogicError("A must be square");
#endif
    typedef BASE(F) R;
    const Int tHeight = Max(A.Height()-1,0);
    t.ResizeTo( tHeight, 1 );

    // Matrix views 
    Matrix<F>
        ATL, ATR,  A00, a01,     A02,  alpha21T,
        ABL, ABR,  a10, alpha11, a12,  a21B,
                   A20, a21,     A22;

    // Temporary matrices
    Matrix<F> w21;

    PartitionDownDiagonal
    ( A, ATL, ATR,
         ABL, ABR, 0 );
    while( ATL.Height()+1 < A.Height() )
    {
        RepartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, /**/ a01,     A02,
         /*************/ /**********************/
               /**/       a10, /**/ alpha11, a12,
          ABL, /**/ ABR,  A20, /**/ a21,     A22, 1 );

        PartitionDown
        ( a21, alpha21T,
               a21B,     1 );

        //--------------------------------------------------------------------//
        const F tau = Reflector( alpha21T, a21B );
        const R epsilon1 = alpha21T.GetRealPart(0,0);
        t.Set(A00.Height(),0,tau);
        alpha21T.Set(0,0,F(1));

        Zeros( w21, a21.Height(), 1 );
        Hemv( LOWER, tau, A22, a21, F(0), w21 );
        const F alpha = -tau*Dot( w21, a21 )/F(2);
        Axpy( alpha, a21, w21 );
        Her2( LOWER, F(-1), a21, w21, A22 );
        alpha21T.Set(0,0,epsilon1);
        //--------------------------------------------------------------------//

        SlidePartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, a01,     /**/ A02,
               /**/       a10, alpha11, /**/ a12,
         /*************/ /**********************/
          ABL, /**/ ABR,  A20, a21,     /**/ A22 );
    }
}
コード例 #10
0
ファイル: xaxpy.hpp プロジェクト: dividiti/CLBlast
 // Describes how to run the CLBlast routine
 static StatusCode RunRoutine(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
   auto queue_plain = queue();
   auto event = cl_event{};
   auto status = Axpy(args.n, args.alpha,
                      buffers.x_vec(), args.x_offset, args.x_inc,
                      buffers.y_vec(), args.y_offset, args.y_inc,
                      &queue_plain, &event);
   if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); }
   return status;
 }
コード例 #11
0
inline void
NewtonStep
( const Matrix<F>& A, const Matrix<F>& X, Matrix<F>& XNew, Matrix<F>& XTmp )
{
#ifndef RELEASE
    CallStackEntry entry("square_root::NewtonStep");
#endif
    // XNew := inv(X) A
    XTmp = X;
    Matrix<Int> p;
    LU( XTmp, p );
    XNew = A;
    lu::SolveAfter( NORMAL, XTmp, p, XNew );

    // XNew := 1/2 ( X + XNew )
    typedef BASE(F) R;
    Axpy( R(1)/R(2), X, XNew );
}
コード例 #12
0
ファイル: NT.hpp プロジェクト: elemental/Elemental
void SUMMA_NTB
( Orientation orientB,
  T alpha,
  const AbstractDistMatrix<T>& APre,
  const AbstractDistMatrix<T>& BPre,
        AbstractDistMatrix<T>& CPre )
{
    EL_DEBUG_CSE
    const Int m = CPre.Height();
    const Int bsize = Blocksize();
    const Grid& g = APre.Grid();

    DistMatrixReadProxy<T,T,MC,MR> AProx( APre );
    DistMatrixReadProxy<T,T,MC,MR> BProx( BPre );
    DistMatrixReadWriteProxy<T,T,MC,MR> CProx( CPre );
    auto& A = AProx.GetLocked();
    auto& B = BProx.GetLocked();
    auto& C = CProx.Get();

    // Temporary distributions
    DistMatrix<T,MR,STAR> A1Trans_MR_STAR(g);
    DistMatrix<T,STAR,MC> D1_STAR_MC(g);
    DistMatrix<T,MR,MC> D1_MR_MC(g);

    A1Trans_MR_STAR.AlignWith( B );
    D1_STAR_MC.AlignWith( B );

    for( Int k=0; k<m; k+=bsize )
    {
        const Int nb = Min(bsize,m-k);
        auto A1 = A( IR(k,k+nb), ALL );
        auto C1 = C( IR(k,k+nb), ALL );

        // D1[*,MC] := alpha A1[*,MR] (B[MC,MR])^T
        //           = alpha (A1^T)[MR,*] (B^T)[MR,MC]
        Transpose( A1, A1Trans_MR_STAR );
        LocalGemm( TRANSPOSE, orientB, alpha, A1Trans_MR_STAR, B, D1_STAR_MC );

        // C1[MC,MR] += scattered & transposed D1[*,MC] summed over grid rows
        Contract( D1_STAR_MC, D1_MR_MC );
        Axpy( T(1), D1_MR_MC, C1 );
    }
}
コード例 #13
0
ファイル: TN.hpp プロジェクト: YingzhouLi/Elemental
void SUMMA_TNA
( Orientation orientA,
  T alpha,
  const AbstractDistMatrix<T>& APre,
  const AbstractDistMatrix<T>& BPre,
        AbstractDistMatrix<T>& CPre )
{
    DEBUG_CSE
    const Int n = CPre.Width();
    const Int bsize = Blocksize();
    const Grid& g = APre.Grid();

    DistMatrixReadProxy<T,T,MC,MR> AProx( APre );
    DistMatrixReadProxy<T,T,MC,MR> BProx( BPre );
    DistMatrixReadWriteProxy<T,T,MC,MR> CProx( CPre );
    auto& A = AProx.GetLocked();
    auto& B = BProx.GetLocked();
    auto& C = CProx.Get();

    // Temporary distributions
    DistMatrix<T,MC,STAR> B1_MC_STAR(g);
    DistMatrix<T,MR,STAR> D1_MR_STAR(g);
    DistMatrix<T,MR,MC  > D1_MR_MC(g);

    B1_MC_STAR.AlignWith( A );
    D1_MR_STAR.AlignWith( A );

    for( Int k=0; k<n; k+=bsize )
    {
        const Int nb = Min(bsize,n-k);
        auto B1 = B( ALL, IR(k,k+nb) );
        auto C1 = C( ALL, IR(k,k+nb) );

        // D1[MR,*] := alpha (A1[MC,MR])^T B1[MC,*]
        //           = alpha (A1^T)[MR,MC] B1[MC,*]
        B1_MC_STAR = B1; 
        LocalGemm( orientA, NORMAL, alpha, A, B1_MC_STAR, D1_MR_STAR );

        // C1[MC,MR] += scattered & transposed D1[MR,*] summed over grid cols
        Contract( D1_MR_STAR, D1_MR_MC );
        Axpy( T(1), D1_MR_MC, C1 );
    }
}
コード例 #14
0
ファイル: Sign.cpp プロジェクト: elemental/Elemental
Int
Newton( DistMatrix<Field>& A, const SignCtrl<Base<Field>>& ctrl )
{
    EL_DEBUG_CSE
    typedef Base<Field> Real;
    Real tol = ctrl.tol;
    if( tol == Real(0) )
        tol = A.Height()*limits::Epsilon<Real>();

    Int numIts=0;
    DistMatrix<Field> B( A.Grid() );
    DistMatrix<Field> *X=&A, *XNew=&B;
    while( numIts < ctrl.maxIts )
    {
        // Overwrite XNew with the new iterate
        NewtonStep( *X, *XNew, ctrl.scaling );

        // Use the difference in the iterates to test for convergence
        Axpy( Real(-1), *XNew, *X );
        const Real oneDiff = OneNorm( *X );
        const Real oneNew = OneNorm( *XNew );

        // Ensure that X holds the current iterate and break if possible
        ++numIts;
        std::swap( X, XNew );
        if( ctrl.progress && A.Grid().Rank() == 0 )
            cout << "after " << numIts << " Newton iter's: "
                 << "oneDiff=" << oneDiff << ", oneNew=" << oneNew
                 << ", oneDiff/oneNew=" << oneDiff/oneNew << ", tol="
                 << tol << endl;
        if( oneDiff/oneNew <= Pow(oneNew,ctrl.power)*tol )
            break;
    }
    if( X != &A )
        A = *X;
    return numIts;
}
コード例 #15
0
ファイル: Symv.hpp プロジェクト: ahmadia/Elemental-1
inline void
Symv
( UpperOrLower uplo,
  T alpha, const DistMatrix<T>& A,
           const DistMatrix<T>& x,
  T beta,        DistMatrix<T>& y,
  bool conjugate=false )
{
#ifndef RELEASE
    CallStackEntry entry("Symv");
    if( A.Grid() != x.Grid() || x.Grid() != y.Grid() )
        throw std::logic_error
        ("{A,x,y} must be distributed over the same grid");
    if( A.Height() != A.Width() )
        throw std::logic_error("A must be square");
    if( ( x.Width() != 1 && x.Height() != 1 ) ||
        ( y.Width() != 1 && y.Height() != 1 ) )
        throw std::logic_error("x and y are assumed to be vectors");
    const int xLength = ( x.Width()==1 ? x.Height() : x.Width() );
    const int yLength = ( y.Width()==1 ? y.Height() : y.Width() );
    if( A.Height() != xLength || A.Height() != yLength )
    {
        std::ostringstream msg;
        msg << "Nonconformal Symv: \n"
            << "  A ~ " << A.Height() << " x " << A.Width() << "\n"
            << "  x ~ " << x.Height() << " x " << x.Width() << "\n"
            << "  y ~ " << y.Height() << " x " << y.Width() << "\n";
        throw std::logic_error( msg.str() );
    }
#endif
    const Grid& g = A.Grid();

    if( x.Width() == 1 && y.Width() == 1 )
    {
        // Temporary distributions
        DistMatrix<T,MC,STAR> x_MC_STAR(g), z_MC_STAR(g);
        DistMatrix<T,MR,STAR> x_MR_STAR(g), z_MR_STAR(g);
        DistMatrix<T,MR,MC  > z_MR_MC(g);
        DistMatrix<T> z(g);

        // Begin the algoritm
        Scale( beta, y );
        x_MC_STAR.AlignWith( A );
        x_MR_STAR.AlignWith( A );
        z_MC_STAR.AlignWith( A );
        z_MR_STAR.AlignWith( A );
        z.AlignWith( y );
        Zeros( z_MC_STAR, y.Height(), 1 );
        Zeros( z_MR_STAR, y.Height(), 1 );
        //--------------------------------------------------------------------//
        x_MC_STAR = x;
        x_MR_STAR = x_MC_STAR;
        if( uplo == LOWER )
        {
            internal::LocalSymvColAccumulateL
            ( alpha, A, x_MC_STAR, x_MR_STAR, z_MC_STAR, z_MR_STAR, conjugate );
        }
        else
        {
            internal::LocalSymvColAccumulateU
            ( alpha, A, x_MC_STAR, x_MR_STAR, z_MC_STAR, z_MR_STAR, conjugate );
        }

        z_MR_MC.SumScatterFrom( z_MR_STAR );
        z = z_MR_MC;
        z.SumScatterUpdate( T(1), z_MC_STAR );
        Axpy( T(1), z, y );
        //--------------------------------------------------------------------//
        x_MC_STAR.FreeAlignments();
        x_MR_STAR.FreeAlignments();
        z_MC_STAR.FreeAlignments();
        z_MR_STAR.FreeAlignments();
        z.FreeAlignments();
    }
    else if( x.Width() == 1 )
    {
        // Temporary distributions
        DistMatrix<T,MC,STAR> x_MC_STAR(g), z_MC_STAR(g);
        DistMatrix<T,MR,STAR> x_MR_STAR(g), z_MR_STAR(g);
        DistMatrix<T,MR,MC  > z_MR_MC(g);
        DistMatrix<T> z(g), zTrans(g);

        // Begin the algoritm
        Scale( beta, y );
        x_MC_STAR.AlignWith( A );
        x_MR_STAR.AlignWith( A );
        z_MC_STAR.AlignWith( A );
        z_MR_STAR.AlignWith( A );
        z.AlignWith( y );
        z_MR_MC.AlignWith( y );
        Zeros( z_MC_STAR, y.Width(), 1 );
        Zeros( z_MR_STAR, y.Width(), 1 );
        //--------------------------------------------------------------------//
        x_MC_STAR = x;
        x_MR_STAR = x_MC_STAR;
        if( uplo == LOWER )
        {
            internal::LocalSymvColAccumulateL
            ( alpha, A, x_MC_STAR, x_MR_STAR, z_MC_STAR, z_MR_STAR, conjugate );
        }
        else
        {
            internal::LocalSymvColAccumulateU
            ( alpha, A, x_MC_STAR, x_MR_STAR, z_MC_STAR, z_MR_STAR, conjugate );
        }

        z.SumScatterFrom( z_MC_STAR );
        z_MR_MC = z;
        z_MR_MC.SumScatterUpdate( T(1), z_MR_STAR );
        Transpose( z_MR_MC, zTrans );
        Axpy( T(1), zTrans, y );
        //--------------------------------------------------------------------//
        x_MC_STAR.FreeAlignments();
        x_MR_STAR.FreeAlignments();
        z_MC_STAR.FreeAlignments();
        z_MR_STAR.FreeAlignments();
        z.FreeAlignments();
        z_MR_MC.FreeAlignments();
    }
    else if( y.Width() == 1 )
    {
        // Temporary distributions
        DistMatrix<T,STAR,MC> x_STAR_MC(g), z_STAR_MC(g);
        DistMatrix<T,STAR,MR> x_STAR_MR(g), z_STAR_MR(g);
        DistMatrix<T,MR,  MC> z_MR_MC(g);
        DistMatrix<T> z(g), zTrans(g);

        // Begin the algoritm
        Scale( beta, y );
        x_STAR_MC.AlignWith( A );
        x_STAR_MR.AlignWith( A );
        z_STAR_MC.AlignWith( A );
        z_STAR_MR.AlignWith( A );
        z.AlignWith( y );
        z_MR_MC.AlignWith( y );
        Zeros( z_STAR_MC, 1, y.Height() );
        Zeros( z_STAR_MR, 1, y.Height() );
        //--------------------------------------------------------------------//
        x_STAR_MR = x;
        x_STAR_MC = x_STAR_MR;
        if( uplo == LOWER )
        {
            internal::LocalSymvRowAccumulateL
            ( alpha, A, x_STAR_MC, x_STAR_MR, z_STAR_MC, z_STAR_MR, conjugate );
        }
        else
        {
            internal::LocalSymvRowAccumulateU
            ( alpha, A, x_STAR_MC, x_STAR_MR, z_STAR_MC, z_STAR_MR, conjugate );
        }

        z.SumScatterFrom( z_STAR_MR );
        z_MR_MC = z;
        z_MR_MC.SumScatterUpdate( T(1), z_STAR_MC );
        Transpose( z_MR_MC, zTrans );
        Axpy( T(1), zTrans, y );
        //--------------------------------------------------------------------//
        x_STAR_MC.FreeAlignments();
        x_STAR_MR.FreeAlignments();
        z_STAR_MC.FreeAlignments();
        z_STAR_MR.FreeAlignments();
        z.FreeAlignments();
        z_MR_MC.FreeAlignments();
    }
    else
    {
        // Temporary distributions
        DistMatrix<T,STAR,MC> x_STAR_MC(g), z_STAR_MC(g);
        DistMatrix<T,STAR,MR> x_STAR_MR(g), z_STAR_MR(g);
        DistMatrix<T,MR,  MC> z_MR_MC(g);
        DistMatrix<T> z(g);

        // Begin the algoritm
        Scale( beta, y );
        x_STAR_MC.AlignWith( A );
        x_STAR_MR.AlignWith( A );
        z_STAR_MC.AlignWith( A );
        z_STAR_MR.AlignWith( A );
        z.AlignWith( y );
        z_MR_MC.AlignWith( y );
        Zeros( z_STAR_MC, 1, y.Width() );
        Zeros( z_STAR_MR, 1, y.Width() );
        //--------------------------------------------------------------------//
        x_STAR_MR = x;
        x_STAR_MC = x_STAR_MR;
        if( uplo == LOWER )
        {
            internal::LocalSymvRowAccumulateL
            ( alpha, A, x_STAR_MC, x_STAR_MR, z_STAR_MC, z_STAR_MR, conjugate );
        }
        else
        {
            internal::LocalSymvRowAccumulateU
            ( alpha, A, x_STAR_MC, x_STAR_MR, z_STAR_MC, z_STAR_MR, conjugate );
        }

        z_MR_MC.SumScatterFrom( z_STAR_MC );
        z = z_MR_MC;
        z.SumScatterUpdate( T(1), z_STAR_MR );
        Axpy( T(1), z, y );
        //--------------------------------------------------------------------//
        x_STAR_MC.FreeAlignments();
        x_STAR_MR.FreeAlignments();
        z_STAR_MC.FreeAlignments();
        z_STAR_MR.FreeAlignments();
        z.FreeAlignments();
        z_MR_MC.FreeAlignments();
    }
}
コード例 #16
0
ファイル: Mod.hpp プロジェクト: YingzhouLi/Elemental
void LUMod
( Matrix<F>& A,
        Permutation& P, 
  const Matrix<F>& u,
  const Matrix<F>& v,
  bool conjugate,
  Base<F> tau )
{
    DEBUG_CSE
    typedef Base<F> Real;
    const Int m = A.Height();
    const Int n = A.Width();
    const Int minDim = Min(m,n);
    if( minDim != m )
        LogicError("It is assumed that height(A) <= width(A)");
    if( u.Height() != m || u.Width() != 1 )
        LogicError("u is expected to be a conforming column vector");
    if( v.Height() != n || v.Width() != 1 )
        LogicError("v is expected to be a conforming column vector");

    // w := inv(L) P u
    auto w( u );
    P.PermuteRows( w );
    Trsv( LOWER, NORMAL, UNIT, A, w );

    // Maintain an external vector for the temporary subdiagonal of U
    Matrix<F> uSub;
    Zeros( uSub, minDim-1, 1 );

    // Reduce w to a multiple of e0
    for( Int i=minDim-2; i>=0; --i )
    {
        // Decide if we should pivot the i'th and i+1'th rows of w
        const F lambdaSub = A(i+1,i);
        const F ups_ii = A(i,i); 
        const F omega_i = w(i);
        const F omega_ip1 = w(i+1);
        const Real rightTerm = Abs(lambdaSub*omega_i+omega_ip1);
        const bool pivot = ( Abs(omega_i) < tau*rightTerm );

        const Range<Int> indi( i, i+1 ),
                         indip1( i+1, i+2 ),
                         indB( i+2, m ),
                         indR( i+1, n );

        auto lBi   = A( indB,   indi   );
        auto lBip1 = A( indB,   indip1 );
        auto uiR   = A( indi,   indR   );
        auto uip1R = A( indip1, indR   );

        if( pivot )
        {
            // P := P_i P
            P.Swap( i, i+1 );

            // Simultaneously perform 
            //   U := P_i U and
            //   L := P_i L P_i^T
            //
            // Then update
            //     L := L T_{i,L}^{-1},
            //     U := T_{i,L} U, 
            //     w := T_{i,L} P_i w,
            // where T_{i,L} is the Gauss transform which zeros (P_i w)_{i+1}.
            // 
            // More succinctly,
            //     gamma    := w(i) / w(i+1),
            //     w(i)     := w(i+1), 
            //     w(i+1)   := 0,
            //     L(:,i)   += gamma L(:,i+1),
            //     U(i+1,:) -= gamma U(i,:).
            const F gamma = omega_i / omega_ip1;
            const F lambda_ii = F(1) + gamma*lambdaSub;
            A(i,  i) = gamma;
            A(i+1,i) = 0;

            auto lBiCopy = lBi;
            Swap( NORMAL, lBi, lBip1 );
            Axpy( gamma, lBiCopy, lBi );

            auto uip1RCopy = uip1R;
            RowSwap( A, i, i+1 );
            Axpy( -gamma, uip1RCopy, uip1R );

            // Force L back to *unit* lower-triangular form via the transform
            //     L := L T_{i,U}^{-1} D^{-1}, 
            // where D is diagonal and responsible for forcing L(i,i) and 
            // L(i+1,i+1) back to 1. The effect on L is:
            //     eta       := L(i,i+1)/L(i,i),
            //     L(:,i+1)  -= eta L(:,i),
            //     delta_i   := L(i,i),
            //     delta_ip1 := L(i+1,i+1),
            //     L(:,i)   /= delta_i,
            //     L(:,i+1) /= delta_ip1,
            // while the effect on U is
            //     U(i,:)   += eta U(i+1,:)
            //     U(i,:)   *= delta_i,
            //     U(i+1,:) *= delta_{i+1},
            // and the effect on w is
            //     w(i) *= delta_i.
            const F eta = lambdaSub/lambda_ii;
            const F delta_i = lambda_ii;
            const F delta_ip1 = F(1) - eta*gamma;

            Axpy( -eta, lBi, lBip1 );
            A(i+1,i) = gamma/delta_i;
            lBi   *= F(1)/delta_i;
            lBip1 *= F(1)/delta_ip1;

            A(i,i) = eta*ups_ii*delta_i;
            Axpy( eta, uip1R, uiR );
            uiR   *= delta_i;
            uip1R *= delta_ip1;
            uSub(i) = ups_ii*delta_ip1;

            // Finally set w(i)
            w(i) = omega_ip1*delta_i;
        }
        else
        {
            // Update
            //     L := L T_{i,L}^{-1},
            //     U := T_{i,L} U, 
            //     w := T_{i,L} w,
            // where T_{i,L} is the Gauss transform which zeros w_{i+1}.
            // 
            // More succinctly,
            //     gamma    := w(i+1) / w(i),
            //     L(:,i)   += gamma L(:,i+1),
            //     U(i+1,:) -= gamma U(i,:),
            //     w(i+1)   := 0.
            const F gamma = omega_ip1 / omega_i;
            A(i+1,i) += gamma;
            Axpy(  gamma, lBip1, lBi );
            Axpy( -gamma, uiR, uip1R );
            uSub(i) = -gamma*ups_ii;
        }
    }

    // Add the modified w v' into U
    {
        auto a0 = A( IR(0), ALL );
        const F omega_0 = w(0); 
        Matrix<F> vTrans;
        Transpose( v, vTrans, conjugate );
        Axpy( omega_0, vTrans, a0 );
    }

    // Transform U from upper-Hessenberg to upper-triangular form
    for( Int i=0; i<minDim-1; ++i ) 
    {
        // Decide if we should pivot the i'th and i+1'th rows U
        const F lambdaSub = A(i+1,i);
        const F ups_ii = A(i,i);
        const F ups_ip1i = uSub(i);
        const Real rightTerm = Abs(lambdaSub*ups_ii+ups_ip1i);
        const bool pivot = ( Abs(ups_ii) < tau*rightTerm );

        const Range<Int> indi( i, i+1 ),
                         indip1( i+1, i+2 ),
                         indB( i+2, m ),
                         indR( i+1, n );

        auto lBi   = A( indB,   indi   );
        auto lBip1 = A( indB,   indip1 );
        auto uiR   = A( indi,   indR   );
        auto uip1R = A( indip1, indR   );

        if( pivot )
        {
            // P := P_i P
            P.Swap( i, i+1 );

            // Simultaneously perform 
            //   U := P_i U and
            //   L := P_i L P_i^T
            //
            // Then update
            //     L := L T_{i,L}^{-1},
            //     U := T_{i,L} U, 
            // where T_{i,L} is the Gauss transform which zeros U(i+1,i).
            // 
            // More succinctly,
            //     gamma    := U(i+1,i) / U(i,i),
            //     L(:,i)   += gamma L(:,i+1),
            //     U(i+1,:) -= gamma U(i,:).
            const F gamma = ups_ii / ups_ip1i;
            const F lambda_ii = F(1) + gamma*lambdaSub;
            A(i+1,i) = ups_ip1i;
            A(i,  i) = gamma;

            auto lBiCopy = lBi;
            Swap( NORMAL, lBi, lBip1 );
            Axpy( gamma, lBiCopy, lBi );

            auto uip1RCopy = uip1R;
            RowSwap( A, i, i+1 );
            Axpy( -gamma, uip1RCopy, uip1R );

            // Force L back to *unit* lower-triangular form via the transform
            //     L := L T_{i,U}^{-1} D^{-1}, 
            // where D is diagonal and responsible for forcing L(i,i) and 
            // L(i+1,i+1) back to 1. The effect on L is:
            //     eta       := L(i,i+1)/L(i,i),
            //     L(:,i+1)  -= eta L(:,i),
            //     delta_i   := L(i,i),
            //     delta_ip1 := L(i+1,i+1),
            //     L(:,i)   /= delta_i,
            //     L(:,i+1) /= delta_ip1,
            // while the effect on U is
            //     U(i,:)   += eta U(i+1,:)
            //     U(i,:)   *= delta_i,
            //     U(i+1,:) *= delta_{i+1}.
            const F eta = lambdaSub/lambda_ii;
            const F delta_i = lambda_ii;
            const F delta_ip1 = F(1) - eta*gamma;

            Axpy( -eta, lBi, lBip1 );
            A(i+1,i) = gamma/delta_i;
            lBi   *= F(1)/delta_i;
            lBip1 *= F(1)/delta_ip1;

            A(i,i) = ups_ip1i*delta_i;
            Axpy( eta, uip1R, uiR );
            uiR   *= delta_i;
            uip1R *= delta_ip1;
        }
        else
        {
            // Update
            //     L := L T_{i,L}^{-1},
            //     U := T_{i,L} U, 
            // where T_{i,L} is the Gauss transform which zeros U(i+1,i).
            // 
            // More succinctly,
            //     gamma    := U(i+1,i)/ U(i,i),
            //     L(:,i)   += gamma L(:,i+1),
            //     U(i+1,:) -= gamma U(i,:).
            const F gamma = ups_ip1i / ups_ii;
            A(i+1,i) += gamma;
            Axpy(  gamma, lBip1, lBi );
            Axpy( -gamma, uiR, uip1R );
        }
    }
}
コード例 #17
0
ファイル: UVar5.hpp プロジェクト: jimgoo/Elemental
inline void
TwoSidedTrmmUVar5( UnitOrNonUnit diag, Matrix<F>& A, const Matrix<F>& U )
{
#ifndef RELEASE
    PushCallStack("internal::TwoSidedTrmmUVar5");
    if( A.Height() != A.Width() )
        throw std::logic_error("A must be square");
    if( U.Height() != U.Width() )
        throw std::logic_error("Triangular matrices must be square");
    if( A.Height() != U.Height() )
        throw std::logic_error("A and U must be the same size");
#endif
    // Matrix views
    Matrix<F>
        ATL, ATR,  A00, A01, A02,
        ABL, ABR,  A10, A11, A12,
                   A20, A21, A22;
    Matrix<F>
        UTL, UTR,  U00, U01, U02,
        UBL, UBR,  U10, U11, U12,
                   U20, U21, U22;

    // Temporary products
    Matrix<F> Y01;

    PartitionDownDiagonal
    ( A, ATL, ATR,
         ABL, ABR, 0 );
    LockedPartitionDownDiagonal
    ( U, UTL, UTR,
         UBL, UBR, 0 );
    while( ATL.Height() < A.Height() )
    {
        RepartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, /**/ A01, A02,
         /*************/ /******************/
               /**/       A10, /**/ A11, A12,
          ABL, /**/ ABR,  A20, /**/ A21, A22 );

        LockedRepartitionDownDiagonal
        ( UTL, /**/ UTR,  U00, /**/ U01, U02,
         /*************/ /******************/
               /**/       U10, /**/ U11, U12,
          UBL, /**/ UBR,  U20, /**/ U21, U22 );

        //--------------------------------------------------------------------//
        // Y01 := U01 A11
        Zeros( A01.Height(), A01.Width(), Y01 );
        Hemm( RIGHT, UPPER, F(1), A11, U01, F(0), Y01 );

        // A01 := U00 A01
        Trmm( LEFT, UPPER, NORMAL, diag, F(1), U00, A01 );

        // A01 := A01 + 1/2 Y01
        Axpy( F(1)/F(2), Y01, A01 );

        // A00 := A00 + (U01 A01' + A01 U01')
        Her2k( UPPER, NORMAL, F(1), U01, A01, F(1), A00 );

        // A01 := A01 + 1/2 Y01
        Axpy( F(1)/F(2), Y01, A01 );

        // A01 := A01 U11'
        Trmm( RIGHT, UPPER, ADJOINT, diag, F(1), U11, A01 );

        // A11 := U11 A11 U11'
        TwoSidedTrmmUUnb( diag, A11, U11 );
        //--------------------------------------------------------------------//

        SlidePartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, A01, /**/ A02,
               /**/       A10, A11, /**/ A12,
         /*************/ /******************/
          ABL, /**/ ABR,  A20, A21, /**/ A22 );

        SlideLockedPartitionDownDiagonal
        ( UTL, /**/ UTR,  U00, U01, /**/ U02,
               /**/       U10, U11, /**/ U12,
         /*************/ /******************/
          UBL, /**/ UBR,  U20, U21, /**/ U22 );
    }
#ifndef RELEASE
    PopCallStack();
#endif
}
コード例 #18
0
ファイル: UVar5.hpp プロジェクト: jimgoo/Elemental
inline void
TwoSidedTrmmUVar5
( UnitOrNonUnit diag, DistMatrix<F>& A, const DistMatrix<F>& U )
{
#ifndef RELEASE
    PushCallStack("internal::TwoSidedTrmmUVar5");
    if( A.Height() != A.Width() )
        throw std::logic_error("A must be square");
    if( U.Height() != U.Width() )
        throw std::logic_error("Triangular matrices must be square");
    if( A.Height() != U.Height() )
        throw std::logic_error("A and U must be the same size");
#endif
    const Grid& g = A.Grid();

    // Matrix views
    DistMatrix<F>
        ATL(g), ATR(g),  A00(g), A01(g), A02(g),
        ABL(g), ABR(g),  A10(g), A11(g), A12(g),
                         A20(g), A21(g), A22(g);
    DistMatrix<F>
        UTL(g), UTR(g),  U00(g), U01(g), U02(g),
        UBL(g), UBR(g),  U10(g), U11(g), U12(g),
                         U20(g), U21(g), U22(g);

    // Temporary distributions
    DistMatrix<F,STAR,STAR> A11_STAR_STAR(g);
    DistMatrix<F,MC,  STAR> A01_MC_STAR(g);
    DistMatrix<F,MR,  STAR> A01_MR_STAR(g);
    DistMatrix<F,VC,  STAR> A01_VC_STAR(g);
    DistMatrix<F,STAR,STAR> U11_STAR_STAR(g);
    DistMatrix<F,MC,  STAR> U01_MC_STAR(g);
    DistMatrix<F,MR,  STAR> U01_MR_STAR(g);
    DistMatrix<F,VC,  STAR> U01_VC_STAR(g);
    DistMatrix<F,VC,  STAR> Y01_VC_STAR(g);
    DistMatrix<F> Y01(g);

    PartitionDownDiagonal
    ( A, ATL, ATR,
         ABL, ABR, 0 );
    LockedPartitionDownDiagonal
    ( U, UTL, UTR,
         UBL, UBR, 0 );
    while( ATL.Height() < A.Height() )
    {
        RepartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, /**/ A01, A02,
         /*************/ /******************/
               /**/       A10, /**/ A11, A12,
          ABL, /**/ ABR,  A20, /**/ A21, A22 );

        LockedRepartitionDownDiagonal
        ( UTL, /**/ UTR,  U00, /**/ U01, U02,
         /*************/ /******************/
               /**/       U10, /**/ U11, U12,
          UBL, /**/ UBR,  U20, /**/ U21, U22 );

        A01_MC_STAR.AlignWith( A00 );
        A01_MR_STAR.AlignWith( A00 );
        A01_VC_STAR.AlignWith( A00 );
        U01_MC_STAR.AlignWith( A00 );
        U01_MR_STAR.AlignWith( A00 );
        U01_VC_STAR.AlignWith( A00 );
        Y01.AlignWith( A01 );
        Y01_VC_STAR.AlignWith( A01 );
        //--------------------------------------------------------------------//
        // Y01 := U01 A11
        A11_STAR_STAR = A11;
        U01_VC_STAR = U01;
        Y01_VC_STAR.ResizeTo( A01.Height(), A01.Width() );
        Hemm
        ( RIGHT, UPPER,
          F(1), A11_STAR_STAR.LocalMatrix(), U01_VC_STAR.LocalMatrix(),
          F(0), Y01_VC_STAR.LocalMatrix() );
        Y01 = Y01_VC_STAR;

        // A01 := U00 A01
        Trmm( LEFT, UPPER, NORMAL, diag, F(1), U00, A01 );

        // A01 := A01 + 1/2 Y01
        Axpy( F(1)/F(2), Y01, A01 );

        // A00 := A00 + (U01 A01' + A01 U01')
        A01_MC_STAR = A01;
        U01_MC_STAR = U01;
        A01_VC_STAR = A01_MC_STAR;
        A01_MR_STAR = A01_VC_STAR;
        U01_MR_STAR = U01_MC_STAR;
        LocalTrr2k
        ( UPPER, ADJOINT, ADJOINT,
          F(1), U01_MC_STAR, A01_MR_STAR, 
                A01_MC_STAR, U01_MR_STAR,
          F(1), A00 );

        // A01 := A01 + 1/2 Y01
        Axpy( F(1)/F(2), Y01_VC_STAR, A01_VC_STAR );

        // A01 := A01 U11'
        U11_STAR_STAR = U11;
        LocalTrmm
        ( RIGHT, UPPER, ADJOINT, diag, F(1), U11_STAR_STAR, A01_VC_STAR );
        A01 = A01_VC_STAR;

        // A11 := U11 A11 U11'
        LocalTwoSidedTrmm( UPPER, diag, A11_STAR_STAR, U11_STAR_STAR );
        A11 = A11_STAR_STAR;
        //--------------------------------------------------------------------//
        A01_MC_STAR.FreeAlignments();
        A01_MR_STAR.FreeAlignments();
        A01_VC_STAR.FreeAlignments();
        U01_MC_STAR.FreeAlignments();
        U01_MR_STAR.FreeAlignments();
        U01_VC_STAR.FreeAlignments();
        Y01.FreeAlignments();
        Y01_VC_STAR.FreeAlignments();

        SlidePartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, A01, /**/ A02,
               /**/       A10, A11, /**/ A12,
         /*************/ /******************/
          ABL, /**/ ABR,  A20, A21, /**/ A22 );

        SlideLockedPartitionDownDiagonal
        ( UTL, /**/ UTR,  U00, U01, /**/ U02,
               /**/       U10, U11, /**/ U12,
         /*************/ /******************/
          UBL, /**/ UBR,  U20, U21, /**/ U22 );
    }
#ifndef RELEASE
    PopCallStack();
#endif
}
コード例 #19
0
ファイル: LL.hpp プロジェクト: jimgoo/Elemental
inline void
SymmLLA
( T alpha, const DistMatrix<T>& A, const DistMatrix<T>& B,
  T beta,        DistMatrix<T>& C )
{
#ifndef RELEASE
    PushCallStack("internal::SymmLLA");
    if( A.Grid() != B.Grid() || B.Grid() != C.Grid() )
        throw std::logic_error
        ("{A,B,C} must be distributed over the same grid");
#endif
    const Grid& g = A.Grid();

    DistMatrix<T> 
        BL(g), BR(g),
        B0(g), B1(g), B2(g);

    DistMatrix<T>
        CL(g), CR(g),
        C0(g), C1(g), C2(g);

    DistMatrix<T,MC,STAR> B1_MC_STAR(g);
    DistMatrix<T,VR,STAR> B1_VR_STAR(g);
    DistMatrix<T,STAR,MR> B1Trans_STAR_MR(g);
    DistMatrix<T> Z1(g);
    DistMatrix<T,MC,STAR> Z1_MC_STAR(g);
    DistMatrix<T,MR,STAR> Z1_MR_STAR(g);
    DistMatrix<T,MR,MC  > Z1_MR_MC(g);

    B1_MC_STAR.AlignWith( A );
    B1_VR_STAR.AlignWith( A );
    B1Trans_STAR_MR.AlignWith( A );
    Z1_MC_STAR.AlignWith( A );
    Z1_MR_STAR.AlignWith( A );

    Scale( beta, C );
    LockedPartitionRight
    ( B, BL, BR, 0 );
    PartitionRight
    ( C, CL, CR, 0 );
    while( CL.Width() < C.Width() )
    {
        LockedRepartitionRight 
        ( BL, /**/ BR,
          B0, /**/ B1, B2 );

        RepartitionRight
        ( CL, /**/ CR,
          C0, /**/ C1, C2 );

        Z1.AlignWith( C1 );
        Zeros( C1.Height(), C1.Width(), Z1_MC_STAR );
        Zeros( C1.Height(), C1.Width(), Z1_MR_STAR );
        //--------------------------------------------------------------------//
        B1_MC_STAR = B1;
        B1_VR_STAR = B1_MC_STAR;
        B1Trans_STAR_MR.TransposeFrom( B1_VR_STAR );
        LocalSymmetricAccumulateLL
        ( TRANSPOSE, 
          alpha, A, B1_MC_STAR, B1Trans_STAR_MR, Z1_MC_STAR, Z1_MR_STAR );

        Z1_MR_MC.SumScatterFrom( Z1_MR_STAR );
        Z1 = Z1_MR_MC;
        Z1.SumScatterUpdate( T(1), Z1_MC_STAR );
        Axpy( T(1), Z1, C1 );
        //--------------------------------------------------------------------//
        Z1.FreeAlignments();

        SlideLockedPartitionRight
        ( BL,     /**/ BR,
          B0, B1, /**/ B2 );

        SlidePartitionRight
        ( CL,     /**/ CR,
          C0, C1, /**/ C2 );
    }
#ifndef RELEASE
    PopCallStack();
#endif
}
コード例 #20
0
ファイル: UVar1.hpp プロジェクト: khalid-hasanov/Elemental
inline void
TwoSidedTrsmUVar1
( UnitOrNonUnit diag, DistMatrix<F>& A, const DistMatrix<F>& U )
{
#ifndef RELEASE
    CallStackEntry entry("internal::TwoSidedTrsmUVar1");
    if( A.Height() != A.Width() )
        LogicError("A must be square");
    if( U.Height() != U.Width() )
        LogicError("Triangular matrices must be square");
    if( A.Height() != U.Height() )
        LogicError("A and U must be the same size");
#endif
    const Grid& g = A.Grid();

    // Matrix views
    DistMatrix<F>
        ATL(g), ATR(g),  A00(g), A01(g), A02(g),
        ABL(g), ABR(g),  A10(g), A11(g), A12(g),
                         A20(g), A21(g), A22(g);
    DistMatrix<F>
        UTL(g), UTR(g),  U00(g), U01(g), U02(g),
        UBL(g), UBR(g),  U10(g), U11(g), U12(g),
                         U20(g), U21(g), U22(g);

    // Temporary distributions
    DistMatrix<F,STAR,STAR> A11_STAR_STAR(g);
    DistMatrix<F,VC,  STAR> A01_VC_STAR(g);
    DistMatrix<F,STAR,STAR> U11_STAR_STAR(g);
    DistMatrix<F,MC,  STAR> U01_MC_STAR(g);
    DistMatrix<F,VC,  STAR> U01_VC_STAR(g);
    DistMatrix<F,VR,  STAR> U01_VR_STAR(g);
    DistMatrix<F,STAR,MR  > U01Adj_STAR_MR(g);
    DistMatrix<F,STAR,STAR> X11_STAR_STAR(g);
    DistMatrix<F,MR,  MC  > Z01_MR_MC(g);
    DistMatrix<F,MC,  STAR> Z01_MC_STAR(g);
    DistMatrix<F,MR,  STAR> Z01_MR_STAR(g);
    DistMatrix<F> Y01(g);

    PartitionDownDiagonal
    ( A, ATL, ATR,
         ABL, ABR, 0 );
    LockedPartitionDownDiagonal
    ( U, UTL, UTR,
         UBL, UBR, 0 );
    while( ATL.Height() < A.Height() )
    {
        RepartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, /**/ A01, A02,
         /*************/ /******************/
               /**/       A10, /**/ A11, A12,
          ABL, /**/ ABR,  A20, /**/ A21, A22 );

        LockedRepartitionDownDiagonal
        ( UTL, /**/ UTR,  U00, /**/ U01, U02,
         /*************/ /******************/
               /**/       U10, /**/ U11, U12,
          UBL, /**/ UBR,  U20, /**/ U21, U22 );

        A01_VC_STAR.AlignWith( A01 );
        U01_MC_STAR.AlignWith( A00 );
        U01_VR_STAR.AlignWith( A00 );
        U01_VC_STAR.AlignWith( A00 );
        U01Adj_STAR_MR.AlignWith( A00 );
        Y01.AlignWith( A01 );
        Z01_MR_MC.AlignWith( A01 );
        Z01_MC_STAR.AlignWith( A00 );
        Z01_MR_STAR.AlignWith( A00 );
        //--------------------------------------------------------------------//
        // Y01 := A00 U01
        U01_MC_STAR = U01;
        U01_VR_STAR = U01_MC_STAR;
        U01Adj_STAR_MR.AdjointFrom( U01_VR_STAR );
        Zeros( Z01_MC_STAR, A01.Height(), A01.Width() );
        Zeros( Z01_MR_STAR, A01.Height(), A01.Width() );
        LocalSymmetricAccumulateLU
        ( ADJOINT, 
          F(1), A00, U01_MC_STAR, U01Adj_STAR_MR, Z01_MC_STAR, Z01_MR_STAR );
        Z01_MR_MC.SumScatterFrom( Z01_MR_STAR );
        Y01 = Z01_MR_MC;
        Y01.SumScatterUpdate( F(1), Z01_MC_STAR );

        // A01 := inv(U00)' A01
        //
        // This is the bottleneck because A01 only has blocksize columns
        Trsm( LEFT, UPPER, ADJOINT, diag, F(1), U00, A01 );

        // A01 := A01 - 1/2 Y01
        Axpy( F(-1)/F(2), Y01, A01 );

        // A11 := A11 - (U01' A01 + A01' U01)
        A01_VC_STAR = A01;
        U01_VC_STAR = U01_MC_STAR;
        Zeros( X11_STAR_STAR, A11.Height(), A11.Width() );
        Her2k
        ( UPPER, ADJOINT,
          F(-1), A01_VC_STAR.Matrix(), U01_VC_STAR.Matrix(),
          F(0), X11_STAR_STAR.Matrix() );
        A11.SumScatterUpdate( F(1), X11_STAR_STAR );

        // A11 := inv(U11)' A11 inv(U11)
        A11_STAR_STAR = A11;
        U11_STAR_STAR = U11;
        LocalTwoSidedTrsm( UPPER, diag, A11_STAR_STAR, U11_STAR_STAR );
        A11 = A11_STAR_STAR;

        // A01 := A01 - 1/2 Y01
        Axpy( F(-1)/F(2), Y01, A01 );

        // A01 := A01 inv(U11)
        A01_VC_STAR = A01;
        LocalTrsm
        ( RIGHT, UPPER, NORMAL, diag, F(1), U11_STAR_STAR, A01_VC_STAR );
        A01 = A01_VC_STAR;
        //--------------------------------------------------------------------//

        SlidePartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, A01, /**/ A02,
               /**/       A10, A11, /**/ A12,
         /*************/ /******************/
          ABL, /**/ ABR,  A20, A21, /**/ A22 );

        SlideLockedPartitionDownDiagonal
        ( UTL, /**/ UTR,  U00, U01, /**/ U02,
               /**/       U10, U11, /**/ U12,
         /*************/ /******************/
          UBL, /**/ UBR,  U20, U21, /**/ U22 );
    }
}
コード例 #21
0
ファイル: LLN.hpp プロジェクト: ahmadia/Elemental-1
inline void
TrmmLLNCOld
( UnitOrNonUnit diag,
  T alpha, const DistMatrix<T>& L,
                 DistMatrix<T>& X )
{
#ifndef RELEASE
    CallStackEntry entry("internal::TrmmLLNCOld");
    if( L.Grid() != X.Grid() )
        throw std::logic_error
        ("L and X must be distributed over the same grid");
    if( L.Height() != L.Width() || L.Width() != X.Height() )
    {
        std::ostringstream msg;
        msg << "Nonconformal TrmmLLNC: \n"
            << "  L ~ " << L.Height() << " x " << L.Width() << "\n"
            << "  X ~ " << X.Height() << " x " << X.Width() << "\n";
        throw std::logic_error( msg.str().c_str() );
    }
#endif
    const Grid& g = L.Grid();

    // Matrix views
    DistMatrix<T> 
        LTL(g), LTR(g),  L00(g), L01(g), L02(g),
        LBL(g), LBR(g),  L10(g), L11(g), L12(g),
                         L20(g), L21(g), L22(g);
    DistMatrix<T> XT(g),  X0(g),
                  XB(g),  X1(g),
                          X2(g);

    // Temporary distributions
    DistMatrix<T,STAR,MC  > L10_STAR_MC(g);
    DistMatrix<T,STAR,STAR> L11_STAR_STAR(g);
    DistMatrix<T,STAR,VR  > X1_STAR_VR(g);
    DistMatrix<T,MR,  STAR> D1Trans_MR_STAR(g);
    DistMatrix<T,MR,  MC  > D1Trans_MR_MC(g);
    DistMatrix<T,MC,  MR  > D1(g);

    // Start the algorithm
    Scale( alpha, X );
    LockedPartitionUpDiagonal
    ( L, LTL, LTR,
         LBL, LBR, 0 );
    PartitionUp
    ( X, XT,
         XB, 0 );
    while( XT.Height() > 0 )
    {
        LockedRepartitionUpDiagonal
        ( LTL, /**/ LTR,  L00, L01, /**/ L02,
               /**/       L10, L11, /**/ L12,
         /*************/ /******************/
          LBL, /**/ LBR,  L20, L21, /**/ L22 );

        RepartitionUp
        ( XT,  X0,
               X1,
         /**/ /**/
          XB,  X2 );

        L10_STAR_MC.AlignWith( X0 );
        D1Trans_MR_STAR.AlignWith( X1 );
        D1Trans_MR_MC.AlignWith( X1 );
        D1.AlignWith( X1 );
        //--------------------------------------------------------------------//
        L11_STAR_STAR = L11;
        X1_STAR_VR = X1;
        LocalTrmm( LEFT, LOWER, NORMAL, diag, T(1), L11_STAR_STAR, X1_STAR_VR );
        X1 = X1_STAR_VR;

        L10_STAR_MC = L10;
        LocalGemm
        ( TRANSPOSE, TRANSPOSE, T(1), X0, L10_STAR_MC, D1Trans_MR_STAR );
        D1Trans_MR_MC.SumScatterFrom( D1Trans_MR_STAR );
        Zeros( D1, X1.Height(), X1.Width() );
        Transpose( D1Trans_MR_MC.Matrix(), D1.Matrix() );
        Axpy( T(1), D1, X1 );
        //--------------------------------------------------------------------//
        D1.FreeAlignments();
        D1Trans_MR_MC.FreeAlignments();
        D1Trans_MR_STAR.FreeAlignments();
        L10_STAR_MC.FreeAlignments();

        SlideLockedPartitionUpDiagonal
        ( LTL, /**/ LTR,  L00, /**/ L01, L02,
         /*************/ /******************/
               /**/       L10, /**/ L11, L12, 
          LBL, /**/ LBR,  L20, /**/ L21, L22 );

        SlidePartitionUp
        ( XT,  X0,
         /**/ /**/
               X1,
          XB,  X2 );
    }
}
コード例 #22
0
ファイル: UVar5.hpp プロジェクト: certik/Elemental
inline void
TwoSidedTrsmUVar5
( UnitOrNonUnit diag, DistMatrix<F>& A, const DistMatrix<F>& U )
{
#ifndef RELEASE
    PushCallStack("internal::TwoSidedTrsmUVar5");
    if( A.Height() != A.Width() )
        throw std::logic_error("A must be square");
    if( U.Height() != U.Width() )
        throw std::logic_error("Triangular matrices must be square");
    if( A.Height() != U.Height() )
        throw std::logic_error("A and U must be the same size");
#endif
    const Grid& g = A.Grid();

    // Matrix views
    DistMatrix<F>
        ATL(g), ATR(g),  A00(g), A01(g), A02(g),
        ABL(g), ABR(g),  A10(g), A11(g), A12(g),
                         A20(g), A21(g), A22(g);
    DistMatrix<F>
        UTL(g), UTR(g),  U00(g), U01(g), U02(g),
        UBL(g), UBR(g),  U10(g), U11(g), U12(g),
                         U20(g), U21(g), U22(g);

    // Temporary distributions
    DistMatrix<F,STAR,STAR> A11_STAR_STAR(g);
    DistMatrix<F,STAR,MC  > A12_STAR_MC(g);
    DistMatrix<F,STAR,MR  > A12_STAR_MR(g);
    DistMatrix<F,STAR,VC  > A12_STAR_VC(g);
    DistMatrix<F,STAR,VR  > A12_STAR_VR(g);
    DistMatrix<F,STAR,STAR> U11_STAR_STAR(g);
    DistMatrix<F,STAR,MC  > U12_STAR_MC(g);
    DistMatrix<F,STAR,MR  > U12_STAR_MR(g);
    DistMatrix<F,STAR,VC  > U12_STAR_VC(g);
    DistMatrix<F,STAR,VR  > U12_STAR_VR(g);
    DistMatrix<F,STAR,VR  > Y12_STAR_VR(g);
    DistMatrix<F> Y12(g);

    PartitionDownDiagonal
    ( A, ATL, ATR,
         ABL, ABR, 0 );
    LockedPartitionDownDiagonal
    ( U, UTL, UTR,
         UBL, UBR, 0 );
    while( ATL.Height() < A.Height() )
    {
        RepartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, /**/ A01, A02,
         /*************/ /******************/
               /**/       A10, /**/ A11, A12,
          ABL, /**/ ABR,  A20, /**/ A21, A22 );

        LockedRepartitionDownDiagonal
        ( UTL, /**/ UTR,  U00, /**/ U01, U02,
         /*************/ /******************/
               /**/       U10, /**/ U11, U12,
          UBL, /**/ UBR,  U20, /**/ U21, U22 );

        A12_STAR_MC.AlignWith( A22 );
        A12_STAR_MR.AlignWith( A22 );
        A12_STAR_VC.AlignWith( A22 );
        A12_STAR_VR.AlignWith( A22 );
        U12_STAR_MC.AlignWith( A22 );
        U12_STAR_MR.AlignWith( A22 );
        U12_STAR_VC.AlignWith( A22 );
        U12_STAR_VR.AlignWith( A22 );
        Y12.AlignWith( A12 );
        Y12_STAR_VR.AlignWith( A12 );
        //--------------------------------------------------------------------//
        // A11 := inv(U11)' A11 inv(U11)
        U11_STAR_STAR = U11;
        A11_STAR_STAR = A11;
        LocalTwoSidedTrsm( UPPER, diag, A11_STAR_STAR, U11_STAR_STAR );
        A11 = A11_STAR_STAR;

        // Y12 := A11 U12
        U12_STAR_VR = U12;
        Y12_STAR_VR.ResizeTo( A12.Height(), A12.Width() );
        Hemm
        ( LEFT, UPPER,
          F(1), A11_STAR_STAR.LocalMatrix(), U12_STAR_VR.LocalMatrix(),
          F(0), Y12_STAR_VR.LocalMatrix() );
        Y12 = Y12_STAR_VR;

        // A12 := inv(U11)' A12
        A12_STAR_VR = A12;
        LocalTrsm
        ( LEFT, UPPER, ADJOINT, diag, F(1), U11_STAR_STAR, A12_STAR_VR );
        A12 = A12_STAR_VR;

        // A12 := A12 - 1/2 Y12
        Axpy( F(-1)/F(2), Y12, A12 );

        // A22 := A22 - (A12' U12 + U12' A12)
        A12_STAR_VR = A12;
        A12_STAR_VC = A12_STAR_VR;
        U12_STAR_VC = U12_STAR_VR;
        A12_STAR_MC = A12_STAR_VC;
        U12_STAR_MC = U12_STAR_VC;
        A12_STAR_MR = A12_STAR_VR;
        U12_STAR_MR = U12_STAR_VR;
        LocalTrr2k
        ( UPPER, ADJOINT, ADJOINT,
          F(-1), U12_STAR_MC, A12_STAR_MR,
                 A12_STAR_MC, U12_STAR_MR,
          F(1), A22 );

        // A12 := A12 - 1/2 Y12
        Axpy( F(-1)/F(2), Y12, A12 );

        // A12 := A12 inv(U22)
        //
        // This is the bottleneck because A12 only has blocksize rows
        Trsm( RIGHT, UPPER, NORMAL, diag, F(1), U22, A12 );
        //--------------------------------------------------------------------//
        A12_STAR_MC.FreeAlignments();
        A12_STAR_MR.FreeAlignments();
        A12_STAR_VC.FreeAlignments();
        A12_STAR_VR.FreeAlignments();
        U12_STAR_MC.FreeAlignments();
        U12_STAR_MR.FreeAlignments();
        U12_STAR_VC.FreeAlignments();
        U12_STAR_VR.FreeAlignments();
        Y12.FreeAlignments();
        Y12_STAR_VR.FreeAlignments();

        SlidePartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, A01, /**/ A02,
               /**/       A10, A11, /**/ A12,
         /*************/ /******************/
          ABL, /**/ ABR,  A20, A21, /**/ A22 );

        SlideLockedPartitionDownDiagonal
        ( UTL, /**/ UTR,  U00, U01, /**/ U02,
               /**/       U10, U11, /**/ U12,
         /*************/ /******************/
          UBL, /**/ UBR,  U20, U21, /**/ U22 );
    }
#ifndef RELEASE
    PopCallStack();
#endif
}
コード例 #23
0
ファイル: LVar4.hpp プロジェクト: khalid-hasanov/Elemental
inline void
TwoSidedTrmmLVar4( UnitOrNonUnit diag, Matrix<F>& A, const Matrix<F>& L )
{
#ifndef RELEASE
    CallStackEntry entry("internal::TwoSidedTrmmLVar4");
    if( A.Height() != A.Width() )
        LogicError("A must be square");
    if( L.Height() != L.Width() )
        LogicError("Triangular matrices must be square");
    if( A.Height() != L.Height() )
        LogicError("A and L must be the same size");
#endif
    // Matrix views
    Matrix<F>
        ATL, ATR,  A00, A01, A02,
        ABL, ABR,  A10, A11, A12,
                   A20, A21, A22;
    Matrix<F>
        LTL, LTR,  L00, L01, L02,
        LBL, LBR,  L10, L11, L12,
                   L20, L21, L22;

    // Temporary products
    Matrix<F> Y10;

    PartitionDownDiagonal
    ( A, ATL, ATR,
         ABL, ABR, 0 );
    LockedPartitionDownDiagonal
    ( L, LTL, LTR,
         LBL, LBR, 0 );
    while( ATL.Height() < A.Height() )
    {
        RepartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, /**/ A01, A02,
         /*************/ /******************/
               /**/       A10, /**/ A11, A12,
          ABL, /**/ ABR,  A20, /**/ A21, A22 );

        LockedRepartitionDownDiagonal
        ( LTL, /**/ LTR,  L00, /**/ L01, L02,
         /*************/ /******************/
               /**/       L10, /**/ L11, L12,
          LBL, /**/ LBR,  L20, /**/ L21, L22 );

        //--------------------------------------------------------------------//
        // Y10 := A11 L10
        Zeros( Y10, A10.Height(), A10.Width() );
        Hemm( LEFT, LOWER, F(1), A11, L10, F(0), Y10 );

        // A10 := A10 + 1/2 Y10
        Axpy( F(1)/F(2), Y10, A10 );

        // A00 := A00 + (A10' L10 + L10' A10)
        Her2k( LOWER, ADJOINT, F(1), A10, L10, F(1), A00 );

        // A10 := A10 + 1/2 Y10
        Axpy( F(1)/F(2), Y10, A10 );

        // A10 := L11' A10
        Trmm( LEFT, LOWER, ADJOINT, diag, F(1), L11, A10 );

        // A20 := A20 + A21 L10
        Gemm( NORMAL, NORMAL, F(1), A21, L10, F(1), A20 );

        // A11 := L11' A11 L11
        TwoSidedTrmmLUnb( diag, A11, L11 );

        // A21 := A21 L11
        Trmm( RIGHT, LOWER, NORMAL, diag, F(1), L11, A21 );
        //--------------------------------------------------------------------//

        SlidePartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, A01, /**/ A02,
               /**/       A10, A11, /**/ A12,
         /*************/ /******************/
          ABL, /**/ ABR,  A20, A21, /**/ A22 );

        SlideLockedPartitionDownDiagonal
        ( LTL, /**/ LTR,  L00, L01, /**/ L02,
               /**/       L10, L11, /**/ L12,
         /*************/ /******************/
          LBL, /**/ LBR,  L20, L21, /**/ L22 );
    }
}
コード例 #24
0
ファイル: LVar4.hpp プロジェクト: khalid-hasanov/Elemental
inline void
TwoSidedTrmmLVar4
( UnitOrNonUnit diag, DistMatrix<F>& A, const DistMatrix<F>& L )
{
#ifndef RELEASE
    CallStackEntry entry("internal::TwoSidedTrmmLVar4");
    if( A.Height() != A.Width() )
        LogicError("A must be square");
    if( L.Height() != L.Width() )
        LogicError("Triangular matrices must be square");
    if( A.Height() != L.Height() )
        LogicError("A and L must be the same size");
#endif
    const Grid& g = A.Grid();

    // Matrix views
    DistMatrix<F>
        ATL(g), ATR(g),  A00(g), A01(g), A02(g),
        ABL(g), ABR(g),  A10(g), A11(g), A12(g),
                         A20(g), A21(g), A22(g);
    DistMatrix<F>
        LTL(g), LTR(g),  L00(g), L01(g), L02(g),
        LBL(g), LBR(g),  L10(g), L11(g), L12(g),
                         L20(g), L21(g), L22(g);

    // Temporary distributions
    DistMatrix<F,STAR,VR  > A10_STAR_VR(g);
    DistMatrix<F,STAR,MR  > A10_STAR_MR(g);
    DistMatrix<F,STAR,MC  > A10_STAR_MC(g);
    DistMatrix<F,STAR,STAR> A11_STAR_STAR(g);
    DistMatrix<F,VC,  STAR> A21_VC_STAR(g);
    DistMatrix<F,MC,  STAR> A21_MC_STAR(g);
    DistMatrix<F,STAR,VR  > L10_STAR_VR(g);
    DistMatrix<F,MR,  STAR> L10Adj_MR_STAR(g);
    DistMatrix<F,STAR,MC  > L10_STAR_MC(g);
    DistMatrix<F,STAR,STAR> L11_STAR_STAR(g);
    DistMatrix<F,STAR,VR  > Y10_STAR_VR(g);

    PartitionDownDiagonal
    ( A, ATL, ATR,
         ABL, ABR, 0 );
    LockedPartitionDownDiagonal
    ( L, LTL, LTR,
         LBL, LBR, 0 );
    while( ATL.Height() < A.Height() )
    {
        RepartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, /**/ A01, A02,
         /*************/ /******************/
               /**/       A10, /**/ A11, A12,
          ABL, /**/ ABR,  A20, /**/ A21, A22 );

        LockedRepartitionDownDiagonal
        ( LTL, /**/ LTR,  L00, /**/ L01, L02,
         /*************/ /******************/
               /**/       L10, /**/ L11, L12,
          LBL, /**/ LBR,  L20, /**/ L21, L22 );

        A10_STAR_VR.AlignWith( A00 );
        A10_STAR_MR.AlignWith( A00 );
        A10_STAR_MC.AlignWith( A00 );
        A21_MC_STAR.AlignWith( A20 );
        L10_STAR_VR.AlignWith( A00 );
        L10Adj_MR_STAR.AlignWith( A00 );
        L10_STAR_MC.AlignWith( A00 );
        Y10_STAR_VR.AlignWith( A10 );
        //--------------------------------------------------------------------//
        // Y10 := A11 L10
        A11_STAR_STAR = A11;
        L10Adj_MR_STAR.AdjointFrom( L10 );
        L10_STAR_VR.AdjointFrom( L10Adj_MR_STAR );
        Zeros( Y10_STAR_VR, A10.Height(), A10.Width() );
        Hemm
        ( LEFT, LOWER,
          F(1), A11_STAR_STAR.LockedMatrix(), L10_STAR_VR.LockedMatrix(),
          F(0), Y10_STAR_VR.Matrix() );

        // A10 := A10 + 1/2 Y10
        A10_STAR_VR = A10;
        Axpy( F(1)/F(2), Y10_STAR_VR, A10_STAR_VR );

        // A00 := A00 + (A10' L10 + L10' A10)
        A10_STAR_MR = A10_STAR_VR;
        A10_STAR_MC = A10_STAR_VR;
        L10_STAR_MC = L10_STAR_VR;
        LocalTrr2k
        ( LOWER, ADJOINT, ADJOINT, ADJOINT,
          F(1), A10_STAR_MC, L10Adj_MR_STAR, 
                L10_STAR_MC, A10_STAR_MR, 
          F(1), A00 );

        // A10 := A10 + 1/2 Y10
        Axpy( F(1)/F(2), Y10_STAR_VR, A10_STAR_VR );

        // A10 := L11' A10
        L11_STAR_STAR = L11;
        LocalTrmm
        ( LEFT, LOWER, ADJOINT, diag, F(1), L11_STAR_STAR, A10_STAR_VR );
        A10 = A10_STAR_VR;

        // A20 := A20 + A21 L10
        A21_MC_STAR = A21;
        LocalGemm
        ( NORMAL, ADJOINT, F(1), A21_MC_STAR, L10Adj_MR_STAR, F(1), A20 );

        // A11 := L11' A11 L11
        LocalTwoSidedTrmm( LOWER, diag, A11_STAR_STAR, L11_STAR_STAR );
        A11 = A11_STAR_STAR;

        // A21 := A21 L11
        A21_VC_STAR = A21_MC_STAR;
        LocalTrmm
        ( RIGHT, LOWER, NORMAL, diag, F(1), L11_STAR_STAR, A21_VC_STAR );
        A21 = A21_VC_STAR;
        //--------------------------------------------------------------------//

        SlidePartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, A01, /**/ A02,
               /**/       A10, A11, /**/ A12,
         /*************/ /******************/
          ABL, /**/ ABR,  A20, A21, /**/ A22 );

        SlideLockedPartitionDownDiagonal
        ( LTL, /**/ LTR,  L00, L01, /**/ L02,
               /**/       L10, L11, /**/ L12,
         /*************/ /******************/
          LBL, /**/ LBR,  L20, L21, /**/ L22 );
    }
}
コード例 #25
0
ファイル: UN.hpp プロジェクト: jimgoo/Elemental
inline void
TrsvUN( UnitOrNonUnit diag, const DistMatrix<F>& U, DistMatrix<F>& x )
{
#ifndef RELEASE
    PushCallStack("internal::TrsvUN");
    if( U.Grid() != x.Grid() )
        throw std::logic_error("{U,x} must be distributed over the same grid");
    if( U.Height() != U.Width() )
        throw std::logic_error("U must be square");
    if( x.Width() != 1 && x.Height() != 1 )
        throw std::logic_error("x must be a vector");
    const int xLength = ( x.Width() == 1 ? x.Height() : x.Width() );
    if( U.Width() != xLength )
        throw std::logic_error("Nonconformal TrsvUN");
#endif
    const Grid& g = U.Grid();

    if( x.Width() == 1 )
    {
        // Matrix views 
        DistMatrix<F> U01(g),
                      U11(g);
        DistMatrix<F> 
            xT(g),  x0(g),
            xB(g),  x1(g),
                    x2(g);

        // Temporary distributions
        DistMatrix<F,STAR,STAR> U11_STAR_STAR(g);
        DistMatrix<F,STAR,STAR> x1_STAR_STAR(g);
        DistMatrix<F,MR,  STAR> x1_MR_STAR(g);
        DistMatrix<F,MC,  STAR> z_MC_STAR(g);

        // Views of z[MC,* ], which will store updates to x
        DistMatrix<F,MC,STAR> z0_MC_STAR(g),
                              z1_MC_STAR(g);

        z_MC_STAR.AlignWith( U );
        Zeros( x.Height(), 1, z_MC_STAR );

        // Start the algorithm
        PartitionUp
        ( x, xT,
             xB, 0 );
        while( xT.Height() > 0 )
        {
            RepartitionUp
            ( xT,  x0,
                   x1,
             /**/ /**/
              xB,  x2 );

            const int n0 = x0.Height();
            const int n1 = x1.Height();
            LockedView( U01, U, 0,  n0, n0, n1 );
            LockedView( U11, U, n0, n0, n1, n1 );
            View( z0_MC_STAR, z_MC_STAR, 0,  0, n0, 1 );
            View( z1_MC_STAR, z_MC_STAR, n0, 0, n1, 1 );

            x1_MR_STAR.AlignWith( U01 );
            //----------------------------------------------------------------//
            if( x2.Height() != 0 )
                x1.SumScatterUpdate( F(1), z1_MC_STAR );

            x1_STAR_STAR = x1;
            U11_STAR_STAR = U11;
            Trsv
            ( UPPER, NORMAL, diag,
              U11_STAR_STAR.LockedLocalMatrix(),
              x1_STAR_STAR.LocalMatrix() );
            x1 = x1_STAR_STAR;

            x1_MR_STAR = x1_STAR_STAR;
            Gemv
            ( NORMAL, F(-1), 
              U01.LockedLocalMatrix(), 
              x1_MR_STAR.LockedLocalMatrix(),
              F(1), z0_MC_STAR.LocalMatrix() );
            //----------------------------------------------------------------//
            x1_MR_STAR.FreeAlignments();

            SlidePartitionUp
            ( xT,  x0,
             /**/ /**/
                   x1,
              xB,  x2 );
        }
    }
    else
    {
        // Matrix views 
        DistMatrix<F> U01(g),
                      U11(g);
        DistMatrix<F> 
            xL(g), xR(g),
            x0(g), x1(g), x2(g);

        // Temporary distributions
        DistMatrix<F,STAR,STAR> U11_STAR_STAR(g);
        DistMatrix<F,STAR,STAR> x1_STAR_STAR(g);
        DistMatrix<F,STAR,MR  > x1_STAR_MR(g);
        DistMatrix<F,MC,  MR  > z1(g);
        DistMatrix<F,MR,  MC  > z1_MR_MC(g);
        DistMatrix<F,STAR,MC  > z_STAR_MC(g);

        // Views of z[* ,MC]
        DistMatrix<F,STAR,MC>  z0_STAR_MC(g),
                               z1_STAR_MC(g);

        z_STAR_MC.AlignWith( U );
        Zeros( 1, x.Width(), z_STAR_MC );

        // Start the algorithm
        PartitionLeft( x,  xL, xR, 0 );
        while( xL.Width() > 0 )
        {
            RepartitionLeft
            ( xL,     /**/ xR,
              x0, x1, /**/ x2 );

            const int n0 = x0.Width();
            const int n1 = x1.Width();
            LockedView( U01, U, 0,  n0, n0, n1 );
            LockedView( U11, U, n0, n0, n1, n1 );
            View( z0_STAR_MC, z_STAR_MC, 0, 0,  1, n0 );
            View( z1_STAR_MC, z_STAR_MC, 0, n0, 1, n1 );

            x1_STAR_MR.AlignWith( U01 );
            z1.AlignWith( x1 );
            //----------------------------------------------------------------//
            if( x2.Width() != 0 )
            {
                z1_MR_MC.SumScatterFrom( z1_STAR_MC );
                z1 = z1_MR_MC;
                Axpy( F(1), z1, x1 );
            }

            x1_STAR_STAR = x1;
            U11_STAR_STAR = U11;
            Trsv
            ( UPPER, NORMAL, diag,
              U11_STAR_STAR.LockedLocalMatrix(),
              x1_STAR_STAR.LocalMatrix() );
            x1 = x1_STAR_STAR;

            x1_STAR_MR = x1_STAR_STAR;
            Gemv
            ( NORMAL, F(-1), 
              U01.LockedLocalMatrix(), 
              x1_STAR_MR.LockedLocalMatrix(),
              F(1), z0_STAR_MC.LocalMatrix() );
            //----------------------------------------------------------------//
            x1_STAR_MR.FreeAlignments();
            z1.FreeAlignments(); 

            SlidePartitionLeft
            ( xL, /**/ xR,
              x0, /**/ x1, x2 );
        }
    }
#ifndef RELEASE
    PopCallStack();
#endif
}
コード例 #26
0
ファイル: TT.hpp プロジェクト: certik/Elemental
inline void
GemmTTB
( Orientation orientationOfA, 
  Orientation orientationOfB,
  T alpha, const DistMatrix<T>& A,
           const DistMatrix<T>& B,
  T beta,        DistMatrix<T>& C )
{
#ifndef RELEASE
    PushCallStack("internal::GemmTTB");
    if( A.Grid() != B.Grid() || B.Grid() != C.Grid() )
        throw std::logic_error
        ("{A,B,C} must be distributed over the same grid");
    if( orientationOfA == NORMAL || orientationOfB == NORMAL )
        throw std::logic_error
        ("GemmTTB expects A and B to be (Conjugate)Transposed");
    if( A.Width()  != C.Height() ||
        B.Height() != C.Width()  ||
        A.Height() != B.Width()    )
    {
        std::ostringstream msg;
        msg << "Nonconformal GemmTTB: \n"
            << "  A ~ " << A.Height() << " x " << A.Width() << "\n"
            << "  B ~ " << B.Height() << " x " << B.Width() << "\n"
            << "  C ~ " << C.Height() << " x " << C.Width() << "\n";
        throw std::logic_error( msg.str().c_str() );
    }
#endif
    const Grid& g = A.Grid();

    // Matrix views
    DistMatrix<T> AL(g), AR(g),
                  A0(g), A1(g), A2(g);
    DistMatrix<T> CT(g),  C0(g),
                  CB(g),  C1(g),
                          C2(g);

    // Temporary distributions
    DistMatrix<T,VR,  STAR> A1_VR_STAR(g);
    DistMatrix<T,STAR,MR  > A1AdjOrTrans_STAR_MR(g);
    DistMatrix<T,STAR,MC  > D1_STAR_MC(g);
    DistMatrix<T,MR,  MC  > D1_MR_MC(g);
    DistMatrix<T> D1(g);

    A1_VR_STAR.AlignWith( B );
    A1AdjOrTrans_STAR_MR.AlignWith( B );
    D1_STAR_MC.AlignWith( B );

    // Start the algorithm 
    Scale( beta, C );
    LockedPartitionRight( A, AL, AR, 0 );
    PartitionDown
    ( C, CT,
         CB, 0 );
    while( AR.Width() > 0 )
    {
        LockedRepartitionRight
        ( AL, /**/     AR,
          A0, /**/ A1, A2 );
 
        RepartitionDown
        ( CT,  C0,
         /**/ /**/
               C1,
          CB,  C2 );

        D1.AlignWith( C1 );
        Zeros( C1.Height(), C1.Width(), D1_STAR_MC );
        //--------------------------------------------------------------------//
        A1_VR_STAR = A1;
        if( orientationOfA == ADJOINT )
            A1AdjOrTrans_STAR_MR.AdjointFrom( A1_VR_STAR );
        else
            A1AdjOrTrans_STAR_MR.TransposeFrom( A1_VR_STAR );
 
        // D1[*,MC] := alpha (A1[MR,*])^[T/H] (B[MC,MR])^[T/H]
        //           = alpha (A1^[T/H])[*,MR] (B^[T/H])[MR,MC]
        LocalGemm
        ( NORMAL, orientationOfB, 
          alpha, A1AdjOrTrans_STAR_MR, B, T(0), D1_STAR_MC );

        // C1[MC,MR] += scattered & transposed D1[*,MC] summed over grid rows
        D1_MR_MC.SumScatterFrom( D1_STAR_MC );
        D1 = D1_MR_MC; 
        Axpy( T(1), D1, C1 );
        //--------------------------------------------------------------------//
        D1.FreeAlignments();

        SlideLockedPartitionRight
        ( AL,     /**/ AR,
          A0, A1, /**/ A2 );

        SlidePartitionDown
        ( CT,  C0,
               C1,
         /**/ /**/
          CB,  C2 );
    }
#ifndef RELEASE
    PopCallStack();
#endif
}
コード例 #27
0
ファイル: UVar1.hpp プロジェクト: khalid-hasanov/Elemental
inline void
TwoSidedTrsmUVar1( UnitOrNonUnit diag, Matrix<F>& A, const Matrix<F>& U )
{
#ifndef RELEASE
    CallStackEntry entry("internal::TwoSidedTrsmUVar1");
    if( A.Height() != A.Width() )
        LogicError("A must be square");
    if( U.Height() != U.Width() )
        LogicError("Triangular matrices must be square");
    if( A.Height() != U.Height() )
        LogicError("A and U must be the same size");
#endif
    // Matrix views
    Matrix<F>
        ATL, ATR,  A00, A01, A02,
        ABL, ABR,  A10, A11, A12,
                   A20, A21, A22;
    Matrix<F>
        UTL, UTR,  U00, U01, U02,
        UBL, UBR,  U10, U11, U12,
                   U20, U21, U22;

    // Temporary products
    Matrix<F> Y01;

    PartitionDownDiagonal
    ( A, ATL, ATR,
         ABL, ABR, 0 );
    LockedPartitionDownDiagonal
    ( U, UTL, UTR,
         UBL, UBR, 0 );
    while( ATL.Height() < A.Height() )
    {
        RepartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, /**/ A01, A02,
         /*************/ /******************/
               /**/       A10, /**/ A11, A12,
          ABL, /**/ ABR,  A20, /**/ A21, A22 );

        LockedRepartitionDownDiagonal
        ( UTL, /**/ UTR,  U00, /**/ U01, U02,
         /*************/ /******************/
               /**/       U10, /**/ U11, U12,
          UBL, /**/ UBR,  U20, /**/ U21, U22 );

        //--------------------------------------------------------------------//
        // Y01 := A00 U01
        Zeros( Y01, A01.Height(), A01.Width() );
        Hemm( LEFT, UPPER, F(1), A00, U01, F(0), Y01 );

        // A01 := inv(U00)' A01
        Trsm( LEFT, UPPER, ADJOINT, diag, F(1), U00, A01 );

        // A01 := A01 - 1/2 Y01
        Axpy( F(-1)/F(2), Y01, A01 );

        // A11 := A11 - (U01' A01 + A01' U01)
        Her2k( UPPER, ADJOINT, F(-1), U01, A01, F(1), A11 );

        // A11 := inv(U11)' A11 inv(U11)
        TwoSidedTrsmUUnb( diag, A11, U11 );

        // A01 := A01 - 1/2 Y01
        Axpy( F(-1)/F(2), Y01, A01 );

        // A01 := A01 inv(U11)
        Trsm( RIGHT, UPPER, NORMAL, diag, F(1), U11, A01 );
        //--------------------------------------------------------------------//

        SlidePartitionDownDiagonal
        ( ATL, /**/ ATR,  A00, A01, /**/ A02,
               /**/       A10, A11, /**/ A12,
         /*************/ /******************/
          ABL, /**/ ABR,  A20, A21, /**/ A22 );

        SlideLockedPartitionDownDiagonal
        ( UTL, /**/ UTR,  U00, U01, /**/ U02,
               /**/       U10, U11, /**/ U12,
         /*************/ /******************/
          UBL, /**/ UBR,  U20, U21, /**/ U22 );
    }
}
コード例 #28
0
ファイル: TT.hpp プロジェクト: certik/Elemental
inline void
GemmTTA
( Orientation orientationOfA, 
  Orientation orientationOfB,
  T alpha, const DistMatrix<T>& A,
           const DistMatrix<T>& B,
  T beta,        DistMatrix<T>& C )
{
#ifndef RELEASE
    PushCallStack("internal::GemmTTA");
    if( A.Grid() != B.Grid() || B.Grid() != C.Grid() )
        throw std::logic_error
        ("{A,B,C} must be distributed over the same grid");
    if( orientationOfA == NORMAL || orientationOfB == NORMAL )
        throw std::logic_error
        ("GemmTTA expects A and B to be (Conjugate)Transposed");
    if( A.Width()  != C.Height() ||
        B.Height() != C.Width()  ||
        A.Height() != B.Width()    )
    {
        std::ostringstream msg;
        msg << "Nonconformal GemmTTA: \n"
            << "  A ~ " << A.Height() << " x " << A.Width() << "\n"
            << "  B ~ " << B.Height() << " x " << B.Width() << "\n"
            << "  C ~ " << C.Height() << " x " << C.Width() << "\n";
        throw std::logic_error( msg.str().c_str() );
    }
#endif
    const Grid& g = A.Grid();

    // Matrix views
    DistMatrix<T> BT(g),  B0(g),
                  BB(g),  B1(g),
                          B2(g);
    DistMatrix<T> CL(g), CR(g),
                  C0(g), C1(g), C2(g);

    // Temporary distributions
    DistMatrix<T,STAR,MC  > B1_STAR_MC(g);
    DistMatrix<T,MR,  STAR> D1_MR_STAR(g);
    DistMatrix<T,MR,  MC  > D1_MR_MC(g);
    DistMatrix<T> D1(g);

    B1_STAR_MC.AlignWith( A ); 
    D1_MR_STAR.AlignWith( A );  

    // Start the algorithm
    Scale( beta, C );
    LockedPartitionDown
    ( B, BT,
         BB, 0 );
    PartitionRight( C, CL, CR, 0 );
    while( BB.Height() > 0 )
    {
        LockedRepartitionDown
        ( BT,  B0,
         /**/ /**/
               B1,
          BB,  B2 );

        RepartitionRight
        ( CL, /**/     CR,
          C0, /**/ C1, C2 );

        D1.AlignWith( C1 );  
        Zeros( C1.Height(), C1.Width(), D1_MR_STAR );
        //--------------------------------------------------------------------//
        B1_STAR_MC = B1; // B1[*,MC] <- B1[MC,MR]

        // D1[MR,*] := alpha (A[MC,MR])^T (B1[*,MC])^T
        //           = alpha (A^T)[MR,MC] (B1^T)[MC,*]
        LocalGemm
        ( orientationOfA, orientationOfB, 
          alpha, A, B1_STAR_MC, T(0), D1_MR_STAR );

        // C1[MC,MR] += scattered & transposed D1[MR,*] summed over grid cols
        D1_MR_MC.SumScatterFrom( D1_MR_STAR );
        D1 = D1_MR_MC; 
        Axpy( T(1), D1, C1 );
        //--------------------------------------------------------------------//
        D1.FreeAlignments();
        
        SlideLockedPartitionDown
        ( BT,  B0,
               B1,
         /**/ /**/
          BB,  B2 );

        SlidePartitionRight
        ( CL,     /**/ CR,
          C0, C1, /**/ C2 ); 
    }
#ifndef RELEASE
    PopCallStack();
#endif
}
コード例 #29
0
ファイル: LLT.hpp プロジェクト: certik/Elemental
inline void
TrmmLLTCOld
( Orientation orientation, 
  UnitOrNonUnit diag,
  T alpha, 
  const DistMatrix<T>& L,
        DistMatrix<T>& X )
{
#ifndef RELEASE
    PushCallStack("internal::TrmmLLTCOld");
    if( L.Grid() != X.Grid() )
        throw std::logic_error
        ("L and X must be distributed over the same grid");
    if( orientation == NORMAL )
        throw std::logic_error("TrmmLLT expects a (Conjugate)Transpose option");
    if( L.Height() != L.Width() || L.Height() != X.Height() )
    {
        std::ostringstream msg;
        msg << "Nonconformal TrmmLLTC: \n"
            << "  L ~ " << L.Height() << " x " << L.Width() << "\n"
            << "  X ~ " << X.Height() << " x " << X.Width() << "\n";
        throw std::logic_error( msg.str().c_str() );
    }
#endif
    const Grid& g = L.Grid();

    // Matrix views
    DistMatrix<T> 
        LTL(g), LTR(g),  L00(g), L01(g), L02(g),
        LBL(g), LBR(g),  L10(g), L11(g), L12(g),
                         L20(g), L21(g), L22(g);
    DistMatrix<T> XT(g),  X0(g),
                  XB(g),  X1(g),
                          X2(g);

    // Temporary distributions
    DistMatrix<T,STAR,STAR> L11_STAR_STAR(g);
    DistMatrix<T,MC,  STAR> L21_MC_STAR(g);
    DistMatrix<T,STAR,VR  > X1_STAR_VR(g);
    DistMatrix<T,MR,  STAR> D1AdjOrTrans_MR_STAR(g);
    DistMatrix<T,MR,  MC  > D1AdjOrTrans_MR_MC(g);
    DistMatrix<T,MC,  MR  > D1(g);

    // Start the algorithm
    Scale( alpha, X );
    LockedPartitionDownDiagonal
    ( L, LTL, LTR,
         LBL, LBR, 0 );
    PartitionDown
    ( X, XT,
         XB, 0 );
    while( XB.Height() > 0 )
    {
        LockedRepartitionDownDiagonal
        ( LTL, /**/ LTR,  L00, /**/ L01, L02,
         /*************/ /******************/
               /**/       L10, /**/ L11, L12,
          LBL, /**/ LBR,  L20, /**/ L21, L22 );

        RepartitionDown
        ( XT,  X0,
         /**/ /**/
               X1,
          XB,  X2 ); 

        L21_MC_STAR.AlignWith( X2 );
        D1AdjOrTrans_MR_STAR.AlignWith( X1 );
        D1AdjOrTrans_MR_MC.AlignWith( X1 );
        D1.AlignWith( X1 );
        Zeros( X1.Width(), X1.Height(), D1AdjOrTrans_MR_STAR );
        Zeros( X1.Height(), X1.Width(), D1 );
        //--------------------------------------------------------------------//
        X1_STAR_VR = X1;
        L11_STAR_STAR = L11;
        LocalTrmm
        ( LEFT, LOWER, orientation, diag, T(1), L11_STAR_STAR, X1_STAR_VR );
        X1 = X1_STAR_VR;
 
        L21_MC_STAR = L21;
        LocalGemm
        ( orientation, NORMAL, 
          T(1), X2, L21_MC_STAR, T(0), D1AdjOrTrans_MR_STAR );
        D1AdjOrTrans_MR_MC.SumScatterFrom( D1AdjOrTrans_MR_STAR );
        if( orientation == TRANSPOSE )
            Transpose( D1AdjOrTrans_MR_MC.LocalMatrix(), D1.LocalMatrix() );
        else
            Adjoint( D1AdjOrTrans_MR_MC.LocalMatrix(), D1.LocalMatrix() );
        Axpy( T(1), D1, X1 );
        //--------------------------------------------------------------------//
        D1.FreeAlignments();
        D1AdjOrTrans_MR_MC.FreeAlignments();
        D1AdjOrTrans_MR_STAR.FreeAlignments();
        L21_MC_STAR.FreeAlignments();

        SlideLockedPartitionDownDiagonal
        ( LTL, /**/ LTR,  L00, L01, /**/ L02,
               /**/       L10, L11, /**/ L12,
         /*************/ /******************/
          LBL, /**/ LBR,  L20, L21, /**/ L22 );

        SlidePartitionDown
        ( XT,  X0,
               X1,
         /**/ /**/
          XB,  X2 );
    }
#ifndef RELEASE
    PopCallStack();
#endif
}
コード例 #30
0
ファイル: nmf_solver_rank2.hpp プロジェクト: beckgom/smallk
    // ------------------------------------------------------------------------
    //
    //         Update function - call on each iteration
    //
    // ------------------------------------------------------------------------
    bool operator() (const Matrix<T>& A,
                     DenseMatrix<T>& W,
                     DenseMatrix<T>& H,
                     DenseMatrix<T>& gradW,
                     DenseMatrix<T>& gradH)
    {
        //
        // Solve this problem for H:  min_{H>=0} |WH - A|^2
        //
        //         W is mx2, H is 2xn, A is mxn.
        //
        // Step 1: Solve the unconstrained least squares problem for H.
        //
        //         Objective function phi(H) = (WH - A)'(WH-A)
        //
        // The objective function is minimized for H satisfying this equation:
        //  
        //         W'W * H = W'A
        // 
        if (!SystemSolveH(H, WtW, WtA))
            return false;

        // Step 2: Adjust the solution of the unconstrained problem by 
        //         searching for the optimal active set.  
        //
        // The solver may return an H matrix that contains negative elements.
        // These negative elements need to be removed.  Remove them in a way
        // that minimizes the overall objective function.
        //
        OptimalActiveSetH(H, WtW, WtA);

        // Do the same for W:

        // compute the kxk matrix HHt = H * H'
        Gemm(NORMAL, TRANSPOSE, T(1), H, H, T(0), HHt);
        
        // compute the mxk matrix AHt =  A * H'
        Gemm(NORMAL, TRANSPOSE, T(1), A, H, T(0), AHt);

        //
        // Solve this problem for W:  min_{W>=0} |H'W' - A'|^2
        //
        //         W is mx2, H is 2xn, A is mxn
        //
        // Step 1: Solve the unconstrained least squares problem for W.
        //
        //         Objective function phi(W) = (H'W' - A')'(H'W' - A')
        //
        // The objective function is minimized for W' satisfying this equation:
        //
        //         HH' * W' = H'A, or the equivalent after a transpose
        //         W * (HH') = AH'
        //      
        // Solve this system for the unknown matrix W.
        //
        if (!SystemSolveW(W, HHt, AHt))
            return false;

        // Step 2: Adjust the solution of the unconstrained problem by 
        //         searching for the optimal active set.  
        //
        // The solver may return a W matrix that contains negative elements.
        // These negative elements need to be removed.  Remove them in a way
        // that minimizes the overall objective function.
        //
        OptimalActiveSetW(W, HHt, AHt);

        NormalizeAndScale(W, H, ScaleFactors);

        // At this point both W and H are fully updated.  Now compute
        // gradW = W*HHt - AHt and gradH = WtW*H - WtA.
        
        // First compute HHt and AHt.  HHt is a 2x2 matrix and can be updated
        // quickly by using the scale factors returned by the call to 
        // NormalizeAndScale.  A call to Gemm is NOT necessary.

        T s0 = ScaleFactors.Get(0, 0);
        T s1 = ScaleFactors.Get(1, 0);
        
        // update HH' for the rescaled H
        T elt00 = HHt.Get(0,0);
        T elt01 = HHt.Get(0,1);
        T elt11 = HHt.Get(1,1);
        HHt.Set(0, 0, elt00*s0*s0);
        HHt.Set(0, 1, elt01*s0*s1);
        HHt.Set(1, 0, elt01*s0*s1);  // since HH' is symmetric
        HHt.Set(1, 1, elt11*s1*s1);
        
        // update the mxk matrix AHt =  A * H' by scaling each column
        DiagonalScale(RIGHT, NORMAL, ScaleFactors, AHt);
        
        // Now gradW can be computed, since HHt and AHt are both current.
        Gemm(NORMAL, NORMAL, T(1), W, HHt, T(0), gradW);
        Axpy( T(-1), AHt, gradW);

        // Compute gradH by first updating WtW and WtA.
        Gemm(TRANSPOSE, NORMAL, T(1), W, W, T(0), WtW);
        Gemm(TRANSPOSE, NORMAL, T(1), W, A, T(0), WtA);
        Gemm(NORMAL, NORMAL, T(1), WtW, H, T(0), gradH);
        Axpy( T(-1), WtA, gradH);

        return true;
    }