Exemple #1
0
ClassificationData ClassificationData::getBootstrappedDataset(UINT numSamples,bool balanceDataset) const{
    
    Random rand;
    ClassificationData newDataset;
    newDataset.setNumDimensions( getNumDimensions() );
    newDataset.setAllowNullGestureClass( allowNullGestureClass );
    newDataset.setExternalRanges( externalRanges, useExternalRanges );
    
    if( numSamples == 0 ) numSamples = totalNumSamples;
    
    newDataset.reserve( numSamples );

    const UINT K = getNumClasses(); 
    
    //Add all the class labels to the new dataset to ensure the dataset has a list of all the labels
    for(UINT k=0; k<K; k++){
        newDataset.addClass( classTracker[k].classLabel );
    }

    if( balanceDataset ){
        //Group the class indexs
        Vector< Vector< UINT > > classIndexs( K );
        for(UINT i=0; i<totalNumSamples; i++){
            classIndexs[ getClassLabelIndexValue( data[i].getClassLabel() ) ].push_back( i );
        }

        //Get the class with the minimum number of examples
        UINT numSamplesPerClass = (UINT)floor( numSamples / Float(K) );

        //Randomly select the training samples from each class
        UINT classIndex = 0;
        UINT classCounter = 0;
        UINT randomIndex = 0;
        for(UINT i=0; i<numSamples; i++){
            randomIndex = rand.getRandomNumberInt(0, (UINT)classIndexs[ classIndex ].size() );
            randomIndex = classIndexs[ classIndex ][ randomIndex ];
            newDataset.addSample(data[ randomIndex ].getClassLabel(), data[ randomIndex ].getSample());
            if( classCounter++ >= numSamplesPerClass && classIndex+1 < K ){
                classCounter = 0;
                classIndex++;
            }
        }

    }else{
        //Randomly select the training samples to add to the new data set
        UINT randomIndex;
        for(UINT i=0; i<numSamples; i++){
            randomIndex = rand.getRandomNumberInt(0, totalNumSamples);
            newDataset.addSample( data[randomIndex].getClassLabel(), data[randomIndex].getSample() );
        }
    }

    //Sort the class labels so they are in order
    newDataset.sortClassLabels();
    
    return newDataset;
}
Exemple #2
0
MatrixFloat ClassificationData::getClassMean() const{
	
	MatrixFloat mean(getNumClasses(),numDimensions);
	VectorFloat counter(getNumClasses(),0);
	
	mean.setAllValues( 0 );
	
	for(UINT i=0; i<totalNumSamples; i++){
		UINT classIndex = getClassLabelIndexValue( data[i].getClassLabel() );
		for(UINT j=0; j<numDimensions; j++){
			mean[classIndex][j] += data[i][j];
		}
		counter[ classIndex ]++;
	}
	
	for(UINT k=0; k<getNumClasses(); k++){
		for(UINT j=0; j<numDimensions; j++){
			mean[k][j] = counter[k] > 0 ? mean[k][j]/counter[k] : 0;
		}
	}
	
	return mean;
}
Exemple #3
0
MatrixFloat ClassificationData::getClassStdDev() const{

	MatrixFloat mean = getClassMean();
	MatrixFloat stdDev(getNumClasses(),numDimensions);
	VectorFloat counter(getNumClasses(),0);
	
	stdDev.setAllValues( 0 );
	
	for(UINT i=0; i<totalNumSamples; i++){
		UINT classIndex = getClassLabelIndexValue( data[i].getClassLabel() );
		for(UINT j=0; j<numDimensions; j++){
			stdDev[classIndex][j] += SQR(data[i][j]-mean[classIndex][j]);
		}
		counter[ classIndex  ]++;
	}
	
	for(UINT k=0; k<getNumClasses(); k++){
		for(UINT j=0; j<numDimensions; j++){
			stdDev[k][j] = sqrt( stdDev[k][j] / Float(counter[k]-1) );
		}
	}
	
	return stdDev;
}
bool ParticleClassifier::predict_( VectorDouble &inputVector ){

    if( !trained ){
        errorLog << "predict_(VectorDouble &inputVector) - The model has not been trained!" << endl;
        return false;
    }
    
    if( numInputDimensions != inputVector.size() ){
        errorLog << "predict_(VectorDouble &inputVector) - The number of features in the model " << numInputDimensions << " does not match that of the input vector " << inputVector.size() << endl;
        return false;
    }
    
    //Scale the input data if needed
    if( useScaling ){
        for(unsigned int j=0; j<numInputDimensions; j++){
            inputVector[j] = scale(inputVector[j],ranges[j].minValue,ranges[j].maxValue,0,1);
        }
    }
    
    predictedClassLabel = 0;
    maxLikelihood = 0;
    std::fill(classLikelihoods.begin(),classLikelihoods.end(),0);
    std::fill(classDistances.begin(),classDistances.end(),0);

    //Update the particle filter
    particleFilter.filter( inputVector );
    
    //Count the number of particles per class
    unsigned int gestureTemplate = 0;
    unsigned int gestureLabel = 0;
    unsigned int gestureIndex = 0;
    for(unsigned int i=0; i<numParticles; i++){
        gestureTemplate = (unsigned int)particleFilter[i].x[0]; //The first element in the state vector is the gesture template index
        gestureLabel = particleFilter.gestureTemplates[ gestureTemplate ].classLabel;
        gestureIndex = getClassLabelIndexValue( gestureLabel );
        
        classDistances[ gestureIndex ] += particleFilter[i].w;
    }
    
    bool rejectPrediction = false;
    if( useNullRejection ){
        if( particleFilter.getWeightSum() < 1.0e-5 ){
            rejectPrediction = true;
        }
    }
    
    //Compute the class likelihoods
    for(unsigned int i=0; i<numClasses; i++){

        classLikelihoods[ i ] = rejectPrediction ? 0 : classDistances[i];

        if( classLikelihoods[i] > maxLikelihood ){
            predictedClassLabel = classLabels[i];
            maxLikelihood = classLikelihoods[i];
        }
    }
    
    //Estimate the phase
    phase = particleFilter.getStateEstimation()[1]; //The 2nd element in the state vector is the estimatied phase
    
    return true;

}
bool DecisionTreeClusterNode::computeError( const ClassificationData &trainingData, MatrixFloat &data, const Vector< UINT > &classLabels, Vector< MinMax > ranges, Vector< UINT > groupIndex, const UINT featureIndex, Float &threshold, Float &error ){

    error = 0;
    threshold = 0;

    const UINT M = trainingData.getNumSamples();
    const UINT K = (UINT)classLabels.size();

    Float giniIndexL = 0;
    Float giniIndexR = 0;
    Float weightL = 0;
    Float weightR = 0;
    VectorFloat groupCounter(2,0);
    MatrixFloat classProbabilities(K,2);

    //Use this data to train a KMeans cluster with 2 clusters
    KMeans kmeans;
    kmeans.setNumClusters( 2 );
    kmeans.setComputeTheta( true );
    kmeans.setMinChange( 1.0e-5 );
    kmeans.setMinNumEpochs( 1 );
    kmeans.setMaxNumEpochs( 100 );

    //Disable the logging to clean things up
    kmeans.setTrainingLoggingEnabled( false );

    if( !kmeans.train_( data ) ){
        errorLog << __GRT_LOG__ << " Failed to train KMeans model for feature: " << featureIndex << std::endl;
        return false;
    }

    //Set the split threshold as the mid point between the two clusters
    const MatrixFloat &clusters = kmeans.getClusters();
    threshold = 0;
    for(UINT i=0; i<clusters.getNumRows(); i++){
        threshold += clusters[i][0];
    }
    threshold /= clusters.getNumRows();

    //Iterate over each sample and work out if it should be in the lhs (0) or rhs (1) group based on the current threshold
    groupCounter[0] = groupCounter[1] = 0;
    classProbabilities.setAllValues(0);
    for(UINT i=0; i<M; i++){
        groupIndex[i] = trainingData[ i ][ featureIndex ] >= threshold ? 1 : 0;
        groupCounter[ groupIndex[i] ]++;
        classProbabilities[ getClassLabelIndexValue(trainingData[i].getClassLabel(),classLabels) ][ groupIndex[i] ]++;
    }

    //Compute the class probabilities for the lhs group and rhs group
    for(UINT k=0; k<K; k++){
        classProbabilities[k][0] = groupCounter[0]>0 ? classProbabilities[k][0]/groupCounter[0] : 0;
        classProbabilities[k][1] = groupCounter[1]>0 ? classProbabilities[k][1]/groupCounter[1] : 0;
    }

    //Compute the Gini index for the lhs and rhs groups
    giniIndexL = giniIndexR = 0;
    for(UINT k=0; k<K; k++){
        giniIndexL += classProbabilities[k][0] * (1.0-classProbabilities[k][0]);
        giniIndexR += classProbabilities[k][1] * (1.0-classProbabilities[k][1]);
    }
    weightL = groupCounter[0]/M;
    weightR = groupCounter[1]/M;
    error = (giniIndexL*weightL) + (giniIndexR*weightR);

    return true;
}
bool TimeSeriesClassificationData::spiltDataIntoKFolds(const UINT K,const bool useStratifiedSampling){

    crossValidationSetup = false;
    crossValidationIndexs.clear();

    //K can not be zero
    if( K == 0 ){
        errorLog << "spiltDataIntoKFolds(UINT K) - K can not be zero!" << std::endl;
        return false;
    }

    //K can not be larger than the number of examples
    if( K > totalNumSamples ){
        errorLog << "spiltDataIntoKFolds(UINT K,bool useStratifiedSampling) - K can not be larger than the total number of samples in the dataset!" << std::endl;
        return false;
    }

    //K can not be larger than the number of examples in a specific class if the stratified sampling option is true
    if( useStratifiedSampling ){
        for(UINT c=0; c<classTracker.size(); c++){
            if( K > classTracker[c].counter ){
                errorLog << "spiltDataIntoKFolds(UINT K,bool useStratifiedSampling) - K can not be larger than the number of samples in any given class!" << std::endl;
                return false;
            }
        }
    }

    //Setup the dataset for k-fold cross validation
    kFoldValue = K;
    Vector< UINT > indexs( totalNumSamples );

    //Work out how many samples are in each fold, the last fold might have more samples than the others
    UINT numSamplesPerFold = (UINT) floor( totalNumSamples/Float(K) );

    //Resize the cross validation indexs buffer
    crossValidationIndexs.resize( K );

    //Create the random partion indexs
    Random random;
    UINT randomIndex = 0;

    if( useStratifiedSampling ){
        //Break the data into seperate classes
        Vector< Vector< UINT > > classData( getNumClasses() );

        //Add the indexs to their respective classes
        for(UINT i=0; i<totalNumSamples; i++){
            classData[ getClassLabelIndexValue( data[i].getClassLabel() ) ].push_back( i );
        }

        //Randomize the order of the indexs in each of the class index buffers
        for(UINT c=0; c<getNumClasses(); c++){
            UINT numSamples = (UINT)classData[c].size();
            for(UINT x=0; x<numSamples; x++){
                //Pick a random index
                randomIndex = random.getRandomNumberInt(0,numSamples);

                //Swap the indexs
                SWAP( classData[c][ x ] , classData[c][ randomIndex ] );
            }
        }

        //Loop over each of the classes and add the data equally to each of the k folds until there is no data left
        Vector< UINT >::iterator iter;
        for(UINT c=0; c<getNumClasses(); c++){
            iter = classData[ c ].begin();
            UINT k = 0;
            while( iter != classData[c].end() ){
                crossValidationIndexs[ k ].push_back( *iter );
                iter++;
                k++;
                k = k % K;
            }
        }

    }else{
        //Randomize the order of the data
        for(UINT i=0; i<totalNumSamples; i++) indexs[i] = i;
        for(UINT x=0; x<totalNumSamples; x++){
            //Pick a random index
            randomIndex = random.getRandomNumberInt(0,totalNumSamples);

            //Swap the indexs
            SWAP( indexs[ x ] , indexs[ randomIndex ] );
        }

        UINT counter = 0;
        UINT foldIndex = 0;
        for(UINT i=0; i<totalNumSamples; i++){
            //Add the index to the current fold
            crossValidationIndexs[ foldIndex ].push_back( indexs[i] );

            //Move to the next fold if ready
            if( ++counter == numSamplesPerFold && foldIndex < K-1 ){
                foldIndex++;
                counter = 0;
            }
        }
    }

    crossValidationSetup = true;
    return true;

}
TimeSeriesClassificationData TimeSeriesClassificationData::split(const UINT trainingSizePercentage,const bool useStratifiedSampling){

    //Partitions the dataset into a training dataset (which is kept by this instance of the TimeSeriesClassificationData) and
    //a testing/validation dataset (which is return as a new instance of the TimeSeriesClassificationData).  The trainingSizePercentage
    //therefore sets the size of the data which remains in this instance and the remaining percentage of data is then added to
    //the testing/validation dataset

    //The dataset has changed so flag that any previous cross validation setup will now not work
    crossValidationSetup = false;
    crossValidationIndexs.clear();

    TimeSeriesClassificationData trainingSet(numDimensions);
    TimeSeriesClassificationData testSet(numDimensions);
    trainingSet.setAllowNullGestureClass(allowNullGestureClass);
    testSet.setAllowNullGestureClass(allowNullGestureClass);
    Vector< UINT > indexs( totalNumSamples );

    //Create the random partion indexs
    Random random;
    UINT randomIndex = 0;

    if( useStratifiedSampling ){
        //Break the data into seperate classes
        Vector< Vector< UINT > > classData( getNumClasses() );

        //Add the indexs to their respective classes
        for(UINT i=0; i<totalNumSamples; i++){
            classData[ getClassLabelIndexValue( data[i].getClassLabel() ) ].push_back( i );
        }

        //Randomize the order of the indexs in each of the class index buffers
        for(UINT k=0; k<getNumClasses(); k++){
            UINT numSamples = (UINT)classData[k].size();
            for(UINT x=0; x<numSamples; x++){
                //Pick a random index
                randomIndex = random.getRandomNumberInt(0,numSamples);

                //Swap the indexs
                SWAP( classData[k][ x ] ,classData[k][ randomIndex ] );
            }
        }

        //Loop over each class and add the data to the trainingSet and testSet
        for(UINT k=0; k<getNumClasses(); k++){
            UINT numTrainingExamples = (UINT) floor( Float(classData[k].size()) / 100.0 * Float(trainingSizePercentage) );

            //Add the data to the training and test sets
            for(UINT i=0; i<numTrainingExamples; i++){
                trainingSet.addSample( data[ classData[k][i] ].getClassLabel(), data[ classData[k][i] ].getData() );
            }
            for(UINT i=numTrainingExamples; i<classData[k].size(); i++){
                testSet.addSample( data[ classData[k][i] ].getClassLabel(), data[ classData[k][i] ].getData() );
            }
        }

        //Overwrite the training data in this instance with the training data of the trainingSet
        data = trainingSet.getClassificationData();
        totalNumSamples = trainingSet.getNumSamples();
    }else{

        const UINT numTrainingExamples = (UINT) floor( Float(totalNumSamples) / 100.0 * Float(trainingSizePercentage) );
        //Create the random partion indexs
        Random random;
        for(UINT i=0; i<totalNumSamples; i++) indexs[i] = i;
        for(UINT x=0; x<totalNumSamples; x++){
            //Pick a random index
            randomIndex = random.getRandomNumberInt(0,totalNumSamples);

            //Swap the indexs
            SWAP( indexs[ x ] , indexs[ randomIndex ] );
        }

        //Add the data to the training and test sets
        for(UINT i=0; i<numTrainingExamples; i++){
            trainingSet.addSample( data[ indexs[i] ].getClassLabel(), data[ indexs[i] ].getData() );
        }
        for(UINT i=numTrainingExamples; i<totalNumSamples; i++){
            testSet.addSample( data[ indexs[i] ].getClassLabel(), data[ indexs[i] ].getData() );
        }

        //Overwrite the training data in this instance with the training data of the trainingSet
        data = trainingSet.getClassificationData();
        totalNumSamples = trainingSet.getNumSamples();
    }

    return testSet;
}
Exemple #8
0
bool BAG::predict_(VectorDouble &inputVector){
    
    if( !trained ){
        errorLog << "predict_(VectorDouble &inputVector) - Model Not Trained!" << endl;
        return false;
    }
    
    predictedClassLabel = 0;
	maxLikelihood = -10000;
    
    if( !trained ) return false;
    
	if( inputVector.size() != numInputDimensions ){
        errorLog << "predict_(VectorDouble &inputVector) - The size of the input vector (" << inputVector.size() << ") does not match the num features in the model (" << numInputDimensions << endl;
		return false;
	}
    
    if( useScaling ){
        for(UINT n=0; n<numInputDimensions; n++){
            inputVector[n] = scale(inputVector[n], ranges[n].minValue, ranges[n].maxValue, 0, 1);
        }
    }
    
    if( classLikelihoods.size() != numClasses ) classLikelihoods.resize(numClasses);
    if( classDistances.size() != numClasses ) classDistances.resize(numClasses);
    
    //Reset the likelihoods and distances
    for(UINT k=0; k<numClasses; k++){
        classLikelihoods[k] = 0;
        classDistances[k] = 0;
    }
    
    //Run the prediction for each classifier
    double sum = 0;
    UINT ensembleSize = (UINT)ensemble.size();
    for(UINT i=0; i<ensembleSize; i++){
        
        if( !ensemble[i]->predict(inputVector) ){
            errorLog << "predict_(VectorDouble &inputVector) - The " << i << " classifier in the ensemble failed prediction!" << endl;
            return false;
        }
        
        classLikelihoods[ getClassLabelIndexValue( ensemble[i]->getPredictedClassLabel() ) ] += weights[i];
        classDistances[ getClassLabelIndexValue( ensemble[i]->getPredictedClassLabel() ) ] += ensemble[i]->getMaximumLikelihood() * weights[i];
        
        sum += weights[i];
    }
    
    //Set the predicted class label as the most common class
    double maxCount = 0;
    UINT maxIndex = 0;
    for(UINT i=0; i<numClasses; i++){
        if( classLikelihoods[i] > maxCount ){
            maxIndex = i;
            maxCount = classLikelihoods[i];
        }
        classLikelihoods[i] /= sum;
        classDistances[i] /= double(ensembleSize);
    }
    
    predictedClassLabel = classLabels[ maxIndex ];
    maxLikelihood = classLikelihoods[ maxIndex ];
    
    return true;
}
Exemple #9
0
ClassificationData ClassificationData::split(const UINT trainingSizePercentage,const bool useStratifiedSampling){

    //Partitions the dataset into a training dataset (which is kept by this instance of the ClassificationData) and
	//a testing/validation dataset (which is return as a new instance of the ClassificationData).  The trainingSizePercentage
	//therefore sets the size of the data which remains in this instance and the remaining percentage of data is then added to
	//the testing/validation dataset

    //The dataset has changed so flag that any previous cross validation setup will now not work
    crossValidationSetup = false;
    crossValidationIndexs.clear();

    ClassificationData trainingSet(numDimensions);
    ClassificationData testSet(numDimensions);
    trainingSet.setAllowNullGestureClass( allowNullGestureClass );
    testSet.setAllowNullGestureClass( allowNullGestureClass );

	//Create the random partion indexs
	Random random;
    UINT randomIndex = 0;
    UINT K = getNumClasses();

    if( useStratifiedSampling ){
        //Break the data into seperate classes
        Vector< Vector< UINT > > classData( K );

        //Add the indexs to their respective classes
        for(UINT i=0; i<totalNumSamples; i++){
            classData[ getClassLabelIndexValue( data[i].getClassLabel() ) ].push_back( i );
        }

        //Randomize the order of the indexs in each of the class index buffers
        for(UINT k=0; k<K; k++){
            std::random_shuffle(classData[k].begin(), classData[k].end());
        }
        
        //Reserve the memory
        UINT numTrainingSamples = 0;
        UINT numTestSamples = 0;
        
        for(UINT k=0; k<K; k++){
            UINT numTrainingExamples = (UINT) floor( Float(classData[k].size()) / 100.0 * Float(trainingSizePercentage) );
            UINT numTestExamples = ((UINT)classData[k].size())-numTrainingExamples;
            numTrainingSamples += numTrainingExamples;
            numTestSamples += numTestExamples;
        }
        
        trainingSet.reserve( numTrainingSamples );
        testSet.reserve( numTestSamples );

        //Loop over each class and add the data to the trainingSet and testSet
        for(UINT k=0; k<K; k++){
            UINT numTrainingExamples = (UINT) floor( Float(classData[k].getSize()) / 100.0 * Float(trainingSizePercentage) );

            //Add the data to the training and test sets
            for(UINT i=0; i<numTrainingExamples; i++){
                trainingSet.addSample( data[ classData[k][i] ].getClassLabel(), data[ classData[k][i] ].getSample() );
            }
            for(UINT i=numTrainingExamples; i<classData[k].getSize(); i++){
                testSet.addSample( data[ classData[k][i] ].getClassLabel(), data[ classData[k][i] ].getSample() );
            }
        }
    }else{

        const UINT numTrainingExamples = (UINT) floor( Float(totalNumSamples) / 100.0 * Float(trainingSizePercentage) );

        //Create the random partion indexs
        UINT randomIndex = 0;
        Vector< UINT > indexs( totalNumSamples );
        for(UINT i=0; i<totalNumSamples; i++) indexs[i] = i;
        std::random_shuffle(indexs.begin(), indexs.end());
        
        //Reserve the memory
        trainingSet.reserve( numTrainingExamples );
        testSet.reserve( totalNumSamples-numTrainingExamples );

        //Add the data to the training and test sets
        for(UINT i=0; i<numTrainingExamples; i++){
            trainingSet.addSample( data[ indexs[i] ].getClassLabel(), data[ indexs[i] ].getSample() );
        }
        for(UINT i=numTrainingExamples; i<totalNumSamples; i++){
            testSet.addSample( data[ indexs[i] ].getClassLabel(), data[ indexs[i] ].getSample() );
        }
    }

    //Overwrite the training data in this instance with the training data of the trainingSet
    *this = trainingSet;

    //Sort the class labels in this dataset
    sortClassLabels();

    //Sort the class labels of the test dataset
    testSet.sortClassLabels();

	return testSet;
}
bool DecisionTreeThresholdNode::computeBestSplitBestRandomSplit( const UINT &numSplittingSteps, const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError ){
    
    const UINT M = trainingData.getNumSamples();
    const UINT N = (UINT)features.size();
    const UINT K = (UINT)classLabels.size();
    
    if( N == 0 ) return false;
    
    minError = grt_numeric_limits< Float >::max();
    UINT bestFeatureIndex = 0;
    Float bestThreshold = 0;
    Float error = 0;
    Float giniIndexL = 0;
    Float giniIndexR = 0;
    Float weightL = 0;
    Float weightR = 0;
    Random random;
    Vector< UINT > groupIndex(M);
    VectorFloat groupCounter(2,0);
    
    MatrixFloat classProbabilities(K,2);

    //Loop over each feature and try and find the best split point
    UINT m,n;
    const UINT numFeatures = features.getSize();
    for(m=0; m<numSplittingSteps; m++){
        //Chose a random feature
        n = random.getRandomNumberInt(0,numFeatures);
        featureIndex = features[n];
        
        //Randomly choose the threshold, the threshold is based on a randomly selected sample with some random scaling
        threshold = trainingData[ random.getRandomNumberInt(0,M) ][ featureIndex ] * random.getRandomNumberUniform(0.8,1.2);
        
        //Iterate over each sample and work out if it should be in the lhs (0) or rhs (1) group
        groupCounter[0] = groupCounter[1] = 0;
        classProbabilities.setAllValues(0);
        for(UINT i=0; i<M; i++){
            groupIndex[i] = trainingData[ i ][ featureIndex ] >= threshold ? 1 : 0;
            groupCounter[ groupIndex[i] ]++;
            classProbabilities[ getClassLabelIndexValue(trainingData[i].getClassLabel(),classLabels) ][ groupIndex[i] ]++;
        }
        
        //Compute the class probabilities for the lhs group and rhs group
        for(UINT k=0; k<K; k++){
            classProbabilities[k][0] = groupCounter[0]>0 ? classProbabilities[k][0]/groupCounter[0] : 0;
            classProbabilities[k][1] = groupCounter[1]>0 ? classProbabilities[k][1]/groupCounter[1] : 0;
        }
        
        //Compute the Gini index for the lhs and rhs groups
        giniIndexL = giniIndexR = 0;
        for(UINT k=0; k<K; k++){
            giniIndexL += classProbabilities[k][0] * (1.0-classProbabilities[k][0]);
            giniIndexR += classProbabilities[k][1] * (1.0-classProbabilities[k][1]);
        }
        weightL = groupCounter[0]/M;
        weightR = groupCounter[1]/M;
        error = (giniIndexL*weightL) + (giniIndexR*weightR);
        
        //Store the best threshold and feature index
        if( error < minError ){
            minError = error;
            bestThreshold = threshold;
            bestFeatureIndex = featureIndex;
        }
    }
    
    //Set the best feature index that will be returned to the DecisionTree that called this function
    featureIndex = bestFeatureIndex;
    
    //Store the node size, feature index, best threshold and class probabilities for this node
    set(M,featureIndex,bestThreshold,trainingData.getClassProbabilities(classLabels));
    
    return true;
}
bool DecisionTreeThresholdNode::computeBestSplitBestIterativeSplit( const UINT &numSplittingSteps, const ClassificationData &trainingData, const Vector< UINT > &features, const Vector< UINT > &classLabels, UINT &featureIndex, Float &minError ){
    
    const UINT M = trainingData.getNumSamples();
    const UINT N = features.getSize();
    const UINT K = classLabels.getSize();
    
    if( N == 0 ) return false;
    
    minError = grt_numeric_limits< Float >::max();
    UINT bestFeatureIndex = 0;
    Float bestThreshold = 0;
    Float error = 0;
    Float minRange = 0;
    Float maxRange = 0;
    Float step = 0;
    Float giniIndexL = 0;
    Float giniIndexR = 0;
    Float weightL = 0;
    Float weightR = 0;
    Vector< UINT > groupIndex(M);
    VectorFloat groupCounter(2,0);
    Vector< MinMax > ranges = trainingData.getRanges();
    
    MatrixFloat classProbabilities(K,2);
    
    //Loop over each feature and try and find the best split point
    for(UINT n=0; n<N; n++){
        minRange = ranges[n].minValue;
        maxRange = ranges[n].maxValue;
        step = (maxRange-minRange)/Float(numSplittingSteps);
        threshold = minRange;
        featureIndex = features[n];
        while( threshold <= maxRange ){
            
            //Iterate over each sample and work out if it should be in the lhs (0) or rhs (1) group
            groupCounter[0] = groupCounter[1] = 0;
            classProbabilities.setAllValues(0);
            for(UINT i=0; i<M; i++){
                groupIndex[i] = trainingData[ i ][ featureIndex ] >= threshold ? 1 : 0;
                groupCounter[ groupIndex[i] ]++;
                classProbabilities[ getClassLabelIndexValue(trainingData[i].getClassLabel(),classLabels) ][ groupIndex[i] ]++;
            }
            
            //Compute the class probabilities for the lhs group and rhs group
            for(UINT k=0; k<K; k++){
                classProbabilities[k][0] = groupCounter[0]>0 ? classProbabilities[k][0]/groupCounter[0] : 0;
                classProbabilities[k][1] = groupCounter[1]>0 ? classProbabilities[k][1]/groupCounter[1] : 0;
            }
            
            //Compute the Gini index for the lhs and rhs groups
            giniIndexL = giniIndexR = 0;
            for(UINT k=0; k<K; k++){
                giniIndexL += classProbabilities[k][0] * (1.0-classProbabilities[k][0]);
                giniIndexR += classProbabilities[k][1] * (1.0-classProbabilities[k][1]);
            }
            weightL = groupCounter[0]/M;
            weightR = groupCounter[1]/M;
            error = (giniIndexL*weightL) + (giniIndexR*weightR);
            
            //Store the best threshold and feature index
            if( error < minError ){
                minError = error;
                bestThreshold = threshold;
                bestFeatureIndex = featureIndex;
            }
            
            //Update the threshold
            threshold += step;
        }
    }
    
    //Set the best feature index that will be returned to the DecisionTree that called this function
    featureIndex = bestFeatureIndex;
    
    //Store the node size, feature index, best threshold and class probabilities for this node
    set(M,featureIndex,bestThreshold,trainingData.getClassProbabilities(classLabels));
    
    return true;
}