//void DataSet_load_from_python(DataSet *dataset, float *y, char **x, int len) {dataset->load_from_python(y, x, len);} void thundersvm_train_sub(DataSet& train_dataset, CMDParser& parser, char* model_file_path){ SvmModel *model = nullptr; switch (parser.param_cmd.svm_type) { case SvmParam::C_SVC: model = new SVC(); break; case SvmParam::NU_SVC: model = new NuSVC(); break; case SvmParam::ONE_CLASS: model = new OneClassSVC(); break; case SvmParam::EPSILON_SVR: model = new SVR(); break; case SvmParam::NU_SVR: model = new NuSVR(); break; } //todo add this to check_parameter method if (parser.param_cmd.svm_type == SvmParam::NU_SVC) { train_dataset.group_classes(); for (int i = 0; i < train_dataset.n_classes(); ++i) { int n1 = train_dataset.count()[i]; for (int j = i + 1; j < train_dataset.n_classes(); ++j) { int n2 = train_dataset.count()[j]; if (parser.param_cmd.nu * (n1 + n2) / 2 > min(n1, n2)) { printf("specified nu is infeasible\n"); return; } } } } if (parser.param_cmd.kernel_type != SvmParam::LINEAR) if (!parser.gamma_set) { parser.param_cmd.gamma = 1.f / train_dataset.n_features(); } #ifdef USE_CUDA CUDA_CHECK(cudaSetDevice(parser.gpu_id)); #endif vector<float_type> predict_y, test_y; if (parser.do_cross_validation) { predict_y = model->cross_validation(train_dataset, parser.param_cmd, parser.nr_fold); } else { model->train(train_dataset, parser.param_cmd); model->save_to_file(model_file_path); LOG(INFO) << "evaluating training score"; predict_y = model->predict(train_dataset.instances(), -1); //predict_y = model->predict(train_dataset.instances(), 10000); //test_y = train_dataset.y(); } Metric *metric = nullptr; switch (parser.param_cmd.svm_type) { case SvmParam::C_SVC: case SvmParam::NU_SVC: { metric = new Accuracy(); break; } case SvmParam::EPSILON_SVR: case SvmParam::NU_SVR: { metric = new MSE(); break; } case SvmParam::ONE_CLASS: { } } if (metric) { LOG(INFO) << metric->name() << " = " << metric->score(predict_y, train_dataset.y()) << std::endl; } return; }
void thundersvm_train_matlab(int argc, char **argv) { CMDParser parser; parser.parse_command_line(argc, argv); /* parser.param_cmd.svm_type = SvmParam::NU_SVC; parser.param_cmd.kernel_type = SvmParam::RBF; parser.param_cmd.C = 100; parser.param_cmd.gamma = 0; parser.param_cmd.nu = 0.1; parser.param_cmd.epsilon = 0.001; */ DataSet train_dataset; char input_file_path[1024] = DATASET_DIR; char model_file_path[1024] = DATASET_DIR; strcat(input_file_path, parser.svmtrain_input_file_name); strcat(model_file_path, parser.model_file_name); train_dataset.load_from_file(input_file_path); SvmModel *model = nullptr; switch (parser.param_cmd.svm_type) { case SvmParam::C_SVC: model = new SVC(); break; case SvmParam::NU_SVC: model = new NuSVC(); break; case SvmParam::ONE_CLASS: model = new OneClassSVC(); break; case SvmParam::EPSILON_SVR: model = new SVR(); break; case SvmParam::NU_SVR: model = new NuSVR(); break; } //todo add this to check_parameter method if (parser.param_cmd.svm_type == SvmParam::NU_SVC) { train_dataset.group_classes(); for (int i = 0; i < train_dataset.n_classes(); ++i) { int n1 = train_dataset.count()[i]; for (int j = i + 1; j < train_dataset.n_classes(); ++j) { int n2 = train_dataset.count()[j]; if (parser.param_cmd.nu * (n1 + n2) / 2 > min(n1, n2)) { printf("specified nu is infeasible\n"); return; } } } } #ifdef USE_CUDA CUDA_CHECK(cudaSetDevice(parser.gpu_id)); #endif vector<float_type> predict_y, test_y; if (parser.do_cross_validation) { vector<float_type> test_predict = model->cross_validation(train_dataset, parser.param_cmd, parser.nr_fold); int dataset_size = test_predict.size() / 2; test_y.insert(test_y.end(), test_predict.begin(), test_predict.begin() + dataset_size); predict_y.insert(predict_y.end(), test_predict.begin() + dataset_size, test_predict.end()); } else { model->train(train_dataset, parser.param_cmd); model->save_to_file(model_file_path); //predict_y = model->predict(train_dataset.instances(), 10000); //test_y = train_dataset.y(); } /* //perform svm testing Metric *metric = nullptr; switch (parser.param_cmd.svm_type) { case SvmParam::C_SVC: case SvmParam::NU_SVC: { metric = new Accuracy(); break; } case SvmParam::EPSILON_SVR: case SvmParam::NU_SVR: { metric = new MSE(); break; } case SvmParam::ONE_CLASS: { } } if (metric) { LOG(INFO) << metric->name() << " = " << metric->score(predict_y, test_y); } */ return; }