/* compute psi function */
void ACPsi(
  GlobalFrictionContactProblem* problem,
  AlartCurnierFun3x3Ptr computeACFun3x3,
  double *globalVelocity,
  double *reaction,
  double *velocity,
  double *rho,
  double *psi)
{

  assert(problem->H->size1 == problem->dimension * problem->numberOfContacts);

  unsigned int localProblemSize = problem->H->size1;

  unsigned int ACProblemSize = sizeOfPsi(NM_triplet(problem->M),
                                         NM_triplet(problem->H));

  unsigned int globalProblemSize = problem->M->size0;


  /* psi <-
       compute -problem->M * globalVelocity + problem->H * reaction + problem->q
     ... */
  cblas_dscal(ACProblemSize, 0., psi, 1);
  cblas_dcopy(globalProblemSize, problem->q, 1, psi, 1);
  NM_gemv(1., problem->H, reaction, 1, psi);
  NM_gemv(-1., problem->M, globalVelocity, 1, psi);


  /* psi + globalProblemSize <-
     compute -velocity + trans(problem->H) * globalVelocity + problem->b
   ... */
  cblas_daxpy(localProblemSize, -1., velocity, 1, psi + globalProblemSize, 1);
  cblas_daxpy(localProblemSize, 1, problem->b, 1, psi + globalProblemSize, 1);
  NM_tgemv(1., problem->H, globalVelocity, 1, psi + globalProblemSize);



  /* compute AC function */
  fc3d_AlartCurnierFunction(localProblemSize,
                                         computeACFun3x3,
                                         reaction,
                                         velocity, problem->mu, rho,
                                         psi+globalProblemSize+
                                         problem->H->size1,
                                         NULL, NULL);

}
/* compute psi function */
void ACPsi(
  GlobalFrictionContactProblem* problem,
  AlartCurnierFun3x3Ptr computeACFun3x3,
  double *globalVelocity,
  double *reaction,
  double *velocity,
  double *rho,
  double *psi)
{

  assert(problem->H->size1 == problem->dimension * problem->numberOfContacts);
  unsigned int m = problem->H->size1;
  unsigned int n = problem->M->size0;
  unsigned int problem_size = n +  2*m ;

  cblas_dscal(problem_size, 0., psi, 1);

  /* -problem->M * globalVelocity + problem->H * reaction + problem->q ==> psi  */
  cblas_dscal(problem_size, 0., psi, 1);

  cblas_dcopy(n, problem->q, 1, psi, 1);
  NM_gemv(1., problem->H, reaction, 1, psi);
  NM_gemv(-1., problem->M, globalVelocity, 1, psi);

  /* -velocity + trans(problem->H) * globalVelocity + problem->b ==> psi + n */
  cblas_daxpy(m, -1., velocity, 1, psi + n, 1);
  cblas_daxpy(m, 1, problem->b, 1, psi + n, 1);
  NM_tgemv(1., problem->H, globalVelocity, 1, psi + n);

  /* compute AC function */
  fc3d_AlartCurnierFunction(m,
                            computeACFun3x3,
                            reaction,
                            velocity, problem->mu, rho,
                            psi+n+m,
                            NULL, NULL);

}
/* Alart & Curnier solver for sparse global problem */
void gfc3d_nonsmooth_Newton_AlartCurnier(
  GlobalFrictionContactProblem* problem,
  double *reaction,
  double *velocity,
  double *globalVelocity,
  int *info,
  SolverOptions* options)
{

  assert(problem);
  assert(reaction);
  assert(velocity);
  assert(info);
  assert(options);

  assert(problem->dimension == 3);

  assert(options->iparam);
  assert(options->dparam);

  assert(problem->q);
  assert(problem->mu);
  assert(problem->M);
  assert(problem->H);

  assert(!problem->M->matrix0);
//  assert(problem->M->matrix1);

  assert(!options->iparam[4]); // only host

  /* M is square */
  assert(problem->M->size0 == problem->M->size1);

  assert(problem->M->size0 == problem->H->size0);

  unsigned int iter = 0;
  unsigned int itermax = options->iparam[0];
  unsigned int erritermax = options->iparam[7];

  if (erritermax == 0)
  {
    /* output a warning here */
    erritermax = 1;
  }

  assert(itermax > 0);
  assert(options->iparam[3] > 0);

  double tolerance = options->dparam[0];
  assert(tolerance > 0);

  if (verbose > 0)
    printf("------------------------ GFC3D - _nonsmooth_Newton_AlartCurnier - Start with tolerance = %g\n", tolerance);


  /* sparse triplet storage */
  NM_triplet(problem->M);
  NM_triplet(problem->H);

  unsigned int ACProblemSize = sizeOfPsi(NM_triplet(problem->M),
                                         NM_triplet(problem->H));

  unsigned int globalProblemSize = (unsigned)NM_triplet(problem->M)->m;

  unsigned int localProblemSize = problem->H->size1;

  assert((int)localProblemSize == problem->numberOfContacts * problem->dimension);

  assert((int)globalProblemSize == problem->H->size0); /* size(velocity) ==
                                                   * Htrans*globalVelocity */


  AlartCurnierFun3x3Ptr computeACFun3x3 = NULL;

  switch (options->iparam[10])
  {
  case 0:
  {
    computeACFun3x3 = &computeAlartCurnierSTD;
    break;
  }
  case 1:
  {
    computeACFun3x3 = &computeAlartCurnierJeanMoreau;
    break;
  };
  case 2:
  {
    computeACFun3x3 = &fc3d_AlartCurnierFunctionGenerated;
    break;
  }
  case 3:
  {
    computeACFun3x3 = &fc3d_AlartCurnierJeanMoreauFunctionGenerated;
    break;
  }
  }

  if(options->iparam[9] == 0)
  {
    /* allocate memory */
    assert(options->dWork == NULL);
    assert(options->iWork == NULL);
    options->dWork = (double *) malloc(
                       (localProblemSize + /* F */
                        3 * localProblemSize + /* A */
                        3 * localProblemSize + /* B */
                        localProblemSize + /* rho */
                        ACProblemSize + /* psi */
                        ACProblemSize + /* rhs */
                        ACProblemSize + /* tmp2 */
                        ACProblemSize + /* tmp3 */
                        ACProblemSize   /* solution */) *  sizeof(double));

    /* XXX big hack here */
    options->iWork = (int *) malloc(
                       (3 * localProblemSize + /* iA */
                        3 * localProblemSize + /* iB */
                        3 * localProblemSize + /* pA */
                        3 * localProblemSize)  /* pB */
                       * sizeof(csi));

    options->iparam[9] = 1;

  }

  assert(options->dWork != NULL);
  assert(options->iWork != NULL);

  double *F = options->dWork;
  double *A = F +   localProblemSize;
  double *B = A +   3 * localProblemSize;
  double *rho = B + 3 * localProblemSize;

  double * psi = rho + localProblemSize;
  double * rhs = psi + ACProblemSize;
  double * tmp2 = rhs + ACProblemSize;
  double * tmp3 = tmp2 + ACProblemSize;
  double * solution = tmp3 + ACProblemSize;

  /* XXX big hack --xhub*/
  csi * iA = (csi *)options->iWork;
  csi * iB = iA + 3 * localProblemSize;
  csi * pA = iB + 3 * localProblemSize;
  csi * pB = pA + 3 * localProblemSize;

  CSparseMatrix A_;
  CSparseMatrix B_;
  CSparseMatrix *J;

  A_.p = pA;
  B_.p = pB;
  A_.i = iA;
  B_.i = iB;

  init3x3DiagBlocks(problem->numberOfContacts, A, &A_);
  init3x3DiagBlocks(problem->numberOfContacts, B, &B_);

  J = cs_spalloc(NM_triplet(problem->M)->n + A_.m + B_.m,
                 NM_triplet(problem->M)->n + A_.m + B_.m,
                 NM_triplet(problem->M)->nzmax + 2*NM_triplet(problem->H)->nzmax +
                 2*A_.n + A_.nzmax + B_.nzmax, 1, 1);

  assert(A_.n == problem->H->size1);
  assert(A_.nz == problem->numberOfContacts * 9);
  assert(B_.n == problem->H->size1);
  assert(B_.nz == problem->numberOfContacts * 9);

  fc3d_AlartCurnierFunction(
    localProblemSize,
    computeACFun3x3,
    reaction, velocity,
    problem->mu, rho,
    F, A, B);

  csi Astart = initACPsiJacobian(NM_triplet(problem->M),
                                 NM_triplet(problem->H),
                                 &A_, &B_, J);

  assert(Astart > 0);

  assert(A_.m == A_.n);
  assert(B_.m == B_.n);

  assert(A_.m == problem->H->size1);

  // compute rho here
  for(unsigned int i = 0; i < localProblemSize; ++i) rho[i] = 1.;

  // direction
  for(unsigned int i = 0; i < ACProblemSize; ++i) rhs[i] = 0.;



  // quick hack to make things work
  // need to use the functions from NumericsMatrix --xhub


  NumericsMatrix *AA_work = createNumericsMatrix(NM_SPARSE,  (int)J->m, (int)J->n);

  NumericsSparseMatrix* SM = newNumericsSparseMatrix();
  SM->triplet = J;
  NumericsMatrix *AA = createNumericsMatrixFromData(NM_SPARSE,  (int)J->m, (int)J->n, SM);

  info[0] = 1;

  /* update local velocity from global velocity */
  /* an assertion ? */
  cblas_dcopy(localProblemSize, problem->b, 1, velocity, 1);
  NM_tgemv(1., problem->H, globalVelocity, 1, velocity);
  double linear_solver_residual=0.0;
  while(iter++ < itermax)
  {

    /* compute psi */
    ACPsi(problem, computeACFun3x3, globalVelocity, reaction, velocity, rho, psi);

    /* compute A & B */
    fc3d_AlartCurnierFunction(localProblemSize,
                              computeACFun3x3,
                              reaction, velocity,
                              problem->mu, rho,
                              F, A, B);
    /* update J */
    updateACPsiJacobian(NM_triplet(problem->M),
                        NM_triplet(problem->H),
                        &A_, &B_, J, Astart);

    /* rhs = -psi */
    cblas_dcopy(ACProblemSize, psi, 1, rhs, 1);
    cblas_dscal(ACProblemSize, -1., rhs, 1);

    /* get compress column storage for linear ops */
    CSparseMatrix* Jcsc = cs_compress(J);

    /* Solve: J X = -psi */

    /* Solve: AWpB X = -F */
    NM_copy(AA, AA_work);

    int info_solver = NM_gesv(AA_work, rhs);
    if (info_solver > 0)
    {
      fprintf(stderr, "------------------------ GFC3D - NSN_AC - solver failed info = %d\n", info_solver);
      break;
      info[0] = 2;
      CHECK_RETURN(!cs_check_triplet(NM_triplet(AA_work)));
    }

    /* Check the quality of the solution */
    if (verbose > 0)
    {
      cblas_dcopy_msan(ACProblemSize, psi, 1, tmp3, 1);
      NM_gemv(1., AA, rhs, 1., tmp3);
      linear_solver_residual = cblas_dnrm2(ACProblemSize, tmp3, 1);
      /* fprintf(stderr, "fc3d esolve: linear equation residual = %g\n", */
      /*         cblas_dnrm2(problemSize, tmp3, 1)); */
      /* for the component wise scaled residual: cf mumps &
       * http://www.netlib.org/lapack/lug/node81.html */
    }

    /* line search */
    double alpha = 1;

    /* set current solution */
    for(unsigned int i = 0; i < globalProblemSize; ++i)
    {
      solution[i] = globalVelocity[i];
    }
    for(unsigned int i = 0; i < localProblemSize; ++i)
    {
      solution[i+globalProblemSize] = velocity[i];
      solution[i+globalProblemSize+localProblemSize] = reaction[i];
    }

    DEBUG_EXPR_WE(
      for(unsigned int i = 0; i < globalProblemSize; ++i)
      {
        printf("globalVelocity[%i] = %6.4e\n",i,globalVelocity[i]);
      }
      for(unsigned int i = 0; i < localProblemSize; ++i)
      {
        printf("velocity[%i] = %6.4e\t",i,velocity[i]);
        printf("reaction[%i] = %6.4e\n",i,reaction[i]);
      }
      );


    int info_ls = _globalLineSearchSparseGP(problem,
                                            computeACFun3x3,
                                            solution,
                                            rhs,
                                            globalVelocity,
                                            reaction, velocity,
                                            problem->mu, rho, F, psi, Jcsc,
                                            tmp2, &alpha, 100);


    cs_spfree(Jcsc);
    if(!info_ls)
    {
      cblas_daxpy(ACProblemSize, alpha, rhs, 1, solution, 1);
    }
    else
    {
      cblas_daxpy(ACProblemSize, 1, rhs, 1., solution, 1);
    }

    for(unsigned int e = 0 ; e < globalProblemSize; ++e)
    {
      globalVelocity[e] = solution[e];
    }

    for(unsigned int e = 0 ; e < localProblemSize; ++e)
    {
      velocity[e] = solution[e+globalProblemSize];
    }

    for(unsigned int e = 0; e < localProblemSize; ++e)
    {
      reaction[e] = solution[e+globalProblemSize+localProblemSize];
    }

    options->dparam[1] = INFINITY;

    if(!(iter % erritermax))
    {

      gfc3d_compute_error(problem,
                          reaction, velocity, globalVelocity,
                          tolerance,
                          &(options->dparam[1]));
    }

    if(verbose > 0)
      printf("------------------------ GFC3D - NSN_AC - iteration %d, residual = %g, linear solver residual = %g, tolerance = %g \n", iter, options->dparam[1],linear_solver_residual, tolerance);

    if(options->dparam[1] < tolerance)
    {
      info[0] = 0;
      break;
    }


  }