void Riccati ( ElementalMatrix<F>& WPre, ElementalMatrix<F>& X, SignCtrl<Base<F>> ctrl ) { DEBUG_CSE DistMatrixReadProxy<F,F,MC,MR> WProx( WPre ); auto& W = WProx.Get(); const Grid& g = W.Grid(); Sign( W, ctrl ); const Int n = W.Height()/2; DistMatrix<F> WTL(g), WTR(g), WBL(g), WBR(g); PartitionDownDiagonal ( W, WTL, WTR, WBL, WBR, n ); // (ML, MR) = sgn(W) - I ShiftDiagonal( W, F(-1) ); // Solve for X in ML X = -MR DistMatrix<F> ML(g), MR(g); PartitionRight( W, ML, MR, n ); MR *= -1; ls::Overwrite( NORMAL, ML, MR, X ); }
void Ricatti( Matrix<F>& W, Matrix<F>& X, SignCtrl<Base<F>> ctrl ) { DEBUG_ONLY(CallStackEntry cse("Ricatti")) Sign( W, ctrl ); const Int n = W.Height()/2; Matrix<F> WTL, WTR, WBL, WBR; PartitionDownDiagonal ( W, WTL, WTR, WBL, WBR, n ); // (ML, MR) = sgn(W) - I ShiftDiagonal( W, F(-1) ); // Solve for X in ML X = -MR Matrix<F> ML, MR; PartitionRight( W, ML, MR, n ); Scale( F(-1), MR ); ls::Overwrite( NORMAL, ML, MR, X ); }
void Riccati ( Matrix<F>& W, Matrix<F>& X, SignCtrl<Base<F>> ctrl ) { DEBUG_CSE Sign( W, ctrl ); const Int n = W.Height()/2; Matrix<F> WTL, WTR, WBL, WBR; PartitionDownDiagonal ( W, WTL, WTR, WBL, WBR, n ); // (ML, MR) = sgn(W) - I ShiftDiagonal( W, F(-1) ); // Solve for X in ML X = -MR Matrix<F> ML, MR; PartitionRight( W, ML, MR, n ); MR *= -1; ls::Overwrite( NORMAL, ML, MR, X ); }
void Ridge ( Orientation orientation, const Matrix<Field>& A, const Matrix<Field>& B, Base<Field> gamma, Matrix<Field>& X, RidgeAlg alg ) { EL_DEBUG_CSE const bool normal = ( orientation==NORMAL ); const Int m = ( normal ? A.Height() : A.Width() ); const Int n = ( normal ? A.Width() : A.Height() ); if( orientation == TRANSPOSE && IsComplex<Field>::value ) LogicError("Transpose version of complex Ridge not yet supported"); if( m >= n ) { Matrix<Field> Z; if( alg == RIDGE_CHOLESKY ) { if( orientation == NORMAL ) Herk( LOWER, ADJOINT, Base<Field>(1), A, Z ); else Herk( LOWER, NORMAL, Base<Field>(1), A, Z ); ShiftDiagonal( Z, Field(gamma*gamma) ); Cholesky( LOWER, Z ); if( orientation == NORMAL ) Gemm( ADJOINT, NORMAL, Field(1), A, B, X ); else Gemm( NORMAL, NORMAL, Field(1), A, B, X ); cholesky::SolveAfter( LOWER, NORMAL, Z, X ); } else if( alg == RIDGE_QR ) { Zeros( Z, m+n, n ); auto ZT = Z( IR(0,m), IR(0,n) ); auto ZB = Z( IR(m,m+n), IR(0,n) ); if( orientation == NORMAL ) ZT = A; else Adjoint( A, ZT ); FillDiagonal( ZB, Field(gamma) ); // NOTE: This QR factorization could exploit the upper-triangular // structure of the diagonal matrix ZB qr::ExplicitTriang( Z ); if( orientation == NORMAL ) Gemm( ADJOINT, NORMAL, Field(1), A, B, X ); else Gemm( NORMAL, NORMAL, Field(1), A, B, X ); cholesky::SolveAfter( LOWER, NORMAL, Z, X ); } else { Matrix<Field> U, V; Matrix<Base<Field>> s; if( orientation == NORMAL ) { SVDCtrl<Base<Field>> ctrl; ctrl.overwrite = false; SVD( A, U, s, V, ctrl ); } else { Matrix<Field> AAdj; Adjoint( A, AAdj ); SVDCtrl<Base<Field>> ctrl; ctrl.overwrite = true; SVD( AAdj, U, s, V, ctrl ); } auto sigmaMap = [=]( const Base<Field>& sigma ) { return sigma / (sigma*sigma + gamma*gamma); }; EntrywiseMap( s, MakeFunction(sigmaMap) ); Gemm( ADJOINT, NORMAL, Field(1), U, B, X ); DiagonalScale( LEFT, NORMAL, s, X ); U = X; Gemm( NORMAL, NORMAL, Field(1), V, U, X ); } } else { LogicError("This case not yet supported"); } }