예제 #1
0
int _svm_struct_learn (int argc, char* argv[])
{  
  SAMPLE sample;  /* training sample */
  LEARN_PARM learn_parm;
  KERNEL_PARM kernel_parm;
  STRUCT_LEARN_PARM struct_parm;
  STRUCTMODEL structmodel;
  int alg_type;

  svm_struct_learn_api_init(argc,argv);

  _svm_struct_learn_read_input_parameters(argc,argv,trainfile,modelfile,&verbosity,
			&struct_verbosity,&struct_parm,&learn_parm,
			&kernel_parm,&alg_type);

   if(struct_verbosity>=1) {
    printf("Reading training examples..."); fflush(stdout);
  }
  /* read the training examples */
  sample=read_struct_examples(trainfile,&struct_parm);
  if(struct_verbosity>=1) {
    printf("done\n"); fflush(stdout);
  }
  
  /* Do the learning and return structmodel. */
  if(alg_type == 0)
    svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_ALG);
  else if(alg_type == 1)
    svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_SHRINK_ALG);
  else if(alg_type == 2)
    svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_PRIMAL_ALG);
  else if(alg_type == 3)
    svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_ALG);
  else if(alg_type == 4)
    svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_CACHE_ALG);
  else if(alg_type == 9)
    svm_learn_struct_joint_custom(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel);
  else
    exit(1);

  /* Warning: The model contains references to the original data 'docs'.
     If you want to free the original data, and only keep the model, you 
     have to make a deep copy of 'model'. */
  if(struct_verbosity>=1) {
    printf("Writing learned model...");fflush(stdout);
  }
  write_struct_model(modelfile,&structmodel,&struct_parm);
  if(struct_verbosity>=1) {
    printf("done\n");fflush(stdout);
  }

  free_struct_sample(sample);
  free_struct_model(structmodel);

  svm_struct_learn_api_exit();

  return 0;
}
int main (int argc, char* argv[])
{  
  SAMPLE sample;  /* training sample */
  LEARN_PARM learn_parm;
  KERNEL_PARM kernel_parm;
  STRUCT_LEARN_PARM struct_parm;
  STRUCTMODEL structmodel;
  int alg_type;

  svm_struct_learn_api_init(argc,argv);

  read_input_parameters(argc,argv,trainfile,modelfile,&verbosity,
			&struct_verbosity,&struct_parm,&learn_parm,
			&kernel_parm,&alg_type);

  if(struct_verbosity>=1) {
    //verbose = true;
    printf("Reading training examples..."); fflush(stdout);
  }
  /* read the training examples */
  sample=read_struct_examples(trainfile,&struct_parm);
  if(struct_verbosity>=1) {
    printf("done\n"); fflush(stdout);
  }

  string config_tmp;
  bool update_loss_function = false;
  if(Config::Instance()->getParameter("update_loss_function", config_tmp)) {
    update_loss_function = config_tmp.c_str()[0] == '1';
  }
  printf("[Main] update_loss_function=%d\n", (int)update_loss_function);
  if(!update_loss_function) {
    printf("update_loss_function should be true\n");
    exit(-1);
  }  

  eUpdateType updateType = UPDATE_NODE_EDGE;
  if(Config::Instance()->getParameter("update_type", config_tmp)) {
    updateType = (eUpdateType)atoi(config_tmp.c_str());
  }
  printf("[Main] update_type=%d\n", (int)updateType);

  mkdir(loss_dir, 0777);

  // check if parameter vector files exist
  const char* parameter_vector_dir = "parameter_vector";
  int idx = 0;
  string parameter_vector_dir_last = findLastFile(parameter_vector_dir, "", &idx);
  string parameter_vector_file_pattern = parameter_vector_dir_last + "/iteration_";
  int idx_1 = 1;
  string parameter_vector_file_last = findLastFile(parameter_vector_file_pattern, ".txt", &idx_1);
  printf("[Main] Checking parameter vector file %s\n", parameter_vector_file_last.c_str());

  // vectors used to store RF weights
  vector<double>* alphas = new vector<double>[sample.n];
  vector<double>* alphas_edge = 0;
  if(updateType == UPDATE_NODE_EDGE) {
    alphas_edge = new vector<double>[sample.n];
  }
  int alphas_idx = 0;

  if(fileExists("alphas.txt") && fileExists(parameter_vector_file_last)) {

    // Loading alpha coefficients
    ifstream ifs("alphas.txt");
    string line;
    int lineIdx = 0;
    while(lineIdx < sample.n && getline(ifs, line)) {
      vector<string> tokens;
      splitString(line, tokens);
      for(vector<string>::iterator it = tokens.begin();
          it != tokens.end(); ++it) {
        alphas[lineIdx].push_back(atoi(it->c_str()));
      }
      ++lineIdx;
    }
    ifs.close();
    if(lineIdx > 0) {
      alphas_idx = alphas[0].size();
    }

    // Loading parameters
    printf("[Main] Found parameter vector file %s\n", parameter_vector_file_last.c_str());
    struct_parm.ssvm_iteration = idx + 1;
    update_output_dir(struct_parm.ssvm_iteration);

    //EnergyParam param(parameter_vector_file_last.c_str());
    //updateCoeffs(sample, param, struct_parm, updateType, alphas, alphas_idx);
    //alphas_idx = 1;

  } else {
    struct_parm.ssvm_iteration = 0;

    // insert alpha coefficients for first iteration
    for(int i = 0; i < sample.n; ++i) {
      alphas[i].push_back(1);
    }

    ofstream ofs("alphas.txt", ios::app);
    int i = 0;
    for(; i < sample.n - 1; ++i) {
      ofs << alphas[i][alphas_idx] << " ";
    }
    if(i < sample.n) {
      ofs << alphas[i][alphas_idx];
    }
    ofs << endl;
    ofs.close();

    // edges
    for(int i = 0; i < sample.n; ++i) {
      alphas_edge[i].push_back(1);
    }

    ofstream ofse("alphas_edge.txt", ios::app);
    i = 0;
    for(; i < sample.n - 1; ++i) {
      ofse << alphas_edge[i][alphas_idx] << " ";
    }
    if(i < sample.n) {
      ofse << alphas_edge[i][alphas_idx];
    }
    ofse << endl;
    ofse.close();

    ++alphas_idx;
  }

  const int nMaxIterations = 5;
  bool bIterate = true;
  do {

    printf("-----------------------------------------SSVM-ITERATION-%d-START\n",
           struct_parm.ssvm_iteration);

    struct_parm.iterationId = 1;

    /* Do the learning and return structmodel. */
    if(alg_type == 0)
      svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_ALG);
    else if(alg_type == 1)
      svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_SHRINK_ALG);
    else if(alg_type == 2)
      svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_PRIMAL_ALG);
    else if(alg_type == 3)
      svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_ALG);
    else if(alg_type == 4)
      svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_CACHE_ALG);
    else if(alg_type == 9)
      svm_learn_struct_joint_custom(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel);
    else
      exit(1);

    char _modelfile[BUFFER_SIZE];
    //sprintf(_modelfile, "%s_%d", modelfile, struct_parm.ssvm_iteration);
    sprintf(_modelfile, "%s_%d", modelfile, struct_parm.ssvm_iteration);
    printf("[Main] Writing learned model to %s\n", _modelfile);
    write_struct_model(_modelfile, &structmodel, &struct_parm);

    // Run inference on training data and increase loss for misclassified points
    printf("[Main] Loading learned model to %s\n", _modelfile);
    EnergyParam param(_modelfile);

    updateCoeffs(sample, param, struct_parm, updateType, alphas, alphas_edge,
                 struct_parm.ssvm_iteration + 1);

    ofstream ofs("alphas.txt", ios::app);
    int i = 0;
    for(; i < sample.n - 1; ++i) {
      ofs << alphas[i][alphas_idx] << " ";
    }
    if(i < sample.n) {
      ofs << alphas[i][alphas_idx];
    }
    ofs << endl;
    ofs.close();

    ofstream ofse("alphas_edge.txt", ios::app);
    i = 0;
    for(; i < sample.n - 1; ++i) {
      ofse << alphas_edge[i][alphas_idx] << " ";
    }
    if(i < sample.n) {
      ofse << alphas_edge[i][alphas_idx];
    }
    ofse << endl;
    ofse.close();

    ++alphas_idx;

    printf("-----------------------------------------SSVM-ITERATION-%d-END\n",
           struct_parm.ssvm_iteration);

    ++struct_parm.ssvm_iteration;

    bIterate = (nMaxIterations == -1 || struct_parm.ssvm_iteration < nMaxIterations);

  } while(bIterate);

  // Output final segmentation for all examples
  long nExamples = sample.n;
  int nRFs = struct_parm.ssvm_iteration;
  double* lossPerLabel = 0;
  labelType* groundTruthLabels = 0;

  for(int i = 0; i < nExamples; i++) { /*** example loop ***/

    Slice_P* slice = sample.examples[i].x.slice;
    Feature* feature = sample.examples[i].x.feature;
    //map<sidType, nodeCoeffType>* nodeCoeffs = sample.examples[i].x.nodeCoeffs;
    //map<sidType, edgeCoeffType>* edgeCoeffs = sample.examples[i].x.edgeCoeffs;
    map<sidType, nodeCoeffType>* nodeCoeffs = 0;
    map<sidType, edgeCoeffType>* edgeCoeffs = 0;
    int nNodes = slice->getNbSupernodes();

    stringstream soutPb;
    soutPb << loss_dir;
    soutPb << "pb_";
    soutPb << getNameFromPathWithoutExtension(slice->getName());
    soutPb << "_";
    soutPb << "combined";
    //soutPb << setw(5) << setfill('0') << ssvm_iteration;
    soutPb << ".tif";
    printf("[Main] Exporting %s\n", soutPb.str().c_str());

    labelType* inferredLabels =
      computeCombinedLabels(slice, feature, groundTruthLabels,
                            lossPerLabel, nRFs, alphas, i,
                            nodeCoeffs, edgeCoeffs, soutPb.str().c_str());

    stringstream sout;
    sout << loss_dir;
    sout << getNameFromPathWithoutExtension(slice->getName());
    sout << "_";
    sout << "combined";
    //sout << setw(5) << setfill('0') << ssvm_iteration;
    sout << ".tif";
    printf("[Main] Exporting %s\n", sout.str().c_str());
    slice->exportOverlay(sout.str().c_str(), inferredLabels);

    stringstream soutBW;
    soutBW << loss_dir;
    soutBW << getNameFromPathWithoutExtension(slice->getName());
    soutBW << "_";
    soutBW << "combined";
    //soutBW << setw(5) << setfill('0') << ssvm_iteration;
    soutBW << "_BW.tif";
    printf("[Main] Exporting %s\n", soutBW.str().c_str());        
    slice->exportSupernodeLabels(soutBW.str().c_str(),
                                 struct_parm.nClasses,
                                 inferredLabels, nNodes,
                                 &(struct_parm.labelToClassIdx));

    delete[] inferredLabels;
  }

  free_struct_sample(sample);
  free_struct_model(structmodel);

  svm_struct_learn_api_exit();

  return 0;
}
void
mexFunction (int nout, mxArray ** out, int nin, mxArray const ** in)
{
  SAMPLE sample;  /* training sample */
  LEARN_PARM learn_parm;
  KERNEL_PARM kernel_parm;
  STRUCT_LEARN_PARM struct_parm;
  STRUCTMODEL structmodel;
  int alg_type;

  enum {IN_ARGS=0, IN_SPARM} ;
  enum {OUT_W=0} ;

  /* SVM-light is not fully reentrant, so we need to run this patch first */
  init_qp_solver() ;
  verbosity = 0 ;
  kernel_cache_statistic = 0 ;

  if (nin != 2) {
    mexErrMsgTxt("Two arguments required") ;
  }

  /* Parse ARGS  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
  char arg [1024 + 1] ;
  int argc ;
  char ** argv ;

  if (! uIsString(in[IN_ARGS], -1)) {
    mexErrMsgTxt("ARGS must be a string") ;
  }

  mxGetString(in[IN_ARGS], arg, sizeof(arg) / sizeof(char)) ;
  arg_split (arg, &argc, &argv) ;

  svm_struct_learn_api_init(argc+1, argv-1) ;

  read_input_parameters (argc+1,argv-1,
                         &verbosity, &struct_verbosity,
                         &struct_parm, &learn_parm,
                         &kernel_parm, &alg_type ) ;
  
  if (kernel_parm.kernel_type != LINEAR &&
      kernel_parm.kernel_type != CUSTOM) {
    mexErrMsgTxt ("Only LINEAR or CUSTOM kerneles are supported") ;
  }

  /* Parse SPARM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
  mxArray const * sparm_array = in [IN_SPARM] ;
  mxArray const * patterns_array ;
  mxArray const * labels_array ;
  mxArray const * kernelFn_array ;
  int numExamples, ei ;

  if (! sparm_array) {
    mexErrMsgTxt("SPARM must be a structure") ;
  }
  struct_parm.mex = sparm_array ;

  patterns_array = mxGetField(sparm_array, 0, "patterns") ;
  if (! patterns_array ||
      ! mxIsCell(patterns_array)) {
    mexErrMsgTxt("SPARM.PATTERNS must be a cell array") ;
  }

  numExamples = mxGetNumberOfElements(patterns_array) ;

  labels_array = mxGetField(sparm_array, 0, "labels") ;
  if (! labels_array ||
      ! mxIsCell(labels_array) ||
      ! mxGetNumberOfElements(labels_array) == numExamples) {
    mexErrMsgTxt("SPARM.LABELS must be a cell array "
                 "with the same number of elements of "
                 "SPARM.PATTERNS") ;
  }

  sample.n = numExamples ;
  sample.examples = (EXAMPLE *) my_malloc (sizeof(EXAMPLE) * numExamples) ;
  for (ei = 0 ; ei < numExamples ; ++ ei) {
    sample.examples[ei].x.mex = mxGetCell(patterns_array, ei) ;
    sample.examples[ei].y.mex = mxGetCell(labels_array,   ei) ;
    sample.examples[ei].y.isOwner = 0 ;
  }

  if (struct_verbosity >= 1) {
    mexPrintf("There are %d training examples\n", numExamples) ;
  }

  kernelFn_array = mxGetField(sparm_array, 0, "kernelFn") ;
  if (! kernelFn_array && kernel_parm.kernel_type == CUSTOM) {
    mexErrMsgTxt("SPARM.KERNELFN must be define for CUSTOM kernels") ;
  }
  if (kernelFn_array) {
    MexKernelInfo * info ;
    if (mxGetClassID(kernelFn_array) != mxFUNCTION_CLASS) {
      mexErrMsgTxt("SPARM.KERNELFN must be a valid function handle") ;
    }
    info = (MexKernelInfo*) kernel_parm.custom ;
    info -> structParm = sparm_array ;
    info -> kernelFn   = kernelFn_array ;
  }

  /* Learning  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
  switch (alg_type) {
  case 0:
    svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_ALG) ;
    break ;
  case 1:
    svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_SHRINK_ALG);
    break ;
  case 2:
    svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_PRIMAL_ALG);
    break ;
  case 3:
    svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_ALG);
    break ;
  case 4:
    svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_CACHE_ALG);
    break  ;
  case 9:
    svm_learn_struct_joint_custom(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel);
    break ;
  default:
    mexErrMsgTxt("Unknown algorithm type") ;
  }

  /* Write output  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */

  /* Warning: The model contains references to the original data 'docs'.
     If you want to free the original data, and only keep the model, you
     have to make a deep copy of 'model'. */

  mxArray * model_array = newMxArrayEncapsulatingSmodel (&structmodel) ;
  out[OUT_W] = mxDuplicateArray (model_array) ;
  destroyMxArrayEncapsulatingSmodel (model_array) ;
  
  free_struct_sample (sample) ;
  free_struct_model (structmodel) ;
  svm_struct_learn_api_exit () ;
  free_qp_solver () ;
}