/* sort children by converting linked list to vector, sorting the vector, and reconstructing to linked list again Requires node mutex to be held. */ void UCTNode::sort_children() { assert(get_mutex().is_held()); std::vector<std::tuple<float, UCTNode*>> tmp; UCTNode * child = m_firstchild; while (child != nullptr) { tmp.emplace_back(child->get_score(), child); child = child->m_nextsibling; } std::sort(begin(tmp), end(tmp)); m_firstchild = nullptr; for (auto& sortnode : tmp) { link_child(std::get<1>(sortnode)); } }
UCTNode* UCTNode::uct_select_child(int color) { UCTNode * best = nullptr; float best_value = -1000.0f; LOCK(get_mutex(), lock); // Progressive widening // int childbound = std::max(2, (int)(((log((double)get_visits()) - 3.0) * 3.0) + 2.0)); int childbound = 362; int childcount = 0; UCTNode * child = m_firstchild; // Count parentvisits. // We do this manually to avoid issues with transpositions. int parentvisits = 0; // Make sure we are at a valid successor. while (child != nullptr && !child->valid()) { child = child->m_nextsibling; } while (child != nullptr && childcount < childbound) { parentvisits += child->get_visits(); child = child->m_nextsibling; // Make sure we are at a valid successor. while (child != nullptr && !child->valid()) { child = child->m_nextsibling; } childcount++; } float numerator = std::sqrt((double)parentvisits); childcount = 0; child = m_firstchild; // Make sure we are at a valid successor. while (child != nullptr && !child->valid()) { child = child->m_nextsibling; } if (child == nullptr) { return nullptr; } // Prune bad probabilities // auto parent_log = std::log((float)parentvisits); // auto cutoff_ratio = cfg_cutoff_offset + cfg_cutoff_ratio * parent_log; // auto best_probability = child->get_score(); // assert(best_probability > 0.001f); while (child != nullptr && childcount < childbound) { // Prune bad probabilities // if (child->get_score() * cutoff_ratio < best_probability) { // break; // } // get_eval() will automatically set first-play-urgency float winrate = child->get_eval(color); float psa = child->get_score(); float denom = 1.0f + child->get_visits(); float puct = cfg_puct * psa * (numerator / denom); float value = winrate + puct; assert(value > -1000.0f); if (value > best_value) { best_value = value; best = child; } child = child->m_nextsibling; // Make sure we are at a valid successor. while (child != nullptr && !child->valid()) { child = child->m_nextsibling; } childcount++; } assert(best != nullptr); return best; }