void tree_stats_helper(const UCTNode& node, size_t depth, size_t& nodes, size_t& non_leaf_nodes, size_t& depth_sum, size_t& max_depth, size_t& children_count) { nodes += 1; non_leaf_nodes += node.get_visits() > 1; depth_sum += depth; if (depth > max_depth) max_depth = depth; for (const auto& child : node.get_children()) { if (!child->first_visit()) children_count += 1; tree_stats_helper(*(child.get()), depth+1, nodes, non_leaf_nodes, depth_sum, max_depth, children_count); } }
UCTNode* UCTNode::uct_select_child(int color) { UCTNode * best = nullptr; float best_value = -1000.0f; LOCK(get_mutex(), lock); // Progressive widening // int childbound = std::max(2, (int)(((log((double)get_visits()) - 3.0) * 3.0) + 2.0)); int childbound = 362; int childcount = 0; UCTNode * child = m_firstchild; // Count parentvisits. // We do this manually to avoid issues with transpositions. int parentvisits = 0; // Make sure we are at a valid successor. while (child != nullptr && !child->valid()) { child = child->m_nextsibling; } while (child != nullptr && childcount < childbound) { parentvisits += child->get_visits(); child = child->m_nextsibling; // Make sure we are at a valid successor. while (child != nullptr && !child->valid()) { child = child->m_nextsibling; } childcount++; } float numerator = std::sqrt((double)parentvisits); childcount = 0; child = m_firstchild; // Make sure we are at a valid successor. while (child != nullptr && !child->valid()) { child = child->m_nextsibling; } if (child == nullptr) { return nullptr; } // Prune bad probabilities // auto parent_log = std::log((float)parentvisits); // auto cutoff_ratio = cfg_cutoff_offset + cfg_cutoff_ratio * parent_log; // auto best_probability = child->get_score(); // assert(best_probability > 0.001f); while (child != nullptr && childcount < childbound) { // Prune bad probabilities // if (child->get_score() * cutoff_ratio < best_probability) { // break; // } // get_eval() will automatically set first-play-urgency float winrate = child->get_eval(color); float psa = child->get_score(); float denom = 1.0f + child->get_visits(); float puct = cfg_puct * psa * (numerator / denom); float value = winrate + puct; assert(value > -1000.0f); if (value > best_value) { best_value = value; best = child; } child = child->m_nextsibling; // Make sure we are at a valid successor. while (child != nullptr && !child->valid()) { child = child->m_nextsibling; } childcount++; } assert(best != nullptr); return best; }