Пример #1
0
bool RandomForests::train(LabelledClassificationData trainingData){
    
    //Clear any previous model
    clear();
    
    const unsigned int M = trainingData.getNumSamples();
    const unsigned int N = trainingData.getNumDimensions();
    const unsigned int K = trainingData.getNumClasses();
    
    if( M == 0 ){
        errorLog << "train(LabelledClassificationData labelledTrainingData) - Training data has zero samples!" << endl;
        return false;
    }
    
    numInputDimensions = N;
    numClasses = K;
    classLabels = trainingData.getClassLabels();
    ranges = trainingData.getRanges();
    
    //Scale the training data if needed
    if( useScaling ){
        //Scale the training data between 0 and 1
        trainingData.scale(0, 1);
    }
    
    //Train the random forest
    forestSize = 10;
    Random random;
    
    DecisionTree tree;
    tree.enableScaling( false ); //We have already scaled the training data so we do not need to scale it again
    tree.setTrainingMode( DecisionTree::BEST_RANDOM_SPLIT );
    tree.setNumSplittingSteps( numRandomSplits );
    tree.setMinNumSamplesPerNode( minNumSamplesPerNode );
    tree.setMaxDepth( maxDepth );
    
    for(UINT i=0; i<forestSize; i++){
        LabelledClassificationData data = trainingData.getBootstrappedDataset();
        
        if( !tree.train( data ) ){
            errorLog << "train(LabelledClassificationData labelledTrainingData) - Failed to train tree at forest index: " << i << endl;
            return false;
        }
        
        //Deep copy the tree into the forest
        forest.push_back( tree.deepCopyTree() );
    }
    
    //Flag that the algorithm has been trained
    trained = true;
    return trained;
}
Пример #2
0
bool BAG::train(LabelledClassificationData trainingData){
    
    const unsigned int M = trainingData.getNumSamples();
    const unsigned int N = trainingData.getNumDimensions();
    const unsigned int K = trainingData.getNumClasses();
    trained = false;
    classLabels.clear();
    
    if( M == 0 ){
        errorLog << "train(LabelledClassificationData trainingData) - Training data has zero samples!" << endl;
        return false;
    }
    
    numFeatures = N;
    numClasses = K;
    classLabels.resize(K);
    ranges = trainingData.getRanges();
    
    UINT ensembleSize = (UINT)ensemble.size();
    
    if( ensembleSize == 0 ){
        errorLog << "train(LabelledClassificationData trainingData) - The ensemble size is zero! You need to add some classifiers to the ensemble first." << endl;
        return false;
    }
    
    for(UINT i=0; i<ensembleSize; i++){
        if( ensemble[i] == NULL ){
            errorLog << "train(LabelledClassificationData trainingData) - The classifier at ensemble index " << i << " has not been set!" << endl;
            return false;
        }
    }

    //Train the ensemble
    for(UINT i=0; i<ensembleSize; i++){
        LabelledClassificationData boostedDataset = trainingData.getBootstrappedDataset();
        
        //Train the classifier with the bootstrapped dataset
        if( !ensemble[i]->train( boostedDataset ) ){
            errorLog << "train(LabelledClassificationData trainingData) - The classifier at ensemble index " << i << " failed training!" << endl;
            return false;
        }
    }
    
    //Set the class labels
    classLabels = trainingData.getClassLabels();
    
    //Flag that the algorithm has been trained
    trained = true;
    return trained;
}