Пример #1
0
int main() {
	// This is the INFERENCE/EM engine, derived from the 
    // libDAI example program (example_sprinkler_em)
    // (http://www.cs.ubc.ca/~murphyk/Bayes/bnintro.html)
    //
    // The factor graph file (input.fg) has to be generated first
	// and the data file (sprinkler.tab),
	// as well as the EM commands 

    // Read the factorgraph from the file
    FactorGraph Network;
    Network.ReadFromFile( "input.fg" );

    // Prepare junction-tree object for doing exact inference for E-step
    PropertySet infprops;
    infprops.set( "verbose", (size_t)1 );
    infprops.set( "updates", string("HUGIN") );
    infprops.set( "maxiter", string("1000") );
    infprops.set( "tol", string("0.00001") );
    infprops.set( "logdomain", true);
    infprops.set( "updates", string("SEQFIX") );
    InfAlg* inf = newInfAlg("BP", Network, infprops );
    inf->init();

    // Read sample from file
    Evidence e;
    ifstream estream( "input.tab" );
    e.addEvidenceTabFile( estream, Network );
    cerr << "Number of samples: " << e.nrSamples() << endl;

    // Read EM specification
    ifstream emstream( "input.em" );
    EMAlg em(e, *inf, emstream);

    // Iterate EM until convergence
    while( !em.hasSatisfiedTermConditions() ) {
        Real l = em.iterate();
        cerr << "Iteration " << em.Iterations() << " likelihood: " << l <<endl;
		Real c = inf->logZ();
        cerr << "Iteration infAlg " << em.Iterations() << " likelihood: " << c <<endl;
    }

    cout.precision(6);
    cout << inf->fg();

    delete inf;

    return 0;
}
Пример #2
0
void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] ) {
    char *filename;


    // Check for proper number of arguments
    if ((nrhs != NR_IN) || (nlhs != NR_OUT)) {
        mexErrMsgTxt("Usage: [psi] = dai_readfg(filename);\n\n"
        "\n"
        "INPUT:  filename   = filename of a .fg file\n"
        "\n"
        "OUTPUT: psi        = linear cell array containing the factors\n"
        "                     (psi{i} is a structure with a Member field\n"
        "                     and a P field, like a CPTAB).\n");
    }

    // Get input parameters
    size_t buflen;
    buflen = mxGetN( FILENAME_IN ) + 1;
    filename = (char *)mxCalloc( buflen, sizeof(char) );
    mxGetString( FILENAME_IN, filename, buflen );


    // Read factorgraph
    FactorGraph fg;
    try {
        fg.ReadFromFile( filename );
    } catch( std::exception &e ) {
        mexErrMsgTxt( e.what() );
    }


    // Save factors
    vector<Factor> psi;
    for( size_t I = 0; I < fg.nrFactors(); I++ )
        psi.push_back(fg.factor(I));


    // Hand over results to MATLAB
    PSI_OUT = Factors2mx(psi);


    return;
}
Пример #3
0
int main( int argc, char** argv ) {
    if( argc != 4 )
        usage("Incorrect number of arguments.");

    FactorGraph fg;
    fg.ReadFromFile( argv[1] );

    PropertySet infprops;
    infprops.set( "verbose", (size_t)0 );
    infprops.set( "updates", string("HUGIN") );
    InfAlg* inf = newInfAlg( "JTREE", fg, infprops );
    inf->init();

    Evidence e;
    ifstream estream( argv[2] );
    e.addEvidenceTabFile( estream, fg );

    cout << "Number of samples: " << e.nrSamples() << endl;
    for( Evidence::iterator ps = e.begin(); ps != e.end(); ps++ )
        cout << "Sample #" << (ps - e.begin()) << " has " << ps->size() << " observations." << endl;

    ifstream emstream( argv[3] );
    EMAlg em(e, *inf, emstream);

    while( !em.hasSatisfiedTermConditions() ) {
        Real l = em.iterate();
        cout << "Iteration " << em.Iterations() << " likelihood: " << l <<endl;
    }

    cout << endl << "Inferred Factor Graph:" << endl << "######################" << endl;
    cout.precision(12);
    cout << inf->fg();

    delete inf;

    return 0;
}
Пример #4
0
int main( int argc, char *argv[] ) {
    try {
        string filename;
        string aliases;
        vector<string> methods;
        double tol;
        size_t maxiter;
        size_t verbose;
        bool marginals = false;
        bool report_iters = true;
        bool report_time = true;

        po::options_description opts_required("Required options");
        opts_required.add_options()
            ("filename", po::value< string >(&filename), "Filename of FactorGraph")
            ("methods", po::value< vector<string> >(&methods)->multitoken(), "DAI methods to test")
        ;

        po::options_description opts_optional("Allowed options");
        opts_optional.add_options()
            ("help", "produce help message")
            ("aliases", po::value< string >(&aliases), "Filename for aliases")
            ("tol", po::value< double >(&tol), "Override tolerance")
            ("maxiter", po::value< size_t >(&maxiter), "Override maximum number of iterations")
            ("verbose", po::value< size_t >(&verbose), "Override verbosity")
            ("marginals", po::value< bool >(&marginals), "Output single node marginals?")
            ("report-time", po::value< bool >(&report_time), "Report calculation time")
            ("report-iters", po::value< bool >(&report_iters), "Report iterations needed")
        ;

        po::options_description cmdline_options;
        cmdline_options.add(opts_required).add(opts_optional);

        po::variables_map vm;
        po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
        po::notify(vm);

        if( vm.count("help") || !(vm.count("filename") && vm.count("methods")) ) {
            cout << "Reads factorgraph <filename.fg> and performs the approximate" << endl;
            cout << "inference algorithms <method*>, reporting calculation time, max and average" << endl;
            cout << "error and relative logZ error (comparing with the results of" << endl;
            cout << "<method0>, the base method, for which one can use JTREE_HUGIN)." << endl << endl;
            cout << opts_required << opts_optional << endl;
            return 1;
        }

        // Read aliases
        map<string,string> Aliases;
        if( !aliases.empty() ) {
            ifstream infile;
            infile.open (aliases.c_str());
            if (infile.is_open()) {
                while( true ) {
                    string line;
                    getline(infile,line);
                    if( infile.fail() )
                        break;
                    if( (!line.empty()) && (line[0] != '#') ) {
                        string::size_type pos = line.find(':',0);
                        if( pos == string::npos )
                            throw "Invalid alias";
                        else {
                            string::size_type posl = line.substr(0, pos).find_last_not_of(" \t");
                            string key = line.substr(0, posl + 1);
                            string::size_type posr = line.substr(pos + 1, line.length()).find_first_not_of(" \t");
                            string val = line.substr(pos + 1 + posr, line.length());
                            Aliases[key] = val;
                        }
                    }
                }
                infile.close();
            } else
                throw "Error opening aliases file";
        }

        FactorGraph fg;
        fg.ReadFromFile( filename.c_str() );

        vector<Factor> q0;
        double logZ0 = 0.0;

        cout.setf( ios_base::scientific );
        cout.precision( 3 );

        cout << "# " << filename << endl;
        cout.width( 40 );
        cout << left << "# METHOD" << "  ";
        if( report_time ) {
            cout.width( 10 );
            cout << right << "SECONDS" << "   ";
        }
        if( report_iters ) {
            cout.width( 10 );
            cout << "ITERS" << "  ";
        }
        cout.width( 10 );
        cout << "MAX ERROR" << "  ";
        cout.width( 10 );
        cout << "AVG ERROR" << "  ";
        cout.width( 10 );
        cout << "LOGZ ERROR" << "  ";
        cout.width( 10 );
        cout << "MAXDIFF" << "  ";
        cout << endl;

        for( size_t m = 0; m < methods.size(); m++ ) {
            pair<string, PropertySet> meth = parseMethod( methods[m], Aliases );

            if( vm.count("tol") )
                meth.second.Set("tol",tol);
            if( vm.count("maxiter") )
                meth.second.Set("maxiter",maxiter);
            if( vm.count("verbose") )
                meth.second.Set("verbose",verbose);
            TestDAI piet(fg, meth.first, meth.second );
            piet.doDAI();
            if( m == 0 ) {
                q0 = piet.q;
                logZ0 = piet.logZ;
            }
            piet.calcErrs(q0);

            cout.width( 40 );
            cout << left << methods[m] << "  ";
            if( report_time ) {
                cout.width( 10 );
                cout << right << piet.time << "    ";
            }
            if( report_iters ) {
                cout.width( 10 );
                if( piet.has_iters ) {
                    cout << piet.iters << "  ";
                } else {
                    cout << "N/A         ";
                }
            }

            if( m > 0 ) {
                cout.setf( ios_base::scientific );
                cout.precision( 3 );
                
                cout.width( 10 ); 
                double me = clipdouble( piet.maxErr(), 1e-9 );
                cout << me << "  ";
                
                cout.width( 10 );
                double ae = clipdouble( piet.avgErr(), 1e-9 );
                cout << ae << "  ";
                
                cout.width( 10 );
                if( piet.has_logZ ) {
                    double le = clipdouble( piet.logZ / logZ0 - 1.0, 1e-9 );
                    cout << le << "  ";
                } else
                    cout << "N/A         ";

                cout.width( 10 );
                if( piet.has_maxdiff ) {
                    double md = clipdouble( piet.maxdiff, 1e-9 );
                    if( isnan( me ) )
                        md = me;
                    if( isnan( ae ) )
                        md = ae;
                    cout << md << "  ";
                } else
                    cout << "N/A         ";
            }
            cout << endl;

            if( marginals ) {
                for( size_t i = 0; i < piet.q.size(); i++ )
                    cout << "# " << piet.q[i] << endl;
            }
        }
    } catch(const char *e) {
        cerr << "Exception: " << e << endl;
        return 1;
    } catch(exception& e) {
        cerr << "Exception: " << e.what() << endl;
        return 1;
    }
    catch(...) {
        cerr << "Exception of unknown type!" << endl;
    }

    return 0;
}
Пример #5
0
int main( int argc, char *argv[] ) {
    if ( argc != 2 ) {
        cout << "Usage: " << argv[0] << " <filename.fg>" << endl << endl;
        cout << "Reads factor graph <filename.fg> and verifies" << endl;
        cout << "whether BBP works correctly on it." << endl << endl;
        return 1;
    } else {
        // Read FactorGraph from the file specified by the first command line argument
        FactorGraph fg;
        fg.ReadFromFile(argv[1]);

        // Set some constants
        size_t verbose = 0;
        Real   tol = 1.0e-9;
        size_t maxiter = 10000;
        Real   damping = 0.0;

        // Store the constants in a PropertySet object
        PropertySet opts;
        opts.set("verbose",verbose);  // Verbosity (amount of output generated)
        opts.set("tol",tol);          // Tolerance for convergence
        opts.set("maxiter",maxiter);  // Maximum number of iterations
        opts.set("damping",damping);  // Amount of damping applied

        // Construct a BP (belief propagation) object from the FactorGraph fg
        BP bp(fg, opts("updates",string("SEQFIX"))("logdomain",false));
        bp.recordSentMessages = true;
        bp.init();
        bp.run();

        vector<size_t> state( fg.nrVars(), 0 );

        for( size_t t = 0; t < 45; t++ ) {
            BBP::Properties::UpdateType updates;
            switch( t % 5 ) {
                case BBP::Properties::UpdateType::SEQ_FIX:
                    updates = BBP::Properties::UpdateType::SEQ_FIX;
                    break;
                case BBP::Properties::UpdateType::SEQ_MAX:
                    updates = BBP::Properties::UpdateType::SEQ_MAX;
                    break;
                case BBP::Properties::UpdateType::SEQ_BP_REV:
                    updates = BBP::Properties::UpdateType::SEQ_BP_REV;
                    break;
                case BBP::Properties::UpdateType::SEQ_BP_FWD:
                    updates = BBP::Properties::UpdateType::SEQ_BP_FWD;
                    break;
                case BBP::Properties::UpdateType::PAR:
                    updates = BBP::Properties::UpdateType::PAR;
                    break;
            }
            BBPCostFunction cfn;
            switch( (t / 5) % 9 ) {
                case 0:
                    cfn = BBPCostFunction::CFN_GIBBS_B;
                    break;
                case 1:
                    cfn = BBPCostFunction::CFN_GIBBS_B2;
                    break;
                case 2:
                    cfn = BBPCostFunction::CFN_GIBBS_EXP;
                    break;
                case 3:
                    cfn = BBPCostFunction::CFN_GIBBS_B_FACTOR;
                    break;
                case 4:
                    cfn = BBPCostFunction::CFN_GIBBS_B2_FACTOR;
                    break;
                case 5:
                    cfn = BBPCostFunction::CFN_GIBBS_EXP_FACTOR;
                    break;
                case 6:
                    cfn = BBPCostFunction::CFN_VAR_ENT;
                    break;
                case 7:
                    cfn = BBPCostFunction::CFN_FACTOR_ENT;
                    break;
                case 8:
                    cfn = BBPCostFunction::CFN_BETHE_ENT;
                    break;
            }

            Real h = 1e-4;
            Real result = numericBBPTest( bp, &state, opts("updates",updates), cfn, h );
            cout << "result: " << result << ",\tupdates=" << updates << ", cfn=" << cfn << endl;
        }
    }

    return 0;
}
int main( int argc, char *argv[] ) {
    if ( argc != 3 ) {
        cout << "Usage: " << argv[0] << " <filename.fg> [map|pd]" << endl << endl;
        cout << "Reads factor graph <filename.fg> and runs" << endl;
        cout << "map: Junction tree MAP" << endl;
        cout << "pd : LBP and posterior decoding" << endl << endl;
        return 1;
    } else {
        // Redirect cerr to inf.log
        ofstream errlog("inf.log");
        //streambuf* orig_cerr = cerr.rdbuf();
        cerr.rdbuf(errlog.rdbuf());

        // Read FactorGraph from the file specified by the first command line argument
        FactorGraph fg;
        fg.ReadFromFile(argv[1]);

        // Set some constants
        size_t maxiter = 10000;
        Real   tol = 1e-9;
        size_t verb = 1;

        // Store the constants in a PropertySet object
        PropertySet opts;
        opts.set("maxiter",maxiter);  // Maximum number of iterations
        opts.set("tol",tol);          // Tolerance for convergence
        opts.set("verbose",verb);     // Verbosity (amount of output generated)

        if (strcmp(argv[2], "map") == 0) {
            // Construct another JTree (junction tree) object that is used to calculate
            // the joint configuration of variables that has maximum probability (MAP state)
            JTree jtmap( fg, opts("updates",string("HUGIN"))("inference",string("MAXPROD")) );
            // Initialize junction tree algorithm
            jtmap.init();
            // Run junction tree algorithm
            jtmap.run();
            // Calculate joint state of all variables that has maximum probability
            vector<size_t> jtmapstate = jtmap.findMaximum();

            /*
            // Report exact MAP variable marginals
            cout << "Exact MAP variable marginals:" << endl;
            for( size_t i = 0; i < fg.nrVars(); i++ )
                cout << jtmap.belief(fg.var(i)) << endl;
            */

            // Report exact MAP joint state
            cerr << "Exact MAP state (log score = " << fg.logScore( jtmapstate ) << "):" << endl;
            cout << fg.nrVars() << endl;
            for( size_t i = 0; i < jtmapstate.size(); i++ )
                cout << fg.var(i).label() << " " << jtmapstate[i] + 1 << endl; // +1 because in MATLAB assignments start at 1
        } else if (strcmp(argv[2], "pd") == 0) {

            // Construct a BP (belief propagation) object from the FactorGraph fg
            // using the parameters specified by opts and two additional properties,
            // specifying the type of updates the BP algorithm should perform and
            // whether they should be done in the real or in the logdomain
            BP bp(fg, opts("updates",string("SEQMAX"))("logdomain",true));
            // Initialize belief propagation algorithm
            bp.init();
            // Run belief propagation algorithm
            bp.run();

            // Report variable marginals for fg, calculated by the belief propagation algorithm
            cerr << "LBP posterior decoding (highest prob assignment in marginal):" << endl;
            cout << fg.nrVars() << endl;
            for( size_t i = 0; i < fg.nrVars(); i++ ) {// iterate over all variables in fg
                //cout << bp.belief(fg.var(i)) << endl; // display the belief of bp for that variable
                Factor marginal = bp.belief(fg.var(i));
                Real maxprob = marginal.max();
                for (size_t j = 0; j < marginal.nrStates(); j++) {
                    if (marginal[j] == maxprob) {
                        cout << fg.var(i).label() << " " << j + 1 << endl; // +1 because in MATLAB assignments start at 1
                    }
                }
            }
        } else {
            cerr << "Invalid inference algorithm specified." << endl;
            return 1;
        }
    }

    return 0;
}
Пример #7
0
/// Main function
int main( int argc, char *argv[] ) {
    // Variables to store command line options
    // Filename of factor graph
    string filename;
    // Filename for aliases
    string aliases;
    // Approximate Inference methods to use
    vector<string> methods;
    // Which marginals to output
    MarginalsOutputType marginals;
    // Output number of iterations?
    bool report_iters = true;
    // Output calculation time?
    bool report_time = true;

    // Define required command line options
    po::options_description opts_required("Required options");
    opts_required.add_options()
        ("filename", po::value< string >(&filename), "Filename of factor graph")
        ("methods", po::value< vector<string> >(&methods)->multitoken(), "DAI methods to perform")
    ;

    // Define allowed command line options
    po::options_description opts_optional("Allowed options");
    opts_optional.add_options()
        ("help", "Produce help message")
        ("aliases", po::value< string >(&aliases), "Filename for aliases")
        ("marginals", po::value< MarginalsOutputType >(&marginals), "Output marginals? (NONE/VAR/FAC/VARFAC/ALL, default=NONE)")
        ("report-time", po::value< bool >(&report_time), "Output calculation time (default==1)?")
        ("report-iters", po::value< bool >(&report_iters), "Output iterations needed (default==1)?")
    ;

    // Define all command line options
    po::options_description cmdline_options;
    cmdline_options.add(opts_required).add(opts_optional);

    // Parse command line
    po::variables_map vm;
    po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
    po::notify(vm);

    // Display help message if necessary
    if( vm.count("help") || !(vm.count("filename") && vm.count("methods")) ) {
        cout << "This program is part of libDAI - http://www.libdai.org/" << endl << endl;
        cout << "Usage: ./testdai --filename <filename.fg> --methods <method1> [<method2> <method3> ...]" << endl << endl;
        cout << "Reads factor graph <filename.fg> and performs the approximate inference algorithms" << endl;
        cout << "<method*>, reporting for each method:" << endl;
        cout << "  o the calculation time needed, in seconds (if report-time == 1);" << endl;
        cout << "  o the number of iterations needed (if report-iters == 1);" << endl;
        cout << "  o the maximum (over all variables) total variation error in the variable marginals;" << endl;
        cout << "  o the average (over all variables) total variation error in the variable marginals;" << endl;
        cout << "  o the maximum (over all factors) total variation error in the factor marginals;" << endl;
        cout << "  o the average (over all factors) total variation error in the factor marginals;" << endl;
        cout << "  o the error (difference) of the logarithm of the partition sums;" << endl << endl;
        cout << "All errors are calculated by comparing the results of the current method with" << endl; 
        cout << "the results of the first method (the base method). If marginals==VAR, additional" << endl;
        cout << "output consists of the variable marginals, if marginals==FAC, the factor marginals" << endl;
        cout << "if marginals==VARFAC, both variable and factor marginals, and if marginals==ALL, all" << endl;
        cout << "marginals calculated by the method are reported." << endl << endl;
        cout << "<method*> should be a list of one or more methods, seperated by spaces, in the format:" << endl << endl;
        cout << "    name[key1=val1,key2=val2,key3=val3,...,keyn=valn]" << endl << endl;
        cout << "where name should be the name of an algorithm in libDAI (or an alias, if an alias" << endl;
        cout << "filename is provided), followed by a list of properties (surrounded by rectangular" << endl;
        cout << "brackets), where each property consists of a key=value pair and the properties are" << endl;
        cout << "seperated by commas. If an alias file is specified, alias substitution is performed." << endl;
        cout << "This is done by looking up the name in the alias file and substituting the alias" << endl;
        cout << "by its corresponding method as defined in the alias file. Properties are parsed from" << endl;
        cout << "left to right, so if a property occurs repeatedly, the right-most value is used." << endl << endl;
        cout << opts_required << opts_optional << endl;
#ifdef DAI_DEBUG
        cout << "Note: this is a debugging build of libDAI." << endl << endl;
#endif
        cout << "Example:  ./testdai --filename testfast.fg --aliases aliases.conf --methods JTREE_HUGIN BP_SEQFIX BP_PARALL[maxiter=5]" << endl;
        return 1;
    }

    try {
        // Read aliases
        map<string,string> Aliases;
        if( !aliases.empty() )
            Aliases = readAliasesFile( aliases );

        // Read factor graph
        FactorGraph fg;
        fg.ReadFromFile( filename.c_str() );

        // Declare variables used for storing variable factor marginals and log partition sum of base method
        vector<Factor> varMarginals0;
        vector<Factor> facMarginals0;
        Real logZ0 = 0.0;

        // Output header
        cout.setf( ios_base::scientific );
        cout.precision( 3 );
        cout << "# " << filename << endl;
        cout.width( 39 );
        cout << left << "# METHOD" << "\t";
        if( report_time )
            cout << right << "SECONDS  " << "\t";
        if( report_iters )
            cout << "ITERS" << "\t";
        cout << "MAX VAR ERR" << "\t";
        cout << "AVG VAR ERR" << "\t";
        cout << "MAX FAC ERR" << "\t";
        cout << "AVG FAC ERR" << "\t";
        cout << "LOGZ ERROR" << "\t";
        cout << "MAXDIFF" << "\t";
        cout << endl;

        // For each method...
        for( size_t m = 0; m < methods.size(); m++ ) {
            // Parse method
            pair<string, PropertySet> meth = parseNameProperties( methods[m], Aliases );

            // Construct object for running the method
            TestDAI testdai(fg, meth.first, meth.second );

            // Run the method
            testdai.doDAI();

            // For the base method, store its variable marginals and logarithm of the partition sum
            if( m == 0 ) {
                varMarginals0 = testdai.varMarginals;
                facMarginals0 = testdai.facMarginals;
                logZ0 = testdai.logZ;
            }

            // Calculate errors relative to base method
            testdai.calcErrors( varMarginals0, facMarginals0 );

            // Output method name
            cout.width( 39 );
            cout << left << methods[m] << "\t";
            // Output calculation time, if requested
            if( report_time )
                cout << right << testdai.time << "\t";
            // Output number of iterations, if requested
            if( report_iters ) {
                if( testdai.has_iters ) {
                    cout << testdai.iters << "\t";
                } else {
                    cout << "N/A  \t";
                }
            }

            // If this is not the base method
            if( m > 0 ) {
                cout.setf( ios_base::scientific );
                cout.precision( 3 );

                // Output maximum error in variable marginals
                Real mev = clipReal( testdai.maxVarErr(), 1e-9 );
                cout << mev << "\t";

                // Output average error in variable marginals
                Real aev = clipReal( testdai.avgVarErr(), 1e-9 );
                cout << aev << "\t";

                // Output maximum error in factor marginals
                Real mef = clipReal( testdai.maxFacErr(), 1e-9 );
                if( mef == INFINITY )
                    cout << "N/A       \t";
                else
                    cout << mef << "\t";

                // Output average error in factor marginals
                Real aef = clipReal( testdai.avgFacErr(), 1e-9 );
                if( aef == INFINITY )
                    cout << "N/A       \t";
                else
                    cout << aef << "\t";

                // Output error in log partition sum
                if( testdai.has_logZ ) {
                    cout.setf( ios::showpos );
                    Real le = clipReal( testdai.logZ - logZ0, 1e-9 );
                    cout << le << "\t";
                    cout.unsetf( ios::showpos );
                } else
                    cout << "N/A       \t";

                // Output maximum difference in last iteration
                if( testdai.has_maxdiff ) {
                    Real md = clipReal( testdai.maxdiff, 1e-9 );
                    if( dai::isnan( mev ) )
                        md = mev;
                    if( dai::isnan( aev ) )
                        md = aev;
                    if( md == INFINITY )
                        md = 1.0;
                    cout << md << "\t";
                } else
                    cout << "N/A    \t";
            }
            cout << endl;

            // Output marginals, if requested
            if( marginals == MarginalsOutputType::VAR || marginals == MarginalsOutputType::VARFAC )
                for( size_t i = 0; i < testdai.varMarginals.size(); i++ )
                    cout << "# " << testdai.varMarginals[i] << endl;
            if( marginals == MarginalsOutputType::FAC || marginals == MarginalsOutputType::VARFAC )
                for( size_t I = 0; I < testdai.facMarginals.size(); I++ )
                    cout << "# " << testdai.facMarginals[I] << endl;
            if( marginals == MarginalsOutputType::ALL )
                for( size_t I = 0; I < testdai.allMarginals.size(); I++ )
                    cout << "# " << testdai.allMarginals[I] << endl;
        }

        return 0;
    } catch( string &s ) {
        // Abort with error message
        cerr << "Exception: " << s << endl;
        return 2;
    }
}
Пример #8
0
int main( int argc, char *argv[] ) {
    if( argc != 3 ) {
        cout << "Usage: " << argv[0] << " <in.fg> <tw>" << endl << endl;
        cout << "Reports some characteristics of the .fg network." << endl;
        cout << "Also calculates treewidth (which may take some time) unless <tw> == 0." << endl;
        return 1;
    } else {
        // Read factorgraph
        FactorGraph fg;
        char *infile = argv[1];
        int calc_tw = atoi(argv[2]);
        fg.ReadFromFile( infile );

        cout << "Number of variables:   " << fg.nrVars() << endl;
        cout << "Number of factors:     " << fg.nrFactors() << endl;
        cout << "Connected:             " << fg.isConnected() << endl;
        cout << "Tree:                  " << fg.isTree() << endl;
        cout << "Has short loops:       " << hasShortLoops(fg.factors()) << endl;
        cout << "Has negatives:         " << hasNegatives(fg.factors()) << endl;
        cout << "Binary variables?      " << fg.isBinary() << endl;
        cout << "Pairwise interactions? " << fg.isPairwise() << endl;
        if( calc_tw ) {
            std::pair<size_t,size_t> tw = treewidth(fg);
            cout << "Treewidth:           " << tw.first << endl;
            cout << "Largest cluster for JTree has " << tw.second << " states " << endl;
        }
        double stsp = 1.0;
        for( size_t i = 0; i < fg.nrVars(); i++ )
            stsp *= fg.var(i).states();
        cout << "Total state space:   " << stsp << endl;

        double cavsum_lcbp = 0.0;
        double cavsum_lcbp2 = 0.0;
        size_t max_Delta_size = 0;
        map<size_t,size_t> cavsizes;
        for( size_t i = 0; i < fg.nrVars(); i++ ) {
            VarSet di = fg.delta(i);
            if( cavsizes.count(di.size()) )
                cavsizes[di.size()]++;
            else
                cavsizes[di.size()] = 1;
            size_t Ds = fg.Delta(i).nrStates();
            if( Ds > max_Delta_size )
                max_Delta_size = Ds;
            cavsum_lcbp += di.nrStates();
            for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
                cavsum_lcbp2 += j->states();
        }
        cout << "Maximum pancake has " << max_Delta_size << " states" << endl;
        cout << "LCBP with full cavities needs " << cavsum_lcbp << " BP runs" << endl;
        cout << "LCBP with only pairinteractions needs " << cavsum_lcbp2 << " BP runs" << endl;
        cout << "Cavity sizes: ";
        for( map<size_t,size_t>::const_iterator it = cavsizes.begin(); it != cavsizes.end(); it++ ) 
            cout << it->first << "(" << it->second << ") ";
        cout << endl;

        cout << "Type: " << (fg.isPairwise() ? "pairwise" : "higher order") << " interactions, " << (fg.isBinary() ? "binary" : "nonbinary") << " variables" << endl;

        if( fg.isPairwise() ) {
            bool girth_reached = false;
            size_t loopdepth;
            for( loopdepth = 2; loopdepth <= fg.nrVars() && !girth_reached; loopdepth++ ) {
                size_t nr_loops = countLoops( fg, loopdepth );
                cout << "Loops up to " << loopdepth << " variables: " << nr_loops << endl;
                if( nr_loops > 0 )
                    girth_reached = true;
            }
            if( girth_reached )
                cout << "Girth: " << loopdepth-1 << endl;
            else
                cout << "Girth: infinity" << endl;
        }

        return 0;
    }
}