Example #1
0
 TConstArrayRef<T> Remap(const TConstArrayRef<ui64>& keys) {
     if (keys.empty()) {
         return TConstArrayRef<T>();
     }
     Y_ASSERT(keys.begin() >= Keys.begin() && keys.begin() <= Keys.end());
     Y_ASSERT(keys.end() >= Keys.begin() && keys.end() <= Keys.end());
     return TConstArrayRef<T>(Words[keys.begin() - Keys.begin()].begin(), Words[keys.end() - Keys.begin() - 1].end());
 }
Example #2
0
TVector<TVector<double>> PrepareEval(const EPredictionType predictionType,
                                     const TVector<TVector<double>>& approx,
                                     NPar::TLocalExecutor* localExecutor) {
    TVector<TVector<double>> result;
    switch (predictionType) {
        case EPredictionType::Probability:
            if (IsMulticlass(approx)) {
                result = CalcSoftmax(approx, localExecutor);
            } else {
                result = {CalcSigmoid(approx[0])};
            }
            break;
        case EPredictionType::Class:
            result.resize(1);
            result[0].reserve(approx.size());
            if (IsMulticlass(approx)) {
                TVector<int> predictions = {SelectBestClass(approx, localExecutor)};
                result[0].assign(predictions.begin(), predictions.end());
            } else {
                for (const double prediction : approx[0]) {
                    result[0].push_back(prediction > 0);
                }
            }
            break;
        case EPredictionType::RawFormulaVal:
            result = approx;
            break;
        default:
            Y_ASSERT(false);
    }
    return result;
}
static void AddCtrsToCandList(const TFold& fold,
                              const TLearnContext& ctx,
                              const TProjection& proj,
                              TCandidateList* candList) {
    TCandidatesInfoList ctrSplits;
    const auto& ctrsHelper = ctx.CtrsHelper;
    const auto& ctrInfo = ctrsHelper.GetCtrInfo(proj);

    for (ui32 ctrIdx = 0; ctrIdx < ctrInfo.size(); ++ctrIdx) {
        const ui32 targetClassesCount = fold.TargetClassesCount[ctrInfo[ctrIdx].TargetClassifierIdx];
        int targetBorderCount = GetTargetBorderCount(ctrInfo[ctrIdx], targetClassesCount);
        const auto& priors = ctrInfo[ctrIdx].Priors;
        int priorsCount = priors.size();
        Y_ASSERT(targetBorderCount < 256);
        for (int border = 0; border < targetBorderCount; ++border) {
            for (int prior = 0; prior < priorsCount; ++prior) {
                TCandidateInfo split;
                split.SplitCandidate.Type = ESplitType::OnlineCtr;
                split.SplitCandidate.Ctr = TCtr(proj, ctrIdx, border, prior, ctrInfo[ctrIdx].BorderCount);
                ctrSplits.Candidates.emplace_back(split);
            }
        }
    }

    candList->push_back(ctrSplits);
}
Example #4
0
void CalcFinalCtrsImpl(
    const ECtrType ctrType,
    const ui64 ctrLeafCountLimit,
    const TVector<int>& permutedTargetClass,
    const TVector<float>& permutedTargets,
    const ui64 learnSampleCount,
    int targetClassesCount,
    TVector<ui64>* hashArr,
    TCtrValueTable* result
) {
    TDenseHash<ui64, ui32> tmpHash;
    auto leafCount = ReindexHash(
        learnSampleCount,
        ctrLeafCountLimit,
        hashArr,
        &tmpHash).first;
    auto hashIndexBuilder = result->GetIndexHashBuilder(leafCount);
    for (const auto& kv : tmpHash) {
        hashIndexBuilder.SetIndex(kv.Key(), kv.Value());
    }
    TArrayRef<int> ctrIntArray;
    TArrayRef<TCtrMeanHistory> ctrMean;
    if (ctrType == ECtrType::BinarizedTargetMeanValue || ctrType == ECtrType::FloatTargetMeanValue) {
        ctrMean = result->AllocateBlobAndGetArrayRef<TCtrMeanHistory>(leafCount);
    } else if (ctrType == ECtrType::Counter || ctrType == ECtrType::FeatureFreq) {
        ctrIntArray = result->AllocateBlobAndGetArrayRef<int>(leafCount);
        result->CounterDenominator = 0;
    } else {
        result->TargetClassesCount = targetClassesCount;
        ctrIntArray = result->AllocateBlobAndGetArrayRef<int>(leafCount * targetClassesCount);
    }

    Y_ASSERT(hashArr->size() == learnSampleCount);
    int targetBorderCount = targetClassesCount - 1;
    auto hashArrPtr = hashArr->data();
    for (ui32 z = 0; z < learnSampleCount; ++z) {
        const ui64 elemId = hashArrPtr[z];
        if (ctrType == ECtrType::BinarizedTargetMeanValue) {
            TCtrMeanHistory& elem = ctrMean[elemId];
            elem.Add(static_cast<float>(permutedTargetClass[z]) / targetBorderCount);
        } else if (ctrType == ECtrType::Counter || ctrType == ECtrType::FeatureFreq) {
            ++ctrIntArray[elemId];
        } else if (ctrType == ECtrType::FloatTargetMeanValue) {
            TCtrMeanHistory& elem = ctrMean[elemId];
            elem.Add(permutedTargets[z]);
        } else {
            TArrayRef<int> elem = MakeArrayRef(ctrIntArray.data() + targetClassesCount * elemId, targetClassesCount);
            ++elem[permutedTargetClass[z]];
        }
    }

    if (ctrType == ECtrType::Counter) {
        result->CounterDenominator = *MaxElement(ctrIntArray.begin(), ctrIntArray.end());
    }
    if (ctrType == ECtrType::FeatureFreq) {
        result->CounterDenominator = static_cast<int>(learnSampleCount);
    }
}
void TOutputFiles::InitializeFiles(const NCatboostOptions::TOutputFilesOptions& params, const TString& namesPrefix) {
    if (!params.AllowWriteFiles()) {
        Y_ASSERT(TimeLeftLogFile.empty());
        Y_ASSERT(LearnErrorLogFile.empty());
        Y_ASSERT(TestErrorLogFile.empty());
        Y_ASSERT(MetaFile.empty());
        Y_ASSERT(SnapshotFile.empty());
        return;
    }

    const auto& trainDir = params.GetTrainDir();
    TFsPath trainDirPath(trainDir);
    if (!trainDir.empty() && !trainDirPath.Exists()) {
        trainDirPath.MkDir();
    }
    NamesPrefix = namesPrefix;
    CB_ENSURE(!params.GetTimeLeftLogFilename().empty(), "empty time_left filename");
    TimeLeftLogFile = TOutputFiles::AlignFilePath(trainDir, params.GetTimeLeftLogFilename(), NamesPrefix);

    CB_ENSURE(!params.GetLearnErrorFilename().empty(), "empty learn_error filename");
    LearnErrorLogFile = TOutputFiles::AlignFilePath(trainDir, params.GetLearnErrorFilename(), NamesPrefix);
    if (params.GetTestErrorFilename()) {
        TestErrorLogFile = TOutputFiles::AlignFilePath(trainDir, params.GetTestErrorFilename(), NamesPrefix);
    }
    if (params.SaveSnapshot()) {
        SnapshotFile = TOutputFiles::AlignFilePath(trainDir, params.GetSnapshotFilename(), NamesPrefix);
    }
    const TString& metaFileFilename = params.GetMetaFileFilename();
    CB_ENSURE(!metaFileFilename.empty(), "empty meta filename");
    MetaFile = TOutputFiles::AlignFilePath(trainDir, metaFileFilename, NamesPrefix);

    const TString& jsonLogFilename = params.GetJsonLogFilename();
    CB_ENSURE(!jsonLogFilename.empty(), "empty json_log filename");
    JsonLogFile = TOutputFiles::AlignFilePath(trainDir, jsonLogFilename, "");

    const TString& profileLogFilename = params.GetProfileLogFilename();
    CB_ENSURE(!profileLogFilename.empty(), "empty profile_log filename");
    ProfileLogFile = TOutputFiles::AlignFilePath(trainDir, profileLogFilename, "");
}
Example #6
0
TString TFeature::BuildDescription(const TFeaturesLayout& layout) const {
    TStringBuilder result;
    if (Type == ESplitType::OnlineCtr) {
        result << ::BuildDescription(layout, Ctr.Base.Projection);
        result << " prior_num=" << Ctr.PriorNum;
        result << " prior_denom=" << Ctr.PriorDenom;
        result << " targetborder=" << Ctr.TargetBorderIdx;
        result << " type=" << Ctr.Base.CtrType;
    } else if (Type == ESplitType::FloatFeature) {
        result << BuildFeatureDescription(layout, FeatureIdx, EFeatureType::Float);
    } else {
        Y_ASSERT(Type == ESplitType::OneHotFeature);
        result << BuildFeatureDescription(layout, FeatureIdx, EFeatureType::Categorical);
    }
    return result;
}
Example #7
0
TVector<double> CalcRegularFeatureEffect(const TFullModel& model, const TPool& pool, int threadCount/*= 1*/) {
    int featureCount = pool.Docs.GetFactorsCount();
    CB_ENSURE(static_cast<size_t>(featureCount) >= model.ObliviousTrees.GetFlatFeatureVectorExpectedSize(), "Insufficient features count in pool");
    int catFeaturesCount = pool.CatFeatures.ysize();
    int floatFeaturesCount = featureCount - catFeaturesCount;
    TFeaturesLayout layout(featureCount, pool.CatFeatures, pool.FeatureId);

    TVector<TFeatureEffect> regularEffect = CalcRegularFeatureEffect(CalcFeatureEffect(model, pool, threadCount),
                                                                     catFeaturesCount, floatFeaturesCount);

    TVector<double> effect(featureCount);
    for (const auto& featureEffect : regularEffect) {
        int featureIdx = layout.GetFeature(featureEffect.Feature.Index, featureEffect.Feature.Type);
        Y_ASSERT(featureIdx < featureCount);
        effect[featureIdx] = featureEffect.Score;
    }

    return effect;
}
Example #8
0
static TString FormatOption(const TOpt* option, const NColorizer::TColors& colors) {
    TStringStream result;
    const TOpt::TShortNames& shorts = option->GetShortNames();
    const TOpt::TLongNames& longs = option->GetLongNames();

    const size_t nopts = shorts.size() + longs.size();
    const bool multiple = 1 < nopts;
    if (multiple)
        result << '{';
    for (size_t i = 0; i < nopts; ++i) {
        if (multiple && 0 != i)
            result << '|';

        if (i < shorts.size()) // short
            result  << colors.GreenColor() << '-' << shorts[i] << colors.OldColor();
        else
            result << colors.GreenColor() << "--" << longs[i - shorts.size()] << colors.OldColor();
    }
    if (multiple)
        result << '}';

    static const TString metavarDef("VAL");
    const TString& title = option->GetArgTitle();
    const TString& metavar = title.Empty() ? metavarDef : title;

    if (option->GetHasArg() == OPTIONAL_ARGUMENT) {
        result << " [" << metavar;
        if (option->HasOptionalValue())
            result << ':' << option->GetOptionalValue();
        result << ']';
    } else if (option->GetHasArg() == REQUIRED_ARGUMENT)
        result << ' ' << metavar;
    else
        Y_ASSERT(option->GetHasArg() == NO_ARGUMENT);

    return result.Str();
}
Example #9
0
void IBinSaver::StoreObject(IObjectBase *pObject)
{
    if (pObject) {
        Y_ASSERT(pSaverClasses->GetObjectTypeID(pObject) != -1 && "trying to save unregistered object");
    }

    ui64 ptrId = ((char*)pObject) - ((char*)nullptr);
    if (StableOutput) {
        ui32 id = 0;
        if (pObject) {
            if (!PtrIds.Get())
                PtrIds.Reset(new PtrIdHash);
            PtrIdHash::iterator pFound = PtrIds->find(pObject);
            if (pFound != PtrIds->end())
                id = pFound->second;
            else {
                id = PtrIds->ysize() + 1;
                PtrIds->insert(std::make_pair(pObject, id));
            }
        }
        ptrId = id;
    }

    DataChunk(&ptrId, sizeof(ptrId));
    if (!Objects.Get())
        Objects.Reset(new CObjectsHash);
    if (ptrId != 0 && Objects->find(ptrId) == Objects->end()) {
        ObjectQueue.push_back(pObject);
        (*Objects)[ptrId];
        int typeId = pSaverClasses->GetObjectTypeID(pObject);
        if (typeId == -1) {
            fprintf(stderr, "IBinSaver: trying to save unregistered object\n");
            abort();
        }
        DataChunk(&typeId, sizeof(typeId));
    }
}
Example #10
0
IObjectBase* IBinSaver::LoadObject()
{
    ui64 ptrId = 0;
    DataChunk(&ptrId, sizeof(ptrId));
    if (ptrId != 0) {
        if (!Objects.Get())
            Objects.Reset(new CObjectsHash);
        CObjectsHash::iterator pFound = Objects->find(ptrId);
        if (pFound != Objects->end())
            return pFound->second;
        int typeId;
        DataChunk(&typeId, sizeof(typeId));
        IObjectBase *pObj = pSaverClasses->CreateObject(typeId);
        Y_ASSERT(pObj != nullptr);
        if (pObj == nullptr) {
            fprintf(stderr, "IBinSaver: trying to load unregistered object\n");
            abort();
        }
        (*Objects)[ptrId] = pObj;
        ObjectQueue.push_back(pObj);
        return pObj;
    }
    return nullptr;
}
static TDStrResult GetFinalDocumentImportances(
    const TVector<TVector<double>>& rawImportances,
    EDocumentStrengthType docImpMethod,
    int topSize,
    EImportanceValuesSign importanceValuesSign
) {
    const ui32 trainDocCount = rawImportances.size();
    Y_ASSERT(rawImportances.size() != 0);
    const ui32 testDocCount = rawImportances[0].size();
    TVector<TVector<double>> preprocessedImportances;
    if (docImpMethod == EDocumentStrengthType::Average) {
        preprocessedImportances = TVector<TVector<double>>(1, TVector<double>(trainDocCount));
        for (ui32 trainDocId = 0; trainDocId < trainDocCount; ++trainDocId) {
            for (ui32 testDocId = 0; testDocId < testDocCount; ++testDocId) {
                preprocessedImportances[0][trainDocId] += rawImportances[trainDocId][testDocId];
            }
        }
        for (ui32 trainDocId = 0; trainDocId < trainDocCount; ++trainDocId) {
            preprocessedImportances[0][trainDocId] /= testDocCount;
        }

    } else {
        Y_ASSERT(docImpMethod == EDocumentStrengthType::PerObject || docImpMethod == EDocumentStrengthType::Raw);
        preprocessedImportances = TVector<TVector<double>>(testDocCount, TVector<double>(trainDocCount));
        for (ui32 trainDocId = 0; trainDocId < trainDocCount; ++trainDocId) {
            for (ui32 testDocId = 0; testDocId < testDocCount; ++testDocId) {
                preprocessedImportances[testDocId][trainDocId] = rawImportances[trainDocId][testDocId];
            }
        }
    }

    TDStrResult result(preprocessedImportances.size());
    for (ui32 testDocId = 0; testDocId < preprocessedImportances.size(); ++testDocId) {
        TVector<double>& preprocessedImportancesRef = preprocessedImportances[testDocId];

        const ui32 docCount = preprocessedImportancesRef.size();
        TVector<ui32> indices(docCount);
        std::iota(indices.begin(), indices.end(), 0);
        if (docImpMethod != EDocumentStrengthType::Raw) {
            Sort(indices.begin(), indices.end(), [&](ui32 first, ui32 second) {
                return Abs(preprocessedImportancesRef[first]) > Abs(preprocessedImportancesRef[second]);
            });
        }

        std::function<bool(double)> predicate;
        if (importanceValuesSign == EImportanceValuesSign::Positive) {
            predicate = [](double v){return v > 0;};
        } else if (importanceValuesSign == EImportanceValuesSign::Negative) {
            predicate = [](double v){return v < 0;};
        } else {
            Y_ASSERT(importanceValuesSign == EImportanceValuesSign::All);
            predicate = [](double){return true;};
        }

        int currentSize = 0;
        for (ui32 i = 0; i < docCount; ++i) {
            if (currentSize == topSize) {
                break;
            }
            if (predicate(preprocessedImportancesRef[indices[i]])) {
                result.Scores[testDocId].push_back(preprocessedImportancesRef[indices[i]]);
                result.Indices[testDocId].push_back(indices[i]);
            }
            ++currentSize;
        }
    }
    return result;
}
Example #12
0
 inline TArrayRef<int> GetBorders(size_t i) {
     Y_ASSERT(i < MaxElem);
     return TArrayRef<int>(BucketData + BorderCount * i, BorderCount);
 }
