예제 #1
0
MCTSNode* MCTSNode::FindChildToSelect(uint children_number) {
	ASSERT(children_number > 0);
	MCTSNode* best = &(children[0]);
	float best_val = best->Eval();
	for (uint i = 1; i < children_number; ++i) {
		float current_val = children[i].Eval();
		if (current_val > best_val) {
			best = &(children[i]);
			best_val = current_val;
		}
	}
	return best;
}
예제 #2
0
MCTSNode* MCTSNode::FindBestChild(uint children_number) {
	ASSERT(children_number > 0);
	MCTSNode* best = &(children[0]);
	float best_mu = best->GetMu();
	for (uint i = 1; i < children_number; ++i) {
		float current_mu = children[i].GetMu();
		if (current_mu > best_mu) {
			best = &(children[i]);
			best_mu = current_mu;
		}
	}
	return best;
}
void PositionMonteCarloTreeSearch::go(unsigned int depth) {
    (void)depth; // depth wird hier nicht benötigt
    srand(time(NULL));
    MCTSNode *rootNode = new MCTSNode();
    int iterations=0;
    while(1) {
        PositionMonteCarloTreeSearch *rootPos = new PositionMonteCarloTreeSearch(*this);
        MCTSNode *selected = rootNode->select(rootPos);
        MCTSNode *expanded = selected->expand(rootPos);
        double result = expanded->simulate(rootPos);
        expanded->update(result);
        iterations++;
        if (iterations%100==0) {
            if (timer.isTimeout() || timer.checkTimeout())
                break;
        }
        delete rootPos;
    }
    string foundMove=rootNode->getMove();
    delete rootNode;
    
    /************************************
     * Das Ergebnis der Suche ausgeben. *
     ************************************/
    timer.stopTimer();
    double seconds = timer.getStartEndDiffSeconds();
    cout << "Move: " << foundMove << endl;
    cout << "Time: " << seconds << " seconds" << endl;
    cout << "bestmove " << foundMove << endl;
}
MCTSNode* MCTSNode::SelectBestChild() const {

    ASSERT(count > 0);

    MCTSNode* best = &children[0];
    float best_mu = best->Best();
    for (uint i = 1; i < count; ++i) {
        float mu = children[i].Best();
        if (mu > best_mu) {
            best = &children[i];
            best_mu = mu;
        }
    }
    return best;
}
MCTSNode* MCTSNode::SelectChild() const {

    ASSERT(count > 0);
    ASSERT(chosen_children.GetPointer());
    ASSERT(chosen_count > 0);

    MCTSNode* best = chosen_children[0];
    float best_val = best->GetValue();
    for (uint i = 1; i < chosen_count; ++i) {
        float val = chosen_children[i]->GetValue();
        if (val > best_val) {
            best = chosen_children[i];
            best_val = val;
        }
    }
    return best;
}
예제 #6
0
파일: mcts.cpp 프로젝트: ikoryakovskiy/grl
void MCTSPolicy::act(double time, const Observation &in, Action *out)
{
  // Clear tree at start of episode
  if (time == 0.)
  {
    safe_delete(&root_);
    trunk_ = NULL;
  }

  // Try warm start
  if (trunk_ && trunk_->children())
  {
    double maxdiff = 0;
    MCTSNode *selected = trunk_->select(0);
    Vector predicted = selected->state();
    
    for (size_t ii=0; ii < in.size(); ++ii)
      maxdiff = fmax(maxdiff, fabs(in[ii]-predicted[ii]));
      
    if (maxdiff < 0.05)
    {
      trunk_ = selected;
      selected->orphanize();

      CRAWL("Trunk set to selected state " << trunk_->state());
    }
    else
    {
      safe_delete(&root_);
      trunk_ = NULL;
      TRACE("Cannot use warm start: predicted state " << predicted << " differs from actual state " << in);
    }
  }

  // Allocate new tree if warm start was not possible
  if (!trunk_)
  {
    allocate();
    root_->init(NULL, 0, in, 0, false);
    root_->allocate(discretizer_->size(in));
    trunk_ = root_;
  }
  
  CRAWL("Trunk set to state " << trunk_->state());

  // Search until budget is up
  timer t;
  size_t searches=0;

  while (t.elapsed() < budget_)
  {
    MCTSNode *node = treePolicy(), *it=node;
    size_t depth=0;
    
    while ((it = it->parent()))
      depth++;
    
    double reward = 0;
    
    CRAWL("Tree policy selected node with state " << node->state() << " at depth " << depth);
    
    if (!node->terminal() && depth < horizon_)
      reward = defaultPolicy(node->state(), horizon_-depth);
     
    CRAWL("Default policy got reward " << reward);

    do
    {
      node->update(reward);
      reward = gamma_*reward + node->reward();
    } while ((node = node->parent()));
    
    searches++;
  }
  
  // Select best action
  if (trunk_->children())
  {
    MCTSNode *node = trunk_->select(0);
    *out = discretizer_->at(trunk_->state(), node->action());
    out->type = atGreedy;

    TRACE("Selected action " << *out << " (Q " << node->q()/node->visits() << ") after " << searches << " searches");
  }
  else
  {
    *out = discretizer_->at(in, lrand48()%discretizer_->size(in));
    out->type = atExploratory;

    TRACE("Selected random action " << *out);
  }
}
예제 #7
0
	Move ChooseMove(const GameState& game, unsigned iterations)
	{
		std::mt19937 r(GlobalRandomDevice());
		MCTSNode root;

		MCTSNode* store = (MCTSNode*)malloc(sizeof(MCTSNode) * iterations);
		MCTSNode* store_head = store;

		for (unsigned iter = 0; iter < iterations; ++iter)
		{
			GameState sim_state = Determinize(game, r);

			// Selection
			MCTSNode* node = &root;
			while (!node->HasUntriedMoves(sim_state) && node->HasChildren())
			{
				MCTSNode* next_node = node->UCTSelectChild(sim_state);

				// Update availability
				for (MCTSNode* avail_node = node->m_child; avail_node; avail_node = avail_node->m_siblings)
				{
					if (sim_state.m_possible_moves.Contains(avail_node->m_move))
					{
						avail_node->m_availability++;
					}
				}

				sim_state.ProcessMove(next_node->m_move);
				node = next_node;
			}
		
			// Expansion
			if (node->HasUntriedMoves(sim_state))
			{
				Move m = node->ChooseRandomUntriedMove(sim_state, r);
				for (MCTSNode* avail_node = node->m_child; avail_node; avail_node = avail_node->m_siblings)
				{
					if (sim_state.m_possible_moves.Contains(avail_node->m_move))
					{
						avail_node->m_availability++;
					}
				}

				sim_state.ProcessMove(m);
				node = node->AddChild(m, store_head);
				store_head += 1;
				node->m_availability++;
			}

			// Simulation
			sim_state.PlayOutRandomly(r);
		
			// Backpropagation
			bool won = sim_state.m_winner == (Winner)game.m_active_player_index;
			while (node)
			{
				node->m_visits++;
				if (won) node->m_wins++;

				node = node->m_parent;
			}
		}

		MCTSNode* best_node = nullptr;
		uint32_t best_visits = 0;

		for (MCTSNode* node = root.m_child; node; node = node->m_siblings)
		{
			if (node->m_visits > best_visits)
			{
				best_visits = node->m_visits;
				best_node = node;
			}
		}

		Move m = best_node->m_move;
		free(store);
		return m;
	}
