TConstArrayRef<T> Remap(const TConstArrayRef<ui64>& keys) { if (keys.empty()) { return TConstArrayRef<T>(); } Y_ASSERT(keys.begin() >= Keys.begin() && keys.begin() <= Keys.end()); Y_ASSERT(keys.end() >= Keys.begin() && keys.end() <= Keys.end()); return TConstArrayRef<T>(Words[keys.begin() - Keys.begin()].begin(), Words[keys.end() - Keys.begin() - 1].end()); }
TVector<TVector<double>> PrepareEval(const EPredictionType predictionType, const TVector<TVector<double>>& approx, NPar::TLocalExecutor* localExecutor) { TVector<TVector<double>> result; switch (predictionType) { case EPredictionType::Probability: if (IsMulticlass(approx)) { result = CalcSoftmax(approx, localExecutor); } else { result = {CalcSigmoid(approx[0])}; } break; case EPredictionType::Class: result.resize(1); result[0].reserve(approx.size()); if (IsMulticlass(approx)) { TVector<int> predictions = {SelectBestClass(approx, localExecutor)}; result[0].assign(predictions.begin(), predictions.end()); } else { for (const double prediction : approx[0]) { result[0].push_back(prediction > 0); } } break; case EPredictionType::RawFormulaVal: result = approx; break; default: Y_ASSERT(false); } return result; }
static void AddCtrsToCandList(const TFold& fold, const TLearnContext& ctx, const TProjection& proj, TCandidateList* candList) { TCandidatesInfoList ctrSplits; const auto& ctrsHelper = ctx.CtrsHelper; const auto& ctrInfo = ctrsHelper.GetCtrInfo(proj); for (ui32 ctrIdx = 0; ctrIdx < ctrInfo.size(); ++ctrIdx) { const ui32 targetClassesCount = fold.TargetClassesCount[ctrInfo[ctrIdx].TargetClassifierIdx]; int targetBorderCount = GetTargetBorderCount(ctrInfo[ctrIdx], targetClassesCount); const auto& priors = ctrInfo[ctrIdx].Priors; int priorsCount = priors.size(); Y_ASSERT(targetBorderCount < 256); for (int border = 0; border < targetBorderCount; ++border) { for (int prior = 0; prior < priorsCount; ++prior) { TCandidateInfo split; split.SplitCandidate.Type = ESplitType::OnlineCtr; split.SplitCandidate.Ctr = TCtr(proj, ctrIdx, border, prior, ctrInfo[ctrIdx].BorderCount); ctrSplits.Candidates.emplace_back(split); } } } candList->push_back(ctrSplits); }
void CalcFinalCtrsImpl( const ECtrType ctrType, const ui64 ctrLeafCountLimit, const TVector<int>& permutedTargetClass, const TVector<float>& permutedTargets, const ui64 learnSampleCount, int targetClassesCount, TVector<ui64>* hashArr, TCtrValueTable* result ) { TDenseHash<ui64, ui32> tmpHash; auto leafCount = ReindexHash( learnSampleCount, ctrLeafCountLimit, hashArr, &tmpHash).first; auto hashIndexBuilder = result->GetIndexHashBuilder(leafCount); for (const auto& kv : tmpHash) { hashIndexBuilder.SetIndex(kv.Key(), kv.Value()); } TArrayRef<int> ctrIntArray; TArrayRef<TCtrMeanHistory> ctrMean; if (ctrType == ECtrType::BinarizedTargetMeanValue || ctrType == ECtrType::FloatTargetMeanValue) { ctrMean = result->AllocateBlobAndGetArrayRef<TCtrMeanHistory>(leafCount); } else if (ctrType == ECtrType::Counter || ctrType == ECtrType::FeatureFreq) { ctrIntArray = result->AllocateBlobAndGetArrayRef<int>(leafCount); result->CounterDenominator = 0; } else { result->TargetClassesCount = targetClassesCount; ctrIntArray = result->AllocateBlobAndGetArrayRef<int>(leafCount * targetClassesCount); } Y_ASSERT(hashArr->size() == learnSampleCount); int targetBorderCount = targetClassesCount - 1; auto hashArrPtr = hashArr->data(); for (ui32 z = 0; z < learnSampleCount; ++z) { const ui64 elemId = hashArrPtr[z]; if (ctrType == ECtrType::BinarizedTargetMeanValue) { TCtrMeanHistory& elem = ctrMean[elemId]; elem.Add(static_cast<float>(permutedTargetClass[z]) / targetBorderCount); } else if (ctrType == ECtrType::Counter || ctrType == ECtrType::FeatureFreq) { ++ctrIntArray[elemId]; } else if (ctrType == ECtrType::FloatTargetMeanValue) { TCtrMeanHistory& elem = ctrMean[elemId]; elem.Add(permutedTargets[z]); } else { TArrayRef<int> elem = MakeArrayRef(ctrIntArray.data() + targetClassesCount * elemId, targetClassesCount); ++elem[permutedTargetClass[z]]; } } if (ctrType == ECtrType::Counter) { result->CounterDenominator = *MaxElement(ctrIntArray.begin(), ctrIntArray.end()); } if (ctrType == ECtrType::FeatureFreq) { result->CounterDenominator = static_cast<int>(learnSampleCount); } }
void TOutputFiles::InitializeFiles(const NCatboostOptions::TOutputFilesOptions& params, const TString& namesPrefix) { if (!params.AllowWriteFiles()) { Y_ASSERT(TimeLeftLogFile.empty()); Y_ASSERT(LearnErrorLogFile.empty()); Y_ASSERT(TestErrorLogFile.empty()); Y_ASSERT(MetaFile.empty()); Y_ASSERT(SnapshotFile.empty()); return; } const auto& trainDir = params.GetTrainDir(); TFsPath trainDirPath(trainDir); if (!trainDir.empty() && !trainDirPath.Exists()) { trainDirPath.MkDir(); } NamesPrefix = namesPrefix; CB_ENSURE(!params.GetTimeLeftLogFilename().empty(), "empty time_left filename"); TimeLeftLogFile = TOutputFiles::AlignFilePath(trainDir, params.GetTimeLeftLogFilename(), NamesPrefix); CB_ENSURE(!params.GetLearnErrorFilename().empty(), "empty learn_error filename"); LearnErrorLogFile = TOutputFiles::AlignFilePath(trainDir, params.GetLearnErrorFilename(), NamesPrefix); if (params.GetTestErrorFilename()) { TestErrorLogFile = TOutputFiles::AlignFilePath(trainDir, params.GetTestErrorFilename(), NamesPrefix); } if (params.SaveSnapshot()) { SnapshotFile = TOutputFiles::AlignFilePath(trainDir, params.GetSnapshotFilename(), NamesPrefix); } const TString& metaFileFilename = params.GetMetaFileFilename(); CB_ENSURE(!metaFileFilename.empty(), "empty meta filename"); MetaFile = TOutputFiles::AlignFilePath(trainDir, metaFileFilename, NamesPrefix); const TString& jsonLogFilename = params.GetJsonLogFilename(); CB_ENSURE(!jsonLogFilename.empty(), "empty json_log filename"); JsonLogFile = TOutputFiles::AlignFilePath(trainDir, jsonLogFilename, ""); const TString& profileLogFilename = params.GetProfileLogFilename(); CB_ENSURE(!profileLogFilename.empty(), "empty profile_log filename"); ProfileLogFile = TOutputFiles::AlignFilePath(trainDir, profileLogFilename, ""); }
TString TFeature::BuildDescription(const TFeaturesLayout& layout) const { TStringBuilder result; if (Type == ESplitType::OnlineCtr) { result << ::BuildDescription(layout, Ctr.Base.Projection); result << " prior_num=" << Ctr.PriorNum; result << " prior_denom=" << Ctr.PriorDenom; result << " targetborder=" << Ctr.TargetBorderIdx; result << " type=" << Ctr.Base.CtrType; } else if (Type == ESplitType::FloatFeature) { result << BuildFeatureDescription(layout, FeatureIdx, EFeatureType::Float); } else { Y_ASSERT(Type == ESplitType::OneHotFeature); result << BuildFeatureDescription(layout, FeatureIdx, EFeatureType::Categorical); } return result; }
TVector<double> CalcRegularFeatureEffect(const TFullModel& model, const TPool& pool, int threadCount/*= 1*/) { int featureCount = pool.Docs.GetFactorsCount(); CB_ENSURE(static_cast<size_t>(featureCount) >= model.ObliviousTrees.GetFlatFeatureVectorExpectedSize(), "Insufficient features count in pool"); int catFeaturesCount = pool.CatFeatures.ysize(); int floatFeaturesCount = featureCount - catFeaturesCount; TFeaturesLayout layout(featureCount, pool.CatFeatures, pool.FeatureId); TVector<TFeatureEffect> regularEffect = CalcRegularFeatureEffect(CalcFeatureEffect(model, pool, threadCount), catFeaturesCount, floatFeaturesCount); TVector<double> effect(featureCount); for (const auto& featureEffect : regularEffect) { int featureIdx = layout.GetFeature(featureEffect.Feature.Index, featureEffect.Feature.Type); Y_ASSERT(featureIdx < featureCount); effect[featureIdx] = featureEffect.Score; } return effect; }
static TString FormatOption(const TOpt* option, const NColorizer::TColors& colors) { TStringStream result; const TOpt::TShortNames& shorts = option->GetShortNames(); const TOpt::TLongNames& longs = option->GetLongNames(); const size_t nopts = shorts.size() + longs.size(); const bool multiple = 1 < nopts; if (multiple) result << '{'; for (size_t i = 0; i < nopts; ++i) { if (multiple && 0 != i) result << '|'; if (i < shorts.size()) // short result << colors.GreenColor() << '-' << shorts[i] << colors.OldColor(); else result << colors.GreenColor() << "--" << longs[i - shorts.size()] << colors.OldColor(); } if (multiple) result << '}'; static const TString metavarDef("VAL"); const TString& title = option->GetArgTitle(); const TString& metavar = title.Empty() ? metavarDef : title; if (option->GetHasArg() == OPTIONAL_ARGUMENT) { result << " [" << metavar; if (option->HasOptionalValue()) result << ':' << option->GetOptionalValue(); result << ']'; } else if (option->GetHasArg() == REQUIRED_ARGUMENT) result << ' ' << metavar; else Y_ASSERT(option->GetHasArg() == NO_ARGUMENT); return result.Str(); }
void IBinSaver::StoreObject(IObjectBase *pObject) { if (pObject) { Y_ASSERT(pSaverClasses->GetObjectTypeID(pObject) != -1 && "trying to save unregistered object"); } ui64 ptrId = ((char*)pObject) - ((char*)nullptr); if (StableOutput) { ui32 id = 0; if (pObject) { if (!PtrIds.Get()) PtrIds.Reset(new PtrIdHash); PtrIdHash::iterator pFound = PtrIds->find(pObject); if (pFound != PtrIds->end()) id = pFound->second; else { id = PtrIds->ysize() + 1; PtrIds->insert(std::make_pair(pObject, id)); } } ptrId = id; } DataChunk(&ptrId, sizeof(ptrId)); if (!Objects.Get()) Objects.Reset(new CObjectsHash); if (ptrId != 0 && Objects->find(ptrId) == Objects->end()) { ObjectQueue.push_back(pObject); (*Objects)[ptrId]; int typeId = pSaverClasses->GetObjectTypeID(pObject); if (typeId == -1) { fprintf(stderr, "IBinSaver: trying to save unregistered object\n"); abort(); } DataChunk(&typeId, sizeof(typeId)); } }
IObjectBase* IBinSaver::LoadObject() { ui64 ptrId = 0; DataChunk(&ptrId, sizeof(ptrId)); if (ptrId != 0) { if (!Objects.Get()) Objects.Reset(new CObjectsHash); CObjectsHash::iterator pFound = Objects->find(ptrId); if (pFound != Objects->end()) return pFound->second; int typeId; DataChunk(&typeId, sizeof(typeId)); IObjectBase *pObj = pSaverClasses->CreateObject(typeId); Y_ASSERT(pObj != nullptr); if (pObj == nullptr) { fprintf(stderr, "IBinSaver: trying to load unregistered object\n"); abort(); } (*Objects)[ptrId] = pObj; ObjectQueue.push_back(pObj); return pObj; } return nullptr; }
static TDStrResult GetFinalDocumentImportances( const TVector<TVector<double>>& rawImportances, EDocumentStrengthType docImpMethod, int topSize, EImportanceValuesSign importanceValuesSign ) { const ui32 trainDocCount = rawImportances.size(); Y_ASSERT(rawImportances.size() != 0); const ui32 testDocCount = rawImportances[0].size(); TVector<TVector<double>> preprocessedImportances; if (docImpMethod == EDocumentStrengthType::Average) { preprocessedImportances = TVector<TVector<double>>(1, TVector<double>(trainDocCount)); for (ui32 trainDocId = 0; trainDocId < trainDocCount; ++trainDocId) { for (ui32 testDocId = 0; testDocId < testDocCount; ++testDocId) { preprocessedImportances[0][trainDocId] += rawImportances[trainDocId][testDocId]; } } for (ui32 trainDocId = 0; trainDocId < trainDocCount; ++trainDocId) { preprocessedImportances[0][trainDocId] /= testDocCount; } } else { Y_ASSERT(docImpMethod == EDocumentStrengthType::PerObject || docImpMethod == EDocumentStrengthType::Raw); preprocessedImportances = TVector<TVector<double>>(testDocCount, TVector<double>(trainDocCount)); for (ui32 trainDocId = 0; trainDocId < trainDocCount; ++trainDocId) { for (ui32 testDocId = 0; testDocId < testDocCount; ++testDocId) { preprocessedImportances[testDocId][trainDocId] = rawImportances[trainDocId][testDocId]; } } } TDStrResult result(preprocessedImportances.size()); for (ui32 testDocId = 0; testDocId < preprocessedImportances.size(); ++testDocId) { TVector<double>& preprocessedImportancesRef = preprocessedImportances[testDocId]; const ui32 docCount = preprocessedImportancesRef.size(); TVector<ui32> indices(docCount); std::iota(indices.begin(), indices.end(), 0); if (docImpMethod != EDocumentStrengthType::Raw) { Sort(indices.begin(), indices.end(), [&](ui32 first, ui32 second) { return Abs(preprocessedImportancesRef[first]) > Abs(preprocessedImportancesRef[second]); }); } std::function<bool(double)> predicate; if (importanceValuesSign == EImportanceValuesSign::Positive) { predicate = [](double v){return v > 0;}; } else if (importanceValuesSign == EImportanceValuesSign::Negative) { predicate = [](double v){return v < 0;}; } else { Y_ASSERT(importanceValuesSign == EImportanceValuesSign::All); predicate = [](double){return true;}; } int currentSize = 0; for (ui32 i = 0; i < docCount; ++i) { if (currentSize == topSize) { break; } if (predicate(preprocessedImportancesRef[indices[i]])) { result.Scores[testDocId].push_back(preprocessedImportancesRef[indices[i]]); result.Indices[testDocId].push_back(indices[i]); } ++currentSize; } } return result; }
inline TArrayRef<int> GetBorders(size_t i) { Y_ASSERT(i < MaxElem); return TArrayRef<int>(BucketData + BorderCount * i, BorderCount); }
inline int& GetTotal(size_t i) { Y_ASSERT(i < MaxElem); return *(Data + i); }
void ComputeOnlineCTRs(const TTrainData& data, const TFold& fold, const TProjection& proj, TLearnContext* ctx, TOnlineCTR* dst, size_t* totalLeafCount) { const TCtrHelper& ctrHelper = ctx->CtrsHelper; const auto& ctrInfo = ctrHelper.GetCtrInfo(proj); dst->Feature.resize(ctrInfo.size()); using THashArr = TVector<ui64>; using TRehashHash = TDenseHash<ui64, ui32>; Y_STATIC_THREAD(THashArr) tlsHashArr; Y_STATIC_THREAD(TRehashHash) rehashHashTlsVal; TVector<ui64>& hashArr = tlsHashArr.Get(); CalcHashes(proj, data.AllFeatures, fold.EffectiveDocCount, fold.LearnPermutation, &hashArr); rehashHashTlsVal.Get().MakeEmpty(fold.LearnPermutation.size()); ui64 topSize = ctx->Params.CatFeatureParams->CtrLeafCountLimit; if (proj.IsSingleCatFeature() && ctx->Params.CatFeatureParams->StoreAllSimpleCtrs) { topSize = Max<ui64>(); } auto leafCount = ReindexHash( fold.LearnPermutation.size(), topSize, &hashArr, rehashHashTlsVal.GetPtr()); *totalLeafCount = leafCount.second; for (int ctrIdx = 0; ctrIdx < dst->Feature.ysize(); ++ctrIdx) { const ECtrType ctrType = ctrInfo[ctrIdx].Type; const ui32 classifierId = ctrInfo[ctrIdx].TargetClassifierIdx; int targetClassesCount = fold.TargetClassesCount[classifierId]; const ui32 targetBorderCount = GetTargetBorderCount(ctrInfo[ctrIdx], targetClassesCount); const ui32 ctrBorderCount = ctrInfo[ctrIdx].BorderCount; const auto& priors = ctrInfo[ctrIdx].Priors; dst->Feature[ctrIdx].SetSizes(priors.size(), targetBorderCount); for (ui32 border = 0; border < targetBorderCount; ++border) { for (int prior = 0; prior < priors.ysize(); ++prior) { Clear(&dst->Feature[ctrIdx][border][prior], data.GetSampleCount()); } } if (ctrType == ECtrType::Borders && targetClassesCount == SIMPLE_CLASSES_COUNT) { CalcOnlineCTRSimple( data, hashArr, leafCount.second, fold.LearnTargetClass[classifierId], priors, ctrBorderCount, &dst->Feature[ctrIdx]); } else if (ctrType == ECtrType::BinarizedTargetMeanValue) { CalcOnlineCTRMean( data, hashArr, leafCount.second, fold.LearnTargetClass[classifierId], targetClassesCount - 1, priors, ctrBorderCount, &dst->Feature[ctrIdx]); } else if (ctrType == ECtrType::Buckets || (ctrType == ECtrType::Borders && targetClassesCount > SIMPLE_CLASSES_COUNT)) { CalcOnlineCTRClasses( data, hashArr, leafCount.second, fold.LearnTargetClass[classifierId], targetClassesCount, GetTargetBorderCount(ctrInfo[ctrIdx], targetClassesCount), priors, ctrBorderCount, ctrType, &dst->Feature[ctrIdx]); } else { Y_ASSERT(ctrType == ECtrType::Counter); CalcOnlineCTRCounter( data, hashArr, leafCount.second, priors, ctrBorderCount, ctx->Params.CatFeatureParams->CounterCalcMethod, &dst->Feature[ctrIdx]); } } }
static void CalcOnlineCTRCounter(const TTrainData& data, const TVector<ui64>& enumeratedCatFeatures, size_t leafCount, const TVector<float>& priors, int ctrBorderCount, ECounterCalc counterCalc, TArray2D<TVector<ui8>>* feature) { const auto docCount = enumeratedCatFeatures.ysize(); auto ctrArrTotal = TCtrCalcer::GetCtrArrTotal(leafCount); TVector<float> shift; TVector<float> norm; CalcNormalization(priors, &shift, &norm); Y_ASSERT(docCount >= data.LearnSampleCount); enum ECalcMode { Skip, Full }; auto CalcCTRs = [](const TVector<ui64>& enumeratedCatFeatures, const TVector<float>& priors, const TVector<float>& shift, const TVector<float>& norm, int ctrBorderCount, int firstId, int lastId, ECalcMode mode, int* denominator, int* ctrArrTotal, TArray2D<TVector<ui8>>* feature) { int currentDenominator = *denominator; if (mode == ECalcMode::Full) { for (int docId = firstId; docId < lastId; ++docId) { const auto elemId = enumeratedCatFeatures[docId]; ++ctrArrTotal[elemId]; currentDenominator = Max(currentDenominator, ctrArrTotal[elemId]); } } const int blockSize = 1000; TVector<int> ctrTotal(blockSize); TVector<int> ctrDenominator(blockSize); for (int blockStart = firstId; blockStart < lastId; blockStart += blockSize) { const int blockEnd = Min(lastId, blockStart + blockSize); for (int docId = blockStart; docId < blockEnd; ++docId) { const auto elemId = enumeratedCatFeatures[docId]; ctrTotal[docId - blockStart] = ctrArrTotal[elemId]; ctrDenominator[docId - blockStart] = currentDenominator; } for (int prior = 0; prior < priors.ysize(); ++prior) { const float priorX = priors[prior]; const float shiftX = shift[prior]; const float normX = norm[prior]; ui8* featureData = (*feature)[0][prior].data(); for (int docId = blockStart; docId < blockEnd; ++docId) { featureData[docId] = CalcCTR(ctrTotal[docId - blockStart], ctrDenominator[docId - blockStart], priorX, shiftX, normX, ctrBorderCount); } } } *denominator = currentDenominator; }; int denominator = 0; if (counterCalc == ECounterCalc::Full) { CalcCTRs(enumeratedCatFeatures, priors, shift, norm, ctrBorderCount, 0, docCount, ECalcMode::Full, &denominator, ctrArrTotal, feature); } else { Y_ASSERT(counterCalc == ECounterCalc::SkipTest); CalcCTRs(enumeratedCatFeatures, priors, shift, norm, ctrBorderCount, 0, data.LearnSampleCount, ECalcMode::Full, &denominator, ctrArrTotal, feature); CalcCTRs(enumeratedCatFeatures, priors, shift, norm, ctrBorderCount, data.LearnSampleCount, docCount, ECalcMode::Skip, &denominator, ctrArrTotal, feature); } }
void GreedyTensorSearch(const TTrainData& data, const TVector<int>& splitCounts, double modelLength, TProfileInfo& profile, TFold* fold, TLearnContext* ctx, TSplitTree* resSplitTree) { TSplitTree currentSplitTree; TrimOnlineCTRcache({fold}); TVector<TIndexType> indices(data.LearnSampleCount); MATRIXNET_INFO_LOG << "\n"; const bool useStatsFromPrevTree = AreStatsFromPrevTreeUsed(ctx->Params.ObliviousTreeOptions.Get()); if (useStatsFromPrevTree) { AssignRandomWeights(data.LearnSampleCount, ctx, fold); ctx->StatsFromPrevTree.GarbageCollect(); } for (ui32 curDepth = 0; curDepth < ctx->Params.ObliviousTreeOptions->MaxDepth; ++curDepth) { TCandidateList candList; AddFloatFeatures(data, ctx, &ctx->StatsFromPrevTree, &candList); AddOneHotFeatures(data, ctx, &ctx->StatsFromPrevTree, &candList); AddSimpleCtrs(data, fold, ctx, &ctx->StatsFromPrevTree, &candList); AddTreeCtrs(data, currentSplitTree, fold, ctx, &ctx->StatsFromPrevTree, &candList); auto IsInCache = [&fold](const TProjection& proj) -> bool {return fold->GetCtrRef(proj).Feature.empty();}; SelectCtrsToDropAfterCalc(ctx->Params.SystemOptions->CpuUsedRamLimit, data.GetSampleCount(), ctx->Params.SystemOptions->NumThreads, IsInCache, &candList); CheckInterrupted(); // check after long-lasting operation if (!useStatsFromPrevTree) { AssignRandomWeights(data.LearnSampleCount, ctx, fold); } profile.AddOperation(TStringBuilder() << "AssignRandomWeights, depth " << curDepth); double scoreStDev = ctx->Params.ObliviousTreeOptions->RandomStrength * CalcScoreStDev(*fold) * CalcScoreStDevMult(data.LearnSampleCount, modelLength); TVector<size_t> candLeafCount(candList.ysize(), 1); const ui64 randSeed = ctx->Rand.GenRand(); ctx->LocalExecutor.ExecRange([&](int id) { auto& candidate = candList[id]; if (candidate.Candidates[0].SplitCandidate.Type == ESplitType::OnlineCtr) { const auto& proj = candidate.Candidates[0].SplitCandidate.Ctr.Projection; if (fold->GetCtrRef(proj).Feature.empty()) { ComputeOnlineCTRs(data, *fold, proj, ctx, &fold->GetCtrRef(proj), &candidate.ResultingCtrTableSize); candLeafCount[id] = candidate.ResultingCtrTableSize; } } TVector<TVector<double>> allScores(candidate.Candidates.size()); ctx->LocalExecutor.ExecRange([&](int oneCandidate) { if (candidate.Candidates[oneCandidate].SplitCandidate.Type == ESplitType::OnlineCtr) { const auto& proj = candidate.Candidates[oneCandidate].SplitCandidate.Ctr.Projection; Y_ASSERT(!fold->GetCtrRef(proj).Feature.empty()); } allScores[oneCandidate] = CalcScore(data.AllFeatures, splitCounts, *fold, indices, ctx->ParamsUsedWithStatsFromPrevTree, ctx->Params, candidate.Candidates[oneCandidate].SplitCandidate, currentSplitTree.GetDepth(), &ctx->StatsFromPrevTree); }, NPar::TLocalExecutor::TExecRangeParams(0, candidate.Candidates.ysize()) , NPar::TLocalExecutor::WAIT_COMPLETE); if (candidate.Candidates[0].SplitCandidate.Type == ESplitType::OnlineCtr && candidate.ShouldDropCtrAfterCalc) { fold->GetCtrRef(candidate.Candidates[0].SplitCandidate.Ctr.Projection).Feature.clear(); } TFastRng64 rand(randSeed + id); rand.Advance(10); // reduce correlation between RNGs in different threads for (size_t i = 0; i < allScores.size(); ++i) { double bestScoreInstance = MINIMAL_SCORE; auto& splitInfo = candidate.Candidates[i]; const auto& scores = allScores[i]; for (int binFeatureIdx = 0; binFeatureIdx < scores.ysize(); ++binFeatureIdx) { const double score = scores[binFeatureIdx]; const double scoreInstance = TRandomScore(score, scoreStDev).GetInstance(rand); if (scoreInstance > bestScoreInstance) { bestScoreInstance = scoreInstance; splitInfo.BestScore = TRandomScore(score, scoreStDev); splitInfo.BestBinBorderId = binFeatureIdx; } } } }, 0, candList.ysize(), NPar::TLocalExecutor::WAIT_COMPLETE); size_t maxLeafCount = 1; for (size_t leafCount : candLeafCount) { maxLeafCount = Max(maxLeafCount, leafCount); } fold->DropEmptyCTRs(); CheckInterrupted(); // check after long-lasting operation profile.AddOperation(TStringBuilder() << "Calc scores " << curDepth); const TCandidateInfo* bestSplitCandidate = nullptr; double bestScore = MINIMAL_SCORE; for (const auto& subList : candList) { for (const auto& candidate : subList.Candidates) { double score = candidate.BestScore.GetInstance(ctx->Rand); TProjection projection = candidate.SplitCandidate.Ctr.Projection; ECtrType ctrType = ctx->CtrsHelper.GetCtrInfo(projection)[candidate.SplitCandidate.Ctr.CtrIdx].Type; if (!ctx->LearnProgress.UsedCtrSplits.has(std::make_pair(ctrType, projection)) && score != MINIMAL_SCORE) { score *= pow(1 / (1 + subList.ResultingCtrTableSize / static_cast<double>(maxLeafCount)), ctx->Params.ObliviousTreeOptions->ModelSizeReg.Get()); } if (score > bestScore) { bestScore = score; bestSplitCandidate = &candidate; } } } if (bestScore == MINIMAL_SCORE) { break; } if (bestSplitCandidate->SplitCandidate.Type == ESplitType::OnlineCtr) { TProjection projection = bestSplitCandidate->SplitCandidate.Ctr.Projection; ECtrType ctrType = ctx->CtrsHelper.GetCtrInfo(projection)[bestSplitCandidate->SplitCandidate.Ctr.CtrIdx].Type; ctx->LearnProgress.UsedCtrSplits.insert(std::make_pair(ctrType, projection)); } auto bestSplit = TSplit(bestSplitCandidate->SplitCandidate, bestSplitCandidate->BestBinBorderId); if (bestSplit.Type == ESplitType::OnlineCtr) { const auto& proj = bestSplit.Ctr.Projection; if (fold->GetCtrRef(proj).Feature.empty()) { size_t totalLeafCount; ComputeOnlineCTRs(data, *fold, proj, ctx, &fold->GetCtrRef(proj), &totalLeafCount); DropStatsForProjection(*fold, *ctx, proj, &ctx->StatsFromPrevTree); } } else if (bestSplit.Type == ESplitType::OneHotFeature) { bestSplit.BinBorder = data.AllFeatures.OneHotValues[bestSplit.FeatureIdx][bestSplit.BinBorder]; } SetPermutedIndices(bestSplit, data.AllFeatures, curDepth + 1, *fold, &indices, ctx); if (useStatsFromPrevTree) { ctx->ParamsUsedWithStatsFromPrevTree.SelectParametersForSmallestSplitSide(curDepth + 1, *fold, indices, &ctx->LocalExecutor); } currentSplitTree.AddSplit(bestSplit); MATRIXNET_INFO_LOG << BuildDescription(ctx->Layout, bestSplit); MATRIXNET_INFO_LOG << " score " << bestScore << "\n"; profile.AddOperation(TStringBuilder() << "Select best split " << curDepth); int redundantIdx = GetRedundantSplitIdx(curDepth + 1, indices); if (redundantIdx != -1) { DeleteSplit(curDepth + 1, redundantIdx, ¤tSplitTree, &indices); MATRIXNET_INFO_LOG << " tensor " << redundantIdx << " is redundant, remove it and stop\n"; break; } } *resSplitTree = std::move(currentSplitTree); }