ESymSolverStatus AugRestoSystemSolver::Solve(const SymMatrix* W,
      double W_factor,
      const Vector* D_x,
      double delta_x,
      const Vector* D_s,
      double delta_s,
      const Matrix* J_c,
      const Vector* D_c,
      double delta_c,
      const Matrix* J_d,
      const Vector* D_d,
      double delta_d,
      const Vector& rhs_x,
      const Vector& rhs_s,
      const Vector& rhs_c,
      const Vector& rhs_d,
      Vector& sol_x,
      Vector& sol_s,
      Vector& sol_c,
      Vector& sol_d,
      bool check_NegEVals,
      Index numberOfNegEVals)
  {
    DBG_START_METH("AugRestoSystemSolver::Solve",dbg_verbosity);
    DBG_ASSERT(J_c && J_d); // should pass these by ref

    // I think the comment below is incorrect
    // Remember, W and the D's may be NULL!
    // ToDo: I don't think the W's can ever be NULL (we always need the structure)
    DBG_ASSERT(W);

    SmartPtr<const CompoundSymMatrix> CW =
      static_cast<const CompoundSymMatrix*>(W);

    SmartPtr<const CompoundVector> CD_x =
      static_cast<const CompoundVector*>(D_x);

    SmartPtr<const CompoundMatrix> CJ_c =
      static_cast<const CompoundMatrix*>(J_c);
    DBG_ASSERT(IsValid(CJ_c));

    SmartPtr<const CompoundMatrix> CJ_d =
      static_cast<const CompoundMatrix*>(J_d);
    DBG_ASSERT(IsValid(CJ_d));

    SmartPtr<const CompoundVector> Crhs_x =
      static_cast<const CompoundVector*>(&rhs_x);
    DBG_ASSERT(IsValid(Crhs_x));

    SmartPtr<CompoundVector> Csol_x = static_cast<CompoundVector*>(&sol_x);
    DBG_ASSERT(IsValid(Csol_x));

    // Get the Sigma inverses
    SmartPtr<const Vector> sigma_n_c;
    SmartPtr<const Vector> sigma_p_c;
    SmartPtr<const Vector> sigma_n_d;
    SmartPtr<const Vector> sigma_p_d;

    if (IsValid(CD_x)) {
      sigma_n_c = CD_x->GetComp(1);
      sigma_p_c = CD_x->GetComp(2);
      sigma_n_d = CD_x->GetComp(3);
      sigma_p_d = CD_x->GetComp(4);
    }

    SmartPtr<const Vector> sigma_tilde_n_c_inv =
      Sigma_tilde_n_c_inv(sigma_n_c, delta_x, *Crhs_x->GetComp(1));
    SmartPtr<const Vector> sigma_tilde_p_c_inv =
      Sigma_tilde_p_c_inv(sigma_p_c, delta_x, *Crhs_x->GetComp(2));
    SmartPtr<const Vector> sigma_tilde_n_d_inv =
      Sigma_tilde_n_d_inv(sigma_n_d, delta_x, *Crhs_x->GetComp(3));
    SmartPtr<const Vector> sigma_tilde_p_d_inv =
      Sigma_tilde_p_d_inv(sigma_p_d, delta_x, *Crhs_x->GetComp(4));

    // Pull out the expansion matrices for d
    SmartPtr<const Matrix> pd_l = CJ_d->GetComp(0,3);
    SmartPtr<const Matrix> neg_pd_u = CJ_d->GetComp(0,4);

    // Now map the correct entries into the Solve method
    // pull out the parts of the hessian h_orig + diag
    DBG_PRINT_MATRIX(2, "CW", *CW);
    SmartPtr<const SymMatrix> h_orig;
    SmartPtr<const Vector> D_xR;
    SmartPtr<const SumSymMatrix> WR_sum =
      dynamic_cast<const SumSymMatrix*>(GetRawPtr(CW->GetComp(0,0)));
    Number orig_W_factor = W_factor;
    if (IsValid(WR_sum)) {
      // We seem to be in the regular situation with exact second
      // derivatives
      double temp_factor;
      WR_sum->GetTerm(0, temp_factor, h_orig);
      DBG_ASSERT(temp_factor == 1. || temp_factor == 0.);
      orig_W_factor = temp_factor * W_factor;
      SmartPtr<const SymMatrix> eta_DR;
      double factor;
      WR_sum->GetTerm(1, factor, eta_DR);
      SmartPtr<const Vector> wr_d =
        static_cast<const DiagMatrix*>(GetRawPtr(eta_DR))->GetDiag();

      if (IsValid(CD_x)) {
        D_xR = D_x_plus_wr_d(CD_x->GetComp(0), factor, *wr_d);
      }
      else {
        D_xR = D_x_plus_wr_d(NULL, factor, *wr_d);
      }
    }
    else {
      // Looks like limited memory quasi-Newton stuff
      const LowRankUpdateSymMatrix* LR_W =
        static_cast<const LowRankUpdateSymMatrix*>(GetRawPtr(CW->GetComp(0,0)));
      DBG_ASSERT(LR_W);
      h_orig = LR_W;
      if (IsValid(CD_x)) {
        D_xR = CD_x->GetComp(0);
      }
      else {
        D_xR = NULL;
      }
    }

    Number delta_xR = delta_x;
    SmartPtr<const Vector> D_sR = D_s;
    Number delta_sR = delta_s;
    SmartPtr<const Matrix> J_cR = CJ_c->GetComp(0,0);
    SmartPtr<const Vector> D_cR =
      Neg_Omega_c_plus_D_c(sigma_tilde_n_c_inv, sigma_tilde_p_c_inv,
                           D_c, rhs_c);
    DBG_PRINT((1,"D_cR tag = %d\n",D_cR->GetTag()));
    Number delta_cR = delta_c;
    SmartPtr<const Matrix> J_dR = CJ_d->GetComp(0,0);
    SmartPtr<const Vector> D_dR =
      Neg_Omega_d_plus_D_d(*pd_l, sigma_tilde_n_d_inv, *neg_pd_u,
                           sigma_tilde_p_d_inv, D_d, rhs_d);
    Number delta_dR = delta_d;
    SmartPtr<const Vector> rhs_xR = Crhs_x->GetComp(0);
    SmartPtr<const Vector> rhs_sR = &rhs_s;
    SmartPtr<const Vector> rhs_cR = Rhs_cR(rhs_c, sigma_tilde_n_c_inv,
                                           *Crhs_x->GetComp(1),
                                           sigma_tilde_p_c_inv,
                                           *Crhs_x->GetComp(2));
    SmartPtr<const Vector> rhs_dR = Rhs_dR(rhs_d, sigma_tilde_n_d_inv,
                                           *Crhs_x->GetComp(3), *pd_l,
                                           sigma_tilde_p_d_inv,
                                           *Crhs_x->GetComp(4), *neg_pd_u);
    SmartPtr<Vector> sol_xR = Csol_x->GetCompNonConst(0);
    Vector& sol_sR = sol_s;
    Vector& sol_cR = sol_c;
    Vector& sol_dR = sol_d;

    ESymSolverStatus status = orig_aug_solver_->Solve(GetRawPtr(h_orig),
                              orig_W_factor,
                              GetRawPtr(D_xR), delta_xR,
                              GetRawPtr(D_sR), delta_sR,
                              GetRawPtr(J_cR), GetRawPtr(D_cR),
                              delta_cR,
                              GetRawPtr(J_dR), GetRawPtr(D_dR),
                              delta_dR,
                              *rhs_xR, *rhs_sR, *rhs_cR, *rhs_dR,
                              *sol_xR, sol_sR, sol_cR, sol_dR,
                              check_NegEVals,
                              numberOfNegEVals);

    if (status == SYMSOLVER_SUCCESS) {
      // Now back out the solutions for the n and p variables
      SmartPtr<Vector> sol_n_c = Csol_x->GetCompNonConst(1);
      sol_n_c->Set(0.0);
      if (IsValid(sigma_tilde_n_c_inv)) {
        sol_n_c->AddTwoVectors(1., *Crhs_x->GetComp(1), -1.0, sol_cR, 0.);
        sol_n_c->ElementWiseMultiply(*sigma_tilde_n_c_inv);
      }

      SmartPtr<Vector> sol_p_c = Csol_x->GetCompNonConst(2);
      sol_p_c->Set(0.0);
      if (IsValid(sigma_tilde_p_c_inv)) {
        DBG_PRINT_VECTOR(2, "rhs_pc", *Crhs_x->GetComp(2));
        DBG_PRINT_VECTOR(2, "delta_y_c", sol_cR);
        DBG_PRINT_VECTOR(2, "Sig~_{p_c}^{-1}", *sigma_tilde_p_c_inv);
        sol_p_c->AddTwoVectors(1., *Crhs_x->GetComp(2), 1.0, sol_cR, 0.);
        sol_p_c->ElementWiseMultiply(*sigma_tilde_p_c_inv);
      }

      SmartPtr<Vector> sol_n_d = Csol_x->GetCompNonConst(3);
      sol_n_d->Set(0.0);
      if (IsValid(sigma_tilde_n_d_inv)) {
        pd_l->TransMultVector(-1.0, sol_dR, 0.0, *sol_n_d);
        sol_n_d->Axpy(1.0, *Crhs_x->GetComp(3));
        sol_n_d->ElementWiseMultiply(*sigma_tilde_n_d_inv);
      }

      SmartPtr<Vector> sol_p_d = Csol_x->GetCompNonConst(4);
      sol_p_d->Set(0.0);
      if (IsValid(sigma_tilde_p_d_inv)) {
        neg_pd_u->TransMultVector(-1.0, sol_dR, 0.0, *sol_p_d);
        sol_p_d->Axpy(1.0, *Crhs_x->GetComp(4));
        sol_p_d->ElementWiseMultiply(*sigma_tilde_p_d_inv);
      }
    }

    return status;

  }
