예제 #1
0
bool DecisionTreeClusterNode::computeError( const ClassificationData &trainingData, MatrixFloat &data, const Vector< UINT > &classLabels, Vector< MinMax > ranges, Vector< UINT > groupIndex, const UINT featureIndex, Float &threshold, Float &error ){

    error = 0;
    threshold = 0;

    const UINT M = trainingData.getNumSamples();
    const UINT K = (UINT)classLabels.size();

    Float giniIndexL = 0;
    Float giniIndexR = 0;
    Float weightL = 0;
    Float weightR = 0;
    VectorFloat groupCounter(2,0);
    MatrixFloat classProbabilities(K,2);

    //Use this data to train a KMeans cluster with 2 clusters
    KMeans kmeans;
    kmeans.setNumClusters( 2 );
    kmeans.setComputeTheta( true );
    kmeans.setMinChange( 1.0e-5 );
    kmeans.setMinNumEpochs( 1 );
    kmeans.setMaxNumEpochs( 100 );

    //Disable the logging to clean things up
    kmeans.setTrainingLoggingEnabled( false );

    if( !kmeans.train_( data ) ){
        errorLog << __GRT_LOG__ << " Failed to train KMeans model for feature: " << featureIndex << std::endl;
        return false;
    }

    //Set the split threshold as the mid point between the two clusters
    const MatrixFloat &clusters = kmeans.getClusters();
    threshold = 0;
    for(UINT i=0; i<clusters.getNumRows(); i++){
        threshold += clusters[i][0];
    }
    threshold /= clusters.getNumRows();

    //Iterate over each sample and work out if it should be in the lhs (0) or rhs (1) group based on the current threshold
    groupCounter[0] = groupCounter[1] = 0;
    classProbabilities.setAllValues(0);
    for(UINT i=0; i<M; i++){
        groupIndex[i] = trainingData[ i ][ featureIndex ] >= threshold ? 1 : 0;
        groupCounter[ groupIndex[i] ]++;
        classProbabilities[ getClassLabelIndexValue(trainingData[i].getClassLabel(),classLabels) ][ groupIndex[i] ]++;
    }

    //Compute the class probabilities for the lhs group and rhs group
    for(UINT k=0; k<K; k++){
        classProbabilities[k][0] = groupCounter[0]>0 ? classProbabilities[k][0]/groupCounter[0] : 0;
        classProbabilities[k][1] = groupCounter[1]>0 ? classProbabilities[k][1]/groupCounter[1] : 0;
    }

    //Compute the Gini index for the lhs and rhs groups
    giniIndexL = giniIndexR = 0;
    for(UINT k=0; k<K; k++){
        giniIndexL += classProbabilities[k][0] * (1.0-classProbabilities[k][0]);
        giniIndexR += classProbabilities[k][1] * (1.0-classProbabilities[k][1]);
    }
    weightL = groupCounter[0]/M;
    weightR = groupCounter[1]/M;
    error = (giniIndexL*weightL) + (giniIndexR*weightR);

    return true;
}