コード例 #1
0
void iter2_fista(void* _conf,
		const struct operator_s* normaleq_op,
		unsigned int D,
		const struct operator_p_s** prox_ops,
		const struct linop_s** ops,
		const struct operator_p_s* xupdate_op,
		long size, float* image, const float* image_adj,
		const float* image_truth,
		void* obj_eval_data,
		float (*obj_eval)(const void*, const float*))
{

	assert(D == 1);
#if 0
	assert(NULL == ops);
#else
	UNUSED(ops);
#endif
	UNUSED(xupdate_op);

	struct iter_fista_conf* conf = _conf;

	float eps = md_norm(1, MD_DIMS(size), image_adj);

	if (checkeps(eps))
		goto cleanup;

	assert((conf->continuation >= 0.) && (conf->continuation <= 1.));

	fista(conf->maxiter, eps * conf->tol, conf->step, conf->continuation, conf->hogwild, size, (void*)normaleq_op, select_vecops(image_adj), operator_iter, operator_p_iter, (void*)prox_ops[0], image, image_adj, image_truth, obj_eval_data, obj_eval);

cleanup:
	;
}
コード例 #2
0
ファイル: iter2.c プロジェクト: mrirecon/bart
void iter2_fista(iter_conf* _conf,
		const struct operator_s* normaleq_op,
		unsigned int D,
		const struct operator_p_s* prox_ops[D],
		const struct linop_s* ops[D],
		const float* biases[D],
		const struct operator_p_s* xupdate_op,
		long size, float* image, const float* image_adj,
		struct iter_monitor_s* monitor)
{
	assert(D == 1);
	assert(NULL == biases);
#if 0
	assert(NULL == ops);
#else
	UNUSED(ops);
#endif
	UNUSED(xupdate_op);

	auto conf = CAST_DOWN(iter_fista_conf, _conf);

	float eps = md_norm(1, MD_DIMS(size), image_adj);

	if (checkeps(eps))
		goto cleanup;

	assert((conf->continuation >= 0.) && (conf->continuation <= 1.));

