void VNode::PrintPolicyTree(int depth, ostream& os) { if (depth != -1 && this->depth() > depth) return; vector<QNode*>& qnodes = children(); if (qnodes.size() == 0) { int astar = this->default_move().action; os << this << "-a=" << astar << endl; } else { QNode* qstar = NULL; for (int a = 0; a < qnodes.size(); a++) { QNode* qnode = qnodes[a]; if (qstar == NULL || qnode->lower_bound() > qstar->lower_bound()) { qstar = qnode; } } os << this << "-a=" << qstar->edge() << endl; vector<OBS_TYPE> labels; map<OBS_TYPE, VNode*>& vnodes = qstar->children(); for (map<OBS_TYPE, VNode*>::iterator it = vnodes.begin(); it != vnodes.end(); it++) { labels.push_back(it->first); } for (int i = 0; i < labels.size(); i++) { if (depth == -1 || this->depth() + 1 <= depth) { os << repeat("| ", this->depth()) << "| o=" << labels[i] << ": "; qstar->Child(labels[i])->PrintPolicyTree(depth, os); } } } }
VNode* DESPOT::Trial(VNode* root, RandomStreams& streams, ScenarioLowerBound* lower_bound, ScenarioUpperBound* upper_bound, const DSPOMDP* model, History& history, SearchStatistics* statistics) { VNode* cur = root; int hist_size = history.Size(); do { if (statistics != NULL && cur->depth() > statistics->longest_trial_length) { statistics->longest_trial_length = cur->depth(); } ExploitBlockers(cur); if (Gap(cur) == 0) { break; } if (cur->IsLeaf()) { double start = clock(); Expand(cur, lower_bound, upper_bound, model, streams, history); if (statistics != NULL) { statistics->time_node_expansion += (double) (clock() - start) / CLOCKS_PER_SEC; statistics->num_expanded_nodes++; statistics->num_tree_particles += cur->particles().size(); } } double start = clock(); QNode* qstar = SelectBestUpperBoundNode(cur); VNode* next = SelectBestWEUNode(qstar); if (statistics != NULL) { statistics->time_path += (clock() - start) / CLOCKS_PER_SEC; } if (next == NULL) { break; } cur = next; history.Add(qstar->edge(), cur->edge()); } while (cur->depth() < Globals::config.search_depth && WEU(cur) > 0); history.Truncate(hist_size); return cur; }
VNode* DESPOT::Prune(VNode* vnode, int& pruned_action, double& pruned_value) { vector<State*> empty; VNode* pruned_v = new VNode(empty, vnode->depth(), NULL, vnode->edge()); vector<QNode*>& children = vnode->children(); int astar = -1; double nustar = Globals::NEG_INFTY; QNode* qstar = NULL; for (int i = 0; i < children.size(); i++) { QNode* qnode = children[i]; double nu; QNode* pruned_q = Prune(qnode, nu); if (nu > nustar) { nustar = nu; astar = qnode->edge(); if (qstar != NULL) { delete qstar; } qstar = pruned_q; } else { delete pruned_q; } } if (nustar < vnode->default_move().value) { nustar = vnode->default_move().value; astar = vnode->default_move().action; delete qstar; } else { pruned_v->children().push_back(qstar); qstar->parent(pruned_v); } pruned_v->lower_bound(vnode->lower_bound()); // for debugging pruned_v->upper_bound(vnode->upper_bound()); pruned_action = astar; pruned_value = nustar; return pruned_v; }
void VNode::PrintTree(int depth, ostream& os) { if (depth != -1 && this->depth() > depth) return; if (this->depth() == 0) { os << "d - default value" << endl << "l - lower bound" << endl << "u - upper bound" << endl << "r - totol weighted one step reward" << endl << "w - total particle weight" << endl; } os << "(" << "d:" << this->default_move().value << " l:" << this->lower_bound() << ", u:" << this->upper_bound() << ", w:" << this->Weight() << ", weu:" << DESPOT::WEU(this) << ")" << endl; vector<QNode*>& qnodes = children(); for (int a = 0; a < qnodes.size(); a++) { QNode* qnode = qnodes[a]; vector<OBS_TYPE> labels; map<OBS_TYPE, VNode*>& vnodes = qnode->children(); for (map<OBS_TYPE, VNode*>::iterator it = vnodes.begin(); it != vnodes.end(); it++) { labels.push_back(it->first); } os << repeat("| ", this->depth()) << "a=" << qnode->edge() << ": " << "(d:" << qnode->default_value << ", l:" << qnode->lower_bound() << ", u:" << qnode->upper_bound() << ", r:" << qnode->step_reward << ")" << endl; for (int i = 0; i < labels.size(); i++) { if (depth == -1 || this->depth() + 1 <= depth) { os << repeat("| ", this->depth()) << "| o=" << labels[i] << ": "; qnode->Child(labels[i])->PrintTree(depth, os); } } } }