Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 34 additions & 31 deletions src/kernel/replace_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,59 +23,61 @@ class replace_rec_fn {
std::function<optional<expr>(expr const &, unsigned)> m_f;
bool m_use_cache;

expr save_result(expr const & e, unsigned offset, expr r, bool shared) {
if (shared)
m_cache.insert(mk_pair(mk_pair(e.raw(), offset), r));
// `std::unordered_map` keeps element references stable across rehash, so the
// `slot` pointer obtained on descent remains valid after recursive inserts.
expr save_result(expr r, expr * slot) {
if (slot) *slot = r;
return r;
}

expr apply(expr const & e, unsigned offset) {
bool shared = false;
expr * slot = nullptr;
if (m_use_cache && !is_likely_unshared(e)) {
auto it = m_cache.find(mk_pair(e.raw(), offset));
if (it != m_cache.end())
return it->second;
shared = true;
auto p = m_cache.try_emplace(mk_pair(e.raw(), offset));
if (!p.second)
return p.first->second;
slot = &p.first->second;
}
if (optional<expr> r = m_f(e, offset)) {
return save_result(e, offset, std::move(*r), shared);
return save_result(std::move(*r), slot);
} else {
switch (e.kind()) {
case expr_kind::Const: case expr_kind::Sort:
case expr_kind::BVar: case expr_kind::Lit:
case expr_kind::MVar: case expr_kind::FVar:
return save_result(e, offset, e, shared);
return save_result(e, slot);
case expr_kind::MData: {
expr new_e = apply(mdata_expr(e), offset);
return save_result(e, offset, update_mdata(e, new_e), shared);
return save_result(update_mdata(e, new_e), slot);
}
case expr_kind::Proj: {
expr new_e = apply(proj_expr(e), offset);
return save_result(e, offset, update_proj(e, new_e), shared);
return save_result(update_proj(e, new_e), slot);
}
case expr_kind::App: {
expr new_f = apply(app_fn(e), offset);
expr new_a = apply(app_arg(e), offset);
return save_result(e, offset, update_app(e, new_f, new_a), shared);
return save_result(update_app(e, new_f, new_a), slot);
}
case expr_kind::Pi: case expr_kind::Lambda: {
expr new_d = apply(binding_domain(e), offset);
expr new_b = apply(binding_body(e), offset+1);
return save_result(e, offset, update_binding(e, new_d, new_b), shared);
return save_result(update_binding(e, new_d, new_b), slot);
}
case expr_kind::Let: {
expr new_t = apply(let_type(e), offset);
expr new_v = apply(let_value(e), offset);
expr new_b = apply(let_body(e), offset+1);
return save_result(e, offset, update_let(e, new_t, new_v, new_b), shared);
return save_result(update_let(e, new_t, new_v, new_b), slot);
}
}
lean_unreachable();
}
}
public:
template<typename F>
replace_rec_fn(F const & f, bool use_cache):m_f(f), m_use_cache(use_cache) {}
replace_rec_fn(F const & f, bool use_cache):m_f(f), m_use_cache(use_cache) {
}

expr operator()(expr const & e) { return apply(e, 0); }
};
Expand All @@ -88,19 +90,20 @@ class replace_fn {
lean::unordered_map<lean_object *, expr> m_cache;
lean_object * m_f;

expr save_result(expr const & e, expr const & r, bool shared) {
if (shared)
m_cache.insert(mk_pair(e.raw(), r));
// `std::unordered_map` keeps element references stable across rehash, so the
// `slot` pointer obtained on descent remains valid after recursive inserts.
expr save_result(expr const & r, expr * slot) {
if (slot) *slot = r;
return r;
}

expr apply(expr const & e) {
bool shared = false;
expr * slot = nullptr;
if (is_shared(e)) {
auto it = m_cache.find(e.raw());
if (it != m_cache.end())
return it->second;
shared = true;
auto p = m_cache.try_emplace(e.raw());
if (!p.second)
return p.first->second;
slot = &p.first->second;
}

lean_inc(e.raw());
Expand All @@ -109,37 +112,37 @@ class replace_fn {
if (!lean_is_scalar(r)) {
expr e_new(lean_ctor_get(r, 0), true);
lean_dec_ref(r);
return save_result(e, e_new, shared);
return save_result(e_new, slot);
}

switch (e.kind()) {
case expr_kind::Const: case expr_kind::Sort:
case expr_kind::BVar: case expr_kind::Lit:
case expr_kind::MVar: case expr_kind::FVar:
return save_result(e, e, shared);
return save_result(e, slot);
case expr_kind::MData: {
expr new_e = apply(mdata_expr(e));
return save_result(e, update_mdata(e, new_e), shared);
return save_result(update_mdata(e, new_e), slot);
}
case expr_kind::Proj: {
expr new_e = apply(proj_expr(e));
return save_result(e, update_proj(e, new_e), shared);
return save_result(update_proj(e, new_e), slot);
}
case expr_kind::App: {
expr new_f = apply(app_fn(e));
expr new_a = apply(app_arg(e));
return save_result(e, update_app(e, new_f, new_a), shared);
return save_result(update_app(e, new_f, new_a), slot);
}
case expr_kind::Pi: case expr_kind::Lambda: {
expr new_d = apply(binding_domain(e));
expr new_b = apply(binding_body(e));
return save_result(e, update_binding(e, new_d, new_b), shared);
return save_result(update_binding(e, new_d, new_b), slot);
}
case expr_kind::Let: {
expr new_t = apply(let_type(e));
expr new_v = apply(let_value(e));
expr new_b = apply(let_body(e));
return save_result(e, update_let(e, new_t, new_v, new_b), shared);
return save_result(update_let(e, new_t, new_v, new_b), slot);
}}
lean_unreachable();
}
Expand Down
Loading