diff --git a/cli/main.cpp b/cli/main.cpp index 36b92f83be..77a4c49c85 100644 --- a/cli/main.cpp +++ b/cli/main.cpp @@ -45,31 +45,32 @@ int main(int argc, char** argv) { // clang-format off auto cli = lyra::cli() | lyra::help(show_help) - | lyra::opt(show_version )["-v"]["--version" ]("Display version info and exit.") - | lyra::opt(list_search_paths )["-l"]["--list-search-paths" ]("List search paths in order and exit.") - | lyra::opt(clang, "clang" )["-c"]["--clang" ]("Path to clang executable (default: '" THORIN_WHICH " clang').") - | lyra::opt(plugins, "plugin" )["-p"]["--plugin" ]("Dynamically load plugin.") - | lyra::opt(search_paths, "path" )["-P"]["--plugin-path" ]("Path to search for plugins.") - | lyra::opt(inc_verbose )["-V"]["--verbose" ]("Verbose mode. Multiple -V options increase the verbosity. The maximum is 4.").cardinality(0, 4) - | lyra::opt(opt, "level" )["-O"]["--optimize" ]("Optimization level (default: 2).") - | lyra::opt(output[Dot ], "file" ) ["--output-dot" ]("Emits the Thorin program as a graph using Graphviz' DOT language.") - | lyra::opt(output[H ], "file" ) ["--output-h" ]("Emits a header file to be used to interface with a plugin in C++.") - | lyra::opt(output[LL ], "file" ) ["--output-ll" ]("Compiles the Thorin program to LLVM.") - | lyra::opt(output[Md ], "file" ) ["--output-md" ]("Emits the input formatted as Markdown.") - | lyra::opt(output[Thorin], "file" )["-o"]["--output-thorin" ]("Emits the Thorin program again.") - | lyra::opt(flags.bootstrap ) ["--bootstrap" ]("Puts thorin into \"bootstrap mode\". This means a '.plugin' directive has the same effect as an '.import' and will not load a library. In addition, no standard plugins will be loaded.") - | lyra::opt(flags.dump_gid, "level" ) ["--dump-gid" ]("Dumps gid of inline expressions as a comment in output if > 0. Use a of 2 to also emit the gid of trivial defs.") - | lyra::opt(flags.dump_recursive ) ["--dump-recursive" ]("Dumps Thorin program with a simple recursive algorithm that is not readable again from Thorin but is less fragile and also works for broken Thorin programs.") - | lyra::opt(flags.aggressive_lam_spec) ["--aggr-lam-spec" ]("Overrides LamSpec behavior to follow recursive calls.") + | lyra::opt(show_version )["-v"]["--version" ]("Display version info and exit.") + | lyra::opt(list_search_paths )["-l"]["--list-search-paths" ]("List search paths in order and exit.") + | lyra::opt(clang, "clang" )["-c"]["--clang" ]("Path to clang executable (default: '" THORIN_WHICH " clang').") + | lyra::opt(plugins, "plugin" )["-p"]["--plugin" ]("Dynamically load plugin.") + | lyra::opt(search_paths, "path" )["-P"]["--plugin-path" ]("Path to search for plugins.") + | lyra::opt(inc_verbose )["-V"]["--verbose" ]("Verbose mode. Multiple -V options increase the verbosity. The maximum is 4.").cardinality(0, 4) + | lyra::opt(opt, "level" )["-O"]["--optimize" ]("Optimization level (default: 2).") + | lyra::opt(output[Dot ], "file" ) ["--output-dot" ]("Emits the Thorin program as a graph using Graphviz' DOT language.") + | lyra::opt(output[H ], "file" ) ["--output-h" ]("Emits a header file to be used to interface with a plugin in C++.") + | lyra::opt(output[LL ], "file" ) ["--output-ll" ]("Compiles the Thorin program to LLVM.") + | lyra::opt(output[Md ], "file" ) ["--output-md" ]("Emits the input formatted as Markdown.") + | lyra::opt(output[Thorin], "file" )["-o"]["--output-thorin" ]("Emits the Thorin program again.") + | lyra::opt(flags.bootstrap ) ["--bootstrap" ]("Puts thorin into \"bootstrap mode\". This means a '.plugin' directive has the same effect as an '.import' and will not load a library. In addition, no standard plugins will be loaded.") + | lyra::opt(flags.dump_gid, "level" ) ["--dump-gid" ]("Dumps gid of inline expressions as a comment in output if > 0. Use a of 2 to also emit the gid of trivial defs.") + | lyra::opt(flags.dump_recursive ) ["--dump-recursive" ]("Dumps Thorin program with a simple recursive algorithm that is not readable again from Thorin but is less fragile and also works for broken Thorin programs.") + | lyra::opt(flags.aggressive_lam_spec ) ["--aggr-lam-spec" ]("Overrides LamSpec behavior to follow recursive calls.") + | lyra::opt(flags.scalerize_threshold, "threshold") ["--scalerize-threshold" ]("Thorin will not scalerize tuples/packs/sigmas/arrays with a number of elements greater than or equal this threshold.") #ifdef THORIN_ENABLE_CHECKS - | lyra::opt(breakpoints, "gid" )["-b"]["--break" ]("*Triggers breakpoint upon construction of node with global id . Useful when running in a debugger.") - | lyra::opt(flags.reeval_breakpoints ) ["--reeval-breakpoints" ]("*Triggers breakpoint even upon unfying a node that has already been built.") - | lyra::opt(flags.break_on_alpha_unequal) ["--break-on-alpha-unequal"]("*Triggers breakpoint as soon as two expressions turn out to be not alpha-equivalent.") - | lyra::opt(flags.break_on_error ) ["--break-on-error" ]("*Triggers breakpoint on ELOG.") - | lyra::opt(flags.break_on_warn ) ["--break-on-warn" ]("*Triggers breakpoint on WLOG.") - | lyra::opt(flags.trace_gids ) ["--trace-gids" ]("*Output gids during World::unify/insert.") + | lyra::opt(breakpoints, "gid" )["-b"]["--break" ]("*Triggers breakpoint upon construction of node with global id . Useful when running in a debugger.") + | lyra::opt(flags.reeval_breakpoints ) ["--reeval-breakpoints" ]("*Triggers breakpoint even upon unfying a node that has already been built.") + | lyra::opt(flags.break_on_alpha_unequal ) ["--break-on-alpha-unequal"]("*Triggers breakpoint as soon as two expressions turn out to be not alpha-equivalent.") + | lyra::opt(flags.break_on_error ) ["--break-on-error" ]("*Triggers breakpoint on ELOG.") + | lyra::opt(flags.break_on_warn ) ["--break-on-warn" ]("*Triggers breakpoint on WLOG.") + | lyra::opt(flags.trace_gids ) ["--trace-gids" ]("*Output gids during World::unify/insert.") #endif - | lyra::arg(input, "file" ) ("Input file.") + | lyra::arg(input, "file" ) ("Input file.") ; // clang-format on diff --git a/dialects/affine/passes/lower_for.cpp b/dialects/affine/passes/lower_for.cpp index 74f7e9d7cf..994b35f8e4 100644 --- a/dialects/affine/passes/lower_for.cpp +++ b/dialects/affine/passes/lower_for.cpp @@ -8,6 +8,26 @@ namespace thorin::affine { +namespace { + +const Def* merge_s(World& w, Ref elem, Ref sigma, Ref mem) { + if (mem) { + auto elems = sigma->projs(); + return merge_sigma(elem, elems); + } + return w.sigma({elem, sigma}); +} + +const Def* merge_t(World& w, Ref elem, Ref tuple, Ref mem) { + if (mem) { + auto elems = tuple->projs(); + return merge_tuple(elem, elems); + } + return w.tuple({elem, tuple}); +} + +} // namespace + Ref LowerFor::rewrite(Ref def) { if (auto i = rewritten_.find(def); i != rewritten_.end()) return i->second; @@ -19,28 +39,25 @@ Ref LowerFor::rewrite(Ref def) { auto exit_lam = exit->isa_mut(); if (!body_lam || !exit_lam) return def; - auto init_types = init->type()->projs(); - auto head_lam = world().mut_lam(world().cn(merge_sigma(begin->type(), init_types)))->set("head"); - auto phis = head_lam->vars(); - auto iter = phis.front(); - auto acc = world().tuple(phis.skip_front()); - auto mem_phi = mem::mem_var(head_lam); - auto bb_type = world().cn(mem_phi ? mem_phi->type() : world().sigma()); - auto new_body = world().mut_lam(bb_type)->set("new_body"); - auto new_exit = world().mut_lam(bb_type)->set("new_exit"); - auto new_yield = world().mut_lam(world().cn(init->type()))->set("new_yield"); - auto cmp = world().call(core::icmp::ul, Defs{iter, end}); - auto new_iter = world().call(core::wrap::add, core::Mode::nusw, Defs{iter, step}); - - head_lam->branch(false, cmp, new_body, new_exit, mem_phi); - - auto new_yield_vars = new_yield->vars(); - new_yield->app(false, head_lam, merge_tuple(new_iter, new_yield_vars)); + auto mem = mem::mem_def(init); + auto head_lam = world().mut_lam(world().cn(merge_s(world(), begin->type(), init->type(), mem)))->set("head"); + auto phis = head_lam->vars(); + auto iter = phis.front(); + auto acc = world().tuple(phis.skip_front()); + mem = mem::mem_var(head_lam); + auto bb_type = world().cn(mem ? mem->type() : world().sigma()); + auto new_body = world().mut_lam(bb_type)->set("new_body"); + auto new_exit = world().mut_lam(bb_type)->set("new_exit"); + auto new_yield = world().mut_lam(world().cn(init->type()))->set("new_yield"); + auto cmp = world().call(core::icmp::ul, Defs{iter, end}); + auto new_iter = world().call(core::wrap::add, core::Mode::nusw, Defs{iter, step}); + head_lam->branch(false, cmp, new_body, new_exit, mem); + new_yield->app(false, head_lam, merge_t(world(), new_iter, new_yield->var(), mem)); new_body->set(false, body->reduce(world().tuple({iter, acc, new_yield})).back()); new_exit->set(false, exit->reduce(acc).back()); - return rewritten_[def] = world().app(head_lam, merge_tuple(begin, init->projs())); + return rewritten_[def] = world().app(head_lam, merge_t(world(), begin, init, mem)); } return def; diff --git a/dialects/clos/clos.cpp b/dialects/clos/clos.cpp index d752932f9f..89cf352755 100644 --- a/dialects/clos/clos.cpp +++ b/dialects/clos/clos.cpp @@ -84,11 +84,11 @@ Ref clos_pack(Ref env, Ref lam, Ref ct) { auto pi = lam->type()->as(); assert(env->type() == pi->dom(Clos_Env_Param)); ct = (ct) ? ct : clos_type(w.cn(clos_remove_env(pi->dom()))); - return w.tuple(ct, {env->type(), lam, env})->isa(); + return w.tuple({env->type(), lam, env})->isa(); } std::tuple clos_unpack(Ref c) { - assert(c && isa_clos_type(c->type())); + assert(c); // && isa_clos_type(c->type())); // auto& w = c->world(); // auto env_type = c->proj(0_u64); // // auto pi = clos_type_to_pi(c->type(), env_type); diff --git a/dialects/clos/clos.thorin b/dialects/clos/clos.thorin index a80775b900..c7450b42e6 100644 --- a/dialects/clos/clos.thorin +++ b/dialects/clos/clos.thorin @@ -47,14 +47,14 @@ %compile.pass_phase (%compile.pass_list eta_red eta_exp - (%compile.scalerize_pass (eta_exp, %compile.scalerize_threshold)) + (%compile.scalerize_pass eta_exp) ) }; .let clos_opt2_phase = { .let nullptr = %compile.nullptr_pass; %compile.pass_phase (%compile.pass_list nullptr - (%compile.scalerize_pass (nullptr, %compile.scalerize_threshold)) + (%compile.scalerize_pass nullptr) %clos.branch_clos_pass (%mem.copy_prop_pass (nullptr, nullptr, .tt)) %clos.lower_typed_clos_prep_pass diff --git a/dialects/compile/compile.cpp b/dialects/compile/compile.cpp index 3b7682397f..6189b7f766 100644 --- a/dialects/compile/compile.cpp +++ b/dialects/compile/compile.cpp @@ -77,16 +77,7 @@ extern "C" THORIN_EXPORT thorin::Plugin thorin_get_plugin() { register_pass(passes); register_pass_with_arg(passes); - - passes[flags_t(Annex::Base)] - = [&](World& world, PipelineBuilder& builder, const Def* app) { - auto [eta_exp, scalerize_threshold] = app->as()->args<2>(); - auto ee = (EtaExp*)builder.pass(eta_exp); - auto threshold = scalerize_threshold->as()->get(); - world.DLOG("registering Scalerize with ee = {}, scalerize_threshold = {}", ee, threshold); - builder.add_pass(app, ee, threshold); - }; - + register_pass_with_arg(passes); register_pass_with_arg(passes); }, nullptr}; diff --git a/dialects/compile/compile.thorin b/dialects/compile/compile.thorin index d54d992ad1..52d4d2f497 100644 --- a/dialects/compile/compile.thorin +++ b/dialects/compile/compile.thorin @@ -103,9 +103,8 @@ .ax %compile.eta_red_pass: %compile.Pass; /// Eta expansion expects an instance of eta reduction as argument. .ax %compile.eta_exp_pass: %compile.Pass -> %compile.Pass; -/// Scalerize expects an instance of eta expansion as argument and a threshold where scalarize should stop. -.ax %compile.scalerize_pass: [%compile.Pass, .Nat] -> %compile.Pass; -.let %compile.scalerize_threshold = 32; +/// Scalerize expects an instance of eta expansion as argument. +.ax %compile.scalerize_pass: %compile.Pass -> %compile.Pass; /// Tail recursion elimination expects an instance of eta reduction as argument. .ax %compile.tail_rec_elim_pass: %compile.Pass -> %compile.Pass; .ax %compile.lam_spec_pass: %compile.Pass; @@ -124,7 +123,7 @@ %compile.beta_red_pass eta_red eta_exp - (%compile.scalerize_pass (eta_exp, %compile.scalerize_threshold)) + (%compile.scalerize_pass eta_exp) (%compile.tail_rec_elim_pass eta_red) }; .let optimization_phase = { @@ -137,7 +136,7 @@ .let nullptr = %compile.nullptr_pass; %compile.pipe (%compile.single_pass_phase nullptr) - (%compile.single_pass_phase (%compile.scalerize_pass (nullptr, %compile.scalerize_threshold))) + (%compile.single_pass_phase (%compile.scalerize_pass nullptr)) (%compile.single_pass_phase %compile.eta_red_pass) (%compile.single_pass_phase (%compile.tail_rec_elim_pass nullptr)) optimization_phase diff --git a/dialects/core/be/ll.cpp b/dialects/core/be/ll.cpp index 530f7dccb5..4c1f552d76 100644 --- a/dialects/core/be/ll.cpp +++ b/dialects/core/be/ll.cpp @@ -347,11 +347,12 @@ void Emitter::emit_epilogue(Lam* lam) { // each callees type should agree with the argument type (should be checked by type checking). // Especially, the number of vars should be the number of arguments. // TODO: does not hold for complex arguments that are not tuples. - assert(callee->num_vars() == app->num_args()); - for (size_t i = 0, e = callee->num_vars(); i != e; ++i) { + assert(callee->num_tvars() == app->num_targs()); + size_t n = callee->num_tvars(); + for (size_t i = 0; i != n; ++i) { // emits the arguments one by one (TODO: handle together like before) - if (auto arg = emit_unsafe(app->arg(i)); !arg.empty()) { - auto phi = callee->var(i); + if (auto arg = emit_unsafe(app->arg(n, i)); !arg.empty()) { + auto phi = callee->var(n, i); assert(!match(phi->type())); lam2bb_[callee].phis[phi].emplace_back(arg, id(lam, true)); locals_[phi] = id(phi); @@ -373,9 +374,10 @@ void Emitter::emit_epilogue(Lam* lam) { } else if (app->callee()->isa()) { return bb.tail("ret ; bottom: unreachable"); } else if (auto callee = Lam::isa_mut_basicblock(app->callee())) { // ordinary jump - for (size_t i = 0, e = callee->num_vars(); i != e; ++i) { - if (auto arg = emit_unsafe(app->arg(i)); !arg.empty()) { - auto phi = callee->var(i); + size_t n = callee->num_tvars(); + for (size_t i = 0; i != n; ++i) { + if (auto arg = emit_unsafe(app->arg(n, i)); !arg.empty()) { + auto phi = callee->var(n, i); assert(!match(phi->type())); lam2bb_[callee].phis[phi].emplace_back(arg, id(lam, true)); locals_[phi] = id(phi); diff --git a/dialects/mem/mem.h b/dialects/mem/mem.h index bdb7a0e3ea..fd6af15d57 100644 --- a/dialects/mem/mem.h +++ b/dialects/mem/mem.h @@ -33,13 +33,14 @@ inline const Pi* fn_mem(Ref domain, Ref codomain) { /// Returns the (first) element of type mem::M from the given tuple. inline Ref mem_def(Ref def) { if (match(def->type())) return def; + if (def->type()->isa()) return {}; // don't look into possibly gigantic arrays if (def->num_projs() > 1) { for (auto proj : def->projs()) if (auto mem = mem_def(proj)) return mem; } - return nullptr; + return {}; } /// Returns the memory argument of a function if it has one. diff --git a/dialects/mem/normalizers.cpp b/dialects/mem/normalizers.cpp index 2875571a71..9773ccc308 100644 --- a/dialects/mem/normalizers.cpp +++ b/dialects/mem/normalizers.cpp @@ -23,8 +23,7 @@ Ref normalize_load(Ref type, Ref callee, Ref arg) { if (ptr->isa()) return world.tuple({mem, world.bot(type->as()->op(1))}); // loading an empty tuple can only result in an empty tuple - if (auto sigma = pointee->isa(); sigma && sigma->num_ops() == 0) - return world.tuple({mem, world.tuple(sigma->type(), {})}); + if (auto sigma = pointee->isa(); sigma && sigma->num_ops() == 0) return world.tuple({mem, world.tuple()}); return world.raw_app(type, callee, {mem, ptr}); } diff --git a/dialects/mem/passes/fp/copy_prop.cpp b/dialects/mem/passes/fp/copy_prop.cpp index d1b067a673..4bc23231fe 100644 --- a/dialects/mem/passes/fp/copy_prop.cpp +++ b/dialects/mem/passes/fp/copy_prop.cpp @@ -11,7 +11,7 @@ Ref CopyProp::rewrite(Ref def) { auto [app, var_lam] = isa_apped_mut_lam(def); if (!isa_workable(var_lam) || (bb_only_ && Lam::isa_returning(var_lam))) return def; - auto n = app->num_args(); + auto n = app->num_targs(); if (n == 0) return app; auto [it, _] = lam2info_.emplace(var_lam, std::tuple(Lattices(n), (Lam*)nullptr, DefArray(n))); @@ -28,20 +28,20 @@ Ref CopyProp::rewrite(Ref def) { switch (lattice[i]) { case Lattice::Dead: break; case Lattice::Prop: - if (app->arg(i)->has_dep(Dep::Proxy)) { + if (app->arg(n, i)->has_dep(Dep::Proxy)) { world().DLOG("found proxy within app: {}@{} - wait till proxy is gone", var_lam, app); return app; } else if (args[i] == nullptr) { - args[i] = app->arg(i); - } else if (args[i] != app->arg(i)) { + args[i] = app->arg(n, i); + } else if (args[i] != app->arg(n, i)) { appxy_ops.emplace_back(world().lit_nat(i)); } else { - assert(args[i] == app->arg(i)); + assert(args[i] == app->arg(n, i)); } break; case Lattice::Keep: - new_doms.emplace_back(var_lam->var(i)->type()); - new_args.emplace_back(app->arg(i)); + new_doms.emplace_back(var_lam->var(n, i)->type()); + new_args.emplace_back(app->arg(n, i)); break; default: unreachable(); } @@ -71,7 +71,7 @@ Ref CopyProp::rewrite(Ref def) { size_t j = 0; DefArray new_vars(n, [&, prop_lam = prop_lam](size_t i) -> Ref { switch (lattice[i]) { - case Lattice::Dead: return proxy(var_lam->var(i)->type(), {var_lam, world().lit_nat(i)}, Varxy); + case Lattice::Dead: return proxy(var_lam->var(n, i)->type(), {var_lam, world().lit_nat(i)}, Varxy); case Lattice::Prop: return args[i]; case Lattice::Keep: return prop_lam->var(j++); default: unreachable(); diff --git a/dialects/mem/passes/fp/ssa_constr.cpp b/dialects/mem/passes/fp/ssa_constr.cpp index 183db10d1c..5a748b8ac1 100644 --- a/dialects/mem/passes/fp/ssa_constr.cpp +++ b/dialects/mem/passes/fp/ssa_constr.cpp @@ -125,7 +125,7 @@ Ref SSAConstr::mem2phi(const App* app, Lam* mem_lam) { auto traxy = proxy(phi_lam->var()->type(), traxy_ops, Traxy); DefArray new_vars(num_mem_vars, [&](size_t i) { return traxy->proj(i); }); - phi_lam->set(mem_lam->reduce(world().tuple(mem_lam->dom(), new_vars))); + phi_lam->set(mem_lam->reduce(world().tuple(new_vars))); } else { world().DLOG("reuse phi_lam '{}'", phi_lam); } diff --git a/dialects/opt/opt.thorin b/dialects/opt/opt.thorin index 98fcb7d4e0..9ab438d5f0 100644 --- a/dialects/opt/opt.thorin +++ b/dialects/opt/opt.thorin @@ -27,7 +27,7 @@ .let nullphase = %compile.single_pass_phase nullptr; %compile.pipe nullphase - (%compile.single_pass_phase (%compile.scalerize_pass (nullptr, %compile.scalerize_threshold))) + (%compile.single_pass_phase (%compile.scalerize_pass nullptr)) (%compile.single_pass_phase %compile.eta_red_pass) (%compile.single_pass_phase (%compile.tail_rec_elim_pass nullptr)) (%compile.single_pass_phase (plugin_cond_pass (%compile.regex_plugin, %regex.lower_regex))) diff --git a/gtest/test.cpp b/gtest/test.cpp index 1fe6fa4a71..77db4be732 100644 --- a/gtest/test.cpp +++ b/gtest/test.cpp @@ -45,7 +45,7 @@ TEST(World, simplify_one_tuple) { type->set(Defs{w.type_nat(), w.type_nat()}); ASSERT_EQ(type, w.sigma({type})) << "constant fold [mut] -> mut"; - auto v = w.tuple(type, {w.lit_idx(42), w.lit_idx(1337)}); + auto v = w.tuple({w.lit_idx(42), w.lit_idx(1337)}); ASSERT_EQ(v, w.tuple({v})) << "constant fold ({42, 1337}) -> {42, 1337}"; } diff --git a/lit/tuple_type_bug.thorin b/lit/tuple_type_bug.thorin new file mode 100644 index 0000000000..de56933714 --- /dev/null +++ b/lit/tuple_type_bug.thorin @@ -0,0 +1,21 @@ +// RUN: %thorin %s -o - +.plugin core; +.plugin mem; + +.Sigma Num: □, 2 = [T: *, _0: T]; // could be anything +.Sigma Shp: *, 2 = [D: [.Nat, .Nat], N: Num]; +.let I64: Num = (%core.I64, 0:%core.I64); +.ax %bug.Mat: Shp -> *; + +.ax %bug.matmul: Π [m: .Nat, n: .Nat, o: .Nat] [N: Num] + [X: %bug.Mat ((m, n), N), Y: %bug.Mat ((n, o), N)] -> %bug.Mat ((m, o), N); + +.fun .extern main [mem: %mem.M, + pX: %mem.Ptr0 (%bug.Mat ((3, 4), I64)), + pY: %mem.Ptr0 (%bug.Mat ((4, 5), I64)), + pZ: %mem.Ptr0 (%bug.Mat ((3, 5), I64))]: %mem.M = + .let (`mem, X) = %mem.load (mem, pX); + .let (`mem, Y) = %mem.load (mem, pY); + .let Z = %bug.matmul (3, 4, 5) I64 (X, Y); + .let `mem = %mem.store (mem, pZ, Z); + return mem; diff --git a/thorin/def.cpp b/thorin/def.cpp index 43b22bdaa8..a25c95b54b 100644 --- a/thorin/def.cpp +++ b/thorin/def.cpp @@ -101,7 +101,7 @@ Ref Sigma ::rebuild(World& w, Ref , Defs o) const { return w.sigma(o) Ref Singleton::rebuild(World& w, Ref , Defs o) const { return w.singleton(o[0]) ->set(dbg()); } Ref Type ::rebuild(World& w, Ref , Defs o) const { return w.type(o[0]) ->set(dbg()); } Ref Test ::rebuild(World& w, Ref , Defs o) const { return w.test(o[0], o[1], o[2], o[3]) ->set(dbg()); } -Ref Tuple ::rebuild(World& w, Ref t, Defs o) const { return w.tuple(t, o) ->set(dbg()); } +Ref Tuple ::rebuild(World& w, Ref , Defs o) const { return w.tuple(o) ->set(dbg()); } Ref UInc ::rebuild(World& w, Ref , Defs o) const { return w.uinc(o[0], offset()) ->set(dbg()); } Ref UMax ::rebuild(World& w, Ref , Defs o) const { return w.umax(o) ->set(dbg()); } Ref Var ::rebuild(World& w, Ref t, Defs o) const { return w.var(t, o[0]->as_mut()) ->set(dbg()); } @@ -415,6 +415,11 @@ void Def::make_internal() { return world().make_internal(this); } std::string Def::unique_name() const { return *sym() + "_"s + std::to_string(gid()); } +nat_t Def::num_tprojs() const { + if (auto a = isa_lit_arity(); a && *a < world().flags().scalerize_threshold) return *a; + return 1; +} + const Def* Def::proj(nat_t a, nat_t i) const { if (a == 1) { if (!type()) return this; diff --git a/thorin/def.h b/thorin/def.h index bd85042fec..72562e74a4 100644 --- a/thorin/def.h +++ b/thorin/def.h @@ -147,10 +147,14 @@ THORIN_ENUM_OPERATORS(Dep) /// Use as mixin to wrap all kind of Def::proj and Def::projs variants. #define THORIN_PROJ(NAME, CONST) \ nat_t num_##NAME##s() CONST { return ((const Def*)NAME())->num_projs(); } \ + nat_t num_t##NAME##s() CONST { return ((const Def*)NAME())->num_tprojs(); } \ Ref NAME(nat_t a, nat_t i) CONST { return ((const Def*)NAME())->proj(a, i); } \ Ref NAME(nat_t i) CONST { return ((const Def*)NAME())->proj(i); } \ + Ref t##NAME(nat_t i) CONST { return ((const Def*)NAME())->tproj(i); } \ template auto NAME##s(F f) CONST { return ((const Def*)NAME())->projs(f); } \ + template auto t##NAME##s(F f) CONST { return ((const Def*)NAME())->tprojs(f); } \ template auto NAME##s() CONST { return ((const Def*)NAME())->projs(); } \ + auto t##NAME##s() CONST { return ((const Def*)NAME())->tprojs(); } \ template auto NAME##s(nat_t a, F f) CONST { return ((const Def*)NAME())->projs(a, f); } \ auto NAME##s(nat_t a) CONST { return ((const Def*)NAME())->projs(a); } @@ -338,18 +342,19 @@ class Def : public RuntimeCast { /// Yields Def::as_lit_arity(), if it is in fact a Lit, or `1` otherwise. nat_t num_projs() const { return isa_lit_arity().value_or(1); } + nat_t num_tprojs() const; /// Similar to World::extract while assuming an arity of @p a, but also works on Sigma%s and Arr%ays. const Def* proj(nat_t a, nat_t i) const; - /// Same as above but takes Def::num_projs as arity. - const Def* proj(nat_t i) const { return proj(num_projs(), i); } + const Def* proj(nat_t i) const { return proj(num_projs(), i); } /// As above but takes Def::num_projs as arity. + const Def* tproj(nat_t i) const { return proj(num_tprojs(), i); } /// As above but takes Def::num_tprojs. - /// Splits this Def via Def::proj%ections into an Array (if `A == -1_s`) or `std::array` (otherwise). + /// Splits this Def via Def::proj%ections into an Array (if `A == -1_n`) or `std::array` (otherwise). /// Applies @p f to each element. - template auto projs(F f) const { + template auto projs(F f) const { using R = std::decay_t; - if constexpr (A == -1_s) { + if constexpr (A == -1_n) { return projs(num_projs(), f); } else { assert(A == as_lit_arity()); @@ -359,13 +364,18 @@ class Def : public RuntimeCast { } } + template auto tprojs(F f) const { return projs(num_tprojs(), f); } + template auto projs(nat_t a, F f) const { using R = std::decay_t; return Array(a, [&](nat_t i) { return f(proj(a, i)); }); } - template auto projs() const { + template auto projs() const { return projs([](const Def* def) { return def; }); } + auto tprojs() const { + return tprojs([](const Def* def) { return def; }); + } auto projs(nat_t a) const { return projs(a, [](const Def* def) { return def; }); } diff --git a/thorin/dump.cpp b/thorin/dump.cpp index 80f6d794ad..45fd18e339 100644 --- a/thorin/dump.cpp +++ b/thorin/dump.cpp @@ -287,7 +287,7 @@ void Dumper::dump_ptrn(const Def* def, const Def* type) { if (!def) { os << type; } else { - auto projs = def->projs(); + auto projs = def->tprojs(); if (projs.size() == 1 || std::ranges::all_of(projs, [](auto def) { return !def; })) { print(os, "{}: {}", def->unique_name(), type); } else { diff --git a/thorin/flags.h b/thorin/flags.h index 75e2dc81c3..d8dd823d04 100644 --- a/thorin/flags.h +++ b/thorin/flags.h @@ -1,16 +1,19 @@ #pragma once +#include + #include "thorin/config.h" namespace thorin { // Compiler switches that must be saved and looked up in later phases of compilation. struct Flags { - int dump_gid = 0; - bool dump_recursive = false; - bool disable_type_checking = false; // TODO implement this flag - bool bootstrap = false; - bool aggressive_lam_spec = false; // HACK makes LamSpec more agressive but potentially non-terminating + uint32_t dump_gid = 0; + uint64_t scalerize_threshold = 32; + bool dump_recursive = false; + bool disable_type_checking = false; // TODO implement this flag + bool bootstrap = false; + bool aggressive_lam_spec = false; // HACK makes LamSpec more agressive but potentially non-terminating #ifdef THORIN_ENABLE_CHECKS bool reeval_breakpoints = false; bool trace_gids = false; diff --git a/thorin/pass/rw/ret_wrap.cpp b/thorin/pass/rw/ret_wrap.cpp index 467645b6a0..78c0506371 100644 --- a/thorin/pass/rw/ret_wrap.cpp +++ b/thorin/pass/rw/ret_wrap.cpp @@ -14,7 +14,7 @@ void RetWrap::enter() { auto new_vars = curr_mut()->vars(); assert(new_vars.back() == ret_var && "we assume that the last element is the ret_var"); new_vars.back() = ret_cont; - auto new_var = world().tuple(curr_mut()->dom(), new_vars); + auto new_var = world().tuple(new_vars); curr_mut()->reset(curr_mut()->reduce(new_var)); } diff --git a/thorin/pass/rw/scalarize.cpp b/thorin/pass/rw/scalarize.cpp index 22b8f6991e..5dd8c7bcae 100644 --- a/thorin/pass/rw/scalarize.cpp +++ b/thorin/pass/rw/scalarize.cpp @@ -30,7 +30,7 @@ Lam* Scalerize::make_scalar(Ref def) { auto arg_sz = std::vector(); bool todo = false; for (size_t i = 0, e = tup_lam->num_doms(); i != e; ++i) { - auto n = flatten(threshold_, types, tup_lam->dom(i), false); + auto n = flatten(types, tup_lam->dom(i), false); arg_sz.push_back(n); todo |= n != 1 || types.back() != tup_lam->dom(i); } @@ -44,7 +44,7 @@ Lam* Scalerize::make_scalar(Ref def) { world().DLOG("type {} ~> {}", tup_lam->type(), pi); auto new_vars = world().tuple(DefArray(tup_lam->num_doms(), [&](auto i) { auto tuple = DefArray(arg_sz.at(i), [&](auto) { return sca_lam->var(n++); }); - return unflatten(threshold_, tuple, tup_lam->dom(i), false); + return unflatten(tuple, tup_lam->dom(i), false); })); sca_lam->set(tup_lam->reduce(new_vars)); tup2sca_[sca_lam] = sca_lam; @@ -74,7 +74,7 @@ Ref Scalerize::rewrite(Ref def) { if (sca_callee != app->callee()) { auto new_args = DefVec(); - flatten(threshold_, new_args, app->arg(), false); + flatten(new_args, app->arg(), false); return world().app(sca_callee, new_args); } } diff --git a/thorin/pass/rw/scalarize.h b/thorin/pass/rw/scalarize.h index df214ea0d9..2ede8465d8 100644 --- a/thorin/pass/rw/scalarize.h +++ b/thorin/pass/rw/scalarize.h @@ -18,10 +18,9 @@ class EtaExp; /// It will not flatten mutable @p Sigma%s or @p Arr%ays. class Scalerize : public RWPass { public: - Scalerize(PassMan& man, EtaExp* eta_exp, nat_t threshold) + Scalerize(PassMan& man, EtaExp* eta_exp) : RWPass(man, "scalerize") - , eta_exp_(eta_exp) - , threshold_(threshold) {} + , eta_exp_(eta_exp) {} Ref rewrite(Ref) override; @@ -30,7 +29,6 @@ class Scalerize : public RWPass { Lam* make_scalar(Ref def); EtaExp* eta_exp_; - nat_t threshold_; Lam2Lam tup2sca_; }; diff --git a/thorin/tuple.cpp b/thorin/tuple.cpp index 6e6351b433..c36bf2d464 100644 --- a/thorin/tuple.cpp +++ b/thorin/tuple.cpp @@ -10,11 +10,11 @@ namespace thorin { namespace { -bool should_flatten(nat_t threshold, const Def* def) { +bool should_flatten(const Def* def) { auto type = (def->is_term() ? def->type() : def); if (type->isa()) return true; if (auto arr = type->isa()) { - if (auto a = arr->isa_lit_arity(); a && *a > threshold) return false; + if (auto a = arr->isa_lit_arity(); a && *a > def->world().flags().scalerize_threshold) return false; return true; } return false; @@ -25,12 +25,13 @@ bool mut_val_or_typ(const Def* def) { return typ->isa_mut(); } -const Def* unflatten(nat_t threshold, Defs defs, const Def* type, size_t& j, bool flatten_muts) { +const Def* unflatten(Defs defs, const Def* type, size_t& j, bool flatten_muts) { if (!defs.empty() && defs[0]->type() == type) return defs[j++]; - if (auto a = type->isa_lit_arity(); flatten_muts == mut_val_or_typ(type) && a && *a != 1 && a <= threshold) { + if (auto a = type->isa_lit_arity(); + flatten_muts == mut_val_or_typ(type) && a && *a != 1 && a <= type->world().flags().scalerize_threshold) { auto& world = type->world(); - DefArray ops(*a, [&](size_t i) { return unflatten(threshold, defs, type->proj(*a, i), j, flatten_muts); }); - return world.tuple(type, ops); + DefArray ops(*a, [&](size_t i) { return unflatten(defs, type->proj(*a, i), j, flatten_muts); }); + return world.tuple(ops); } return defs[j++]; @@ -52,11 +53,10 @@ std::string tuple2str(const Def* def) { return std::string(array.begin(), array.end()); } -size_t flatten(nat_t threshold, DefVec& ops, const Def* def, bool flatten_muts) { - if (auto a = def->isa_lit_arity(); - a && *a != 1 && should_flatten(threshold, def) && flatten_muts == mut_val_or_typ(def)) { +size_t flatten(DefVec& ops, const Def* def, bool flatten_muts) { + if (auto a = def->isa_lit_arity(); a && *a != 1 && should_flatten(def) && flatten_muts == mut_val_or_typ(def)) { auto n = 0; - for (size_t i = 0; i != *a; ++i) n += flatten(threshold, ops, def->proj(*a, i), flatten_muts); + for (size_t i = 0; i != *a; ++i) n += flatten(ops, def->proj(*a, i), flatten_muts); return n; } else { ops.emplace_back(def); @@ -64,23 +64,21 @@ size_t flatten(nat_t threshold, DefVec& ops, const Def* def, bool flatten_muts) } } -const Def* flatten(nat_t threshold, const Def* def) { - if (!should_flatten(threshold, def)) return def; +const Def* flatten(const Def* def) { + if (!should_flatten(def)) return def; DefVec ops; - flatten(threshold, ops, def); - return def->is_term() ? def->world().tuple(def->type(), ops) : def->world().sigma(ops); + flatten(ops, def); + return def->is_term() ? def->world().tuple(ops) : def->world().sigma(ops); } -const Def* unflatten(nat_t threshold, Defs defs, const Def* type, bool flatten_muts) { +const Def* unflatten(Defs defs, const Def* type, bool flatten_muts) { size_t j = 0; - auto def = unflatten(threshold, defs, type, j, flatten_muts); + auto def = unflatten(defs, type, j, flatten_muts); assert(j == defs.size()); return def; } -const Def* unflatten(nat_t threshold, const Def* def, const Def* type) { - return unflatten(threshold, def->projs(Lit::as(def->arity())), type); -} +const Def* unflatten(const Def* def, const Def* type) { return unflatten(def->projs(Lit::as(def->arity())), type); } DefArray merge(const Def* def, Defs defs) { return DefArray(defs.size() + 1, [&](auto i) { return i == 0 ? def : defs[i - 1]; }); diff --git a/thorin/tuple.h b/thorin/tuple.h index e4b27e4a74..960601af18 100644 --- a/thorin/tuple.h +++ b/thorin/tuple.h @@ -153,14 +153,14 @@ bool is_unit(const Def*); std::string tuple2str(const Def*); /// Flattens a sigma/array/pack/tuple. -const Def* flatten(nat_t threshold, const Def* def); +const Def* flatten(const Def* def); /// Same as unflatten, but uses the operands of a flattened pack/tuple directly. -size_t flatten(nat_t threshold, DefVec& ops, const Def* def, bool flatten_sigmas = true); +size_t flatten(DefVec& ops, const Def* def, bool flatten_sigmas = true); /// Applies the reverse transformation on a pack/tuple, given the original type. -const Def* unflatten(nat_t threshold, const Def* def, const Def* type); +const Def* unflatten(const Def* def, const Def* type); /// Same as unflatten, but uses the operands of a flattened pack/tuple directly. -const Def* unflatten(nat_t threshold, Defs ops, const Def* type, bool flatten_muts = true); +const Def* unflatten(Defs ops, const Def* type, bool flatten_muts = true); DefArray merge(Defs, Defs); DefArray merge(const Def* def, Defs defs); diff --git a/thorin/world.cpp b/thorin/world.cpp index dca0d1e64c..2b8a61a7c9 100644 --- a/thorin/world.cpp +++ b/thorin/world.cpp @@ -239,14 +239,15 @@ Ref World::tuple(Defs ops) { if (ops.size() == 1) return ops[0]; auto sigma = infer_sigma(*this, ops); - auto t = tuple(sigma, ops); + auto t = tuple_(sigma, ops); if (!Check::assignable(sigma, t)) error(t, "cannot assign tuple '{}' of type '{}' to incompatible tuple type '{}'", t, t->type(), sigma); return t; } -Ref World::tuple(Ref type, Defs ops) { +Ref World::tuple_(Ref type, Defs ops) { + assert(type->isa_imm()); // TODO type-check type vs inferred type auto n = ops.size(); @@ -364,18 +365,18 @@ Ref World::insert(Ref d, Ref index, Ref val) { error(val, "value of type {} is not assignable to type {}", val->type(), target_type); } - if (auto l = Lit::isa(size); l && *l == 1) - return tuple(d, {val}); // d could be mut - that's why the tuple ctor is needed + if (auto l = Lit::isa(size); l && *l == 1) return val; + // return tuple_(d, {val}); // d could be mut - that's why the tuple ctor is needed // insert((a, b, c, d), 2, x) -> (a, b, x, d) if (auto t = d->isa(); t && lidx) return t->refine(*lidx, val); // insert(‹4; x›, 2, y) -> (x, x, y, x) if (auto pack = d->isa(); pack && lidx) { - if (auto a = pack->isa_lit_arity()) { + if (auto a = pack->isa_lit_arity(); a && *a < flags().scalerize_threshold) { DefArray new_ops(*a, pack->body()); new_ops[*lidx] = val; - return tuple(type, new_ops); + return tuple_(type, new_ops); } } @@ -464,7 +465,7 @@ const Lit* World::lit(Ref type, u64 val) { template Ref World::ext(Ref type) { if (auto arr = type->isa()) return pack(arr->shape(), ext(arr->body())); if (auto sigma = type->isa()) - return tuple(sigma, DefArray(sigma->num_ops(), [&](size_t i) { return ext(sigma->op(i)); })); + return tuple_(sigma, DefArray(sigma->num_ops(), [&](size_t i) { return ext(sigma->op(i)); })); return unify>(0, type); } diff --git a/thorin/world.h b/thorin/world.h index b4102fae58..fa4ef109fd 100644 --- a/thorin/world.h +++ b/thorin/world.h @@ -288,7 +288,7 @@ class World { ///@{ Ref tuple(Defs ops); /// Ascribes @p type to this tuple - needed for dependently typed and mutable Sigma%s. - Ref tuple(Ref type, Defs ops); + Ref tuple_(Ref type, Defs ops); const Tuple* tuple() { return data_.tuple; } ///< the unit value of type `[]` Ref tuple(Sym sym); ///< Converts @p sym to a tuple of type '«n; I8»'. ///@}