コード例 #1
0
ファイル: detection_thread.hpp プロジェクト: 2php/eblearn
void detection_thread<T>::init_detector(detector<T> &detect,
                                        configuration &conf,
                                        std::string &odir, bool silent) {
  // multi-scaling parameters
  double maxs = conf.try_get_double("max_scale", 2.0);
  double mins = conf.try_get_double("min_scale", 1.0);
  t_scaling scaling_type =
      (t_scaling) conf.try_get_uint("scaling_type", SCALES_STEP);
  double scaling = conf.try_get_double("scaling", 1.4);
  std::vector<midxdim> scales;
  switch (scaling_type) {
    case MANUAL:
      if (!conf.exists("scales"))
	eblerror("expected \"scales\" variable to be defined in manual mode");
      scales = string_to_midxdimvector(conf.get_cstring("scales"));
      detect.set_resolutions(scales);
      break ;
    case ORIGINAL: detect.set_scaling_original(); break ;
    case SCALES_STEP:
      detect.set_resolutions(scaling, maxs, mins);
      break ;
    case SCALES_STEP_UP:
      detect.set_resolutions(scaling, maxs, mins);
      detect.set_scaling_type(scaling_type);
      break ;
    default:
      detect.set_scaling_type(scaling_type);
  }
  // remove pads from target scales if requested
  if (conf.exists_true("scaling_remove_pad")) detect.set_scaling_rpad(true);
  // optimize memory usage by using only 2 buffers for entire flow
  state<T> input(1, 1, 1), output(1, 1, 1);
  if (!conf.exists_false("mem_optimization"))
    detect.set_mem_optimization(input, output, true);
  // TODO: always keep inputs, otherwise detection doesnt work. fix this.
  // 				   conf.exists_true("save_detections") ||
  // 				   (display && !mindisplay));
  // zero padding
  float hzpad = conf.try_get_float("hzpad", 0);
  float wzpad = conf.try_get_float("wzpad", 0);
  detect.set_zpads(hzpad, wzpad);
  if (conf.exists("input_min")) // limit inputs size
    detect.set_min_resolution(conf.get_uint("input_min"));
  if (conf.exists("input_max")) // limit inputs size
    detect.set_max_resolution(conf.get_uint("input_max"));
  if (silent) detect.set_silent();
  if (conf.exists_bool("save_detections")) {
    std::string detdir = odir;
    detdir += "detections";
    uint nsave = conf.try_get_uint("save_max_per_frame", 0);
    bool diverse = conf.exists_true("save_diverse");
    detdir = detect.set_save(detdir, nsave, diverse);
  }
  detect.set_scaler_mode(conf.exists_true("scaler_mode"));
  if (conf.exists("bbox_decision"))
    detect.set_bbox_decision(conf.get_uint("bbox_decision"));
  if (conf.exists("bbox_scalings")) {
    mfidxdim scalings =
	string_to_fidxdimvector(conf.get_cstring("bbox_scalings"));
    detect.set_bbox_scalings(scalings);
  }

  // nms configuration //////////////////////////////////////////////////////
  t_nms nms_type = (t_nms) conf.try_get_uint("nms", 0);
  float pre_threshold = conf.try_get_float("pre_threshold", 0.0);
  float post_threshold = conf.try_get_float("post_threshold", 0.0);
  float pre_hfact = conf.try_get_float("pre_hfact", 1.0);
  float pre_wfact = conf.try_get_float("pre_wfact", 1.0);
  float post_hfact = conf.try_get_float("post_hfact", 1.0);
  float post_wfact = conf.try_get_float("post_wfact", 1.0);
  float woverh = conf.try_get_float("woverh", 1.0);
  float max_overlap = conf.try_get_float("max_overlap", 0.0);
  float max_hcenter_dist = conf.try_get_float("max_hcenter_dist", 0.0);
  float max_wcenter_dist = conf.try_get_float("max_wcenter_dist", 0.0);
  float vote_max_overlap = conf.try_get_float("vote_max_overlap", 0.0);
  float vote_mhd = conf.try_get_float("vote_max_hcenter_dist", 0.0);
  float vote_mwd = conf.try_get_float("vote_max_wcenter_dist", 0.0);
  detect.set_nms(nms_type, pre_threshold, post_threshold, pre_hfact,
                 pre_wfact, post_hfact, post_wfact, woverh, max_overlap,
                 max_hcenter_dist, max_wcenter_dist, vote_max_overlap,
                 vote_mhd, vote_mwd);
  if (conf.exists("raw_thresholds")) {
    std::string srt = conf.get_string("raw_thresholds");
    std::vector<float> rt = string_to_floatvector(srt.c_str());
    detect.set_raw_thresholds(rt);
  }
  if (conf.exists("outputs_threshold"))
    detect.set_outputs_threshold(conf.get_double("outputs_threshold"),
                                 conf.try_get_double("outputs_threshold_val",
                                                     -1));
  ///////////////////////////////////////////////////////////////////////////
  if (conf.exists("netdims")) {
    idxdim d = string_to_idxdim(conf.get_string("netdims"));
    detect.set_netdim(d);
  }
  if (conf.exists("smoothing")) {
    idxdim ker;
    if (conf.exists("smoothing_kernel"))
      ker = string_to_idxdim(conf.get_string("smoothing_kernel"));
    detect.set_smoothing(conf.get_uint("smoothing"),
                         conf.try_get_double("smoothing_sigma", 1),
                         &ker,
                         conf.try_get_double("smoothing_sigma_scale", 1));
  }
  if (conf.exists("background_name"))
    detect.set_bgclass(conf.get_cstring("background_name"));
  if (conf.exists_true("bbox_ignore_outsiders"))
    detect.set_ignore_outsiders();
  if (conf.exists("corners_inference"))
    detect.set_corners_inference(conf.get_uint("corners_inference"));
  if (conf.exists("input_gain"))
    detect.set_input_gain(conf.get_double("input_gain"));
  if (conf.exists_true("dump_outputs")) {
    std::string fname;
    fname << odir << "/dump/detect_out";
    detect.set_outputs_dumping(fname.c_str());
  }
}
コード例 #2
0
ファイル: train_utils.hpp プロジェクト: athuls/gsra
  void test_and_save(uint iter, configuration &conf, string &conffname,
		     parameter<Tnet> &theparam,
		     supervised_trainer<Tnet,Tdata,Tlabel> &thetrainer,
		     labeled_datasource<Tnet,Tdata,Tlabel> &train_ds,
		     labeled_datasource<Tnet,Tdata,Tlabel> &test_ds,
		     classifier_meter &trainmeter,
		     classifier_meter &testmeter,
		     infer_param &infp, gd_param &gdp, string &shortname) {
    timer ttest;
    ostringstream wname, wfname;

    //   // some code to average several random solutions
    //     cout << "Testing...";
    //     if (original_tests > 1) cout << " (" << original_tests << " times)";
    //     cout << endl;
    //     ttest.restart();
    //     for (uint i = 0; i < original_tests; ++i) {
    //       if (test_only && original_tests > 1) {
    // 	// we obviously wanna test several random solutions
    // 	cout << "Initializing weights from random." << endl;
    // 	thenet.forget(fgp);
    //       }
    //       if (!no_training_test)
    // 	thetrainer.test(train_ds, trainmeter, infp);
    //       thetrainer.test(test_ds, testmeter, infp);
    //       cout << "testing_time="; ttest.pretty_elapsed(); cout << endl;
    //     }
    //     if (test_only && original_tests > 1) {
    //       // display averages over all tests
    //       testmeter.display_average(test_ds.name(), test_ds.lblstr, 
    // 				test_ds.is_test());
    //       trainmeter.display_average(train_ds.name(), train_ds.lblstr, 
    // 				 train_ds.is_test());
    //     }
    cout << "Testing..." << endl;
    uint maxtest = conf.exists("max_testing") ? conf.get_uint("max_testing") :0;
    ttest.start();
    if (!conf.exists_true("no_training_test"))
      thetrainer.test(train_ds, trainmeter, infp, maxtest);	// test
    if (!conf.exists_true("no_testing_test"))
      thetrainer.test(test_ds, testmeter, infp, maxtest);	// test
    cout << "testing_time="; ttest.pretty_elapsed(); cout << endl;
    // save samples picking statistics
    if (conf.exists_true("save_pickings")) {
      string fname; fname << "pickings_" << iter;
      train_ds.save_pickings(fname.c_str());
    }
    // save weights and confusion matrix for test set
    wname.str("");
    if (conf.exists("job_name"))
      wname << conf.get_string("job_name");
    wname << "_net" << setfill('0') << setw(5) << iter;
    wfname.str(""); wfname << wname.str() << ".mat";
    if (conf.exists_false("save_weights"))
      cout << "Not saving weights (save_weights set to 0)." << endl;
    else {
      cout << "saving net to " << wfname.str() << endl;
      theparam.save_x(wfname.str().c_str()); // save trained network
      cout << "saved=" << wfname.str() << endl;
    }
    // detection test
    if (conf.exists_true("detection_test")) {
      uint dt_nthreads = 1;
      if (conf.exists("detection_test_nthreads"))
	dt_nthreads = conf.get_uint("detection_test_nthreads");
      timer dtest;
      dtest.start();
      // copy config file and augment it and detect it
      string cmd, params;
      if (conf.exists("detection_params")) {
	params = conf.get_string("detection_params");
	params = string_replaceall(params, "\\n", "\n");
      }
      cmd << "cp " << conffname << " tmp.conf && echo \"silent=1\n"
	  << "nthreads=" << dt_nthreads << "\nevaluate=1\nweights_file=" 
	  << wfname.str() << "\n" << params << "\" >> tmp.conf && detect tmp.conf";
      if (std::system(cmd.c_str()))
	cerr << "warning: failed to execute: " << cmd << endl;
      cout << "detection_test_time="; dtest.pretty_elapsed(); cout << endl;
    }
    // set retrain to next iteration with current saved weights
    ostringstream progress;
    progress << "retrain_iteration = " << iter + 1 << endl
	     << "retrain_weights = " << wfname.str() << endl;
    // save progress
    job::write_progress(iter + 1, conf.get_uint("iterations"),
			progress.str().c_str());
    // save confusion
    if (conf.exists_true("save_confusion")) {
      string fname; fname << wname.str() << "_confusion_test.mat";
      cout << "saving confusion to " << fname << endl;
      save_matrix(testmeter.get_confusion(), fname.c_str());
    }
#ifdef __GUI__ // display
    static supervised_trainer_gui<Tnet,Tdata,Tlabel> stgui(shortname.c_str());
    static supervised_trainer_gui<Tnet,Tdata,Tlabel> stgui2(shortname.c_str());
    bool display = conf.exists_true("show_train"); // enable/disable display
    uint ninternals = conf.exists("show_train_ninternals") ? 
      conf.get_uint("show_train_ninternals") : 1; // # examples' to display
    bool show_train_errors = conf.exists_true("show_train_errors");
    bool show_train_correct = conf.exists_true("show_train_correct");
    bool show_val_errors = conf.exists_true("show_val_errors");
    bool show_val_correct = conf.exists_true("show_val_correct");
    bool show_raw_outputs = conf.exists_true("show_raw_outputs");
    bool show_all_jitter = conf.exists_true("show_all_jitter");
    uint hsample = conf.exists("show_hsample") ?conf.get_uint("show_hsample"):5;
    uint wsample = conf.exists("show_wsample") ?conf.get_uint("show_wsample"):5;
    if (display) {
      cout << "Displaying training..." << endl;
      if (show_train_errors) {
	stgui2.display_correctness(true, true, thetrainer, train_ds, infp,
				   hsample, wsample, show_raw_outputs,
				   show_all_jitter);
	stgui2.display_correctness(true, false, thetrainer, train_ds, infp,
				   hsample, wsample, show_raw_outputs,
				   show_all_jitter);
      }
      if (show_train_correct) {
	stgui2.display_correctness(false, true, thetrainer, train_ds, infp,
				   hsample, wsample, show_raw_outputs,
				   show_all_jitter);
	stgui2.display_correctness(false, false, thetrainer, train_ds, infp,
				   hsample, wsample, show_raw_outputs,
				   show_all_jitter);
      }
      if (show_val_errors) {
	stgui.display_correctness(true, true, thetrainer, test_ds, infp,
				  hsample, wsample, show_raw_outputs,
				  show_all_jitter);
	stgui.display_correctness(true, false, thetrainer, test_ds, infp,
				  hsample, wsample, show_raw_outputs,
				  show_all_jitter);
      }
      if (show_val_correct) {
	stgui.display_correctness(false, true, thetrainer, test_ds, infp,
				  hsample, wsample, show_raw_outputs,
				  show_all_jitter);
	stgui.display_correctness(false, false, thetrainer, test_ds, infp,
				  hsample, wsample, show_raw_outputs,
				  show_all_jitter);
      }
      stgui.display_internals(thetrainer, test_ds, infp, gdp, ninternals);
    }
#endif
  }