Beispiel #1
0
void tree_stats_helper(const UCTNode& node, size_t depth,
                       size_t& nodes, size_t& non_leaf_nodes,
                       size_t& depth_sum, size_t& max_depth,
                       size_t& children_count) {
    nodes += 1;
    non_leaf_nodes += node.get_visits() > 1;
    depth_sum += depth;
    if (depth > max_depth) max_depth = depth;

    for (const auto& child : node.get_children()) {
        if (!child->first_visit()) children_count += 1;

        tree_stats_helper(*(child.get()), depth+1,
                          nodes, non_leaf_nodes, depth_sum,
                          max_depth, children_count);
    }
}
Beispiel #2
0
UCTNode* UCTNode::uct_select_child(int color) {
    UCTNode * best = nullptr;
    float best_value = -1000.0f;

    LOCK(get_mutex(), lock);
    // Progressive widening
    // int childbound = std::max(2, (int)(((log((double)get_visits()) - 3.0) * 3.0) + 2.0));
    int childbound = 362;
    int childcount = 0;
    UCTNode * child = m_firstchild;

    // Count parentvisits.
    // We do this manually to avoid issues with transpositions.
    int parentvisits = 0;
    // Make sure we are at a valid successor.
    while (child != nullptr && !child->valid()) {
        child = child->m_nextsibling;
    }
    while (child != nullptr  && childcount < childbound) {
        parentvisits      += child->get_visits();
        child = child->m_nextsibling;
        // Make sure we are at a valid successor.
        while (child != nullptr && !child->valid()) {
            child = child->m_nextsibling;
        }
        childcount++;
    }
    float numerator = std::sqrt((double)parentvisits);

    childcount = 0;
    child = m_firstchild;
    // Make sure we are at a valid successor.
    while (child != nullptr && !child->valid()) {
        child = child->m_nextsibling;
    }
    if (child == nullptr) {
        return nullptr;
    }

    // Prune bad probabilities
    // auto parent_log = std::log((float)parentvisits);
    // auto cutoff_ratio = cfg_cutoff_offset + cfg_cutoff_ratio * parent_log;
    // auto best_probability = child->get_score();
    // assert(best_probability > 0.001f);

    while (child != nullptr && childcount < childbound) {
        // Prune bad probabilities
        // if (child->get_score() * cutoff_ratio < best_probability) {
        //     break;
        // }

        // get_eval() will automatically set first-play-urgency
        float winrate = child->get_eval(color);
        float psa = child->get_score();
        float denom = 1.0f + child->get_visits();
        float puct = cfg_puct * psa * (numerator / denom);
        float value = winrate + puct;
        assert(value > -1000.0f);

        if (value > best_value) {
            best_value = value;
            best = child;
        }

        child = child->m_nextsibling;
        // Make sure we are at a valid successor.
        while (child != nullptr && !child->valid()) {
            child = child->m_nextsibling;
        }
        childcount++;
    }

    assert(best != nullptr);

    return best;
}