Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Effects: keep track of CPS calls and distinguish CPS calls and trampolined calls #1648

Merged
merged 1 commit into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions compiler/lib/driver.ml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ let phi p =

let ( +> ) f g x = g (f x)

let map_fst f (x, y) = f x, y
let map_fst f (x, y, z) = f x, y, z

let effects ~deadcode_sentinal p =
if Config.Flag.effects ()
Expand All @@ -104,9 +104,11 @@ let effects ~deadcode_sentinal p =
Deadcode.f p
else p, live_vars
in
let p, cps = p |> Effects.f ~flow_info:info ~live_vars +> map_fst Lambda_lifting.f in
p, cps)
else p, (Code.Var.Set.empty : Effects.cps_calls)
p |> Effects.f ~flow_info:info ~live_vars +> map_fst Lambda_lifting.f)
else
( p
, (Code.Var.Set.empty : Effects.trampolined_calls)
, (Code.Var.Set.empty : Effects.in_cps) )

let exact_calls profile ~deadcode_sentinal p =
if not (Config.Flag.effects ())
Expand Down Expand Up @@ -193,14 +195,14 @@ let generate
~wrap_with_fun
~warn_on_unhandled_effect
~deadcode_sentinal
((p, live_vars), cps_calls) =
((p, live_vars), trampolined_calls, _) =
if times () then Format.eprintf "Start Generation...@.";
let should_export = should_export wrap_with_fun in
Generate.f
p
~exported_runtime
~live_vars
~cps_calls
~trampolined_calls
~should_export
~warn_on_unhandled_effect
~deadcode_sentinal
Expand Down
41 changes: 29 additions & 12 deletions compiler/lib/effects.ml
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ let jump_closures blocks_to_transform idom : jump_closures =
idom
{ closure_of_jump = Addr.Map.empty; closures_of_alloc_site = Addr.Map.empty }

type cps_calls = Var.Set.t
type trampolined_calls = Var.Set.t

type in_cps = Var.Set.t

type st =
{ mutable new_blocks : Code.block Addr.Map.t * Code.Addr.t
Expand All @@ -263,7 +265,8 @@ type st =
; block_order : (Addr.t, int) Hashtbl.t
; live_vars : Deadcode.variable_uses
; flow_info : Global_flow.info
; cps_calls : cps_calls ref
; trampolined_calls : trampolined_calls ref
; in_cps : in_cps ref
}

let add_block st block =
Expand All @@ -280,10 +283,11 @@ let allocate_closure ~st ~params ~body ~branch loc =
let name = Var.fresh () in
[ Let (name, Closure (params, (pc, []))), loc ], name

let tail_call ~st ?(instrs = []) ~exact ~check ~f args loc =
let tail_call ~st ?(instrs = []) ~exact ~in_cps ~check ~f args loc =
assert (exact || check);
let ret = Var.fresh () in
if check then st.cps_calls := Var.Set.add ret !(st.cps_calls);
if check then st.trampolined_calls := Var.Set.add ret !(st.trampolined_calls);
if in_cps then st.in_cps := Var.Set.add ret !(st.in_cps);
instrs @ [ Let (ret, Apply { f; args; exact }), loc ], (Return ret, loc)

let cps_branch ~st ~src (pc, args) loc =
Expand All @@ -302,7 +306,15 @@ let cps_branch ~st ~src (pc, args) loc =
(* We check the stack depth only for backward edges (so, at
least once per loop iteration) *)
let check = Hashtbl.find st.block_order src >= Hashtbl.find st.block_order pc in
tail_call ~st ~instrs ~exact:true ~check ~f:(closure_of_pc ~st pc) args loc
tail_call
~st
~instrs
~exact:true
~in_cps:false
~check
~f:(closure_of_pc ~st pc)
args
loc

