static TVector<TVector<double>> CalcShapValuesForDocumentBlock(const TFullModel& model, const TPool& pool, size_t start, size_t end, NPar::TLocalExecutor& localExecutor, int dimension) { CB_ENSURE(!HasComplexCtrs(model.ObliviousTrees), "Model uses complex Ctr features. This is not allowed for SHAP values calculation"); const TObliviousTrees& forest = model.ObliviousTrees; const size_t documentCount = end - start; TVector<ui8> allBinarizedFeatures = BinarizeFeatures(model, pool, start, end); TVector<TVector<ui8>> binarizedFeaturesByDocument = TransposeBinarizedFeatures(allBinarizedFeatures, documentCount); allBinarizedFeatures.clear(); const int flatFeatureCount = pool.Docs.GetFactorsCount(); TVector<int> binFeaturesMapping = MapFeatures(forest); TVector<TVector<double>> shapValues(documentCount, TVector<double>(flatFeatureCount + 1, 0.0)); NPar::TLocalExecutor::TExecRangeParams blockParams(0, documentCount); localExecutor.ExecRange([&] (int documentIdx) { const size_t treeCount = forest.GetTreeCount(); for (size_t treeIdx = 0; treeIdx < treeCount; ++treeIdx) { TVector<TVector<size_t>> subtreeSizes = CalcSubtreeSizesForTree(forest, treeIdx); TVector<TFeaturePathElement> initialFeaturePath; CalcShapValuesRecursive(forest, binFeaturesMapping, binarizedFeaturesByDocument[documentIdx], treeIdx, /*depth*/ 0, subtreeSizes, dimension, /*nodeIdx*/ 0, initialFeaturePath, /*zeroPathFraction*/ 1, /*onePathFraction*/ 1, /*feature*/ -1, &shapValues[documentIdx]); shapValues[documentIdx][flatFeatureCount] += CalcMeanValueForTree(forest, subtreeSizes, treeIdx, dimension); } }, blockParams, NPar::TLocalExecutor::WAIT_COMPLETE); return shapValues; }
static void AssignRandomWeights(int learnSampleCount, TLearnContext* ctx, TFold* fold) { TVector<float> sampleWeights; sampleWeights.yresize(learnSampleCount); const ui64 randSeed = ctx->Rand.GenRand(); NPar::TLocalExecutor::TExecRangeParams blockParams(0, learnSampleCount); blockParams.SetBlockSize(10000); ctx->LocalExecutor.ExecRange([&](int blockIdx) { TFastRng64 rand(randSeed + blockIdx); rand.Advance(10); // reduce correlation between RNGs in different threads const float baggingTemperature = ctx->Params.ObliviousTreeOptions->BootstrapConfig->GetBaggingTemperature(); float* sampleWeightsData = sampleWeights.data(); NPar::TLocalExecutor::BlockedLoopBody(blockParams, [&rand, sampleWeightsData, baggingTemperature](int i) { const float w = -FastLogf(rand.GenRandReal1() + 1e-100); sampleWeightsData[i] = powf(w, baggingTemperature); })(blockIdx); }, 0, blockParams.GetBlockCount(), NPar::TLocalExecutor::WAIT_COMPLETE); TFold& ff = *fold; ff.AssignPermuted(sampleWeights, &ff.SampleWeights); if (!ff.LearnWeights.empty()) { for (int i = 0; i < learnSampleCount; ++i) { ff.SampleWeights[i] *= ff.LearnWeights[i]; } } const int approxDimension = ff.GetApproxDimension(); for (TFold::TBodyTail& bt : ff.BodyTailArr) { for (int dim = 0; dim < approxDimension; ++dim) { double* weightedDerData = bt.WeightedDer[dim].data(); const double* derData = bt.Derivatives[dim].data(); const float* sampleWeightsData = ff.SampleWeights.data(); ctx->LocalExecutor.ExecRange([=](int z) { weightedDerData[z] = derData[z] * sampleWeightsData[z]; }, NPar::TLocalExecutor::TExecRangeParams(bt.BodyFinish, bt.TailFinish).SetBlockSize(4000) , NPar::TLocalExecutor::WAIT_COMPLETE); } } }