Ejemplo n.º 1
0
inline void callFunction(mxArray* plhs[], const mxArray*prhs[]) {
   if (!mexCheckType<T>(prhs[0])) 
      mexErrMsgTxt("type of argument 1 is not consistent");
   if (mxIsSparse(prhs[0])) 
      mexErrMsgTxt("argument 1 should be full");
   if (!mexCheckType<T>(prhs[1])) 
      mexErrMsgTxt("type of argument 2 is not consistent");
   if (mxIsSparse(prhs[1])) 
      mexErrMsgTxt("argument 2 should be full");
   if (!mexCheckType<bool>(prhs[2])) 
      mexErrMsgTxt("type of argument 3 should be boolean");
   if (mxIsSparse(prhs[2])) 
      mexErrMsgTxt("argument 3 should be full");
   if (!mxIsStruct(prhs[3])) 
      mexErrMsgTxt("argument 4 should be struct");

   T* prX = reinterpret_cast<T*>(mxGetPr(prhs[0]));
   const mwSize* dimsX=mxGetDimensions(prhs[0]);
   INTM n=static_cast<INTM>(dimsX[0]);
   INTM M=static_cast<INTM>(dimsX[1]);

   T* prD = reinterpret_cast<T*>(mxGetPr(prhs[1]));
   const mwSize* dimsD=mxGetDimensions(prhs[1]);
   INTM nD=static_cast<INTM>(dimsD[0]);
   INTM K=static_cast<INTM>(dimsD[1]);
   if (n != nD) mexErrMsgTxt("argument sizes are not consistent");

   bool* prmask = reinterpret_cast<bool*>(mxGetPr(prhs[2]));
   const mwSize* dimsM=mxGetDimensions(prhs[2]);
   INTM nM=static_cast<INTM>(dimsM[0]);
   INTM mM=static_cast<INTM>(dimsM[1]);
   if (nM != n || mM != M) mexErrMsgTxt("argument sizes are not consistent");

   T lambda = getScalarStruct<T>(prhs[3],"lambda");
   T lambda2 = getScalarStructDef<T>(prhs[3],"lambda2",0);
   int L = getScalarStructDef<int>(prhs[3],"L",K);
   int numThreads = getScalarStructDef<int>(prhs[3],"numThreads",-1);
   bool pos = getScalarStructDef<bool>(prhs[3],"pos",false);
   bool verbose = getScalarStructDef<bool>(prhs[3],"verbose",true);
   constraint_type mode = (constraint_type)getScalarStructDef<int>(prhs[3],"mode",PENALTY);
   if (L > n && !(mode == PENALTY && isZero(lambda) && !pos && lambda2 > 0)) {
      if (verbose)
         printf("L is changed to %d\n",(int)n);
      L=n;
   }
   if (L > K) {
      if (verbose)
         printf("L is changed to %d\n",(int)K);
      L=K;
   }
   Matrix<T> X(prX,n,M);
   Matrix<T> D(prD,n,K);
   Matrix<bool> mask(prmask,n,M);
   SpMatrix<T> alpha;

   lasso_mask<T>(X,D,alpha,mask,L,lambda,lambda2,mode,pos,numThreads);
   convertSpMatrix(plhs[0],alpha.m(),alpha.n(),alpha.n(),
         alpha.nzmax(),alpha.v(),alpha.r(),alpha.pB());
}
inline void callFunction(mxArray* plhs[], const mxArray*prhs[]) {
    if (!mexCheckType<T>(prhs[0]))
        mexErrMsgTxt("type of argument 1 is not consistent");
    if (mxIsSparse(prhs[0]))
        mexErrMsgTxt("argument 1 should be full");
    if (!mexCheckType<T>(prhs[1]))
        mexErrMsgTxt("type of argument 2 is not consistent");
    if (mxIsSparse(prhs[1]))
        mexErrMsgTxt("argument 2 should be full");
    if (mxIsSparse(prhs[2]))
        mexErrMsgTxt("argument 3 should be full");
    if (!mxIsStruct(prhs[3]))
        mexErrMsgTxt("argument 4 should be struct");


    T* prX = reinterpret_cast<T*>(mxGetPr(prhs[0]));
    const mwSize* dimsX=mxGetDimensions(prhs[0]);
    long n=static_cast<long>(dimsX[0]);
    long M=static_cast<long>(dimsX[1]);

    T* prD = reinterpret_cast<T*>(mxGetPr(prhs[1]));
    const mwSize* dimsD=mxGetDimensions(prhs[1]);
    long nD=static_cast<long>(dimsD[0]);
    long K=static_cast<long>(dimsD[1]);
    if (n != nD) mexErrMsgTxt("argument sizes are not consistent");

    T lambda = getScalarStruct<T>(prhs[3],"lambda");
    long L = getScalarStructDef<long>(prhs[3],"L",K);
    long numThreads = getScalarStructDef<long>(prhs[3],"numThreads",-1);
    bool pos = getScalarStructDef<bool>(prhs[3],"pos",false);
    constraint_type mode = (constraint_type)getScalarStructDef<long>(prhs[3],"mode",PENALTY);
    if (L > n) {
        printf("L is changed to %ld\n",n);
        L=n;
    }
    if (L > K) {
        printf("L is changed to %ld\n",K);
        L=K;
    }
    Matrix<T> X(prX,n,M);
    Matrix<T> D(prD,n,K);


    T* prWeight = reinterpret_cast<T*>(mxGetPr(prhs[2]));
    const mwSize* dimsW=mxGetDimensions(prhs[2]);
    long KK=static_cast<long>(dimsW[0]);
    long MM=static_cast<long>(dimsW[1]);
    if (K != KK || M != MM) mexErrMsgTxt("argument sizes are not consistent");


    Matrix<T> weight(prWeight,KK,MM);

    SpMatrix<T> alpha;
    lassoWeight<T>(X,D,weight,alpha,L,lambda,mode,pos,numThreads);
    convertSpMatrix(plhs[0],alpha.m(),alpha.n(),alpha.n(),
                    alpha.nzmax(),alpha.v(),alpha.r(),alpha.pB());
}
Ejemplo n.º 3
0
inline void callFunction(mxArray* plhs[], const mxArray*prhs[],const int nlhs) {
  if (!mxIsCell(prhs[0])) 
    mexErrMsgTxt("argument 2 should be a cell");
  std::vector<NodeElem *> *gstruct = mexMatlabToCgroups(prhs[0]);
  mwSize dims[1] = {1};
  const char *names[] = {"eta_g", "groups", "own_variables","N_own_variables"};
  plhs[1]=mxCreateStructArray(1,dims,4,names);
  SpMatrix<bool> *groups;
  Vector<int> *own_variables;
  Vector<int> *N_own_variables;
  Vector<double> *eta_g;
  int *permutations;
  int nb_perm;
  int nb_vars = _treeOfGroupStruct<double>(gstruct,&permutations,&nb_perm,&eta_g,&groups,&own_variables,&N_own_variables);
  del_gstruct(gstruct);
  mxArray* mxeta_g = makeVector<double>(eta_g);
  mxArray* mxown_variables = makeVector<int>(own_variables);
  mxArray* mxN_own_variables = makeVector<int>(N_own_variables);
  mxArray* mxgroups[1];
  convertSpMatrix<bool>(mxgroups[0],groups->m(),groups->n(),groups->n(),
			groups->nzmax(),groups->v(),groups->r(),groups->pB());
  delete eta_g;
  delete groups;
  delete own_variables;
  delete N_own_variables;
  mxSetField(plhs[1],0,"eta_g",mxeta_g);
  mxSetField(plhs[1],0,"groups",mxgroups[0]);
  mxSetField(plhs[1],0,"own_variables",mxown_variables);
  mxSetField(plhs[1],0,"N_own_variables",mxN_own_variables);
  dims[0] = nb_perm;
  mxArray *mxperm = mxCreateNumericArray((mwSize)1,dims,mxINT32_CLASS,mxREAL);
  if(nb_perm > 0)
    memcpy(mxGetPr(mxperm),permutations,nb_perm * sizeof(int));
  plhs[0] = mxperm;
  dims[0] = 1;
  plhs[2]=mxCreateNumericArray((mwSize)1,dims,mxINT32_CLASS,mxREAL);
  int* pr_out=reinterpret_cast<int *>(mxGetPr(plhs[2]));
  *pr_out = nb_vars;
}
Ejemplo n.º 4
0
inline void callFunction(mxArray* plhs[], const mxArray*prhs[],const int nrhs,
      const int nlhs) {
   if (nrhs==3) {
      if (!mexCheckType<T>(prhs[0])) 
         mexErrMsgTxt("type of argument 1 is not consistent");
      if (mxIsSparse(prhs[0])) 
         mexErrMsgTxt("argument 1 should be full");
      if (!mexCheckType<T>(prhs[1])) 
         mexErrMsgTxt("type of argument 2 is not consistent");
      if (mxIsSparse(prhs[1])) 
         mexErrMsgTxt("argument 2 should be full");
      if (!mxIsStruct(prhs[2])) 
         mexErrMsgTxt("argument 3 should be struct");

      T* prX = reinterpret_cast<T*>(mxGetPr(prhs[0]));
      const mwSize* dimsX=mxGetDimensions(prhs[0]);
      int n=static_cast<int>(dimsX[0]);
      int M=static_cast<int>(dimsX[1]);

      T* prD = reinterpret_cast<T*>(mxGetPr(prhs[1]));
      const mwSize* dimsD=mxGetDimensions(prhs[1]);
      int nD=static_cast<int>(dimsD[0]);
      int K=static_cast<int>(dimsD[1]);
      if (n != nD) mexErrMsgTxt("argument sizes are not consistent");
      T lambda = getScalarStruct<T>(prhs[2],"lambda");
      T lambda2 = getScalarStructDef<T>(prhs[2],"lambda2",0);
      int L = getScalarStructDef<int>(prhs[2],"L",K);
      int length_path = MAX(2,getScalarStructDef<int>(prhs[2],"length_path",4*L));
      int numThreads = getScalarStructDef<int>(prhs[2],"numThreads",-1);
      bool pos = getScalarStructDef<bool>(prhs[2],"pos",false);
      //bool verbose = getScalarStructDef<bool>(prhs[2],"verbose",false);
      bool ols = getScalarStructDef<bool>(prhs[2],"ols",false);
      bool cholesky = ols || getScalarStructDef<bool>(prhs[2],"cholesky",false);
      constraint_type mode = (constraint_type)getScalarStructDef<int>(prhs[2],"mode",PENALTY);
      if (L > n && !(mode == PENALTY && isZero(lambda) && !pos && lambda2 > 0)) {
//         if (verbose)
//            printf("L is changed to %d\n",n);
         L=n;
      }
      if (L > K) {
//         if (verbose)
//            printf("L is changed to %d\n",K);
         L=K;
      }
      Matrix<T> X(prX,n,M);
      Matrix<T> D(prD,n,K);
      SpMatrix<T> alpha;

      if (nlhs == 2) {
         Matrix<T> norm(K,length_path);
         norm.setZeros();
         if (cholesky) {
            lasso<T>(X,D,alpha,L,lambda,lambda2,mode,pos,ols,numThreads,&norm,length_path);
         } else {
            lasso2<T>(X,D,alpha,L,lambda,lambda2,mode,pos,numThreads,&norm,length_path);
         }
         Vector<T> norms_col;
         norm.norm_2_cols(norms_col);
         int length=1;
         for (int i = 1; i<norms_col.n(); ++i)
            if (norms_col[i]) ++length;
         plhs[1]=createMatrix<T>(K,length);
         T* pr_norm=reinterpret_cast<T*>(mxGetPr(plhs[1]));
         Matrix<T> norm2(pr_norm,K,length);
         Vector<T> col;
         for (int i = 0; i<length; ++i) {
            norm2.refCol(i,col);
            norm.copyCol(i,col);
         }
      } else {
         if (cholesky) {
            lasso<T>(X,D,alpha,L,lambda,lambda2,mode,pos,ols,numThreads,NULL,length_path);
         } else {
            lasso2<T>(X,D,alpha,L,lambda,lambda2,mode,pos,numThreads,NULL,length_path);
         }
      }
      convertSpMatrix(plhs[0],alpha.m(),alpha.n(),alpha.n(),
            alpha.nzmax(),alpha.v(),alpha.r(),alpha.pB());
   } else {
      if (!mexCheckType<T>(prhs[0])) 
         mexErrMsgTxt("type of argument 1 is not consistent");
      if (mxIsSparse(prhs[0])) 
         mexErrMsgTxt("argument 1 should be full");
      if (!mexCheckType<T>(prhs[1])) 
         mexErrMsgTxt("type of argument 2 is not consistent");
      if (mxIsSparse(prhs[1])) 
         mexErrMsgTxt("argument 2 should be full");
      if (!mexCheckType<T>(prhs[2])) 
         mexErrMsgTxt("type of argument 3 is not consistent");
      if (mxIsSparse(prhs[2])) 
         mexErrMsgTxt("argument 3 should be full");
      if (!mxIsStruct(prhs[3])) 
         mexErrMsgTxt("argument 4 should be struct");

      T* prX = reinterpret_cast<T*>(mxGetPr(prhs[0]));
      const mwSize* dimsX=mxGetDimensions(prhs[0]);
      int n=static_cast<int>(dimsX[0]);
      int M=static_cast<int>(dimsX[1]);

      T* prG = reinterpret_cast<T*>(mxGetPr(prhs[1]));
      const mwSize* dimsD=mxGetDimensions(prhs[1]);
      int K1=static_cast<int>(dimsD[0]);
      int K2=static_cast<int>(dimsD[1]);
      if (K1 != K2) mexErrMsgTxt("argument sizes are not consistent");
      int K=K1;

      T* prDtR = reinterpret_cast<T*>(mxGetPr(prhs[2]));
      const mwSize* dimsDtR=mxGetDimensions(prhs[2]);
      int K3=static_cast<int>(dimsDtR[0]);
      int M2=static_cast<int>(dimsDtR[1]);
      if (K1 != K3) mexErrMsgTxt("argument sizes are not consistent");
      if (M != M2) mexErrMsgTxt("argument sizes are not consistent");

      T lambda = getScalarStruct<T>(prhs[3],"lambda");
      T lambda2 = getScalarStructDef<T>(prhs[3],"lambda2",0);
      int L = getScalarStructDef<int>(prhs[3],"L",K1);
      int length_path = getScalarStructDef<int>(prhs[3],"length_path",4*L);
      int numThreads = getScalarStructDef<int>(prhs[3],"numThreads",-1);
      bool pos = getScalarStructDef<bool>(prhs[3],"pos",false);
//      bool verbose = getScalarStructDef<bool>(prhs[3],"verbose",true);
      bool ols = getScalarStructDef<bool>(prhs[3],"ols",false);
      bool cholesky = ols || getScalarStructDef<bool>(prhs[3],"cholesky",false);
      constraint_type mode = (constraint_type)getScalarStructDef<int>(prhs[3],"mode",PENALTY);
      if (L > n && !(mode == PENALTY && isZero(lambda) && !pos && lambda2 > 0)) {
//         if (verbose)
//            printf("L is changed to %d\n",n);
         L=n;
      }
      if (L > K) {
//         if (verbose)
//            printf("L is changed to %d\n",K);
         L=K;
      }
      Matrix<T> X(prX,n,M);
      Matrix<T> G(prG,K,K);
      Matrix<T> DtR(prDtR,K,M);
      SpMatrix<T> alpha;

      if (nlhs == 2) {
         Matrix<T> norm(K,length_path);
         norm.setZeros();
         if (cholesky) {
            lasso<T>(X,G,DtR,alpha,L,lambda,mode,pos,ols,numThreads,&norm,length_path);
         } else {
            lasso2<T>(X,G,DtR,alpha,L,lambda,mode,pos,numThreads,&norm,length_path);
         }
         Vector<T> norms_col;
         norm.norm_2_cols(norms_col);
         int length=1;
         for (int i = 1; i<norms_col.n(); ++i)
            if (norms_col[i]) ++length;
         plhs[1]=createMatrix<T>(K,length);
         T* pr_norm=reinterpret_cast<T*>(mxGetPr(plhs[1]));
         Matrix<T> norm2(pr_norm,K,length);
         Vector<T> col;
         for (int i = 0; i<length; ++i) {
            norm2.refCol(i,col);
            norm.copyCol(i,col);
         }
      } else {
         if (cholesky) {
            lasso<T>(X,G,DtR,alpha,L,lambda,mode,pos,ols,numThreads,NULL,length_path);
         } else {
            lasso2<T>(X,G,DtR,alpha,L,lambda,mode,pos,numThreads,NULL,length_path);
         }
      }
      convertSpMatrix(plhs[0],alpha.m(),alpha.n(),alpha.n(),
            alpha.nzmax(),alpha.v(),alpha.r(),alpha.pB());
   }
}
inline void callFunction(mxArray* plhs[], const mxArray*prhs[],
      const int nlhs) {
   if (!mexCheckType<T>(prhs[0])) 
      mexErrMsgTxt("type of argument 1 is not consistent");
   if (mxIsSparse(prhs[0])) 
      mexErrMsgTxt("argument 1 should not be sparse");

   if (!mxIsStruct(prhs[1])) 
      mexErrMsgTxt("argument 2 should be struct");
   if (!mxIsStruct(prhs[2])) 
      mexErrMsgTxt("argument 3 should be struct");

   T* pr_alpha0 = reinterpret_cast<T*>(mxGetPr(prhs[0]));
   const mwSize* dimsAlpha=mxGetDimensions(prhs[0]);
   int pAlpha=static_cast<int>(dimsAlpha[0]);
   int nAlpha=static_cast<int>(dimsAlpha[1]);
   Matrix<T> alpha0(pr_alpha0,pAlpha,nAlpha);

   mxArray* ppr_GG = mxGetField(prhs[1],0,"weights");
   if (!mxIsSparse(ppr_GG)) 
      mexErrMsgTxt("field groups should be sparse");
   T* graph_weights = reinterpret_cast<T*>(mxGetPr(ppr_GG));
   mwSize* GG_r=mxGetIr(ppr_GG);
   mwSize* GG_pB=mxGetJc(ppr_GG);
   const mwSize* dims_GG=mxGetDimensions(ppr_GG);
   int GGm=static_cast<int>(dims_GG[0]);
   int GGn=static_cast<int>(dims_GG[1]);
   if (GGm != GGn || GGm != pAlpha)
      mexErrMsgTxt("size of field groups is not consistent");

   mxArray* ppr_weights = mxGetField(prhs[1],0,"start_weights");
   if (mxIsSparse(ppr_weights)) 
      mexErrMsgTxt("field start_weights should not be sparse");
   T* start_weights = reinterpret_cast<T*>(mxGetPr(ppr_weights));
   const mwSize* dims_weights=mxGetDimensions(ppr_weights);
   int nweights=static_cast<int>(dims_weights[0])*static_cast<int>(dims_weights[1]);
   if (nweights != pAlpha)
      mexErrMsgTxt("size of field start_weights is not consistent");

   mxArray* ppr_weights2 = mxGetField(prhs[1],0,"stop_weights");
   if (mxIsSparse(ppr_weights2)) 
      mexErrMsgTxt("field stop_weights should not be sparse");
   T* stop_weights = reinterpret_cast<T*>(mxGetPr(ppr_weights2));
   const mwSize* dims_weights2=mxGetDimensions(ppr_weights2);
   int nweights2=static_cast<int>(dims_weights2[0])*static_cast<int>(dims_weights2[1]);
   if (nweights2 != pAlpha)
      mexErrMsgTxt("size of field stop_weights is not consistent");


   FISTA::ParamFISTA<T> param;
   param.num_threads = getScalarStructDef<int>(prhs[2],"numThreads",-1);
   getStringStruct(prhs[2],"regul",param.name_regul,param.length_names);
   param.regul = regul_from_string(param.name_regul);
   if (param.regul==INCORRECT_REG)
      mexErrMsgTxt("Unknown regularization");
   param.intercept = getScalarStructDef<bool>(prhs[2],"intercept",false);
   param.verbose = getScalarStructDef<bool>(prhs[2],"verbose",false);
   param.eval_dual_norm = getScalarStructDef<bool>(prhs[2],"dual_norm",false);

   if (param.regul != GRAPH_PATH_L0 && param.regul != GRAPH_PATH_CONV)
      mexErrMsgTxt("Use a different mexEvalGraphPath function");

   if (param.num_threads == -1) {
      param.num_threads=1;
#ifdef _OPENMP
      param.num_threads =  MIN(MAX_THREADS,omp_get_num_procs());
#endif
   }
   
   GraphPathStruct<T> graph;
   graph.n=pAlpha;
   graph.m=GG_pB[graph.n]-GG_pB[0];
   graph.weights=graph_weights;
   graph.start_weights=start_weights;
   graph.stop_weights=stop_weights;
   graph.ir=GG_r;
   graph.jc=GG_pB;
   graph.precision = getScalarStructDef<long long>(prhs[2],"precision",100000000000000000);

   Vector<T> val;
   SpMatrix<T> path;
   if (nlhs==1) {
      FISTA::EvalGraphPath<T>(alpha0,param,val,&graph);
   } else {
      FISTA::EvalGraphPath<T>(alpha0,param,val,&graph,&path);
   }

   plhs[0]=createMatrix<T>(1,val.n());
   T* pr_val=reinterpret_cast<T*>(mxGetPr(plhs[0]));
   for (int i = 0; i<val.n(); ++i) pr_val[i]=val[i];

   if (nlhs==2)
      convertSpMatrix(plhs[1],path.m(),path.n(),path.n(),
            path.nzmax(),path.v(),path.r(),path.pB());
}
Ejemplo n.º 6
0
   inline void callFunction(mxArray* plhs[], const mxArray*prhs[], 
         const long nlhs) {
      if (!mexCheckType<T>(prhs[0])) 
         mexErrMsgTxt("type of argument 1 is not consistent");
      if (mxIsSparse(prhs[0])) 
         mexErrMsgTxt("argument 1 should be full");
      if (!mexCheckType<T>(prhs[1])) 
         mexErrMsgTxt("type of argument 2 is not consistent");
      if (mxIsSparse(prhs[1])) 
         mexErrMsgTxt("argument 2 should be full");
      if (!mexCheckType<bool>(prhs[2])) 
         mexErrMsgTxt("type of argument 3 should be boolean");
      if (mxIsSparse(prhs[2])) 
         mexErrMsgTxt("argument 3 should be full");

      if (!mxIsStruct(prhs[3])) 
         mexErrMsgTxt("argument 4 should be struct");
      
      T* prX = reinterpret_cast<T*>(mxGetPr(prhs[0]));
      const mwSize* dimsX=mxGetDimensions(prhs[0]);
      long n=static_cast<long>(dimsX[0]);
      long M=static_cast<long>(dimsX[1]);

      T* prD = reinterpret_cast<T*>(mxGetPr(prhs[1]));
      const mwSize* dimsD=mxGetDimensions(prhs[1]);
      long nD=static_cast<long>(dimsD[0]);
      long K=static_cast<long>(dimsD[1]);
      if (n != nD) mexErrMsgTxt("argument sizes are not consistent");

      bool* prmask = reinterpret_cast<bool*>(mxGetPr(prhs[2]));
      const mwSize* dimsM=mxGetDimensions(prhs[2]);
      long nM=static_cast<long>(dimsM[0]);
      long mM=static_cast<long>(dimsM[1]);
      if (nM != n || mM != M) mexErrMsgTxt("argument sizes are not consistent");

      Matrix<T> X(prX,n,M);
      Matrix<bool> mask(prmask,n,M);
      Matrix<T> D(prD,n,K);
      SpMatrix<T> alpha;

      long numThreads = getScalarStructDef<long>(prhs[3],"numThreads",-1);
      mxArray* pr_L=mxGetField(prhs[3],0,"L");
      mxArray* pr_eps=mxGetField(prhs[3],0,"eps");
      mxArray* pr_lambda=mxGetField(prhs[3],0,"lambda");
      if (!pr_L && !pr_eps && !pr_lambda) mexErrMsgTxt("You should either provide L, eps or lambda");
      
      long sizeL = 1;
      long L=MIN(n,K);
      long *pL = &L;
      if (pr_L) {
         const mwSize* dimsL= mxGetDimensions(pr_L);
         sizeL=static_cast<long>(dimsL[0])*static_cast<long>(dimsL[1]);
         if (sizeL > 1) {
            if (!mexCheckType<long>(pr_L)) 
               mexErrMsgTxt("Type of param.L should be int32");
            pL = reinterpret_cast<long*>(mxGetPr(pr_L));
         }
         L=MIN(L,static_cast<long>(mxGetScalar(pr_L)));
      }

      long sizeE=1;
      T eps=0;
      T* pE=&eps;
      if (pr_eps) {
         const mwSize* dimsE=mxGetDimensions(pr_eps);
         sizeE=static_cast<long>(dimsE[0])*static_cast<long>(dimsE[1]);
         eps=static_cast<T>(mxGetScalar(pr_eps));
         if (sizeE > 1)
            pE = reinterpret_cast<T*>(mxGetPr(pr_eps));
      }

      T lambda=0;
      long sizeLambda=1;
      T* pLambda=&lambda;
      if (pr_lambda) {
         const mwSize* dimsLambda=mxGetDimensions(pr_lambda);
         sizeLambda=static_cast<long>(dimsLambda[0])*static_cast<long>(dimsLambda[1]);
         lambda=static_cast<T>(mxGetScalar(pr_lambda));
         if (sizeLambda > 1)
            pLambda = reinterpret_cast<T*>(mxGetPr(pr_lambda));
      }

      Matrix<T>* prPath=NULL;
      if (nlhs == 2) {
         plhs[1]=createMatrix<T>(K,L);
         T* pr_path=reinterpret_cast<T*>(mxGetPr(plhs[1]));
         Matrix<T> path(pr_path,K,L);
         path.setZeros();
         prPath=&path;
      }
      omp_mask<T>(X,D,alpha,mask,pL,pE,pLambda,sizeL > 1,sizeE > 1,sizeLambda > 1,
            numThreads,prPath);
      convertSpMatrix(plhs[0],K,M,alpha.n(),alpha.nzmax(),alpha.v(),alpha.r(),
            alpha.pB());
   }
