void PushPairInto ( DistMultiVec<Real>& s, DistMultiVec<Real>& z, const DistMultiVec<Real>& w, const DistMultiVec<Int>& orders, const DistMultiVec<Int>& firstInds, Real wMaxNormLimit, Int cutoff ) { DEBUG_ONLY(CSE cse("soc::PushPairInto")) DistMultiVec<Real> sLower(s.Comm()), zLower(z.Comm()); soc::LowerNorms( s, sLower, orders, firstInds, cutoff ); soc::LowerNorms( z, zLower, orders, firstInds, cutoff ); const int localHeight = s.LocalHeight(); for( Int iLoc=0; iLoc<localHeight; ++iLoc ) { const Int i = s.GlobalRow(iLoc); const Real w0 = w.GetLocal(iLoc,0); if( i == firstInds.GetLocal(iLoc,0) && w0 > wMaxNormLimit ) { // TODO: Switch to a non-adhoc modification s.UpdateLocal( iLoc, 0, Real(1)/wMaxNormLimit ); z.UpdateLocal( iLoc, 0, Real(1)/wMaxNormLimit ); } } }
void SOCSquareRoot ( const DistMultiVec<Real>& x, DistMultiVec<Real>& xRoot, const DistMultiVec<Int>& orders, const DistMultiVec<Int>& firstInds, Int cutoff ) { DEBUG_ONLY(CSE cse("SOCSquareRoot")) DistMultiVec<Real> d(x.Comm()); SOCDets( x, d, orders, firstInds ); ConeBroadcast( d, orders, firstInds ); auto roots = x; ConeBroadcast( roots, orders, firstInds ); const Int localHeight = x.LocalHeight(); xRoot.SetComm( x.Comm() ); Zeros( xRoot, x.Height(), 1 ); for( Int iLoc=0; iLoc<localHeight; ++iLoc ) { const Int i = x.GlobalRow(iLoc); const Real x0 = roots.GetLocal(iLoc,0); const Real det = d.GetLocal(iLoc,0); const Real eta0 = Sqrt(x0+Sqrt(det))/Sqrt(Real(2)); if( i == firstInds.GetLocal(iLoc,0) ) xRoot.SetLocal( iLoc, 0, eta0 ); else xRoot.SetLocal( iLoc, 0, x.GetLocal(iLoc,0)/(2*eta0) ); } }
void PushInto ( DistMultiVec<Real>& x, const DistMultiVec<Int>& orders, const DistMultiVec<Int>& firstInds, Real minDist, Int cutoff ) { DEBUG_ONLY(CSE cse("soc::PushInto")) DistMultiVec<Real> d(x.Comm()); soc::LowerNorms( x, d, orders, firstInds, cutoff ); const int localHeight = x.LocalHeight(); for( Int iLoc=0; iLoc<localHeight; ++iLoc ) { const Int i = x.GlobalRow(iLoc); const Real x0 = x.GetLocal(iLoc,0); const Real lowerNorm = d.GetLocal(iLoc,0); if( i == firstInds.GetLocal(iLoc,0) && x0-lowerNorm < minDist ) x.UpdateLocal( iLoc, 0, minDist - (x0-lowerNorm) ); } }
void SOCApply ( const DistMultiVec<Real>& x, const DistMultiVec<Real>& y, DistMultiVec<Real>& z, const DistMultiVec<Int>& orders, const DistMultiVec<Int>& firstInds, Int cutoff ) { DEBUG_ONLY(CSE cse("SOCApply")) SOCDots( x, y, z, orders, firstInds ); auto xRoots = x; auto yRoots = y; ConeBroadcast( xRoots, orders, firstInds ); ConeBroadcast( yRoots, orders, firstInds ); const Int localHeight = x.LocalHeight(); for( Int iLoc=0; iLoc<localHeight; ++iLoc ) { const Int i = x.GlobalRow(iLoc); const Int firstInd = firstInds.GetLocal(iLoc,0); if( i != firstInd ) z.UpdateLocal ( iLoc, 0, xRoots.GetLocal(iLoc,0)*y.GetLocal(iLoc,0) + yRoots.GetLocal(iLoc,0)*x.GetLocal(iLoc,0) ); } }
void Tikhonov ( Orientation orientation, const DistSparseMatrix<F>& A, const DistMultiVec<F>& B, const DistSparseMatrix<F>& G, DistMultiVec<F>& X, const LeastSquaresCtrl<Base<F>>& ctrl ) { DEBUG_CSE mpi::Comm comm = A.Comm(); // Explicitly form W := op(A) // ========================== DistSparseMatrix<F> W(comm); if( orientation == NORMAL ) W = A; else if( orientation == TRANSPOSE ) Transpose( A, W ); else Adjoint( A, W ); const Int m = W.Height(); const Int n = W.Width(); const Int numRHS = B.Width(); // Embed into a higher-dimensional problem via appending regularization // ==================================================================== DistSparseMatrix<F> WEmb(comm); if( m >= n ) VCat( W, G, WEmb ); else HCat( W, G, WEmb ); DistMultiVec<F> BEmb(comm); Zeros( BEmb, WEmb.Height(), numRHS ); if( m >= n ) { // BEmb := [B; 0] // -------------- const Int mLocB = B.LocalHeight(); BEmb.Reserve( mLocB*numRHS ); for( Int iLoc=0; iLoc<mLocB; ++iLoc ) { const Int i = B.GlobalRow(iLoc); for( Int j=0; j<numRHS; ++j ) BEmb.QueueUpdate( i, j, B.GetLocal(iLoc,j) ); } BEmb.ProcessQueues(); } else BEmb = B; // Solve the higher-dimensional problem // ==================================== DistMultiVec<F> XEmb(comm); LeastSquares( NORMAL, WEmb, BEmb, XEmb, ctrl ); // Extract the solution // ==================== if( m >= n ) X = XEmb; else GetSubmatrix( XEmb, IR(0,n), IR(0,numRHS), X ); }