	fista(conf->maxiter, eps * conf->tol, conf->step, conf->continuation, conf->hogwild, size, select_vecops(image_adj),
		OPERATOR2ITOP(normaleq_op), OPERATOR_P2ITOP(prox_ops[0]), image, image_adj, monitor);

cleanup:
	;
}
コード例 #3
0
ファイル: main.c プロジェクト: albertotb/fista
double regularization_path(problem *prob, double epsilon, int nval)
{
   int nr_folds = 5;
   double llog, error, best_error = DBL_MAX, lambda, best_lambda;
   double lmax, lmin, lstep;
   double *y_hat = dvector(1, prob->n);
   double *w = dvector(1, prob->dim);

  /* compute maximum lambda for which all weights are 0 (Osborne et al. 1999)
    * lambda_max = ||X'y||_inf. According to scikit-learn source code, you can
    * divide by npatterns and it still works */
   dmvtransmult(prob->X, prob->n, prob->dim, prob->y, prob->n, w);
   lmax = dvnorm(w, prob->dim, INF) / prob->n;
   lmin = epsilon*lmax;
   lstep = (log2(lmax)-log2(lmin))/nval;

   fprintf(stdout, "lmax=%g lmin=%g epsilon=%g nval=%d\n",
           lmax, lmin, epsilon, nval);

   /* warm-starts: weights are set to 0 only at the begining */
   dvset(w, prob->dim, 0);
   for(llog=log2(lmax); llog >= log2(lmin); llog -= lstep)
   {
      lambda = pow(2, llog);
      /*cross_validation(prob, w, lambda, 0, nr_folds, y_hat);*/

      /*******************************************************/
      int iter = 1000; double tol = 0, fret;
      fista(prob, w, lambda, 0, tol, 0, &iter, &fret);
      fista_predict(prob, w, y_hat);
      /*******************************************************/

      error = mae(prob->y, prob->n, y_hat);
      fprintf(stdout, "   lambda %10.6lf   MAE %7.6lf   active weights %d/%d\n",
              lambda, error, dvnotzero(w, prob->dim), prob->dim);

      dvprint(stdout, w, prob->dim);

      if (error < best_error)
      {
         best_error = error;
         best_lambda = lambda;
      }
   }

   free_dvector(y_hat, 1, prob->n);
   free_dvector(w, 1, prob->dim);

   print_line(60);
   fprintf(stdout, "\nBest: lambda=%g MAE=%g active weights=%d/%d\n",
           best_lambda, best_error, dvnotzero(w, prob->dim), prob->dim);

   return best_lambda;
}
コード例 #4
0
ファイル: main.c プロジェクト: albertotb/fista
int main(int argc, char *argv[])
{
   char *ftest = NULL;
   struct timeval t0, t1, diff;
   problem *train, *test;
   int regpath_flag = 0, backtracking_flag = 0, std_flag = 1, verbose_flag = 0;
   int iter = 1000, c, crossval_flag = 0, nr_folds = 10, nval = 100, nzerow;
   double *w, *y_hat, *mean, *var;
   double lambda_1 = 1e-6, lambda_2 = 0, tol = 1e-9, epsilon, fret;

   while (1)
   {
      static struct option long_options[] =
      {
         /* These options don't set a flag.
          We distinguish them by their indices. */
         {"help",                   no_argument, 0, 'h'},
         {"verbose",                no_argument, 0, 'v'},
         {"backtracking",           no_argument, 0, 'b'},
         {"original",               no_argument, 0, 'o'},
         {"test",             required_argument, 0, 't'},
         {"l1",               required_argument, 0, 'l'},
         {"l2",               required_argument, 0, 'r'},
         {"cross-validation", optional_argument, 0, 'c'},
         {"tolerance       ", optional_argument, 0, 'e'},
         {"regpath",          optional_argument, 0, 'p'},
         /*{"stop",             optional_argument, 0, 's'},*/
         {"max-iters",        optional_argument, 0, 'i'},
         {0, 0, 0, 0}
      };

      int option_index = 0;

      c = getopt_long (argc, argv, "vhbot:r:l:p::c::e::s::i::", long_options, &option_index);

      /* Detect the end of the options. */
      if (c == -1)
         break;

      switch(c)
      {
         case 'h':
            exit_with_help(argv[PROG]);
            break;

         case 'b':
            backtracking_flag = 1;
            break;

         case 'v':
            verbose_flag = 1;
            break;

         case 'o':
            std_flag = 0;
            break;

         case 't':
            ftest = optarg;
            break;

         case 'c':
            crossval_flag = 1;
            if (optarg)
               if (sscanf(optarg, "%d", &nr_folds) != 1)
               {
                  fprintf(stderr, "%s: option -c requires an int\n", argv[PROG]);
                  exit_without_help(argv[PROG]);
               }
            break;

         case 'e':
            if (optarg)
               if (sscanf(optarg, "%lf", &tol) != 1)
               {
                  fprintf(stderr, "%s: option -e requires a double\n", argv[PROG]);
                  exit_without_help(argv[PROG]);
               }
            break;

         case 'p':
            regpath_flag = 1;
            if (optarg)
               if (sscanf(optarg, "%d", &nval) != 1)
               {
                  fprintf(stderr, "%s: option -p requires an int\n", argv[PROG]);
                  exit_without_help(argv[PROG]);
               }
            break;

         //case 's':
         //   search_flag = 1;
         //   if (optarg)
         //      if (sscanf(optarg, "%lf:%d:%lf", &lmax, &nval, &lmin) != 3)
         //      {
         //         printf("%s\n", optarg);
         //         fprintf(stderr, "%s: option -s requires a range in the format MAX:NVAL:MIN\n", argv[PROG]);
         //         exit_without_help(argv[PROG]);
         //      }
         //   break;

         case 'l':
            if (sscanf(optarg, "%lf", &lambda_1) != 1)
            {
               fprintf(stderr, "%s: option -l requires a float\n", argv[PROG]);
               exit_without_help(argv[PROG]);
            }
            break;

         case 'r':
            if (sscanf(optarg, "%lf", &lambda_2) != 1)
            {
               fprintf(stderr, "%s: option -r requires a float\n", argv[PROG]);
               exit_without_help(argv[PROG]);
            }
            break;

         case 'i':
            if (optarg)
               if (sscanf(optarg, "%d", &iter) != 1)
               {
                  fprintf(stderr, "%s: option -i requires an int\n", argv[PROG]);
                  exit_without_help(argv[PROG]);
               }
            break;

         case '?':
            /* getopt_long already printed an error message. */
            exit_without_help(argv[PROG]);
            break;

         default:
            printf("?? getopt returned character code 0%o ??\n", c); 
      }
   }

   if ((argc - optind) < ARGC_MIN || (argc - optind) > ARGC_MAX)
   {
      fprintf(stderr, "%s: missing file operand\n", argv[PROG]);
      exit_without_help(argv[PROG]);
   }

   /* start time */
   gettimeofday(&t0, 0);

   train = read_problem(argv[optind]);

   fprintf(stdout, "n:%d dim:%d\n", train->n, train->dim);

   /* alloc vector for means and variances, plus 1 for output */
   if (std_flag)
   {
      fprintf(stdout, "Standarizing train set...\n");
      mean = dvector(1, train->dim+1);
      var = dvector(1, train->dim+1);
      standarize(train, 1, mean, var);
   }

   if (ftest)
   {
      test = read_problem(ftest);
      if (std_flag)
         standarize(test, 0, mean, var);
   }

   if (regpath_flag)
   {
      fprintf(stdout, "Regularization path...\n");
      /* in glmnet package they use 0.0001 instead of 0.001 ? */
      epsilon = train->n > train->dim ? 0.001 : 0.01;
      lambda_1 = regularization_path(train, epsilon, nval);
   }

   fprintf(stdout, "lambda_1: %g\n", lambda_1);
   fprintf(stdout, "lambda_2: %g\n", lambda_2);

   /* initialize weight vector to 0 */
   w = dvector(1, train->dim);
   dvset(w, train->dim, 0);

   fprintf(stdout, "Training model...\n");
   if (backtracking_flag)
      /*fista_backtrack(train, w, lambda_1, lambda_2, tol, &iter, &fret);*/
      fista_nocov(train, w, lambda_1, lambda_2, tol, &iter, &fret);
   else
      fista(train, w, lambda_1, lambda_2, tol, verbose_flag, &iter, &fret);

   y_hat = dvector(1, train->n);
   fista_predict(train, w, y_hat);

   nzerow = dvnotzero(w, train->dim);

   fprintf(stdout, "Iterations: %d\n", iter);
   fprintf(stdout, "Active weights: %d/%d\n", nzerow, train->dim);
   if (std_flag)
      fprintf(stdout, "MAE train: %g\n", var[train->dim+1]*mae(train->y, train->n, y_hat));
   fprintf(stdout, "MAE train (standarized): %g\n", mae(train->y, train->n, y_hat));

   free_dvector(y_hat, 1, train->n);

   if (crossval_flag)
   {
      dvset(w, train->dim, 0);
      y_hat = dvector(1, train->n);
      cross_validation(train, w, lambda_1, lambda_2, nr_folds, y_hat);
      fprintf(stdout, "MAE cross-validation: %lf\n",
              mae(train->y, train->n, y_hat));
      free_dvector(y_hat, 1, train->n);
   }

   if (ftest)
   {
      /* we alloc memory again since test size is different from train size */
      y_hat = dvector(1, test->n);
      fista_predict(test, w, y_hat);
      fprintf(stdout, "MAE test: %g\n", mae(test->y, test->n, y_hat));
      free_dvector(y_hat, 1, test->n);
   }

   /* stop time */
   gettimeofday(&t1, 0);
   timeval_subtract(&t1, &t0, &diff);
   fprintf(stdout, "Time(h:m:s.us): %02d:%02d:%02d.%06ld\n",
           diff.tv_sec/3600, (diff.tv_sec/60), diff.tv_sec%60, diff.tv_usec);

   if (verbose_flag)
   {
      fprintf(stdout, "Weights: ");
      dvprint(stdout, w, train->dim);
   }

   free_dvector(w, 1, train->dim);

   if (std_flag)
   {
      free_dvector(mean, 1, train->dim+1);
      free_dvector(var, 1, train->dim+1);
   }

   if (ftest)
   {
      free_dvector(test->y, 1, test->n);
      free_dmatrix(test->X, 1, test->n, 1, test->dim);
      free(test);
   }

   free_dvector(train->y, 1, train->n);
   free_dmatrix(train->X, 1, train->n, 1, train->dim);
   free(train);

   return 0;
}