Ejemplo n.º 7
0
   inline void callFunction(mxArray* plhs[], const mxArray*prhs[], 
         const int nlhs) {
      if (!mexCheckType<T>(prhs[0])) 
         mexErrMsgTxt("type of argument 1 is not consistent");
      if (mxIsSparse(prhs[0])) 
         mexErrMsgTxt("argument 1 should be full");
      if (!mexCheckType<T>(prhs[1])) 
         mexErrMsgTxt("type of argument 2 is not consistent");
      if (mxIsSparse(prhs[1])) 
         mexErrMsgTxt("argument 2 should be full");
      if (!mexCheckType<bool>(prhs[2])) 
         mexErrMsgTxt("type of argument 3 should be boolean");
      if (mxIsSparse(prhs[2])) 
         mexErrMsgTxt("argument 3 should be full");

      if (!mxIsStruct(prhs[3])) 
         mexErrMsgTxt("argument 3 should be struct");

      T* prX = reinterpret_cast<T*>(mxGetPr(prhs[0]));
      const mwSize* dimsX=mxGetDimensions(prhs[0]);
      int n=static_cast<int>(dimsX[0]);
      int M=static_cast<int>(dimsX[1]);

      T* prD = reinterpret_cast<T*>(mxGetPr(prhs[1]));
      const mwSize* dimsD=mxGetDimensions(prhs[1]);
      int nD=static_cast<int>(dimsD[0]);
      int K=static_cast<int>(dimsD[1]);
      if (n != nD) mexErrMsgTxt("argument sizes are not consistent");

      bool* prmask = reinterpret_cast<bool*>(mxGetPr(prhs[2]));
      const mwSize* dimsM=mxGetDimensions(prhs[2]);
      int nM=static_cast<int>(dimsM[0]);
      int mM=static_cast<int>(dimsM[1]);
      if (nM != n || mM != M) mexErrMsgTxt("argument sizes are not consistent");


      Matrix<T> X(prX,n,M);
      Matrix<T> D(prD,n,K);
      SpMatrix<T> alpha;

      mxArray* pr_L=mxGetField(prhs[3],0,"L");
      if (!pr_L) mexErrMsgTxt("Missing field L in param");
      const mwSize* dimsL=mxGetDimensions(pr_L);
      int sizeL=static_cast<int>(dimsL[0])*static_cast<int>(dimsL[1]);

      mxArray* pr_eps=mxGetField(prhs[3],0,"eps");
      if (!pr_eps) mexErrMsgTxt("Missing field eps in param");
      const mwSize* dimsE=mxGetDimensions(pr_eps);
      int sizeE=static_cast<int>(dimsE[0])*static_cast<int>(dimsE[1]);
      int numThreads = getScalarStructDef<int>(prhs[3],"numThreads",-1);
      Matrix<bool> mask(prmask,n,M);

      if (nlhs == 2) {
         int L=MIN(n,MIN(static_cast<int>(mxGetScalar(pr_L)),K));
         plhs[1]=createMatrix<T>(K,L);
         T* pr_path=reinterpret_cast<T*>(mxGetPr(plhs[1]));
         Matrix<T> path(pr_path,K,L);
         path.setZeros();
         T eps=static_cast<T>(mxGetScalar(pr_eps));
         omp_mask<T>(X,D,alpha,mask,L,eps,numThreads,path);
      } else {

         if (sizeL == 1) {
            int L=static_cast<int>(mxGetScalar(pr_L));
            if (sizeE == 1) {
               T eps=static_cast<T>(mxGetScalar(pr_eps));
               omp_mask<T>(X,D,alpha,mask,L,eps,numThreads);
            } else {
               T* pE = reinterpret_cast<T*>(mxGetPr(pr_eps));
               omp_mask<T>(X,D,alpha,mask,L,pE,numThreads);
            }
         } else {
            if (!mexCheckType<int>(pr_L)) 
               mexErrMsgTxt("Type of param.L should be int32");
            int* pL = reinterpret_cast<int*>(mxGetPr(pr_L));
            if (sizeE == 1) {
               T eps=static_cast<T>(mxGetScalar(pr_eps));
               omp_mask<T>(X,D,alpha,mask,pL,eps,numThreads);
            } else {
               T* pE = reinterpret_cast<T*>(mxGetPr(pr_eps));
               omp_mask<T>(X,D,alpha,mask,pL,pE,numThreads,true,true);
            }
         }
      }

      convertSpMatrix(plhs[0],K,M,alpha.n(),alpha.nzmax(),alpha.v(),alpha.r(),
            alpha.pB());
   }