bool RandomForests::train(LabelledClassificationData trainingData){ //Clear any previous model clear(); const unsigned int M = trainingData.getNumSamples(); const unsigned int N = trainingData.getNumDimensions(); const unsigned int K = trainingData.getNumClasses(); if( M == 0 ){ errorLog << "train(LabelledClassificationData labelledTrainingData) - Training data has zero samples!" << endl; return false; } numInputDimensions = N; numClasses = K; classLabels = trainingData.getClassLabels(); ranges = trainingData.getRanges(); //Scale the training data if needed if( useScaling ){ //Scale the training data between 0 and 1 trainingData.scale(0, 1); } //Train the random forest forestSize = 10; Random random; DecisionTree tree; tree.enableScaling( false ); //We have already scaled the training data so we do not need to scale it again tree.setTrainingMode( DecisionTree::BEST_RANDOM_SPLIT ); tree.setNumSplittingSteps( numRandomSplits ); tree.setMinNumSamplesPerNode( minNumSamplesPerNode ); tree.setMaxDepth( maxDepth ); for(UINT i=0; i<forestSize; i++){ LabelledClassificationData data = trainingData.getBootstrappedDataset(); if( !tree.train( data ) ){ errorLog << "train(LabelledClassificationData labelledTrainingData) - Failed to train tree at forest index: " << i << endl; return false; } //Deep copy the tree into the forest forest.push_back( tree.deepCopyTree() ); } //Flag that the algorithm has been trained trained = true; return trained; }
bool RandomForests::train_(ClassificationData &trainingData){ //Clear any previous model clear(); const unsigned int M = trainingData.getNumSamples(); const unsigned int N = trainingData.getNumDimensions(); const unsigned int K = trainingData.getNumClasses(); if( M == 0 ){ errorLog << "train_(ClassificationData &trainingData) - Training data has zero samples!" << endl; return false; } if( bootstrappedDatasetWeight <= 0.0 || bootstrappedDatasetWeight > 1.0 ){ errorLog << "train_(ClassificationData &trainingData) - Bootstrapped Dataset Weight must be [> 0.0 and <= 1.0]" << endl; return false; } numInputDimensions = N; numClasses = K; classLabels = trainingData.getClassLabels(); ranges = trainingData.getRanges(); //Scale the training data if needed if( useScaling ){ //Scale the training data between 0 and 1 trainingData.scale(0, 1); } //Flag that the main algorithm has been trained encase we need to trigger any callbacks trained = true; //Train the random forest forest.reserve( forestSize ); for(UINT i=0; i<forestSize; i++){ //Get a balanced bootstrapped dataset UINT datasetSize = (UINT)(trainingData.getNumSamples() * bootstrappedDatasetWeight); ClassificationData data = trainingData.getBootstrappedDataset( datasetSize, true ); DecisionTree tree; tree.setDecisionTreeNode( *decisionTreeNode ); tree.enableScaling( false ); //We have already scaled the training data so we do not need to scale it again tree.setTrainingMode( trainingMode ); tree.setNumSplittingSteps( numRandomSplits ); tree.setMinNumSamplesPerNode( minNumSamplesPerNode ); tree.setMaxDepth( maxDepth ); tree.enableNullRejection( useNullRejection ); tree.setRemoveFeaturesAtEachSpilt( removeFeaturesAtEachSpilt ); trainingLog << "Training forest " << i+1 << "/" << forestSize << "..." << endl; //Train this tree if( !tree.train( data ) ){ errorLog << "train_(ClassificationData &labelledTrainingData) - Failed to train tree at forest index: " << i << endl; clear(); return false; } //Deep copy the tree into the forest forest.push_back( tree.deepCopyTree() ); } return true; }
bool RandomForests::train_(ClassificationData &trainingData){ //Clear any previous model clear(); const unsigned int M = trainingData.getNumSamples(); const unsigned int N = trainingData.getNumDimensions(); const unsigned int K = trainingData.getNumClasses(); if( M == 0 ){ errorLog << "train_(ClassificationData &trainingData) - Training data has zero samples!" << std::endl; return false; } if( bootstrappedDatasetWeight <= 0.0 || bootstrappedDatasetWeight > 1.0 ){ errorLog << "train_(ClassificationData &trainingData) - Bootstrapped Dataset Weight must be [> 0.0 and <= 1.0]" << std::endl; return false; } numInputDimensions = N; numClasses = K; classLabels = trainingData.getClassLabels(); ranges = trainingData.getRanges(); //Scale the training data if needed if( useScaling ){ //Scale the training data between 0 and 1 trainingData.scale(0, 1); } if( useValidationSet ){ validationSetAccuracy = 0; validationSetPrecision.resize( useNullRejection ? K+1 : K, 0 ); validationSetRecall.resize( useNullRejection ? K+1 : K, 0 ); } //Flag that the main algorithm has been trained encase we need to trigger any callbacks trained = true; //Train the random forest forest.reserve( forestSize ); for(UINT i=0; i<forestSize; i++){ //Get a balanced bootstrapped dataset UINT datasetSize = (UINT)(trainingData.getNumSamples() * bootstrappedDatasetWeight); ClassificationData data = trainingData.getBootstrappedDataset( datasetSize, true ); Timer timer; timer.start(); DecisionTree tree; tree.setDecisionTreeNode( *decisionTreeNode ); tree.enableScaling( false ); //We have already scaled the training data so we do not need to scale it again tree.setUseValidationSet( useValidationSet ); tree.setValidationSetSize( validationSetSize ); tree.setTrainingMode( trainingMode ); tree.setNumSplittingSteps( numRandomSplits ); tree.setMinNumSamplesPerNode( minNumSamplesPerNode ); tree.setMaxDepth( maxDepth ); tree.enableNullRejection( useNullRejection ); tree.setRemoveFeaturesAtEachSpilt( removeFeaturesAtEachSpilt ); trainingLog << "Training decision tree " << i+1 << "/" << forestSize << "..." << std::endl; //Train this tree if( !tree.train_( data ) ){ errorLog << "train_(ClassificationData &trainingData) - Failed to train tree at forest index: " << i << std::endl; clear(); return false; } Float computeTime = timer.getMilliSeconds(); trainingLog << "Decision tree trained in " << (computeTime*0.001)/60.0 << " minutes" << std::endl; if( useValidationSet ){ Float forestNorm = 1.0 / forestSize; validationSetAccuracy += tree.getValidationSetAccuracy(); VectorFloat precision = tree.getValidationSetPrecision(); VectorFloat recall = tree.getValidationSetRecall(); grt_assert( precision.getSize() == validationSetPrecision.getSize() ); grt_assert( recall.getSize() == validationSetRecall.getSize() ); for(UINT i=0; i<validationSetPrecision.getSize(); i++){ validationSetPrecision[i] += precision[i] * forestNorm; } for(UINT i=0; i<validationSetRecall.getSize(); i++){ validationSetRecall[i] += recall[i] * forestNorm; } } //Deep copy the tree into the forest forest.push_back( tree.deepCopyTree() ); } if( useValidationSet ){ validationSetAccuracy /= forestSize; trainingLog << "Validation set accuracy: " << validationSetAccuracy << std::endl; trainingLog << "Validation set precision: "; for(UINT i=0; i<validationSetPrecision.getSize(); i++){ trainingLog << validationSetPrecision[i] << " "; } trainingLog << std::endl; trainingLog << "Validation set recall: "; for(UINT i=0; i<validationSetRecall.getSize(); i++){ trainingLog << validationSetRecall[i] << " "; } trainingLog << std::endl; } return true; }