Exemple #1
0
static TVector<TVector<ui64>> CollectLeavesStatistics(const TPool& pool, const TFullModel& model) {
    const size_t treeCount = model.ObliviousTrees.TreeSizes.size();
    TVector<TVector<ui64>> leavesStatistics(treeCount, TVector<ui64>{});
    for (size_t index = 0; index < treeCount; ++index) {
        leavesStatistics[index].resize(model.ObliviousTrees.LeafValues[index].size() / model.ObliviousTrees.ApproxDimension);
    }

    auto binFeatures = BinarizeFeatures(model, pool);

    const auto documentsCount = pool.Docs.GetDocCount();
    for (size_t treeIdx = 0; treeIdx < treeCount; ++treeIdx) {
        TVector<TIndexType> indices = BuildIndicesForBinTree(
            model,
            binFeatures,
            treeIdx);

        if (indices.empty()) {
            continue;
        }
        for (size_t doc = 0; doc < documentsCount; ++doc) {
            const TIndexType valueIndex = indices[doc];
            ++leavesStatistics[treeIdx][valueIndex];
        }
    }
    return leavesStatistics;
}
static TVector<TVector<double>> CalcShapValuesForDocumentBlock(const TFullModel& model,
                                                               const TPool& pool,
                                                               size_t start,
                                                               size_t end,
                                                               NPar::TLocalExecutor& localExecutor,
                                                               int dimension) {
    CB_ENSURE(!HasComplexCtrs(model.ObliviousTrees), "Model uses complex Ctr features. This is not allowed for SHAP values calculation");

    const TObliviousTrees& forest = model.ObliviousTrees;
    const size_t documentCount = end - start;

    TVector<ui8> allBinarizedFeatures = BinarizeFeatures(model, pool, start, end);
    TVector<TVector<ui8>> binarizedFeaturesByDocument = TransposeBinarizedFeatures(allBinarizedFeatures, documentCount);
    allBinarizedFeatures.clear();

    const int flatFeatureCount = pool.Docs.GetFactorsCount();
    TVector<int> binFeaturesMapping = MapFeatures(forest);
    TVector<TVector<double>> shapValues(documentCount, TVector<double>(flatFeatureCount + 1, 0.0));

    NPar::TLocalExecutor::TExecRangeParams blockParams(0, documentCount);
    localExecutor.ExecRange([&] (int documentIdx) {
        const size_t treeCount = forest.GetTreeCount();
        for (size_t treeIdx = 0; treeIdx < treeCount; ++treeIdx) {
            TVector<TVector<size_t>> subtreeSizes = CalcSubtreeSizesForTree(forest, treeIdx);
            TVector<TFeaturePathElement> initialFeaturePath;
            CalcShapValuesRecursive(forest, binFeaturesMapping, binarizedFeaturesByDocument[documentIdx], treeIdx, /*depth*/ 0, subtreeSizes, dimension,
                                    /*nodeIdx*/ 0, initialFeaturePath, /*zeroPathFraction*/ 1, /*onePathFraction*/ 1, /*feature*/ -1,
                                    &shapValues[documentIdx]);

            shapValues[documentIdx][flatFeatureCount] += CalcMeanValueForTree(forest, subtreeSizes, treeIdx, dimension);
        }
    }, blockParams, NPar::TLocalExecutor::WAIT_COMPLETE);

    return shapValues;
}