ValuedAction DESPOT::Evaluate(VNode* root, vector<State*>& particles, RandomStreams& streams, POMCPPrior* prior, const DSPOMDP* model) { double value = 0; for (int i = 0; i < particles.size(); i++) { particles[i]->scenario_id = i; } for (int i = 0; i < particles.size(); i++) { State* particle = particles[i]; VNode* cur = root; State* copy = model->Copy(particle); double discount = 1.0; double val = 0; int steps = 0; while (!streams.Exhausted()) { int action = (cur != NULL) ? OptimalAction(cur).action : prior->GetAction(*copy); assert(action != -1); double reward; OBS_TYPE obs; bool terminal = model->Step(*copy, streams.Entry(copy->scenario_id), action, reward, obs); val += discount * reward; discount *= Discount(); if (!terminal) { prior->Add(action, obs); streams.Advance(); steps++; if (cur != NULL && !cur->IsLeaf()) { QNode* qnode = cur->Child(action); map<OBS_TYPE, VNode*>& vnodes = qnode->children(); cur = vnodes.find(obs) != vnodes.end() ? vnodes[obs] : NULL; } } else { break; } } for (int i = 0; i < steps; i++) { streams.Back(); prior->PopLast(); } model->Free(copy); value += val; } return ValuedAction(OptimalAction(root).action, value / particles.size()); }
VNode* DESPOT::Trial(VNode* root, RandomStreams& streams, ScenarioLowerBound* lower_bound, ScenarioUpperBound* upper_bound, const DSPOMDP* model, History& history, SearchStatistics* statistics) { VNode* cur = root; int hist_size = history.Size(); do { if (statistics != NULL && cur->depth() > statistics->longest_trial_length) { statistics->longest_trial_length = cur->depth(); } ExploitBlockers(cur); if (Gap(cur) == 0) { break; } if (cur->IsLeaf()) { double start = clock(); Expand(cur, lower_bound, upper_bound, model, streams, history); if (statistics != NULL) { statistics->time_node_expansion += (double) (clock() - start) / CLOCKS_PER_SEC; statistics->num_expanded_nodes++; statistics->num_tree_particles += cur->particles().size(); } } double start = clock(); QNode* qstar = SelectBestUpperBoundNode(cur); VNode* next = SelectBestWEUNode(qstar); if (statistics != NULL) { statistics->time_path += (clock() - start) / CLOCKS_PER_SEC; } if (next == NULL) { break; } cur = next; history.Add(qstar->edge(), cur->edge()); } while (cur->depth() < Globals::config.search_depth && WEU(cur) > 0); history.Truncate(hist_size); return cur; }