bool RandomForests::train(LabelledClassificationData trainingData){ //Clear any previous model clear(); const unsigned int M = trainingData.getNumSamples(); const unsigned int N = trainingData.getNumDimensions(); const unsigned int K = trainingData.getNumClasses(); if( M == 0 ){ errorLog << "train(LabelledClassificationData labelledTrainingData) - Training data has zero samples!" << endl; return false; } numInputDimensions = N; numClasses = K; classLabels = trainingData.getClassLabels(); ranges = trainingData.getRanges(); //Scale the training data if needed if( useScaling ){ //Scale the training data between 0 and 1 trainingData.scale(0, 1); } //Train the random forest forestSize = 10; Random random; DecisionTree tree; tree.enableScaling( false ); //We have already scaled the training data so we do not need to scale it again tree.setTrainingMode( DecisionTree::BEST_RANDOM_SPLIT ); tree.setNumSplittingSteps( numRandomSplits ); tree.setMinNumSamplesPerNode( minNumSamplesPerNode ); tree.setMaxDepth( maxDepth ); for(UINT i=0; i<forestSize; i++){ LabelledClassificationData data = trainingData.getBootstrappedDataset(); if( !tree.train( data ) ){ errorLog << "train(LabelledClassificationData labelledTrainingData) - Failed to train tree at forest index: " << i << endl; return false; } //Deep copy the tree into the forest forest.push_back( tree.deepCopyTree() ); } //Flag that the algorithm has been trained trained = true; return trained; }
bool BAG::train(LabelledClassificationData trainingData){ const unsigned int M = trainingData.getNumSamples(); const unsigned int N = trainingData.getNumDimensions(); const unsigned int K = trainingData.getNumClasses(); trained = false; classLabels.clear(); if( M == 0 ){ errorLog << "train(LabelledClassificationData trainingData) - Training data has zero samples!" << endl; return false; } numFeatures = N; numClasses = K; classLabels.resize(K); ranges = trainingData.getRanges(); UINT ensembleSize = (UINT)ensemble.size(); if( ensembleSize == 0 ){ errorLog << "train(LabelledClassificationData trainingData) - The ensemble size is zero! You need to add some classifiers to the ensemble first." << endl; return false; } for(UINT i=0; i<ensembleSize; i++){ if( ensemble[i] == NULL ){ errorLog << "train(LabelledClassificationData trainingData) - The classifier at ensemble index " << i << " has not been set!" << endl; return false; } } //Train the ensemble for(UINT i=0; i<ensembleSize; i++){ LabelledClassificationData boostedDataset = trainingData.getBootstrappedDataset(); //Train the classifier with the bootstrapped dataset if( !ensemble[i]->train( boostedDataset ) ){ errorLog << "train(LabelledClassificationData trainingData) - The classifier at ensemble index " << i << " failed training!" << endl; return false; } } //Set the class labels classLabels = trainingData.getClassLabels(); //Flag that the algorithm has been trained trained = true; return trained; }