void BoostedClassifier::train(const vector<TrainingExample*>& examples) {
  if(examples.size()==0) return;
  assert(m_trees.size()==0);
  size_t m = examples.size();
  vector<double> distribs(m,1.0/m);
  m_sum_alphas = 0;
  for(size_t b=0; b<m_num_max_trees; b++) {
    DecisionTree* hb = new DecisionTree(m_label);
    hb->train(examples, distribs);
    double epsb = 0;
    size_t num_misclassified=0;
    for(size_t i=0; i<m; i++) {
      double predictprob= hb->predict(*examples[i]);
      bool prediction = (predictprob>0.5)?true:false;
      // cerr << "Actual label: " << egs_for_tree[i]->getLabel() << endl;
      // cerr << ", Predicted: " << m_label << " with prob " << predictprob << endl;
      bool actual = isLabelEqual(examples[i]->getLabel(), m_label);
      if(prediction!=actual) {
        epsb+=distribs[i];
        num_misclassified++;
      }
    }
    // cerr << "Number misclassified: " << num_misclassified << " of " << m << ", my label: " << m_label << endl;
    // cerr << "\tEpsb: " << epsb << endl;
    double epsilon = 0.001;
    if(epsb==0) {
      epsb=epsilon;
    } else if(epsb==1.0) {
      epsb=1-epsilon;
    }
    double alphab = 0.5*log((1-epsb)/epsb)/log(2.0);
    double z = 0.0;
    for(size_t i=0; i<m; i++) {
      double predictprob= hb->predict(*examples[i]);
    //   cerr << "My tree label: " << m_label << ", actual label: " << egs_for_tree[i]->getLabel()<< ", prediction probability: " << predictprob << endl;
      bool prediction = (predictprob>0.5)?true:false;
      bool actual = isLabelEqual(examples[i]->getLabel(),m_label);
      if(prediction!=actual) {
        // if incorrect drum up...
        distribs[i] = distribs[i]*exp(alphab);
      } else {
        // if correct drum down...
        distribs[i] = distribs[i]*exp(-alphab);
      }
      z +=distribs[i];
    }
    // cerr << "z: " << z << endl;
    for(size_t i=0; i<m; i++) {
      distribs[i] = distribs[i]/z;
    } 
    m_trees.push_back(hb);
    m_tree_alpha.push_back(alphab);
    m_sum_alphas += alphab;
    // cerr << "Trained tree with alphab: " << alphab <<endl;
  } // for each tree 
}
Exemplo n.º 2
0
bool RandomForests::train(LabelledClassificationData trainingData){
    
    //Clear any previous model
    clear();
    
    const unsigned int M = trainingData.getNumSamples();
    const unsigned int N = trainingData.getNumDimensions();
    const unsigned int K = trainingData.getNumClasses();
    
    if( M == 0 ){
        errorLog << "train(LabelledClassificationData labelledTrainingData) - Training data has zero samples!" << endl;
        return false;
    }
    
    numInputDimensions = N;
    numClasses = K;
    classLabels = trainingData.getClassLabels();
    ranges = trainingData.getRanges();
    
    //Scale the training data if needed
    if( useScaling ){
        //Scale the training data between 0 and 1
        trainingData.scale(0, 1);
    }
    
    //Train the random forest
    forestSize = 10;
    Random random;
    
    DecisionTree tree;
    tree.enableScaling( false ); //We have already scaled the training data so we do not need to scale it again
    tree.setTrainingMode( DecisionTree::BEST_RANDOM_SPLIT );
    tree.setNumSplittingSteps( numRandomSplits );
    tree.setMinNumSamplesPerNode( minNumSamplesPerNode );
    tree.setMaxDepth( maxDepth );
    
    for(UINT i=0; i<forestSize; i++){
        LabelledClassificationData data = trainingData.getBootstrappedDataset();
        
        if( !tree.train( data ) ){
            errorLog << "train(LabelledClassificationData labelledTrainingData) - Failed to train tree at forest index: " << i << endl;
            return false;
        }
        
        //Deep copy the tree into the forest
        forest.push_back( tree.deepCopyTree() );
    }
    
    //Flag that the algorithm has been trained
    trained = true;
    return trained;
}
Exemplo n.º 3
0
bool RandomForests::train_(ClassificationData &trainingData){
    
    //Clear any previous model
    clear();
    
    const unsigned int M = trainingData.getNumSamples();
    const unsigned int N = trainingData.getNumDimensions();
    const unsigned int K = trainingData.getNumClasses();
    
    if( M == 0 ){
        errorLog << "train_(ClassificationData &trainingData) - Training data has zero samples!" << endl;
        return false;
    }

    if( bootstrappedDatasetWeight <= 0.0 || bootstrappedDatasetWeight > 1.0 ){
        errorLog << "train_(ClassificationData &trainingData) - Bootstrapped Dataset Weight must be [> 0.0 and <= 1.0]" << endl;
        return false;
    }
    
    numInputDimensions = N;
    numClasses = K;
    classLabels = trainingData.getClassLabels();
    ranges = trainingData.getRanges();
    
    //Scale the training data if needed
    if( useScaling ){
        //Scale the training data between 0 and 1
        trainingData.scale(0, 1);
    }
    
    //Flag that the main algorithm has been trained encase we need to trigger any callbacks
    trained = true;
    
    //Train the random forest
    forest.reserve( forestSize );
    for(UINT i=0; i<forestSize; i++){
        
        //Get a balanced bootstrapped dataset
        UINT datasetSize = (UINT)(trainingData.getNumSamples() * bootstrappedDatasetWeight);
        ClassificationData data = trainingData.getBootstrappedDataset( datasetSize, true );
 
        DecisionTree tree;
        tree.setDecisionTreeNode( *decisionTreeNode );
        tree.enableScaling( false ); //We have already scaled the training data so we do not need to scale it again
        tree.setTrainingMode( trainingMode );
        tree.setNumSplittingSteps( numRandomSplits );
        tree.setMinNumSamplesPerNode( minNumSamplesPerNode );
        tree.setMaxDepth( maxDepth );
        tree.enableNullRejection( useNullRejection );
        tree.setRemoveFeaturesAtEachSpilt( removeFeaturesAtEachSpilt );

        trainingLog << "Training forest " << i+1 << "/" << forestSize << "..." << endl;
        
        //Train this tree
        if( !tree.train( data ) ){
            errorLog << "train_(ClassificationData &labelledTrainingData) - Failed to train tree at forest index: " << i << endl;
            clear();
            return false;
        }
        
        //Deep copy the tree into the forest
        forest.push_back( tree.deepCopyTree() );
    }

    return true;
}
Exemplo n.º 4
0
int main(int argc, const char * argv[])
{
    //Parse the data filename from the argument list
    if( argc != 2 ){
        cout << "Error: failed to parse data filename from command line. You should run this example with one argument pointing to the data filename!\n";
        return EXIT_FAILURE;
    }
    const string filename = argv[1];

    //Create a new DecisionTree instance
    DecisionTree dTree;
    
    //Set the node that the DecisionTree will use - different nodes may result in different decision boundaries
    //and some nodes may provide better accuracy than others on specific classification tasks
    //The current node options are:
    //- DecisionTreeClusterNode
    //- DecisionTreeThresholdNode
    dTree.setDecisionTreeNode( DecisionTreeClusterNode() );
    
    //Set the number of steps that will be used to choose the best splitting values
    //More steps will give you a better model, but will take longer to train
    dTree.setNumSplittingSteps( 1000 );
    
    //Set the maximum depth of the tree
    dTree.setMaxDepth( 10 );
    
    //Set the minimum number of samples allowed per node
    dTree.setMinNumSamplesPerNode( 10 );
    
    //Load some training data to train the classifier
    ClassificationData trainingData;
    
    if( !trainingData.load( filename ) ){
        cout << "Failed to load training data: " << filename << endl;
        return EXIT_FAILURE;
    }
    
    //Use 20% of the training dataset to create a test dataset
    ClassificationData testData = trainingData.split( 80 );
    
    //Train the classifier
    if( !dTree.train( trainingData ) ){
        cout << "Failed to train classifier!\n";
        return EXIT_FAILURE;
    }
    
    //Print the tree
    dTree.print();
    
    //Save the model to a file
    if( !dTree.save("DecisionTreeModel.grt") ){
        cout << "Failed to save the classifier model!\n";
        return EXIT_FAILURE;
    }
    
    //Load the model from a file
    if( !dTree.load("DecisionTreeModel.grt") ){
        cout << "Failed to load the classifier model!\n";
        return EXIT_FAILURE;
    }
    
    //Test the accuracy of the model on the test data
    double accuracy = 0;
    for(UINT i=0; i<testData.getNumSamples(); i++){
        //Get the i'th test sample
        UINT classLabel = testData[i].getClassLabel();
        VectorDouble inputVector = testData[i].getSample();
        
        //Perform a prediction using the classifier
        bool predictSuccess = dTree.predict( inputVector );
        
        if( !predictSuccess ){
            cout << "Failed to perform prediction for test sampel: " << i <<"\n";
            return EXIT_FAILURE;
        }
        
        //Get the predicted class label
        UINT predictedClassLabel = dTree.getPredictedClassLabel();
        VectorDouble classLikelihoods = dTree.getClassLikelihoods();
        VectorDouble classDistances = dTree.getClassDistances();
        
        //Update the accuracy
        if( classLabel == predictedClassLabel ) accuracy++;
        
        cout << "TestSample: " << i <<  " ClassLabel: " << classLabel << " PredictedClassLabel: " << predictedClassLabel << endl;
    }
    
    cout << "Test Accuracy: " << accuracy/double(testData.getNumSamples())*100.0 << "%" << endl;
    
    return EXIT_SUCCESS;
}