TEST(McmcNutsBaseNuts, build_tree_test) { rng_t base_rng(0); int model_size = 1; double init_momentum = 1.5; stan::mcmc::ps_point z_init(model_size); z_init.q(0) = 0; z_init.p(0) = init_momentum; stan::mcmc::ps_point z_propose(model_size); Eigen::VectorXd p_sharp_left = Eigen::VectorXd::Zero(model_size); Eigen::VectorXd p_sharp_right = Eigen::VectorXd::Zero(model_size); Eigen::VectorXd rho = z_init.p; double log_sum_weight = -std::numeric_limits<double>::infinity(); double H0 = -0.1; int n_leapfrog = 0; double sum_metro_prob = 0; stan::mcmc::mock_model model(model_size); stan::mcmc::mock_nuts sampler(model, base_rng); sampler.set_nominal_stepsize(1); sampler.set_stepsize_jitter(0); sampler.sample_stepsize(); sampler.z() = z_init; std::stringstream debug, info, warn, error, fatal; stan::callbacks::stream_logger logger(debug, info, warn, error, fatal); bool valid_subtree = sampler.build_tree(3, z_propose, p_sharp_left, p_sharp_right, rho, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); EXPECT_TRUE(valid_subtree); EXPECT_EQ(init_momentum * (n_leapfrog + 1), rho(0)); EXPECT_EQ(1, p_sharp_left(0)); EXPECT_EQ(1, p_sharp_right(0)); EXPECT_EQ(8 * init_momentum, sampler.z().q(0)); EXPECT_EQ(init_momentum, sampler.z().p(0)); EXPECT_EQ(8, n_leapfrog); EXPECT_FLOAT_EQ(H0 + std::log(n_leapfrog), log_sum_weight); EXPECT_FLOAT_EQ(std::exp(H0) * n_leapfrog, sum_metro_prob); EXPECT_EQ("", debug.str()); EXPECT_EQ("", info.str()); EXPECT_EQ("", warn.str()); EXPECT_EQ("", error.str()); EXPECT_EQ("", fatal.str()); }
TEST(McmcBaseNuts, slice_criterion) { rng_t base_rng(0); int model_size = 1; double init_momentum = 1.5; Eigen::VectorXd rho = Eigen::VectorXd::Zero(model_size); stan::mcmc::ps_point z_init(model_size); z_init.q(0) = 0; z_init.p(0) = init_momentum; stan::mcmc::ps_point z_propose(model_size); stan::mcmc::nuts_util util; util.log_u = 0; util.H0 = 0; util.sign = 1; util.n_tree = 0; util.sum_prob = 0; std::stringstream output, error; stan::mcmc::mock_model model(model_size); stan::mcmc::divergent_nuts sampler(model, base_rng, &output, &error); sampler.set_nominal_stepsize(1); sampler.set_stepsize_jitter(0); sampler.sample_stepsize(); sampler.z() = z_init; int n_valid = 0; sampler.z().V = -750; n_valid = sampler.build_tree(0, rho, &z_init, z_propose, util); EXPECT_EQ(1, n_valid); EXPECT_EQ(0, sampler.n_divergent_); sampler.z().V = -250; n_valid = sampler.build_tree(0, rho, &z_init, z_propose, util); EXPECT_EQ(0, n_valid); EXPECT_EQ(0, sampler.n_divergent_); sampler.z().V = 750; n_valid = sampler.build_tree(0, rho, &z_init, z_propose, util); EXPECT_EQ(0, n_valid); EXPECT_EQ(1, sampler.n_divergent_); EXPECT_EQ("", output.str()); EXPECT_EQ("", error.str()); }
TEST(McmcNutsBaseNuts, build_tree) { rng_t base_rng(0); int model_size = 1; double init_momentum = 1.5; stan::mcmc::ps_point z_init(model_size); z_init.q(0) = 0; z_init.p(0) = init_momentum; stan::mcmc::ps_point z_propose(model_size); Eigen::VectorXd rho = z_init.p; double log_sum_weight = -std::numeric_limits<double>::infinity(); double H0 = -0.1; int n_leapfrog = 0; double sum_metro_prob = 0; stan::mcmc::mock_model model(model_size); stan::mcmc::mock_nuts sampler(model, base_rng); sampler.set_nominal_stepsize(1); sampler.set_stepsize_jitter(0); sampler.sample_stepsize(); sampler.z() = z_init; std::stringstream output; stan::interface_callbacks::writer::stream_writer writer(output); std::stringstream error_stream; stan::interface_callbacks::writer::stream_writer error_writer(error_stream); bool valid_subtree = sampler.build_tree(3, rho, z_propose, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, writer, error_writer); EXPECT_TRUE(valid_subtree); EXPECT_EQ(init_momentum * (n_leapfrog + 1), rho(0)); EXPECT_EQ(8 * init_momentum, sampler.z().q(0)); EXPECT_EQ(init_momentum, sampler.z().p(0)); EXPECT_EQ(8, n_leapfrog); EXPECT_FLOAT_EQ(H0 + std::log(n_leapfrog), log_sum_weight); EXPECT_FLOAT_EQ(std::exp(H0) * n_leapfrog, sum_metro_prob); EXPECT_EQ("", output.str()); EXPECT_EQ("", error_stream.str()); }
TEST(McmcBaseNuts, build_tree) { rng_t base_rng(0); int model_size = 1; double init_momentum = 1.5; Eigen::VectorXd rho = Eigen::VectorXd::Zero(model_size); stan::mcmc::ps_point z_init(model_size); z_init.q(0) = 0; z_init.p(0) = init_momentum; stan::mcmc::ps_point z_propose(model_size); stan::mcmc::nuts_util util; util.log_u = -1; util.H0 = -0.1; util.sign = 1; util.n_tree = 0; util.sum_prob = 0; std::stringstream output, error; stan::mcmc::mock_model model(model_size); stan::mcmc::mock_nuts sampler(model, base_rng, &output, &error); sampler.set_nominal_stepsize(1); sampler.set_stepsize_jitter(0); sampler.sample_stepsize(); sampler.z() = z_init; int n_valid = sampler.build_tree(3, rho, &z_init, z_propose, util); EXPECT_EQ(8, n_valid); EXPECT_EQ(8, util.n_tree); EXPECT_FLOAT_EQ(std::exp(util.H0) * util.n_tree, util.sum_prob); EXPECT_EQ(init_momentum * util.n_tree, rho(0)); EXPECT_EQ(init_momentum, z_init.q(0)); EXPECT_EQ(init_momentum, z_init.p(0)); EXPECT_EQ(8 * init_momentum, sampler.z().q(0)); EXPECT_EQ(init_momentum, sampler.z().p(0)); EXPECT_EQ("", output.str()); EXPECT_EQ("", error.str()); }
TEST(McmcNutsBaseNuts, rho_aggregation_test) { rng_t base_rng(0); int model_size = 1; double init_momentum = 1.5; stan::mcmc::ps_point z_init(model_size); z_init.q(0) = 0; z_init.p(0) = init_momentum; stan::mcmc::ps_point z_propose(model_size); Eigen::VectorXd p_sharp_left = Eigen::VectorXd::Zero(model_size); Eigen::VectorXd p_sharp_right = Eigen::VectorXd::Zero(model_size); Eigen::VectorXd rho = z_init.p; double log_sum_weight = -std::numeric_limits<double>::infinity(); double H0 = -0.1; int n_leapfrog = 0; double sum_metro_prob = 0; stan::mcmc::mock_model model(model_size); stan::mcmc::rho_inspector_mock_nuts sampler(model, base_rng); sampler.set_nominal_stepsize(1); sampler.set_stepsize_jitter(0); sampler.sample_stepsize(); sampler.z() = z_init; std::stringstream debug, info, warn, error, fatal; stan::callbacks::stream_logger logger(debug, info, warn, error, fatal); sampler.build_tree(3, z_propose, p_sharp_left, p_sharp_right, rho, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); EXPECT_EQ(7, sampler.rho_values.size()); EXPECT_EQ(2 * init_momentum, sampler.rho_values.at(0)); EXPECT_EQ(2 * init_momentum, sampler.rho_values.at(1)); EXPECT_EQ(4 * init_momentum, sampler.rho_values.at(2)); EXPECT_EQ(2 * init_momentum, sampler.rho_values.at(3)); EXPECT_EQ(2 * init_momentum, sampler.rho_values.at(4)); EXPECT_EQ(4 * init_momentum, sampler.rho_values.at(5)); EXPECT_EQ(8 * init_momentum, sampler.rho_values.at(6)); }
sample transition(sample& init_sample) { // Initialize the algorithm this->sample_stepsize(); nuts_util util; this->seed(init_sample.cont_params()); this->_hamiltonian.sample_p(this->_z, this->_rand_int); this->_hamiltonian.init(this->_z); ps_point z_plus(this->_z); ps_point z_minus(z_plus); ps_point z_sample(z_plus); ps_point z_propose(z_plus); int n_cont = init_sample.cont_params().size(); Eigen::VectorXd rho_init = this->_z.p; Eigen::VectorXd rho_plus(n_cont); rho_plus.setZero(); Eigen::VectorXd rho_minus(n_cont); rho_minus.setZero(); util.H0 = this->_hamiltonian.H(this->_z); // Sample the slice variable util.log_u = std::log(this->_rand_uniform()); // Build a balanced binary tree until the NUTS criterion fails util.criterion = true; int n_valid = 0; this->_depth = 0; this->_n_divergent = 0; util.n_tree = 0; util.sum_prob = 0; while (util.criterion && (this->_depth <= this->_max_depth) ) { // Randomly sample a direction in time ps_point* z = 0; Eigen::VectorXd* rho = 0; if (this->_rand_uniform() > 0.5) { z = &z_plus; rho = &rho_plus; util.sign = 1; } else { z = &z_minus; rho = &rho_minus; util.sign = -1; } // And build a new subtree in that direction this->_z.ps_point::operator=(*z); int n_valid_subtree = build_tree(_depth, *rho, 0, z_propose, util); *z = this->_z; // Metropolis-Hastings sample the fresh subtree if (!util.criterion) break; double subtree_prob = 0; if (n_valid) { subtree_prob = static_cast<double>(n_valid_subtree) / static_cast<double>(n_valid); } else { subtree_prob = n_valid_subtree ? 1 : 0; } if (this->_rand_uniform() < subtree_prob) z_sample = z_propose; n_valid += n_valid_subtree; // Check validity of completed tree this->_z.ps_point::operator=(z_plus); Eigen::VectorXd delta_rho = rho_minus + rho_init + rho_plus; util.criterion = compute_criterion(z_minus, this->_z, delta_rho); ++(this->_depth); } --(this->_depth); // Correct for increment at end of loop double accept_prob = util.sum_prob / static_cast<double>(util.n_tree); this->_z.ps_point::operator=(z_sample); return sample(this->_z.q, - this->_z.V, accept_prob); }
sample transition(sample& init_sample, interface_callbacks::writer::base_writer& info_writer, interface_callbacks::writer::base_writer& error_writer) { // Initialize the algorithm this->sample_stepsize(); nuts_util util; this->seed(init_sample.cont_params()); this->hamiltonian_.sample_p(this->z_, this->rand_int_); this->hamiltonian_.init(this->z_, info_writer, error_writer); ps_point z_plus(this->z_); ps_point z_minus(z_plus); ps_point z_sample(z_plus); ps_point z_propose(z_plus); int n_cont = init_sample.cont_params().size(); Eigen::VectorXd rho_init = this->z_.p; Eigen::VectorXd rho_plus(n_cont); rho_plus.setZero(); Eigen::VectorXd rho_minus(n_cont); rho_minus.setZero(); util.H0 = this->hamiltonian_.H(this->z_); // Sample the slice variable util.log_u = std::log(this->rand_uniform_()); // Build a balanced binary tree until the NUTS criterion fails util.criterion = true; int n_valid = 0; this->depth_ = 0; this->divergent_ = 0; util.n_tree = 0; util.sum_prob = 0; while (util.criterion && (this->depth_ <= this->max_depth_)) { // Randomly sample a direction in time ps_point* z = 0; Eigen::VectorXd* rho = 0; if (this->rand_uniform_() > 0.5) { z = &z_plus; rho = &rho_plus; util.sign = 1; } else { z = &z_minus; rho = &rho_minus; util.sign = -1; } // And build a new subtree in that direction this->z_.ps_point::operator=(*z); int n_valid_subtree = build_tree(depth_, *rho, 0, z_propose, util, info_writer, error_writer); ++(this->depth_); *z = this->z_; // Metropolis-Hastings sample the fresh subtree if (!util.criterion) break; double subtree_prob = 0; if (n_valid) { subtree_prob = static_cast<double>(n_valid_subtree) / static_cast<double>(n_valid); } else { subtree_prob = n_valid_subtree ? 1 : 0; } if (this->rand_uniform_() < subtree_prob) z_sample = z_propose; n_valid += n_valid_subtree; // Check validity of completed tree this->z_.ps_point::operator=(z_plus); Eigen::VectorXd delta_rho = rho_minus + rho_init + rho_plus; util.criterion = compute_criterion(z_minus, this->z_, delta_rho); } this->n_leapfrog_ = util.n_tree; double accept_prob = util.sum_prob / static_cast<double>(util.n_tree); this->z_.ps_point::operator=(z_sample); this->energy_ = this->hamiltonian_.H(this->z_); return sample(this->z_.q, - this->z_.V, accept_prob); }