void Train(char *dataType,char *disfun,char *pivotSelectionMethod,int numPivot,char *classifyMethod,char *trainDataFileName,int initialSize,int dim,char *pivotsAndTrainModelFileName,int coordinate) { #ifdef __GNUC__ char *newTrainDataFileName="../../../SourceFiles/util/data/"; #endif #ifdef _WIN32 char *newTrainDataFileName="./SourceFiles/util/data/"; #endif joinCharArray(newTrainDataFileName,trainDataFileName); #ifdef __GNUC__ char *newPivotsAndTrainModelFileName="../../../SourceFiles/util/result/"; #endif #ifdef _WIN32 char *newPivotsAndTrainModelFileName="./SourceFiles/util/result/"; #endif joinCharArray(newPivotsAndTrainModelFileName,pivotsAndTrainModelFileName); vector<shared_ptr<CMetricData> > *traindata=0; //store the traindata vector<string> trainDataLabel; if(strcmp(dataType,"vector")==0) { traindata=CDoubleVector::loadData(newTrainDataFileName,initialSize,dim); CReadLabel readTrainDataLabel; trainDataLabel = readTrainDataLabel.loadLabel(newTrainDataFileName,initialSize,dataType); } else if(strcmp(dataType,"string")==0) { traindata=CStringObject::loadData(newTrainDataFileName,initialSize); CReadLabel readTrainDataLabel; trainDataLabel = readTrainDataLabel.loadLabel(newTrainDataFileName,initialSize,dataType); } else if(strcmp(dataType,"dna")==0) { traindata=CDNA_CLASSIFY::loadData(newTrainDataFileName,initialSize,dim); CReadLabel readTrainDataLabel; trainDataLabel = readTrainDataLabel.loadLabel(newTrainDataFileName,initialSize/57,dataType); } else if(strcmp(dataType,"time_series")==0) { traindata=CTimeSeries::loadData(newTrainDataFileName,initialSize,dim); CReadLabel readTrainDataLabel; trainDataLabel = readTrainDataLabel.loadLabel(newTrainDataFileName,initialSize,dataType); } else if(strcmp(dataType,"image")==0) { traindata=CImage::loadData(newTrainDataFileName,initialSize,dim); } else if(strcmp(dataType,"peptide")==0) { traindata=CPeptide::loadData(newTrainDataFileName,initialSize,dim); } CMetricDistance *metric=0; if(strcmp(disfun,"EuclideanDistance")==0) { metric = new CEuclideanDistance; } else if(strcmp(disfun,"EditDistance")==0) { metric = new CEditDistance; } else if(strcmp(disfun,"TimeSeriesMetric")==0) { metric = new CTimeSeriesMetric; } else if(strcmp(disfun,"ImageMetric")==0) { metric = new CImageMetric; } else if(strcmp(disfun,"DNAMetric")==0) { metric = new CDNAMetric; } else if(strcmp(disfun,"PeptideMetric")==0) { metric = new CPeptideMetric; } else if(strcmp(disfun,"LInfinityDistance")==0) { metric = new CLInfinityDistance; } CountedMetric *cmetric=new CountedMetric(metric); CPivotSelectionMethod *pivotselectionmethod=0; if(strcmp(pivotSelectionMethod,"fft")==0) { pivotselectionmethod = new CFFTPivotSelectionMethod; } else if(strcmp(pivotSelectionMethod,"pcaonfft")==0) { pivotselectionmethod = new CPCAOnFFT(2); } else if(strcmp(pivotSelectionMethod,"incremental")==0) { pivotselectionmethod = new CIncrementalSelection(1,2); } if(numPivot > initialSize) { cout << "The pivot number larger than the traindata size. Please reset the number of pivot !" << endl; exit(0); } if(strcmp(classifyMethod,"knn")==0) { CTrain_Knn train_knn; train_knn.TrainModel(classifyMethod,traindata,trainDataLabel,cmetric,pivotselectionmethod,numPivot,newPivotsAndTrainModelFileName,dim,coordinate); } else if(strcmp(classifyMethod,"knnCrossValid")==0) { CTrain_Knn train_knn; train_knn.TrainModelUseCrossValidation(classifyMethod,traindata,trainDataLabel,cmetric,pivotselectionmethod,numPivot,newPivotsAndTrainModelFileName,dim,coordinate); } else if(strcmp(classifyMethod,"naviebayes")==0) { CTrain_NavieBayes train_naviebayes; train_naviebayes.TrainModel(classifyMethod,traindata,trainDataLabel,cmetric,pivotselectionmethod,numPivot,newPivotsAndTrainModelFileName,newTrainDataFileName,dim,coordinate); } else if(strcmp(classifyMethod,"c4.5")==0) { C4_5 c45; c45.TrainModel(classifyMethod,traindata,trainDataLabel,cmetric,pivotselectionmethod,numPivot,newPivotsAndTrainModelFileName,newTrainDataFileName,dim,coordinate); } else if(strcmp(classifyMethod,"svm")==0) { CDatasetInMetricSpace getTrainData; GetMetricData M_traindata; M_traindata=getTrainData.getMetricTrainData(classifyMethod,traindata,trainDataLabel,cmetric,pivotselectionmethod,numPivot,newPivotsAndTrainModelFileName,dim,coordinate); CTrain_SVM svm(M_traindata); svm.TrainModel(newPivotsAndTrainModelFileName); } }