예제 #8
0
Move MCTSTree::BestMove(Player player, Board& board) {

	ClearTree();

	ASSERT (!board.IsFull());
	ASSERT (root != NULL);

	MCTSNode* current_node;
	uint current_level;
	typedef MCTSNode* mcts_node_ptr;
	mcts_node_ptr path[ultimate_depth + 1];
	uint full_path[ultimate_depth + 1];
	path[0] = root.GetPointer();
	Board brd;

	if (root->children == NULL) {
		root->Expand(board);
		root_children_number = board.MovesLeft();
	}

	for (uint i = 0; i < playouts_per_move; ++i) {

		current_level = 0;
		current_node = root.GetPointer();
		brd.Load(board);
		while (current_node->children != NULL) {
			if (max_depth == 0)
				current_node = &current_node->children[Rand::next_rand(brd.MovesLeft())];
			else current_node = current_node->FindChildToSelect(brd.MovesLeft());
			brd.PlayLegal(Move(brd.CurrentPlayer(), current_node->loc));
			path[++current_level] = current_node;
			full_path[current_level] = current_node->loc.GetPos();
		}

		if (current_level < max_depth && brd.MovesLeft() > 0 &&
				current_node->uct_stats.played >= visits_to_expand +
				2 * Params::initialization) {
			current_node->Expand(brd);
			current_node = current_node->FindChildToSelect(brd.MovesLeft());
			brd.PlayLegal(Move(brd.CurrentPlayer(), current_node->loc));
			path[++current_level] = current_node;
			full_path[current_level] = current_node->loc.GetPos();
		}

		Player current = brd.CurrentPlayer();
		Player winner = RandomFinish(brd, full_path, current_level);

		for (int level = current_level; level >= 0; --level) {
			if (winner != current)
				path[level]->uct_stats.won++;
			path[level]->uct_stats.played++;
			path[level]->SetInvalidUCB();
			current = current.Opponent();
		}

		for (int level = board.MovesLeft(); level > 0; --level) {
			uint pos = full_path[level];
			int tree_level = current_level - 1;
			if (tree_level >= level)
				tree_level = level - 1;
			else if (((level + tree_level) & 1) == 0)
				tree_level--;
			if ((tree_level & 1) == 0)
				current = player;
			else current = player.Opponent();
			while (tree_level >= 0) {
				MCTSNode* updated = path[tree_level]->pos_to_children_mapping[pos];
				if (winner == current)
					updated->rave_stats.won++;
				updated->rave_stats.played++;
				updated->SetInvalidRAVE();
				tree_level -= 2;
			}
		}
	}

	MCTSNode* best = root->FindBestChild(board.MovesLeft());
	current_player = current_player.Opponent();

	return Move(player, best->loc);
}