bool parse(const mxArray* opts_in)
    {
        if (!mxIsStruct(opts_in)) {
            mexErrMsgTxt("Options parameter must be a structure.\n");
            return (false);
        }
        int num_fields = mxGetNumberOfFields(opts_in);
        
        for (int fn=0; fn<num_fields; fn++) {
            const char* opt_name = mxGetFieldNameByNumber(opts_in, fn);
            std::string opt_name_str = opt_name;
            mxArray *opt_val = mxGetFieldByNumber(opts_in, 0, fn);
            if (opt_name_str == "num_max_iter") {
                num_max_iter = static_cast<size_t>(mxGetScalar(opt_val));
            }
            else if (opt_name_str == "rho") {
                rho = mxGetScalar(opt_val);
            }
            else if (opt_name_str == "eps_gnorm") {
                eps_gnorm = mxGetScalar(opt_val);
            }
            else if (opt_name_str == "solver") {
                std::string solver_name = GetMatlabString(opt_val);
                if (solver_name == "fistadescent") {
                    solver = SOLVER_FISTADESCENT;
                }
                else if (solver_name == "lbfgs") {
                    solver = SOLVER_LBFGS;
                }
                else {
                    mexErrMsgTxt("No solver with this name.\n");
                    return false;
                }
            }
            else {
                mexErrMsgTxt("Name of the option is invalid.\n");
                return (false);
            }
        }

        return true;
    }
Пример #2
0
    bool parse(const mxArray* opts_in)
    {
        if (!mxIsStruct(opts_in)) {
            mexErrMsgTxt("Options parameter must be a structure.\n");
            return (false);
        }
        int num_fields = mxGetNumberOfFields(opts_in);
        
        for (int fn=0; fn<num_fields; fn++) {
            const char* opt_name = mxGetFieldNameByNumber(opts_in, fn);
            std::string opt_name_str = opt_name;
            mxArray *opt_val = mxGetFieldByNumber(opts_in, 0, fn);
            if (opt_name_str == "rho_start") {
                rho_start = mxGetScalar(opt_val);
            }
            else if (opt_name_str == "rho_end") {
                rho_end = mxGetScalar(opt_val);
            }
            else if (opt_name_str == "eps_dkl") {
                eps_dkl = mxGetScalar(opt_val);
            }
            else if (opt_name_str == "eps_obj") {
                eps_obj = mxGetScalar(opt_val);
            }
            else if (opt_name_str == "eps_mp") {
                eps_mp = mxGetScalar(opt_val);
            }
            else if (opt_name_str == "eps_entropy") {
                eps_entropy = mxGetScalar(opt_val);
            }
            else if (opt_name_str == "num_max_iter_dc") {
                num_max_iter_dc = static_cast<size_t>(mxGetScalar(opt_val));
            }
            else if (opt_name_str == "num_max_iter_mp") {
                num_max_iter_mp = static_cast<size_t>(mxGetScalar(opt_val));
            }
            else if (opt_name_str == "rho_schedule_constant") {
                rho_schedule_constant = mxGetScalar(opt_val);
            }
            else if (opt_name_str == "do_round") {
                do_round = static_cast<bool>(mxGetScalar(opt_val));
            }
            else if (opt_name_str == "initial_lp_active") {
                initial_lp_active = static_cast<bool>(mxGetScalar(opt_val));
            }
            else if (opt_name_str == "initial_lp_improvement_ratio") {
                initial_lp_improvement_ratio = mxGetScalar(opt_val);
            }
            else if (opt_name_str == "initial_lp_rho_start") {
                initial_lp_rho_start = mxGetScalar(opt_val);
            }
            else if (opt_name_str == "initial_rho_similar_values") {
                initial_rho_similar_values = static_cast<bool>(mxGetScalar(opt_val));
            }
            else if (opt_name_str == "initial_rho_factor_kl_smaller") {
                initial_rho_factor_kl_smaller = mxGetScalar(opt_val);
            }
            else if (opt_name_str == "skip_if_increase") {
                skip_if_increase = static_cast<bool>(mxGetScalar(opt_val));
            }
            else if (opt_name_str == "solver_sdd") {
                std::string solver_name = GetMatlabString(opt_val);
                if (solver_name == "fistadescent") {
                    solver_sdd = LPQPSDD::FISTADESCENT;
                }
                else if (solver_name == "lbfgs") {
                    solver_sdd = LPQPSDD::LBFGS;
                }
                else {
                    mexErrMsgTxt("No solver with this name.\n");
                    return false;
                }
            }
            else {
                mexErrMsgTxt("Name of the option is invalid.\n");
                return (false);
            }
        }

        return true;
    }
