static inline bool IsConstCatValue(int featureIdx, const TDocumentStorage& docStorage, const TDocSelector& docSelector) {
    size_t docCount = docSelector.GetDocCount();
    if (docCount == 0) {
        return true;
    }
    const TVector<float>& src = docStorage.Factors[featureIdx];
    int src0 = ConvertFloatCatFeatureToIntHash(src[docSelector(0)]);
    for (size_t i = 1; i < docCount; ++i) {
        if (ConvertFloatCatFeatureToIntHash(src[docSelector(i)]) != src0) {
            return false;
        }
    }
    return true;
}
Esempio n. 2
0
void CalcFinalCtrs(const ECtrType ctrType, const TFeatureCombination& projection, const TPool& pool, ui64 sampleCount,
                   const TVector<int>& permutedTargetClass, const TVector<float>& permutedTargets,
                   int targetClassesCount, ui64 ctrLeafCountLimit, bool storeAllSimpleCtr, TCtrValueTable* result) {
    TVector<ui64> hashArr;
    CalcHashes(
        projection,
        [&pool] (int floatFeatureIdx, size_t docId) -> float {
            return pool.Docs.Factors[floatFeatureIdx][docId];
        },
        [&pool] (int floatFeatureIdx, size_t docId) -> int {
            return ConvertFloatCatFeatureToIntHash(pool.Docs.Factors[floatFeatureIdx][docId]);
        },
        sampleCount,
        &hashArr);

    if (projection.IsSingleCatFeature() && storeAllSimpleCtr) {
        ctrLeafCountLimit = Max<ui64>();
    }
    CalcFinalCtrsImpl(ctrType, ctrLeafCountLimit, permutedTargetClass, permutedTargets, sampleCount, targetClassesCount, &hashArr, result);
}
static inline void BinarizeCatFeature(int featureIdx,
                                      const TDocumentStorage& docStorage,
                                      const TDocSelector& docSelector,
                                      int catFeatureIdx,
                                      TAllFeatures* features) {
    size_t docCount = docSelector.GetDocCount();
    const TVector<float>& src = docStorage.Factors[featureIdx];
    TVector<int>& dstRemapped = features->CatFeaturesRemapped[catFeatureIdx];
    TVector<int>& dstValues = features->OneHotValues[catFeatureIdx];
    bool dstIsOneHot = features->IsOneHot[catFeatureIdx];

    dstRemapped.resize(docCount);

    using TCatFeaturesRemap = THashMap<int, int>;
    TCatFeaturesRemap uniqueFeaturesRemap;
    if (dstValues.empty()) {
        // Processing learn data
        for (size_t i = 0; i < docCount; ++i) {
            const auto val = ConvertFloatCatFeatureToIntHash(src[docSelector(i)]);
            TCatFeaturesRemap::insert_ctx ctx = nullptr;
            TCatFeaturesRemap::iterator it = uniqueFeaturesRemap.find(val, ctx);
            if (it == uniqueFeaturesRemap.end()) {
                it = uniqueFeaturesRemap.emplace_direct(ctx, val, (int)uniqueFeaturesRemap.size());
            }
            dstRemapped[i] = it->second;
        }
        dstValues.resize(uniqueFeaturesRemap.size());
        for (const auto& kv : uniqueFeaturesRemap) {
            dstValues[kv.second] = kv.first;
        }
        // Cases `dstValues.size() == 1` and `> oneHotMaxSize` are up to the caller.
    } else {
        for (size_t i = 0; i < dstValues.size(); ++i) {
            uniqueFeaturesRemap.emplace(dstValues[i], static_cast<int>(i));
        }
        if (dstIsOneHot) {
            for (size_t i = 0; i < docCount; ++i) {
                const auto val = ConvertFloatCatFeatureToIntHash(src[docSelector(i)]);
                TCatFeaturesRemap::iterator it = uniqueFeaturesRemap.find(val);
                if (it == uniqueFeaturesRemap.end()) {
                    dstRemapped[i] = static_cast<int>(uniqueFeaturesRemap.size());
                } else {
                    dstRemapped[i] = it->second;
                }
            }
        } else {
            for (size_t i = 0; i < docCount; ++i) {
                const auto val = ConvertFloatCatFeatureToIntHash(src[docSelector(i)]);
                TCatFeaturesRemap::insert_ctx ctx = nullptr;
                TCatFeaturesRemap::iterator it = uniqueFeaturesRemap.find(val, ctx);
                if (it == uniqueFeaturesRemap.end()) {
                    int remap = static_cast<int>(uniqueFeaturesRemap.size());
                    dstValues.push_back(val);
                    it = uniqueFeaturesRemap.emplace_direct(ctx, val, remap);
                    dstRemapped[i] = it->second;
                } else {
                    dstRemapped[i] = it->second;
                }
            }
        }
    }
}