static inline bool IsConstCatValue(int featureIdx, const TDocumentStorage& docStorage, const TDocSelector& docSelector) { size_t docCount = docSelector.GetDocCount(); if (docCount == 0) { return true; } const TVector<float>& src = docStorage.Factors[featureIdx]; int src0 = ConvertFloatCatFeatureToIntHash(src[docSelector(0)]); for (size_t i = 1; i < docCount; ++i) { if (ConvertFloatCatFeatureToIntHash(src[docSelector(i)]) != src0) { return false; } } return true; }
void CalcFinalCtrs(const ECtrType ctrType, const TFeatureCombination& projection, const TPool& pool, ui64 sampleCount, const TVector<int>& permutedTargetClass, const TVector<float>& permutedTargets, int targetClassesCount, ui64 ctrLeafCountLimit, bool storeAllSimpleCtr, TCtrValueTable* result) { TVector<ui64> hashArr; CalcHashes( projection, [&pool] (int floatFeatureIdx, size_t docId) -> float { return pool.Docs.Factors[floatFeatureIdx][docId]; }, [&pool] (int floatFeatureIdx, size_t docId) -> int { return ConvertFloatCatFeatureToIntHash(pool.Docs.Factors[floatFeatureIdx][docId]); }, sampleCount, &hashArr); if (projection.IsSingleCatFeature() && storeAllSimpleCtr) { ctrLeafCountLimit = Max<ui64>(); } CalcFinalCtrsImpl(ctrType, ctrLeafCountLimit, permutedTargetClass, permutedTargets, sampleCount, targetClassesCount, &hashArr, result); }
static inline void BinarizeCatFeature(int featureIdx, const TDocumentStorage& docStorage, const TDocSelector& docSelector, int catFeatureIdx, TAllFeatures* features) { size_t docCount = docSelector.GetDocCount(); const TVector<float>& src = docStorage.Factors[featureIdx]; TVector<int>& dstRemapped = features->CatFeaturesRemapped[catFeatureIdx]; TVector<int>& dstValues = features->OneHotValues[catFeatureIdx]; bool dstIsOneHot = features->IsOneHot[catFeatureIdx]; dstRemapped.resize(docCount); using TCatFeaturesRemap = THashMap<int, int>; TCatFeaturesRemap uniqueFeaturesRemap; if (dstValues.empty()) { // Processing learn data for (size_t i = 0; i < docCount; ++i) { const auto val = ConvertFloatCatFeatureToIntHash(src[docSelector(i)]); TCatFeaturesRemap::insert_ctx ctx = nullptr; TCatFeaturesRemap::iterator it = uniqueFeaturesRemap.find(val, ctx); if (it == uniqueFeaturesRemap.end()) { it = uniqueFeaturesRemap.emplace_direct(ctx, val, (int)uniqueFeaturesRemap.size()); } dstRemapped[i] = it->second; } dstValues.resize(uniqueFeaturesRemap.size()); for (const auto& kv : uniqueFeaturesRemap) { dstValues[kv.second] = kv.first; } // Cases `dstValues.size() == 1` and `> oneHotMaxSize` are up to the caller. } else { for (size_t i = 0; i < dstValues.size(); ++i) { uniqueFeaturesRemap.emplace(dstValues[i], static_cast<int>(i)); } if (dstIsOneHot) { for (size_t i = 0; i < docCount; ++i) { const auto val = ConvertFloatCatFeatureToIntHash(src[docSelector(i)]); TCatFeaturesRemap::iterator it = uniqueFeaturesRemap.find(val); if (it == uniqueFeaturesRemap.end()) { dstRemapped[i] = static_cast<int>(uniqueFeaturesRemap.size()); } else { dstRemapped[i] = it->second; } } } else { for (size_t i = 0; i < docCount; ++i) { const auto val = ConvertFloatCatFeatureToIntHash(src[docSelector(i)]); TCatFeaturesRemap::insert_ctx ctx = nullptr; TCatFeaturesRemap::iterator it = uniqueFeaturesRemap.find(val, ctx); if (it == uniqueFeaturesRemap.end()) { int remap = static_cast<int>(uniqueFeaturesRemap.size()); dstValues.push_back(val); it = uniqueFeaturesRemap.emplace_direct(ctx, val, remap); dstRemapped[i] = it->second; } else { dstRemapped[i] = it->second; } } } } }