Пример #3
0
// [states] = grante_sample(model, fg, method, sample_count, options);
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
	// Option structure
	const mxArray* opt_s = 0;
	if (nrhs >= 4 && mxIsEmpty(prhs[4]) == false)
		opt_s = prhs[4];

	MatlabCPPInitialize(GetScalarDefaultOption(opt_s, "verbose", 0) > 0);

	if (nrhs < 4 || nrhs > 5 || nlhs != 1) {
		mexErrMsgTxt("Wrong number of arguments.\n");
		MatlabCPPExit();
		return;
	}

	// Master model
	Grante::FactorGraphModel model;
	if (matlab_parse_factorgraphmodel(prhs[0], model) == false) {
		MatlabCPPExit();
		return;
	}

	// Parse factor graph
	std::vector<Grante::FactorGraph*> FG;
	bool fgs_parsed = matlab_parse_factorgraphs(model, prhs[1], FG);
	if (fgs_parsed == false) {
		MatlabCPPExit();
		return;
	}
	size_t num_fgs = FG.size();
	if (num_fgs != 1) {
		mexErrMsgTxt("mex_grante_sample supports only one "
			"factor graph at a time.\n");
		for (unsigned int fgi = 0; fgi < FG.size(); ++fgi)
			delete (FG[fgi]);

		MatlabCPPExit();
		return;
	}
	// Compute energies
	Grante::FactorGraph* fg = FG[0];
	fg->ForwardMap();

	// Parse sampling method
	std::string method_name = GetMatlabString(prhs[2]);
	Grante::InferenceMethod* inf = 0;
	if (method_name == "treeinf") {
		if (Grante::FactorGraphStructurizer::IsForestStructured(fg) == false) {
			mexErrMsgTxt("Exact sampling is currently only "
				"possible for tree-structured factor graphs.\n");
			MatlabCPPExit();
			return;
		}
		inf = new Grante::TreeInference(fg);
	} else if (method_name == "gibbs") {
		Grante::GibbsInference* ginf = new Grante::GibbsInference(fg);
		ginf->SetSamplingParameters(
			GetIntegerDefaultOption(opt_s, "gibbs_burnin", 100),
			GetIntegerDefaultOption(opt_s, "gibbs_spacing", 0),
			GetIntegerDefaultOption(opt_s, "gibbs_samples", 1));
		inf = ginf;
	} else if (method_name == "mcgibbs") {
		Grante::MultichainGibbsInference* mcginf =
			new Grante::MultichainGibbsInference(fg);

		mcginf->SetSamplingParameters(
			GetIntegerDefaultOption(opt_s, "mcgibbs_chains", 5),
			GetScalarDefaultOption(opt_s, "mcgibbs_maxpsrf", 1.01),
			GetIntegerDefaultOption(opt_s, "mcgibbs_spacing", 0),
			GetIntegerDefaultOption(opt_s, "mcgibbs_samples", 1000));
		inf = mcginf;
	} else {
		mexErrMsgTxt("Unknown sampling method.  Use 'treeinf', "
			"'gibbs', or 'mcgibbs'.\n");
		MatlabCPPExit();
		return;
	}

	// Parse sample_count
	if (mxIsDouble(prhs[3]) == false || mxGetNumberOfElements(prhs[3]) != 1) {
		mexErrMsgTxt("sample_count must be a (1,1) double array.\n");
		MatlabCPPExit();
		return;
	}
	unsigned int sample_count = mxGetScalar(prhs[3]);
	assert(sample_count > 0);

	// Perform inference
	mexPrintf("[Grante] performing sampling using method: '%s'\n",
		method_name.c_str());

	unsigned int var_count = fg->Cardinalities().size();
	plhs[0] = mxCreateNumericMatrix(var_count, sample_count,
		mxDOUBLE_CLASS, mxREAL);
	double* sample_p = mxGetPr(plhs[0]);

	// Sample
	std::vector<std::vector<unsigned int> > states;
	inf->Sample(states, sample_count);
	assert(states.size() == sample_count);
	for (unsigned int si = 0; si < states.size(); ++si) {
		// Add 1 for Matlab indexing
		std::transform(states[si].begin(), states[si].end(),
			&sample_p[si * var_count],
			std::bind2nd(std::plus<double>(), 1.0));
	}
	delete (inf);
	delete (fg);
	MatlabCPPExit();
}