Example #1
0
void Riccati
( ElementalMatrix<F>& WPre,
  ElementalMatrix<F>& X, 
  SignCtrl<Base<F>> ctrl )
{
    DEBUG_CSE

    DistMatrixReadProxy<F,F,MC,MR> WProx( WPre );
    auto& W = WProx.Get();

    const Grid& g = W.Grid();
    Sign( W, ctrl );
    const Int n = W.Height()/2;
    DistMatrix<F> WTL(g), WTR(g),
                  WBL(g), WBR(g);
    PartitionDownDiagonal
    ( W, WTL, WTR,
         WBL, WBR, n );

    // (ML, MR) = sgn(W) - I
    ShiftDiagonal( W, F(-1) );

    // Solve for X in ML X = -MR
    DistMatrix<F> ML(g), MR(g);
    PartitionRight( W, ML, MR, n );
    MR *= -1;
    ls::Overwrite( NORMAL, ML, MR, X );
}
Example #2
0
void Ricatti( Matrix<F>& W, Matrix<F>& X, SignCtrl<Base<F>> ctrl )
{
    DEBUG_ONLY(CallStackEntry cse("Ricatti"))
    Sign( W, ctrl );
    const Int n = W.Height()/2;
    Matrix<F> WTL, WTR,
              WBL, WBR;
    PartitionDownDiagonal
    ( W, WTL, WTR,
         WBL, WBR, n );

    // (ML, MR) = sgn(W) - I
    ShiftDiagonal( W, F(-1) );

    // Solve for X in ML X = -MR
    Matrix<F> ML, MR;
    PartitionRight( W, ML, MR, n );
    Scale( F(-1), MR );
    ls::Overwrite( NORMAL, ML, MR, X );
}
Example #3
0
void Riccati
( Matrix<F>& W, Matrix<F>& X, SignCtrl<Base<F>> ctrl )
{
    DEBUG_CSE
    Sign( W, ctrl );
    const Int n = W.Height()/2;
    Matrix<F> WTL, WTR,
              WBL, WBR;
    PartitionDownDiagonal
    ( W, WTL, WTR,
         WBL, WBR, n );

    // (ML, MR) = sgn(W) - I
    ShiftDiagonal( W, F(-1) );

    // Solve for X in ML X = -MR
    Matrix<F> ML, MR;
    PartitionRight( W, ML, MR, n );
    MR *= -1;
    ls::Overwrite( NORMAL, ML, MR, X );
}
Example #4
0
void Ridge
( Orientation orientation,
  const Matrix<Field>& A,
  const Matrix<Field>& B,
        Base<Field> gamma,
        Matrix<Field>& X,
  RidgeAlg alg )
{
    EL_DEBUG_CSE

    const bool normal = ( orientation==NORMAL );
    const Int m = ( normal ? A.Height() : A.Width()  );
    const Int n = ( normal ? A.Width()  : A.Height() );
    if( orientation == TRANSPOSE && IsComplex<Field>::value )
        LogicError("Transpose version of complex Ridge not yet supported");

    if( m >= n )
    {
        Matrix<Field> Z;
        if( alg == RIDGE_CHOLESKY )
        {
            if( orientation == NORMAL )
                Herk( LOWER, ADJOINT, Base<Field>(1), A, Z );
            else
                Herk( LOWER, NORMAL, Base<Field>(1), A, Z );
            ShiftDiagonal( Z, Field(gamma*gamma) );
            Cholesky( LOWER, Z );
            if( orientation == NORMAL )
                Gemm( ADJOINT, NORMAL, Field(1), A, B, X );
            else
                Gemm( NORMAL, NORMAL, Field(1), A, B, X );
            cholesky::SolveAfter( LOWER, NORMAL, Z, X );
        }
        else if( alg == RIDGE_QR )
        {
            Zeros( Z, m+n, n );
            auto ZT = Z( IR(0,m),   IR(0,n) );
            auto ZB = Z( IR(m,m+n), IR(0,n) );
            if( orientation == NORMAL )
                ZT = A;
            else
                Adjoint( A, ZT );
            FillDiagonal( ZB, Field(gamma) );
            // NOTE: This QR factorization could exploit the upper-triangular
            //       structure of the diagonal matrix ZB
            qr::ExplicitTriang( Z );
            if( orientation == NORMAL )
                Gemm( ADJOINT, NORMAL, Field(1), A, B, X );
            else
                Gemm( NORMAL, NORMAL, Field(1), A, B, X );
            cholesky::SolveAfter( LOWER, NORMAL, Z, X );
        }
        else
        {
            Matrix<Field> U, V;
            Matrix<Base<Field>> s;
            if( orientation == NORMAL )
            {
                SVDCtrl<Base<Field>> ctrl;
                ctrl.overwrite = false;
                SVD( A, U, s, V, ctrl );
            }
            else
            {
                Matrix<Field> AAdj;
                Adjoint( A, AAdj );

                SVDCtrl<Base<Field>> ctrl;
                ctrl.overwrite = true;
                SVD( AAdj, U, s, V, ctrl );
            }
            auto sigmaMap =
              [=]( const Base<Field>& sigma )
              { return sigma / (sigma*sigma + gamma*gamma); };
            EntrywiseMap( s, MakeFunction(sigmaMap) );
            Gemm( ADJOINT, NORMAL, Field(1), U, B, X );
            DiagonalScale( LEFT, NORMAL, s, X );
            U = X;
            Gemm( NORMAL, NORMAL, Field(1), V, U, X );
        }
    }
    else
    {
        LogicError("This case not yet supported");
    }
}