int _svm_struct_learn (int argc, char* argv[]) { SAMPLE sample; /* training sample */ LEARN_PARM learn_parm; KERNEL_PARM kernel_parm; STRUCT_LEARN_PARM struct_parm; STRUCTMODEL structmodel; int alg_type; svm_struct_learn_api_init(argc,argv); _svm_struct_learn_read_input_parameters(argc,argv,trainfile,modelfile,&verbosity, &struct_verbosity,&struct_parm,&learn_parm, &kernel_parm,&alg_type); if(struct_verbosity>=1) { printf("Reading training examples..."); fflush(stdout); } /* read the training examples */ sample=read_struct_examples(trainfile,&struct_parm); if(struct_verbosity>=1) { printf("done\n"); fflush(stdout); } /* Do the learning and return structmodel. */ if(alg_type == 0) svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_ALG); else if(alg_type == 1) svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_SHRINK_ALG); else if(alg_type == 2) svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_PRIMAL_ALG); else if(alg_type == 3) svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_ALG); else if(alg_type == 4) svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_CACHE_ALG); else if(alg_type == 9) svm_learn_struct_joint_custom(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel); else exit(1); /* Warning: The model contains references to the original data 'docs'. If you want to free the original data, and only keep the model, you have to make a deep copy of 'model'. */ if(struct_verbosity>=1) { printf("Writing learned model...");fflush(stdout); } write_struct_model(modelfile,&structmodel,&struct_parm); if(struct_verbosity>=1) { printf("done\n");fflush(stdout); } free_struct_sample(sample); free_struct_model(structmodel); svm_struct_learn_api_exit(); return 0; }
int main (int argc, char* argv[]) { SAMPLE sample; /* training sample */ LEARN_PARM learn_parm; KERNEL_PARM kernel_parm; STRUCT_LEARN_PARM struct_parm; STRUCTMODEL structmodel; int alg_type; svm_struct_learn_api_init(argc,argv); read_input_parameters(argc,argv,trainfile,modelfile,&verbosity, &struct_verbosity,&struct_parm,&learn_parm, &kernel_parm,&alg_type); if(struct_verbosity>=1) { //verbose = true; printf("Reading training examples..."); fflush(stdout); } /* read the training examples */ sample=read_struct_examples(trainfile,&struct_parm); if(struct_verbosity>=1) { printf("done\n"); fflush(stdout); } string config_tmp; bool update_loss_function = false; if(Config::Instance()->getParameter("update_loss_function", config_tmp)) { update_loss_function = config_tmp.c_str()[0] == '1'; } printf("[Main] update_loss_function=%d\n", (int)update_loss_function); if(!update_loss_function) { printf("update_loss_function should be true\n"); exit(-1); } eUpdateType updateType = UPDATE_NODE_EDGE; if(Config::Instance()->getParameter("update_type", config_tmp)) { updateType = (eUpdateType)atoi(config_tmp.c_str()); } printf("[Main] update_type=%d\n", (int)updateType); mkdir(loss_dir, 0777); // check if parameter vector files exist const char* parameter_vector_dir = "parameter_vector"; int idx = 0; string parameter_vector_dir_last = findLastFile(parameter_vector_dir, "", &idx); string parameter_vector_file_pattern = parameter_vector_dir_last + "/iteration_"; int idx_1 = 1; string parameter_vector_file_last = findLastFile(parameter_vector_file_pattern, ".txt", &idx_1); printf("[Main] Checking parameter vector file %s\n", parameter_vector_file_last.c_str()); // vectors used to store RF weights vector<double>* alphas = new vector<double>[sample.n]; vector<double>* alphas_edge = 0; if(updateType == UPDATE_NODE_EDGE) { alphas_edge = new vector<double>[sample.n]; } int alphas_idx = 0; if(fileExists("alphas.txt") && fileExists(parameter_vector_file_last)) { // Loading alpha coefficients ifstream ifs("alphas.txt"); string line; int lineIdx = 0; while(lineIdx < sample.n && getline(ifs, line)) { vector<string> tokens; splitString(line, tokens); for(vector<string>::iterator it = tokens.begin(); it != tokens.end(); ++it) { alphas[lineIdx].push_back(atoi(it->c_str())); } ++lineIdx; } ifs.close(); if(lineIdx > 0) { alphas_idx = alphas[0].size(); } // Loading parameters printf("[Main] Found parameter vector file %s\n", parameter_vector_file_last.c_str()); struct_parm.ssvm_iteration = idx + 1; update_output_dir(struct_parm.ssvm_iteration); //EnergyParam param(parameter_vector_file_last.c_str()); //updateCoeffs(sample, param, struct_parm, updateType, alphas, alphas_idx); //alphas_idx = 1; } else { struct_parm.ssvm_iteration = 0; // insert alpha coefficients for first iteration for(int i = 0; i < sample.n; ++i) { alphas[i].push_back(1); } ofstream ofs("alphas.txt", ios::app); int i = 0; for(; i < sample.n - 1; ++i) { ofs << alphas[i][alphas_idx] << " "; } if(i < sample.n) { ofs << alphas[i][alphas_idx]; } ofs << endl; ofs.close(); // edges for(int i = 0; i < sample.n; ++i) { alphas_edge[i].push_back(1); } ofstream ofse("alphas_edge.txt", ios::app); i = 0; for(; i < sample.n - 1; ++i) { ofse << alphas_edge[i][alphas_idx] << " "; } if(i < sample.n) { ofse << alphas_edge[i][alphas_idx]; } ofse << endl; ofse.close(); ++alphas_idx; } const int nMaxIterations = 5; bool bIterate = true; do { printf("-----------------------------------------SSVM-ITERATION-%d-START\n", struct_parm.ssvm_iteration); struct_parm.iterationId = 1; /* Do the learning and return structmodel. */ if(alg_type == 0) svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_ALG); else if(alg_type == 1) svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_SHRINK_ALG); else if(alg_type == 2) svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_PRIMAL_ALG); else if(alg_type == 3) svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_ALG); else if(alg_type == 4) svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_CACHE_ALG); else if(alg_type == 9) svm_learn_struct_joint_custom(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel); else exit(1); char _modelfile[BUFFER_SIZE]; //sprintf(_modelfile, "%s_%d", modelfile, struct_parm.ssvm_iteration); sprintf(_modelfile, "%s_%d", modelfile, struct_parm.ssvm_iteration); printf("[Main] Writing learned model to %s\n", _modelfile); write_struct_model(_modelfile, &structmodel, &struct_parm); // Run inference on training data and increase loss for misclassified points printf("[Main] Loading learned model to %s\n", _modelfile); EnergyParam param(_modelfile); updateCoeffs(sample, param, struct_parm, updateType, alphas, alphas_edge, struct_parm.ssvm_iteration + 1); ofstream ofs("alphas.txt", ios::app); int i = 0; for(; i < sample.n - 1; ++i) { ofs << alphas[i][alphas_idx] << " "; } if(i < sample.n) { ofs << alphas[i][alphas_idx]; } ofs << endl; ofs.close(); ofstream ofse("alphas_edge.txt", ios::app); i = 0; for(; i < sample.n - 1; ++i) { ofse << alphas_edge[i][alphas_idx] << " "; } if(i < sample.n) { ofse << alphas_edge[i][alphas_idx]; } ofse << endl; ofse.close(); ++alphas_idx; printf("-----------------------------------------SSVM-ITERATION-%d-END\n", struct_parm.ssvm_iteration); ++struct_parm.ssvm_iteration; bIterate = (nMaxIterations == -1 || struct_parm.ssvm_iteration < nMaxIterations); } while(bIterate); // Output final segmentation for all examples long nExamples = sample.n; int nRFs = struct_parm.ssvm_iteration; double* lossPerLabel = 0; labelType* groundTruthLabels = 0; for(int i = 0; i < nExamples; i++) { /*** example loop ***/ Slice_P* slice = sample.examples[i].x.slice; Feature* feature = sample.examples[i].x.feature; //map<sidType, nodeCoeffType>* nodeCoeffs = sample.examples[i].x.nodeCoeffs; //map<sidType, edgeCoeffType>* edgeCoeffs = sample.examples[i].x.edgeCoeffs; map<sidType, nodeCoeffType>* nodeCoeffs = 0; map<sidType, edgeCoeffType>* edgeCoeffs = 0; int nNodes = slice->getNbSupernodes(); stringstream soutPb; soutPb << loss_dir; soutPb << "pb_"; soutPb << getNameFromPathWithoutExtension(slice->getName()); soutPb << "_"; soutPb << "combined"; //soutPb << setw(5) << setfill('0') << ssvm_iteration; soutPb << ".tif"; printf("[Main] Exporting %s\n", soutPb.str().c_str()); labelType* inferredLabels = computeCombinedLabels(slice, feature, groundTruthLabels, lossPerLabel, nRFs, alphas, i, nodeCoeffs, edgeCoeffs, soutPb.str().c_str()); stringstream sout; sout << loss_dir; sout << getNameFromPathWithoutExtension(slice->getName()); sout << "_"; sout << "combined"; //sout << setw(5) << setfill('0') << ssvm_iteration; sout << ".tif"; printf("[Main] Exporting %s\n", sout.str().c_str()); slice->exportOverlay(sout.str().c_str(), inferredLabels); stringstream soutBW; soutBW << loss_dir; soutBW << getNameFromPathWithoutExtension(slice->getName()); soutBW << "_"; soutBW << "combined"; //soutBW << setw(5) << setfill('0') << ssvm_iteration; soutBW << "_BW.tif"; printf("[Main] Exporting %s\n", soutBW.str().c_str()); slice->exportSupernodeLabels(soutBW.str().c_str(), struct_parm.nClasses, inferredLabels, nNodes, &(struct_parm.labelToClassIdx)); delete[] inferredLabels; } free_struct_sample(sample); free_struct_model(structmodel); svm_struct_learn_api_exit(); return 0; }
void mexFunction (int nout, mxArray ** out, int nin, mxArray const ** in) { SAMPLE sample; /* training sample */ LEARN_PARM learn_parm; KERNEL_PARM kernel_parm; STRUCT_LEARN_PARM struct_parm; STRUCTMODEL structmodel; int alg_type; enum {IN_ARGS=0, IN_SPARM} ; enum {OUT_W=0} ; /* SVM-light is not fully reentrant, so we need to run this patch first */ init_qp_solver() ; verbosity = 0 ; kernel_cache_statistic = 0 ; if (nin != 2) { mexErrMsgTxt("Two arguments required") ; } /* Parse ARGS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */ char arg [1024 + 1] ; int argc ; char ** argv ; if (! uIsString(in[IN_ARGS], -1)) { mexErrMsgTxt("ARGS must be a string") ; } mxGetString(in[IN_ARGS], arg, sizeof(arg) / sizeof(char)) ; arg_split (arg, &argc, &argv) ; svm_struct_learn_api_init(argc+1, argv-1) ; read_input_parameters (argc+1,argv-1, &verbosity, &struct_verbosity, &struct_parm, &learn_parm, &kernel_parm, &alg_type ) ; if (kernel_parm.kernel_type != LINEAR && kernel_parm.kernel_type != CUSTOM) { mexErrMsgTxt ("Only LINEAR or CUSTOM kerneles are supported") ; } /* Parse SPARM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */ mxArray const * sparm_array = in [IN_SPARM] ; mxArray const * patterns_array ; mxArray const * labels_array ; mxArray const * kernelFn_array ; int numExamples, ei ; if (! sparm_array) { mexErrMsgTxt("SPARM must be a structure") ; } struct_parm.mex = sparm_array ; patterns_array = mxGetField(sparm_array, 0, "patterns") ; if (! patterns_array || ! mxIsCell(patterns_array)) { mexErrMsgTxt("SPARM.PATTERNS must be a cell array") ; } numExamples = mxGetNumberOfElements(patterns_array) ; labels_array = mxGetField(sparm_array, 0, "labels") ; if (! labels_array || ! mxIsCell(labels_array) || ! mxGetNumberOfElements(labels_array) == numExamples) { mexErrMsgTxt("SPARM.LABELS must be a cell array " "with the same number of elements of " "SPARM.PATTERNS") ; } sample.n = numExamples ; sample.examples = (EXAMPLE *) my_malloc (sizeof(EXAMPLE) * numExamples) ; for (ei = 0 ; ei < numExamples ; ++ ei) { sample.examples[ei].x.mex = mxGetCell(patterns_array, ei) ; sample.examples[ei].y.mex = mxGetCell(labels_array, ei) ; sample.examples[ei].y.isOwner = 0 ; } if (struct_verbosity >= 1) { mexPrintf("There are %d training examples\n", numExamples) ; } kernelFn_array = mxGetField(sparm_array, 0, "kernelFn") ; if (! kernelFn_array && kernel_parm.kernel_type == CUSTOM) { mexErrMsgTxt("SPARM.KERNELFN must be define for CUSTOM kernels") ; } if (kernelFn_array) { MexKernelInfo * info ; if (mxGetClassID(kernelFn_array) != mxFUNCTION_CLASS) { mexErrMsgTxt("SPARM.KERNELFN must be a valid function handle") ; } info = (MexKernelInfo*) kernel_parm.custom ; info -> structParm = sparm_array ; info -> kernelFn = kernelFn_array ; } /* Learning ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */ switch (alg_type) { case 0: svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_ALG) ; break ; case 1: svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_SHRINK_ALG); break ; case 2: svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_PRIMAL_ALG); break ; case 3: svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_ALG); break ; case 4: svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_CACHE_ALG); break ; case 9: svm_learn_struct_joint_custom(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel); break ; default: mexErrMsgTxt("Unknown algorithm type") ; } /* Write output ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */ /* Warning: The model contains references to the original data 'docs'. If you want to free the original data, and only keep the model, you have to make a deep copy of 'model'. */ mxArray * model_array = newMxArrayEncapsulatingSmodel (&structmodel) ; out[OUT_W] = mxDuplicateArray (model_array) ; destroyMxArrayEncapsulatingSmodel (model_array) ; free_struct_sample (sample) ; free_struct_model (structmodel) ; svm_struct_learn_api_exit () ; free_qp_solver () ; }