SgUctValue SgUctSearch::GetBound(bool useRave, const SgUctNode& node, const SgUctNode& child) const { SgUctValue posCount = node.PosCount(); int virtualLossCount = node.VirtualLossCount(); if (virtualLossCount > 0) { posCount += SgUctValue(virtualLossCount); } return GetBound(useRave, Log(posCount), child); }
void MoHexPlayer::CopyKnowledgeData(const SgUctTree& tree, const SgUctNode& node, HexColor color, MoveSequence& sequence, const MoHexSharedData& oldData, MoHexSharedData& newData) const { // This check will fail in the root if we are reusing the // entire tree, so only do it when not in the root. if (sequence != oldData.gameSequence) { SgHashCode hash = SequenceHash::Hash(sequence); StoneBoard stones; if (!oldData.stones.Get(hash, stones)) return; newData.stones.Add(hash, stones); } if (!node.HasChildren()) return; for (SgUctChildIterator it(tree, node); it; ++it) { sequence.push_back(Move(color, static_cast<HexPoint>((*it).Move()))); CopyKnowledgeData(tree, *it, !color, sequence, oldData, newData); sequence.pop_back(); } }
void SgUctTree::SetChildren(std::size_t allocatorId, const SgUctNode& node, const vector<SgMove>& moves) { SG_ASSERT(Contains(node)); SG_ASSERT(Allocator(allocatorId).HasCapacity(moves.size())); SG_ASSERT(node.HasChildren()); SgUctAllocator& allocator = Allocator(allocatorId); const SgUctNode* firstChild = allocator.Finish(); int nuChildren = 0; for (size_t i = 0; i < moves.size(); ++i) { bool found = false; for (SgUctChildIterator it(*this, node); it; ++it) { SgMove move = (*it).Move(); if (move == moves[i]) { found = true; SgUctNode* child = allocator.CreateOne(move); child->CopyDataFrom(*it); int childNuChildren = (*it).NuChildren(); child->SetNuChildren(childNuChildren); if (childNuChildren > 0) child->SetFirstChild((*it).FirstChild()); ++nuChildren; break; } } if (! found) { allocator.CreateOne(moves[i]); ++nuChildren; } } SG_ASSERT((size_t)nuChildren == moves.size()); SgUctNode& nonConstNode = const_cast<SgUctNode&>(node); // Write order dependency: SgUctSearch in lock-free mode assumes that // m_firstChild is valid if m_nuChildren is greater zero SgSynchronizeThreadMemory(); nonConstNode.SetFirstChild(firstChild); SgSynchronizeThreadMemory(); nonConstNode.SetNuChildren(nuChildren); }
void SgUctTree::SetMustplay(const SgUctNode& node, const std::vector<SgUctMoveInfo>& moves, bool deleteChildTrees) { for (SgUctChildIterator it(*this, node); it; ++it) { SgUctNode* child = const_cast<SgUctNode*>(&(*it)); bool found = false; for (size_t j = 0; j < moves.size(); ++j) { if (child->Move() == moves[j].m_move) { found = true; if (moves[j].m_count > 0) child->AddGameResults(moves[j].m_value, moves[j].m_count); if (moves[j].m_raveCount > 0) child->AddRaveValue(moves[j].m_raveValue, moves[j].m_raveCount); if (deleteChildTrees) { // Write order dependency child->SetNuChildren(0); SgSynchronizeThreadMemory(); child->SetFirstChild(0); } break; } } if (!found) child->SetProvenType(SG_PROVEN_WIN); // mark as loss } SgSynchronizeThreadMemory(); }
void SgUctTree::ApplyFilter(std::size_t allocatorId, const SgUctNode& node, const vector<SgMove>& rootFilter) { SG_ASSERT(Contains(node)); SG_ASSERT(Allocator(allocatorId).HasCapacity(node.NuChildren())); if (! node.HasChildren()) return; SgUctAllocator& allocator = Allocator(allocatorId); const SgUctNode* firstChild = allocator.Finish(); int nuChildren = 0; for (SgUctChildIterator it(*this, node); it; ++it) { SgMove move = (*it).Move(); if (find(rootFilter.begin(), rootFilter.end(), move) == rootFilter.end()) { SgUctNode* child = allocator.CreateOne(move); child->CopyDataFrom(*it); int childNuChildren = (*it).NuChildren(); child->SetNuChildren(childNuChildren); if (childNuChildren > 0) child->SetFirstChild((*it).FirstChild()); ++nuChildren; } } SgUctNode& nonConstNode = const_cast<SgUctNode&>(node); // Write order dependency: SgUctSearch in lock-free mode assumes that // m_firstChild is valid if m_nuChildren is greater zero nonConstNode.SetFirstChild(firstChild); SgSynchronizeThreadMemory(); nonConstNode.SetNuChildren(nuChildren); }
/** Optimized version of GetValueEstimate() if RAVE and not other estimators are used. Previously there were more estimators than move value and RAVE value, and in the future there may be again. GetValueEstimate() is easier to extend, this function is more optimized for the special case. */ SgUctValue SgUctSearch::GetValueEstimateRave(const SgUctNode& child) const { SG_ASSERT(m_rave); SgUctValue value; SgUctStatistics uctStats; if (child.HasMean()) { uctStats.Initialize(child.Mean(), child.MoveCount()); } SgUctStatistics raveStats; if (child.HasRaveValue()) { raveStats.Initialize(child.RaveValue(), child.RaveCount()); } int virtualLossCount = child.VirtualLossCount(); if (virtualLossCount > 0) { uctStats.Add(InverseEstimate(0), SgUctValue(virtualLossCount)); raveStats.Add(0, SgUctValue(virtualLossCount)); } bool hasRave = raveStats.IsDefined(); if (uctStats.IsDefined()) { SgUctValue moveValue = InverseEstimate((SgUctValue)uctStats.Mean()); if (hasRave) { SgUctValue moveCount = uctStats.Count(); SgUctValue raveCount = raveStats.Count(); SgUctValue weight = raveCount / (moveCount * (m_raveWeightParam1 + m_raveWeightParam2 * raveCount) + raveCount); value = weight * raveStats.Mean() + (1.f - weight) * moveValue; } else { // This can happen only in lock-free multi-threading. Normally, // each move played in a position should also cause a RAVE value // to be added. But in lock-free multi-threading it can happen // that the move value was already updated but the RAVE value not SG_ASSERT(m_numberThreads > 1 && m_lockFree); value = moveValue; } } else if (hasRave) value = raveStats.Mean(); else value = m_firstPlayUrgency; SG_ASSERT(m_numberThreads > 1 || fabs(value - GetValueEstimate(m_rave, child)) < 1e-3/*epsilon*/); return value; }
SgUctValue SgUctSearch::GetValueEstimate(bool useRave, const SgUctNode& child) const { SgUctValue value = 0; SgUctValue weightSum = 0; bool hasValue = false; SgUctStatistics uctStats; if (child.HasMean()) { uctStats.Initialize(child.Mean(), child.MoveCount()); } int virtualLossCount = child.VirtualLossCount(); if (virtualLossCount > 0) { uctStats.Add(InverseEstimate(0), SgUctValue(virtualLossCount)); } if (uctStats.IsDefined()) { SgUctValue weight = static_cast<SgUctValue>(uctStats.Count()); value += weight * InverseEstimate((SgUctValue)uctStats.Mean()); weightSum += weight; hasValue = true; } if (useRave) { SgUctStatistics raveStats; if (child.HasRaveValue()) { raveStats.Initialize(child.RaveValue(), child.RaveCount()); } if (virtualLossCount > 0) { raveStats.Add(0, SgUctValue(virtualLossCount)); } if (raveStats.IsDefined()) { SgUctValue raveCount = raveStats.Count(); SgUctValue weight = raveCount / ( m_raveWeightParam1 + m_raveWeightParam2 * raveCount ); value += weight * raveStats.Mean(); weightSum += weight; hasValue = true; } } if (hasValue) return value / weightSum; else return m_firstPlayUrgency; }
SgUctValue SgUctSearch::GetBound(bool useRave, SgUctValue logPosCount, const SgUctNode& child) const { SgUctValue value; if (useRave) value = GetValueEstimateRave(child); else value = GetValueEstimate(false, child); if (m_biasTermConstant == 0.0) return value; else { SgUctValue moveCount = static_cast<SgUctValue>(child.MoveCount()); SgUctValue bound = value + m_biasTermConstant * sqrt(logPosCount / (moveCount + 1)); return bound; } }
void SgUctTree::MergeChildren(std::size_t allocatorId, const SgUctNode& node, const std::vector<SgMoveInfo>& moves, bool deleteChildTrees) { SG_ASSERT(Contains(node)); // Parameters are const-references, because only the tree is allowed // to modify nodes SgUctNode& nonConstNode = const_cast<SgUctNode&>(node); size_t nuNewChildren = moves.size(); if (nuNewChildren == 0) { // Write order dependency nonConstNode.SetNuChildren(0); SgSynchronizeThreadMemory(); nonConstNode.SetFirstChild(0); return; } SgUctAllocator& allocator = Allocator(allocatorId); SG_ASSERT(allocator.HasCapacity(nuNewChildren)); const SgUctNode* newFirstChild = allocator.Finish(); std::size_t parentCount = allocator.Create(moves); // Update new children with data in old children for (std::size_t i = 0; i < moves.size(); ++i) { SgUctNode* newChild = const_cast<SgUctNode*>(&newFirstChild[i]); for (SgUctChildIterator it(*this, node); it; ++it) { const SgUctNode& oldChild = *it; if (oldChild.Move() == moves[i].m_move) { newChild->MergeResults(oldChild); newChild->SetKnowledgeCount(oldChild.KnowledgeCount()); if (! deleteChildTrees) { newChild->SetPosCount(oldChild.PosCount()); parentCount += oldChild.MoveCount(); if (oldChild.HasChildren()) { newChild->SetFirstChild(oldChild.FirstChild()); newChild->SetNuChildren(oldChild.NuChildren()); } } break; } } } nonConstNode.SetPosCount(parentCount); // Write order dependency: We do not want an SgUctChildIterator to // run past the end of a node's children, which can happen if one // is created between the two statements below. We modify node in // such a way so as to avoid that. if (nonConstNode.NuChildren() < (int)nuNewChildren) { nonConstNode.SetFirstChild(newFirstChild); SgSynchronizeThreadMemory(); nonConstNode.SetNuChildren(nuNewChildren); } else { nonConstNode.SetNuChildren(nuNewChildren); SgSynchronizeThreadMemory(); nonConstNode.SetFirstChild(newFirstChild); } }
/** Recursive function used by SgUctTree::ExtractSubtree and SgUctTree::CopyPruneLowCount. @param target The target tree. @param targetNode The target node; it is already created but the content not yet copied @param node The node in the source tree to be copied. @param minCount The minimum count (SgUctNode::MoveCount()) of a non-root node in the source tree to copy @param currentAllocatorId The current node allocator. Will be incremented in each call to CopySubtree to use node allocators of target tree evenly. @param warnTruncate Print warning to SgDebug() if tree was truncated (e.g due to reassigning nodes to different allocators) @param[in,out] abort Flag to abort copying. Must be initialized to false by top-level caller @param timer @param maxTime See ExtractSubtree() */ void SgUctTree::CopySubtree(SgUctTree& target, SgUctNode& targetNode, const SgUctNode& node, std::size_t minCount, std::size_t& currentAllocatorId, bool warnTruncate, bool& abort, SgTimer& timer, double maxTime) const { SG_ASSERT(Contains(node)); SG_ASSERT(target.Contains(targetNode)); targetNode.CopyDataFrom(node); if (! node.HasChildren() || node.MoveCount() < minCount) return; SgUctAllocator& targetAllocator = target.Allocator(currentAllocatorId); int nuChildren = node.NuChildren(); if (! abort) { if (! targetAllocator.HasCapacity(nuChildren)) { // This can happen even if target tree has same maximum number of // nodes, because allocators are used differently. if (warnTruncate) SgDebug() << "SgUctTree::CopySubtree: Truncated (allocator capacity)\n"; abort = true; } if (timer.IsTimeOut(maxTime, 10000)) { if (warnTruncate) SgDebug() << "SgUctTree::CopySubtree: Truncated (max time)\n"; abort = true; } if (SgUserAbort()) { if (warnTruncate) SgDebug() << "SgUctTree::CopySubtree: Truncated (aborted)\n"; abort = true; } } if (abort) { // Don't copy the children and set the pos count to zero (should // reflect the sum of children move counts) targetNode.SetPosCount(0); return; } SgUctNode* firstTargetChild = targetAllocator.Finish(); targetNode.SetFirstChild(firstTargetChild); targetNode.SetNuChildren(nuChildren); // Create target nodes first (must be contiguous in the target tree) targetAllocator.CreateN(nuChildren); // Recurse SgUctNode* targetChild = firstTargetChild; for (SgUctChildIterator it(*this, node); it; ++it, ++targetChild) { const SgUctNode& child = *it; ++currentAllocatorId; // Cycle to use allocators uniformly if (currentAllocatorId >= target.NuAllocators()) currentAllocatorId = 0; CopySubtree(target, *targetChild, child, minCount, currentAllocatorId, warnTruncate, abort, timer, maxTime); } }