expr dsimplify_core_fn::visit_let(expr const & e) { if (m_cfg.m_zeta) { return visit(instantiate(let_body(e), let_value(e))); } else { type_context::tmp_locals locals(m_ctx); expr b = e; bool modified = false; while (is_let(b)) { expr t = instantiate_rev(let_type(b), locals.size(), locals.data()); expr v = instantiate_rev(let_value(b), locals.size(), locals.data()); expr new_t = visit(t); expr new_v = visit(v); if (!is_eqp(t, new_t) || !is_eqp(v, new_v)) modified = true; locals.push_let(let_name(b), new_t, new_v); b = let_body(b); } b = instantiate_rev(b, locals.size(), locals.data()); expr new_b = visit(b); if (!is_eqp(b, new_b)) modified = true; if (modified) return locals.mk_lambda(new_b); else return e; } }
bool apply(expr const & a, expr const & b) { if (is_eqp(a, b)) return true; if (a.hash() != b.hash()) return false; if (a.kind() != b.kind()) return false; if (is_var(a)) return var_idx(a) == var_idx(b); if (m_cache.check(a, b)) return true; switch (a.kind()) { case expr_kind::Var: lean_unreachable(); // LCOV_EXCL_LINE case expr_kind::Constant: return const_name(a) == const_name(b) && compare(const_levels(a), const_levels(b), [](level const & l1, level const & l2) { return l1 == l2; }); case expr_kind::Meta: return mlocal_name(a) == mlocal_name(b) && apply(mlocal_type(a), mlocal_type(b)); case expr_kind::Local: return mlocal_name(a) == mlocal_name(b) && apply(mlocal_type(a), mlocal_type(b)) && (!CompareBinderInfo || local_pp_name(a) == local_pp_name(b)) && (!CompareBinderInfo || local_info(a) == local_info(b)); case expr_kind::App: check_system(); return apply(app_fn(a), app_fn(b)) && apply(app_arg(a), app_arg(b)); case expr_kind::Lambda: case expr_kind::Pi: check_system(); return apply(binding_domain(a), binding_domain(b)) && apply(binding_body(a), binding_body(b)) && (!CompareBinderInfo || binding_name(a) == binding_name(b)) && (!CompareBinderInfo || binding_info(a) == binding_info(b)); case expr_kind::Let: check_system(); return apply(let_type(a), let_type(b)) && apply(let_value(a), let_value(b)) && apply(let_body(a), let_body(b)) && (!CompareBinderInfo || let_name(a) == let_name(b)); case expr_kind::Sort: return sort_level(a) == sort_level(b); case expr_kind::Macro: check_system(); if (macro_def(a) != macro_def(b) || macro_num_args(a) != macro_num_args(b)) return false; for (unsigned i = 0; i < macro_num_args(a); i++) { if (!apply(macro_arg(a, i), macro_arg(b, i))) return false; } return true; } lean_unreachable(); // LCOV_EXCL_LINE }
bool is_lt(expr const & a, expr const & b, bool use_hash) { if (is_eqp(a, b)) return false; unsigned wa = get_weight(a); unsigned wb = get_weight(b); if (wa < wb) return true; if (wa > wb) return false; if (a.kind() != b.kind()) return a.kind() < b.kind(); if (use_hash) { if (a.hash() < b.hash()) return true; if (a.hash() > b.hash()) return false; } if (a == b) return false; switch (a.kind()) { case expr_kind::Var: return var_idx(a) < var_idx(b); case expr_kind::Constant: if (const_name(a) != const_name(b)) return const_name(a) < const_name(b); else return is_lt(const_levels(a), const_levels(b), use_hash); case expr_kind::App: if (app_fn(a) != app_fn(b)) return is_lt(app_fn(a), app_fn(b), use_hash); else return is_lt(app_arg(a), app_arg(b), use_hash); case expr_kind::Lambda: case expr_kind::Pi: if (binding_domain(a) != binding_domain(b)) return is_lt(binding_domain(a), binding_domain(b), use_hash); else return is_lt(binding_body(a), binding_body(b), use_hash); case expr_kind::Let: if (let_type(a) != let_type(b)) return is_lt(let_type(a), let_type(b), use_hash); else if (let_value(a) != let_value(b)) return is_lt(let_value(a), let_value(b), use_hash); else return is_lt(let_body(a), let_body(b), use_hash); case expr_kind::Sort: return is_lt(sort_level(a), sort_level(b), use_hash); case expr_kind::Local: case expr_kind::Meta: if (mlocal_name(a) != mlocal_name(b)) return mlocal_name(a) < mlocal_name(b); else return is_lt(mlocal_type(a), mlocal_type(b), use_hash); case expr_kind::Macro: if (macro_def(a) != macro_def(b)) return macro_def(a) < macro_def(b); if (macro_num_args(a) != macro_num_args(b)) return macro_num_args(a) < macro_num_args(b); for (unsigned i = 0; i < macro_num_args(a); i++) { if (macro_arg(a, i) != macro_arg(b, i)) return is_lt(macro_arg(a, i), macro_arg(b, i), use_hash); } return false; } lean_unreachable(); // LCOV_EXCL_LINE }
bool expr_eq_fn::apply(expr const & a, expr const & b) { if (is_eqp(a, b)) return true; if (a.hash() != b.hash()) return false; if (a.kind() != b.kind()) return false; if (is_var(a)) return var_idx(a) == var_idx(b); if (m_counter >= LEAN_EQ_CACHE_THRESHOLD && is_shared(a) && is_shared(b)) { auto p = std::make_pair(a.raw(), b.raw()); if (!m_eq_visited) m_eq_visited.reset(new expr_cell_pair_set); if (m_eq_visited->find(p) != m_eq_visited->end()) return true; m_eq_visited->insert(p); } check_system("expression equality test"); switch (a.kind()) { case expr_kind::Var: lean_unreachable(); // LCOV_EXCL_LINE case expr_kind::Constant: return const_name(a) == const_name(b) && compare(const_levels(a), const_levels(b), [](level const & l1, level const & l2) { return l1 == l2; }); case expr_kind::Local: case expr_kind::Meta: return mlocal_name(a) == mlocal_name(b) && apply(mlocal_type(a), mlocal_type(b)); case expr_kind::App: m_counter++; return apply(app_fn(a), app_fn(b)) && apply(app_arg(a), app_arg(b)); case expr_kind::Lambda: case expr_kind::Pi: m_counter++; return apply(binding_domain(a), binding_domain(b)) && apply(binding_body(a), binding_body(b)) && (!m_compare_binder_info || binding_info(a) == binding_info(b)); case expr_kind::Sort: return sort_level(a) == sort_level(b); case expr_kind::Macro: m_counter++; if (macro_def(a) != macro_def(b) || macro_num_args(a) != macro_num_args(b)) return false; for (unsigned i = 0; i < macro_num_args(a); i++) { if (!apply(macro_arg(a, i), macro_arg(b, i))) return false; } return true; case expr_kind::Let: m_counter++; return apply(let_type(a), let_type(b)) && apply(let_value(a), let_value(b)) && apply(let_body(a), let_body(b)); } lean_unreachable(); // LCOV_EXCL_LINE }
expr replace_visitor::visit_let(expr const & e) { lean_assert(is_let(e)); expr new_t = visit(let_type(e)); expr new_v = visit(let_value(e)); expr new_b = visit(let_body(e)); return update_let(e, new_t, new_v, new_b); }
expr copy(expr const & a) { switch (a.kind()) { case expr_kind::Var: return mk_var(var_idx(a)); case expr_kind::Constant: return mk_constant(const_name(a)); case expr_kind::Type: return mk_type(ty_level(a)); case expr_kind::Value: return mk_value(static_cast<expr_value*>(a.raw())->m_val); case expr_kind::App: return mk_app(num_args(a), begin_args(a)); case expr_kind::Eq: return mk_eq(eq_lhs(a), eq_rhs(a)); case expr_kind::Lambda: return mk_lambda(abst_name(a), abst_domain(a), abst_body(a)); case expr_kind::Pi: return mk_pi(abst_name(a), abst_domain(a), abst_body(a)); case expr_kind::Let: return mk_let(let_name(a), let_type(a), let_value(a), let_body(a)); case expr_kind::MetaVar: return mk_metavar(metavar_idx(a), metavar_ctx(a)); } lean_unreachable(); }
expr apply(expr const & a) { check_system("max_sharing"); auto r = m_expr_cache.find(a); if (r != m_expr_cache.end()) return *r; expr res; switch (a.kind()) { case expr_kind::Var: res = a; break; case expr_kind::Constant: res = update_constant(a, map(const_levels(a), [&](level const & l) { return apply(l); })); break; case expr_kind::Sort: res = update_sort(a, apply(sort_level(a))); break; case expr_kind::App: { expr new_f = apply(app_fn(a)); expr new_a = apply(app_arg(a)); res = update_app(a, new_f, new_a); break; } case expr_kind::Lambda: case expr_kind::Pi: { expr new_d = apply(binding_domain(a)); expr new_b = apply(binding_body(a)); res = update_binding(a, new_d, new_b); break; } case expr_kind::Let: { expr new_t = apply(let_type(a)); expr new_v = apply(let_value(a)); expr new_b = apply(let_body(a)); res = update_let(a, new_t, new_v, new_b); break; } case expr_kind::Meta: case expr_kind::Local: res = update_mlocal(a, apply(mlocal_type(a))); break; case expr_kind::Macro: { buffer<expr> new_args; for (unsigned i = 0; i < macro_num_args(a); i++) new_args.push_back(macro_arg(a, i)); res = update_macro(a, new_args.size(), new_args.data()); break; }} m_expr_cache.insert(res); return res; }
void collect_locals(expr const & e, collected_locals & ls, bool restricted) { if (!has_local(e)) return; expr_set visited; std::function<void(expr const & e)> visit = [&](expr const & e) { if (!has_local(e)) return; if (restricted && is_meta(e)) return; if (visited.find(e) != visited.end()) return; visited.insert(e); switch (e.kind()) { case expr_kind::Var: case expr_kind::Constant: case expr_kind::Sort: break; // do nothing case expr_kind::Local: if (!restricted) visit(mlocal_type(e)); ls.insert(e); break; case expr_kind::Meta: lean_assert(!restricted); visit(mlocal_type(e)); break; case expr_kind::Macro: for (unsigned i = 0; i < macro_num_args(e); i++) visit(macro_arg(e, i)); break; case expr_kind::App: visit(app_fn(e)); visit(app_arg(e)); break; case expr_kind::Lambda: case expr_kind::Pi: visit(binding_domain(e)); visit(binding_body(e)); break; case expr_kind::Let: visit(let_type(e)); visit(let_value(e)); visit(let_body(e)); break; } }; visit(e); }
expr compiler_step_visitor::visit_lambda_let(expr const & e) { type_context::tmp_locals locals(m_ctx); expr t = e; while (true) { /* Types are ignored in compilation steps. So, we do not invoke visit for d. */ if (is_lambda(t)) { expr d = instantiate_rev(binding_domain(t), locals.size(), locals.data()); locals.push_local(binding_name(t), d, binding_info(t)); t = binding_body(t); } else if (is_let(t)) { expr d = instantiate_rev(let_type(t), locals.size(), locals.data()); expr v = visit(instantiate_rev(let_value(t), locals.size(), locals.data())); locals.push_let(let_name(t), d, v); t = let_body(t); } else { break; } } t = instantiate_rev(t, locals.size(), locals.data()); t = visit(t); return copy_tag(e, locals.mk_lambda(t)); }
expr visit_let(expr const & e) { // TODO(Leo): improve return visit(instantiate(let_body(e), let_value(e))); }
expr update_let(expr const & e, expr const & new_type, expr const & new_value, expr const & new_body) { if (!is_eqp(let_type(e), new_type) || !is_eqp(let_value(e), new_value) || !is_eqp(let_body(e), new_body)) return mk_let(let_name(e), new_type, new_value, new_body); else return e; }
void visit_let(expr const & e) { if (should_visit(e)) { visit(instantiate(let_body(e), let_value(e))); } }
bool is_lt(expr const & a, expr const & b, bool use_hash) { if (is_eqp(a, b)) return false; unsigned da = get_depth(a); unsigned db = get_depth(b); if (da < db) return true; if (da > db) return false; if (a.kind() != b.kind()) return a.kind() < b.kind(); if (use_hash) { if (a.hash() < b.hash()) return true; if (a.hash() > b.hash()) return false; } if (a == b) return false; if (is_var(a)) return var_idx(a) < var_idx(b); switch (a.kind()) { case expr_kind::Var: lean_unreachable(); // LCOV_EXCL_LINE case expr_kind::Constant: return const_name(a) < const_name(b); case expr_kind::App: if (num_args(a) != num_args(b)) return num_args(a) < num_args(b); for (unsigned i = 0; i < num_args(a); i++) { if (arg(a, i) != arg(b, i)) return is_lt(arg(a, i), arg(b, i), use_hash); } lean_unreachable(); // LCOV_EXCL_LINE case expr_kind::Lambda: // Remark: we ignore get_abs_name because we want alpha-equivalence case expr_kind::Pi: if (abst_domain(a) != abst_domain(b)) return is_lt(abst_domain(a), abst_domain(b), use_hash); else return is_lt(abst_body(a), abst_body(b), use_hash); case expr_kind::Type: return ty_level(a) < ty_level(b); case expr_kind::Value: return to_value(a) < to_value(b); case expr_kind::Let: if (let_type(a) != let_type(b)) { return is_lt(let_type(a), let_type(b), use_hash); } else if (let_value(a) != let_value(b)){ return is_lt(let_value(a), let_value(b), use_hash); } else { return is_lt(let_body(a), let_body(b), use_hash); } case expr_kind::MetaVar: if (metavar_name(a) != metavar_name(b)) { return metavar_name(a) < metavar_name(b); } else { auto it1 = metavar_lctx(a).begin(); auto it2 = metavar_lctx(b).begin(); auto end1 = metavar_lctx(a).end(); auto end2 = metavar_lctx(b).end(); for (; it1 != end1 && it2 != end2; ++it1, ++it2) { if (it1->kind() != it2->kind()) { return it1->kind() < it2->kind(); } else if (it1->s() != it2->s()) { return it1->s() < it2->s(); } else if (it1->is_inst()) { if (it1->v() != it2->v()) return is_lt(it1->v(), it2->v(), use_hash); } else { if (it1->n() != it2->n()) return it1->n() < it2->n(); } } return it1 == end1 && it2 != end2; } } lean_unreachable(); // LCOV_EXCL_LINE }
expr apply(expr const & a) { bool sh = false; if (is_shared(a)) { auto r = m_cache.find(a.raw()); if (r != m_cache.end()) return r->second; sh = true; } switch (a.kind()) { case expr_kind::Var: case expr_kind::Constant: case expr_kind::Type: case expr_kind::Value: return save_result(a, copy(a), sh); case expr_kind::App: { buffer<expr> new_args; for (expr const & old_arg : args(a)) new_args.push_back(apply(old_arg)); return save_result(a, mk_app(new_args), sh); } case expr_kind::HEq: return save_result(a, mk_heq(apply(heq_lhs(a)), apply(heq_rhs(a))), sh); case expr_kind::Pair: return save_result(a, mk_pair(apply(pair_first(a)), apply(pair_second(a)), apply(pair_type(a))), sh); case expr_kind::Proj: return save_result(a, mk_proj(proj_first(a), apply(proj_arg(a))), sh); case expr_kind::Lambda: return save_result(a, mk_lambda(abst_name(a), apply(abst_domain(a)), apply(abst_body(a))), sh); case expr_kind::Pi: return save_result(a, mk_pi(abst_name(a), apply(abst_domain(a)), apply(abst_body(a))), sh); case expr_kind::Sigma: return save_result(a, mk_sigma(abst_name(a), apply(abst_domain(a)), apply(abst_body(a))), sh); case expr_kind::Let: return save_result(a, mk_let(let_name(a), apply(let_type(a)), apply(let_value(a)), apply(let_body(a))), sh); case expr_kind::MetaVar: return save_result(a, update_metavar(a, [&](local_entry const & e) -> local_entry { if (e.is_inst()) return mk_inst(e.s(), apply(e.v())); else return e; }), sh); } lean_unreachable(); // LCOV_EXCL_LINE }
optional<substitution> fo_unify(expr e1, expr e2) { substitution s; unsigned i1, i2; buffer<expr_pair> todo; todo.emplace_back(e1, e2); while (!todo.empty()) { auto p = todo.back(); todo.pop_back(); e1 = find(s, p.first); e2 = find(s, p.second); if (e1 != e2) { if (is_metavar_wo_local_context(e1)) { assign(s, e1, e2); } else if (is_metavar_wo_local_context(e2)) { assign(s, e2, e1); } else if (is_equality(e1) && is_equality(e2)) { expr_pair p1 = get_equality_args(e1); expr_pair p2 = get_equality_args(e2); todo.emplace_back(p1.second, p2.second); todo.emplace_back(p1.first, p2.first); } else { if (e1.kind() != e2.kind()) return optional<substitution>(); switch (e1.kind()) { case expr_kind::Var: case expr_kind::Constant: case expr_kind::Type: case expr_kind::Value: case expr_kind::MetaVar: return optional<substitution>(); case expr_kind::App: i1 = num_args(e1); i2 = num_args(e2); while (i1 > 0 && i2 > 0) { --i1; --i2; if (i1 == 0 && i2 > 0) { todo.emplace_back(arg(e1, i1), mk_app(i2+1, begin_args(e2))); } else if (i2 == 0 && i1 > 0) { todo.emplace_back(mk_app(i1+1, begin_args(e1)), arg(e2, i2)); } else { todo.emplace_back(arg(e1, i1), arg(e2, i2)); } } break; case expr_kind::Lambda: case expr_kind::Pi: todo.emplace_back(abst_body(e1), abst_body(e2)); todo.emplace_back(abst_domain(e1), abst_domain(e2)); break; case expr_kind::Let: todo.emplace_back(let_body(e1), let_body(e2)); todo.emplace_back(let_value(e1), let_value(e2)); if (static_cast<bool>(let_type(e1)) != static_cast<bool>(let_type(e2))) return optional<substitution>(); if (let_type(e1)) { lean_assert(let_type(e2)); todo.emplace_back(*let_type(e1), *let_type(e2)); } break; } } } } return optional<substitution>(s); }