コード例 #1
0
static TVector<TVector<double>> CalcShapValuesForDocumentBlock(const TFullModel& model,
                                                               const TPool& pool,
                                                               size_t start,
                                                               size_t end,
                                                               NPar::TLocalExecutor& localExecutor,
                                                               int dimension) {
    CB_ENSURE(!HasComplexCtrs(model.ObliviousTrees), "Model uses complex Ctr features. This is not allowed for SHAP values calculation");

    const TObliviousTrees& forest = model.ObliviousTrees;
    const size_t documentCount = end - start;

    TVector<ui8> allBinarizedFeatures = BinarizeFeatures(model, pool, start, end);
    TVector<TVector<ui8>> binarizedFeaturesByDocument = TransposeBinarizedFeatures(allBinarizedFeatures, documentCount);
    allBinarizedFeatures.clear();

    const int flatFeatureCount = pool.Docs.GetFactorsCount();
    TVector<int> binFeaturesMapping = MapFeatures(forest);
    TVector<TVector<double>> shapValues(documentCount, TVector<double>(flatFeatureCount + 1, 0.0));

    NPar::TLocalExecutor::TExecRangeParams blockParams(0, documentCount);
    localExecutor.ExecRange([&] (int documentIdx) {
        const size_t treeCount = forest.GetTreeCount();
        for (size_t treeIdx = 0; treeIdx < treeCount; ++treeIdx) {
            TVector<TVector<size_t>> subtreeSizes = CalcSubtreeSizesForTree(forest, treeIdx);
            TVector<TFeaturePathElement> initialFeaturePath;
            CalcShapValuesRecursive(forest, binFeaturesMapping, binarizedFeaturesByDocument[documentIdx], treeIdx, /*depth*/ 0, subtreeSizes, dimension,
                                    /*nodeIdx*/ 0, initialFeaturePath, /*zeroPathFraction*/ 1, /*onePathFraction*/ 1, /*feature*/ -1,
                                    &shapValues[documentIdx]);

            shapValues[documentIdx][flatFeatureCount] += CalcMeanValueForTree(forest, subtreeSizes, treeIdx, dimension);
        }
    }, blockParams, NPar::TLocalExecutor::WAIT_COMPLETE);

    return shapValues;
}
コード例 #2
0
static inline void BinarizeFloatFeature(int featureIdx,
                                        const TDocumentStorage& docStorage,
                                        const TDocSelector& docSelector,
                                        const TVector<float>& borders,
                                        ENanMode nanMode,
                                        NPar::TLocalExecutor& localExecutor,
                                        int floatFeatureIdx,
                                        TAllFeatures* features,
                                        bool* seenNans) {
    size_t docCount = docSelector.GetDocCount();
    const TVector<float>& src = docStorage.Factors[featureIdx];
    TVector<ui8>& hist = features->FloatHistograms[floatFeatureIdx];

    hist.resize(docCount);

    ui8* histData = hist.data();
    const float* featureBorderData = borders.data();
    const int featureBorderSize = borders.ysize();

    localExecutor.ExecRange([&] (int i) {
        const auto& featureVal = src[docSelector(i)];
        if (IsNan(featureVal)) {
            *seenNans = true;
            histData[i] = nanMode == ENanMode::Min ? 0 : featureBorderSize;
        } else {
            int j = 0;
            while (j < featureBorderSize && featureVal > featureBorderData[j]) {
                ++histData[i];
                ++j;
            }
        //    histData[i] = LowerBound(featureBorderData, featureBorderData + featureBorderSize, featureVal) - featureBorderData;
        }
    }
    , NPar::TLocalExecutor::TExecRangeParams(0, docCount).SetBlockSize(1000)
    , NPar::TLocalExecutor::WAIT_COMPLETE);
}