Beispiel #1
0
void
mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    MatlabCPPInitialize(false);

    // Check for proper number of arguments
    if ( (nrhs < NR_IN) || (nrhs > NR_IN + NR_IN_OPT) || \
            (nlhs < NR_OUT) || (nlhs > NR_OUT + NR_OUT_OPT) ) {
        mexErrMsgTxt("Wrong number of arguments.");
    }

    // fetch the input (we support matrices and cell arrays)
    std::vector< Eigen::VectorXd > theta_unary;
    std::vector< Eigen::VectorXd > theta_pair;

    if (!mxIsCell(THETA_UNARY_IN)) {
        GetMatlabPotentialFromMatrix(THETA_UNARY_IN, theta_unary);
        GetMatlabPotentialFromMatrix(THETA_PAIR_IN, theta_pair);
    }
    else {
        GetMatlabPotential(THETA_UNARY_IN, theta_unary);
        GetMatlabPotential(THETA_PAIR_IN, theta_pair);
    }
    
    Eigen::MatrixXi edges;
    GetMatlabMatrix(EDGES_IN, edges);

    double beta = mxGetScalar(BETA_IN);
    
    TreeInference inf = TreeInference(theta_unary, theta_pair, edges);
    inf.run(beta);
            
    LOGZ_OUT = mxCreateNumericMatrix(1, 1, mxDOUBLE_CLASS, mxREAL);
    double* logz_p = mxGetPr(LOGZ_OUT);
    logz_p[0] = inf.getLogPartitionSum();
    
    // return marginals (if input was provided as a matrix, also return
    // marginals in a matrix)
    if (nlhs > 1) {
        std::vector< Eigen::VectorXd >& mu_unary = inf.getUnaryMarginals();
        if (!mxIsCell(THETA_UNARY_IN)) {
            size_t num_states = mxGetM(THETA_UNARY_IN);
            size_t num_vars = mxGetN(THETA_UNARY_IN);
            MARGINALS_UNARY_OUT = mxCreateNumericMatrix(
                                    num_states, 
                                    num_vars,
                                    mxDOUBLE_CLASS, mxREAL);
            double* mu_res_p = mxGetPr(MARGINALS_UNARY_OUT);
            for (size_t v_idx=0; v_idx<mu_unary.size(); v_idx++) {
                for (size_t idx=0; idx<mu_unary[v_idx].size(); idx++) {
                    mu_res_p[v_idx*num_states+idx] = mu_unary[v_idx](idx);
                }
            }
        }
        else {
            mwSize dim_0 = static_cast<mwSize>(mu_unary.size());
            MARGINALS_UNARY_OUT = mxCreateCellArray(1, &dim_0);
            for (size_t v_idx=0; v_idx<mu_unary.size(); v_idx++) {
                mxArray* m = mxCreateNumericMatrix(
                                static_cast<int>(mu_unary[v_idx].size()),
                                1,
                                mxDOUBLE_CLASS, mxREAL);
                double* mu_res_p = mxGetPr(m);
                for (size_t idx=0; idx<mu_unary[v_idx].size(); idx++) {
                    mu_res_p[idx] = mu_unary[v_idx](idx);
                }

                mxSetCell(MARGINALS_UNARY_OUT, v_idx, m);
            }
        }
    }

    // return pairwise marginals
    if (nlhs > 2) {
        std::vector< Eigen::VectorXd >& mu_pair = inf.getPairwiseMarginals();
        if (!mxIsCell(THETA_UNARY_IN)) {
            size_t num_states = mxGetM(THETA_PAIR_IN);
            size_t num_edges = mxGetN(THETA_PAIR_IN);
            MARGINALS_PAIR_OUT = mxCreateNumericMatrix(
                                    num_states, 
                                    num_edges,
                                    mxDOUBLE_CLASS, mxREAL);
            double* mu_res_p = mxGetPr(MARGINALS_PAIR_OUT);
            for (size_t e_idx=0; e_idx<mu_pair.size(); e_idx++) {
                for (size_t idx=0; idx<mu_pair[e_idx].size(); idx++) {
                    mu_res_p[e_idx*num_states+idx] = mu_pair[e_idx](idx);
                }
            }
        }
        else {
            mxAssert(0, "not implemented yet!");
            // TODO: not implemented yet, see above!
            // see unary case above!
        }
    }

    MatlabCPPExit();
}
Beispiel #2
0
void
mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    MatlabCPPInitialize(false);

    // Check for proper number of arguments
    if ( (nrhs < NR_IN) || (nrhs > NR_IN + NR_IN_OPT) || \
            (nlhs < NR_OUT) || (nlhs > NR_OUT + NR_OUT_OPT) ) {
        mexErrMsgTxt("Wrong number of arguments.");
    }

    // fetch the input (we support matrices and cell arrays)
    std::vector< Eigen::VectorXd > theta_unary;
    std::vector< Eigen::VectorXd > theta_pair;

    if (!mxIsCell(THETA_UNARY_IN)) {
        GetMatlabPotentialFromMatrix(THETA_UNARY_IN, theta_unary);
        GetMatlabPotentialFromMatrix(THETA_PAIR_IN, theta_pair);
    }
    else {
        GetMatlabPotential(THETA_UNARY_IN, theta_unary);
        GetMatlabPotential(THETA_PAIR_IN, theta_pair);
    }
    
    Eigen::MatrixXi edges;
    GetMatlabMatrix(EDGES_IN, edges);

    // decomposition: cell array of edge index vectors
    std::vector< std::vector<size_t> > decomposition;
    if (nrhs >= 5) {
        size_t num = mxGetNumberOfElements(DECOMPOSITION_IN);
        decomposition.resize(num);
        for (size_t d_idx=0; d_idx<num; d_idx++) {
            const mxArray* v = mxGetCell(DECOMPOSITION_IN, d_idx);
            size_t num_edges = mxGetM(v);
            decomposition[d_idx].resize(num_edges);
	        const double* ptr = mxGetPr(v);
            for (size_t e_idx=0; e_idx<num_edges; e_idx++) {
                decomposition[d_idx][e_idx] = static_cast<size_t>(ptr[e_idx]);
            }
        }
    }
    
    // parse options
    MexLPQPOptions options;
    if (nrhs >= 4) {
        bool opts_parsed = options.parse(OPTIONS_IN);
        if (!opts_parsed) {
            MatlabCPPExit();
            return;
        }
    }

    LPQP* lpqp;
    if (nrhs < 5) {
        lpqp = new LPQPNPBP(theta_unary, theta_pair, edges);
    }
    else {
        lpqp = new LPQPSDD(theta_unary, theta_pair, edges, decomposition, options.solver_sdd);
    }

    lpqp->setRhoStart(options.rho_start);
    lpqp->setRhoEnd(options.rho_end);
    lpqp->setEpsilonEntropy(options.eps_entropy);
    lpqp->setEpsilonKullbackLeibler(options.eps_dkl);
    lpqp->setEpsilonObjective(options.eps_obj);
    lpqp->setEpsilonMP(options.eps_mp);
    lpqp->setMaximumNumberOfIterationsDC(options.num_max_iter_dc);
    lpqp->setMaximumNumberOfIterationsMP(options.num_max_iter_mp);
    lpqp->setRhoScheduleConstant(options.rho_schedule_constant);
    lpqp->setInitialLPActive(options.initial_lp_active);
    lpqp->setInitialLPImprovmentRatio(options.initial_lp_improvement_ratio);
    lpqp->setInitialLPRhoStart(options.initial_lp_rho_start);
    lpqp->setInitialRhoSimilarValues(options.initial_rho_similar_values);
    lpqp->setInitialRhoFactorKLSmaller(options.initial_rho_factor_kl_smaller);
    lpqp->setSkipIfIncrease(options.skip_if_increase);

    lpqp->run();
    
    if (options.do_round) {
        double curr_qp_obj = lpqp->computeQPValue();

        lpqp->roundSolution();

        double qp_obj_after_rounding = lpqp->computeQPValue();
        printf("QP objective before rounding: %f after rounding: %f\n",curr_qp_obj, qp_obj_after_rounding);
    }

    // return marginals (if input was provided as a matrix, also return
    // marginals in a matrix)
    std::vector< Eigen::VectorXd >& mu_unary = lpqp->getUnaryMarginals();
    if (!mxIsCell(THETA_UNARY_IN)) {
        size_t num_states = mxGetM(THETA_UNARY_IN);
        size_t num_vars = mxGetN(THETA_UNARY_IN);
        MARGINALS_UNARY_OUT = mxCreateNumericMatrix(
                                num_states, 
                                num_vars,
                                mxDOUBLE_CLASS, mxREAL);
        double* mu_res_p = mxGetPr(MARGINALS_UNARY_OUT);
        for (size_t v_idx=0; v_idx<mu_unary.size(); v_idx++) {
            for (size_t idx=0; idx<mu_unary[v_idx].size(); idx++) {
                mu_res_p[v_idx*num_states+idx] = mu_unary[v_idx](idx);
            }
        }
    }
    else {
        mwSize dim_0 = static_cast<mwSize>(mu_unary.size());
        MARGINALS_UNARY_OUT = mxCreateCellArray(1, &dim_0);
        for (size_t v_idx=0; v_idx<mu_unary.size(); v_idx++) {
            mxArray* m = mxCreateNumericMatrix(
                            static_cast<int>(mu_unary[v_idx].size()),
                            1,
                            mxDOUBLE_CLASS, mxREAL);
            double* mu_res_p = mxGetPr(m);
            for (size_t idx=0; idx<mu_unary[v_idx].size(); idx++) {
                mu_res_p[idx] = mu_unary[v_idx](idx);
            }

            mxSetCell(MARGINALS_UNARY_OUT, v_idx, m);
        }
    }
    
    // return history
    if (nlhs > 1) {
        std::vector<double>& history_obj = lpqp->getHistoryObjective();
        std::vector<double>& history_obj_qp = lpqp->getHistoryObjectiveQP();
        std::vector<double>& history_obj_lp = lpqp->getHistoryObjectiveLP();
        std::vector<double>& history_obj_decoded = lpqp->getHistoryObjectiveDecoded();
        std::vector<size_t>& history_iteration = lpqp->getHistoryIteration();
        std::vector<double>& history_rho = lpqp->getHistoryRho();
        //std::vector< Eigen::VectorXi > history_decoded = lpqp.getHistoryDecoded();

        //const char *field_names[] = {"obj", "obj_qp", "obj_lp", "obj_decoded", "iteration", "beta", "decoded"};
        const char *field_names[] = {"obj", "obj_qp", "obj_lp", "obj_decoded", "iteration", "rho"};
        mwSize dims[2];
        dims[0] = 1;
        dims[1] = history_obj.size();
        HISTORY_OUT = mxCreateStructArray(2, dims, sizeof(field_names)/sizeof(*field_names), field_names);

        int field_obj, field_obj_qp, field_obj_lp, field_obj_decoded, field_iteration, field_rho;
        //int field_decoded;
        field_obj = mxGetFieldNumber(HISTORY_OUT, "obj");
        field_obj_qp = mxGetFieldNumber(HISTORY_OUT, "obj_qp");
        field_obj_lp = mxGetFieldNumber(HISTORY_OUT, "obj_lp");
        field_obj_decoded = mxGetFieldNumber(HISTORY_OUT, "obj_decoded");
        field_iteration = mxGetFieldNumber(HISTORY_OUT, "iteration");
        field_rho = mxGetFieldNumber(HISTORY_OUT, "rho");
        //field_decoded = mxGetFieldNumber(HISTORY_OUT, "decoded");

        for (size_t i=0; i<history_obj.size(); i++) {
            mxArray *field_value;

            field_value = mxCreateDoubleMatrix(1,1,mxREAL);
            *mxGetPr(field_value) = history_obj[i];
            mxSetFieldByNumber(HISTORY_OUT, i, field_obj, field_value);
            
            field_value = mxCreateDoubleMatrix(1,1,mxREAL);
            *mxGetPr(field_value) = history_obj_qp[i];
            mxSetFieldByNumber(HISTORY_OUT, i, field_obj_qp, field_value);
            
            field_value = mxCreateDoubleMatrix(1,1,mxREAL);
            *mxGetPr(field_value) = history_obj_lp[i];
            mxSetFieldByNumber(HISTORY_OUT, i, field_obj_lp, field_value);
            
            field_value = mxCreateDoubleMatrix(1,1,mxREAL);
            *mxGetPr(field_value) = history_obj_decoded[i];
            mxSetFieldByNumber(HISTORY_OUT, i, field_obj_decoded, field_value);
            
            field_value = mxCreateDoubleMatrix(1,1,mxREAL);
            *mxGetPr(field_value) = static_cast<double>(history_iteration[i]);
            mxSetFieldByNumber(HISTORY_OUT, i, field_iteration, field_value);
            
            field_value = mxCreateDoubleMatrix(1,1,mxREAL);
            *mxGetPr(field_value) = history_rho[i];
            mxSetFieldByNumber(HISTORY_OUT, i, field_rho, field_value);
            
            //if (options.with_decoded_history) {
            //    field_value = mxCreateDoubleMatrix(history_decoded[i].size(),1,mxREAL);
            //    double* decoded_res_p = mxGetPr(field_value);
            //    for (size_t idx=0; idx<history_decoded[i].size(); idx++) {
            //        decoded_res_p[idx] = static_cast<double>(history_decoded[i][idx]);
            //    }
            //    mxSetFieldByNumber(HISTORY_OUT, i, field_decoded, field_value);
            //}
        }
    }

    delete lpqp;

    MatlabCPPExit();
}
void
mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    MatlabCPPInitialize(false);

    // Check for proper number of arguments
    if ( (nrhs < NR_IN) || (nrhs > NR_IN + NR_IN_OPT) || \
            (nlhs < NR_OUT) || (nlhs > NR_OUT + NR_OUT_OPT) ) {
        mexErrMsgTxt("Wrong number of arguments.");
    }

    // fetch the input (we support matrices and cell arrays)
    std::vector< Eigen::VectorXd > theta_unary;
    std::vector< Eigen::VectorXd > theta_pair;

    if (!mxIsCell(THETA_UNARY_IN)) {
        GetMatlabPotentialFromMatrix(THETA_UNARY_IN, theta_unary);
        GetMatlabPotentialFromMatrix(THETA_PAIR_IN, theta_pair);
    }
    else {
        GetMatlabPotential(THETA_UNARY_IN, theta_unary);
        GetMatlabPotential(THETA_PAIR_IN, theta_pair);
    }
    
    Eigen::MatrixXi edges;
    GetMatlabMatrix(EDGES_IN, edges);

    // decomposition: cell array of edge index vectors
    std::vector< std::vector<size_t> > decomposition;
    size_t num = mxGetNumberOfElements(DECOMPOSITION_IN);
    decomposition.resize(num);
    for (size_t d_idx=0; d_idx<num; d_idx++) {
        const mxArray* v = mxGetCell(DECOMPOSITION_IN, d_idx);
        size_t num_edges = mxGetM(v);
        decomposition[d_idx].resize(num_edges);
	    const double* ptr = mxGetPr(v);
        for (size_t e_idx=0; e_idx<num_edges; e_idx++) {
            decomposition[d_idx][e_idx] = static_cast<size_t>(ptr[e_idx]);
        }
    }
    
    // parse options
    MexSmoothDDOptions options;
    if (nrhs >= 5) {
        bool opts_parsed = options.parse(OPTIONS_IN);
        if (!opts_parsed) {
            MatlabCPPExit();
            return;
        }
    }
    
    // set algorithm & parameters according to options
    SmoothDualDecomposition* sdd;
    /*if (options.solver == SOLVER_FISTADESCENT) {
        sdd = new SmoothDualDecompositionFistaDescent(theta_unary, theta_pair, edges, decomposition);
    }
    if (options.solver == SOLVER_FISTA) {
        sdd = new SmoothDualDecompositionFista(theta_unary, theta_pair, edges, decomposition);
    }
    else if (options.solver == SOLVER_GRADIENTDESCENT) {
        sdd = new SmoothDualDecompositionGradientDescent(theta_unary, theta_pair, edges, decomposition);
    }
    else if (options.solver == SOLVER_NESTEROV) {
        sdd = new SmoothDualDecompositionNesterov(theta_unary, theta_pair, edges, decomposition);
    }
    else {
        mxAssert(0, "Solver not found. Should not happen!");
    }*/
    // TODO
    sdd = new SmoothDualDecompositionFistaDescent(theta_unary, theta_pair, edges, decomposition);
    //sdd = new SmoothDualDecompositionLBFGS(theta_unary, theta_pair, edges, decomposition);
    sdd->setMaximumNumberOfIterations(options.num_max_iter);
    sdd->setEpsilonGradientNorm(options.eps_gnorm);

    // run the computations
    sdd->run(options.rho);

    // return marginals (if input was provided as a matrix, also return
    // marginals in a matrix)
    std::vector< Eigen::VectorXd >& mu_unary = sdd->getUnaryMarginals();
    if (!mxIsCell(THETA_UNARY_IN)) {
        size_t num_states = mxGetM(THETA_UNARY_IN);
        size_t num_vars = mxGetN(THETA_UNARY_IN);
        MARGINALS_UNARY_OUT = mxCreateNumericMatrix(
                                num_states, 
                                num_vars,
                                mxDOUBLE_CLASS, mxREAL);
        double* mu_res_p = mxGetPr(MARGINALS_UNARY_OUT);
        for (size_t v_idx=0; v_idx<mu_unary.size(); v_idx++) {
            for (size_t idx=0; idx<mu_unary[v_idx].size(); idx++) {
                mu_res_p[v_idx*num_states+idx] = mu_unary[v_idx](idx);
            }
        }
    }
    else {
        mwSize dim_0 = static_cast<mwSize>(mu_unary.size());
        MARGINALS_UNARY_OUT = mxCreateCellArray(1, &dim_0);
        for (size_t v_idx=0; v_idx<mu_unary.size(); v_idx++) {
            mxArray* m = mxCreateNumericMatrix(
                            static_cast<int>(mu_unary[v_idx].size()),
                            1,
                            mxDOUBLE_CLASS, mxREAL);
            double* mu_res_p = mxGetPr(m);
            for (size_t idx=0; idx<mu_unary[v_idx].size(); idx++) {
                mu_res_p[idx] = mu_unary[v_idx](idx);
            }

            mxSetCell(MARGINALS_UNARY_OUT, v_idx, m);
        }
    }

    // return history
    /*if (nlhs > 1) {
        std::vector<double> history_obj = lpqp.getHistoryObjective();
        std::vector<double> history_obj_qp = lpqp.getHistoryObjectiveQP();
        std::vector<double> history_obj_lp = lpqp.getHistoryObjectiveLP();
        std::vector<double> history_obj_decoded = lpqp.getHistoryObjectiveDecoded();
        std::vector<size_t> history_iteration = lpqp.getHistoryIteration();
        std::vector<double> history_beta = lpqp.getHistoryBeta();
        std::vector< Eigen::VectorXi > history_decoded = lpqp.getHistoryDecoded();

        const char *field_names[] = {"obj", "obj_qp", "obj_lp", "obj_decoded", "iteration", "beta", "decoded"};
        mwSize dims[2];
        dims[0] = 1;
        dims[1] = history_obj.size();
        HISTORY_OUT = mxCreateStructArray(2, dims, sizeof(field_names)/sizeof(*field_names), field_names);

        int field_obj, field_obj_qp, field_obj_lp, field_obj_decoded, field_iteration, field_beta, field_decoded;
        field_obj = mxGetFieldNumber(HISTORY_OUT, "obj");
        field_obj_qp = mxGetFieldNumber(HISTORY_OUT, "obj_qp");
        field_obj_lp = mxGetFieldNumber(HISTORY_OUT, "obj_lp");
        field_obj_decoded = mxGetFieldNumber(HISTORY_OUT, "obj_decoded");
        field_iteration = mxGetFieldNumber(HISTORY_OUT, "iteration");
        field_beta = mxGetFieldNumber(HISTORY_OUT, "beta");
        field_decoded = mxGetFieldNumber(HISTORY_OUT, "decoded");

        for (size_t i=0; i<history_obj.size(); i++) {
            mxArray *field_value;

            field_value = mxCreateDoubleMatrix(1,1,mxREAL);
            *mxGetPr(field_value) = history_obj[i];
            mxSetFieldByNumber(HISTORY_OUT, i, field_obj, field_value);
            
            field_value = mxCreateDoubleMatrix(1,1,mxREAL);
            *mxGetPr(field_value) = history_obj_qp[i];
            mxSetFieldByNumber(HISTORY_OUT, i, field_obj_qp, field_value);
            
            field_value = mxCreateDoubleMatrix(1,1,mxREAL);
            *mxGetPr(field_value) = history_obj_lp[i];
            mxSetFieldByNumber(HISTORY_OUT, i, field_obj_lp, field_value);
            
            field_value = mxCreateDoubleMatrix(1,1,mxREAL);
            *mxGetPr(field_value) = history_obj_decoded[i];
            mxSetFieldByNumber(HISTORY_OUT, i, field_obj_decoded, field_value);
            
            field_value = mxCreateDoubleMatrix(1,1,mxREAL);
            *mxGetPr(field_value) = static_cast<double>(history_iteration[i]);
            mxSetFieldByNumber(HISTORY_OUT, i, field_iteration, field_value);
            
            field_value = mxCreateDoubleMatrix(1,1,mxREAL);
            *mxGetPr(field_value) = history_beta[i];
            mxSetFieldByNumber(HISTORY_OUT, i, field_beta, field_value);
            
            if (options.with_decoded_history) {
                field_value = mxCreateDoubleMatrix(history_decoded[i].size(),1,mxREAL);
                double* decoded_res_p = mxGetPr(field_value);
                for (size_t idx=0; idx<history_decoded[i].size(); idx++) {
                    decoded_res_p[idx] = static_cast<double>(history_decoded[i][idx]);
                }
                mxSetFieldByNumber(HISTORY_OUT, i, field_decoded, field_value);
            }
        }
    }*/
    
    // return pairwise marginals
    if (nlhs > 1) {
        std::vector< Eigen::VectorXd >& mu_pair = sdd->getPairwiseMarginals();
        if (!mxIsCell(THETA_UNARY_IN)) {
            size_t num_states = mxGetM(THETA_PAIR_IN);
            size_t num_edges = mxGetN(THETA_PAIR_IN);
            MARGINALS_PAIR_OUT = mxCreateNumericMatrix(
                                    num_states, 
                                    num_edges,
                                    mxDOUBLE_CLASS, mxREAL);
            double* mu_res_p = mxGetPr(MARGINALS_PAIR_OUT);
            for (size_t e_idx=0; e_idx<mu_pair.size(); e_idx++) {
                for (size_t idx=0; idx<mu_pair[e_idx].size(); idx++) {
                    mu_res_p[e_idx*num_states+idx] = mu_pair[e_idx](idx);
                }
            }
        }
        else {
            mwSize dim_0 = static_cast<mwSize>(mu_pair.size());
            MARGINALS_PAIR_OUT = mxCreateCellArray(1, &dim_0);
            for (size_t e_idx=0; e_idx<mu_pair.size(); e_idx++) {
                mxArray* m = mxCreateNumericMatrix(
                                static_cast<int>(mu_pair[e_idx].size()),
                                1,
                                mxDOUBLE_CLASS, mxREAL);
                double* mu_res_p = mxGetPr(m);
                for (size_t idx=0; idx<mu_pair[e_idx].size(); idx++) {
                    mu_res_p[idx] = mu_pair[e_idx](idx);
                }

                mxSetCell(MARGINALS_PAIR_OUT, e_idx, m);
            }
        }
    }

    delete sdd;

    MatlabCPPExit();
}
Beispiel #4
0
// [states] = grante_sample(model, fg, method, sample_count, options);
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
	// Option structure
	const mxArray* opt_s = 0;
	if (nrhs >= 4 && mxIsEmpty(prhs[4]) == false)
		opt_s = prhs[4];

	MatlabCPPInitialize(GetScalarDefaultOption(opt_s, "verbose", 0) > 0);

	if (nrhs < 4 || nrhs > 5 || nlhs != 1) {
		mexErrMsgTxt("Wrong number of arguments.\n");
		MatlabCPPExit();
		return;
	}

	// Master model
	Grante::FactorGraphModel model;
	if (matlab_parse_factorgraphmodel(prhs[0], model) == false) {
		MatlabCPPExit();
		return;
	}

	// Parse factor graph
	std::vector<Grante::FactorGraph*> FG;
	bool fgs_parsed = matlab_parse_factorgraphs(model, prhs[1], FG);
	if (fgs_parsed == false) {
		MatlabCPPExit();
		return;
	}
	size_t num_fgs = FG.size();
	if (num_fgs != 1) {
		mexErrMsgTxt("mex_grante_sample supports only one "
			"factor graph at a time.\n");
		for (unsigned int fgi = 0; fgi < FG.size(); ++fgi)
			delete (FG[fgi]);

		MatlabCPPExit();
		return;
	}
	// Compute energies
	Grante::FactorGraph* fg = FG[0];
	fg->ForwardMap();

	// Parse sampling method
	std::string method_name = GetMatlabString(prhs[2]);
	Grante::InferenceMethod* inf = 0;
	if (method_name == "treeinf") {
		if (Grante::FactorGraphStructurizer::IsForestStructured(fg) == false) {
			mexErrMsgTxt("Exact sampling is currently only "
				"possible for tree-structured factor graphs.\n");
			MatlabCPPExit();
			return;
		}
		inf = new Grante::TreeInference(fg);
	} else if (method_name == "gibbs") {
		Grante::GibbsInference* ginf = new Grante::GibbsInference(fg);
		ginf->SetSamplingParameters(
			GetIntegerDefaultOption(opt_s, "gibbs_burnin", 100),
			GetIntegerDefaultOption(opt_s, "gibbs_spacing", 0),
			GetIntegerDefaultOption(opt_s, "gibbs_samples", 1));
		inf = ginf;
	} else if (method_name == "mcgibbs") {
		Grante::MultichainGibbsInference* mcginf =
			new Grante::MultichainGibbsInference(fg);

		mcginf->SetSamplingParameters(
			GetIntegerDefaultOption(opt_s, "mcgibbs_chains", 5),
			GetScalarDefaultOption(opt_s, "mcgibbs_maxpsrf", 1.01),
			GetIntegerDefaultOption(opt_s, "mcgibbs_spacing", 0),
			GetIntegerDefaultOption(opt_s, "mcgibbs_samples", 1000));
		inf = mcginf;
	} else {
		mexErrMsgTxt("Unknown sampling method.  Use 'treeinf', "
			"'gibbs', or 'mcgibbs'.\n");
		MatlabCPPExit();
		return;
	}

	// Parse sample_count
	if (mxIsDouble(prhs[3]) == false || mxGetNumberOfElements(prhs[3]) != 1) {
		mexErrMsgTxt("sample_count must be a (1,1) double array.\n");
		MatlabCPPExit();
		return;
	}
	unsigned int sample_count = mxGetScalar(prhs[3]);
	assert(sample_count > 0);

	// Perform inference
	mexPrintf("[Grante] performing sampling using method: '%s'\n",
		method_name.c_str());

	unsigned int var_count = fg->Cardinalities().size();
	plhs[0] = mxCreateNumericMatrix(var_count, sample_count,
		mxDOUBLE_CLASS, mxREAL);
	double* sample_p = mxGetPr(plhs[0]);

	// Sample
	std::vector<std::vector<unsigned int> > states;
	inf->Sample(states, sample_count);
	assert(states.size() == sample_count);
	for (unsigned int si = 0; si < states.size(); ++si) {
		// Add 1 for Matlab indexing
		std::transform(states[si].begin(), states[si].end(),
			&sample_p[si * var_count],
			std::bind2nd(std::plus<double>(), 1.0));
	}
	delete (inf);
	delete (fg);
	MatlabCPPExit();
}