diff --git a/src/ast/rewriter/var_subst.cpp b/src/ast/rewriter/var_subst.cpp index ec33bd2650..532a49ec5f 100644 --- a/src/ast/rewriter/var_subst.cpp +++ b/src/ast/rewriter/var_subst.cpp @@ -52,6 +52,20 @@ expr_ref var_subst::operator()(expr * n, unsigned num_args, expr * const * args) rep(n, result); return result; } + if (is_app(n) && all_of(*to_app(n), [&](expr* arg) { return is_ground(arg) || is_var(arg); })) { + ptr_buffer new_args; + for (auto arg : *to_app(n)) { + if (is_ground(arg)) + new_args.push_back(arg); + else { + unsigned idx = to_var(arg)->get_idx(); + new_args.push_back(m_std_order ? args[idx] : args[num_args - idx - 1]); + } + } + result = m.mk_app(to_app(n)->get_decl(), new_args.size(), new_args.data()); + // verbose_stream() << result << "\n"; + return result; + } SASSERT(is_well_sorted(result.m(), n)); m_reducer.reset(); if (m_std_order) diff --git a/src/cmd_context/cmd_context.cpp b/src/cmd_context/cmd_context.cpp index 5f5195bb9d..ca9bbec268 100644 --- a/src/cmd_context/cmd_context.cpp +++ b/src/cmd_context/cmd_context.cpp @@ -33,7 +33,6 @@ Module Name: #include "ast/fpa_decl_plugin.h" #include "ast/special_relations_decl_plugin.h" #include "ast/ast_pp.h" -#include "ast/rewriter/var_subst.h" #include "ast/pp.h" #include "ast/ast_smt2_pp.h" #include "ast/ast_ll_pp.h" @@ -406,8 +405,7 @@ void cmd_context::insert_macro(symbol const& s, unsigned arity, sort*const* doma recfun::promise_def d = p.ensure_def(s, arity, domain, t->get_sort(), false); // recursive functions have opposite calling convention from macros! - var_subst sub(m(), true); - expr_ref tt = sub(t, rvars); + expr_ref tt = std_subst()(t, rvars); p.set_definition(replace, d, true, vars.size(), vars.data(), tt); register_fun(s, d.get_def()->get_decl()); } @@ -461,7 +459,6 @@ bool cmd_context::macros_find(symbol const& s, unsigned n, expr*const* args, exp if (eq) { t = d.m_body; t = sub(t); - verbose_stream() << "macro " << t << "\n"; ptr_buffer domain; for (unsigned i = 0; i < n; ++i) domain.push_back(args[i]->get_sort()); @@ -1257,9 +1254,8 @@ bool cmd_context::try_mk_macro_app(symbol const & s, unsigned num_args, expr * c tout << "s: " << s << "\n"; tout << "body:\n" << mk_ismt2_pp(_t, m()) << "\n"; tout << "args:\n"; for (unsigned i = 0; i < num_args; i++) tout << mk_ismt2_pp(args[i], m()) << "\n" << mk_pp(args[i]->get_sort(), m()) << "\n";); - var_subst subst(m(), false); scoped_rlimit no_limit(m().limit(), 0); - result = subst(_t, coerced_args); + result = rev_subst()(_t, coerced_args); if (well_sorted_check_enabled() && !is_well_sorted(m(), result)) throw cmd_exception("invalid macro application, sort mismatch ", s); return true; @@ -1524,6 +1520,8 @@ void cmd_context::reset(bool finalize) { if (m_own_manager) { dealloc(m_manager); m_manager = nullptr; + m_std_subst = nullptr; + m_rev_subst = nullptr; m_manager_initialized = false; } else { diff --git a/src/cmd_context/cmd_context.h b/src/cmd_context/cmd_context.h index cd43203a7f..dde1d7962e 100644 --- a/src/cmd_context/cmd_context.h +++ b/src/cmd_context/cmd_context.h @@ -33,6 +33,7 @@ Module Name: #include "ast/datatype_decl_plugin.h" #include "ast/recfun_decl_plugin.h" #include "ast/rewriter/seq_rewriter.h" +#include "ast/rewriter/var_subst.h" #include "ast/converters/generic_model_converter.h" #include "solver/solver.h" #include "solver/check_logic.h" @@ -280,6 +281,7 @@ class cmd_context : public progress_callback, public tactic_manager, public ast_ ptr_vector m_assertions; std::vector m_assertion_strings; ptr_vector m_assertion_names; // named assertions are represented using boolean variables. + scoped_ptr m_std_subst, m_rev_subst; struct scope { unsigned m_func_decls_stack_lim; @@ -317,6 +319,9 @@ class cmd_context : public progress_callback, public tactic_manager, public ast_ scoped_ptr m_pp_env; pp_env & get_pp_env() const; + var_subst& std_subst() { if (!m_std_subst) m_std_subst = alloc(var_subst, m(), true); return *m_std_subst; } + var_subst& rev_subst() { if (!m_rev_subst) m_rev_subst = alloc(var_subst, m(), false); return *m_rev_subst; } + void register_builtin_sorts(decl_plugin * p); void register_builtin_ops(decl_plugin * p); void load_plugin(symbol const & name, bool install_names, svector& fids);