Esempio n. 1
0
/** Optimized version of GetValueEstimate() if RAVE and not other
    estimators are used.
    Previously there were more estimators than move value and RAVE value,
    and in the future there may be again. GetValueEstimate() is easier to
    extend, this function is more optimized for the special case. */
SgUctValue SgUctSearch::GetValueEstimateRave(const SgUctNode& child) const
{
    SG_ASSERT(m_rave);
    SgUctValue value;
    SgUctStatistics uctStats;
    if (child.HasMean())
    {
        uctStats.Initialize(child.Mean(), child.MoveCount());
    }
    SgUctStatistics raveStats;
    if (child.HasRaveValue())
    {
        raveStats.Initialize(child.RaveValue(), child.RaveCount());
    }
    int virtualLossCount = child.VirtualLossCount();
    if (virtualLossCount > 0)
    {
        uctStats.Add(InverseEstimate(0), SgUctValue(virtualLossCount));
        raveStats.Add(0, SgUctValue(virtualLossCount));
    }
    bool hasRave = raveStats.IsDefined();
    
    if (uctStats.IsDefined())
    {
        SgUctValue moveValue = InverseEstimate((SgUctValue)uctStats.Mean());
        if (hasRave)
        {
            SgUctValue moveCount = uctStats.Count();
            SgUctValue raveCount = raveStats.Count();
            SgUctValue weight =
                raveCount
                / (moveCount
                   * (m_raveWeightParam1 + m_raveWeightParam2 * raveCount)
                   + raveCount);
            value = weight * raveStats.Mean() + (1.f - weight) * moveValue;
        }
        else
        {
            // This can happen only in lock-free multi-threading. Normally,
            // each move played in a position should also cause a RAVE value
            // to be added. But in lock-free multi-threading it can happen
            // that the move value was already updated but the RAVE value not
            SG_ASSERT(m_numberThreads > 1 && m_lockFree);
            value = moveValue;
        }
    }
    else if (hasRave)
        value = raveStats.Mean();
    else
        value = m_firstPlayUrgency;
    SG_ASSERT(m_numberThreads > 1
              || fabs(value - GetValueEstimate(m_rave, child)) < 1e-3/*epsilon*/);
    return value;
}
Esempio n. 2
0
SgUctValue SgUctSearch::GetValueEstimate(bool useRave, const SgUctNode& child) const
{
    SgUctValue value = 0;
    SgUctValue weightSum = 0;
    bool hasValue = false;

    SgUctStatistics uctStats;
    if (child.HasMean())
    {
        uctStats.Initialize(child.Mean(), child.MoveCount());
    }
    int virtualLossCount = child.VirtualLossCount();
    if (virtualLossCount > 0)
    {
        uctStats.Add(InverseEstimate(0), SgUctValue(virtualLossCount));
    }

    if (uctStats.IsDefined())
    {
        SgUctValue weight = static_cast<SgUctValue>(uctStats.Count());
        value += weight * InverseEstimate((SgUctValue)uctStats.Mean());
        weightSum += weight;
        hasValue = true;
    }

    if (useRave)
    {
        SgUctStatistics raveStats;
        if (child.HasRaveValue())
        {
            raveStats.Initialize(child.RaveValue(), child.RaveCount());
        }
        if (virtualLossCount > 0)
        {
            raveStats.Add(0, SgUctValue(virtualLossCount));
        }
        if (raveStats.IsDefined())
        {
            SgUctValue raveCount = raveStats.Count();
            SgUctValue weight =
                raveCount
                / (  m_raveWeightParam1
                     + m_raveWeightParam2 * raveCount
                     );
            value += weight * raveStats.Mean();
            weightSum += weight;
            hasValue = true;
        }
    }
    if (hasValue)
        return value / weightSum;
    else
        return m_firstPlayUrgency;
}
Esempio n. 3
0
SgUctValue SgUctSearch::GetBound(bool useRave, SgUctValue logPosCount,
                                 const SgUctNode& child) const
{
    SgUctValue value;
    if (useRave)
        value = GetValueEstimateRave(child);
    else
        value = GetValueEstimate(false, child);
    if (m_biasTermConstant == 0.0)
        return value;
    else
    {
        SgUctValue moveCount = static_cast<SgUctValue>(child.MoveCount());
        SgUctValue bound =
            value + m_biasTermConstant * sqrt(logPosCount / (moveCount + 1));
        return bound;
    }
}
Esempio n. 4
0
/** Recursive function used by SgUctTree::ExtractSubtree and
    SgUctTree::CopyPruneLowCount.
    @param target The target tree.
    @param targetNode The target node; it is already created but the content
    not yet copied
    @param node The node in the source tree to be copied.
    @param minCount The minimum count (SgUctNode::MoveCount()) of a non-root
    node in the source tree to copy
    @param currentAllocatorId The current node allocator. Will be incremented
    in each call to CopySubtree to use node allocators of target tree evenly.
    @param warnTruncate Print warning to SgDebug() if tree was
    truncated (e.g due to reassigning nodes to different allocators)
    @param[in,out] abort Flag to abort copying. Must be initialized to false
    by top-level caller
    @param timer
    @param maxTime See ExtractSubtree()
*/
void SgUctTree::CopySubtree(SgUctTree& target, SgUctNode& targetNode,
                            const SgUctNode& node, std::size_t minCount,
                            std::size_t& currentAllocatorId,
                            bool warnTruncate, bool& abort, SgTimer& timer,
                            double maxTime) const
{
    SG_ASSERT(Contains(node));
    SG_ASSERT(target.Contains(targetNode));
    targetNode.CopyDataFrom(node);

    if (! node.HasChildren() || node.MoveCount() < minCount)
        return;

    SgUctAllocator& targetAllocator = target.Allocator(currentAllocatorId);
    int nuChildren = node.NuChildren();
    if (! abort)
    {
        if (! targetAllocator.HasCapacity(nuChildren))
        {
            // This can happen even if target tree has same maximum number of
            // nodes, because allocators are used differently.
            if (warnTruncate)
                SgDebug() <<
                "SgUctTree::CopySubtree: Truncated (allocator capacity)\n";
            abort = true;
        }
        if (timer.IsTimeOut(maxTime, 10000))
        {
            if (warnTruncate)
                SgDebug() << "SgUctTree::CopySubtree: Truncated (max time)\n";
            abort = true;
        }
        if (SgUserAbort())
        {
            if (warnTruncate)
                SgDebug() << "SgUctTree::CopySubtree: Truncated (aborted)\n";
            abort = true;
        }
    }
    if (abort)
    {
        // Don't copy the children and set the pos count to zero (should
        // reflect the sum of children move counts)
        targetNode.SetPosCount(0);
        return;
    }

    SgUctNode* firstTargetChild = targetAllocator.Finish();
    targetNode.SetFirstChild(firstTargetChild);
    targetNode.SetNuChildren(nuChildren);

    // Create target nodes first (must be contiguous in the target tree)
    targetAllocator.CreateN(nuChildren);

    // Recurse
    SgUctNode* targetChild = firstTargetChild;
    for (SgUctChildIterator it(*this, node); it; ++it, ++targetChild)
    {
        const SgUctNode& child = *it;
        ++currentAllocatorId; // Cycle to use allocators uniformly
        if (currentAllocatorId >= target.NuAllocators())
            currentAllocatorId = 0;
        CopySubtree(target, *targetChild, child, minCount, currentAllocatorId,
                    warnTruncate, abort, timer, maxTime);
    }
}