Exemplo n.º 1
0
/*
    sort children by converting linked list to vector,
    sorting the vector, and reconstructing to linked list again
    Requires node mutex to be held.
*/
void UCTNode::sort_children() {
    assert(get_mutex().is_held());
    std::vector<std::tuple<float, UCTNode*>> tmp;

    UCTNode * child = m_firstchild;

    while (child != nullptr) {
        tmp.emplace_back(child->get_score(), child);
        child = child->m_nextsibling;
    }

    std::sort(begin(tmp), end(tmp));

    m_firstchild = nullptr;

    for (auto& sortnode : tmp) {
        link_child(std::get<1>(sortnode));
    }
}
Exemplo n.º 2
0
UCTNode* UCTNode::uct_select_child(int color) {
    UCTNode * best = nullptr;
    float best_value = -1000.0f;

    LOCK(get_mutex(), lock);
    // Progressive widening
    // int childbound = std::max(2, (int)(((log((double)get_visits()) - 3.0) * 3.0) + 2.0));
    int childbound = 362;
    int childcount = 0;
    UCTNode * child = m_firstchild;

    // Count parentvisits.
    // We do this manually to avoid issues with transpositions.
    int parentvisits = 0;
    // Make sure we are at a valid successor.
    while (child != nullptr && !child->valid()) {
        child = child->m_nextsibling;
    }
    while (child != nullptr  && childcount < childbound) {
        parentvisits      += child->get_visits();
        child = child->m_nextsibling;
        // Make sure we are at a valid successor.
        while (child != nullptr && !child->valid()) {
            child = child->m_nextsibling;
        }
        childcount++;
    }
    float numerator = std::sqrt((double)parentvisits);

    childcount = 0;
    child = m_firstchild;
    // Make sure we are at a valid successor.
    while (child != nullptr && !child->valid()) {
        child = child->m_nextsibling;
    }
    if (child == nullptr) {
        return nullptr;
    }

    // Prune bad probabilities
    // auto parent_log = std::log((float)parentvisits);
    // auto cutoff_ratio = cfg_cutoff_offset + cfg_cutoff_ratio * parent_log;
    // auto best_probability = child->get_score();
    // assert(best_probability > 0.001f);

    while (child != nullptr && childcount < childbound) {
        // Prune bad probabilities
        // if (child->get_score() * cutoff_ratio < best_probability) {
        //     break;
        // }

        // get_eval() will automatically set first-play-urgency
        float winrate = child->get_eval(color);
        float psa = child->get_score();
        float denom = 1.0f + child->get_visits();
        float puct = cfg_puct * psa * (numerator / denom);
        float value = winrate + puct;
        assert(value > -1000.0f);

        if (value > best_value) {
            best_value = value;
            best = child;
        }

        child = child->m_nextsibling;
        // Make sure we are at a valid successor.
        while (child != nullptr && !child->valid()) {
            child = child->m_nextsibling;
        }
        childcount++;
    }

    assert(best != nullptr);

    return best;
}