コード例 #1
0
ファイル: UCTSearch.cpp プロジェクト: eras44/sparcraft
 size_t UCTSearch::getChildNodeType(UCTNode & parent, const GameState & prevState) const
{
    if (!prevState.bothCanMove())
    {
        return SearchNodeType::SoloNode;
    }
    else
    {
        if (parent.getNodeType() == SearchNodeType::RootNode)
        {
            return SearchNodeType::FirstSimNode;
        }
        else if (parent.getNodeType() == SearchNodeType::SoloNode)
        {
            return SearchNodeType::FirstSimNode;
        }
        else if (parent.getNodeType() == SearchNodeType::SecondSimNode)
        {
            return SearchNodeType::FirstSimNode;
        }
        else if (parent.getNodeType() == SearchNodeType::FirstSimNode)
        {
            return SearchNodeType::SecondSimNode;
        }
    }

    return SearchNodeType::Default;
}
コード例 #2
0
ファイル: UCTSearch.cpp プロジェクト: TFiFiE/lizzie
void UCTSearch::dump_stats(FastState & state, UCTNode & parent) {
    if (cfg_quiet || !parent.has_children()) {
        return;
    }

    const int color = state.get_to_move();

    // sort children, put best move on top
    parent.sort_children(color);

    if (parent.get_first_child()->first_visit()) {
        return;
    }

    int movecount = 0;
    for (const auto& node : parent.get_children()) {
        // Always display at least two moves. In the case there is
        // only one move searched the user could get an idea why.
        if (++movecount > 2 && !node->get_visits()) break;

        std::string move = state.move_to_text(node->get_move());
        FastState tmpstate = state;
        tmpstate.play_move(node->get_move());
        std::string pv = move + " " + get_pv(tmpstate, *node);

        myprintf("%4s -> %7d (V: %5.2f%%) (N: %5.2f%%) PV: %s\n",
            move.c_str(),
            node->get_visits(),
            node->get_visits() ? node->get_eval(color)*100.0f : 0.0f,
            node->get_score() * 100.0f,
            pv.c_str());
    }
    tree_stats(parent);
}
コード例 #3
0
ファイル: UCTSearch.cpp プロジェクト: eras44/sparcraft
StateEvalScore UCTSearch::traverse(UCTNode & node, GameState & currentState)
{
    StateEvalScore playoutVal;

    _results.totalVisits++;

    // if we haven't visited this node yet, do a playout
    if (node.numVisits() == 0)
    {
        // update the status of the current state with this node's moves
        //updateState(node, currentState, !node.hasChildren());
        updateState(node, currentState, true);

        // do the playout
        playoutVal = currentState.eval(_params.maxPlayer(), _params.evalMethod(), _params.simScript(Players::Player_One), _params.simScript(Players::Player_Two));

        _results.nodesVisited++;
    }
    // otherwise we have seen this node before
    else
    {
        // update the state for a non-leaf node
        updateState(node, currentState, false);

        if (currentState.isTerminal())
        {
            playoutVal = currentState.eval(_params.maxPlayer(), EvaluationMethods::LTD2);
        }
        else
        {
            // if the children haven't been generated yet
            if (!node.hasChildren())
            {
                generateChildren(node, currentState);
            }

            UCTNode & next = UCTNodeSelect(node);
            playoutVal = traverse(next, currentState);
        }
    }

    node.incVisits();
    
    if (playoutVal.val() > 0)
    {
        node.addWins(1);
    }
    else if (playoutVal.val() == 0)
    {
        node.addWins(0.5);
    }

    return playoutVal;
}
コード例 #4
0
ファイル: UCTSearch.cpp プロジェクト: eras44/sparcraft
 IDType UCTSearch::getPlayerToMove(UCTNode & node, const GameState & state) const
{
	const IDType whoCanMove(state.whoCanMove());

	// if both players can move
	if (whoCanMove == Players::Player_Both)
	{
        // pick the first move based on our policy
		const IDType policy(_params.playerToMoveMethod());
		const IDType maxPlayer(_params.maxPlayer());

        // the max player always chooses at the root
        if (isRoot(node))
        {
            return maxPlayer;
        }

        // the type of node this is
        const IDType nodeType = node.getNodeType();

        // the 2nd player in a sim move is always the enemy of the first
        if (nodeType == SearchNodeType::FirstSimNode)
        {
            return state.getEnemy(node.getPlayer());
        }
        // otherwise use our policy to see who goes first in a sim move state
        else
        {
		    if (policy == SparCraft::PlayerToMove::Alternate)
		    {
			    return state.getEnemy(node.getPlayer());
		    }
		    else if (policy == SparCraft::PlayerToMove::Not_Alternate)
		    {
			    return node.getPlayer();
		    }
		    else if (policy == SparCraft::PlayerToMove::Random)
		    {
			    return rand() % 2;
		    }

            // we should never get to this state
		    System::FatalError("UCT Error: Nobody can move for some reason");
		    return Players::Player_None;
        }
	}
	else
	{
		return whoCanMove;
	}
}
コード例 #5
0
ファイル: UCTSearch.cpp プロジェクト: eras44/sparcraft
void UCTSearch::updateState(UCTNode & node, GameState & state, bool isLeaf)
{
    // if it's the first sim move with children, or the root node
    if ((node.getNodeType() != SearchNodeType::FirstSimNode) || isLeaf)
    {
        // if this is a second sim node
        if (node.getNodeType() == SearchNodeType::SecondSimNode)
        {
            // make the parent's moves on the state because they haven't been done yet
            state.makeMoves(node.getParent()->getMove());
        }

        // do the current node moves and call finished moving
        state.makeMoves(node.getMove());
        state.finishedMoving();
    }
}
コード例 #6
0
ファイル: UCTSearch.cpp プロジェクト: TFiFiE/lizzie
void tree_stats_helper(const UCTNode& node, size_t depth,
                       size_t& nodes, size_t& non_leaf_nodes,
                       size_t& depth_sum, size_t& max_depth,
                       size_t& children_count) {
    nodes += 1;
    non_leaf_nodes += node.get_visits() > 1;
    depth_sum += depth;
    if (depth > max_depth) max_depth = depth;

    for (const auto& child : node.get_children()) {
        if (!child->first_visit()) children_count += 1;

        tree_stats_helper(*(child.get()), depth+1,
                          nodes, non_leaf_nodes, depth_sum,
                          max_depth, children_count);
    }
}
コード例 #7
0
std::string UCTSearch::get_pv(BoardHistory& state, UCTNode& parent, bool use_san) {
    if (!parent.has_children()) {
        return std::string();
    }

    auto& best_child = parent.get_best_root_child(state.cur().side_to_move());
    auto best_move = best_child.get_move();
    auto res = use_san ? state.cur().move_to_san(best_move) : UCI::move(best_move);

    StateInfo st;
    state.cur().do_move(best_move, st);

    auto next = get_pv(state, best_child, use_san);
    if (!next.empty()) {
        res.append(" ").append(next);
    }
    state.cur().undo_move(best_move);
    return res;
}
コード例 #8
0
ファイル: UCTNode.cpp プロジェクト: intfrr/leela-zero
/*
    sort children by converting linked list to vector,
    sorting the vector, and reconstructing to linked list again
    Requires node mutex to be held.
*/
void UCTNode::sort_children() {
    assert(get_mutex().is_held());
    std::vector<std::tuple<float, UCTNode*>> tmp;

    UCTNode * child = m_firstchild;

    while (child != nullptr) {
        tmp.emplace_back(child->get_score(), child);
        child = child->m_nextsibling;
    }

    std::sort(begin(tmp), end(tmp));

    m_firstchild = nullptr;

    for (auto& sortnode : tmp) {
        link_child(std::get<1>(sortnode));
    }
}
コード例 #9
0
ファイル: UCTNode.cpp プロジェクト: intfrr/leela-zero
void UCTNode::kill_superkos(KoState & state) {
    UCTNode * child = m_firstchild;

    while (child != nullptr) {
        int move = child->get_move();

        if (move != FastBoard::PASS) {
            KoState mystate = state;
            mystate.play_move(move);

            if (mystate.superko()) {
                UCTNode * tmp = child->m_nextsibling;
                delete_child(child);
                child = tmp;
                continue;
            }
        }
        child = child->m_nextsibling;
    }
}
コード例 #10
0
ファイル: UCTSearch.cpp プロジェクト: TFiFiE/lizzie
std::string UCTSearch::get_pv(FastState & state, UCTNode& parent) {
    if (!parent.has_children()) {
        return std::string();
    }

    auto& best_child = parent.get_best_root_child(state.get_to_move());
    if (best_child.first_visit()) {
        return std::string();
    }
    auto best_move = best_child.get_move();
    auto res = state.move_to_text(best_move);

    state.play_move(best_move);

    auto next = get_pv(state, best_child);
    if (!next.empty()) {
        res.append(" ").append(next);
    }
    return res;
}
コード例 #11
0
ファイル: UCTSearch.cpp プロジェクト: eras44/sparcraft
// generate the children of state 'node'
// state is the GameState after node's moves have been performed
void UCTSearch::generateChildren(UCTNode & node, GameState & state)
{
    // figure out who is next to move in the game
    const IDType playerToMove(getPlayerToMove(node, state));

    // generate all the moves possible from this state
	state.generateMoves(_moveArray, playerToMove);
    _moveArray.shuffleMoveActions();

    // generate the 'ordered moves' for move ordering
    generateOrderedMoves(state, _moveArray, playerToMove);

    // for each child of this state, add a child to the current node
    for (size_t child(0); (child < _params.maxChildren()) && getNextMove(playerToMove, _moveArray, child, _actionVec); ++child)
    {
        // add the child to the tree
        node.addChild(&node, playerToMove, getChildNodeType(node, state), _actionVec, _params.maxChildren(), _memoryPool ? _memoryPool->alloc() : NULL);
        _results.nodesCreated++;
    }
}
コード例 #12
0
ファイル: UCTSearch.cpp プロジェクト: eras44/sparcraft
void UCTSearch::printSubTreeGraphViz(UCTNode & node, GraphViz::Graph & g, GameState state)
{
    if (node.getNodeType() == SearchNodeType::FirstSimNode && node.hasChildren())
    {
        // don't make any moves if it is a first simnode
    }
    else
    {
        if (node.getNodeType() == SearchNodeType::SecondSimNode)
        {
            state.makeMoves(node.getParent()->getMove());
        }

        state.makeMoves(node.getMove());
        state.finishedMoving();
    }

    std::stringstream label;
    std::stringstream move;

    for (size_t a(0); a<node.getMove().size(); ++a)
    {
        move << node.getMove()[a].moveString() << "\\n";
    }

    if (node.getMove().size() == 0)
    {
        move << "root";
    }

    std::string firstSim = SearchNodeType::getName(node.getNodeType());

    Unit p1 = state.getUnit(0,0);
    Unit p2 = state.getUnit(1,0);

    label   << move.str() 
            << "\\nVal: "       << node.getUCTVal() 
            << "\\nWins: "      << node.numWins() 
            << "\\nVisits: "    << node.numVisits() 
            << "\\nChildren: "  << node.numChildren() 
            << "\\n"            << firstSim
            << "\\nPtr: "       << &node
            << "\\n---------------"
            << "\\nFrame: " << state.getTime()
            << "\\nHP: " << p1.currentHP() << "  " << p2.currentHP()
            << "\\nAtk: " << p1.nextAttackActionTime() << "  " << p2.nextAttackActionTime()
            << "\\nMove: " << p1.nextMoveActionTime() << "  " << p2.nextMoveActionTime()
            << "\\nPrev: " << p1.previousActionTime() << "  " << p2.previousActionTime();
    
    std::string fillcolor       ("#aaaaaa");

    if (node.getPlayer() == Players::Player_One)
    {
        fillcolor = "#ff0000";
    }
    else if (node.getPlayer() == Players::Player_Two)
    {
        fillcolor = "#00ff00";
    }
    
    GraphViz::Node n(getNodeIDString(node));
    n.set("label",      label.str());
    n.set("fillcolor",  fillcolor);
    n.set("color",      "#000000");
    n.set("fontcolor",  "#000000");
    n.set("style",      "filled,bold");
    n.set("shape",      "box");
    g.addNode(n);

    // recurse for each child
    for (size_t c(0); c<node.numChildren(); ++c)
    {
        UCTNode & child = node.getChild(c);
        if (child.numVisits() > 0)
        {
            GraphViz::Edge edge(getNodeIDString(node), getNodeIDString(child));
            g.addEdge(edge);
            printSubTreeGraphViz(child, g, state);
        }
    }
}
コード例 #13
0
void UCTSearch::dump_stats(BoardHistory& state, UCTNode& parent) {
    if (cfg_quiet || !parent.has_children()) {
        return;
    }
    myprintf("\n");

    const Color color = state.cur().side_to_move();

    // sort children, put best move on top
    m_root->sort_root_children(color);

    if (parent.get_first_child()->first_visit()) {
        return;
    }

    auto root_temperature = get_root_temperature();
    auto accum_vector = m_root->calc_proportional(root_temperature, color);

    for (const auto& node : boost::adaptors::reverse(parent.get_children())) {
        std::string tmp = state.cur().move_to_san(node->get_move());
        std::string pvstring(tmp);
        std::string moveprob(10, '\0');

        if (cfg_randomize) {
            auto move_probability = accum_vector.back();
            accum_vector.pop_back();
            if (accum_vector.size() > 0) {
                move_probability -= accum_vector.back();
            }
            move_probability *= 100.0f; // The following code expects percentage.
            if (move_probability > 0.01f) {
                std::snprintf(&moveprob[0], moveprob.size(), "(%6.2f%%)", move_probability);
            } else if (move_probability > 0.00001f) {
                std::snprintf(&moveprob[0], moveprob.size(), "%s", "(> 0.00%)");
            } else {
                std::snprintf(&moveprob[0], moveprob.size(), "%s", "(  0.00%)");
            }
        } else {
            auto needed = std::snprintf(&moveprob[0], moveprob.size(), "%s", " ");
            moveprob.resize(needed+1);
        }
        myprintf_so("info string %5s -> %7d %s (V: %5.2f%%) (N: %5.2f%%) PV: ",
                tmp.c_str(),
                node->get_visits(),
                moveprob.c_str(),
                node->get_eval(color)*100.0f,
                node->get_score() * 100.0f);

        StateInfo si;
        state.cur().do_move(node->get_move(), si);
        // Since this is just a string, set use_san=true
        pvstring += " " + get_pv(state, *node, true);
        state.cur().undo_move(node->get_move());

        myprintf_so("%s\n", pvstring.c_str());
    }
    // winrate separate info string since it's not UCI spec
    float feval = m_root->get_eval(color);
    myprintf_so("info string stm %s winrate %5.2f%%\n",
        color == Color::WHITE ? "White" : "Black", feval * 100.f);
    myprintf("\n");
}
コード例 #14
0
ファイル: UCTNode.cpp プロジェクト: intfrr/leela-zero
UCTNode* UCTNode::uct_select_child(int color) {
    UCTNode * best = nullptr;
    float best_value = -1000.0f;

    LOCK(get_mutex(), lock);
    // Progressive widening
    // int childbound = std::max(2, (int)(((log((double)get_visits()) - 3.0) * 3.0) + 2.0));
    int childbound = 362;
    int childcount = 0;
    UCTNode * child = m_firstchild;

    // Count parentvisits.
    // We do this manually to avoid issues with transpositions.
    int parentvisits = 0;
    // Make sure we are at a valid successor.
    while (child != nullptr && !child->valid()) {
        child = child->m_nextsibling;
    }
    while (child != nullptr  && childcount < childbound) {
        parentvisits      += child->get_visits();
        child = child->m_nextsibling;
        // Make sure we are at a valid successor.
        while (child != nullptr && !child->valid()) {
            child = child->m_nextsibling;
        }
        childcount++;
    }
    float numerator = std::sqrt((double)parentvisits);

    childcount = 0;
    child = m_firstchild;
    // Make sure we are at a valid successor.
    while (child != nullptr && !child->valid()) {
        child = child->m_nextsibling;
    }
    if (child == nullptr) {
        return nullptr;
    }

    // Prune bad probabilities
    // auto parent_log = std::log((float)parentvisits);
    // auto cutoff_ratio = cfg_cutoff_offset + cfg_cutoff_ratio * parent_log;
    // auto best_probability = child->get_score();
    // assert(best_probability > 0.001f);

    while (child != nullptr && childcount < childbound) {
        // Prune bad probabilities
        // if (child->get_score() * cutoff_ratio < best_probability) {
        //     break;
        // }

        // get_eval() will automatically set first-play-urgency
        float winrate = child->get_eval(color);
        float psa = child->get_score();
        float denom = 1.0f + child->get_visits();
        float puct = cfg_puct * psa * (numerator / denom);
        float value = winrate + puct;
        assert(value > -1000.0f);

        if (value > best_value) {
            best_value = value;
            best = child;
        }

        child = child->m_nextsibling;
        // Make sure we are at a valid successor.
        while (child != nullptr && !child->valid()) {
            child = child->m_nextsibling;
        }
        childcount++;
    }

    assert(best != nullptr);

    return best;
}
コード例 #15
0
ファイル: UCTSearch.cpp プロジェクト: TFiFiE/lizzie
int UCTSearch::get_best_move(passflag_t passflag) {
    int color = m_rootstate.board.get_to_move();

    // Make sure best is first
    m_root->sort_children(color);

    // Check whether to randomize the best move proportional
    // to the playout counts, early game only.
    auto movenum = int(m_rootstate.get_movenum());
    if (movenum < cfg_random_cnt) {
        m_root->randomize_first_proportionally();
    }

    auto first_child = m_root->get_first_child();
    assert(first_child != nullptr);

    auto bestmove = first_child->get_move();
    auto bestscore = first_child->get_eval(color);

    // do we want to fiddle with the best move because of the rule set?
    if (passflag & UCTSearch::NOPASS) {
        // were we going to pass?
        if (bestmove == FastBoard::PASS) {
            UCTNode * nopass = m_root->get_nopass_child(m_rootstate);

            if (nopass != nullptr) {
                myprintf("Preferring not to pass.\n");
                bestmove = nopass->get_move();
                if (nopass->first_visit()) {
                    bestscore = 1.0f;
                } else {
                    bestscore = nopass->get_eval(color);
                }
            } else {
                myprintf("Pass is the only acceptable move.\n");
            }
        }
    } else {
        if (!cfg_dumbpass && bestmove == FastBoard::PASS) {
            // Either by forcing or coincidence passing is
            // on top...check whether passing loses instantly
            // do full count including dead stones.
            // In a reinforcement learning setup, it is possible for the
            // network to learn that, after passing in the tree, the two last
            // positions are identical, and this means the position is only won
            // if there are no dead stones in our own territory (because we use
            // Trump-Taylor scoring there). So strictly speaking, the next
            // heuristic isn't required for a pure RL network, and we have
            // a commandline option to disable the behavior during learning.
            // On the other hand, with a supervised learning setup, we fully
            // expect that the engine will pass out anything that looks like
            // a finished game even with dead stones on the board (because the
            // training games were using scoring with dead stone removal).
            // So in order to play games with a SL network, we need this
            // heuristic so the engine can "clean up" the board. It will still
            // only clean up the bare necessity to win. For full dead stone
            // removal, kgs-genmove_cleanup and the NOPASS mode must be used.
            float score = m_rootstate.final_score();
            // Do we lose by passing?
            if ((score > 0.0f && color == FastBoard::WHITE)
                ||
                (score < 0.0f && color == FastBoard::BLACK)) {
                myprintf("Passing loses :-(\n");
                // Find a valid non-pass move.
                UCTNode * nopass = m_root->get_nopass_child(m_rootstate);
                if (nopass != nullptr) {
                    myprintf("Avoiding pass because it loses.\n");
                    bestmove = nopass->get_move();
                    if (nopass->first_visit()) {
                        bestscore = 1.0f;
                    } else {
                        bestscore = nopass->get_eval(color);
                    }
                } else {
                    myprintf("No alternative to passing.\n");
                }
            } else {
                myprintf("Passing wins :-)\n");
            }
        } else if (!cfg_dumbpass
                   && m_rootstate.get_last_move() == FastBoard::PASS) {
            // Opponents last move was passing.
            // We didn't consider passing. Should we have and
            // end the game immediately?
            float score = m_rootstate.final_score();
            // do we lose by passing?
            if ((score > 0.0f && color == FastBoard::WHITE)
                ||
                (score < 0.0f && color == FastBoard::BLACK)) {
                myprintf("Passing loses, I'll play on.\n");
            } else {
                myprintf("Passing wins, I'll pass out.\n");
                bestmove = FastBoard::PASS;
            }
        }
    }

    // if we aren't passing, should we consider resigning?
    if (bestmove != FastBoard::PASS) {
        if (should_resign(passflag, bestscore)) {
            myprintf("Eval (%.2f%%) looks bad. Resigning.\n",
                     100.0f * bestscore);
            bestmove = FastBoard::RESIGN;
        }
    }

    return bestmove;
}