sample transition(sample& init_sample) { // Initialize the algorithm this->_sample_stepsize(); nuts_util util; this->seed(init_sample.cont_params(), init_sample.disc_params()); this->_hamiltonian.sample_p(this->_z, this->_rand_int); this->_hamiltonian.init(this->_z); ps_point z_plus(static_cast<ps_point>(this->_z)); ps_point z_minus(z_plus); ps_point z_sample(z_plus); ps_point z_propose(z_plus); int n_cont = init_sample.cont_params().size(); Eigen::VectorXd rho_init = this->_z.p; Eigen::VectorXd rho_plus = Eigen::VectorXd::Zero(n_cont); Eigen::VectorXd rho_minus = Eigen::VectorXd::Zero(n_cont); util.H0 = this->_hamiltonian.H(this->_z); // Sample the slice variable util.log_u = std::log(this->_rand_uniform()); // Build a balanced binary tree until the NUTS criterion fails util.criterion = true; int n_valid = 0; this->_depth = 0; while (util.criterion && (this->_depth <= this->_max_depth) ) { util.n_tree = 0; util.sum_prob = 0; // Randomly sample a direction in time ps_point* z = 0; Eigen::VectorXd* rho = 0; if (this->_rand_uniform() > 0.5) { z = &z_plus; rho = &rho_plus; util.sign = 1; } else { z = &z_minus; rho = &rho_minus; util.sign = -1; } // And build a new subtree in that direction this->_z.copy_base(*z); int n_valid_subtree = build_tree(_depth, *rho, 0, z_propose, util); *z = static_cast<ps_point>(this->_z); // Metropolis-Hastings sample the fresh subtree if (!util.criterion) break; double subtree_prob = 0; if (n_valid) { subtree_prob = static_cast<double>(n_valid_subtree) / static_cast<double>(n_valid); } else { subtree_prob = n_valid_subtree ? 1 : 0; } if (this->_rand_uniform() < subtree_prob) z_sample = z_propose; n_valid += n_valid_subtree; // Check validity of completed tree this->_z.copy_base(z_plus); Eigen::VectorXd delta_rho = rho_minus + rho_init + rho_plus; util.criterion = _compute_criterion(z_minus, this->_z, delta_rho); ++(this->_depth); } --(this->_depth); // Correct for increment at end of loop this->_z.copy_base(z_sample); double accept_prob = util.sum_prob / static_cast<double>(util.n_tree); return sample(this->_z.q, this->_z.r, - this->_hamiltonian.V(this->_z), accept_prob); }
sample transition(sample& init_sample) { this->seed(init_sample.cont_params(), init_sample.disc_params()); return sample(this->_z.q, this->_z.r, - this->_hamiltonian.V(this->_z), 0); }