コード例 #1
0
ファイル: RandomForests.cpp プロジェクト: hal2001/grt
bool RandomForests::combineModels( const RandomForests &forest ){

    if( !getTrained() ){
        errorLog << "combineModels( const RandomForests &forest ) - This instance has not been trained!" << endl;
        return false;
    }

    if( !forest.getTrained() ){
        errorLog << "combineModels( const RandomForests &forest ) - This external forest instance has not been trained!" << endl;
        return false;
    }

    if( this->getNumInputDimensions() != forest.getNumInputDimensions() ) {
        errorLog << "combineModels( const RandomForests &forest ) - The number of input dimensions of the external forest (";
        errorLog << forest.getNumInputDimensions() << ") does not match the number of input dimensions of this instance (";
        errorLog << this->getNumInputDimensions() << ")!" << endl;
        return false;
    }

    //Add the trees in the other forest to this model
    DecisionTreeNode *node;
    for(UINT i=0; i<forest.getForestSize(); i++){
        node = forest.getTree(i);
        if( node ){
            this->forest.push_back( node->deepCopy() );
            forestSize++;
        }
    }

    return true;
}
コード例 #2
0
ファイル: grt-rf-tool.cpp プロジェクト: nickgillian/grt
bool computeFeatureWeights( CommandLineParser &parser ){

    infoLog << "Computing feature weights..." << endl;

    string resultsFilename = "";
    string modelFilename = "";
    bool combineWeights = false;

    //Get the model filename
    if( !parser.get("model-filename",modelFilename) ){
        errorLog << "Failed to parse filename from command line! You can set the model filename using the --model." << endl;
        printUsage();
        return false;
    }

    //Get the results filename
    if( !parser.get("filename",resultsFilename) ){
        errorLog << "Failed to parse results filename from command line! You can set the results filename using the -f." << endl;
        printUsage();
        return false;
    }

    //Get the results filename
    parser.get("combine-weights",combineWeights);

    //Load the model
    GestureRecognitionPipeline pipeline;

    if( !pipeline.load( modelFilename ) ){
        errorLog << "Failed to load model from file: " << modelFilename << endl;
        printUsage();
        return false;
    }

    //Make sure the pipeline contains a random forest model and that it is trained
    RandomForests *forest = pipeline.getClassifier< RandomForests >();

    if( !forest ){
        errorLog << "Model loaded, but the pipeline does not contain a RandomForests classifier!" << endl;
        printUsage();
        return false;
    }

    if( !forest->getTrained() ){
        errorLog << "Model loaded, but the RandomForests classifier is not trained!" << endl;
        printUsage();
        return false;
    }

    //Compute the feature weights
    if( combineWeights ){
        VectorFloat weights = forest->getFeatureWeights();
        if( weights.getSize() == 0 ){
            errorLog << "Failed to compute feature weights!" << endl;
            printUsage();
            return false;
        }

        //Save the results to a file
        fstream file;
        file.open( resultsFilename.c_str(), fstream::out );
        
        const unsigned int N = weights.getSize();
        for(unsigned int i=0; i<N; i++){
            file << weights[i] << endl;
        }
        
        file.close();
    }else{

        double norm = 0.0;
        const unsigned int K = forest->getForestSize();
        const unsigned int N = forest->getNumInputDimensions();
        VectorFloat tmp( N, 0.0 );
        MatrixDouble weights(K,N);

        for(unsigned int i=0; i<K; i++){

            DecisionTreeNode *tree = forest->getTree(i);
            tree->computeFeatureWeights( tmp );
            norm = 1.0 / Util::sum( tmp );
            for(unsigned int j=0; j<N; j++){
                tmp[j] *= norm;
                weights[i][j] = tmp[j];
                tmp[j] = 0;
            }
        }

        //Save the results to a file
        weights.save( resultsFilename );
    }
    

    return true;
}