static TVector<TVector<double>> CalcShapValuesForDocumentBlock(const TFullModel& model, const TPool& pool, size_t start, size_t end, NPar::TLocalExecutor& localExecutor, int dimension) { CB_ENSURE(!HasComplexCtrs(model.ObliviousTrees), "Model uses complex Ctr features. This is not allowed for SHAP values calculation"); const TObliviousTrees& forest = model.ObliviousTrees; const size_t documentCount = end - start; TVector<ui8> allBinarizedFeatures = BinarizeFeatures(model, pool, start, end); TVector<TVector<ui8>> binarizedFeaturesByDocument = TransposeBinarizedFeatures(allBinarizedFeatures, documentCount); allBinarizedFeatures.clear(); const int flatFeatureCount = pool.Docs.GetFactorsCount(); TVector<int> binFeaturesMapping = MapFeatures(forest); TVector<TVector<double>> shapValues(documentCount, TVector<double>(flatFeatureCount + 1, 0.0)); NPar::TLocalExecutor::TExecRangeParams blockParams(0, documentCount); localExecutor.ExecRange([&] (int documentIdx) { const size_t treeCount = forest.GetTreeCount(); for (size_t treeIdx = 0; treeIdx < treeCount; ++treeIdx) { TVector<TVector<size_t>> subtreeSizes = CalcSubtreeSizesForTree(forest, treeIdx); TVector<TFeaturePathElement> initialFeaturePath; CalcShapValuesRecursive(forest, binFeaturesMapping, binarizedFeaturesByDocument[documentIdx], treeIdx, /*depth*/ 0, subtreeSizes, dimension, /*nodeIdx*/ 0, initialFeaturePath, /*zeroPathFraction*/ 1, /*onePathFraction*/ 1, /*feature*/ -1, &shapValues[documentIdx]); shapValues[documentIdx][flatFeatureCount] += CalcMeanValueForTree(forest, subtreeSizes, treeIdx, dimension); } }, blockParams, NPar::TLocalExecutor::WAIT_COMPLETE); return shapValues; }
static inline void BinarizeFloatFeature(int featureIdx, const TDocumentStorage& docStorage, const TDocSelector& docSelector, const TVector<float>& borders, ENanMode nanMode, NPar::TLocalExecutor& localExecutor, int floatFeatureIdx, TAllFeatures* features, bool* seenNans) { size_t docCount = docSelector.GetDocCount(); const TVector<float>& src = docStorage.Factors[featureIdx]; TVector<ui8>& hist = features->FloatHistograms[floatFeatureIdx]; hist.resize(docCount); ui8* histData = hist.data(); const float* featureBorderData = borders.data(); const int featureBorderSize = borders.ysize(); localExecutor.ExecRange([&] (int i) { const auto& featureVal = src[docSelector(i)]; if (IsNan(featureVal)) { *seenNans = true; histData[i] = nanMode == ENanMode::Min ? 0 : featureBorderSize; } else { int j = 0; while (j < featureBorderSize && featureVal > featureBorderData[j]) { ++histData[i]; ++j; } // histData[i] = LowerBound(featureBorderData, featureBorderData + featureBorderSize, featureVal) - featureBorderData; } } , NPar::TLocalExecutor::TExecRangeParams(0, docCount).SetBlockSize(1000) , NPar::TLocalExecutor::WAIT_COMPLETE); }