Esempio n. 1
0
/** Optimized version of GetValueEstimate() if RAVE and not other
    estimators are used.
    Previously there were more estimators than move value and RAVE value,
    and in the future there may be again. GetValueEstimate() is easier to
    extend, this function is more optimized for the special case. */
SgUctValue SgUctSearch::GetValueEstimateRave(const SgUctNode& child) const
{
    SG_ASSERT(m_rave);
    SgUctValue value;
    SgUctStatistics uctStats;
    if (child.HasMean())
    {
        uctStats.Initialize(child.Mean(), child.MoveCount());
    }
    SgUctStatistics raveStats;
    if (child.HasRaveValue())
    {
        raveStats.Initialize(child.RaveValue(), child.RaveCount());
    }
    int virtualLossCount = child.VirtualLossCount();
    if (virtualLossCount > 0)
    {
        uctStats.Add(InverseEstimate(0), SgUctValue(virtualLossCount));
        raveStats.Add(0, SgUctValue(virtualLossCount));
    }
    bool hasRave = raveStats.IsDefined();
    
    if (uctStats.IsDefined())
    {
        SgUctValue moveValue = InverseEstimate((SgUctValue)uctStats.Mean());
        if (hasRave)
        {
            SgUctValue moveCount = uctStats.Count();
            SgUctValue raveCount = raveStats.Count();
            SgUctValue weight =
                raveCount
                / (moveCount
                   * (m_raveWeightParam1 + m_raveWeightParam2 * raveCount)
                   + raveCount);
            value = weight * raveStats.Mean() + (1.f - weight) * moveValue;
        }
        else
        {
            // This can happen only in lock-free multi-threading. Normally,
            // each move played in a position should also cause a RAVE value
            // to be added. But in lock-free multi-threading it can happen
            // that the move value was already updated but the RAVE value not
            SG_ASSERT(m_numberThreads > 1 && m_lockFree);
            value = moveValue;
        }
    }
    else if (hasRave)
        value = raveStats.Mean();
    else
        value = m_firstPlayUrgency;
    SG_ASSERT(m_numberThreads > 1
              || fabs(value - GetValueEstimate(m_rave, child)) < 1e-3/*epsilon*/);
    return value;
}
Esempio n. 2
0
SgUctValue SgUctSearch::GetValueEstimate(bool useRave, const SgUctNode& child) const
{
    SgUctValue value = 0;
    SgUctValue weightSum = 0;
    bool hasValue = false;

    SgUctStatistics uctStats;
    if (child.HasMean())
    {
        uctStats.Initialize(child.Mean(), child.MoveCount());
    }
    int virtualLossCount = child.VirtualLossCount();
    if (virtualLossCount > 0)
    {
        uctStats.Add(InverseEstimate(0), SgUctValue(virtualLossCount));
    }

    if (uctStats.IsDefined())
    {
        SgUctValue weight = static_cast<SgUctValue>(uctStats.Count());
        value += weight * InverseEstimate((SgUctValue)uctStats.Mean());
        weightSum += weight;
        hasValue = true;
    }

    if (useRave)
    {
        SgUctStatistics raveStats;
        if (child.HasRaveValue())
        {
            raveStats.Initialize(child.RaveValue(), child.RaveCount());
        }
        if (virtualLossCount > 0)
        {
            raveStats.Add(0, SgUctValue(virtualLossCount));
        }
        if (raveStats.IsDefined())
        {
            SgUctValue raveCount = raveStats.Count();
            SgUctValue weight =
                raveCount
                / (  m_raveWeightParam1
                     + m_raveWeightParam2 * raveCount
                     );
            value += weight * raveStats.Mean();
            weightSum += weight;
            hasValue = true;
        }
    }
    if (hasValue)
        return value / weightSum;
    else
        return m_firstPlayUrgency;
}