Example #13
0
 inline int& GetTotal(size_t i) {
     Y_ASSERT(i < MaxElem);
     return *(Data + i);
 }
Example #14
0
void ComputeOnlineCTRs(const TTrainData& data,
                       const TFold& fold,
                       const TProjection& proj,
                       TLearnContext* ctx,
                       TOnlineCTR* dst,
                       size_t* totalLeafCount) {

    const TCtrHelper& ctrHelper = ctx->CtrsHelper;
    const auto& ctrInfo = ctrHelper.GetCtrInfo(proj);
    dst->Feature.resize(ctrInfo.size());

    using THashArr = TVector<ui64>;
    using TRehashHash = TDenseHash<ui64, ui32>;
    Y_STATIC_THREAD(THashArr) tlsHashArr;
    Y_STATIC_THREAD(TRehashHash) rehashHashTlsVal;
    TVector<ui64>& hashArr = tlsHashArr.Get();
    CalcHashes(proj, data.AllFeatures, fold.EffectiveDocCount, fold.LearnPermutation, &hashArr);
    rehashHashTlsVal.Get().MakeEmpty(fold.LearnPermutation.size());
    ui64 topSize = ctx->Params.CatFeatureParams->CtrLeafCountLimit;
    if (proj.IsSingleCatFeature() && ctx->Params.CatFeatureParams->StoreAllSimpleCtrs) {
        topSize = Max<ui64>();
    }
    auto leafCount = ReindexHash(
        fold.LearnPermutation.size(),
        topSize,
        &hashArr,
        rehashHashTlsVal.GetPtr());
    *totalLeafCount = leafCount.second;

    for (int ctrIdx = 0; ctrIdx < dst->Feature.ysize(); ++ctrIdx) {
        const ECtrType ctrType = ctrInfo[ctrIdx].Type;
        const ui32 classifierId = ctrInfo[ctrIdx].TargetClassifierIdx;
        int targetClassesCount = fold.TargetClassesCount[classifierId];

        const ui32 targetBorderCount = GetTargetBorderCount(ctrInfo[ctrIdx], targetClassesCount);
        const ui32 ctrBorderCount = ctrInfo[ctrIdx].BorderCount;
        const auto& priors = ctrInfo[ctrIdx].Priors;
        dst->Feature[ctrIdx].SetSizes(priors.size(), targetBorderCount);

        for (ui32 border = 0; border < targetBorderCount; ++border) {
            for (int prior = 0; prior < priors.ysize(); ++prior) {
                Clear(&dst->Feature[ctrIdx][border][prior], data.GetSampleCount());
            }
        }

        if (ctrType == ECtrType::Borders && targetClassesCount == SIMPLE_CLASSES_COUNT) {
            CalcOnlineCTRSimple(
                data,
                hashArr,
                leafCount.second,
                fold.LearnTargetClass[classifierId],
                priors,
                ctrBorderCount,
                &dst->Feature[ctrIdx]);

        } else if (ctrType == ECtrType::BinarizedTargetMeanValue) {
            CalcOnlineCTRMean(
                data,
                hashArr,
                leafCount.second,
                fold.LearnTargetClass[classifierId],
                targetClassesCount - 1,
                priors,
                ctrBorderCount,
                &dst->Feature[ctrIdx]);

        } else if (ctrType == ECtrType::Buckets ||
                   (ctrType == ECtrType::Borders && targetClassesCount > SIMPLE_CLASSES_COUNT)) {
            CalcOnlineCTRClasses(
                data,
                hashArr,
                leafCount.second,
                fold.LearnTargetClass[classifierId],
                targetClassesCount,
                GetTargetBorderCount(ctrInfo[ctrIdx], targetClassesCount),
                priors,
                ctrBorderCount,
                ctrType,
                &dst->Feature[ctrIdx]);
        } else {
            Y_ASSERT(ctrType == ECtrType::Counter);
            CalcOnlineCTRCounter(
                data,
                hashArr,
                leafCount.second,
                priors,
                ctrBorderCount,
                ctx->Params.CatFeatureParams->CounterCalcMethod,
                &dst->Feature[ctrIdx]);
        }
    }
}
Example #15
0
static void CalcOnlineCTRCounter(const TTrainData& data,
                                 const TVector<ui64>& enumeratedCatFeatures,
                                 size_t leafCount,
                                 const TVector<float>& priors,
                                 int ctrBorderCount,
                                 ECounterCalc counterCalc,
                                 TArray2D<TVector<ui8>>* feature) {
    const auto docCount = enumeratedCatFeatures.ysize();
    auto ctrArrTotal = TCtrCalcer::GetCtrArrTotal(leafCount);
    TVector<float> shift;
    TVector<float> norm;
    CalcNormalization(priors, &shift, &norm);
    Y_ASSERT(docCount >= data.LearnSampleCount);

    enum ECalcMode {
        Skip,
        Full
    };

    auto CalcCTRs = [](const TVector<ui64>& enumeratedCatFeatures,
                       const TVector<float>& priors,
                       const TVector<float>& shift,
                       const TVector<float>& norm,
                       int ctrBorderCount,
                       int firstId, int lastId, ECalcMode mode,
                       int* denominator,
                       int* ctrArrTotal,
                       TArray2D<TVector<ui8>>* feature) {
        int currentDenominator = *denominator;

        if (mode == ECalcMode::Full) {
            for (int docId = firstId; docId < lastId; ++docId) {
                const auto elemId = enumeratedCatFeatures[docId];
                ++ctrArrTotal[elemId];
                currentDenominator = Max(currentDenominator, ctrArrTotal[elemId]);
            }
        }

        const int blockSize = 1000;
        TVector<int> ctrTotal(blockSize);
        TVector<int> ctrDenominator(blockSize);
        for (int blockStart = firstId; blockStart < lastId; blockStart += blockSize) {
            const int blockEnd = Min(lastId, blockStart + blockSize);
            for (int docId = blockStart; docId < blockEnd; ++docId) {
                const auto elemId = enumeratedCatFeatures[docId];
                ctrTotal[docId - blockStart] = ctrArrTotal[elemId];
                ctrDenominator[docId - blockStart] = currentDenominator;
            }

            for (int prior = 0; prior < priors.ysize(); ++prior) {
                const float priorX = priors[prior];
                const float shiftX = shift[prior];
                const float normX = norm[prior];
                ui8* featureData = (*feature)[0][prior].data();
                for (int docId = blockStart; docId < blockEnd; ++docId) {
                    featureData[docId] = CalcCTR(ctrTotal[docId - blockStart], ctrDenominator[docId - blockStart], priorX, shiftX, normX, ctrBorderCount);
                }
            }
        }

        *denominator = currentDenominator;
    };

    int denominator = 0;
    if (counterCalc == ECounterCalc::Full) {
        CalcCTRs(enumeratedCatFeatures,
                 priors, shift, norm,
                 ctrBorderCount,
                 0, docCount, ECalcMode::Full,
                 &denominator,
                 ctrArrTotal,
                 feature);
    } else {
        Y_ASSERT(counterCalc == ECounterCalc::SkipTest);
        CalcCTRs(enumeratedCatFeatures,
                 priors, shift, norm,
                 ctrBorderCount,
                 0, data.LearnSampleCount, ECalcMode::Full,
                 &denominator,
                 ctrArrTotal,
                 feature);
        CalcCTRs(enumeratedCatFeatures,
                 priors, shift, norm,
                 ctrBorderCount,
                 data.LearnSampleCount, docCount, ECalcMode::Skip,
                 &denominator,
                 ctrArrTotal,
                 feature);
    }
}
void GreedyTensorSearch(const TTrainData& data,
                        const TVector<int>& splitCounts,
                        double modelLength,
                        TProfileInfo& profile,
                        TFold* fold,
                        TLearnContext* ctx,
                        TSplitTree* resSplitTree) {
    TSplitTree currentSplitTree;
    TrimOnlineCTRcache({fold});

    TVector<TIndexType> indices(data.LearnSampleCount);
    MATRIXNET_INFO_LOG << "\n";

    const bool useStatsFromPrevTree = AreStatsFromPrevTreeUsed(ctx->Params.ObliviousTreeOptions.Get());
    if (useStatsFromPrevTree) {
        AssignRandomWeights(data.LearnSampleCount, ctx, fold);
        ctx->StatsFromPrevTree.GarbageCollect();
    }

    for (ui32 curDepth = 0; curDepth < ctx->Params.ObliviousTreeOptions->MaxDepth; ++curDepth) {
        TCandidateList candList;
        AddFloatFeatures(data, ctx, &ctx->StatsFromPrevTree, &candList);
        AddOneHotFeatures(data, ctx, &ctx->StatsFromPrevTree, &candList);
        AddSimpleCtrs(data, fold, ctx, &ctx->StatsFromPrevTree, &candList);
        AddTreeCtrs(data, currentSplitTree, fold, ctx, &ctx->StatsFromPrevTree, &candList);

        auto IsInCache = [&fold](const TProjection& proj) -> bool {return fold->GetCtrRef(proj).Feature.empty();};
        SelectCtrsToDropAfterCalc(ctx->Params.SystemOptions->CpuUsedRamLimit, data.GetSampleCount(), ctx->Params.SystemOptions->NumThreads, IsInCache, &candList);

        CheckInterrupted(); // check after long-lasting operation
        if (!useStatsFromPrevTree) {
            AssignRandomWeights(data.LearnSampleCount, ctx, fold);
        }
        profile.AddOperation(TStringBuilder() << "AssignRandomWeights, depth " << curDepth);
        double scoreStDev = ctx->Params.ObliviousTreeOptions->RandomStrength * CalcScoreStDev(*fold) * CalcScoreStDevMult(data.LearnSampleCount, modelLength);

        TVector<size_t> candLeafCount(candList.ysize(), 1);
        const ui64 randSeed = ctx->Rand.GenRand();
        ctx->LocalExecutor.ExecRange([&](int id) {
            auto& candidate = candList[id];
            if (candidate.Candidates[0].SplitCandidate.Type == ESplitType::OnlineCtr) {
                const auto& proj = candidate.Candidates[0].SplitCandidate.Ctr.Projection;
                if (fold->GetCtrRef(proj).Feature.empty()) {
                    ComputeOnlineCTRs(data,
                                      *fold,
                                      proj,
                                      ctx,
                                      &fold->GetCtrRef(proj),
                                      &candidate.ResultingCtrTableSize);
                    candLeafCount[id] = candidate.ResultingCtrTableSize;
                }
            }
            TVector<TVector<double>> allScores(candidate.Candidates.size());
            ctx->LocalExecutor.ExecRange([&](int oneCandidate) {
                if (candidate.Candidates[oneCandidate].SplitCandidate.Type == ESplitType::OnlineCtr) {
                    const auto& proj = candidate.Candidates[oneCandidate].SplitCandidate.Ctr.Projection;
                    Y_ASSERT(!fold->GetCtrRef(proj).Feature.empty());
                }
                allScores[oneCandidate] = CalcScore(data.AllFeatures,
                                                    splitCounts,
                                                    *fold,
                                                    indices,
                                                    ctx->ParamsUsedWithStatsFromPrevTree,
                                                    ctx->Params,
                                                    candidate.Candidates[oneCandidate].SplitCandidate,
                                                    currentSplitTree.GetDepth(),
                                                    &ctx->StatsFromPrevTree);
            }, NPar::TLocalExecutor::TExecRangeParams(0, candidate.Candidates.ysize())
             , NPar::TLocalExecutor::WAIT_COMPLETE);
            if (candidate.Candidates[0].SplitCandidate.Type == ESplitType::OnlineCtr && candidate.ShouldDropCtrAfterCalc) {
                fold->GetCtrRef(candidate.Candidates[0].SplitCandidate.Ctr.Projection).Feature.clear();
            }
            TFastRng64 rand(randSeed + id);
            rand.Advance(10); // reduce correlation between RNGs in different threads
            for (size_t i = 0; i < allScores.size(); ++i) {
                double bestScoreInstance = MINIMAL_SCORE;
                auto& splitInfo = candidate.Candidates[i];
                const auto& scores = allScores[i];
                for (int binFeatureIdx = 0; binFeatureIdx < scores.ysize(); ++binFeatureIdx) {
                    const double score = scores[binFeatureIdx];
                    const double scoreInstance = TRandomScore(score, scoreStDev).GetInstance(rand);
                    if (scoreInstance > bestScoreInstance) {
                        bestScoreInstance = scoreInstance;
                        splitInfo.BestScore = TRandomScore(score, scoreStDev);
                        splitInfo.BestBinBorderId = binFeatureIdx;
                    }
                }
            }
        }, 0, candList.ysize(), NPar::TLocalExecutor::WAIT_COMPLETE);
        size_t maxLeafCount = 1;
        for (size_t leafCount : candLeafCount) {
            maxLeafCount = Max(maxLeafCount, leafCount);
        }
        fold->DropEmptyCTRs();
        CheckInterrupted(); // check after long-lasting operation
        profile.AddOperation(TStringBuilder() << "Calc scores " << curDepth);

        const TCandidateInfo* bestSplitCandidate = nullptr;
        double bestScore = MINIMAL_SCORE;
        for (const auto& subList : candList) {
            for (const auto& candidate : subList.Candidates) {
                double score = candidate.BestScore.GetInstance(ctx->Rand);
                TProjection projection = candidate.SplitCandidate.Ctr.Projection;
                ECtrType ctrType = ctx->CtrsHelper.GetCtrInfo(projection)[candidate.SplitCandidate.Ctr.CtrIdx].Type;

                if (!ctx->LearnProgress.UsedCtrSplits.has(std::make_pair(ctrType, projection)) && score != MINIMAL_SCORE) {
                    score *= pow(1 / (1 + subList.ResultingCtrTableSize / static_cast<double>(maxLeafCount)), ctx->Params.ObliviousTreeOptions->ModelSizeReg.Get());
                }
                if (score > bestScore) {
                    bestScore = score;
                    bestSplitCandidate = &candidate;
                }
            }
        }
        if (bestScore == MINIMAL_SCORE) {
            break;
        }
        if (bestSplitCandidate->SplitCandidate.Type == ESplitType::OnlineCtr) {
            TProjection projection = bestSplitCandidate->SplitCandidate.Ctr.Projection;
            ECtrType ctrType = ctx->CtrsHelper.GetCtrInfo(projection)[bestSplitCandidate->SplitCandidate.Ctr.CtrIdx].Type;

            ctx->LearnProgress.UsedCtrSplits.insert(std::make_pair(ctrType, projection));
        }
        auto bestSplit = TSplit(bestSplitCandidate->SplitCandidate, bestSplitCandidate->BestBinBorderId);
        if (bestSplit.Type == ESplitType::OnlineCtr) {
            const auto& proj = bestSplit.Ctr.Projection;
            if (fold->GetCtrRef(proj).Feature.empty()) {
                size_t totalLeafCount;
                ComputeOnlineCTRs(data,
                                  *fold,
                                  proj,
                                  ctx,
                                  &fold->GetCtrRef(proj),
                                  &totalLeafCount);
                DropStatsForProjection(*fold, *ctx, proj, &ctx->StatsFromPrevTree);
            }
        } else if (bestSplit.Type == ESplitType::OneHotFeature) {
            bestSplit.BinBorder = data.AllFeatures.OneHotValues[bestSplit.FeatureIdx][bestSplit.BinBorder];
        }
        SetPermutedIndices(bestSplit, data.AllFeatures, curDepth + 1, *fold, &indices, ctx);
        if (useStatsFromPrevTree) {
            ctx->ParamsUsedWithStatsFromPrevTree.SelectParametersForSmallestSplitSide(curDepth + 1, *fold, indices, &ctx->LocalExecutor);
        }
        currentSplitTree.AddSplit(bestSplit);
        MATRIXNET_INFO_LOG << BuildDescription(ctx->Layout, bestSplit);
        MATRIXNET_INFO_LOG << " score " << bestScore << "\n";


        profile.AddOperation(TStringBuilder() << "Select best split " << curDepth);

        int redundantIdx = GetRedundantSplitIdx(curDepth + 1, indices);
        if (redundantIdx != -1) {
            DeleteSplit(curDepth + 1, redundantIdx, &currentSplitTree, &indices);
            MATRIXNET_INFO_LOG << "  tensor " << redundantIdx << " is redundant, remove it and stop\n";
            break;
        }
    }
    *resSplitTree = std::move(currentSplitTree);
}