コード例 #1
0
ファイル: classification.cpp プロジェクト: ruimao/UMAD
void Train(char *dataType,char *disfun,char *pivotSelectionMethod,int numPivot,char *classifyMethod,char *trainDataFileName,int initialSize,int dim,char *pivotsAndTrainModelFileName,int coordinate)
{
#ifdef __GNUC__
	char *newTrainDataFileName="../../../SourceFiles/util/data/";
#endif

#ifdef _WIN32
	char *newTrainDataFileName="./SourceFiles/util/data/";
#endif
	joinCharArray(newTrainDataFileName,trainDataFileName);

#ifdef __GNUC__
	char *newPivotsAndTrainModelFileName="../../../SourceFiles/util/result/";
#endif

#ifdef _WIN32
	char *newPivotsAndTrainModelFileName="./SourceFiles/util/result/";
#endif
	joinCharArray(newPivotsAndTrainModelFileName,pivotsAndTrainModelFileName);

	vector<shared_ptr<CMetricData> > *traindata=0; //store the traindata
	vector<string> trainDataLabel;

	if(strcmp(dataType,"vector")==0)
	{
		traindata=CDoubleVector::loadData(newTrainDataFileName,initialSize,dim);

		CReadLabel readTrainDataLabel;
		trainDataLabel = readTrainDataLabel.loadLabel(newTrainDataFileName,initialSize,dataType);
	}

	else if(strcmp(dataType,"string")==0)
	{
		traindata=CStringObject::loadData(newTrainDataFileName,initialSize);

		CReadLabel readTrainDataLabel;
		trainDataLabel = readTrainDataLabel.loadLabel(newTrainDataFileName,initialSize,dataType);
	}
	else if(strcmp(dataType,"dna")==0)
	{
		traindata=CDNA_CLASSIFY::loadData(newTrainDataFileName,initialSize,dim);

		CReadLabel readTrainDataLabel;
		trainDataLabel = readTrainDataLabel.loadLabel(newTrainDataFileName,initialSize/57,dataType);
	}

	else if(strcmp(dataType,"time_series")==0)
	{
		traindata=CTimeSeries::loadData(newTrainDataFileName,initialSize,dim);

		CReadLabel readTrainDataLabel;
		trainDataLabel = readTrainDataLabel.loadLabel(newTrainDataFileName,initialSize,dataType);
	}

	else if(strcmp(dataType,"image")==0)
	{
		traindata=CImage::loadData(newTrainDataFileName,initialSize,dim);
	}
	
	else if(strcmp(dataType,"peptide")==0)
	{
		traindata=CPeptide::loadData(newTrainDataFileName,initialSize,dim);
	}

	CMetricDistance *metric=0;
	if(strcmp(disfun,"EuclideanDistance")==0)
	{
		metric = new CEuclideanDistance;
	}
	else if(strcmp(disfun,"EditDistance")==0)
	{
		metric = new CEditDistance;
	}
	else if(strcmp(disfun,"TimeSeriesMetric")==0)
	{
		metric = new CTimeSeriesMetric;
	}
	else if(strcmp(disfun,"ImageMetric")==0)
	{
		metric = new CImageMetric;
	}
	else if(strcmp(disfun,"DNAMetric")==0)
	{
		metric = new CDNAMetric;
	}
	else if(strcmp(disfun,"PeptideMetric")==0)
	{
		metric = new CPeptideMetric;
	}
	else if(strcmp(disfun,"LInfinityDistance")==0)
	{
		metric = new CLInfinityDistance;
	}

	CountedMetric *cmetric=new CountedMetric(metric);
	CPivotSelectionMethod *pivotselectionmethod=0;
	if(strcmp(pivotSelectionMethod,"fft")==0)
	{
		pivotselectionmethod = new CFFTPivotSelectionMethod;
	}
	else if(strcmp(pivotSelectionMethod,"pcaonfft")==0)
	{
		pivotselectionmethod = new CPCAOnFFT(2);
	}
	else if(strcmp(pivotSelectionMethod,"incremental")==0)
	{
		pivotselectionmethod = new CIncrementalSelection(1,2);
	}
	
	if(numPivot > initialSize)
	{
		cout << "The pivot number larger than the traindata size. Please reset the number of pivot !" << endl;
		exit(0);
	}

	if(strcmp(classifyMethod,"knn")==0)
	{
		CTrain_Knn train_knn;
		train_knn.TrainModel(classifyMethod,traindata,trainDataLabel,cmetric,pivotselectionmethod,numPivot,newPivotsAndTrainModelFileName,dim,coordinate);
	}
	else if(strcmp(classifyMethod,"knnCrossValid")==0)
	{
		CTrain_Knn train_knn;
		train_knn.TrainModelUseCrossValidation(classifyMethod,traindata,trainDataLabel,cmetric,pivotselectionmethod,numPivot,newPivotsAndTrainModelFileName,dim,coordinate);
	}
	else if(strcmp(classifyMethod,"naviebayes")==0)
	{
		CTrain_NavieBayes train_naviebayes;
		train_naviebayes.TrainModel(classifyMethod,traindata,trainDataLabel,cmetric,pivotselectionmethod,numPivot,newPivotsAndTrainModelFileName,newTrainDataFileName,dim,coordinate);
	}
	else if(strcmp(classifyMethod,"c4.5")==0)
	{
		C4_5 c45;
		c45.TrainModel(classifyMethod,traindata,trainDataLabel,cmetric,pivotselectionmethod,numPivot,newPivotsAndTrainModelFileName,newTrainDataFileName,dim,coordinate);
	}
	else if(strcmp(classifyMethod,"svm")==0)
	{
		CDatasetInMetricSpace getTrainData;
		GetMetricData M_traindata;
		M_traindata=getTrainData.getMetricTrainData(classifyMethod,traindata,trainDataLabel,cmetric,pivotselectionmethod,numPivot,newPivotsAndTrainModelFileName,dim,coordinate);
		CTrain_SVM svm(M_traindata);
		svm.TrainModel(newPivotsAndTrainModelFileName);
	}
}