size_t UCTSearch::getChildNodeType(UCTNode & parent, const GameState & prevState) const { if (!prevState.bothCanMove()) { return SearchNodeType::SoloNode; } else { if (parent.getNodeType() == SearchNodeType::RootNode) { return SearchNodeType::FirstSimNode; } else if (parent.getNodeType() == SearchNodeType::SoloNode) { return SearchNodeType::FirstSimNode; } else if (parent.getNodeType() == SearchNodeType::SecondSimNode) { return SearchNodeType::FirstSimNode; } else if (parent.getNodeType() == SearchNodeType::FirstSimNode) { return SearchNodeType::SecondSimNode; } } return SearchNodeType::Default; }
void UCTSearch::dump_stats(FastState & state, UCTNode & parent) { if (cfg_quiet || !parent.has_children()) { return; } const int color = state.get_to_move(); // sort children, put best move on top parent.sort_children(color); if (parent.get_first_child()->first_visit()) { return; } int movecount = 0; for (const auto& node : parent.get_children()) { // Always display at least two moves. In the case there is // only one move searched the user could get an idea why. if (++movecount > 2 && !node->get_visits()) break; std::string move = state.move_to_text(node->get_move()); FastState tmpstate = state; tmpstate.play_move(node->get_move()); std::string pv = move + " " + get_pv(tmpstate, *node); myprintf("%4s -> %7d (V: %5.2f%%) (N: %5.2f%%) PV: %s\n", move.c_str(), node->get_visits(), node->get_visits() ? node->get_eval(color)*100.0f : 0.0f, node->get_score() * 100.0f, pv.c_str()); } tree_stats(parent); }
StateEvalScore UCTSearch::traverse(UCTNode & node, GameState & currentState) { StateEvalScore playoutVal; _results.totalVisits++; // if we haven't visited this node yet, do a playout if (node.numVisits() == 0) { // update the status of the current state with this node's moves //updateState(node, currentState, !node.hasChildren()); updateState(node, currentState, true); // do the playout playoutVal = currentState.eval(_params.maxPlayer(), _params.evalMethod(), _params.simScript(Players::Player_One), _params.simScript(Players::Player_Two)); _results.nodesVisited++; } // otherwise we have seen this node before else { // update the state for a non-leaf node updateState(node, currentState, false); if (currentState.isTerminal()) { playoutVal = currentState.eval(_params.maxPlayer(), EvaluationMethods::LTD2); } else { // if the children haven't been generated yet if (!node.hasChildren()) { generateChildren(node, currentState); } UCTNode & next = UCTNodeSelect(node); playoutVal = traverse(next, currentState); } } node.incVisits(); if (playoutVal.val() > 0) { node.addWins(1); } else if (playoutVal.val() == 0) { node.addWins(0.5); } return playoutVal; }
IDType UCTSearch::getPlayerToMove(UCTNode & node, const GameState & state) const { const IDType whoCanMove(state.whoCanMove()); // if both players can move if (whoCanMove == Players::Player_Both) { // pick the first move based on our policy const IDType policy(_params.playerToMoveMethod()); const IDType maxPlayer(_params.maxPlayer()); // the max player always chooses at the root if (isRoot(node)) { return maxPlayer; } // the type of node this is const IDType nodeType = node.getNodeType(); // the 2nd player in a sim move is always the enemy of the first if (nodeType == SearchNodeType::FirstSimNode) { return state.getEnemy(node.getPlayer()); } // otherwise use our policy to see who goes first in a sim move state else { if (policy == SparCraft::PlayerToMove::Alternate) { return state.getEnemy(node.getPlayer()); } else if (policy == SparCraft::PlayerToMove::Not_Alternate) { return node.getPlayer(); } else if (policy == SparCraft::PlayerToMove::Random) { return rand() % 2; } // we should never get to this state System::FatalError("UCT Error: Nobody can move for some reason"); return Players::Player_None; } } else { return whoCanMove; } }
void UCTSearch::updateState(UCTNode & node, GameState & state, bool isLeaf) { // if it's the first sim move with children, or the root node if ((node.getNodeType() != SearchNodeType::FirstSimNode) || isLeaf) { // if this is a second sim node if (node.getNodeType() == SearchNodeType::SecondSimNode) { // make the parent's moves on the state because they haven't been done yet state.makeMoves(node.getParent()->getMove()); } // do the current node moves and call finished moving state.makeMoves(node.getMove()); state.finishedMoving(); } }
void tree_stats_helper(const UCTNode& node, size_t depth, size_t& nodes, size_t& non_leaf_nodes, size_t& depth_sum, size_t& max_depth, size_t& children_count) { nodes += 1; non_leaf_nodes += node.get_visits() > 1; depth_sum += depth; if (depth > max_depth) max_depth = depth; for (const auto& child : node.get_children()) { if (!child->first_visit()) children_count += 1; tree_stats_helper(*(child.get()), depth+1, nodes, non_leaf_nodes, depth_sum, max_depth, children_count); } }
std::string UCTSearch::get_pv(BoardHistory& state, UCTNode& parent, bool use_san) { if (!parent.has_children()) { return std::string(); } auto& best_child = parent.get_best_root_child(state.cur().side_to_move()); auto best_move = best_child.get_move(); auto res = use_san ? state.cur().move_to_san(best_move) : UCI::move(best_move); StateInfo st; state.cur().do_move(best_move, st); auto next = get_pv(state, best_child, use_san); if (!next.empty()) { res.append(" ").append(next); } state.cur().undo_move(best_move); return res; }
/* sort children by converting linked list to vector, sorting the vector, and reconstructing to linked list again Requires node mutex to be held. */ void UCTNode::sort_children() { assert(get_mutex().is_held()); std::vector<std::tuple<float, UCTNode*>> tmp; UCTNode * child = m_firstchild; while (child != nullptr) { tmp.emplace_back(child->get_score(), child); child = child->m_nextsibling; } std::sort(begin(tmp), end(tmp)); m_firstchild = nullptr; for (auto& sortnode : tmp) { link_child(std::get<1>(sortnode)); } }
void UCTNode::kill_superkos(KoState & state) { UCTNode * child = m_firstchild; while (child != nullptr) { int move = child->get_move(); if (move != FastBoard::PASS) { KoState mystate = state; mystate.play_move(move); if (mystate.superko()) { UCTNode * tmp = child->m_nextsibling; delete_child(child); child = tmp; continue; } } child = child->m_nextsibling; } }
std::string UCTSearch::get_pv(FastState & state, UCTNode& parent) { if (!parent.has_children()) { return std::string(); } auto& best_child = parent.get_best_root_child(state.get_to_move()); if (best_child.first_visit()) { return std::string(); } auto best_move = best_child.get_move(); auto res = state.move_to_text(best_move); state.play_move(best_move); auto next = get_pv(state, best_child); if (!next.empty()) { res.append(" ").append(next); } return res; }
// generate the children of state 'node' // state is the GameState after node's moves have been performed void UCTSearch::generateChildren(UCTNode & node, GameState & state) { // figure out who is next to move in the game const IDType playerToMove(getPlayerToMove(node, state)); // generate all the moves possible from this state state.generateMoves(_moveArray, playerToMove); _moveArray.shuffleMoveActions(); // generate the 'ordered moves' for move ordering generateOrderedMoves(state, _moveArray, playerToMove); // for each child of this state, add a child to the current node for (size_t child(0); (child < _params.maxChildren()) && getNextMove(playerToMove, _moveArray, child, _actionVec); ++child) { // add the child to the tree node.addChild(&node, playerToMove, getChildNodeType(node, state), _actionVec, _params.maxChildren(), _memoryPool ? _memoryPool->alloc() : NULL); _results.nodesCreated++; } }
void UCTSearch::printSubTreeGraphViz(UCTNode & node, GraphViz::Graph & g, GameState state) { if (node.getNodeType() == SearchNodeType::FirstSimNode && node.hasChildren()) { // don't make any moves if it is a first simnode } else { if (node.getNodeType() == SearchNodeType::SecondSimNode) { state.makeMoves(node.getParent()->getMove()); } state.makeMoves(node.getMove()); state.finishedMoving(); } std::stringstream label; std::stringstream move; for (size_t a(0); a<node.getMove().size(); ++a) { move << node.getMove()[a].moveString() << "\\n"; } if (node.getMove().size() == 0) { move << "root"; } std::string firstSim = SearchNodeType::getName(node.getNodeType()); Unit p1 = state.getUnit(0,0); Unit p2 = state.getUnit(1,0); label << move.str() << "\\nVal: " << node.getUCTVal() << "\\nWins: " << node.numWins() << "\\nVisits: " << node.numVisits() << "\\nChildren: " << node.numChildren() << "\\n" << firstSim << "\\nPtr: " << &node << "\\n---------------" << "\\nFrame: " << state.getTime() << "\\nHP: " << p1.currentHP() << " " << p2.currentHP() << "\\nAtk: " << p1.nextAttackActionTime() << " " << p2.nextAttackActionTime() << "\\nMove: " << p1.nextMoveActionTime() << " " << p2.nextMoveActionTime() << "\\nPrev: " << p1.previousActionTime() << " " << p2.previousActionTime(); std::string fillcolor ("#aaaaaa"); if (node.getPlayer() == Players::Player_One) { fillcolor = "#ff0000"; } else if (node.getPlayer() == Players::Player_Two) { fillcolor = "#00ff00"; } GraphViz::Node n(getNodeIDString(node)); n.set("label", label.str()); n.set("fillcolor", fillcolor); n.set("color", "#000000"); n.set("fontcolor", "#000000"); n.set("style", "filled,bold"); n.set("shape", "box"); g.addNode(n); // recurse for each child for (size_t c(0); c<node.numChildren(); ++c) { UCTNode & child = node.getChild(c); if (child.numVisits() > 0) { GraphViz::Edge edge(getNodeIDString(node), getNodeIDString(child)); g.addEdge(edge); printSubTreeGraphViz(child, g, state); } } }
void UCTSearch::dump_stats(BoardHistory& state, UCTNode& parent) { if (cfg_quiet || !parent.has_children()) { return; } myprintf("\n"); const Color color = state.cur().side_to_move(); // sort children, put best move on top m_root->sort_root_children(color); if (parent.get_first_child()->first_visit()) { return; } auto root_temperature = get_root_temperature(); auto accum_vector = m_root->calc_proportional(root_temperature, color); for (const auto& node : boost::adaptors::reverse(parent.get_children())) { std::string tmp = state.cur().move_to_san(node->get_move()); std::string pvstring(tmp); std::string moveprob(10, '\0'); if (cfg_randomize) { auto move_probability = accum_vector.back(); accum_vector.pop_back(); if (accum_vector.size() > 0) { move_probability -= accum_vector.back(); } move_probability *= 100.0f; // The following code expects percentage. if (move_probability > 0.01f) { std::snprintf(&moveprob[0], moveprob.size(), "(%6.2f%%)", move_probability); } else if (move_probability > 0.00001f) { std::snprintf(&moveprob[0], moveprob.size(), "%s", "(> 0.00%)"); } else { std::snprintf(&moveprob[0], moveprob.size(), "%s", "( 0.00%)"); } } else { auto needed = std::snprintf(&moveprob[0], moveprob.size(), "%s", " "); moveprob.resize(needed+1); } myprintf_so("info string %5s -> %7d %s (V: %5.2f%%) (N: %5.2f%%) PV: ", tmp.c_str(), node->get_visits(), moveprob.c_str(), node->get_eval(color)*100.0f, node->get_score() * 100.0f); StateInfo si; state.cur().do_move(node->get_move(), si); // Since this is just a string, set use_san=true pvstring += " " + get_pv(state, *node, true); state.cur().undo_move(node->get_move()); myprintf_so("%s\n", pvstring.c_str()); } // winrate separate info string since it's not UCI spec float feval = m_root->get_eval(color); myprintf_so("info string stm %s winrate %5.2f%%\n", color == Color::WHITE ? "White" : "Black", feval * 100.f); myprintf("\n"); }
UCTNode* UCTNode::uct_select_child(int color) { UCTNode * best = nullptr; float best_value = -1000.0f; LOCK(get_mutex(), lock); // Progressive widening // int childbound = std::max(2, (int)(((log((double)get_visits()) - 3.0) * 3.0) + 2.0)); int childbound = 362; int childcount = 0; UCTNode * child = m_firstchild; // Count parentvisits. // We do this manually to avoid issues with transpositions. int parentvisits = 0; // Make sure we are at a valid successor. while (child != nullptr && !child->valid()) { child = child->m_nextsibling; } while (child != nullptr && childcount < childbound) { parentvisits += child->get_visits(); child = child->m_nextsibling; // Make sure we are at a valid successor. while (child != nullptr && !child->valid()) { child = child->m_nextsibling; } childcount++; } float numerator = std::sqrt((double)parentvisits); childcount = 0; child = m_firstchild; // Make sure we are at a valid successor. while (child != nullptr && !child->valid()) { child = child->m_nextsibling; } if (child == nullptr) { return nullptr; } // Prune bad probabilities // auto parent_log = std::log((float)parentvisits); // auto cutoff_ratio = cfg_cutoff_offset + cfg_cutoff_ratio * parent_log; // auto best_probability = child->get_score(); // assert(best_probability > 0.001f); while (child != nullptr && childcount < childbound) { // Prune bad probabilities // if (child->get_score() * cutoff_ratio < best_probability) { // break; // } // get_eval() will automatically set first-play-urgency float winrate = child->get_eval(color); float psa = child->get_score(); float denom = 1.0f + child->get_visits(); float puct = cfg_puct * psa * (numerator / denom); float value = winrate + puct; assert(value > -1000.0f); if (value > best_value) { best_value = value; best = child; } child = child->m_nextsibling; // Make sure we are at a valid successor. while (child != nullptr && !child->valid()) { child = child->m_nextsibling; } childcount++; } assert(best != nullptr); return best; }
int UCTSearch::get_best_move(passflag_t passflag) { int color = m_rootstate.board.get_to_move(); // Make sure best is first m_root->sort_children(color); // Check whether to randomize the best move proportional // to the playout counts, early game only. auto movenum = int(m_rootstate.get_movenum()); if (movenum < cfg_random_cnt) { m_root->randomize_first_proportionally(); } auto first_child = m_root->get_first_child(); assert(first_child != nullptr); auto bestmove = first_child->get_move(); auto bestscore = first_child->get_eval(color); // do we want to fiddle with the best move because of the rule set? if (passflag & UCTSearch::NOPASS) { // were we going to pass? if (bestmove == FastBoard::PASS) { UCTNode * nopass = m_root->get_nopass_child(m_rootstate); if (nopass != nullptr) { myprintf("Preferring not to pass.\n"); bestmove = nopass->get_move(); if (nopass->first_visit()) { bestscore = 1.0f; } else { bestscore = nopass->get_eval(color); } } else { myprintf("Pass is the only acceptable move.\n"); } } } else { if (!cfg_dumbpass && bestmove == FastBoard::PASS) { // Either by forcing or coincidence passing is // on top...check whether passing loses instantly // do full count including dead stones. // In a reinforcement learning setup, it is possible for the // network to learn that, after passing in the tree, the two last // positions are identical, and this means the position is only won // if there are no dead stones in our own territory (because we use // Trump-Taylor scoring there). So strictly speaking, the next // heuristic isn't required for a pure RL network, and we have // a commandline option to disable the behavior during learning. // On the other hand, with a supervised learning setup, we fully // expect that the engine will pass out anything that looks like // a finished game even with dead stones on the board (because the // training games were using scoring with dead stone removal). // So in order to play games with a SL network, we need this // heuristic so the engine can "clean up" the board. It will still // only clean up the bare necessity to win. For full dead stone // removal, kgs-genmove_cleanup and the NOPASS mode must be used. float score = m_rootstate.final_score(); // Do we lose by passing? if ((score > 0.0f && color == FastBoard::WHITE) || (score < 0.0f && color == FastBoard::BLACK)) { myprintf("Passing loses :-(\n"); // Find a valid non-pass move. UCTNode * nopass = m_root->get_nopass_child(m_rootstate); if (nopass != nullptr) { myprintf("Avoiding pass because it loses.\n"); bestmove = nopass->get_move(); if (nopass->first_visit()) { bestscore = 1.0f; } else { bestscore = nopass->get_eval(color); } } else { myprintf("No alternative to passing.\n"); } } else { myprintf("Passing wins :-)\n"); } } else if (!cfg_dumbpass && m_rootstate.get_last_move() == FastBoard::PASS) { // Opponents last move was passing. // We didn't consider passing. Should we have and // end the game immediately? float score = m_rootstate.final_score(); // do we lose by passing? if ((score > 0.0f && color == FastBoard::WHITE) || (score < 0.0f && color == FastBoard::BLACK)) { myprintf("Passing loses, I'll play on.\n"); } else { myprintf("Passing wins, I'll pass out.\n"); bestmove = FastBoard::PASS; } } } // if we aren't passing, should we consider resigning? if (bestmove != FastBoard::PASS) { if (should_resign(passflag, bestscore)) { myprintf("Eval (%.2f%%) looks bad. Resigning.\n", 100.0f * bestscore); bestmove = FastBoard::RESIGN; } } return bestmove; }