コード例 #1
0
TEST(Services, do_bfgs_optimize__lbfgs) {
  std::vector<double> cont_vector(2);
  cont_vector[0] = -1; cont_vector[1] = 1;
  std::vector<int> disc_vector;

  static const std::string DATA("");
  std::stringstream data_stream(DATA);
  stan::io::dump dummy_context(data_stream);
  Model model(dummy_context);

  typedef stan::optimization::BFGSLineSearch<Model,stan::optimization::LBFGSUpdate<> > Optimizer_LBFGS;
  Optimizer_LBFGS lbfgs(model, cont_vector, disc_vector, &std::cout);


  double lp = 0;
  bool save_iterations = true;
  int refresh = 0;
  int return_code;
  unsigned int random_seed = 0;
  rng_t base_rng(random_seed);

  std::fstream* output_stream = 0;
  mock_callback callback;

  return_code = stan::services::optimization::do_bfgs_optimize(model, lbfgs, base_rng,
                                                               lp, cont_vector, disc_vector,
                                                               output_stream, &std::cout,
                                                               save_iterations, refresh,
                                                               callback);
  EXPECT_FLOAT_EQ(return_code, 0);
  EXPECT_EQ(35, callback.n);
}
コード例 #2
0
TEST(StanCommon, reject_func_call_generated_quantities) {
  std::string error_msg = "user-specified rejection";

  std::fstream empty_data_stream(std::string("").c_str());
  stan::io::dump empty_data_context(empty_data_stream);
  empty_data_stream.close();
  std::stringstream model_output;
  model_output.str("");

  // instantiate model
  reject_func_call_generated_quantities_model_namespace::reject_func_call_generated_quantities_model* model 
       = new reject_func_call_generated_quantities_model_namespace::reject_func_call_generated_quantities_model(empty_data_context, &model_output);

  // instantiate args to log_prob function
  Eigen::VectorXd cont_params = Eigen::VectorXd::Zero(model->num_params_r());
  std::vector<double> cont_vector(cont_params.size());
  for (int i = 0; i < cont_params.size(); ++i)
    cont_vector.at(i) = cont_params(i);
  std::vector<int> disc_vector;
  double lp;

  boost::ecuyer1988 base_rng;
  base_rng.seed(123456);

  lp = model->log_prob<false, false>(cont_vector, disc_vector, &std::cout);
  try {
    stan::services::io::write_iteration(model_output, *model, base_rng,
                    lp, cont_vector, disc_vector);
  } catch (const std::domain_error& e) {
    if (std::string(e.what()).find(error_msg) == std::string::npos) {
      FAIL() << std::endl << "---------------------------------" << std::endl
             << "--- EXPECTED: error_msg=" << error_msg << std::endl
             << "--- FOUND: e.what()=" << e.what() << std::endl
             << "--------------------------------*" << std::endl
             << std::endl;
    }
    return;
  }
  FAIL() << "model failed to do reject" << std::endl;
}
コード例 #3
0
ファイル: advi.hpp プロジェクト: housian0724/stan
    /**
     * Runs the algorithm and writes to output.
     *
     * @param  tol_rel_obj    relative tolerance parameter for convergence
     * @param  max_iterations max number of iterations to run algorithm
     */
    int run(double tol_rel_obj, int max_iterations) const {
        if (diag_stream_) {
            *diag_stream_ << "iter,time_in_seconds,ELBO" << std::endl;
        }

        // initialize variational approximation
        Q variational = Q(cont_params_);

        // run inference algorithm
        robbins_monro_adagrad(variational, tol_rel_obj, max_iterations);

        // get mean of posterior approximation and write on first output line
        cont_params_ = variational.mean();
        std::vector<double> cont_vector(cont_params_.size());
        for (int i = 0; i < cont_params_.size(); ++i)
            cont_vector.at(i) = cont_params_(i);
        std::vector<int> disc_vector;

        if (out_stream_) {
            services::io::write_iteration(*out_stream_, model_, rng_,
                                          0.0, cont_vector, disc_vector,
                                          print_stream_);
        }

        // draw more samples from posterior and write on subsequent lines
        if (out_stream_) {
            for (int n = 0; n < n_posterior_samples_; ++n) {
                cont_params_ = variational.sample(rng_);
                for (int i = 0; i < cont_params_.size(); ++i)
                    cont_vector.at(i) = cont_params_(i);
                services::io::write_iteration(*out_stream_, model_, rng_,
                                              0.0, cont_vector, disc_vector, print_stream_);
            }
        }

        return stan::services::error_codes::OK;
    }