コード例 #1
0
ファイル: Test_Knn.cpp プロジェクト: geekzph/UMAD
/**
 *@ test training model of classification algorithm.
 *@ param classifyMethod  classifymethod: "knn" , "naviebayes".
 *@ param testdata: the primary trainging data.
 *@ param testDataLabel: the class label of primary training data.
 *@ param metric: distance function.
 *@ param pivotsAndTrainModelFileName: the file to store the selected pivots information and training model.
 *@ param testModelFileName: the file to store the training model.
 *@ param status: determine the source of the test data. 0: TestDataFromTrainData; 1: TestDataFromTestData.
 *@ param k: the number of nearest neighbors
*/
void CTest_Knn::TestModel(char *classifyMethod,vector<shared_ptr<CMetricData> > *testdata,vector<string> testDataLabel, CMetricDistance *metric,char *pivotsAndTrainModelFileName,char *testModelFileName,int status,int k,int splitRatio)
{
	CDatasetInMetricSpace getTestData;
	GetMetricData M_testdata;
	if(status==0)
	{
		M_testdata=getTestData.getMetricTestData_fromTrainData(classifyMethod,testdata,testDataLabel,metric,pivotsAndTrainModelFileName,splitRatio);
		pivotNumber=getTestData.pivotsNum;
	}
	else if(status==1)
	{
		M_testdata=getTestData.getMetricTestData_fromTestData(classifyMethod,testdata,testDataLabel,metric,pivotsAndTrainModelFileName);
		pivotNumber=getTestData.pivotsNum; 
	}
	showClassiciationResult(pivotsAndTrainModelFileName,M_testdata,testModelFileName,pivotNumber,k);
}
コード例 #2
0
ファイル: classification.cpp プロジェクト: ruimao/UMAD
void Train(char *dataType,char *disfun,char *pivotSelectionMethod,int numPivot,char *classifyMethod,char *trainDataFileName,int initialSize,int dim,char *pivotsAndTrainModelFileName,int coordinate)
{
#ifdef __GNUC__
	char *newTrainDataFileName="../../../SourceFiles/util/data/";
#endif

#ifdef _WIN32
	char *newTrainDataFileName="./SourceFiles/util/data/";
#endif
	joinCharArray(newTrainDataFileName,trainDataFileName);

#ifdef __GNUC__
	char *newPivotsAndTrainModelFileName="../../../SourceFiles/util/result/";
#endif

#ifdef _WIN32
	char *newPivotsAndTrainModelFileName="./SourceFiles/util/result/";
#endif
	joinCharArray(newPivotsAndTrainModelFileName,pivotsAndTrainModelFileName);

	vector<shared_ptr<CMetricData> > *traindata=0; //store the traindata
	vector<string> trainDataLabel;

	if(strcmp(dataType,"vector")==0)
	{
		traindata=CDoubleVector::loadData(newTrainDataFileName,initialSize,dim);

		CReadLabel readTrainDataLabel;
		trainDataLabel = readTrainDataLabel.loadLabel(newTrainDataFileName,initialSize,dataType);
	}

	else if(strcmp(dataType,"string")==0)
	{
		traindata=CStringObject::loadData(newTrainDataFileName,initialSize);

		CReadLabel readTrainDataLabel;
		trainDataLabel = readTrainDataLabel.loadLabel(newTrainDataFileName,initialSize,dataType);
	}
	else if(strcmp(dataType,"dna")==0)
	{
		traindata=CDNA_CLASSIFY::loadData(newTrainDataFileName,initialSize,dim);

		CReadLabel readTrainDataLabel;
		trainDataLabel = readTrainDataLabel.loadLabel(newTrainDataFileName,initialSize/57,dataType);
	}

	else if(strcmp(dataType,"time_series")==0)
	{
		traindata=CTimeSeries::loadData(newTrainDataFileName,initialSize,dim);

		CReadLabel readTrainDataLabel;
		trainDataLabel = readTrainDataLabel.loadLabel(newTrainDataFileName,initialSize,dataType);
	}

	else if(strcmp(dataType,"image")==0)
	{
		traindata=CImage::loadData(newTrainDataFileName,initialSize,dim);
	}
	
	else if(strcmp(dataType,"peptide")==0)
	{
		traindata=CPeptide::loadData(newTrainDataFileName,initialSize,dim);
	}

	CMetricDistance *metric=0;
	if(strcmp(disfun,"EuclideanDistance")==0)
	{
		metric = new CEuclideanDistance;
	}
	else if(strcmp(disfun,"EditDistance")==0)
	{
		metric = new CEditDistance;
	}
	else if(strcmp(disfun,"TimeSeriesMetric")==0)
	{
		metric = new CTimeSeriesMetric;
	}
	else if(strcmp(disfun,"ImageMetric")==0)
	{
		metric = new CImageMetric;
	}
	else if(strcmp(disfun,"DNAMetric")==0)
	{
		metric = new CDNAMetric;
	}
	else if(strcmp(disfun,"PeptideMetric")==0)
	{
		metric = new CPeptideMetric;
	}
	else if(strcmp(disfun,"LInfinityDistance")==0)
	{
		metric = new CLInfinityDistance;
	}

	CountedMetric *cmetric=new CountedMetric(metric);
	CPivotSelectionMethod *pivotselectionmethod=0;
	if(strcmp(pivotSelectionMethod,"fft")==0)
	{
		pivotselectionmethod = new CFFTPivotSelectionMethod;
	}
	else if(strcmp(pivotSelectionMethod,"pcaonfft")==0)
	{
		pivotselectionmethod = new CPCAOnFFT(2);
	}
	else if(strcmp(pivotSelectionMethod,"incremental")==0)
	{
		pivotselectionmethod = new CIncrementalSelection(1,2);
	}
	
	if(numPivot > initialSize)
	{
		cout << "The pivot number larger than the traindata size. Please reset the number of pivot !" << endl;
		exit(0);
	}

	if(strcmp(classifyMethod,"knn")==0)
	{
		CTrain_Knn train_knn;
		train_knn.TrainModel(classifyMethod,traindata,trainDataLabel,cmetric,pivotselectionmethod,numPivot,newPivotsAndTrainModelFileName,dim,coordinate);
	}
	else if(strcmp(classifyMethod,"knnCrossValid")==0)
	{
		CTrain_Knn train_knn;
		train_knn.TrainModelUseCrossValidation(classifyMethod,traindata,trainDataLabel,cmetric,pivotselectionmethod,numPivot,newPivotsAndTrainModelFileName,dim,coordinate);
	}
	else if(strcmp(classifyMethod,"naviebayes")==0)
	{
		CTrain_NavieBayes train_naviebayes;
		train_naviebayes.TrainModel(classifyMethod,traindata,trainDataLabel,cmetric,pivotselectionmethod,numPivot,newPivotsAndTrainModelFileName,newTrainDataFileName,dim,coordinate);
	}
	else if(strcmp(classifyMethod,"c4.5")==0)
	{
		C4_5 c45;
		c45.TrainModel(classifyMethod,traindata,trainDataLabel,cmetric,pivotselectionmethod,numPivot,newPivotsAndTrainModelFileName,newTrainDataFileName,dim,coordinate);
	}
	else if(strcmp(classifyMethod,"svm")==0)
	{
		CDatasetInMetricSpace getTrainData;
		GetMetricData M_traindata;
		M_traindata=getTrainData.getMetricTrainData(classifyMethod,traindata,trainDataLabel,cmetric,pivotselectionmethod,numPivot,newPivotsAndTrainModelFileName,dim,coordinate);
		CTrain_SVM svm(M_traindata);
		svm.TrainModel(newPivotsAndTrainModelFileName);
	}
}
コード例 #3
0
ファイル: classification.cpp プロジェクト: ruimao/UMAD
void Test(char *dataType,char *disfun,char *classifyMethod,char *testDataFileName,int finalSize,int dim,char *pivotsAndTrainModelFileName,char *testModelFileName,int status,int k,int splitRatio)
{
#ifdef __GNUC__
	char *newTestDataFileName="../../../SourceFiles/util/data/";
#endif

#ifdef  _WIN32
	char *newTestDataFileName="./SourceFiles/util/data/";
#endif
	joinCharArray(newTestDataFileName,testDataFileName);
	
#ifdef __GNUC__
	char *newPivotsAndTrainModelFileName="../../../SourceFiles/util/result/";
#endif

#ifdef _WIN32
	char *newPivotsAndTrainModelFileName="./SourceFiles/util/result/";
#endif
	joinCharArray(newPivotsAndTrainModelFileName,pivotsAndTrainModelFileName);

#ifdef __GNUC__
	char *newTestModelFileName="../../../SourceFiles/util/result/";
#endif

#ifdef _WIN32
	char *newTestModelFileName="./SourceFiles/util/result/";	
#endif
	joinCharArray(newTestModelFileName,testModelFileName);

	vector<shared_ptr<CMetricData> > *testdata=0; //store the testdata
	vector<string> testDataLabel;

	if(strcmp(dataType,"vector")==0)
	{
		testdata=CDoubleVector::loadData(newTestDataFileName,finalSize,dim);

		CReadLabel readTestDataLabel;
		testDataLabel=readTestDataLabel.loadLabel(newTestDataFileName,finalSize,dataType);
	}

	else if(strcmp(dataType,"string")==0)
	{
		testdata=CStringObject::loadData(newTestDataFileName,finalSize);

		CReadLabel readTestDataLabel;
		testDataLabel=readTestDataLabel.loadLabel(newTestDataFileName,finalSize,dataType);
	}
	else if(strcmp(dataType,"dna")==0)
	{
		testdata=CDNA_CLASSIFY::loadData(newTestDataFileName,finalSize,dim);

		CReadLabel readTestDataLabel;
		testDataLabel=readTestDataLabel.loadLabel(newTestDataFileName,finalSize/57,dataType);
	}

	else if(strcmp(dataType,"time_series")==0)
	{
		testdata=CTimeSeries::loadData(newTestDataFileName,finalSize,dim);

		CReadLabel readTestDataLabel;
		testDataLabel=readTestDataLabel.loadLabel(newTestDataFileName,finalSize,dataType);
	}

	else if(strcmp(dataType,"image")==0)
	{
		testdata=CImage::loadData(newTestDataFileName ,finalSize, dim);
	}
	
	else if(strcmp(dataType,"peptide")==0)
	{
		testdata=CPeptide::loadData(newTestDataFileName,finalSize,dim);
	}	

	CMetricDistance *metric=0;
	if(strcmp(disfun,"EuclideanDistance")==0)
	{
		metric = new CEuclideanDistance;
	}
	else if(strcmp(disfun,"EditDistance")==0)
	{
		metric = new CEditDistance;
	}
	else if(strcmp(disfun,"TimeSeriesMetric")==0)
	{
		metric = new CTimeSeriesMetric;
	}
	else if(strcmp(disfun,"ImageMetric")==0)
	{
		metric = new CImageMetric;
	}
	else if(strcmp(disfun,"DNAMetric")==0)
	{
		metric = new CDNAMetric;
	}
	else if(strcmp(disfun,"PeptideMetric")==0)
	{
		metric = new CPeptideMetric;
	}
	else if(strcmp(disfun,"LInfinityDistance")==0)
	{
		metric = new CLInfinityDistance;
	}

	CountedMetric *cmetric=new CountedMetric(metric);
	if(strcmp(classifyMethod,"knn")==0)
	{
		CTest_Knn test_knn;
		test_knn.TestModel(classifyMethod,testdata,testDataLabel,cmetric,newPivotsAndTrainModelFileName,newTestModelFileName,status,k,splitRatio);
	}
	else if(strcmp(classifyMethod,"naviebayes")==0)
	{
		CTest_NavieBayes test_naviebayes;
		test_naviebayes.TestModel(classifyMethod,testdata,testDataLabel,cmetric,newPivotsAndTrainModelFileName,newTestModelFileName,status,splitRatio);
	}
	else if(strcmp(classifyMethod,"c4.5")==0)
	{
		CTest_C4_5 test_c45;
		test_c45.TestModel(classifyMethod,testdata,testDataLabel,cmetric,newPivotsAndTrainModelFileName,newTestModelFileName,status,splitRatio);
	}
	else if(strcmp(classifyMethod,"svm")==0)
	{
		CDatasetInMetricSpace getTestData;
		GetMetricData M_testdata;
		if(status==0)
		{
			M_testdata=getTestData.getMetricTestData_fromTrainData(classifyMethod,testdata,testDataLabel,cmetric,newPivotsAndTrainModelFileName,splitRatio);
			printf("\n\nEvaluation on training data (%d items):\n", M_testdata.metricData.size());
		}
		else if(status==1)
		{
			M_testdata=getTestData.getMetricTestData_fromTestData(classifyMethod,testdata,testDataLabel,cmetric,newPivotsAndTrainModelFileName);
			printf("\nEvaluation on test data (%d items):\n", M_testdata.metricData.size());
		}
		CTest_SVM svm(M_testdata);
		svm.TestModel(newPivotsAndTrainModelFileName,newTestModelFileName);
	}
}