Esempio n. 1
0
SgUctValue SgUctSearch::GetBound(bool useRave, const SgUctNode& node,
                                 const SgUctNode& child) const
{
    SgUctValue posCount = node.PosCount();
    int virtualLossCount = node.VirtualLossCount();
    if (virtualLossCount > 0)
    {
        posCount += SgUctValue(virtualLossCount);
    }
    return GetBound(useRave, Log(posCount), child);
}
void MoHexPlayer::CopyKnowledgeData(const SgUctTree& tree,
                                    const SgUctNode& node,
                                    HexColor color, MoveSequence& sequence,
                                    const MoHexSharedData& oldData,
                                    MoHexSharedData& newData) const
{
    // This check will fail in the root if we are reusing the
    // entire tree, so only do it when not in the root.
    if (sequence != oldData.gameSequence)
    {
        SgHashCode hash = SequenceHash::Hash(sequence);
        StoneBoard stones;
        if (!oldData.stones.Get(hash, stones))
            return;
        newData.stones.Add(hash, stones);
    }
    if (!node.HasChildren())
        return;
    for (SgUctChildIterator it(tree, node); it; ++it)
    {
        sequence.push_back(Move(color, static_cast<HexPoint>((*it).Move())));
        CopyKnowledgeData(tree, *it, !color, sequence, oldData, newData);
        sequence.pop_back();
    }
}
Esempio n. 3
0
void SgUctTree::SetChildren(std::size_t allocatorId, const SgUctNode& node,
                            const vector<SgMove>& moves)
{
    SG_ASSERT(Contains(node));
    SG_ASSERT(Allocator(allocatorId).HasCapacity(moves.size()));
    SG_ASSERT(node.HasChildren());

    SgUctAllocator& allocator = Allocator(allocatorId);
    const SgUctNode* firstChild = allocator.Finish();

    int nuChildren = 0;
    for (size_t i = 0; i < moves.size(); ++i)
    {
        bool found = false;
        for (SgUctChildIterator it(*this, node); it; ++it)
        {
            SgMove move = (*it).Move();
            if (move == moves[i])
            {
                found = true;
                SgUctNode* child = allocator.CreateOne(move);
                child->CopyDataFrom(*it);
                int childNuChildren = (*it).NuChildren();
                child->SetNuChildren(childNuChildren);
                if (childNuChildren > 0)
                    child->SetFirstChild((*it).FirstChild());
                ++nuChildren;
                break;
            }
        }
        if (! found)
        {
            allocator.CreateOne(moves[i]);
            ++nuChildren;
        }
    }
    SG_ASSERT((size_t)nuChildren == moves.size());

    SgUctNode& nonConstNode = const_cast<SgUctNode&>(node);
    // Write order dependency: SgUctSearch in lock-free mode assumes that
    // m_firstChild is valid if m_nuChildren is greater zero
    SgSynchronizeThreadMemory();
    nonConstNode.SetFirstChild(firstChild);
    SgSynchronizeThreadMemory();
    nonConstNode.SetNuChildren(nuChildren);
}
Esempio n. 4
0
void SgUctTree::SetMustplay(const SgUctNode& node,
                            const std::vector<SgUctMoveInfo>& moves,
                            bool deleteChildTrees)
{
    for (SgUctChildIterator it(*this, node); it; ++it)
    {
        SgUctNode* child = const_cast<SgUctNode*>(&(*it));
        bool found = false;
        for (size_t j = 0; j < moves.size(); ++j) 
        {
            if (child->Move() == moves[j].m_move)
            {
                found = true;
                if (moves[j].m_count > 0)
                    child->AddGameResults(moves[j].m_value, moves[j].m_count);
                if (moves[j].m_raveCount > 0)
                    child->AddRaveValue(moves[j].m_raveValue, 
                                        moves[j].m_raveCount);
                if (deleteChildTrees) 
                {
                    // Write order dependency
                    child->SetNuChildren(0);
                    SgSynchronizeThreadMemory();                    
                    child->SetFirstChild(0);
                }
                break;
            }
        }
        if (!found)
            child->SetProvenType(SG_PROVEN_WIN); // mark as loss
    }
    SgSynchronizeThreadMemory();
}
Esempio n. 5
0
void SgUctTree::ApplyFilter(std::size_t allocatorId, const SgUctNode& node,
                            const vector<SgMove>& rootFilter)
{
    SG_ASSERT(Contains(node));
    SG_ASSERT(Allocator(allocatorId).HasCapacity(node.NuChildren()));
    if (! node.HasChildren())
        return;

    SgUctAllocator& allocator = Allocator(allocatorId);
    const SgUctNode* firstChild = allocator.Finish();

    int nuChildren = 0;
    for (SgUctChildIterator it(*this, node); it; ++it)
    {
        SgMove move = (*it).Move();
        if (find(rootFilter.begin(), rootFilter.end(), move)
            == rootFilter.end())
        {
            SgUctNode* child = allocator.CreateOne(move);
            child->CopyDataFrom(*it);
            int childNuChildren = (*it).NuChildren();
            child->SetNuChildren(childNuChildren);
            if (childNuChildren > 0)
                child->SetFirstChild((*it).FirstChild());
            ++nuChildren;
        }
    }

    SgUctNode& nonConstNode = const_cast<SgUctNode&>(node);
    // Write order dependency: SgUctSearch in lock-free mode assumes that
    // m_firstChild is valid if m_nuChildren is greater zero
    nonConstNode.SetFirstChild(firstChild);
    SgSynchronizeThreadMemory();
    nonConstNode.SetNuChildren(nuChildren);
}
Esempio n. 6
0
/** Optimized version of GetValueEstimate() if RAVE and not other
    estimators are used.
    Previously there were more estimators than move value and RAVE value,
    and in the future there may be again. GetValueEstimate() is easier to
    extend, this function is more optimized for the special case. */