Beispiel #2
0
ESymSolverStatus LowRankAugSystemSolver::UpdateFactorization(
   const SymMatrix* W,
   double           W_factor,
   const Vector*    D_x,
   double           delta_x,
   const Vector*    D_s,
   double           delta_s,
   const Matrix&    J_c,
   const Vector*    D_c,
   double           delta_c,
   const Matrix&    J_d,
   const Vector*    D_d,
   double           delta_d,
   const Vector&    proto_rhs_x,
   const Vector&    proto_rhs_s,
   const Vector&    proto_rhs_c,
   const Vector&    proto_rhs_d,
   bool             check_NegEVals,
   Index            numberOfNegEVals
   )
{
   DBG_START_METH("LowRankAugSystemSolver::UpdateFactorization",
      dbg_verbosity);

   DBG_ASSERT(W_factor == 0.0 || W_factor == 1.0);
   ESymSolverStatus retval = SYMSOLVER_SUCCESS;

   // Get the low update information out of W
   const LowRankUpdateSymMatrix* LR_W = static_cast<const LowRankUpdateSymMatrix*>(W);
   DBG_ASSERT(LR_W); DBG_PRINT_MATRIX(2, "LR_W", *LR_W);

   SmartPtr<const Vector> B0;
   SmartPtr<const MultiVectorMatrix> V;
   SmartPtr<const MultiVectorMatrix> U;
   if( W_factor == 1.0 )
   {
      V = LR_W->GetV();
      U = LR_W->GetU();
      B0 = LR_W->GetDiag();
   }
   SmartPtr<const Matrix> P_LM = LR_W->P_LowRank();
   SmartPtr<const VectorSpace> LR_VecSpace = LR_W->LowRankVectorSpace();

   if( IsNull(B0) )
   {
      SmartPtr<Vector> zero_B0 = (IsValid(P_LM)) ? LR_VecSpace->MakeNew() : proto_rhs_x.MakeNew();
      zero_B0->Set(0.0);
      B0 = GetRawPtr(zero_B0);
   }

   // set up the Hessian for the underlying augmented system solver
   // without the low-rank update
   if( IsValid(P_LM) && LR_W->ReducedDiag() )
   {
      DBG_ASSERT(IsValid(B0));
      SmartPtr<Vector> fullx = proto_rhs_x.MakeNew();
      P_LM->MultVector(1., *B0, 0., *fullx);
      Wdiag_->SetDiag(*fullx);
   }
   else
   {
      Wdiag_->SetDiag(*B0);
      DBG_PRINT_VECTOR(2, "B0", *B0);
   }

   SmartPtr<MultiVectorMatrix> Vtilde1_x;
   if( IsValid(V) )
   {
      SmartPtr<MultiVectorMatrix> V_x;
      Index nV = V->NCols();
      //DBG_PRINT((1, "delta_x  = %e\n", delta_x));
      //DBG_PRINT_MATRIX(2, "V", *V);
      retval = SolveMultiVector(D_x, delta_x, D_s, delta_s, J_c, D_c, delta_c, J_d, D_d, delta_d, proto_rhs_x,
         proto_rhs_s, proto_rhs_c, proto_rhs_d, *V, P_LM, V_x, Vtilde1_, Vtilde1_x, check_NegEVals, numberOfNegEVals);
      if( retval != SYMSOLVER_SUCCESS )
      {
         Jnlst().Printf(J_DETAILED, J_SOLVE_PD_SYSTEM,
            "LowRankAugSystemSolver: SolveMultiVector returned retval = %d for V.\n", retval);
         return retval;
      }
      //DBG_PRINT_MATRIX(2, "Vtilde1_x", *Vtilde1_x);

      SmartPtr<DenseSymMatrixSpace> M1space = new DenseSymMatrixSpace(nV);
      SmartPtr<DenseSymMatrix> M1 = M1space->MakeNewDenseSymMatrix();
      M1->FillIdentity();
      M1->HighRankUpdateTranspose(1., *Vtilde1_x, *V_x, 1.);
      //DBG_PRINT_MATRIX(2, "M1", *M1);
      SmartPtr<DenseGenMatrixSpace> J1space = new DenseGenMatrixSpace(nV, nV);
      J1_ = J1space->MakeNewDenseGenMatrix();
      bool retchol = J1_->ComputeCholeskyFactor(*M1);
      // M1 must be positive definite!
      //DBG_ASSERT(retchol);
      if( !retchol )
      {
         Jnlst().Printf(J_DETAILED, J_SOLVE_PD_SYSTEM, "LowRankAugSystemSolver: Cholesky for M1 returned error!\n");
         retval = SYMSOLVER_WRONG_INERTIA;
         num_neg_evals_++;
         return retval;
      }
   }
   else
   {
      Vtilde1_ = NULL;
      J1_ = NULL;
   }

   if( IsValid(U) )
   {
      Index nU = U->NCols();
      SmartPtr<MultiVectorMatrix> U_x;
      SmartPtr<MultiVectorMatrix> Utilde1;
      SmartPtr<MultiVectorMatrix> Utilde1_x;
      SmartPtr<MultiVectorMatrix> Utilde2_x;
      retval = SolveMultiVector(D_x, delta_x, D_s, delta_s, J_c, D_c, delta_c, J_d, D_d, delta_d, proto_rhs_x,
         proto_rhs_s, proto_rhs_c, proto_rhs_d, *U, P_LM, U_x, Utilde1, Utilde1_x, check_NegEVals, numberOfNegEVals);
      if( retval != SYMSOLVER_SUCCESS )
      {
         Jnlst().Printf(J_DETAILED, J_SOLVE_PD_SYSTEM,
            "LowRankAugSystemSolver: SolveMultiVector returned retval = %d for U.\n", retval);
         return retval;
      }

      if( IsNull(Vtilde1_) )
      {
         Utilde2_ = Utilde1;
         Utilde2_x = Utilde1_x;
      }
      else
      {
         Index nV = Vtilde1_->NCols();
         SmartPtr<DenseGenMatrixSpace> Cspace = new DenseGenMatrixSpace(nV, nU);
         SmartPtr<DenseGenMatrix> C = Cspace->MakeNewDenseGenMatrix();
         C->HighRankUpdateTranspose(1., *Vtilde1_x, *U_x, 0.);
         J1_->CholeskySolveMatrix(*C);
         Utilde2_ = Utilde1;
         Utilde2_->AddRightMultMatrix(-1, *Vtilde1_, *C, 1.);
         Utilde2_x = Utilde1_x->MakeNewMultiVectorMatrix();
         for( Index i = 0; i < Utilde1_x->NCols(); i++ )
         {
            const CompoundVector* cvec = static_cast<const CompoundVector*>(GetRawPtr(Utilde2_->GetVector(i)));
            DBG_ASSERT(cvec);
            Utilde2_x->SetVector(i, *cvec->GetComp(0));
         }
      }

      SmartPtr<DenseSymMatrixSpace> M2space = new DenseSymMatrixSpace(nU);
      SmartPtr<DenseSymMatrix> M2 = M2space->MakeNewDenseSymMatrix();
      M2->FillIdentity();
      M2->HighRankUpdateTranspose(-1., *Utilde2_x, *U_x, 1.);
      SmartPtr<DenseGenMatrixSpace> J2space = new DenseGenMatrixSpace(nU, nU);
      J2_ = J2space->MakeNewDenseGenMatrix();
      //DBG_PRINT_MATRIX(2, "M2", *M2);
      bool retchol = J2_->ComputeCholeskyFactor(*M2);
      if( !retchol )
      {
         Jnlst().Printf(J_DETAILED, J_SOLVE_PD_SYSTEM, "LowRankAugSystemSolver: Cholesky for M2 returned error.\n");
         retval = SYMSOLVER_WRONG_INERTIA;
         num_neg_evals_++;
         return retval;
      }
   }
   else
   {
      J2_ = NULL;
      Utilde2_ = NULL;
   }

   return retval;
}