Exemplo n.º 1
bool RandomForests::train(LabelledClassificationData trainingData){
    //Clear any previous model
    const unsigned int M = trainingData.getNumSamples();
    const unsigned int N = trainingData.getNumDimensions();
    const unsigned int K = trainingData.getNumClasses();
    if( M == 0 ){
        errorLog << "train(LabelledClassificationData labelledTrainingData) - Training data has zero samples!" << endl;
        return false;
    numInputDimensions = N;
    numClasses = K;
    classLabels = trainingData.getClassLabels();
    ranges = trainingData.getRanges();
    //Scale the training data if needed
    if( useScaling ){
        //Scale the training data between 0 and 1
        trainingData.scale(0, 1);
    //Train the random forest
    forestSize = 10;
    Random random;
    DecisionTree tree;
    tree.enableScaling( false ); //We have already scaled the training data so we do not need to scale it again
    tree.setTrainingMode( DecisionTree::BEST_RANDOM_SPLIT );
    tree.setNumSplittingSteps( numRandomSplits );
    tree.setMinNumSamplesPerNode( minNumSamplesPerNode );
    tree.setMaxDepth( maxDepth );
    for(UINT i=0; i<forestSize; i++){
        LabelledClassificationData data = trainingData.getBootstrappedDataset();
        if( !tree.train( data ) ){
            errorLog << "train(LabelledClassificationData labelledTrainingData) - Failed to train tree at forest index: " << i << endl;
            return false;
        //Deep copy the tree into the forest
        forest.push_back( tree.deepCopyTree() );
    //Flag that the algorithm has been trained
    trained = true;
    return trained;
Exemplo n.º 2
bool BAG::train(LabelledClassificationData trainingData){
    const unsigned int M = trainingData.getNumSamples();
    const unsigned int N = trainingData.getNumDimensions();
    const unsigned int K = trainingData.getNumClasses();
    trained = false;
    if( M == 0 ){
        errorLog << "train(LabelledClassificationData trainingData) - Training data has zero samples!" << endl;
        return false;
    numFeatures = N;
    numClasses = K;
    ranges = trainingData.getRanges();
    UINT ensembleSize = (UINT)ensemble.size();
    if( ensembleSize == 0 ){
        errorLog << "train(LabelledClassificationData trainingData) - The ensemble size is zero! You need to add some classifiers to the ensemble first." << endl;
        return false;
    for(UINT i=0; i<ensembleSize; i++){
        if( ensemble[i] == NULL ){
            errorLog << "train(LabelledClassificationData trainingData) - The classifier at ensemble index " << i << " has not been set!" << endl;
            return false;

    //Train the ensemble
    for(UINT i=0; i<ensembleSize; i++){
        LabelledClassificationData boostedDataset = trainingData.getBootstrappedDataset();
        //Train the classifier with the bootstrapped dataset
        if( !ensemble[i]->train( boostedDataset ) ){
            errorLog << "train(LabelledClassificationData trainingData) - The classifier at ensemble index " << i << " failed training!" << endl;
            return false;
    //Set the class labels
    classLabels = trainingData.getClassLabels();
    //Flag that the algorithm has been trained
    trained = true;
    return trained;