コード例 #1
0
int main(int argc, char ** argv)
{
	init_shogun_with_defaults();
#ifdef USE_MOSEK

	int32_t num_examples = 10;
	int32_t example_length = 250;
	int32_t num_features = 10;
	int32_t num_noise_features = 2;
	CHMSVMModel* model = CTwoStateModel::simulate_data(num_examples, example_length, num_features, num_noise_features);

	CStructuredLabels* labels = model->get_labels();
	CFeatures* features = model->get_features();

	CPrimalMosekSOSVM* sosvm = new CPrimalMosekSOSVM(model, labels);
	SG_REF(sosvm);

	sosvm->train();
//	sosvm->get_w().display_vector("w");

	CStructuredLabels* out = CLabelsFactory::to_structured(sosvm->apply());

	ASSERT( out->get_num_labels() == labels->get_num_labels() );

	for ( int32_t i = 0 ; i < out->get_num_labels() ; ++i )
	{
		CSequence* pred_seq = CSequence::obtain_from_generic( out->get_label(i) );
		CSequence* true_seq = CSequence::obtain_from_generic( labels->get_label(i) );
		SG_UNREF(pred_seq);
		SG_UNREF(true_seq);
	}

	SG_UNREF(out);
	SG_UNREF(features); // because model->get_features() increased the count
	SG_UNREF(labels);   // because model->get_labels() increased the count
	SG_UNREF(sosvm);

#endif /* USE_MOSEK */
	exit_shogun();

	return 0;
}
コード例 #2
0
ファイル: so_multiclass.cpp プロジェクト: AlexBinder/shogun
int main(int argc, char ** argv)
{
	init_shogun_with_defaults();
	
	SGVector< float64_t > labs(NUM_CLASSES*NUM_SAMPLES);
	SGMatrix< float64_t > feats(DIMS, NUM_CLASSES*NUM_SAMPLES);

	gen_rand_data(labs, feats);
	//read_data(labs, feats);

	// Create train labels
	CMulticlassSOLabels* labels = new CMulticlassSOLabels(labs);
	CMulticlassLabels*  mlabels = new CMulticlassLabels(labs);

	// Create train features
	CDenseFeatures< float64_t >* features = new CDenseFeatures< float64_t >(feats);

	// Create structured model
	CMulticlassModel* model = new CMulticlassModel(features, labels);

	// Create loss function
	CHingeLoss* loss = new CHingeLoss();

	// Create SO-SVM
	CPrimalMosekSOSVM* sosvm = new CPrimalMosekSOSVM(model, loss, labels);
	CDualLibQPBMSOSVM* bundle = new CDualLibQPBMSOSVM(model, loss, labels, 1000);
	bundle->set_verbose(false);
	SG_REF(sosvm);
	SG_REF(bundle);

	CTime start;
	float64_t t1;
	sosvm->train();
	SG_SPRINT(">>>> PrimalMosekSOSVM trained in %9.4f\n", (t1 = start.cur_time_diff(false)));
	bundle->train();
	SG_SPRINT(">>>> BMRM trained in %9.4f\n", start.cur_time_diff(false)-t1);
	CStructuredLabels* out = CStructuredLabels::obtain_from_generic(sosvm->apply());
	CStructuredLabels* bout = CStructuredLabels::obtain_from_generic(bundle->apply());

	// Create liblinear svm classifier with L2-regularized L2-loss
	CLibLinear* svm = new CLibLinear(L2R_L2LOSS_SVC);

	// Add some configuration to the svm
	svm->set_epsilon(EPSILON);
	svm->set_bias_enabled(false);

	// Create a multiclass svm classifier that consists of several of the previous one
	CLinearMulticlassMachine* mc_svm = 
			new CLinearMulticlassMachine( new CMulticlassOneVsRestStrategy(), 
			(CDotFeatures*) features, svm, mlabels);
	SG_REF(mc_svm);

	// Train the multiclass machine using the data passed in the constructor
	mc_svm->train();
	CMulticlassLabels* mout = CMulticlassLabels::obtain_from_generic(mc_svm->apply());

	SGVector< float64_t > w = sosvm->get_w();
	for ( int32_t i = 0 ; i < w.vlen ; ++i )
		SG_SPRINT("%10f ", w[i]);
	SG_SPRINT("\n\n");

	for ( int32_t i = 0 ; i < NUM_CLASSES ; ++i )
	{
		CLinearMachine* lm = (CLinearMachine*) mc_svm->get_machine(i);
		SGVector< float64_t > mw = lm->get_w();
		for ( int32_t j = 0 ; j < mw.vlen ; ++j )
			SG_SPRINT("%10f ", mw[j]);

		SG_UNREF(lm); // because of CLinearMulticlassMachine::get_machine()
	}
	SG_SPRINT("\n");

	CStructuredAccuracy* structured_evaluator = new CStructuredAccuracy();
	CMulticlassAccuracy* multiclass_evaluator = new CMulticlassAccuracy();
	SG_REF(structured_evaluator);
	SG_REF(multiclass_evaluator);

	SG_SPRINT("SO-SVM: %5.2f%\n", 100.0*structured_evaluator->evaluate(out, labels));
	SG_SPRINT("BMRM:   %5.2f%\n", 100.0*structured_evaluator->evaluate(bout, labels));
	SG_SPRINT("MC:     %5.2f%\n", 100.0*multiclass_evaluator->evaluate(mout, mlabels));

	// Free memory
	SG_UNREF(multiclass_evaluator);
	SG_UNREF(structured_evaluator);
	SG_UNREF(mout);
	SG_UNREF(mc_svm);
	SG_UNREF(bundle);
	SG_UNREF(sosvm);
	SG_UNREF(bout);
	SG_UNREF(out);
	exit_shogun();

	return 0;
}