コード例 #1
0
ファイル: RandomForests.cpp プロジェクト: BryanBo-Cao/grt
bool RandomForests::train_(ClassificationData &trainingData){
    
    //Clear any previous model
    clear();
    
    const unsigned int M = trainingData.getNumSamples();
    const unsigned int N = trainingData.getNumDimensions();
    const unsigned int K = trainingData.getNumClasses();
    
    if( M == 0 ){
        errorLog << "train_(ClassificationData &trainingData) - Training data has zero samples!" << std::endl;
        return false;
    }

    if( bootstrappedDatasetWeight <= 0.0 || bootstrappedDatasetWeight > 1.0 ){
        errorLog << "train_(ClassificationData &trainingData) - Bootstrapped Dataset Weight must be [> 0.0 and <= 1.0]" << std::endl;
        return false;
    }
    
    numInputDimensions = N;
    numClasses = K;
    classLabels = trainingData.getClassLabels();
    ranges = trainingData.getRanges();
    
    //Scale the training data if needed
    if( useScaling ){
        //Scale the training data between 0 and 1
        trainingData.scale(0, 1);
    }

    if( useValidationSet ){
        validationSetAccuracy = 0;
        validationSetPrecision.resize( useNullRejection ? K+1 : K, 0 );
        validationSetRecall.resize( useNullRejection ? K+1 : K, 0 );
    }
    
    //Flag that the main algorithm has been trained encase we need to trigger any callbacks
    trained = true;
    
    //Train the random forest
    forest.reserve( forestSize );

    for(UINT i=0; i<forestSize; i++){
        
        //Get a balanced bootstrapped dataset
        UINT datasetSize = (UINT)(trainingData.getNumSamples() * bootstrappedDatasetWeight);
        ClassificationData data = trainingData.getBootstrappedDataset( datasetSize, true );

        Timer timer;
        timer.start();
 
        DecisionTree tree;
        tree.setDecisionTreeNode( *decisionTreeNode );
        tree.enableScaling( false ); //We have already scaled the training data so we do not need to scale it again
        tree.setUseValidationSet( useValidationSet );
        tree.setValidationSetSize( validationSetSize );
        tree.setTrainingMode( trainingMode );
        tree.setNumSplittingSteps( numRandomSplits );
        tree.setMinNumSamplesPerNode( minNumSamplesPerNode );
        tree.setMaxDepth( maxDepth );
        tree.enableNullRejection( useNullRejection );
        tree.setRemoveFeaturesAtEachSpilt( removeFeaturesAtEachSpilt );

        trainingLog << "Training decision tree " << i+1 << "/" << forestSize << "..." << std::endl;
        
        //Train this tree
        if( !tree.train_( data ) ){
            errorLog << "train_(ClassificationData &trainingData) - Failed to train tree at forest index: " << i << std::endl;
            clear();
            return false;
        }

        Float computeTime = timer.getMilliSeconds();
        trainingLog << "Decision tree trained in " << (computeTime*0.001)/60.0 << " minutes" << std::endl;

        if( useValidationSet ){
            Float forestNorm = 1.0 / forestSize;
            validationSetAccuracy += tree.getValidationSetAccuracy();
            VectorFloat precision = tree.getValidationSetPrecision();
            VectorFloat recall = tree.getValidationSetRecall();

            grt_assert( precision.getSize() == validationSetPrecision.getSize() );
            grt_assert( recall.getSize() == validationSetRecall.getSize() );

            for(UINT i=0; i<validationSetPrecision.getSize(); i++){
                validationSetPrecision[i] += precision[i] * forestNorm;
            }

            for(UINT i=0; i<validationSetRecall.getSize(); i++){
                validationSetRecall[i] += recall[i] * forestNorm;
            }

        }
        
        //Deep copy the tree into the forest
        forest.push_back( tree.deepCopyTree() );
    }

    if( useValidationSet ){
        validationSetAccuracy /= forestSize;
        trainingLog << "Validation set accuracy: " << validationSetAccuracy << std::endl;

        trainingLog << "Validation set precision: ";
        for(UINT i=0; i<validationSetPrecision.getSize(); i++){
            trainingLog << validationSetPrecision[i] << " ";
        }
        trainingLog << std::endl;

        trainingLog << "Validation set recall: ";
        for(UINT i=0; i<validationSetRecall.getSize(); i++){
            trainingLog << validationSetRecall[i] << " ";
        }
        trainingLog << std::endl;
    }

    return true;
}
コード例 #2
0
ファイル: RandomForests.cpp プロジェクト: hal2001/grt
bool RandomForests::train_(ClassificationData &trainingData){
    
    //Clear any previous model
    clear();
    
    const unsigned int M = trainingData.getNumSamples();
    const unsigned int N = trainingData.getNumDimensions();
    const unsigned int K = trainingData.getNumClasses();
    
    if( M == 0 ){
        errorLog << "train_(ClassificationData &trainingData) - Training data has zero samples!" << endl;
        return false;
    }

    if( bootstrappedDatasetWeight <= 0.0 || bootstrappedDatasetWeight > 1.0 ){
        errorLog << "train_(ClassificationData &trainingData) - Bootstrapped Dataset Weight must be [> 0.0 and <= 1.0]" << endl;
        return false;
    }
    
    numInputDimensions = N;
    numClasses = K;
    classLabels = trainingData.getClassLabels();
    ranges = trainingData.getRanges();
    
    //Scale the training data if needed
    if( useScaling ){
        //Scale the training data between 0 and 1
        trainingData.scale(0, 1);
    }
    
    //Flag that the main algorithm has been trained encase we need to trigger any callbacks
    trained = true;
    
    //Train the random forest
    forest.reserve( forestSize );
    for(UINT i=0; i<forestSize; i++){
        
        //Get a balanced bootstrapped dataset
        UINT datasetSize = (UINT)(trainingData.getNumSamples() * bootstrappedDatasetWeight);
        ClassificationData data = trainingData.getBootstrappedDataset( datasetSize, true );
 
        DecisionTree tree;
        tree.setDecisionTreeNode( *decisionTreeNode );
        tree.enableScaling( false ); //We have already scaled the training data so we do not need to scale it again
        tree.setTrainingMode( trainingMode );
        tree.setNumSplittingSteps( numRandomSplits );
        tree.setMinNumSamplesPerNode( minNumSamplesPerNode );
        tree.setMaxDepth( maxDepth );
        tree.enableNullRejection( useNullRejection );
        tree.setRemoveFeaturesAtEachSpilt( removeFeaturesAtEachSpilt );

        trainingLog << "Training forest " << i+1 << "/" << forestSize << "..." << endl;
        
        //Train this tree
        if( !tree.train( data ) ){
            errorLog << "train_(ClassificationData &labelledTrainingData) - Failed to train tree at forest index: " << i << endl;
            clear();
            return false;
        }
        
        //Deep copy the tree into the forest
        forest.push_back( tree.deepCopyTree() );
    }

    return true;
}
コード例 #3
0
ファイル: DecisionTreeExample.cpp プロジェクト: sgrignard/grt
int main(int argc, const char * argv[])
{
    //Parse the data filename from the argument list
    if( argc != 2 ){
        cout << "Error: failed to parse data filename from command line. You should run this example with one argument pointing to the data filename!\n";
        return EXIT_FAILURE;
    }
    const string filename = argv[1];

    //Create a new DecisionTree instance
    DecisionTree dTree;
    
    //Set the node that the DecisionTree will use - different nodes may result in different decision boundaries
    //and some nodes may provide better accuracy than others on specific classification tasks
    //The current node options are:
    //- DecisionTreeClusterNode
    //- DecisionTreeThresholdNode
    dTree.setDecisionTreeNode( DecisionTreeClusterNode() );
    
    //Set the number of steps that will be used to choose the best splitting values
    //More steps will give you a better model, but will take longer to train
    dTree.setNumSplittingSteps( 1000 );
    
    //Set the maximum depth of the tree
    dTree.setMaxDepth( 10 );
    
    //Set the minimum number of samples allowed per node
    dTree.setMinNumSamplesPerNode( 10 );
    
    //Load some training data to train the classifier
    ClassificationData trainingData;
    
    if( !trainingData.load( filename ) ){
        cout << "Failed to load training data: " << filename << endl;
        return EXIT_FAILURE;
    }
    
    //Use 20% of the training dataset to create a test dataset
    ClassificationData testData = trainingData.split( 80 );
    
    //Train the classifier
    if( !dTree.train( trainingData ) ){
        cout << "Failed to train classifier!\n";
        return EXIT_FAILURE;
    }
    
    //Print the tree
    dTree.print();
    
    //Save the model to a file
    if( !dTree.save("DecisionTreeModel.grt") ){
        cout << "Failed to save the classifier model!\n";
        return EXIT_FAILURE;
    }
    
    //Load the model from a file
    if( !dTree.load("DecisionTreeModel.grt") ){
        cout << "Failed to load the classifier model!\n";
        return EXIT_FAILURE;
    }
    
    //Test the accuracy of the model on the test data
    double accuracy = 0;
    for(UINT i=0; i<testData.getNumSamples(); i++){
        //Get the i'th test sample
        UINT classLabel = testData[i].getClassLabel();
        VectorDouble inputVector = testData[i].getSample();
        
        //Perform a prediction using the classifier
        bool predictSuccess = dTree.predict( inputVector );
        
        if( !predictSuccess ){
            cout << "Failed to perform prediction for test sampel: " << i <<"\n";
            return EXIT_FAILURE;
        }
        
        //Get the predicted class label
        UINT predictedClassLabel = dTree.getPredictedClassLabel();
        VectorDouble classLikelihoods = dTree.getClassLikelihoods();
        VectorDouble classDistances = dTree.getClassDistances();
        
        //Update the accuracy
        if( classLabel == predictedClassLabel ) accuracy++;
        
        cout << "TestSample: " << i <<  " ClassLabel: " << classLabel << " PredictedClassLabel: " << predictedClassLabel << endl;
    }
    
    cout << "Test Accuracy: " << accuracy/double(testData.getNumSamples())*100.0 << "%" << endl;
    
    return EXIT_SUCCESS;
}