Exemple #1
0
double regularization_path(problem *prob, double epsilon, int nval)
{
   int nr_folds = 5;
   double llog, error, best_error = DBL_MAX, lambda, best_lambda;
   double lmax, lmin, lstep;
   double *y_hat = dvector(1, prob->n);
   double *w = dvector(1, prob->dim);

  /* compute maximum lambda for which all weights are 0 (Osborne et al. 1999)
    * lambda_max = ||X'y||_inf. According to scikit-learn source code, you can
    * divide by npatterns and it still works */
   dmvtransmult(prob->X, prob->n, prob->dim, prob->y, prob->n, w);
   lmax = dvnorm(w, prob->dim, INF) / prob->n;
   lmin = epsilon*lmax;
   lstep = (log2(lmax)-log2(lmin))/nval;

   fprintf(stdout, "lmax=%g lmin=%g epsilon=%g nval=%d\n",
           lmax, lmin, epsilon, nval);

   /* warm-starts: weights are set to 0 only at the begining */
   dvset(w, prob->dim, 0);
   for(llog=log2(lmax); llog >= log2(lmin); llog -= lstep)
   {
      lambda = pow(2, llog);
      /*cross_validation(prob, w, lambda, 0, nr_folds, y_hat);*/

      /*******************************************************/
      int iter = 1000; double tol = 0, fret;
      fista(prob, w, lambda, 0, tol, 0, &iter, &fret);
      fista_predict(prob, w, y_hat);
      /*******************************************************/

      error = mae(prob->y, prob->n, y_hat);
      fprintf(stdout, "   lambda %10.6lf   MAE %7.6lf   active weights %d/%d\n",
              lambda, error, dvnotzero(w, prob->dim), prob->dim);

      dvprint(stdout, w, prob->dim);

      if (error < best_error)
      {
         best_error = error;
         best_lambda = lambda;
      }
   }

   free_dvector(y_hat, 1, prob->n);
   free_dvector(w, 1, prob->dim);

   print_line(60);
   fprintf(stdout, "\nBest: lambda=%g MAE=%g active weights=%d/%d\n",
           best_lambda, best_error, dvnotzero(w, prob->dim), prob->dim);

   return best_lambda;
}
Exemple #2
0
int main(int argc, char** argv)
{
    
clock_t tstart=0, tstop=0;
//
// the strart time counter
tstart = clock();

    string a (argv[2]);
    double theta = boost::lexical_cast<double>(a);
    ifstream inFile ( argv[1], ifstream::in); 
    string line;
    int linenum = 0;
    string item;
    
    int nbObj  = 0;
    int nbAttr = 0;

    string me (argv[3]);
    string mi (argv[4]);
    string mae (argv[5]);
    string mai (argv[6]);

    int minExt =  boost::lexical_cast<int>(me);
    int minInt =  boost::lexical_cast<int>(mi);
    int maxExt =  boost::lexical_cast<int>(mae);
    int maxInt =  boost::lexical_cast<int>(mai);

        
    /// couting objects, attributes ///
    getline (inFile, line);
    nbAttr = std::count(line.begin(), line.end(), '\t');
    while (getline (inFile, line))
          nbObj++;
    cout << "Found " << nbObj << " objects and " << nbAttr << " attributes." << endl;
    
    /// init data ///
    std::set<double> sorted_domain;
   
    
    //string objNames[nbObj];
    //string attrNames[nbAttr];
    vector<string> objNames (nbObj);
    vector<string> attrNames (nbAttr);
    
    multi_double data(boost::extents[nbObj][nbAttr]);
     

    /// reading data ///
    inFile.clear(); 
    inFile.seekg (0, std::ios::beg);
    std::istringstream stm;

    int itemnum = 0;
    getline (inFile, line);
    istringstream linestream(line);
    getline (linestream, item, '\t');
    while (getline (linestream, item, '\t'))
    {
          attrNames[itemnum++] = item;
    }

    while (getline (inFile, line))
    {
        istringstream linestream(line);
        itemnum = 0;
        getline (linestream, item, '\t');
        objNames[linenum] = item;
	
        while (getline (linestream, item, '\t'))
        {
          data[linenum][itemnum] =  boost::lexical_cast<double>(item);
          sorted_domain.insert(data[linenum][itemnum]);
          itemnum++;
        }
        linenum++;
    }
    inFile.close();
      


    // Building domain  of value
     vector<double> W;
     set<double>::iterator myIterator;
     for(myIterator =  sorted_domain.begin();
         myIterator !=  sorted_domain.end();
             myIterator++)
     {
            W.push_back(*myIterator);
            //cout << *myIterator << endl;
     }

    vector<double> classesL;
    vector<double> classesR;
    
    /// Computing blocks of tolerance over W ///
    double curL,curR;
    unsigned int i, j;
    double k = -999999999;
    for (i = 0 ; i != W.size(); i++)
    {
       	curL= W[i];
        curR= W[i];
	for (j = i; j != W.size(); j++) 
	{
	if (W[j] - curL <= theta)
         		    curR = W[j];
			else break;
		}
		if (! (curR<= k) ) 
                {
                   classesL.push_back(curL);
                   classesR.push_back(curR);
                }
		k = curR;
    }


//// computing 'proximity' of a class, ie classes having a non empty intersection with the current
vector<int> beginL (classesR.size());
vector<int> beginR (classesR.size());

for (unsigned int l=0; l != classesL.size(); l++)
	{
		beginL[l] = -1;
		beginR[l] = -1;

		for (unsigned int prox=0; prox != classesL.size(); prox++)
		{
			double interL = (classesL[l] > classesL[prox] ? classesL[l] : classesL[prox]);
			double interR = (classesR[l] < classesR[prox] ? classesR[l] : classesR[prox]);
			if (interR >= interL) {
				if (beginL[l] == -1) beginL[l] = prox;
				beginR[l] = prox;
			}
		}
	}

   /// Building and mining each context. ///
   unsigned int totalBiclusters = 0;
   unsigned int totalCandidates = 0;
 //  unsigned long density = 0;

   for (unsigned int l = 0; l != classesL.size() ; l++) 
   {
       //cout << "[" << classesL[l] << ";" << classesR[l] << "], " << endl;
       InClose *algo;
       algo = new InClose(data,classesL, classesR, l, beginL, beginR, minExt, minInt, maxExt, maxInt);
       algo->initContextFromTol(classesL[l], classesR[l], data, nbObj, nbAttr);
       totalBiclusters += algo->ttmain();
      totalCandidates += algo->nbCandidates;
       //algo->outputConcepts();
      // density += algo->density;
       delete algo;
   } 
  tstop = clock();

// For experiments.
// cout << totalBiclusters << " total maximal biclusters." << endl;
// cout << totalCandidates << " candidates generated as either non maximal or redundant biclusters." << endl;

//double c = boost::lexical_cast<double>( classesL.size());
//double o = boost::lexical_cast<double>( nbObj);
//double at = boost::lexical_cast<double>( nbAttr);
//double d = boost::lexical_cast<double>( density);
//cout << d / c / (o * at);

cout << totalBiclusters << " max. biclusters\n"
     << totalCandidates << " candidates\n"
     << mstimer(tstart,tstop) << " ms\n" ;

/**cout << theta  << "\t"
     << nbObj  << "\t"
     << minExt << "\t"
     << maxExt << "\t"
     << minInt << "\t"
     << maxInt << "\t"
     << totalBiclusters << "\t"
     << totalCandidates << "\t" 
     << classesL.size() << "\t"
     << mstimer(tstart,tstop) << "\n";
**/
   return EXIT_SUCCESS;
}
void CmEvaluation::Evaluate(CStr gtW, CStr &salDir, CStr &resName, vecS &des)
{
	int NumMethod = des.size(); // Number of different methods
	vector<vecD> precision(NumMethod), recall(NumMethod), tpr(NumMethod), fpr(NumMethod);
	static const int CN = 21; // Color Number 
	static const char* c[CN] = {"'k'", "'b'", "'g'", "'r'", "'c'", "'m'", "'y'",
		"':k'", "':b'", "':g'", "':r'", "':c'", "':m'", "':y'", 
		"'--k'", "'--b'", "'--g'", "'--r'", "'--c'", "'--m'", "'--y'"
	};
	FILE* f = fopen(_S(resName), "w");
	CV_Assert(f != NULL);
	fprintf(f, "clear;\nclose all;\nclc;\n\n\n%%%%\nfigure(1);\nhold on;\n");
	vecD thr(NUM_THRESHOLD);
	for (int i = 0; i < NUM_THRESHOLD; i++)
		thr[i] = i * STEP;
	PrintVector(f, thr, "Threshold");
	fprintf(f, "\n");
	
	vecD mae(NumMethod);
	for (int i = 0; i < NumMethod; i++)
		mae[i] = Evaluate_(gtW, salDir, "_" + des[i] + ".png", precision[i], recall[i], tpr[i], fpr[i]); //Evaluate(salDir + "*" + des[i] + ".png", gtW, val[i], recall[i], t);

	string leglendStr("legend(");
	vecS strPre(NumMethod), strRecall(NumMethod), strTpr(NumMethod), strFpr(NumMethod);
	for (int i = 0; i < NumMethod; i++){
		strPre[i] = format("Precision_%s", _S(des[i]));
		strRecall[i] = format("Recall_%s", _S(des[i]));
		strTpr[i] = format("TPR_%s", _S(des[i]));
		strFpr[i] = format("FPR_%s", _S(des[i]));
		PrintVector(f, recall[i], strRecall[i]);
		PrintVector(f, precision[i], strPre[i]);
		PrintVector(f, tpr[i], strTpr[i]);
		PrintVector(f, fpr[i], strFpr[i]);
		fprintf(f, "plot(%s, %s, %s, 'linewidth', %d);\n", _S(strRecall[i]), _S(strPre[i]), c[i % CN], i < CN ? 2 : 1);
		leglendStr += format("'%s', ",  _S(des[i]));
	}
	leglendStr.resize(leglendStr.size() - 2);
	leglendStr += ");";
	string xLabel = "label('Recall');\n";
	string yLabel = "label('Precision')\n";
	fprintf(f, "hold off;\nx%sy%s\n%s\ngrid on;\naxis([0 1 0 1]);\ntitle('Precision recall curve');\n", _S(xLabel), _S(yLabel), _S(leglendStr));


	fprintf(f, "\n\n\n%%%%\nfigure(2);\nhold on;\n");
	for (int i = 0; i < NumMethod; i++)
		fprintf(f, "plot(%s, %s,  %s, 'linewidth', %d);\n", _S(strFpr[i]), _S(strTpr[i]), c[i % CN], i < CN ? 2 : 1);
	xLabel = "label('False positive rate');\n";
	yLabel = "label('True positive rate')\n";
	fprintf(f, "hold off;\nx%sy%s\n%s\ngrid on;\naxis([0 1 0 1]);\n\n\n%%%%\nfigure(3);\ntitle('ROC curve');\n", _S(xLabel), _S(yLabel), _S(leglendStr));

	double betaSqr = 0.3; // As suggested by most papers for salient object detection
	vecD areaROC(NumMethod, 0), avgFMeasure(NumMethod, 0), maxFMeasure(NumMethod, 0);
	for (int i = 0; i < NumMethod; i++){
		CV_Assert(fpr[i].size() == tpr[i].size() && precision[i].size() == recall[i].size() && fpr[i].size() == precision[i].size());
		for (size_t t = 0; t < fpr[i].size(); t++){
			double fMeasure = (1+betaSqr) * precision[i][t] * recall[i][t] / (betaSqr * precision[i][t] + recall[i][t]);
			avgFMeasure[i] += fMeasure/fpr[i].size(); // Doing average like this might have strange effect as in: 
			maxFMeasure[i] = max(maxFMeasure[i], fMeasure);
			if (t > 0){
				areaROC[i] += (tpr[i][t] + tpr[i][t - 1]) * (fpr[i][t - 1] - fpr[i][t]) / 2.0;

			}
		}
		fprintf(f, "%%%5s: AUC = %5.3f, MeanF = %5.3f, MaxF = %5.3f, MAE = %5.3f\n", _S(des[i]), areaROC[i], avgFMeasure[i], maxFMeasure[i], mae[i]);
	}
	PrintVector(f, areaROC, "AUC");
	PrintVector(f, avgFMeasure, "MeanFMeasure");
	PrintVector(f, maxFMeasure, "MaxFMeasure");
	PrintVector(f, mae, "MAE");

	// methodLabels = {'AC', 'SR', 'DRFI', 'GU', 'GB'};
	fprintf(f, "methodLabels = {'%s'", _S(des[0]));
	for (int i = 1; i < NumMethod; i++)
		fprintf(f, ", '%s'", _S(des[i]));
	fprintf(f, "};\n\nbar([MeanFMeasure; MaxFMeasure; AUC]');\nlegend('Mean F_\\beta', 'Max F_\\beta', 'AUC');xlim([0 %d]);\ngrid on;\n", NumMethod+1);
	fprintf(f, "xticklabel_rotate([1:%d],90, methodLabels,'interpreter','none');\n", NumMethod);
	fprintf(f, "\n\nfigure(4);\nbar(MAE);\ntitle('MAE');\ngrid on;\nxlim([0 %d]);", NumMethod+1);
	fprintf(f, "xticklabel_rotate([1:%d],90, methodLabels,'interpreter','none');\n", NumMethod);
	fclose(f);
	printf("%-70s\r", "");
}
void CmEvaluation::EvalueMask(CStr gtW, CStr &maskDir, vecS &des, CStr resFile, double betaSqr, bool alertNul, CStr suffix, CStr title)
{
	vecS namesNS; 
	string gtDir, gtExt;
	int imgNum = CmFile::GetNamesNE(gtW, namesNS, gtDir, gtExt);
	int methodNum = (int)des.size();
	vecD pr(methodNum), rec(methodNum), count(methodNum), fm(methodNum);
	vecD intUnio(methodNum), mae(methodNum);
	for (int i = 0; i < imgNum; i++){
		Mat truM = imread(gtDir + namesNS[i] + gtExt, CV_LOAD_IMAGE_GRAYSCALE);
		for (int m = 0; m < methodNum; m++)	{
			string mapName = maskDir + namesNS[i] + "_" + des[m];
			mapName += suffix.empty() ? ".png" : "_" + suffix + ".png";
			Mat res = imread(mapName, CV_LOAD_IMAGE_GRAYSCALE);
			if (truM.data == NULL || res.data == NULL || truM.size != res.size){
				if (alertNul)
					printf("Truth(%d, %d), Res(%d, %d): %s\n", truM.cols, truM.rows, res.cols, res.rows, _S(mapName));
				continue;
			}
			compare(truM, 128, truM, CMP_GE);
			compare(res, 128, res, CMP_GE);
			Mat commMat, unionMat, diff1f;
			bitwise_and(truM, res, commMat);
			bitwise_or(truM, res, unionMat);
			double commV = sum(commMat).val[0];
			double p = commV/(sum(res).val[0] + EPS);
			double r = commV/(sum(truM).val[0] + EPS);
			pr[m] += p;
			rec[m] += r;
			intUnio[m] += commV / (sum(unionMat).val[0] + EPS);
			absdiff(truM, res, diff1f);
			mae[m] += sum(diff1f).val[0]/(diff1f.rows * diff1f.cols * 255);
			count[m]++;
		}
	}

	for (int m = 0; m < methodNum; m++){
		pr[m] /= count[m], rec[m] /= count[m];
		fm[m] = (1 + betaSqr) * pr[m] * rec[m] / (betaSqr * pr[m] + rec[m] + EPS);
		intUnio[m] /= count[m];
		mae[m] /= count[m];
	}

#ifndef fopen_s
	FILE *f = fopen(_S(resFile), "a");
#else
	FILE *f; 
	fopen_s(&f, _S(resFile), "a");
#endif
	if (f != NULL){
		fprintf(f, "\n%%%%\n");
		CmEvaluation::PrintVector(f, pr, "PrecisionMask" + suffix);
		CmEvaluation::PrintVector(f, rec, "RecallMask" + suffix);
		CmEvaluation::PrintVector(f, fm, "FMeasureMask" + suffix);
		CmEvaluation::PrintVector(f, intUnio, "IntUnion" + suffix);
		CmEvaluation::PrintVector(f, intUnio, "MAE" + suffix);
		fprintf(f, "bar([%s]');\ngrid on\n", _S("PrecisionMask" + suffix + "; RecallMask" + suffix + "; FMeasureMask" + suffix + "; IntUnion" + suffix));
		fprintf(f, "title('%s');\naxis([0 %d 0.8 1]);\nmethodLabels = { '%s'", _S(title), des.size() + 1, _S(des[0]));
		for (size_t i = 1; i < des.size(); i++)
			fprintf(f, ", '%s'", _S(des[i]));
		fprintf(f, " };\nlegend('Precision', 'Recall', 'FMeasure', 'IntUnion');\n");
		fprintf(f, "xticklabel_rotate([1:%d], 90, methodLabels, 'interpreter', 'none');\n", des.size());
		fclose(f);
	}

	if (des.size() == 1)
		printf("Precision = %g, recall = %g, F-Measure = %g, intUnion = %g, mae = %g\n", pr[0], rec[0], fm[0], intUnio[0], mae[0]);
}
Exemple #5
0
int main(int argc, char *argv[])
{
   char *ftest = NULL;
   struct timeval t0, t1, diff;
   problem *train, *test;
   int regpath_flag = 0, backtracking_flag = 0, std_flag = 1, verbose_flag = 0;
   int iter = 1000, c, crossval_flag = 0, nr_folds = 10, nval = 100, nzerow;
   double *w, *y_hat, *mean, *var;
   double lambda_1 = 1e-6, lambda_2 = 0, tol = 1e-9, epsilon, fret;

   while (1)
   {
      static struct option long_options[] =
      {
         /* These options don't set a flag.
          We distinguish them by their indices. */
         {"help",                   no_argument, 0, 'h'},
         {"verbose",                no_argument, 0, 'v'},
         {"backtracking",           no_argument, 0, 'b'},
         {"original",               no_argument, 0, 'o'},
         {"test",             required_argument, 0, 't'},
         {"l1",               required_argument, 0, 'l'},
         {"l2",               required_argument, 0, 'r'},
         {"cross-validation", optional_argument, 0, 'c'},
         {"tolerance       ", optional_argument, 0, 'e'},
         {"regpath",          optional_argument, 0, 'p'},
         /*{"stop",             optional_argument, 0, 's'},*/
         {"max-iters",        optional_argument, 0, 'i'},
         {0, 0, 0, 0}
      };

      int option_index = 0;

      c = getopt_long (argc, argv, "vhbot:r:l:p::c::e::s::i::", long_options, &option_index);

      /* Detect the end of the options. */
      if (c == -1)
         break;

      switch(c)
      {
         case 'h':
            exit_with_help(argv[PROG]);
            break;

         case 'b':
            backtracking_flag = 1;
            break;

         case 'v':
            verbose_flag = 1;
            break;

         case 'o':
            std_flag = 0;
            break;

         case 't':
            ftest = optarg;
            break;

         case 'c':
            crossval_flag = 1;
            if (optarg)
               if (sscanf(optarg, "%d", &nr_folds) != 1)
               {
                  fprintf(stderr, "%s: option -c requires an int\n", argv[PROG]);
                  exit_without_help(argv[PROG]);
               }
            break;

         case 'e':
            if (optarg)
               if (sscanf(optarg, "%lf", &tol) != 1)
               {
                  fprintf(stderr, "%s: option -e requires a double\n", argv[PROG]);
                  exit_without_help(argv[PROG]);
               }
            break;

         case 'p':
            regpath_flag = 1;
            if (optarg)
               if (sscanf(optarg, "%d", &nval) != 1)
               {
                  fprintf(stderr, "%s: option -p requires an int\n", argv[PROG]);
                  exit_without_help(argv[PROG]);
               }
            break;

         //case 's':
         //   search_flag = 1;
         //   if (optarg)
         //      if (sscanf(optarg, "%lf:%d:%lf", &lmax, &nval, &lmin) != 3)
         //      {
         //         printf("%s\n", optarg);
         //         fprintf(stderr, "%s: option -s requires a range in the format MAX:NVAL:MIN\n", argv[PROG]);
         //         exit_without_help(argv[PROG]);
         //      }
         //   break;

         case 'l':
            if (sscanf(optarg, "%lf", &lambda_1) != 1)
            {
               fprintf(stderr, "%s: option -l requires a float\n", argv[PROG]);
               exit_without_help(argv[PROG]);
            }
            break;

         case 'r':
            if (sscanf(optarg, "%lf", &lambda_2) != 1)
            {
               fprintf(stderr, "%s: option -r requires a float\n", argv[PROG]);
               exit_without_help(argv[PROG]);
            }
            break;

         case 'i':
            if (optarg)
               if (sscanf(optarg, "%d", &iter) != 1)
               {
                  fprintf(stderr, "%s: option -i requires an int\n", argv[PROG]);
                  exit_without_help(argv[PROG]);
               }
            break;

         case '?':
            /* getopt_long already printed an error message. */
            exit_without_help(argv[PROG]);
            break;

         default:
            printf("?? getopt returned character code 0%o ??\n", c); 
      }
   }

   if ((argc - optind) < ARGC_MIN || (argc - optind) > ARGC_MAX)
   {
      fprintf(stderr, "%s: missing file operand\n", argv[PROG]);
      exit_without_help(argv[PROG]);
   }

   /* start time */
   gettimeofday(&t0, 0);

   train = read_problem(argv[optind]);

   fprintf(stdout, "n:%d dim:%d\n", train->n, train->dim);

   /* alloc vector for means and variances, plus 1 for output */
   if (std_flag)
   {
      fprintf(stdout, "Standarizing train set...\n");
      mean = dvector(1, train->dim+1);
      var = dvector(1, train->dim+1);
      standarize(train, 1, mean, var);
   }

   if (ftest)
   {
      test = read_problem(ftest);
      if (std_flag)
         standarize(test, 0, mean, var);
   }

   if (regpath_flag)
   {
      fprintf(stdout, "Regularization path...\n");
      /* in glmnet package they use 0.0001 instead of 0.001 ? */
      epsilon = train->n > train->dim ? 0.001 : 0.01;
      lambda_1 = regularization_path(train, epsilon, nval);
   }

   fprintf(stdout, "lambda_1: %g\n", lambda_1);
   fprintf(stdout, "lambda_2: %g\n", lambda_2);

   /* initialize weight vector to 0 */
   w = dvector(1, train->dim);
   dvset(w, train->dim, 0);

   fprintf(stdout, "Training model...\n");
   if (backtracking_flag)
      /*fista_backtrack(train, w, lambda_1, lambda_2, tol, &iter, &fret);*/
      fista_nocov(train, w, lambda_1, lambda_2, tol, &iter, &fret);
   else
      fista(train, w, lambda_1, lambda_2, tol, verbose_flag, &iter, &fret);

   y_hat = dvector(1, train->n);
   fista_predict(train, w, y_hat);

   nzerow = dvnotzero(w, train->dim);

   fprintf(stdout, "Iterations: %d\n", iter);
   fprintf(stdout, "Active weights: %d/%d\n", nzerow, train->dim);
   if (std_flag)
      fprintf(stdout, "MAE train: %g\n", var[train->dim+1]*mae(train->y, train->n, y_hat));
   fprintf(stdout, "MAE train (standarized): %g\n", mae(train->y, train->n, y_hat));

   free_dvector(y_hat, 1, train->n);

   if (crossval_flag)
   {
      dvset(w, train->dim, 0);
      y_hat = dvector(1, train->n);
      cross_validation(train, w, lambda_1, lambda_2, nr_folds, y_hat);
      fprintf(stdout, "MAE cross-validation: %lf\n",
              mae(train->y, train->n, y_hat));
      free_dvector(y_hat, 1, train->n);
   }

   if (ftest)
   {
      /* we alloc memory again since test size is different from train size */
      y_hat = dvector(1, test->n);
      fista_predict(test, w, y_hat);
      fprintf(stdout, "MAE test: %g\n", mae(test->y, test->n, y_hat));
      free_dvector(y_hat, 1, test->n);
   }

   /* stop time */
   gettimeofday(&t1, 0);
   timeval_subtract(&t1, &t0, &diff);
   fprintf(stdout, "Time(h:m:s.us): %02d:%02d:%02d.%06ld\n",
           diff.tv_sec/3600, (diff.tv_sec/60), diff.tv_sec%60, diff.tv_usec);

   if (verbose_flag)
   {
      fprintf(stdout, "Weights: ");
      dvprint(stdout, w, train->dim);
   }

   free_dvector(w, 1, train->dim);

   if (std_flag)
   {
      free_dvector(mean, 1, train->dim+1);
      free_dvector(var, 1, train->dim+1);
   }

   if (ftest)
   {
      free_dvector(test->y, 1, test->n);
      free_dmatrix(test->X, 1, test->n, 1, test->dim);
      free(test);
   }

   free_dvector(train->y, 1, train->n);
   free_dmatrix(train->X, 1, train->n, 1, train->dim);
   free(train);

   return 0;
}