inline void callFunction(mxArray* plhs[], const mxArray*prhs[],
      const int nlhs) {
   if (!mexCheckType<T>(prhs[0])) 
      mexErrMsgTxt("type of argument 1 is not consistent");
   if (mxIsSparse(prhs[0])) 
      mexErrMsgTxt("argument 1 should not be sparse");

   if (!mexCheckType<T>(prhs[1])) 
      mexErrMsgTxt("type of argument 2 is not consistent");

   if (!mexCheckType<T>(prhs[2])) 
      mexErrMsgTxt("type of argument 3 is not consistent");
   if (mxIsSparse(prhs[2])) 
      mexErrMsgTxt("argument 3 should not be sparse");

   if (!mxIsStruct(prhs[3])) 
         mexErrMsgTxt("argument 4 should be struct");
   if (!mxIsStruct(prhs[4])) 
      mexErrMsgTxt("argument 5 should be struct");

   T* prX = reinterpret_cast<T*>(mxGetPr(prhs[0]));
   const mwSize* dimsX=mxGetDimensions(prhs[0]);
   INTM m=static_cast<INTM>(dimsX[0]);
   INTM n=static_cast<INTM>(dimsX[1]);
   Matrix<T> X(prX,m,n);

   const mwSize* dimsD=mxGetDimensions(prhs[1]);
   INTM mD=static_cast<INTM>(dimsD[0]);
   INTM p=static_cast<INTM>(dimsD[1]);
   AbstractMatrixB<T>* D;
   double* D_v;
   mwSize* D_r, *D_pB, *D_pE;
   INTM* D_r2, *D_pB2, *D_pE2;
   T* D_v2;
   if (mxIsSparse(prhs[1])) {
      D_v=static_cast<double*>(mxGetPr(prhs[1]));
      D_r=mxGetIr(prhs[1]);
      D_pB=mxGetJc(prhs[1]);
      D_pE=D_pB+1;
      createCopySparse<T>(D_v2,D_r2,D_pB2,D_pE2,
            D_v,D_r,D_pB,D_pE,p);
      D = new SpMatrix<T>(D_v2,D_r2,D_pB2,D_pE2,mD,p,D_pB2[p]);
   } else {
      T* prD = reinterpret_cast<T*>(mxGetPr(prhs[1]));
      D = new Matrix<T>(prD,m,p);
   }

   T* pr_alpha0 = reinterpret_cast<T*>(mxGetPr(prhs[2]));
   const mwSize* dimsAlpha=mxGetDimensions(prhs[2]);
   INTM pAlpha=static_cast<INTM>(dimsAlpha[0]);
   INTM nAlpha=static_cast<INTM>(dimsAlpha[1]);
   Matrix<T> alpha0(pr_alpha0,pAlpha,nAlpha);

   mxArray* ppr_GG = mxGetField(prhs[3],0,"groups");
   if (!mxIsSparse(ppr_GG)) 
      mexErrMsgTxt("field groups should be sparse");
   mwSize* GG_r=mxGetIr(ppr_GG);
   mwSize* GG_pB=mxGetJc(ppr_GG);
   const mwSize* dims_GG=mxGetDimensions(ppr_GG);
   INTM GGm=static_cast<INTM>(dims_GG[0]);
   INTM GGn=static_cast<INTM>(dims_GG[1]);
   if (GGm != GGn)
      mexErrMsgTxt("size of field groups is not consistent");

   mxArray* ppr_GV = mxGetField(prhs[3],0,"groups_var");
   if (!mxIsSparse(ppr_GV)) 
      mexErrMsgTxt("field groups_var should be sparse");
   mwSize* GV_r=mxGetIr(ppr_GV);
   mwSize* GV_pB=mxGetJc(ppr_GV);
   const mwSize* dims_GV=mxGetDimensions(ppr_GV);
   INTM nV=static_cast<INTM>(dims_GV[0]);
   INTM nG=static_cast<INTM>(dims_GV[1]);
   if (nV <= 0 || nG != GGn)
      mexErrMsgTxt("size of field groups-var is not consistent");

   mxArray* ppr_weights = mxGetField(prhs[3],0,"eta_g");
   if (mxIsSparse(ppr_weights)) 
      mexErrMsgTxt("field eta_g should not be sparse");
   T* pr_weights = reinterpret_cast<T*>(mxGetPr(ppr_weights));
   const mwSize* dims_weights=mxGetDimensions(ppr_weights);
   INTM mm1=static_cast<INTM>(dims_weights[0]);
   INTM nnG=static_cast<INTM>(dims_weights[1]);
   if (mm1 != 1 || nnG != nG)
      mexErrMsgTxt("size of field eta_g is not consistent");

   plhs[0]=createMatrix<T>(pAlpha,nAlpha);
   T* pr_alpha=reinterpret_cast<T*>(mxGetPr(plhs[0]));
   Matrix<T> alpha(pr_alpha,pAlpha,nAlpha);

   FISTA::ParamFISTA<T> param;
   param.num_threads = getScalarStructDef<int>(prhs[4],"numThreads",-1);
   param.max_it = getScalarStructDef<int>(prhs[4],"max_it",1000);
   param.tol = getScalarStructDef<T>(prhs[4],"tol",0.000001);
   param.it0 = getScalarStructDef<int>(prhs[4],"it0",100);
   param.pos = getScalarStructDef<bool>(prhs[4],"pos",false);
   param.compute_gram = getScalarStructDef<bool>(prhs[4],"compute_gram",false);
   param.max_iter_backtracking = getScalarStructDef<int>(prhs[4],"max_iter_backtracking",1000);
   param.L0 = getScalarStructDef<T>(prhs[4],"L0",1.0);
   param.fixed_step = getScalarStructDef<T>(prhs[4],"fixed_step",false);
   param.gamma = MAX(1.01,getScalarStructDef<T>(prhs[4],"gamma",1.5));
   param.c= getScalarStructDef<T>(prhs[4],"c",1.0);
   param.lambda= getScalarStructDef<T>(prhs[4],"lambda",1.0);
   param.delta = getScalarStructDef<T>(prhs[4],"delta",1.0);
   param.lambda2= getScalarStructDef<T>(prhs[4],"lambda2",0.0);
   param.lambda3= getScalarStructDef<T>(prhs[4],"lambda3",0.0);
   param.size_group= getScalarStructDef<int>(prhs[4],"size_group",1);
   param.admm = getScalarStructDef<bool>(prhs[4],"admm",false);
   param.lin_admm = getScalarStructDef<bool>(prhs[4],"lin_admm",false);
   param.sqrt_step = getScalarStructDef<bool>(prhs[4],"sqrt_step",true);
   param.is_inner_weights = getScalarStructDef<bool>(prhs[4],"is_inner_weights",false);
   param.transpose = getScalarStructDef<bool>(prhs[4],"transpose",false);
   getStringStruct(prhs[4],"regul",param.name_regul,param.length_names);
   if (param.is_inner_weights) {
      mxArray* ppr_inner_weights = mxGetField(prhs[4],0,"inner_weights");
      if (!ppr_inner_weights) mexErrMsgTxt("field inner_weights is not provided");
      if (!mexCheckType<T>(ppr_inner_weights)) 
         mexErrMsgTxt("type of inner_weights is not correct");
      param.inner_weights = reinterpret_cast<T*>(mxGetPr(ppr_inner_weights));
   }

   param.regul = regul_from_string(param.name_regul);
   if (param.regul==INCORRECT_REG)
      mexErrMsgTxt("Unknown regularization");
   getStringStruct(prhs[4],"loss",param.name_loss,param.length_names);
   param.loss = loss_from_string(param.name_loss);
   if (param.loss==INCORRECT_LOSS)
      mexErrMsgTxt("Unknown loss");

   param.intercept = getScalarStructDef<bool>(prhs[4],"intercept",false);
   param.resetflow = getScalarStructDef<bool>(prhs[4],"resetflow",false);
   param.verbose = getScalarStructDef<bool>(prhs[4],"verbose",false);
   param.clever = getScalarStructDef<bool>(prhs[4],"clever",false);
   param.ista= getScalarStructDef<bool>(prhs[4],"ista",false);
   param.subgrad= getScalarStructDef<bool>(prhs[4],"subgrad",false);
   param.log= getScalarStructDef<bool>(prhs[4],"log",false);
   param.a= getScalarStructDef<T>(prhs[4],"a",T(1.0));
   param.b= getScalarStructDef<T>(prhs[4],"b",0);

   if (param.log) {
      mxArray *stringData = mxGetField(prhs[4],0,"logName");
      if (!stringData) 
         mexErrMsgTxt("Missing field logName");
      int stringLength = mxGetN(stringData)+1;
      param.logName= new char[stringLength];
      mxGetString(stringData,param.logName,stringLength);
   }

   if ((param.loss != CUR && param.loss != MULTILOG) && (pAlpha != p || nAlpha != n || mD != m)) { 
      mexErrMsgTxt("Argument sizes are not consistent");
   } else if (param.loss == MULTILOG) {
      Vector<T> Xv;
      X.toVect(Xv);
      INTM maxval = static_cast<INTM>(Xv.maxval());
      INTM minval = static_cast<INTM>(Xv.minval());
      if (minval != 0)
         mexErrMsgTxt("smallest class should be 0");
      if (maxval*X.n() > nAlpha || mD != m) {
         cerr << "Number of classes: " << maxval << endl;
         //cerr << "Alpha: " << pAlpha << " x " << nAlpha << endl;
         //cerr << "X: " << X.m() << " x " << X.n() << endl;
         mexErrMsgTxt("Argument sizes are not consistent");
      }
   } else if (param.loss == CUR && (pAlpha != D->n() || nAlpha != D->m())) {
      mexErrMsgTxt("Argument sizes are not consistent");
   }

   if (param.regul==GRAPHMULT && abs<T>(param.lambda2 - 0) < 1e-20) {
      mexErrMsgTxt("Error: with multi-task-graph, lambda2 should be > 0");
   }

   if (param.num_threads == -1) {
      param.num_threads=1;
#ifdef _OPENMP
      param.num_threads =  MIN(MAX_THREADS,omp_get_num_procs());
#endif
   } 

   if (param.regul==TREE_L0 || param.regul==TREEMULT || param.regul==TREE_L2 || param.regul==TREE_LINF) 
      mexErrMsgTxt("Error: mexFistaTree should be used instead");

   GraphStruct<T> graph;
   graph.Nv=nV;
   graph.Ng=nG;
   graph.weights=pr_weights;
   graph.gg_ir=GG_r;
   graph.gg_jc=GG_pB;
   graph.gv_ir=GV_r;
   graph.gv_jc=GV_pB;


   Matrix<T> duality_gap;
   FISTA::solver<T>(X,*D,alpha0,alpha,param,duality_gap,&graph);
   if (nlhs==2) {
      plhs[1]=createMatrix<T>(duality_gap.m(),duality_gap.n());
      T* pr_dualitygap=reinterpret_cast<T*>(mxGetPr(plhs[1]));
      for (int i = 0; i<duality_gap.n()*duality_gap.m(); ++i) pr_dualitygap[i]=duality_gap[i];
   }
   if (param.logName) delete[](param.logName);

   if (mxIsSparse(prhs[1])) {
      deleteCopySparse<T>(D_v2,D_r2,D_pB2,D_pE2,
            D_v,D_r);
   }
   delete(D);
}
示例#2
0
inline void callFunction(mxArray* plhs[], const mxArray*prhs[],
      const int nlhs) {
   if (!mexCheckType<T>(prhs[0])) 
      mexErrMsgTxt("type of argument 1 is not consistent");
   if (mxIsSparse(prhs[0])) 
      mexErrMsgTxt("argument 1 should not be sparse");

   if (!mexCheckType<T>(prhs[1])) 
      mexErrMsgTxt("type of argument 2 is not consistent");

   if (!mexCheckType<T>(prhs[2])) 
      mexErrMsgTxt("type of argument 3 is not consistent");
   if (mxIsSparse(prhs[2])) 
      mexErrMsgTxt("argument 3 should not be sparse");

   if (!mxIsStruct(prhs[3])) 
      mexErrMsgTxt("argument 4 should be a struct");

   T* prX = reinterpret_cast<T*>(mxGetPr(prhs[0]));
   const mwSize* dimsX=mxGetDimensions(prhs[0]);
   INTM m=static_cast<INTM>(dimsX[0]);
   INTM n=static_cast<INTM>(dimsX[1]);
   Matrix<T> X(prX,m,n);

   const mwSize* dimsD=mxGetDimensions(prhs[1]);
   INTM mD=static_cast<INTM>(dimsD[0]);
   INTM p=static_cast<INTM>(dimsD[1]);
   AbstractMatrixB<T>* D;
   AbstractMatrixB<T>* D2 = NULL;
   AbstractMatrixB<T>* D3 = NULL;

   double* D_v;
   mwSize* D_r, *D_pB, *D_pE;
   INTM* D_r2, *D_pB2, *D_pE2;
   T* D_v2 = NULL;
   
   const int shifts = getScalarStructDef<int>(prhs[3],"shifts",1); // undocumented function

   if (mxIsSparse(prhs[1])) {
      D_v=static_cast<double*>(mxGetPr(prhs[1]));
      D_r=mxGetIr(prhs[1]);
      D_pB=mxGetJc(prhs[1]);
      D_pE=D_pB+1;
      createCopySparse<T>(D_v2,D_r2,D_pB2,D_pE2,
            D_v,D_r,D_pB,D_pE,p);
      D = new SpMatrix<T>(D_v2,D_r2,D_pB2,D_pE2,mD,p,D_pB2[p]);
   } else {
      T* prD = reinterpret_cast<T*>(mxGetPr(prhs[1]));
      D = new Matrix<T>(prD,mD,p);
   }
   const bool double_rows = getScalarStructDef<bool>(prhs[3],"double_rows",false); // undocumented function
   if (double_rows) {
      D2=D;
      D=new DoubleRowMatrix<T>(*D);
   }

   if (shifts > 1) {
      const bool center_shifts = getScalarStructDef<bool>(prhs[3],"center_shifts",false);
      D3=D;
      D=new ShiftMatrix<T>(*D,shifts,center_shifts);
   }

   T* pr_alpha0 = reinterpret_cast<T*>(mxGetPr(prhs[2]));
   const mwSize* dimsAlpha=mxGetDimensions(prhs[2]);
   INTM pAlpha=static_cast<INTM>(dimsAlpha[0]);
   INTM nAlpha=static_cast<INTM>(dimsAlpha[1]);
   Matrix<T> alpha0(pr_alpha0,pAlpha,nAlpha);

   plhs[0]=createMatrix<T>(pAlpha,nAlpha);
   T* pr_alpha=reinterpret_cast<T*>(mxGetPr(plhs[0]));
   Matrix<T> alpha(pr_alpha,pAlpha,nAlpha);

   FISTA::ParamFISTA<T> param;
   param.num_threads = getScalarStructDef<int>(prhs[3],"numThreads",-1);
   param.max_it = getScalarStructDef<int>(prhs[3],"max_it",1000);
   param.tol = getScalarStructDef<T>(prhs[3],"tol",0.000001);
   param.it0 = getScalarStructDef<int>(prhs[3],"it0",100);
   param.pos = getScalarStructDef<bool>(prhs[3],"pos",false);
   param.compute_gram = getScalarStructDef<bool>(prhs[3],"compute_gram",false);
   param.max_iter_backtracking = getScalarStructDef<int>(prhs[3],"max_iter_backtracking",1000);
   param.L0 = getScalarStructDef<T>(prhs[3],"L0",1.0);
   param.fixed_step = getScalarStructDef<T>(prhs[3],"fixed_step",false);
   param.gamma = MAX(1.01,getScalarStructDef<T>(prhs[3],"gamma",1.5));
   param.c = getScalarStructDef<T>(prhs[3],"c",1.0);
   param.lambda= getScalarStructDef<T>(prhs[3],"lambda",1.0);
   param.delta = getScalarStructDef<T>(prhs[3],"delta",1.0);
   param.lambda2= getScalarStructDef<T>(prhs[3],"lambda2",0.0);
   param.lambda3= getScalarStructDef<T>(prhs[3],"lambda3",0.0);
   mxArray* ppr_groups = mxGetField(prhs[3],0,"groups");
   if (ppr_groups) {
      if (!mexCheckType<int>(ppr_groups))
         mexErrMsgTxt("param.groups should be int32 (starting group is 1)");
      int* pr_groups = reinterpret_cast<int*>(mxGetPr(ppr_groups));
      const mwSize* dims_groups =mxGetDimensions(ppr_groups);
      int num_groups=static_cast<int>(dims_groups[0])*static_cast<int>(dims_groups[1]);
      if (num_groups != pAlpha) mexErrMsgTxt("Wrong size of param.groups");
      param.ngroups=num_groups;
      param.groups=pr_groups;
   } else {
      param.size_group= getScalarStructDef<int>(prhs[3],"size_group",1);
   }

   param.admm = getScalarStructDef<bool>(prhs[3],"admm",false);
   param.lin_admm = getScalarStructDef<bool>(prhs[3],"lin_admm",false);
   param.sqrt_step = getScalarStructDef<bool>(prhs[3],"sqrt_step",true);
   param.is_inner_weights = getScalarStructDef<bool>(prhs[3],"is_inner_weights",false);
   param.transpose = getScalarStructDef<bool>(prhs[3],"transpose",false);
   if (param.is_inner_weights) {
      mxArray* ppr_inner_weights = mxGetField(prhs[4],0,"inner_weights");
      if (!ppr_inner_weights) mexErrMsgTxt("field inner_weights is not provided");
      if (!mexCheckType<T>(ppr_inner_weights)) 
         mexErrMsgTxt("type of inner_weights is not correct");
      param.inner_weights = reinterpret_cast<T*>(mxGetPr(ppr_inner_weights));
   }

   getStringStruct(prhs[3],"regul",param.name_regul,param.length_names);
   param.regul = regul_from_string(param.name_regul);
   if (param.regul==INCORRECT_REG)
      mexErrMsgTxt("Unknown regularization");
   getStringStruct(prhs[3],"loss",param.name_loss,param.length_names);
   param.loss = loss_from_string(param.name_loss);
   if (param.loss==INCORRECT_LOSS)
      mexErrMsgTxt("Unknown loss");

   param.intercept = getScalarStructDef<bool>(prhs[3],"intercept",false);
   param.resetflow = getScalarStructDef<bool>(prhs[3],"resetflow",false);
   param.verbose = getScalarStructDef<bool>(prhs[3],"verbose",false);
   param.clever = getScalarStructDef<bool>(prhs[3],"clever",false);
   param.ista= getScalarStructDef<bool>(prhs[3],"ista",false);
   param.linesearch_mode= getScalarStructDef<int>(prhs[3],"linesearch_mode",0);
   param.subgrad= getScalarStructDef<bool>(prhs[3],"subgrad",false);
   param.log= getScalarStructDef<bool>(prhs[3],"log",false);
   param.a= getScalarStructDef<T>(prhs[3],"a",T(1.0));
   param.b= getScalarStructDef<T>(prhs[3],"b",0);

   if (param.log) {
      mxArray *stringData = mxGetField(prhs[3],0,"logName");
      if (!stringData) 
         mexErrMsgTxt("Missing field logName");
      int stringLength = mxGetN(stringData)+1;
      param.logName= new char[stringLength];
      mxGetString(stringData,param.logName,stringLength);
   }

   if ((!double_rows && shifts==1 && param.loss != CUR && param.loss != MULTILOG) && (pAlpha != p || nAlpha != n || mD != m)) { 
      mexErrMsgTxt("Argument sizes are not consistent");
   } else if (param.loss == MULTILOG) {
      Vector<T> Xv;
      X.toVect(Xv);
      INTM maxval = static_cast<INTM>(Xv.maxval());
      INTM minval = static_cast<INTM>(Xv.minval());
      if (minval != 0)
         mexErrMsgTxt("smallest class should be 0");
      if (maxval*X.n() > nAlpha || mD != m) {
         cerr << "Number of classes: " << maxval << endl;
         //cerr << "Alpha: " << pAlpha << " x " << nAlpha << endl;
         //cerr << "X: " << X.m() << " x " << X.n() << endl;
         mexErrMsgTxt("Argument sizes are not consistent");
      }
   } else if (param.loss == CUR && (pAlpha != D->n() || nAlpha != D->m())) {
      mexErrMsgTxt("Argument sizes are not consistent");
   }

   if (param.num_threads == -1) {
      param.num_threads=1;
#ifdef _OPENMP
      param.num_threads =  MIN(MAX_THREADS,omp_get_num_procs());
#endif
   } 

   if (param.regul==GRAPH_PATH_L0 || param.regul==GRAPH_PATH_CONV) 
      mexErrMsgTxt("Error: mexFistaPathCoding should be used instead");
   if (param.regul==GRAPH || param.regul==GRAPHMULT) 
      mexErrMsgTxt("Error: mexFistaGraph should be used instead");
   if (param.regul==TREE_L0 || param.regul==TREEMULT || param.regul==TREE_L2 || param.regul==TREE_LINF) 
      mexErrMsgTxt("Error: mexFistaTree should be used instead");

   Matrix<T> duality_gap;
   FISTA::solver<T>(X,*D,alpha0,alpha,param,duality_gap);
   if (nlhs==2) {
      plhs[1]=createMatrix<T>(duality_gap.m(),duality_gap.n());
      T* pr_dualitygap=reinterpret_cast<T*>(mxGetPr(plhs[1]));
      for (int i = 0; i<duality_gap.n()*duality_gap.m(); ++i) pr_dualitygap[i]=duality_gap[i];
   }
   if (param.logName) delete[](param.logName);

   if (mxIsSparse(prhs[1])) {
      deleteCopySparse<T>(D_v2,D_r2,D_pB2,D_pE2,
            D_v,D_r);
   }
   delete(D);
   if (shifts > 1) {
      delete(D2);
   }
   if (double_rows) {
      delete(D3);
   }

}
示例#3
0
inline void callFunction(mxArray* plhs[], const mxArray*prhs[],
      const long nlhs) {
   if (!mexCheckType<T>(prhs[0])) 
      mexErrMsgTxt("type of argument 1 is not consistent");
   if (mxIsSparse(prhs[0])) 
      mexErrMsgTxt("argument 1 should not be sparse");

   if (!mexCheckType<T>(prhs[1])) 
      mexErrMsgTxt("type of argument 2 is not consistent");

   if (!mexCheckType<T>(prhs[2])) 
      mexErrMsgTxt("type of argument 3 is not consistent");
   if (mxIsSparse(prhs[2])) 
      mexErrMsgTxt("argument 3 should not be sparse");

   if (!mxIsStruct(prhs[3])) 
         mexErrMsgTxt("argument 4 should be struct");
   if (!mxIsStruct(prhs[4])) 
      mexErrMsgTxt("argument 5 should be struct");

   T* prX = reinterpret_cast<T*>(mxGetPr(prhs[0]));
   const mwSize* dimsX=mxGetDimensions(prhs[0]);
   long m=static_cast<long>(dimsX[0]);
   long n=static_cast<long>(dimsX[1]);
   Matrix<T> X(prX,m,n);

   const mwSize* dimsD=mxGetDimensions(prhs[1]);
   long mD=static_cast<long>(dimsD[0]);
   long p=static_cast<long>(dimsD[1]);
   AbstractMatrixB<T>* D;

   double* D_v;
   mwSize* D_r, *D_pB, *D_pE;
   long* D_r2, *D_pB2, *D_pE2;
   T* D_v2;
   if (mxIsSparse(prhs[1])) {
      D_v=static_cast<double*>(mxGetPr(prhs[1]));
      D_r=mxGetIr(prhs[1]);
      D_pB=mxGetJc(prhs[1]);
      D_pE=D_pB+1;
      createCopySparse<T>(D_v2,D_r2,D_pB2,D_pE2,
            D_v,D_r,D_pB,D_pE,p);
      D = new SpMatrix<T>(D_v2,D_r2,D_pB2,D_pE2,mD,p,D_pB2[p]);
   } else {
      T* prD = reinterpret_cast<T*>(mxGetPr(prhs[1]));
      D = new Matrix<T>(prD,m,p);
   }

   T* pr_alpha0 = reinterpret_cast<T*>(mxGetPr(prhs[2]));
   const mwSize* dimsAlpha=mxGetDimensions(prhs[2]);
   long pAlpha=static_cast<long>(dimsAlpha[0]);
   long nAlpha=static_cast<long>(dimsAlpha[1]);
   Matrix<T> alpha0(pr_alpha0,pAlpha,nAlpha);

   mxArray* ppr_own_variables = mxGetField(prhs[3],0,"own_variables");
   if (!mexCheckType<long>(ppr_own_variables)) 
         mexErrMsgTxt("own_variables field should be int32");
   if (!ppr_own_variables) mexErrMsgTxt("field own_variables is not provided");
   long* pr_own_variables = reinterpret_cast<long*>(mxGetPr(ppr_own_variables));
   const mwSize* dims_groups =mxGetDimensions(ppr_own_variables);
   long num_groups=static_cast<long>(dims_groups[0])*static_cast<long>(dims_groups[1]);

   mxArray* ppr_N_own_variables = mxGetField(prhs[3],0,"N_own_variables");
   if (!ppr_N_own_variables) mexErrMsgTxt("field N_own_variables is not provided");
   if (!mexCheckType<long>(ppr_N_own_variables)) 
         mexErrMsgTxt("N_own_variables field should be int32");
   const mwSize* dims_var =mxGetDimensions(ppr_N_own_variables);
   long num_groups2=static_cast<long>(dims_var[0])*static_cast<long>(dims_var[1]);
   if (num_groups != num_groups2)
      mexErrMsgTxt("Error in tree definition");
   long* pr_N_own_variables = reinterpret_cast<long*>(mxGetPr(ppr_N_own_variables));
   long num_var=0;
   for (long i = 0; i<num_groups; ++i)
      num_var+=pr_N_own_variables[i];
   if (pAlpha < num_var) 
      mexErrMsgTxt("Input alpha is too small");

   mxArray* ppr_lambda_g = mxGetField(prhs[3],0,"eta_g");
   if (!ppr_lambda_g) mexErrMsgTxt("field eta_g is not provided");
   const mwSize* dims_weights =mxGetDimensions(ppr_lambda_g);
   long num_groups3=static_cast<long>(dims_weights[0])*static_cast<long>(dims_weights[1]);
   if (num_groups != num_groups3)
      mexErrMsgTxt("Error in tree definition");
   T* pr_lambda_g = reinterpret_cast<T*>(mxGetPr(ppr_lambda_g));

   mxArray* ppr_groups = mxGetField(prhs[3],0,"groups");
   const mwSize* dims_gg =mxGetDimensions(ppr_groups);
   if ((num_groups != static_cast<long>(dims_gg[0])) || 
         (num_groups != static_cast<long>(dims_gg[1])))
      mexErrMsgTxt("Error in tree definition");
   if (!ppr_groups) mexErrMsgTxt("field groups is not provided");
   mwSize* pr_groups_ir = reinterpret_cast<mwSize*>(mxGetIr(ppr_groups));
   mwSize* pr_groups_jc = reinterpret_cast<mwSize*>(mxGetJc(ppr_groups));

   plhs[0]=createMatrix<T>(pAlpha,nAlpha);
   T* pr_alpha=reinterpret_cast<T*>(mxGetPr(plhs[0]));
   Matrix<T> alpha(pr_alpha,pAlpha,nAlpha);

   FISTA::ParamFISTA<T> param;
   param.num_threads = getScalarStructDef<long>(prhs[4],"numThreads",-1);
   param.max_it = getScalarStructDef<long>(prhs[4],"max_it",1000);
   param.tol = getScalarStructDef<T>(prhs[4],"tol",0.000001);
   param.it0 = getScalarStructDef<long>(prhs[4],"it0",100);
   param.pos = getScalarStructDef<bool>(prhs[4],"pos",false);
   param.compute_gram = getScalarStructDef<bool>(prhs[4],"compute_gram",false);
   param.max_iter_backtracking = getScalarStructDef<long>(prhs[4],"max_iter_backtracking",1000);
   param.L0 = getScalarStructDef<T>(prhs[4],"L0",1.0);
   param.fixed_step = getScalarStructDef<T>(prhs[4],"fixed_step",false);
   param.gamma = MAX(1.01,getScalarStructDef<T>(prhs[4],"gamma",1.5));
   param.c = getScalarStructDef<T>(prhs[4],"c",1.0);
   param.lambda= getScalarStructDef<T>(prhs[4],"lambda",1.0);
   param.lambda2= getScalarStructDef<T>(prhs[4],"lambda2",0.0);
   param.lambda3= getScalarStructDef<T>(prhs[4],"lambda3",0.0);
   param.size_group= getScalarStructDef<long>(prhs[4],"size_group",1);
   param.delta = getScalarStructDef<T>(prhs[4],"delta",1.0);
   param.admm = getScalarStructDef<bool>(prhs[4],"admm",false);
   param.lin_admm = getScalarStructDef<bool>(prhs[4],"lin_admm",false);
   param.sqrt_step = getScalarStructDef<bool>(prhs[4],"sqrt_step",true);
   getStringStruct(prhs[4],"regul",param.name_regul,param.length_names);
   param.is_inner_weights = getScalarStructDef<bool>(prhs[4],"is_inner_weights",false);
   param.transpose = getScalarStructDef<bool>(prhs[4],"transpose",false);
   if (param.is_inner_weights) {
      mxArray* ppr_inner_weights = mxGetField(prhs[4],0,"inner_weights");
      if (!ppr_inner_weights) mexErrMsgTxt("field inner_weights is not provided");
      if (!mexCheckType<T>(ppr_inner_weights)) 
         mexErrMsgTxt("type of inner_weights is not correct");
      param.inner_weights = reinterpret_cast<T*>(mxGetPr(ppr_inner_weights));
   }

   param.regul = regul_from_string(param.name_regul);
   if (param.regul==INCORRECT_REG)
      mexErrMsgTxt("Unknown regularization");
   getStringStruct(prhs[4],"loss",param.name_loss,param.length_names);
   param.loss = loss_from_string(param.name_loss);
   if (param.loss==INCORRECT_LOSS)
      mexErrMsgTxt("Unknown loss");

   param.intercept = getScalarStructDef<bool>(prhs[4],"intercept",false);
   param.resetflow = getScalarStructDef<bool>(prhs[4],"resetflow",false);
   param.verbose = getScalarStructDef<bool>(prhs[4],"verbose",false);
   param.clever = getScalarStructDef<bool>(prhs[4],"clever",false);
   param.ista= getScalarStructDef<bool>(prhs[4],"ista",false);
   param.subgrad= getScalarStructDef<bool>(prhs[4],"subgrad",false);
   param.log= getScalarStructDef<bool>(prhs[4],"log",false);
   param.a= getScalarStructDef<T>(prhs[4],"a",T(1.0));
   param.b= getScalarStructDef<T>(prhs[4],"b",0);

   if (param.log) {
      mxArray *stringData = mxGetField(prhs[4],0,"logName");
      if (!stringData) 
         mexErrMsgTxt("Missing field logName");
      long stringLength = mxGetN(stringData)+1;
      param.logName= new char[stringLength];
      mxGetString(stringData,param.logName,stringLength);
   }

   if ((param.loss != CUR && param.loss != MULTILOG) && (pAlpha != p || nAlpha != n || mD != m)) { 
      mexErrMsgTxt("Argument sizes are not consistent");
   } else if (param.loss == MULTILOG) {
      Vector<T> Xv;
      X.toVect(Xv);
      long maxval = static_cast<long>(Xv.maxval());
      long minval = static_cast<long>(Xv.minval());
      if (minval != 0)
         mexErrMsgTxt("smallest class should be 0");
      if (maxval*X.n() > nAlpha || mD != m) {
         cerr << "Number of classes: " << maxval << endl;
         //cerr << "Alpha: " << pAlpha << " x " << nAlpha << endl;
         //cerr << "X: " << X.m() << " x " << X.n() << endl;
         mexErrMsgTxt("Argument sizes are not consistent");
      }
   } else if (param.loss == CUR && (pAlpha != D->n() || nAlpha != D->m())) {
      mexErrMsgTxt("Argument sizes are not consistent");
   }

   if (param.regul==GRAPH || param.regul==GRAPHMULT) 
      mexErrMsgTxt("Error: mexFistaGraph should be used instead");

   if (param.regul==TREEMULT && abs<T>(param.lambda2 - 0) < 1e-20) {
      mexErrMsgTxt("Error: with multi-task-tree, lambda2 should be > 0");
   }

   if (param.num_threads == -1) {
      param.num_threads=1;
#ifdef _OPENMP
      param.num_threads =  MIN(MAX_THREADS,omp_get_num_procs());
#endif
   } 

   TreeStruct<T> tree;
   tree.Nv=0;
   for (long i = 0; i<num_groups; ++i) tree.Nv+=pr_N_own_variables[i];
   tree.Ng=num_groups;
   tree.weights=pr_lambda_g;
   tree.own_variables=pr_own_variables;
   tree.N_own_variables=pr_N_own_variables;
   tree.groups_ir=pr_groups_ir;
   tree.groups_jc=pr_groups_jc;

   Matrix<T> duality_gap;
   FISTA::solver<T>(X,*D,alpha0,alpha,param,duality_gap,NULL,&tree);
   if (nlhs==2) {
      plhs[1]=createMatrix<T>(duality_gap.m(),duality_gap.n());
      T* pr_dualitygap=reinterpret_cast<T*>(mxGetPr(plhs[1]));
      for (long i = 0; i<duality_gap.n()*duality_gap.m(); ++i) pr_dualitygap[i]=duality_gap[i];
   }
   if (param.logName) delete[](param.logName);

   if (mxIsSparse(prhs[1])) {
      deleteCopySparse<T>(D_v2,D_r2,D_pB2,D_pE2,
            D_v,D_r);
   }
   delete(D);
}