bool BAG::train_(ClassificationData &trainingData){ //Clear any previous models clear(); const unsigned int M = trainingData.getNumSamples(); const unsigned int N = trainingData.getNumDimensions(); const unsigned int K = trainingData.getNumClasses(); if( M == 0 ){ errorLog << "train_(ClassificationData &trainingData) - Training data has zero samples!" << endl; return false; } numInputDimensions = N; numClasses = K; classLabels.resize(K); ranges = trainingData.getRanges(); //Scale the training data if needed if( useScaling ){ //Scale the training data between 0 and 1 trainingData.scale(0, 1); } UINT ensembleSize = (UINT)ensemble.size(); if( ensembleSize == 0 ){ errorLog << "train_(ClassificationData &trainingData) - The ensemble size is zero! You need to add some classifiers to the ensemble first." << endl; return false; } for(UINT i=0; i<ensembleSize; i++){ if( ensemble[i] == NULL ){ errorLog << "train_(ClassificationData &trainingData) - The classifier at ensemble index " << i << " has not been set!" << endl; return false; } } //Train the ensemble for(UINT i=0; i<ensembleSize; i++){ ClassificationData boostedDataset = trainingData.getBootstrappedDataset(); trainingLog << "Training ensemble " << i+1 << ". Ensemble type: " << ensemble[i]->getClassType() << endl; //Train the classifier with the bootstrapped dataset if( !ensemble[i]->train( boostedDataset ) ){ errorLog << "train_(ClassificationData &trainingData) - The classifier at ensemble index " << i << " failed training!" << endl; return false; } } //Set the class labels classLabels = trainingData.getClassLabels(); //Flag that the model has been trained trained = true; return trained; }
bool RandomForests::train_(ClassificationData &trainingData){ //Clear any previous model clear(); const unsigned int M = trainingData.getNumSamples(); const unsigned int N = trainingData.getNumDimensions(); const unsigned int K = trainingData.getNumClasses(); if( M == 0 ){ errorLog << "train_(ClassificationData &trainingData) - Training data has zero samples!" << endl; return false; } if( bootstrappedDatasetWeight <= 0.0 || bootstrappedDatasetWeight > 1.0 ){ errorLog << "train_(ClassificationData &trainingData) - Bootstrapped Dataset Weight must be [> 0.0 and <= 1.0]" << endl; return false; } numInputDimensions = N; numClasses = K; classLabels = trainingData.getClassLabels(); ranges = trainingData.getRanges(); //Scale the training data if needed if( useScaling ){ //Scale the training data between 0 and 1 trainingData.scale(0, 1); } //Flag that the main algorithm has been trained encase we need to trigger any callbacks trained = true; //Train the random forest forest.reserve( forestSize ); for(UINT i=0; i<forestSize; i++){ //Get a balanced bootstrapped dataset UINT datasetSize = (UINT)(trainingData.getNumSamples() * bootstrappedDatasetWeight); ClassificationData data = trainingData.getBootstrappedDataset( datasetSize, true ); DecisionTree tree; tree.setDecisionTreeNode( *decisionTreeNode ); tree.enableScaling( false ); //We have already scaled the training data so we do not need to scale it again tree.setTrainingMode( trainingMode ); tree.setNumSplittingSteps( numRandomSplits ); tree.setMinNumSamplesPerNode( minNumSamplesPerNode ); tree.setMaxDepth( maxDepth ); tree.enableNullRejection( useNullRejection ); tree.setRemoveFeaturesAtEachSpilt( removeFeaturesAtEachSpilt ); trainingLog << "Training forest " << i+1 << "/" << forestSize << "..." << endl; //Train this tree if( !tree.train( data ) ){ errorLog << "train_(ClassificationData &labelledTrainingData) - Failed to train tree at forest index: " << i << endl; clear(); return false; } //Deep copy the tree into the forest forest.push_back( tree.deepCopyTree() ); } return true; }
bool RandomForests::train_(ClassificationData &trainingData){ //Clear any previous model clear(); const unsigned int M = trainingData.getNumSamples(); const unsigned int N = trainingData.getNumDimensions(); const unsigned int K = trainingData.getNumClasses(); if( M == 0 ){ errorLog << "train_(ClassificationData &trainingData) - Training data has zero samples!" << std::endl; return false; } if( bootstrappedDatasetWeight <= 0.0 || bootstrappedDatasetWeight > 1.0 ){ errorLog << "train_(ClassificationData &trainingData) - Bootstrapped Dataset Weight must be [> 0.0 and <= 1.0]" << std::endl; return false; } numInputDimensions = N; numClasses = K; classLabels = trainingData.getClassLabels(); ranges = trainingData.getRanges(); //Scale the training data if needed if( useScaling ){ //Scale the training data between 0 and 1 trainingData.scale(0, 1); } if( useValidationSet ){ validationSetAccuracy = 0; validationSetPrecision.resize( useNullRejection ? K+1 : K, 0 ); validationSetRecall.resize( useNullRejection ? K+1 : K, 0 ); } //Flag that the main algorithm has been trained encase we need to trigger any callbacks trained = true; //Train the random forest forest.reserve( forestSize ); for(UINT i=0; i<forestSize; i++){ //Get a balanced bootstrapped dataset UINT datasetSize = (UINT)(trainingData.getNumSamples() * bootstrappedDatasetWeight); ClassificationData data = trainingData.getBootstrappedDataset( datasetSize, true ); Timer timer; timer.start(); DecisionTree tree; tree.setDecisionTreeNode( *decisionTreeNode ); tree.enableScaling( false ); //We have already scaled the training data so we do not need to scale it again tree.setUseValidationSet( useValidationSet ); tree.setValidationSetSize( validationSetSize ); tree.setTrainingMode( trainingMode ); tree.setNumSplittingSteps( numRandomSplits ); tree.setMinNumSamplesPerNode( minNumSamplesPerNode ); tree.setMaxDepth( maxDepth ); tree.enableNullRejection( useNullRejection ); tree.setRemoveFeaturesAtEachSpilt( removeFeaturesAtEachSpilt ); trainingLog << "Training decision tree " << i+1 << "/" << forestSize << "..." << std::endl; //Train this tree if( !tree.train_( data ) ){ errorLog << "train_(ClassificationData &trainingData) - Failed to train tree at forest index: " << i << std::endl; clear(); return false; } Float computeTime = timer.getMilliSeconds(); trainingLog << "Decision tree trained in " << (computeTime*0.001)/60.0 << " minutes" << std::endl; if( useValidationSet ){ Float forestNorm = 1.0 / forestSize; validationSetAccuracy += tree.getValidationSetAccuracy(); VectorFloat precision = tree.getValidationSetPrecision(); VectorFloat recall = tree.getValidationSetRecall(); grt_assert( precision.getSize() == validationSetPrecision.getSize() ); grt_assert( recall.getSize() == validationSetRecall.getSize() ); for(UINT i=0; i<validationSetPrecision.getSize(); i++){ validationSetPrecision[i] += precision[i] * forestNorm; } for(UINT i=0; i<validationSetRecall.getSize(); i++){ validationSetRecall[i] += recall[i] * forestNorm; } } //Deep copy the tree into the forest forest.push_back( tree.deepCopyTree() ); } if( useValidationSet ){ validationSetAccuracy /= forestSize; trainingLog << "Validation set accuracy: " << validationSetAccuracy << std::endl; trainingLog << "Validation set precision: "; for(UINT i=0; i<validationSetPrecision.getSize(); i++){ trainingLog << validationSetPrecision[i] << " "; } trainingLog << std::endl; trainingLog << "Validation set recall: "; for(UINT i=0; i<validationSetRecall.getSize(); i++){ trainingLog << validationSetRecall[i] << " "; } trainingLog << std::endl; } return true; }