SgUctValue SgUctSearch::GetValueEstimateRave(const SgUctNode& child) const
{
    SG_ASSERT(m_rave);
    SgUctValue value;
    SgUctStatistics uctStats;
    if (child.HasMean())
    {
        uctStats.Initialize(child.Mean(), child.MoveCount());
    }
    SgUctStatistics raveStats;
    if (child.HasRaveValue())
    {
        raveStats.Initialize(child.RaveValue(), child.RaveCount());
    }
    int virtualLossCount = child.VirtualLossCount();
    if (virtualLossCount > 0)
    {
        uctStats.Add(InverseEstimate(0), SgUctValue(virtualLossCount));
        raveStats.Add(0, SgUctValue(virtualLossCount));
    }
    bool hasRave = raveStats.IsDefined();
    
    if (uctStats.IsDefined())
    {
        SgUctValue moveValue = InverseEstimate((SgUctValue)uctStats.Mean());
        if (hasRave)
        {
            SgUctValue moveCount = uctStats.Count();
            SgUctValue raveCount = raveStats.Count();
            SgUctValue weight =
                raveCount
                / (moveCount
                   * (m_raveWeightParam1 + m_raveWeightParam2 * raveCount)
                   + raveCount);
            value = weight * raveStats.Mean() + (1.f - weight) * moveValue;
        }
        else
        {
            // This can happen only in lock-free multi-threading. Normally,
            // each move played in a position should also cause a RAVE value
            // to be added. But in lock-free multi-threading it can happen
            // that the move value was already updated but the RAVE value not
            SG_ASSERT(m_numberThreads > 1 && m_lockFree);
            value = moveValue;
        }
    }
    else if (hasRave)
        value = raveStats.Mean();
    else
        value = m_firstPlayUrgency;
    SG_ASSERT(m_numberThreads > 1
              || fabs(value - GetValueEstimate(m_rave, child)) < 1e-3/*epsilon*/);
    return value;
}
Esempio n. 7
0
SgUctValue SgUctSearch::GetValueEstimate(bool useRave, const SgUctNode& child) const
{
    SgUctValue value = 0;
    SgUctValue weightSum = 0;
    bool hasValue = false;

    SgUctStatistics uctStats;
    if (child.HasMean())
    {
        uctStats.Initialize(child.Mean(), child.MoveCount());
    }
    int virtualLossCount = child.VirtualLossCount();
    if (virtualLossCount > 0)
    {
        uctStats.Add(InverseEstimate(0), SgUctValue(virtualLossCount));
    }

    if (uctStats.IsDefined())
    {
        SgUctValue weight = static_cast<SgUctValue>(uctStats.Count());
        value += weight * InverseEstimate((SgUctValue)uctStats.Mean());
        weightSum += weight;
        hasValue = true;
    }

    if (useRave)
    {
        SgUctStatistics raveStats;
        if (child.HasRaveValue())
        {
            raveStats.Initialize(child.RaveValue(), child.RaveCount());
        }
        if (virtualLossCount > 0)
        {
            raveStats.Add(0, SgUctValue(virtualLossCount));
        }
        if (raveStats.IsDefined())
        {
            SgUctValue raveCount = raveStats.Count();
            SgUctValue weight =
                raveCount
                / (  m_raveWeightParam1
                     + m_raveWeightParam2 * raveCount
                     );
            value += weight * raveStats.Mean();
            weightSum += weight;
            hasValue = true;
        }
    }
    if (hasValue)
        return value / weightSum;
    else
        return m_firstPlayUrgency;
}
Esempio n. 8
0
SgUctValue SgUctSearch::GetBound(bool useRave, SgUctValue logPosCount,
                                 const SgUctNode& child) const
{
    SgUctValue value;
    if (useRave)
        value = GetValueEstimateRave(child);
    else
        value = GetValueEstimate(false, child);
    if (m_biasTermConstant == 0.0)
        return value;
    else
    {
        SgUctValue moveCount = static_cast<SgUctValue>(child.MoveCount());
        SgUctValue bound =
            value + m_biasTermConstant * sqrt(logPosCount / (moveCount + 1));
        return bound;
    }
}
Esempio n. 9
0
void SgUctTree::MergeChildren(std::size_t allocatorId, const SgUctNode& node,
                              const std::vector<SgMoveInfo>& moves,
                              bool deleteChildTrees)
{
    SG_ASSERT(Contains(node));
    // Parameters are const-references, because only the tree is allowed
    // to modify nodes
    SgUctNode& nonConstNode = const_cast<SgUctNode&>(node);
    size_t nuNewChildren = moves.size();

    if (nuNewChildren == 0)
    {
        // Write order dependency
        nonConstNode.SetNuChildren(0);
        SgSynchronizeThreadMemory();
        nonConstNode.SetFirstChild(0);
        return;
    }

    SgUctAllocator& allocator = Allocator(allocatorId);
    SG_ASSERT(allocator.HasCapacity(nuNewChildren));

    const SgUctNode* newFirstChild = allocator.Finish();
    std::size_t parentCount = allocator.Create(moves);
    
    // Update new children with data in old children
    for (std::size_t i = 0; i < moves.size(); ++i) 
    {
        SgUctNode* newChild = const_cast<SgUctNode*>(&newFirstChild[i]);
        for (SgUctChildIterator it(*this, node); it; ++it)
        {
            const SgUctNode& oldChild = *it;
            if (oldChild.Move() == moves[i].m_move)
            {
                newChild->MergeResults(oldChild);
                newChild->SetKnowledgeCount(oldChild.KnowledgeCount());
                if (! deleteChildTrees)
                {
                    newChild->SetPosCount(oldChild.PosCount());
                    parentCount += oldChild.MoveCount();
                    if (oldChild.HasChildren())
                    {
                        newChild->SetFirstChild(oldChild.FirstChild());
                        newChild->SetNuChildren(oldChild.NuChildren());
                    }
                }
                break;
            }
        }
    }
    nonConstNode.SetPosCount(parentCount);

    // Write order dependency: We do not want an SgUctChildIterator to
    // run past the end of a node's children, which can happen if one
    // is created between the two statements below. We modify node in
    // such a way so as to avoid that.
    if (nonConstNode.NuChildren() < (int)nuNewChildren)
    {
        nonConstNode.SetFirstChild(newFirstChild);
    SgSynchronizeThreadMemory();
        nonConstNode.SetNuChildren(nuNewChildren);
    }
    else
    {
        nonConstNode.SetNuChildren(nuNewChildren);
    SgSynchronizeThreadMemory();
        nonConstNode.SetFirstChild(newFirstChild);
    }
}
Esempio n. 10
0
/** Recursive function used by SgUctTree::ExtractSubtree and
    SgUctTree::CopyPruneLowCount.
    @param target The target tree.
    @param targetNode The target node; it is already created but the content
    not yet copied
    @param node The node in the source tree to be copied.
    @param minCount The minimum count (SgUctNode::MoveCount()) of a non-root
    node in the source tree to copy
    @param currentAllocatorId The current node allocator. Will be incremented
    in each call to CopySubtree to use node allocators of target tree evenly.
    @param warnTruncate Print warning to SgDebug() if tree was
    truncated (e.g due to reassigning nodes to different allocators)
    @param[in,out] abort Flag to abort copying. Must be initialized to false
    by top-level caller
    @param timer
    @param maxTime See ExtractSubtree()
*/
void SgUctTree::CopySubtree(SgUctTree& target, SgUctNode& targetNode,
                            const SgUctNode& node, std::size_t minCount,
                            std::size_t& currentAllocatorId,
                            bool warnTruncate, bool& abort, SgTimer& timer,
                            double maxTime) const
{
    SG_ASSERT(Contains(node));
    SG_ASSERT(target.Contains(targetNode));
    targetNode.CopyDataFrom(node);

    if (! node.HasChildren() || node.MoveCount() < minCount)
        return;

    SgUctAllocator& targetAllocator = target.Allocator(currentAllocatorId);
    int nuChildren = node.NuChildren();
    if (! abort)
    {
        if (! targetAllocator.HasCapacity(nuChildren))
        {
            // This can happen even if target tree has same maximum number of
            // nodes, because allocators are used differently.
            if (warnTruncate)
                SgDebug() <<
                "SgUctTree::CopySubtree: Truncated (allocator capacity)\n";
            abort = true;
        }
        if (timer.IsTimeOut(maxTime, 10000))
        {
            if (warnTruncate)
                SgDebug() << "SgUctTree::CopySubtree: Truncated (max time)\n";
            abort = true;
        }
        if (SgUserAbort())
        {
            if (warnTruncate)
                SgDebug() << "SgUctTree::CopySubtree: Truncated (aborted)\n";
            abort = true;
        }
    }
    if (abort)
    {
        // Don't copy the children and set the pos count to zero (should
        // reflect the sum of children move counts)
        targetNode.SetPosCount(0);
        return;
    }

    SgUctNode* firstTargetChild = targetAllocator.Finish();
    targetNode.SetFirstChild(firstTargetChild);
    targetNode.SetNuChildren(nuChildren);

    // Create target nodes first (must be contiguous in the target tree)
    targetAllocator.CreateN(nuChildren);

    // Recurse
    SgUctNode* targetChild = firstTargetChild;
    for (SgUctChildIterator it(*this, node); it; ++it, ++targetChild)
    {
        const SgUctNode& child = *it;
        ++currentAllocatorId; // Cycle to use allocators uniformly
        if (currentAllocatorId >= target.NuAllocators())
            currentAllocatorId = 0;
        CopySubtree(target, *targetChild, child, minCount, currentAllocatorId,
                    warnTruncate, abort, timer, maxTime);
    }
}