示例#1
0
inline void KindProposer::validate (const CrossCat & cross_cat) const
{
    if (LOOM_DEBUG_LEVEL >= 1) {
        LOOM_ASSERT_EQ(kinds.size(), cross_cat.kinds.size());
        for (const auto & kind : kinds) {
            LOOM_ASSERT_EQ(kind.model.schema, cross_cat.schema);
            kind.mixture.validate(kind.model);
        }
        for (size_t i = 0; i < kinds.size(); ++i) {
            size_t proposer_group_count =
                kinds[i].mixture.clustering.counts().size();
            size_t cross_cat_group_count =
                cross_cat.kinds[i].mixture.clustering.counts().size();
            LOOM_ASSERT_EQ(proposer_group_count, cross_cat_group_count);
        }
    }
}
示例#2
0
inline void CrossCat::simplify (
        std::vector<ProductValue::Diff> & partial_diffs) const
{
    if (LOOM_DEBUG_LEVEL >= 1) {
        LOOM_ASSERT_EQ(partial_diffs.size(), kinds.size());
    }
#define LOOM_SIMPLIFY_DURING_INFERENCE
#ifdef LOOM_SIMPLIFY_DURING_INFERENCE
    auto diff = partial_diffs.begin();
    for (auto & kind : kinds) {
        kind.model.schema.simplify(*diff++);
    }
#endif // LOOM_SIMPLIFY_DURING_INFERENCE
}
示例#3
0
inline void CrossCat::validate () const
{
    if (LOOM_DEBUG_LEVEL >= 1) {
        LOOM_ASSERT_LT(0, schema.total_size());
        ValueSchema expected_schema;
        for (const auto & kind : kinds) {
            kind.model.validate();
            kind.mixture.validate(kind.model);
            expected_schema += kind.model.schema;
        }
        LOOM_ASSERT_EQ(schema, expected_schema);
        for (auto & tare : tares) {
            schema.validate(tare);
        }
    }
    if (LOOM_DEBUG_LEVEL >= 2) {
        splitter.validate(schema, featureid_to_kindid, kinds.size());
        for (size_t f = 0; f < featureid_to_kindid.size(); ++f) {
            size_t k = featureid_to_kindid[f];
            const auto & featureids = kinds[k].featureids;
            LOOM_ASSERT(
                featureids.find(f) != featureids.end(),
                "kind.featureids is missing " << f);
        }
        for (size_t k = 0; k < kinds.size(); ++k) {
            for (size_t f : kinds[k].featureids) {
                LOOM_ASSERT_EQ(featureid_to_kindid[f], k);
            }
        }
        for (size_t k = 0; k < kinds.size(); ++k) {
            LOOM_ASSERT_EQ(kinds[k].model.tares.size(), tares.size());
        }
    }
    if (LOOM_DEBUG_LEVEL >= 3) {
        std::vector<size_t> row_counts;
        for (const auto & kind : kinds) {
            row_counts.push_back(kind.mixture.count_rows());
        }
        for (size_t k = 1; k < kinds.size(); ++k) {
            LOOM_ASSERT_EQ(row_counts[k], row_counts[0]);
            LOOM_ASSERT_EQ(
                kinds[k].mixture.maintaining_cache,
                kinds[0].mixture.maintaining_cache);
        }
        std::vector<ProductValue> partial_tares;
        for (size_t id = 0; id < tares.size(); ++id) {
            splitter.split(tares[id], partial_tares);
            for (size_t k = 0; k < kinds.size(); ++k) {
                LOOM_ASSERT_EQ(partial_tares[k], kinds[k].model.tares[id]);
            }
        }
    }
}