// -----------------------------------------------------------------------------
// Calculate contact forces and torques for all contact pairs.
// -----------------------------------------------------------------------------
void
ChLcpSolverParallelDEM::host_CalcContactForces(
    custom_vector<int>&    ext_body_id,
    custom_vector<real3>&  ext_body_force,
    custom_vector<real3>&  ext_body_torque)
{
#pragma omp parallel for
  for (int index = 0; index < data_container->num_contacts; index++) {
    function_CalcContactForces(
      index,
      data_container->step_size,
      data_container->host_data.pos_data.data(),
      data_container->host_data.rot_data.data(),
      data_container->host_data.vel_data.data(),
      data_container->host_data.omg_data.data(),
      data_container->host_data.elastic_moduli.data(),
      data_container->host_data.mu.data(),
      data_container->host_data.alpha.data(),
      data_container->host_data.cr.data(),
      data_container->host_data.cohesion_data.data(),
      data_container->host_data.bids_rigid_rigid.data(),
      data_container->host_data.cpta_rigid_rigid.data(),
      data_container->host_data.cptb_rigid_rigid.data(),
      data_container->host_data.norm_rigid_rigid.data(),
      data_container->host_data.dpth_rigid_rigid.data(),
      data_container->host_data.erad_rigid_rigid.data(),
      ext_body_id.data(),
      ext_body_force.data(),
      ext_body_torque.data());
  }
}
void ChSolverParallel::ShurProduct(
                                   custom_vector<real> &x,
                                   custom_vector<real> & output) {

   data_container->system_timer.start("ChSolverParallel_shurA");
   shurA(x.data());
   data_container->system_timer.stop("ChSolverParallel_shurA");

   data_container->system_timer.start("ChSolverParallel_shurB");
   shurB(x.data(), output.data());
   data_container->system_timer.stop("ChSolverParallel_shurB");

}
void ChSolverParallel::UpdatePosition(
custom_vector<real> &x) {

   if (rigid_rigid->solve_sliding == true || rigid_rigid->solve_spinning == true) {
      return;
   }
   shurA(x.data());

   data_container->host_data.vel_new_data = data_container->host_data.vel_data + data_container->host_data.QXYZ_data;
   data_container->host_data.omg_new_data + data_container->host_data.omg_data + data_container->host_data.QUVW_data;

#pragma omp parallel for
   for (int i = 0; i < data_container->num_bodies; i++) {

      data_container->host_data.pos_new_data[i] = data_container->host_data.pos_data[i] + data_container->host_data.vel_new_data[i] * step_size;
      //real3 dp = data_container->host_data.pos_new_data[i]-data_container->host_data.pos_data[i];
      //cout<<dp<<endl;
      real4 moldrot = data_container->host_data.rot_data[i];
      real3 newwel = data_container->host_data.omg_new_data[i];

      M33 A = AMat(moldrot);
      real3 newwel_abs = MatMult(A, newwel);
      real mangle = length(newwel_abs) * step_size;
      newwel_abs = normalize(newwel_abs);
      real4 mdeltarot = Q_from_AngAxis(mangle, newwel_abs);
      real4 mnewrot = mdeltarot % moldrot;
      data_container->host_data.rot_new_data[i] = mnewrot;
   }
}
// -----------------------------------------------------------------------------
// Calculate contact forces and torques for all contact pairs.
// -----------------------------------------------------------------------------
void ChLcpSolverParallelDEM::host_CalcContactForces(custom_vector<int>& ext_body_id,
                                                    custom_vector<real3>& ext_body_force,
                                                    custom_vector<real3>& ext_body_torque,
                                                    custom_vector<int2>& shape_pairs,
                                                    custom_vector<bool>& shear_touch) {
#pragma omp parallel for
  for (int index = 0; index < data_manager->num_rigid_contacts; index++) {
    function_CalcContactForces(index,
                               data_manager->settings.solver.contact_force_model,
                               data_manager->settings.solver.tangential_displ_mode,
                               data_manager->settings.solver.use_material_properties,
                               data_manager->settings.solver.characteristic_vel,
                               data_manager->settings.solver.min_slip_vel,
                               data_manager->settings.step_size,
                               data_manager->host_data.mass_rigid.data(),
                               data_manager->host_data.pos_rigid.data(),
                               data_manager->host_data.rot_rigid.data(),
                               data_manager->host_data.v.data(),
                               data_manager->host_data.elastic_moduli.data(),
                               data_manager->host_data.cr.data(),
                               data_manager->host_data.dem_coeffs.data(),
                               data_manager->host_data.mu.data(),
                               data_manager->host_data.cohesion_data.data(),
                               data_manager->host_data.bids_rigid_rigid.data(),
                               shape_pairs.data(),
                               data_manager->host_data.cpta_rigid_rigid.data(),
                               data_manager->host_data.cptb_rigid_rigid.data(),
                               data_manager->host_data.norm_rigid_rigid.data(),
                               data_manager->host_data.dpth_rigid_rigid.data(),
                               data_manager->host_data.erad_rigid_rigid.data(),
                               data_manager->host_data.shear_neigh.data(),
                               shear_touch.data(),
                               data_manager->host_data.shear_disp.data(),
                               ext_body_id.data(),
                               ext_body_force.data(),
                               ext_body_torque.data());
  }
}
uint ChSolverAPGDBlaze::SolveAPGDBlaze(const uint max_iter,
                                       const uint size,
                                       custom_vector<real> &b,
                                       custom_vector<real> &x) {


   data_container->system_timer.start("ChSolverParallel_solverA");
   blaze::DynamicVector<real> one(size, 1.0);

   ms.resize(size);
   mg_tmp2.resize(size);
   mb_tmp.resize(size);
   mg_tmp.resize(size);
   mg_tmp1.resize(size);
   mg.resize(size);
   ml.resize(size);
   mx.resize(size);
   my.resize(size);
   mb.resize(size);
   mso.resize(size);

   lastgoodres = 10e30;
   theta_k = init_theta_k;
   theta_k1 = theta_k;
   beta_k1 = 0.0;
   mb_tmp_norm = 0, mg_tmp_norm = 0;
   obj1 = 0.0, obj2 = 0.0;
   dot_mg_ms = 0, norm_ms = 0;
   delta_obj = 1e8;

#pragma omp parallel for
   for (int i = 0; i < size; i++) {
      ml[i] = x[i];
      mb[i] = b[i];
   }
   Project(ml.data());
   ml_candidate = ml;
   mg = data_container->host_data.Nshur * ml;
   mg = mg - mb;
   mb_tmp = ml - one;
   mg_tmp = data_container->host_data.Nshur * mb_tmp;

   mb_tmp_norm = sqrt((mb_tmp, trans(mb_tmp)));
   mg_tmp_norm = sqrt((mg_tmp, trans(mg_tmp)));

   if (mb_tmp_norm == 0) {
      L_k = 1;
   } else {
      L_k = mg_tmp_norm / mb_tmp_norm;
   }

   t_k = 1.0 / L_k;
   my = ml;
   mx = ml;
   data_container->system_timer.stop("ChSolverParallel_solverA");

   for (current_iteration = 0; current_iteration < max_iter; current_iteration++) {

      data_container->system_timer.start("ChSolverParallel_solverB");

      mg_tmp1 = data_container->host_data.Nshur * my;
      data_container->system_timer.stop("ChSolverParallel_solverB");
      data_container->system_timer.start("ChSolverParallel_solverC");

      mg = mg_tmp1 - mb;
      mx = -t_k * mg + my;

      Project(mx.data());
      mg_tmp = data_container->host_data.Nshur * mx;
      data_container->system_timer.stop("ChSolverParallel_solverC");
      data_container->system_timer.start("ChSolverParallel_solverD");

      //mg_tmp2 = mg_tmp - mb;
      mso = .5 * mg_tmp - mb;
      obj1 = (mx, trans(mso));
      ms = .5 * mg_tmp1 - mb;
      obj2 = (my, trans(ms));
      ms = mx - my;
      dot_mg_ms = (mg, trans(ms));
      norm_ms = (ms, trans(ms));

      data_container->system_timer.stop("ChSolverParallel_solverD");
      while (obj1 > obj2 + dot_mg_ms + 0.5 * L_k * norm_ms) {
         data_container->system_timer.start("ChSolverParallel_solverE");
         L_k = 2.0 * L_k;
         t_k = 1.0 / L_k;
         mx = -t_k * mg + my;
         Project(mx.data());

         mg_tmp = data_container->host_data.Nshur * mx;

         mso = .5 * mg_tmp - mb;
         obj1 = (mx, trans(mso));
         ms = mx - my;
         dot_mg_ms = (mg, trans(ms));
         norm_ms = (ms, trans(ms));

         data_container->system_timer.stop("ChSolverParallel_solverE");
      }
      data_container->system_timer.start("ChSolverParallel_solverF");

      theta_k1 = (-pow(theta_k, 2) + theta_k * sqrt(pow(theta_k, 2) + 4)) / 2.0;
      beta_k1 = theta_k * (1.0 - theta_k) / (pow(theta_k, 2) + theta_k1);

      ms = mx - ml;
      my = beta_k1 * ms + mx;
      real dot_mg_ms = (mg, trans(ms));

      if (dot_mg_ms > 0) {
         my = mx;
         theta_k1 = 1.0;
      }
      L_k = 0.9 * L_k;
      t_k = 1.0 / L_k;
      ml = mx;
      step_grow = 2.0;
      theta_k = theta_k1;
      //if (current_iteration % 2 == 0) {
      mg_tmp2 = mg_tmp - mb;
      real g_proj_norm = Res4(num_unilaterals, mg_tmp2, ml, mb_tmp);

      if (num_bilaterals > 0) {
         real resid_bilat = -1;
         for (int i = num_unilaterals; i < x.size(); i++) {
            resid_bilat = std::max(resid_bilat, std::abs(mg_tmp2[i]));
         }
         g_proj_norm = std::max(g_proj_norm, resid_bilat);
      }

      bool update = false;
      if (g_proj_norm < lastgoodres) {
         lastgoodres = g_proj_norm;
         ml_candidate = ml;
         objective_value = (ml_candidate, mso);  //maxdeltalambda = GetObjectiveBlaze(ml_candidate, mb);
         update = true;
      }

      residual = lastgoodres;
      //if (update_rhs) {
      //ComputeSRhs(ml_candidate, rhs, vel_data, omg_data, b);
      //}

      AtIterationEnd(residual, objective_value, iter_hist.size());
      if (tol_objective) {
         if (objective_value <= tolerance) {
            break;
         }
      } else {
         if (residual < tolerance) {
            break;
         }
      }
      data_container->system_timer.stop("ChSolverParallel_solverF");
   }
#pragma omp parallel for
   for (int i = 0; i < size; i++) {
      x[i] = ml_candidate[i];
   }
   return current_iteration;
}
uint ChSolverParallel::SolveStab(const uint max_iter,
                                 const uint size,
                                 const custom_vector<real> &mb,
                                 custom_vector<real> &x) {
   uint N = mb.size();
   //	bool verbose = false;
   //	custom_vector<real> mr(N, 0), ml(N,0), mp(N,0), mz(N,0), mNMr(N,0), mNp(N,0), mMNp(N,0), mtmp(N,0);
   //	custom_vector<real> mz_old;
   //	custom_vector<real> mNMr_old;
   //	real grad_diffstep = 0.01;
   //	double rel_tol = tolerance;
   //	double abs_tol = tolerance;
   //
   //	double rel_tol_b = NormInf(mb) * rel_tol;
   //	//ml = x;
   //		ShurBilaterals(ml,mr);
   //		mr = mb-mr;
   //		mp=mr;
   //		mz=mr;
   //
   //		ShurBilaterals(mz,mNMr);
   //		ShurBilaterals(mp,mNp);
   //	//mNp = mNMr;
   //		for (current_iteration = 0; current_iteration < max_iter; current_iteration++) {
   //			mMNp = mNp;
   //
   //			double zNMr = Dot(mz,mNMr);
   //			double MNpNp = Dot(mMNp, mNp);
   //			if (fabs(MNpNp)<10e-30)
   //			{
   //				if (verbose) {cout << "Iter=" << current_iteration << " Rayleygh quotient alpha breakdown: " << zNMr << " / " << MNpNp << "\n";}
   //				MNpNp=10e-12;
   //			}
   //			double alpha = zNMr/MNpNp;
   //			mtmp = mp*alpha;
   //			ml=ml+mtmp;
   //			double maxdeltalambda = Norm(mtmp);
   //
   //			ShurBilaterals(ml,mr);
   //			mr = mb-mr;
   //
   //			double r_proj_resid = Norm(mr);
   //
   //			if (r_proj_resid < max(rel_tol_b, abs_tol) )
   //			{
   //				if (verbose)
   //				{
   //					cout << "Iter=" << current_iteration << " P(r)-converged!  |P(r)|=" << r_proj_resid << "\n";
   //				}
   //				break;
   //			}
   //
   //			mz_old = mz;
   //			mz = mr;
   //			mNMr_old = mNMr;
   //
   //			ShurBilaterals(mz,mNMr);
   //			double numerator = Dot(mz,mNMr-mNMr_old);
   //			double denominator = Dot(mz_old,mNMr_old);
   //			double beta =numerator /numerator;
   //
   //			if (fabs(denominator)<10e-30 || fabs(numerator)<10e-30)
   //			{
   //				if (verbose)
   //				{
   //					cout << "Iter=" << current_iteration << " Ribiere quotient beta restart: " << numerator << " / " << denominator << "\n";
   //				}
   //				beta =0;
   //			}
   //
   //			mtmp = mp*beta;
   //			mp = mz+mtmp;
   //			mNp = mNp*beta+mNMr;
   //
   //			AtIterationEnd(r_proj_resid, maxdeltalambda, current_iteration);
   //
   //		}
   //		x=ml;
   custom_vector<real> v(N, 0), v_hat(x.size()), w(N, 0), w_old, xMR, v_old, Av(x.size()), w_oold;
   real beta, c = 1, eta, norm_rMR, norm_r0, c_old = 1, s_old = 0, s = 0, alpha, beta_old, c_oold, s_oold, r1_hat, r1, r2, r3;
   ShurBilaterals(x, v_hat);
   v_hat = mb - v_hat;
   beta = Norm(v_hat);
   w_old = w;
   eta = beta;
   xMR = x;
   norm_rMR = beta;
   norm_r0 = beta;

   if (beta == 0 || norm_rMR / norm_r0 < tolerance) {
      return 0;
   }

   for (current_iteration = 0; current_iteration < max_iter; current_iteration++) {
      //// Lanczos
      v_old = v;
      v = 1.0 / beta * v_hat;
      ShurBilaterals(v, Av);
      alpha = Dot(v, Av);
      v_hat = Av - alpha * v - beta * v_old;
      beta_old = beta;
      beta = Norm(v_hat);
      //// QR factorization
      c_oold = c_old;
      c_old = c;
      s_oold = s_old;
      s_old = s;
      r1_hat = c_old * alpha - c_oold * s_old * beta_old;
      r1 = 1 / sqrt(r1_hat * r1_hat + beta * beta);
      r2 = s_old * alpha + c_oold * c_old * beta_old;
      r3 = s_oold * beta_old;
      //// Givens Rotation
      c = r1_hat * r1;
      s = beta * r1;
      //// update
      w_oold = w_old;
      w_old = w;
      w = r1 * (v - r3 * w_oold - r2 * w_old);
      x = x + c * eta * w;
      norm_rMR = norm_rMR * abs(s);
      eta = -s * eta;
      residual = norm_rMR / norm_r0;

      real maxdeltalambda = CompRes(mb, num_contacts);      //NormInf(ms);
      AtIterationEnd(residual, maxdeltalambda, iter_hist.size() + current_iteration);

      if (residual < tolerance) {
         break;
      }
   }
   total_iteration += current_iteration;
   current_iteration = total_iteration;
   return current_iteration;
}