Esempio n. 1
0
void SolverWrap(SEXP A, SEXP fin, SEXP gin, SEXP params, SEXP x, SEXP y,
                SEXP u, SEXP v, SEXP opt, SEXP status) {
  SEXP Adim = GET_DIM(A);
  size_t m = INTEGER(Adim)[0];
  size_t n = INTEGER(Adim)[1];
  unsigned int num_obj = length(fin);

  pogs::MatrixDense<T> A_dense('c', m, n, REAL(A));

  // Initialize Pogs data structure
  pogs::PogsDirect<T, pogs::MatrixDense<T> > pogs_data(A_dense);
  std::vector<FunctionObj<T> > f, g;

  f.reserve(m);
  g.reserve(n);

  // Populate parameters.
  PopulateParams(params, &pogs_data);

  // Allocate space for factors if more than one objective.
  int err = 0;

  for (unsigned int i = 0; i < num_obj && !err; ++i) {
    // Populate function objects.
    f.clear();
    g.clear();
    PopulateFunctionObj(VECTOR_ELT(fin, i), m, &f);
    PopulateFunctionObj(VECTOR_ELT(gin, i), n, &g);

    // Run solver.
    INTEGER(status)[i] = pogs_data.Solve(f, g);

    // Get Solution
    memcpy(REAL(x) + i * n, pogs_data.GetX(), n * sizeof(T));
    memcpy(REAL(y) + i * m, pogs_data.GetY(), m * sizeof(T));
    memcpy(REAL(u) + i * n, pogs_data.GetMu(), n * sizeof(T));
    memcpy(REAL(v) + i * m, pogs_data.GetLambda(), m * sizeof(T));

    REAL(opt)[i] = pogs_data.GetOptval();
  }
}
Esempio n. 2
0
void SolverWrap(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
  size_t m = mxGetM(prhs[0]);
  size_t n = mxGetN(prhs[0]);

  // Convert column major (matlab) to row major (c++).
  T* A = new T[m * n];
  ColToRowMajor(reinterpret_cast<T*>(mxGetPr(prhs[0])), m, n, A);

  // Initialize Pogs data structure
  PogsData<T, T*> pogs_data(A, m, n);
  pogs_data.f.reserve(m);
  pogs_data.g.reserve(n);
  pogs_data.x = reinterpret_cast<T*>(mxGetPr(plhs[0]));
  if (nlhs >= 2)
    pogs_data.y = reinterpret_cast<T*>(mxGetPr(plhs[1]));

  // Populate parameters.
  int err = 0;
  if (nrhs == 4)
    err = PopulateParams(prhs[3], &pogs_data);

  // Populate function objects.
  if (err == 0)
    err = PopulateFunctionObj("f", prhs[1], m, &pogs_data.f);
  if (err == 0)
    err = PopulateFunctionObj("g", prhs[2], n, &pogs_data.g);

  // Run solver.
  if (err == 0)
    Pogs(&pogs_data);

  if (nlhs >= 3)
    reinterpret_cast<T*>(mxGetPr(plhs[2]))[0] = pogs_data.optval;

  delete [] A;
}