let cps_jump_cont ~st ~src ((pc, _) as cont) loc =
match Addr.Set.mem pc st.blocks_to_transform with
Expand Down Expand Up @@ -365,7 +377,7 @@ let cps_last ~st ~alloc_jump_closures pc ((last, last_loc) : last * loc) ~k :
(* Is the number of successive 'returns' is unbounded is CPS, it
means that we have an unbounded of calls in direct style
(even with tail call optimization) *)
tail_call ~st ~exact:true ~check:false ~f:k [ x ] last_loc
tail_call ~st ~exact:true ~in_cps:false ~check:false ~f:k [ x ] last_loc
| Raise (x, rmode) -> (
assert (List.is_empty alloc_jump_closures);
match Hashtbl.find_opt st.matching_exn_handler pc with
Expand Down Expand Up @@ -401,6 +413,7 @@ let cps_last ~st ~alloc_jump_closures pc ((last, last_loc) : last * loc) ~k :
~instrs:
((Let (exn_handler, Prim (Extern "caml_pop_trap", [])), noloc) :: instrs)
~exact:true
~in_cps:false
~check:false
~f:exn_handler
[ x ]
Expand Down Expand Up @@ -463,6 +476,7 @@ let cps_instr ~st (instr : instr) : instr =
(* Add the continuation parameter, and change the initial block if
needed *)
let k, cont = Hashtbl.find st.closure_info pc in
st.in_cps := Var.Set.add x !(st.in_cps);
Let (x, Closure (params @ [ k ], cont))
| Let (x, Prim (Extern "caml_alloc_dummy_function", [ size; arity ])) -> (
match arity with
Expand Down Expand Up @@ -532,7 +546,7 @@ let cps_block ~st ~k pc block =
let exact =
exact || Global_flow.exact_call st.flow_info f (List.length args)
in
tail_call ~st ~exact ~check:true ~f (args @ [ k ]) loc)
tail_call ~st ~exact ~in_cps:true ~check:true ~f (args @ [ k ]) loc)
| Prim (Extern "%resume", [ Pv stack; Pv f; Pv arg ]) ->
Some
(fun ~k ->
Expand All @@ -542,6 +556,7 @@ let cps_block ~st ~k pc block =
~instrs:
[ Let (k', Prim (Extern "caml_resume_stack", [ Pv stack; Pv k ])), noloc ]
~exact:(Global_flow.exact_call st.flow_info f 1)
~in_cps:true
~check:true
~f
[ arg; k' ]
Expand Down Expand Up @@ -599,7 +614,8 @@ let cps_block ~st ~k pc block =

let cps_transform ~live_vars ~flow_info ~cps_needed p =
let closure_info = Hashtbl.create 16 in
let cps_calls = ref Var.Set.empty in
let trampolined_calls = ref Var.Set.empty in
let in_cps = ref Var.Set.empty in
let p =
Code.fold_closures_innermost_first
p
Expand Down Expand Up @@ -658,7 +674,8 @@ let cps_transform ~live_vars ~flow_info ~cps_needed p =
; block_order = cfg.block_order
; flow_info
; live_vars
; cps_calls
; trampolined_calls
; in_cps
}
in
let function_needs_cps =
Expand Down Expand Up @@ -735,7 +752,7 @@ let cps_transform ~live_vars ~flow_info ~cps_needed p =
in
{ start = new_start; blocks; free_pc = new_start + 1 }
in
p, !cps_calls
p, !trampolined_calls, !in_cps

(****)

Expand Down Expand Up @@ -927,7 +944,7 @@ let f ~flow_info ~live_vars p =
let cps_needed = Partial_cps_analysis.f p flow_info in
let p, cps_needed = rewrite_toplevel ~cps_needed p in
let p = split_blocks ~cps_needed p in
let p, cps_calls = cps_transform ~live_vars ~flow_info ~cps_needed p in
let p, trampolined_calls, in_cps = cps_transform ~live_vars ~flow_info ~cps_needed p in
hhugo marked this conversation as resolved.
Show resolved Hide resolved
if Debug.find "times" () then Format.eprintf " effects: %a@." Timer.print t;
Code.invariant p;
p, cps_calls
p, trampolined_calls, in_cps
6 changes: 4 additions & 2 deletions compiler/lib/effects.mli
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*)

type cps_calls = Code.Var.Set.t
type trampolined_calls = Code.Var.Set.t

val remove_empty_blocks : live_vars:Deadcode.variable_uses -> Code.program -> Code.program

type in_cps = Code.Var.Set.t

val f :
flow_info:Global_flow.info
-> live_vars:Deadcode.variable_uses
-> Code.program
-> Code.program * cps_calls
-> Code.program * trampolined_calls * in_cps
47 changes: 24 additions & 23 deletions compiler/lib/generate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ type fall_through =
type application_description =
{ arity : int
; exact : bool
; cps : bool
; trampolined : bool
}

module Share = struct
Expand Down Expand Up @@ -133,7 +133,7 @@ module Share = struct
| _ -> t)

let get
~cps_calls
~trampolined_calls
?alias_strings
?(alias_prims = false)
?(alias_apply = true)
Expand All @@ -150,9 +150,9 @@ module Share = struct
match i with
| Let (_, Constant c) -> get_constant c share
| Let (x, Apply { args; exact; _ }) ->
let cps = Var.Set.mem x cps_calls in
if (not exact) || cps
then add_apply { arity = List.length args; exact; cps } share
let trampolined = Var.Set.mem x trampolined_calls in
if (not exact) || trampolined
then add_apply { arity = List.length args; exact; trampolined } share
else share
| Let (_, Special (Alias_prim name)) ->
let name = Primitive.resolve name in
Expand Down Expand Up @@ -244,11 +244,11 @@ module Share = struct
try J.EVar (AppMap.find desc t.vars.applies)
with Not_found ->
let x =
let { arity; exact; cps } = desc in
let { arity; exact; trampolined } = desc in
Var.fresh_n
(Printf.sprintf
"caml_%scall%d"
(match exact, cps with
(match exact, trampolined with
| true, false -> assert false
| true, true -> "cps_exact_"
| false, false -> ""
Expand All @@ -269,7 +269,7 @@ module Ctx = struct
; exported_runtime : (Code.Var.t * bool ref) option
; should_export : bool
; effect_warning : bool ref
; cps_calls : Effects.cps_calls
; trampolined_calls : Effects.trampolined_calls
; deadcode_sentinal : Var.t
; mutated_vars : Code.Var.Set.t Code.Addr.Map.t
; freevars : Code.Var.Set.t Code.Addr.Map.t
Expand All @@ -284,7 +284,7 @@ module Ctx = struct
~freevars
blocks
live
cps_calls
trampolined_calls
share
debug =
{ blocks
Expand All @@ -294,7 +294,7 @@ module Ctx = struct
; exported_runtime
; should_export
; effect_warning = ref (not warn_on_unhandled_effect)
; cps_calls
; trampolined_calls
; deadcode_sentinal
; mutated_vars
; freevars
Expand Down Expand Up @@ -773,7 +773,7 @@ let parallel_renaming back_edge params args continuation queue =

(****)

let apply_fun_raw ctx f params exact cps =
let apply_fun_raw ctx f params exact trampolined =
let n = List.length params in
let apply_directly =
(* Make sure we are performing a regular call, not a (slower)
Expand Down Expand Up @@ -802,7 +802,7 @@ let apply_fun_raw ctx f params exact cps =
, apply_directly
, J.call (runtime_fun ctx "caml_call_gen") [ f; J.array params ] J.N )
in
if cps
if trampolined
then (
assert (Config.Flag.effects ());
(* When supporting effect, we systematically perform tailcall
Expand All @@ -815,7 +815,7 @@ let apply_fun_raw ctx f params exact cps =
, J.call (runtime_fun ctx "caml_trampoline_return") [ f; J.array params ] J.N ))
else apply

let generate_apply_fun ctx { arity; exact; cps } =
let generate_apply_fun ctx { arity; exact; trampolined } =
let f' = Var.fresh_n "f" in
let f = J.V f' in
let params =
Expand All @@ -830,23 +830,24 @@ let generate_apply_fun ctx { arity; exact; cps } =
( None
, J.fun_
(f :: params)
[ J.Return_statement (Some (apply_fun_raw ctx f' params' exact cps)), J.N ]
[ J.Return_statement (Some (apply_fun_raw ctx f' params' exact trampolined)), J.N
]
J.N )

let apply_fun ctx f params exact cps loc =
let apply_fun ctx f params exact trampolined loc =
(* We always go through an intermediate function when doing CPS
calls. This function first checks the stack depth to prevent
a stack overflow. This makes the code smaller than inlining
the test, and we expect the performance impact to be low
since the function should get inlined by the JavaScript
engines. *)
if Config.Flag.inline_callgen () || (exact && not cps)
then apply_fun_raw ctx f params exact cps
if Config.Flag.inline_callgen () || (exact && not trampolined)
then apply_fun_raw ctx f params exact trampolined
else
let y =
Share.get_apply
(generate_apply_fun ctx)
{ arity = List.length params; exact; cps }
{ arity = List.length params; exact; trampolined }
ctx.Ctx.share
in
J.call y (f :: params) loc
Expand Down Expand Up @@ -1029,7 +1030,7 @@ let throw_statement ctx cx k loc =
let rec translate_expr ctx queue loc x e level : _ * J.statement_list =
match e with
| Apply { f; args; exact } ->
let cps = Var.Set.mem x ctx.Ctx.cps_calls in
let trampolined = Var.Set.mem x ctx.Ctx.trampolined_calls in
let args, prop, queue =
List.fold_right
~f:(fun x (args, prop, queue) ->
Expand All @@ -1040,7 +1041,7 @@ let rec translate_expr ctx queue loc x e level : _ * J.statement_list =
in
let (prop', f), queue = access_queue queue f in
let prop = or_p prop prop' in
let e = apply_fun ctx f args exact cps loc in
let e = apply_fun ctx f args exact trampolined loc in
(e, prop, queue), []
| Block (tag, a, array_or_not, _mut) ->
let contents, prop, queue =
Expand Down Expand Up @@ -1949,13 +1950,13 @@ let f
(p : Code.program)
~exported_runtime
~live_vars
~cps_calls
~trampolined_calls
~should_export
~warn_on_unhandled_effect
~deadcode_sentinal
debug =
let t' = Timer.make () in
let share = Share.get ~cps_calls ~alias_prims:exported_runtime p in
let share = Share.get ~trampolined_calls ~alias_prims:exported_runtime p in
let exported_runtime =
if exported_runtime then Some (Code.Var.fresh_n "runtime", ref false) else None
in
Expand All @@ -1971,7 +1972,7 @@ let f
~freevars
p.blocks
live_vars
cps_calls
trampolined_calls
share
debug
in
Expand Down
2 changes: 1 addition & 1 deletion compiler/lib/generate.mli
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ val f :
Code.program
-> exported_runtime:bool
-> live_vars:Deadcode.variable_uses
-> cps_calls:Effects.cps_calls
-> trampolined_calls:Effects.trampolined_calls
-> should_export:bool
-> warn_on_unhandled_effect:bool
-> deadcode_sentinal:Code.Var.t
Expand Down
Loading