示例#1
0
void Approximator::validate(const Approximation& approx) {
    POMAGMA_ASSERT_EQ(approx.lower.item_dim(), m_item_dim);
    POMAGMA_ASSERT_EQ(approx.upper.item_dim(), m_item_dim);

    std::vector<Ob> set;
    for (auto iter = approx.lower.iter_insn(approx.upper); iter.ok();
         iter.next()) {
        set.push_back(*iter);
    }
    for (auto x : set) {
        for (auto y : set) {
            POMAGMA_ASSERT(not m_nless.find(x, y),
                           "approximation contains distinct obs: " << x << ", "
                                                                   << y);
        }
    }

    Approximation closed = unknown();
    closed = approx;
    {
        DenseSet temp_set(m_item_dim);
        close(closed, temp_set);
    }
    POMAGMA_ASSERT_EQ(closed.ob, approx.ob);
    POMAGMA_ASSERT(closed.upper == approx.upper, "upper set is not closed");
    POMAGMA_ASSERT(closed.lower == approx.lower, "lower set is not closed");
}
示例#2
0
inline void Approximator::map(const BinaryFunction& fun,
                              const DenseSet& lhs_set, const DenseSet& rhs_set,
                              DenseSet& val_set, DenseSet& temp_set) {
    POMAGMA_ASSERT_EQ(lhs_set.item_dim(), m_item_dim);
    POMAGMA_ASSERT_EQ(rhs_set.item_dim(), m_item_dim);
    POMAGMA_ASSERT_EQ(val_set.item_dim(), m_item_dim);
    POMAGMA_ASSERT_EQ(temp_set.item_dim(), m_item_dim);

    for (auto iter = lhs_set.iter(); iter.ok(); iter.next()) {
        Ob lhs = *iter;

        // optimize for special cases of APP and COMP
        if (Ob lhs_top = fun.find(lhs, m_top)) {
            if (Ob lhs_bot = fun.find(lhs, m_bot)) {
                bool lhs_is_constant = (lhs_top == lhs_bot);
                if (lhs_is_constant) {
                    val_set.raw_insert(lhs_top);
                    continue;
                }
            }
        }

        temp_set.set_insn(rhs_set, fun.get_Lx_set(lhs));
        for (auto iter = temp_set.iter(); iter.ok(); iter.next()) {
            Ob rhs = *iter;
            Ob val = fun.find(lhs, rhs);
            val_set.raw_insert(val);
        }
    }
}
示例#3
0
void Carrier::validate () const
{
    UniqueLock lock(m_mutex);

    POMAGMA_INFO("Validating Carrier");

    m_support.validate();

    size_t actual_item_count = 0;
    size_t actual_rep_count = 0;
    for (Ob i = 1; i <= item_dim(); ++i) {
        Ob rep = m_reps[i].load();
        if (contains(i)) {
            POMAGMA_ASSERT(rep, "supported object has no rep: " << i);
            POMAGMA_ASSERT(rep <= i, "rep out of order: " << rep << "," << i);
            ++actual_item_count;
            if (rep == i) {
                ++actual_rep_count;
            }
        } else {
            POMAGMA_ASSERT(rep == 0, "unsupported object has rep: " << i);
        }
    }
    POMAGMA_ASSERT_EQ(item_count(), actual_item_count);
    POMAGMA_ASSERT_EQ(rep_count(), actual_rep_count);
}
示例#4
0
文件: router.cpp 项目: fritzo/pomagma
void Router::update_weights(
    const std::vector<float>& probs,
    const std::unordered_map<std::string, size_t>& symbol_counts,
    const std::unordered_map<Ob, size_t>& ob_counts,
    std::vector<float>& symbol_weights, std::vector<float>& ob_weights,
    float reltol) const {
    POMAGMA_INFO("Updating weights");
    const size_t symbol_count = m_types.size();
    const size_t ob_count = m_carrier.item_count();
    POMAGMA_ASSERT_EQ(probs.size(), 1 + ob_count);
    POMAGMA_ASSERT_EQ(symbol_weights.size(), symbol_count);
    POMAGMA_ASSERT_EQ(ob_weights.size(), 1 + ob_count);
    const float max_increase = 1.0 + reltol;

    std::vector<float> temp_symbol_weights(symbol_weights.size());
    std::vector<float> temp_ob_weights(ob_weights.size());

update_weights_loop : {
    POMAGMA_DEBUG("distributing route weight");

    std::fill(temp_symbol_weights.begin(), temp_symbol_weights.end(), 0);
    for (size_t i = 0; i < symbol_count; ++i) {
        temp_symbol_weights[i] = map_get(symbol_counts, m_types[i].name, 0);
    }

    std::fill(temp_ob_weights.begin(), temp_ob_weights.end(), 0);
    for (const auto& pair : ob_counts) {
        temp_ob_weights[pair.first] = pair.second;
    }

#pragma omp parallel for schedule(dynamic, 1)
    for (size_t i = 0; i < ob_count; ++i) {
        Ob ob = 1 + i;

        const float weight = ob_weights[ob] / probs[ob];
        for (const Segment& segment : iter_val(ob)) {
            float part = weight * get_prob(segment, probs);
            add_weight(part, segment, temp_symbol_weights, temp_ob_weights);
        }
    }

    std::swap(symbol_weights, temp_symbol_weights);
    std::swap(ob_weights, temp_ob_weights);

    for (size_t i = 0; i < symbol_count; ++i) {
        if (symbol_weights[i] > temp_symbol_weights[i] * max_increase) {
            goto update_weights_loop;
        }
    }

    for (size_t i = 0; i < ob_count; ++i) {
        Ob ob = 1 + i;
        if (ob_weights[ob] > temp_ob_weights[ob] * max_increase) {
            goto update_weights_loop;
        }
    }
}
}
示例#5
0
inline void Approximator::map(const InjectiveFunction& fun,
                              const DenseSet& key_set, DenseSet& val_set,
                              DenseSet& temp_set) {
    POMAGMA_ASSERT_EQ(key_set.item_dim(), m_item_dim);
    POMAGMA_ASSERT_EQ(val_set.item_dim(), m_item_dim);
    POMAGMA_ASSERT_EQ(temp_set.item_dim(), m_item_dim);

    temp_set.set_insn(fun.defined(), key_set);
    for (auto iter = temp_set.iter(); iter.ok(); iter.next()) {
        Ob key = *iter;
        Ob val = fun.find(key);
        val_set.raw_insert(val);
    }
}
示例#6
0
inline void random_init (Carrier & carrier, rng_t & rng)
{
    POMAGMA_ASSERT_EQ(carrier.item_count(), 0);
    const size_t size = carrier.item_dim();
    for (Ob i = 1; i <= size; ++i) {
        POMAGMA_ASSERT(carrier.unsafe_insert(), "insertion failed");
    }
    POMAGMA_ASSERT_EQ(carrier.item_count(), size);
    std::bernoulli_distribution randomly_remove(0.5);
    for (Ob i = 1; i <= size; ++i) {
        if (randomly_remove(rng)) {
            carrier.unsafe_remove(i);
        }
    }
}
示例#7
0
文件: router.cpp 项目: fritzo/pomagma
void Router::update_probs(std::vector<float>& probs, float reltol) const {
    POMAGMA_INFO("Updating ob probs");
    const size_t item_count = m_carrier.item_count();
    POMAGMA_ASSERT_EQ(probs.size(), 1 + item_count);
    const float max_increase = 1.0 + reltol;

    bool changed = true;
    while (changed) {
        changed = false;

        POMAGMA_DEBUG("accumulating route probabilities");

#pragma omp parallel for schedule(dynamic, 1)
        for (size_t i = 0; i < item_count; ++i) {
            Ob ob = 1 + i;
            float& prob = probs[ob];

            float temp_prob = 0;
            for (const Segment& segment : iter_val(ob)) {
                temp_prob += get_prob(segment, probs);
            }

            if (temp_prob > prob * max_increase) {
                changed = true;
            }

            prob = temp_prob;
        }
    }
}
示例#8
0
inline void Approximator::map(const SymmetricFunction& fun,
                              const DenseSet& lhs_set, const DenseSet& rhs_set,
                              DenseSet& val_set, DenseSet& temp_set) {
    POMAGMA_ASSERT_EQ(lhs_set.item_dim(), m_item_dim);
    POMAGMA_ASSERT_EQ(rhs_set.item_dim(), m_item_dim);
    POMAGMA_ASSERT_EQ(val_set.item_dim(), m_item_dim);
    POMAGMA_ASSERT_EQ(temp_set.item_dim(), m_item_dim);

    for (auto iter = lhs_set.iter(); iter.ok(); iter.next()) {
        Ob lhs = *iter;
        temp_set.set_insn(rhs_set, fun.get_Lx_set(lhs));
        for (auto iter = temp_set.iter(); iter.ok(); iter.next()) {
            Ob rhs = *iter;
            Ob val = fun.find(lhs, rhs);
            val_set.raw_insert(val);
        }
    }
}
示例#9
0
void Approximator::close(Approximation& approx, DenseSet& temp_set) {
    POMAGMA_ASSERT_EQ(temp_set.item_dim(), m_item_dim);
    for (size_t iter = 0;; ++iter) {
        POMAGMA_DEBUG1("close step " << iter);
        if (try_close(approx, temp_set)) {
            return;
        }
    }
}
示例#10
0
void BinaryRelation::validate_disjoint(const BinaryRelation& other) const {
    POMAGMA_INFO("Validating disjoint pair of BinaryRelations");

    // validate supports agree
    POMAGMA_ASSERT_EQ(support().item_dim(), other.support().item_dim());
    POMAGMA_ASSERT_EQ(support().count_items(), other.support().count_items());
    POMAGMA_ASSERT(support() == other.support(),
                   "BinaryRelation supports differ");

    // validate disjointness
    DenseSet this_set(item_dim(), nullptr);
    DenseSet other_set(item_dim(), nullptr);
    for (auto i = support().iter(); i.ok(); i.next()) {
        this_set.init(m_lines.Lx(*i));
        other_set.init(other.m_lines.Lx(*i));
        POMAGMA_ASSERT(this_set.disjoint(other_set),
                       "BinaryRelations intersect at row " << *i);
    }
}
示例#11
0
文件: syntax.cpp 项目: fritzo/pomagma
void Structure::assert_valid() {
    POMAGMA_INFO("Validating solver::Structure");

    // Check atoms.
    POMAGMA_ASSERT(term_arity(TermAtom::TOP) == TermArity::TOP, "Missing TOP");
    POMAGMA_ASSERT(term_arity(TermAtom::BOT) == TermArity::BOT, "Missing BOT");
    POMAGMA_ASSERT(term_arity(TermAtom::I) == TermArity::I, "Missing I");
    POMAGMA_ASSERT(term_arity(TermAtom::K) == TermArity::K, "Missing K");
    POMAGMA_ASSERT(term_arity(TermAtom::B) == TermArity::B, "Missing B");
    POMAGMA_ASSERT(term_arity(TermAtom::C) == TermArity::C, "Missing C");
    POMAGMA_ASSERT(term_arity(TermAtom::S) == TermArity::S, "Missing S");

    // Check terms.
    const Term max_term = term_arity_.size() - 1;
    for (Term term = 1; term <= max_term; ++term) {
        const TermArity arity = term_arity(term);
        switch (arity) {
            case TermArity::IVAR: {
                const unsigned rank = ivar_arg(term);
                POMAGMA_ASSERT_EQ(term, ivar(rank));
                break;
            }
            case TermArity::NVAR: {
                const std::string& name = nvar_arg(term);
                POMAGMA_ASSERT_EQ(term, nvar(name));
                break;
            }
            case TermArity::APP: {
                Term lhs;
                Term rhs;
                std::tie(lhs, rhs) = app_arg(term);
                POMAGMA_ASSERT_EQ(term, app(lhs, rhs));
                break;
            }
            case TermArity::JOIN: {
                Term lhs;
                Term rhs;
                std::tie(lhs, rhs) = join_arg(term);
                POMAGMA_ASSERT_EQ(term, join(lhs, rhs));
                break;
            }
            default:
                break;
        }
    }

    // Check literals.
    const Literal max_lit = less_arg_.size() - 1;
    for (Literal lit = 1; lit <= max_lit; ++lit) {
        Term lhs;
        Term rhs;

        std::tie(lhs, rhs) = literal_arg(lit);
        POMAGMA_ASSERT_EQ(lit, less(lhs, rhs));

        std::tie(lhs, rhs) = literal_arg(-lit);
        POMAGMA_ASSERT_EQ(-lit, nless(lhs, rhs));
    }
}
示例#12
0
void test_merge (Carrier & carrier, rng_t & rng)
{
    POMAGMA_INFO("Checking unsafe_merge");
    const DenseSet & support = carrier.support();
    size_t merge_count = 0;
    g_merge_count = 0;
    std::bernoulli_distribution randomly_merge(0.1);
    for (auto rep_iter = support.iter(); rep_iter.ok(); rep_iter.next())
    for (auto dep_iter = support.iter(); dep_iter.ok(); dep_iter.next()) {
        Ob dep = carrier.find(*dep_iter);
        Ob rep = carrier.find(*rep_iter);
        if ((rep < dep) and randomly_merge(rng)) {
            carrier.merge(dep, rep);
            ++merge_count;
            break;
        }
    }
    POMAGMA_ASSERT_EQ(merge_count, g_merge_count);
}
示例#13
0
void remove_deps (Carrier & carrier, Function & fun)
{
    POMAGMA_INFO("Merging deps");
    const DenseSet & support = carrier.support();
    bool merged;
    do {
        merged = false;
        for (auto iter = support.iter(); iter.ok(); iter.next()) {
            Ob dep = *iter;
            if (carrier.find(dep) != dep) {
                fun.unsafe_merge(dep);
                carrier.unsafe_remove(dep);
                merged = true;
            }
        }
    } while (merged);
    fun.update_values();
    POMAGMA_ASSERT_EQ(carrier.rep_count(), carrier.item_count());
    fun.validate();
}
示例#14
0
文件: router.cpp 项目: fritzo/pomagma
void Router::fit_language(
    const std::unordered_map<std::string, size_t>& symbol_counts,
    const std::unordered_map<Ob, size_t>& ob_counts, float reltol) {
    POMAGMA_INFO("Fitting language");
    const size_t item_count = m_carrier.item_count();
    std::vector<float> ob_probs(1 + item_count, 0);
    std::vector<float> ob_weights(1 + item_count, 0);
    std::vector<float> symbol_weights(m_types.size(), 0);
    POMAGMA_ASSERT_EQ(m_types.size(), m_language.size());
    const float max_increase = 1.0 + reltol;

    bool changed = true;
    while (changed) {
        changed = false;

        update_probs(ob_probs, reltol);

        update_weights(ob_probs, symbol_counts, ob_counts, symbol_weights,
                       ob_weights, reltol);

        POMAGMA_DEBUG("optimizing language");
        float total_weight = 0;
        for (float weight : symbol_weights) {
            total_weight += weight;
        }
        for (size_t i = 0; i < m_types.size(); ++i) {
            SegmentType& type = m_types[i];
            float new_prob = symbol_weights[i] / total_weight;
            float old_prob = type.prob;
            type.prob = new_prob;
            m_language[type.name] = new_prob;

            if (new_prob > old_prob * max_increase) {
                changed = true;
            }
        }
    }
}
示例#15
0
// Inference rules, in order of appearance
//
//                LESS x y   LESS x z
//   ----------   -------------------
//   LESS x TOP     LESS x RAND y z
//
//                LESS y x   LESS z x   LESS y x   LESS z x
//   ----------   -------------------   -------------------
//   LESS BOT x     LESS JOIN y z x       LESS RAND y z x
//
bool Approximator::try_close(Approximation& approx, DenseSet& temp_set) {
    POMAGMA_ASSERT_EQ(approx.lower.item_dim(), m_item_dim);
    POMAGMA_ASSERT_EQ(approx.upper.item_dim(), m_item_dim);
    POMAGMA_ASSERT_EQ(temp_set.item_dim(), m_item_dim);

    Approximation start = unknown();
    start = approx;

    if (approx.ob) {
        approx.upper.raw_insert(approx.ob);
        approx.lower.raw_insert(approx.ob);
    }

    approx.upper.raw_insert(m_top);
    for (auto iter = approx.upper.iter(); iter.ok(); iter.next()) {
        Ob ob = *iter;
        approx.upper += m_less.get_Lx_set(ob);

        if (m_rand) {
            temp_set.set_insn(approx.upper, m_rand->get_Lx_set(ob));
            for (auto iter = temp_set.iter(); iter.ok(); iter.next()) {
                Ob other = *iter;
                if (other >= ob) {
                    break;
                }
                Ob val = m_rand->find(ob, other);
                approx.upper.raw_insert(val);
            }
        }
    }

    approx.lower.raw_insert(m_bot);
    for (auto iter = approx.lower.iter(); iter.ok(); iter.next()) {
        Ob ob = *iter;
        approx.lower += m_less.get_Rx_set(ob);

        if (m_join) {
            temp_set.set_insn(approx.lower, m_join->get_Lx_set(ob));
            for (auto iter = temp_set.iter(); iter.ok(); iter.next()) {
                Ob other = *iter;
                if (other >= ob) {
                    break;
                }
                Ob val = m_join->find(ob, other);
                approx.lower.raw_insert(val);
            }
        }

        if (m_rand) {
            temp_set.set_insn(approx.lower, m_rand->get_Lx_set(ob));
            for (auto iter = temp_set.iter(); iter.ok(); iter.next()) {
                Ob other = *iter;
                if (other >= ob) {
                    break;
                }
                Ob val = m_rand->find(ob, other);
                approx.lower.raw_insert(val);
            }
        }
    }

    if (not approx.ob) {
        temp_set.set_insn(approx.upper, approx.lower);
        for (auto iter = temp_set.iter(); iter.ok(); iter.next()) {
            approx.ob = *iter;
            break;
        }
    }

    return approx == start;
}