void NMF ( const ElementalMatrix<Real>& APre, ElementalMatrix<Real>& XPre, ElementalMatrix<Real>& YPre, const NMFCtrl<Real>& ctrl ) { DEBUG_ONLY(CSE cse("NMF")) DistMatrixReadProxy<Real,Real,MC,MR> AProx( APre ); DistMatrixReadWriteProxy<Real,Real,MC,MR> XProx( XPre ); DistMatrixWriteProxy<Real,Real,MC,MR> YProx( YPre ); auto& A = AProx.GetLocked(); auto& X = XProx.Get(); auto& Y = YProx.Get(); DistMatrix<Real> AAdj(A.Grid()), XAdj(A.Grid()), YAdj(A.Grid()); Adjoint( A, AAdj ); for( Int iter=0; iter<ctrl.maxIter; ++iter ) { NNLS( X, A, YAdj, ctrl.nnlsCtrl ); Adjoint( YAdj, Y ); NNLS( Y, AAdj, XAdj, ctrl.nnlsCtrl ); Adjoint( XAdj, X ); } }
void LUNMedium ( const AbstractDistMatrix<F>& UPre, AbstractDistMatrix<F>& XPre, bool checkIfSingular ) { DEBUG_CSE const Int m = XPre.Height(); const Int bsize = Blocksize(); const Grid& g = UPre.Grid(); DistMatrixReadProxy<F,F,MC,MR> UProx( UPre ); DistMatrixReadWriteProxy<F,F,MC,MR> XProx( XPre ); auto& U = UProx.GetLocked(); auto& X = XProx.Get(); DistMatrix<F,MC, STAR> U01_MC_STAR(g); DistMatrix<F,STAR,STAR> U11_STAR_STAR(g); DistMatrix<F,MR, STAR> X1Trans_MR_STAR(g); const Int kLast = LastOffset( m, bsize ); Int k=kLast, kOld=m; while( true ) { const bool in2x2 = ( k>0 && U.Get(k,k-1) != F(0) ); if( in2x2 ) --k; const Int nb = kOld-k; const Range<Int> ind0( 0, k ), ind1( k, k+nb ); auto U01 = U( ind0, ind1 ); auto U11 = U( ind1, ind1 ); auto X0 = X( ind0, ALL ); auto X1 = X( ind1, ALL ); U11_STAR_STAR = U11; // U11[* ,* ] <- U11[MC,MR] X1Trans_MR_STAR.AlignWith( X0 ); Transpose( X1, X1Trans_MR_STAR ); // X1^T[MR,* ] := X1^T[MR,* ] U11^-T[* ,* ] // = (U11^-1[* ,* ] X1[* ,MR])^T LocalQuasiTrsm ( RIGHT, UPPER, TRANSPOSE, F(1), U11_STAR_STAR, X1Trans_MR_STAR, checkIfSingular ); Transpose( X1Trans_MR_STAR, X1 ); U01_MC_STAR.AlignWith( X0 ); U01_MC_STAR = U01; // U01[MC,* ] <- U01[MC,MR] // X0[MC,MR] -= U01[MC,* ] X1[* ,MR] LocalGemm ( NORMAL, TRANSPOSE, F(-1), U01_MC_STAR, X1Trans_MR_STAR, F(1), X0 ); if( k == 0 ) break; kOld = k; k -= Min(bsize,k); } }
void LLNMedium ( const AbstractDistMatrix<F>& LPre, AbstractDistMatrix<F>& XPre, bool checkIfSingular ) { DEBUG_CSE const Int m = XPre.Height(); const Int bsize = Blocksize(); const Grid& g = LPre.Grid(); DistMatrixReadProxy<F,F,MC,MR> LProx( LPre ); DistMatrixReadWriteProxy<F,F,MC,MR> XProx( XPre ); auto& L = LProx.GetLocked(); auto& X = XProx.Get(); DistMatrix<F,STAR,STAR> L11_STAR_STAR(g); DistMatrix<F,MC, STAR> L21_MC_STAR(g); DistMatrix<F,MR, STAR> X1Trans_MR_STAR(g); for( Int k=0; k<m; k+=bsize ) { const Int nbProp = Min(bsize,m-k); const bool in2x2 = ( k+nbProp<m && L.Get(k+nbProp-1,k+nbProp) != F(0) ); const Int nb = ( in2x2 ? nbProp+1 : nbProp ); const Range<Int> ind1( k, k+nb ), ind2( k+nb, m ); auto L11 = L( ind1, ind1 ); auto L21 = L( ind2, ind1 ); auto X1 = X( ind1, ALL ); auto X2 = X( ind2, ALL ); L11_STAR_STAR = L11; // L11[* ,* ] <- L11[MC,MR] X1Trans_MR_STAR.AlignWith( X2 ); Transpose( X1, X1Trans_MR_STAR ); // X1^T[MR,* ] := X1^T[MR,* ] L11^-T[* ,* ] // = (L11^-1[* ,* ] X1[* ,MR])^T LocalQuasiTrsm ( RIGHT, LOWER, TRANSPOSE, F(1), L11_STAR_STAR, X1Trans_MR_STAR, checkIfSingular ); Transpose( X1Trans_MR_STAR, X1 ); L21_MC_STAR.AlignWith( X2 ); L21_MC_STAR = L21; // L21[MC,* ] <- L21[MC,MR] // X2[MC,MR] -= L21[MC,* ] X1[* ,MR] LocalGemm ( NORMAL, TRANSPOSE, F(-1), L21_MC_STAR, X1Trans_MR_STAR, F(1), X2 ); } }
void LUNMedium ( UnitOrNonUnit diag, const AbstractDistMatrix<F>& UPre, AbstractDistMatrix<F>& XPre, bool checkIfSingular ) { EL_DEBUG_CSE const Int m = XPre.Height(); const Int bsize = Blocksize(); const Grid& g = UPre.Grid(); DistMatrixReadProxy<F,F,MC,MR> UProx( UPre ); DistMatrixReadWriteProxy<F,F,MC,MR> XProx( XPre ); auto& U = UProx.GetLocked(); auto& X = XProx.Get(); DistMatrix<F,MC, STAR> U01_MC_STAR(g); DistMatrix<F,STAR,STAR> U11_STAR_STAR(g); DistMatrix<F,MR, STAR> X1Trans_MR_STAR(g); const Int kLast = LastOffset( m, bsize ); for( Int k=kLast; k>=0; k-=bsize ) { const Int nb = Min(bsize,m-k); const Range<Int> ind0( 0, k ), ind1( k, k+nb ); auto U01 = U( ind0, ind1 ); auto U11 = U( ind1, ind1 ); auto X0 = X( ind0, ALL ); auto X1 = X( ind1, ALL ); U11_STAR_STAR = U11; // U11[* ,* ] <- U11[MC,MR] X1Trans_MR_STAR.AlignWith( X0 ); Transpose( X1, X1Trans_MR_STAR ); // X1^T[MR,* ] := X1^T[MR,* ] U11^-T[* ,* ] // = (U11^-1[* ,* ] X1[* ,MR])^T LocalTrsm ( RIGHT, UPPER, TRANSPOSE, diag, F(1), U11_STAR_STAR, X1Trans_MR_STAR, checkIfSingular ); Transpose( X1Trans_MR_STAR, X1 ); U01_MC_STAR.AlignWith( X0 ); U01_MC_STAR = U01; // U01[MC,* ] <- U01[MC,MR] // X0[MC,MR] -= U01[MC,* ] X1[* ,MR] LocalGemm ( NORMAL, TRANSPOSE, F(-1), U01_MC_STAR, X1Trans_MR_STAR, F(1), X0 ); } }
void LUNLarge ( UnitOrNonUnit diag, const AbstractDistMatrix<F>& UPre, AbstractDistMatrix<F>& XPre, bool checkIfSingular ) { EL_DEBUG_CSE const Int m = XPre.Height(); const Int bsize = Blocksize(); const Grid& g = UPre.Grid(); DistMatrixReadProxy<F,F,MC,MR> UProx( UPre ); DistMatrixReadWriteProxy<F,F,MC,MR> XProx( XPre ); auto& U = UProx.GetLocked(); auto& X = XProx.Get(); DistMatrix<F,MC, STAR> U01_MC_STAR(g); DistMatrix<F,STAR,STAR> U11_STAR_STAR(g); DistMatrix<F,STAR,MR > X1_STAR_MR(g); DistMatrix<F,STAR,VR > X1_STAR_VR(g); const Int kLast = LastOffset( m, bsize ); for( Int k=kLast; k>=0; k-=bsize ) { const Int nb = Min(bsize,m-k); const Range<Int> ind0( 0, k ), ind1( k, k+nb ); auto U01 = U( ind0, ind1 ); auto U11 = U( ind1, ind1 ); auto X0 = X( ind0, ALL ); auto X1 = X( ind1, ALL ); U11_STAR_STAR = U11; // U11[* ,* ] <- U11[MC,MR] X1_STAR_VR = X1; // X1[* ,VR] <- X1[MC,MR] // X1[* ,VR] := U11^-1[* ,* ] X1[* ,VR] LocalTrsm ( LEFT, UPPER, NORMAL, diag, F(1), U11_STAR_STAR, X1_STAR_VR, checkIfSingular ); X1_STAR_MR.AlignWith( X0 ); X1_STAR_MR = X1_STAR_VR; // X1[* ,MR] <- X1[* ,VR] X1 = X1_STAR_MR; // X1[MC,MR] <- X1[* ,MR] U01_MC_STAR.AlignWith( X0 ); U01_MC_STAR = U01; // U01[MC,* ] <- U01[MC,MR] // X0[MC,MR] -= U01[MC,* ] X1[* ,MR] LocalGemm( NORMAL, NORMAL, F(-1), U01_MC_STAR, X1_STAR_MR, F(1), X0 ); } }
void LLNLarge ( const AbstractDistMatrix<F>& LPre, AbstractDistMatrix<F>& XPre, bool checkIfSingular ) { DEBUG_CSE const Int m = XPre.Height(); const Int bsize = Blocksize(); const Grid& g = LPre.Grid(); DistMatrixReadProxy<F,F,MC,MR> LProx( LPre ); DistMatrixReadWriteProxy<F,F,MC,MR> XProx( XPre ); auto& L = LProx.GetLocked(); auto& X = XProx.Get(); DistMatrix<F,STAR,STAR> L11_STAR_STAR(g); DistMatrix<F,MC, STAR> L21_MC_STAR(g); DistMatrix<F,STAR,MR > X1_STAR_MR(g); DistMatrix<F,STAR,VR > X1_STAR_VR(g); for( Int k=0; k<m; k+=bsize ) { const Int nbProp = Min(bsize,m-k); const bool in2x2 = ( k+nbProp<m && L.Get(k+nbProp-1,k+nbProp) != F(0) ); const Int nb = ( in2x2 ? nbProp+1 : nbProp ); const Range<Int> ind1( k, k+nb ), ind2( k+nb, m ); auto L11 = L( ind1, ind1 ); auto L21 = L( ind2, ind1 ); auto X1 = X( ind1, ALL ); auto X2 = X( ind2, ALL ); // X1[* ,VR] := L11^-1[* ,* ] X1[* ,VR] L11_STAR_STAR = L11; X1_STAR_VR = X1; LocalQuasiTrsm ( LEFT, LOWER, NORMAL, F(1), L11_STAR_STAR, X1_STAR_VR, checkIfSingular ); X1_STAR_MR.AlignWith( X2 ); X1_STAR_MR = X1_STAR_VR; // X1[* ,MR] <- X1[* ,VR] X1 = X1_STAR_MR; // X1[MC,MR] <- X1[* ,MR] L21_MC_STAR.AlignWith( X2 ); L21_MC_STAR = L21; // L21[MC,* ] <- L21[MC,MR] // X2[MC,MR] -= L21[MC,* ] X1[* ,MR] LocalGemm( NORMAL, NORMAL, F(-1), L21_MC_STAR, X1_STAR_MR, F(1), X2 ); } }
ValueInt<Base<F>> InverseFreeSignDivide( ElementalMatrix<F>& XPre ) { DEBUG_CSE DistMatrixReadWriteProxy<F,F,MC,MR> XProx( XPre ); auto& X = XProx.Get(); typedef Base<F> Real; const Grid& g = X.Grid(); const Int n = X.Width(); if( X.Height() != 2*n ) LogicError("Matrix should be 2n x n"); // Expose A and B, and then copy A auto B = X( IR(0,n ), ALL ); auto A = X( IR(n,2*n), ALL ); DistMatrix<F> ACopy( A ); // Run the inverse-free alternative to Sign InverseFreeSign( X ); // Compute the pivoted QR decomp of inv(A + B) A [See LAWN91] // 1) B := A + B // 2) [Q,R,Pi] := QRP(A) // 3) B := Q^H B // 4) [R,Q] := RQ(B) B += A; DistMatrix<F,MD,STAR> t(g); DistMatrix<Base<F>,MD,STAR> d(g); DistMatrix<Int,VR,STAR> p(g); QR( A, t, d, p ); qr::ApplyQ( LEFT, ADJOINT, A, t, d, B ); RQ( B, t, d ); // A := Q^H A Q A = ACopy; rq::ApplyQ( LEFT, ADJOINT, B, t, d, A ); rq::ApplyQ( RIGHT, NORMAL, B, t, d, A ); // Return || E21 ||1 / || A ||1 // Return || E21 ||1 / || A ||1 ValueInt<Real> part = ComputePartition( A ); part.value /= OneNorm(ACopy); return part; }
void QP ( const ElementalMatrix<Real>& APre, const ElementalMatrix<Real>& BPre, ElementalMatrix<Real>& XPre, const qp::direct::Ctrl<Real>& ctrl ) { DEBUG_CSE DistMatrixReadProxy<Real,Real,MC,MR> AProx( APre ), BProx( BPre ); DistMatrixWriteProxy<Real,Real,MC,MR> XProx( XPre ); auto& A = AProx.GetLocked(); auto& B = BProx.GetLocked(); auto& X = XProx.Get(); const Int n = A.Width(); const Int k = B.Width(); const Grid& g = A.Grid(); DistMatrix<Real> Q(g), AHat(g), bHat(g), c(g); Herk( LOWER, ADJOINT, Real(1), A, Q ); Zeros( AHat, 0, n ); Zeros( bHat, 0, 1 ); Zeros( X, n, k ); DistMatrix<Real> y(g), z(g); for( Int j=0; j<k; ++j ) { auto x = X( ALL, IR(j) ); auto b = B( ALL, IR(j) ); Zeros( c, n, 1 ); Gemv( ADJOINT, Real(-1), A, b, Real(0), c ); El::QP( Q, AHat, bHat, c, x, y, z, ctrl ); } }
int InverseFreeSign( ElementalMatrix<F>& XPre, Int maxIts=100, Base<F> tau=0 ) { DEBUG_CSE DistMatrixReadWriteProxy<F,F,MC,MR> XProx( XPre ); auto& X = XProx.Get(); typedef Base<F> Real; const Grid& g = X.Grid(); const Int n = X.Width(); if( X.Height() != 2*n ) LogicError("X must be 2n x n"); // Compute the tolerance if it is unset if( tau == Real(0) ) tau = n*limits::Epsilon<Real>(); // Expose A and B in the original and temporary DistMatrix<F> XAlt( 2*n, n, g ); auto B = X( IR(0,n ), ALL ); auto A = X( IR(n,2*n), ALL ); auto BAlt = XAlt( IR(0,n ), ALL ); auto AAlt = XAlt( IR(n,2*n), ALL ); // Flip the sign of A A *= -1; // Set up the space for explicitly computing the left half of Q DistMatrix<F,MD,STAR> t(g); DistMatrix<Base<F>,MD,STAR> d(g); DistMatrix<F> Q( 2*n, n, g ); auto Q12 = Q( IR(0,n ), ALL ); auto Q22 = Q( IR(n,2*n), ALL ); // Run the iterative algorithm Int numIts=0; DistMatrix<F> R(g), RLast(g); while( numIts < maxIts ) { XAlt = X; QR( XAlt, t, d ); // Form the left half of Q Zero( Q12 ); MakeIdentity( Q22 ); qr::ApplyQ( LEFT, NORMAL, XAlt, t, d, Q ); // Save a copy of R R = BAlt; MakeTrapezoidal( UPPER, R ); // Form the new iterate Gemm( ADJOINT, NORMAL, F(1), Q12, A, F(0), AAlt ); Gemm( ADJOINT, NORMAL, F(1), Q22, B, F(0), BAlt ); X = XAlt; // Use the difference in the iterates to test for convergence ++numIts; if( numIts > 1 ) { const Real oneRLast = OneNorm(RLast); AxpyTrapezoid( UPPER, F(-1), R, RLast ); const Real oneRDiff = OneNorm(RLast); if( oneRDiff <= tau*oneRLast ) break; } RLast = R; } // Revert the sign of A and return A *= -1; return numIts; }
void Ridge ( Orientation orientation, const AbstractDistMatrix<Field>& APre, const AbstractDistMatrix<Field>& BPre, Base<Field> gamma, AbstractDistMatrix<Field>& XPre, RidgeAlg alg ) { EL_DEBUG_CSE DistMatrixReadProxy<Field,Field,MC,MR> AProx( APre ), BProx( BPre ); DistMatrixWriteProxy<Field,Field,MC,MR> XProx( XPre ); auto& A = AProx.GetLocked(); auto& B = BProx.GetLocked(); auto& X = XProx.Get(); const bool normal = ( orientation==NORMAL ); const Int m = ( normal ? A.Height() : A.Width() ); const Int n = ( normal ? A.Width() : A.Height() ); if( orientation == TRANSPOSE && IsComplex<Field>::value ) LogicError("Transpose version of complex Ridge not yet supported"); if( m >= n ) { DistMatrix<Field> Z(A.Grid()); if( alg == RIDGE_CHOLESKY ) { if( orientation == NORMAL ) Herk( LOWER, ADJOINT, Base<Field>(1), A, Z ); else Herk( LOWER, NORMAL, Base<Field>(1), A, Z ); ShiftDiagonal( Z, Field(gamma*gamma) ); Cholesky( LOWER, Z ); if( orientation == NORMAL ) Gemm( ADJOINT, NORMAL, Field(1), A, B, X ); else Gemm( NORMAL, NORMAL, Field(1), A, B, X ); cholesky::SolveAfter( LOWER, NORMAL, Z, X ); } else if( alg == RIDGE_QR ) { Zeros( Z, m+n, n ); auto ZT = Z( IR(0,m), IR(0,n) ); auto ZB = Z( IR(m,m+n), IR(0,n) ); if( orientation == NORMAL ) ZT = A; else Adjoint( A, ZT ); FillDiagonal( ZB, Field(gamma) ); // NOTE: This QR factorization could exploit the upper-triangular // structure of the diagonal matrix ZB qr::ExplicitTriang( Z ); if( orientation == NORMAL ) Gemm( ADJOINT, NORMAL, Field(1), A, B, X ); else Gemm( NORMAL, NORMAL, Field(1), A, B, X ); cholesky::SolveAfter( LOWER, NORMAL, Z, X ); } else { DistMatrix<Field> U(A.Grid()), V(A.Grid()); DistMatrix<Base<Field>,VR,STAR> s(A.Grid()); if( orientation == NORMAL ) { SVDCtrl<Base<Field>> ctrl; ctrl.overwrite = false; SVD( A, U, s, V, ctrl ); } else { DistMatrix<Field> AAdj(A.Grid()); Adjoint( A, AAdj ); SVDCtrl<Base<Field>> ctrl; ctrl.overwrite = true; SVD( AAdj, U, s, V ); } auto sigmaMap = [=]( const Base<Field>& sigma ) { return sigma / (sigma*sigma + gamma*gamma); }; EntrywiseMap( s, MakeFunction(sigmaMap) ); Gemm( ADJOINT, NORMAL, Field(1), U, B, X ); DiagonalScale( LEFT, NORMAL, s, X ); U = X; Gemm( NORMAL, NORMAL, Field(1), V, U, X ); } } else { LogicError("This case not yet supported"); } }
void LUNLarge ( const ElementalMatrix<F>& UPre, ElementalMatrix<F>& XPre, bool checkIfSingular ) { DEBUG_CSE const Int m = XPre.Height(); const Int bsize = Blocksize(); const Grid& g = UPre.Grid(); DistMatrixReadProxy<F,F,MC,MR> UProx( UPre ); DistMatrixReadWriteProxy<F,F,MC,MR> XProx( XPre ); auto& U = UProx.GetLocked(); auto& X = XProx.Get(); DistMatrix<F,MC, STAR> U01_MC_STAR(g); DistMatrix<F,STAR,STAR> U11_STAR_STAR(g); DistMatrix<F,STAR,MR > X1_STAR_MR(g); DistMatrix<F,STAR,VR > X1_STAR_VR(g); const Int kLast = LastOffset( m, bsize ); Int k=kLast, kOld=m; while( true ) { const bool in2x2 = ( k>0 && U.Get(k,k-1) != F(0) ); if( in2x2 ) --k; const Int nb = kOld-k; const Range<Int> ind0( 0, k ), ind1( k, k+nb ); auto U01 = U( ind0, ind1 ); auto U11 = U( ind1, ind1 ); auto X0 = X( ind0, ALL ); auto X1 = X( ind1, ALL ); U11_STAR_STAR = U11; // U11[* ,* ] <- U11[MC,MR] X1_STAR_VR = X1; // X1[* ,VR] <- X1[MC,MR] // X1[* ,VR] := U11^-1[* ,* ] X1[* ,VR] LocalQuasiTrsm ( LEFT, UPPER, NORMAL, F(1), U11_STAR_STAR, X1_STAR_VR, checkIfSingular ); X1_STAR_MR.AlignWith( X0 ); X1_STAR_MR = X1_STAR_VR; // X1[* ,MR] <- X1[* ,VR] X1 = X1_STAR_MR; // X1[MC,MR] <- X1[* ,MR] U01_MC_STAR.AlignWith( X0 ); U01_MC_STAR = U01; // U01[MC,* ] <- U01[MC,MR] // X0[MC,MR] -= U01[MC,* ] X1[* ,MR] LocalGemm( NORMAL, NORMAL, F(-1), U01_MC_STAR, X1_STAR_MR, F(1), X0 ); if( k == 0 ) break; kOld = k; k -= Min(bsize,k); } }
void LLN ( UnitOrNonUnit diag, F alpha, const AbstractDistMatrix<F>& LPre, AbstractDistMatrix<F>& XPre, bool checkIfSingular ) { DEBUG_CSE const Int n = LPre.Height(); const Int bsize = Blocksize(); const Grid& g = LPre.Grid(); DistMatrixReadProxy<F,F,MC,MR> LProx( LPre ); DistMatrixReadWriteProxy<F,F,MC,MR> XProx( XPre ); auto& L = LProx.GetLocked(); auto& X = XProx.Get(); // Temporary distributions DistMatrix<F,STAR,STAR> L11_STAR_STAR(g), X11_STAR_STAR(g); DistMatrix<F,MC, STAR> L21_MC_STAR(g); DistMatrix<F,STAR,MR > X10_STAR_MR(g), X11_STAR_MR(g); DistMatrix<F,STAR,VR > X10_STAR_VR(g); ScaleTrapezoid( alpha, LOWER, X ); for( Int k=0; k<n; k+=bsize ) { const Int nb = Min(bsize,n-k); const Range<Int> ind0( 0, k ), ind1( k, k+nb ), ind2( k+nb, n ); auto L11 = L( ind1, ind1 ); auto L21 = L( ind2, ind1 ); auto X10 = X( ind1, ind0 ); auto X11 = X( ind1, ind1 ); auto X20 = X( ind2, ind0 ); auto X21 = X( ind2, ind1 ); L11_STAR_STAR = L11; X11_STAR_STAR = X11; X10_STAR_VR = X10; LocalTrsm ( LEFT, LOWER, NORMAL, diag, F(1), L11_STAR_STAR, X10_STAR_VR, checkIfSingular ); Trstrm ( LEFT, LOWER, NORMAL, diag, F(1), L11_STAR_STAR, X11_STAR_STAR, checkIfSingular ); X11 = X11_STAR_STAR; X11_STAR_MR.AlignWith( X21 ); X11_STAR_MR = X11_STAR_STAR; MakeTrapezoidal( LOWER, X11_STAR_MR ); X10_STAR_MR.AlignWith( X20 ); X10_STAR_MR = X10_STAR_VR; X10 = X10_STAR_MR; L21_MC_STAR.AlignWith( X20 ); L21_MC_STAR = L21; LocalGemm ( NORMAL, NORMAL, F(-1), L21_MC_STAR, X10_STAR_MR, F(1), X20 ); LocalGemm ( NORMAL, NORMAL, F(-1), L21_MC_STAR, X11_STAR_MR, F(1), X21 ); } }
void SolveAfter ( Orientation orientation, const ElementalMatrix<F>& APre, const ElementalMatrix<F>& householderScalars, const ElementalMatrix<Base<F>>& signature, const ElementalMatrix<F>& B, ElementalMatrix<F>& XPre ) { DEBUG_CSE const Int m = APre.Height(); const Int n = APre.Width(); if( m > n ) LogicError("Must have full row rank"); DistMatrixReadProxy<F,F,MC,MR> AProx( APre ); DistMatrixWriteProxy<F,F,MC,MR> XProx( XPre ); auto& A = AProx.GetLocked(); auto& X = XProx.Get(); X.Resize( n, B.Width() ); // TODO: Add scaling auto AL = A( IR(0,m), IR(0,m) ); if( orientation == NORMAL ) { if( m != B.Height() ) LogicError("A and B do not conform"); // Copy B into X auto XT = X( IR(0,m), ALL ); auto XB = X( IR(m,n), ALL ); XT = B; Zero( XB ); if( orientation == TRANSPOSE ) Conjugate( XT ); // Solve against L (checking for singularities) Trsm( LEFT, LOWER, NORMAL, NON_UNIT, F(1), AL, XT, true ); // Apply Q' to X lq::ApplyQ( LEFT, ADJOINT, A, householderScalars, signature, X ); if( orientation == TRANSPOSE ) Conjugate( X ); } else { // Copy B into X X = B; if( orientation == TRANSPOSE ) Conjugate( X ); // Apply Q to X lq::ApplyQ( LEFT, NORMAL, A, householderScalars, signature, X ); // Shrink X to its new height X.Resize( m, X.Width() ); // Solve against L' (check for singularities) Trsm( LEFT, LOWER, ADJOINT, NON_UNIT, F(1), AL, X, true ); if( orientation == TRANSPOSE ) Conjugate( X ); } }
void Tikhonov ( Orientation orientation, const ElementalMatrix<F>& APre, const ElementalMatrix<F>& BPre, const ElementalMatrix<F>& G, ElementalMatrix<F>& XPre, TikhonovAlg alg ) { DEBUG_CSE DistMatrixReadProxy<F,F,MC,MR> AProx( APre ), BProx( BPre ); DistMatrixWriteProxy<F,F,MC,MR> XProx( XPre ); auto& A = AProx.GetLocked(); auto& B = BProx.GetLocked(); auto& X = XProx.Get(); const bool normal = ( orientation==NORMAL ); const Int m = ( normal ? A.Height() : A.Width() ); const Int n = ( normal ? A.Width() : A.Height() ); if( G.Width() != n ) LogicError("Tikhonov matrix was the wrong width"); if( orientation == TRANSPOSE && IsComplex<F>::value ) LogicError("Transpose version of complex Tikhonov not yet supported"); if( m >= n ) { DistMatrix<F> Z(A.Grid()); if( alg == TIKHONOV_CHOLESKY ) { if( orientation == NORMAL ) Herk( LOWER, ADJOINT, Base<F>(1), A, Z ); else Herk( LOWER, NORMAL, Base<F>(1), A, Z ); Herk( LOWER, ADJOINT, Base<F>(1), G, Base<F>(1), Z ); Cholesky( LOWER, Z ); } else { const Int mG = G.Height(); Zeros( Z, m+mG, n ); auto ZT = Z( IR(0,m), IR(0,n) ); auto ZB = Z( IR(m,m+mG), IR(0,n) ); if( orientation == NORMAL ) ZT = A; else Adjoint( A, ZT ); ZB = G; qr::ExplicitTriang( Z ); } if( orientation == NORMAL ) Gemm( ADJOINT, NORMAL, F(1), A, B, X ); else Gemm( NORMAL, NORMAL, F(1), A, B, X ); cholesky::SolveAfter( LOWER, NORMAL, Z, X ); } else { LogicError("This case not yet supported"); } }