Exemplo n.º 1
0
int main(int argc, char **argv)
{

	float priors[] = { 1.0, 10.0 };	// Edible vs poisonos weights

	CvMat *var_type;
	CvMat *data;				// jmh add
	data = cvCreateMat(20, 30, CV_8U);	// jmh add

	var_type = cvCreateMat(data->cols + 1, 1, CV_8U);
	cvSet(var_type, cvScalarAll(CV_VAR_CATEGORICAL));	// all these vars 
	// are categorical
	CvDTree *dtree;
	dtree = new CvDTree;
	dtree->train(data, CV_ROW_SAMPLE, responses, 0, 0, var_type, missing, CvDTreeParams(8,	// max depth
																						10,	// min sample count
																						0,	// regression accuracy: N/A here
																						true,	// compute surrogate split, 
																						//   as we have missing data
																						15,	// max number of categories 
																						//   (use sub-optimal algorithm for
																						//   larger numbers)
																						10,	// cross-validations 
																						true,	// use 1SE rule => smaller tree
																						true,	// throw away the pruned tree branches
																						priors	// the array of priors, the bigger 
																						//   p_weight, the more attention
																						//   to the poisonous mushrooms
				 )
		);

	dtree->save("tree.xml", "MyTree");
	dtree->clear();
	dtree->load("tree.xml", "MyTree");

#define MAX_CLUSTERS 5
	CvScalar color_tab[MAX_CLUSTERS];
	IplImage *img = cvCreateImage(cvSize(500, 500), 8, 3);
	CvRNG rng = cvRNG(0xffffffff);

	color_tab[0] = CV_RGB(255, 0, 0);
	color_tab[1] = CV_RGB(0, 255, 0);
	color_tab[2] = CV_RGB(100, 100, 255);
	color_tab[3] = CV_RGB(255, 0, 255);
	color_tab[4] = CV_RGB(255, 255, 0);

	cvNamedWindow("clusters", 1);

	for (;;) {
		int k, cluster_count = cvRandInt(&rng) % MAX_CLUSTERS + 1;
		int i, sample_count = cvRandInt(&rng) % 1000 + 1;
		CvMat *points = cvCreateMat(sample_count, 1, CV_32FC2);
		CvMat *clusters = cvCreateMat(sample_count, 1, CV_32SC1);

		/* generate random sample from multivariate 
		   Gaussian distribution */
		for (k = 0; k < cluster_count; k++) {
			CvPoint center;
			CvMat point_chunk;
			center.x = cvRandInt(&rng) % img->width;
			center.y = cvRandInt(&rng) % img->height;
			cvGetRows(points, &point_chunk,
					  k * sample_count / cluster_count,
					  k == cluster_count - 1 ? sample_count :
					  (k + 1) * sample_count / cluster_count);
			cvRandArr(&rng, &point_chunk, CV_RAND_NORMAL,
					  cvScalar(center.x, center.y, 0, 0),
					  cvScalar(img->width / 6, img->height / 6, 0, 0));
		}

		/* shuffle samples */
		for (i = 0; i < sample_count / 2; i++) {
			CvPoint2D32f *pt1 = (CvPoint2D32f *) points->data.fl +
				cvRandInt(&rng) % sample_count;
			CvPoint2D32f *pt2 = (CvPoint2D32f *) points->data.fl +
				cvRandInt(&rng) % sample_count;
			CvPoint2D32f temp;
			CV_SWAP(*pt1, *pt2, temp);
		}

		cvKMeans2(points, cluster_count, clusters,
				  cvTermCriteria(CV_TERMCRIT_EPS + CV_TERMCRIT_ITER, 10, 1.0));
		cvZero(img);
		for (i = 0; i < sample_count; i++) {
			CvPoint2D32f pt = ((CvPoint2D32f *) points->data.fl)[i];
			int cluster_idx = clusters->data.i[i];
			cvCircle(img, cvPointFrom32f(pt), 2,
					 color_tab[cluster_idx], CV_FILLED);
		}

		cvReleaseMat(&points);
		cvReleaseMat(&clusters);

		cvShowImage("clusters", img);

		int key = cvWaitKey(0);
		if (key == 27)			// 'ESC'
			break;
	}
}
int main( int argc, char** argv )
{
	Mat img;
   char file[255];
	
	//total no of training samples
	int total_train_samples = 0;
	for(int cl=0; cl<nr_classes; cl++)
	{
		total_train_samples = total_train_samples + train_samples[cl];
	}
	
	// Training Data
	Mat training_data = Mat(total_train_samples,feature_size,CV_32FC1);
	Mat training_label = Mat(total_train_samples,1,CV_32FC1);
	// training data .csv file
	ofstream trainingDataCSV;
	trainingDataCSV.open("./training_data.csv");	
		
	int index = 0;	
	for(int cl=0; cl<nr_classes; cl++)
	{
      for(int ll=0; ll<train_samples[cl]; ll++)
      {
      	//assign sample label
			training_label.at<float>(index+ll,0) = class_labels[cl]; 	
			//image feature extraction
 			sprintf(file, "%s/%d/%d.png", pathToImages, class_labels[cl], ll);
         img = imread(file, 1);
         if (!img.data)
         {
             cout << "File " << file << " not found\n";
             exit(1);
         }
         imshow("sample",img);
         waitKey(1);
         //calculate feature vector
			vector<float> feature = ColorHistFeature(img);
			for(int ft=0; ft<feature.size(); ft++)
			{
				training_data.at<float>(index+ll,ft) = feature[ft];
				trainingDataCSV<<feature[ft]<<",";
			}
			trainingDataCSV<<class_labels[cl]<<"\n";
		}
		index = index + train_samples[cl];
	}	
	
	trainingDataCSV.close();

	/// Decision Tree
	// Training
	float *priors = NULL;
	CvDTreeParams DTParams = CvDTreeParams(25, // max depth
		                                    5, // min sample count
		                                    0, // regression accuracy: N/A here
		                                    false, // compute surrogate split, no missing data
		                                    15, // max number of categories (use sub-optimal algorithm for larger numbers)
		                                    15, // the number of cross-validation folds
		                                    false, // use 1SE rule => smaller tree
		                                    false, // throw away the pruned tree branches
		                                    priors // the array of priors
		                                   );
	CvDTree DTree;
	DTree.train(training_data,CV_ROW_SAMPLE,training_label,Mat(),Mat(),Mat(),Mat(),DTParams);
			
	// save model
	DTree.save("training.model");		
	
	// load model
	CvDTree DT;
	DT.load("training.model");	
	
	// test on sample image
	string filename = string(pathToImages)+"/test.png";
	Mat test_img = imread(filename.c_str());
	vector<float> test_feature = ColorHistFeature(test_img);
	CvDTreeNode* result_node = DT.predict(Mat(test_feature),Mat(),false);
	double predictedClass = result_node->value;
	cout<<"predictedClass "<<predictedClass<<"\n";

/*	
	//CvMLData for calculating error
	CvMLData* MLData;
	MLData = new CvMLData();
	MLData->read_csv("training_data.csv");
	MLData->set_response_idx(feature_size);
//	MLData->change_var_type(feature_size,CV_VAR_CATEGORICAL);
	
	// calculate training error
	float error = DT.calc_error(MLData,CV_TRAIN_ERROR,0);
	cout<<"training error "<<error<<"\n";
*/
	return 0;
}