Skip to content

Commit

Permalink
Handle closures without capture state. Implement core::array::from_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
msprotz committed Feb 23, 2024
1 parent 7780d2b commit e065f59
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 116 deletions.
1 change: 1 addition & 0 deletions bin/main.ml
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ Supported options:|}
Eurydice.Logging.log "Phase2.5" "%a" pfiles files;
let files = Krml.Inlining.cross_call_analysis files in
let files = Krml.Simplify.remove_unused files in
let files = Eurydice.Cleanup2.remove_array_from_fn#visit_files () files in
(* Macros stemming from globals *)
let files, macros = Eurydice.Cleanup2.build_macros files in

Expand Down
258 changes: 142 additions & 116 deletions lib/AstOfLlbc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,10 @@ let rec typ_of_ty (env: env) (ty: Charon.Types.ty): K.typ =

| TAdt (TTuple, { types = args; const_generics; _ }) ->
assert (const_generics = []);
if args = [] then
TUnit
else begin
assert (List.length args > 1);
TTuple (List.map (typ_of_ty env) args)
begin match args with
| [] -> TUnit
| [ t ] -> typ_of_ty env t (* happens with closures *)
| _ -> TTuple (List.map (typ_of_ty env) args)
end

| TAdt (TAssumed TArray, { types = [ t ]; const_generics = [ cg ]; _ }) ->
Expand Down Expand Up @@ -587,6 +586,114 @@ let maybe_addrof (env: env) (ty: C.ty) (e: K.expr) =
| _ ->
K.(with_type (TBuf (e.typ, false)) (EAddrOf e))

type lookup_result = {
name: K.lident;
n_type_args: int; (* just for a sanity check *)
cg_types: K.typ list;
arg_types: K.typ list;
ret_type: K.typ;
is_assumed: bool;
}

let lookup_fun (env: env) (f: C.fn_ptr): lookup_result =
let open RustNames in
let matches p = Charon.NameMatcher.match_fn_ptr env.name_ctx RustNames.config p f in
let builtin b =
let { Builtin.name; typ; n_type_args; cg_args; _ } = b in
let ret_type, arg_types = Krml.Helpers.flatten_arrow typ in
{ name; n_type_args; arg_types; ret_type; cg_types = cg_args; is_assumed = true }
in
match List.find_opt (fun (p, _) -> matches p) known_builtins with
| Some (_, b) ->
builtin b
| None ->
let regular f =
let { C.name; signature = { generics = { types = type_params; const_generics; _ }; inputs; output; _ }; _ } = env.get_nth_function f in
L.log "Calls" "--> name: %s" (string_of_name env name);
L.log "Calls" "--> args: %s, ret: %s"
(String.concat " ++ " (List.map (Charon.PrintTypes.ty_to_string env.format_env) inputs))
(Charon.PrintTypes.ty_to_string env.format_env output);
let env = push_cg_binders env const_generics in
let env = push_type_binders env type_params in
{
name = lid_of_name env name;
n_type_args = List.length type_params;
cg_types = List.map (fun (v: C.const_generic_var) -> typ_of_literal_ty env v.ty) const_generics;
arg_types = List.map (typ_of_ty env) inputs;
ret_type = typ_of_ty env output;
is_assumed = false
}
in
match f.func with
| FunId (FRegular f) ->
regular f

| FunId (FAssumed f) ->
Krml.Warn.fatal_error "unknown assumed function: %s" (C.show_assumed_fun_id f)

| TraitMethod (trait_ref, method_name, _trait_opaque_signature) ->
match trait_ref.trait_id with
| TraitImpl id ->
let trait = env.get_nth_trait_impl id in
let f = List.assoc method_name trait.required_methods in
regular f
| _ ->
Krml.Warn.fatal_error "Error looking trait ref: %s %s"
(Charon.PrintTypes.trait_ref_to_string env.format_env trait_ref) method_name

let expression_of_fn_ptr env (fn_ptr: C.fn_ptr) =
let {
C.generics = { types = type_args; const_generics = const_generic_args; trait_refs; _ };
trait_and_method_generic_args;
_
} = fn_ptr in

(* General case for function calls and trait method calls. *)
L.log "Calls" "Visiting call: %s" (Charon.PrintExpressions.fn_ptr_to_string env.format_env fn_ptr);
L.log "Calls" "is_array_map: %b" (RustNames.is_array_map env fn_ptr);
L.log "Calls" "--> %d type_args, %d const_generics, %d trait_refs"
(List.length type_args) (List.length const_generic_args) (List.length trait_refs);

let type_args, const_generic_args, trait_refs =
match trait_and_method_generic_args with
| None ->
type_args, const_generic_args, trait_refs
| Some { types; const_generics; trait_refs; _ } ->
types @ type_args, const_generics @ const_generic_args, trait_refs @ trait_refs
in
L.log "Calls" "--> %d type_args, %d const_generics, %d trait_refs"
(List.length type_args) (List.length const_generic_args) (List.length trait_refs);
L.log "Calls" "--> trait_refs: %s\n"
(String.concat " ++ " (List.map (Charon.PrintTypes.trait_ref_to_string env.format_env) trait_refs));
L.log "Calls" "--> pattern: %s" (string_of_fn_ptr env fn_ptr);

let type_args = List.map (typ_of_ty env) type_args in
let const_generic_args = List.map (expression_of_const_generic env) const_generic_args in
let { name; n_type_args = n_type_params; arg_types = inputs; ret_type = output; cg_types = cg_inputs; is_assumed } =
lookup_fun env fn_ptr
in

let inputs = if inputs = [] && not is_assumed then [ K.TUnit ] else inputs in
if not (n_type_params = List.length type_args) then
Krml.Warn.fatal_error "%a: n_type_params %d != type_args %d"
Krml.PrintAst.Ops.plid name
n_type_params (List.length type_args);
let poly_t_sans_cgs = Krml.Helpers.fold_arrow inputs output in
let poly_t = Krml.Helpers.fold_arrow cg_inputs poly_t_sans_cgs in
let output, t =
Krml.DeBruijn.(subst_ctn (List.length env.binders) const_generic_args (subst_tn type_args output)),
Krml.DeBruijn.(subst_ctn (List.length env.binders) const_generic_args (subst_tn type_args poly_t_sans_cgs))
in
let hd =
let hd = K.with_type poly_t (K.EQualified name) in
if type_args <> [] || const_generic_args <> [] then
K.with_type t (K.ETApp (hd, const_generic_args, type_args))
else
hd
in
hd, is_assumed, output


let expression_of_rvalue (env: env) (p: C.rvalue): K.expr =
match p with
| Use op ->
Expand Down Expand Up @@ -632,8 +739,21 @@ let expression_of_rvalue (env: env) (p: C.rvalue): K.expr =
end
| Aggregate (AggregatedAdt (TAssumed _, _, _), _) ->
failwith "unsupported: AggregatedAdt / TAssume"
| Aggregate (AggregatedClosure _, _) ->
failwith "unsupported: AggregatedClosure"

| Aggregate (AggregatedClosure (func, generics), ops) ->
if ops <> [] then
failwith (Printf.sprintf "unsupported: AggregatedClosure (TODO: closure conversion): %d" (List.length ops))
else
let fun_ptr = { C.func = C.FunId (FRegular func); generics; trait_and_method_generic_args = None } in
let e, _, _ = expression_of_fn_ptr env fun_ptr in
begin match e.typ with
| TArrow (TBuf (TUnit, _) as t_state, t) ->
(* Empty closure block, passed by address...? TBD *)
K.(with_type t (EApp (e, [ with_type t_state (EAddrOf Krml.Helpers.eunit) ])))
| _ ->
assert false
end

| Aggregate (AggregatedArray (t, cg), ops) ->
K.with_type (TArray (typ_of_ty env t, constant_of_scalar_value (assert_cg_scalar cg))) (K.EBufCreateL (Stack, List.map (expression_of_operand env) ops))
| Global id ->
Expand All @@ -652,61 +772,6 @@ let expression_of_assertion (env: env) ({ cond; expected }: C.assertion): K.expr
with_type TAny (EAbort (None, Some "assert failure")),
Krml.Helpers.eunit)))

type lookup_result = {
name: K.lident;
n_type_args: int; (* just for a sanity check *)
cg_types: K.typ list;
arg_types: K.typ list;
ret_type: K.typ;
is_assumed: bool;
}

let lookup_fun (env: env) (f: C.fn_ptr): lookup_result =
let open RustNames in
let matches p = Charon.NameMatcher.match_fn_ptr env.name_ctx RustNames.config p f in
let builtin b =
let { Builtin.name; typ; n_type_args; cg_args; _ } = b in
let ret_type, arg_types = Krml.Helpers.flatten_arrow typ in
{ name; n_type_args; arg_types; ret_type; cg_types = cg_args; is_assumed = true }
in
match List.find_opt (fun (p, _) -> matches p) known_builtins with
| Some (_, b) ->
builtin b
| None ->
let regular f =
let { C.name; signature = { generics = { types = type_params; const_generics; _ }; inputs; output; _ }; _ } = env.get_nth_function f in
L.log "Calls" "--> name: %s" (string_of_name env name);
L.log "Calls" "--> args: %s, ret: %s"
(String.concat " ++ " (List.map (Charon.PrintTypes.ty_to_string env.format_env) inputs))
(Charon.PrintTypes.ty_to_string env.format_env output);
let env = push_cg_binders env const_generics in
let env = push_type_binders env type_params in
{
name = lid_of_name env name;
n_type_args = List.length type_params;
cg_types = List.map (fun (v: C.const_generic_var) -> typ_of_literal_ty env v.ty) const_generics;
arg_types = List.map (typ_of_ty env) inputs;
ret_type = typ_of_ty env output;
is_assumed = false
}
in
match f.func with
| FunId (FRegular f) ->
regular f

| FunId (FAssumed f) ->
Krml.Warn.fatal_error "unknown assumed function: %s" (C.show_assumed_fun_id f)

| TraitMethod (trait_ref, method_name, _trait_opaque_signature) ->
match trait_ref.trait_id with
| TraitImpl id ->
let trait = env.get_nth_trait_impl id in
let f = List.assoc method_name trait.required_methods in
regular f
| _ ->
Krml.Warn.fatal_error "Error looking trait ref: %s %s"
(Charon.PrintTypes.trait_ref_to_string env.format_env trait_ref) method_name

let lesser t1 t2 =
if t1 = K.TAny then
t2
Expand Down Expand Up @@ -819,63 +884,19 @@ let rec expression_of_raw_statement (env: env) (ret_var: C.var_id) (s: C.raw_sta
H.with_unit (K.EBufWrite (Krml.DeBruijn.lift 1 dest, i,
K.with_type t (K.EBufRead (Krml.DeBruijn.lift 1 src, i))))))

| Call { func = FnOpRegular ({
func;
generics = { types = type_args; const_generics = const_generic_args; trait_refs; _ };
trait_and_method_generic_args
} as fn_ptr); args; dest; _ }
->
(* General case for function calls and trait method calls. *)
L.log "Calls" "Visiting call: %s" (Charon.PrintExpressions.fn_ptr_to_string env.format_env fn_ptr);
L.log "Calls" "is_array_map: %b" (RustNames.is_array_map env fn_ptr);
L.log "Calls" "--> %d type_args, %d const_generics, %d trait_refs"
(List.length type_args) (List.length const_generic_args) (List.length trait_refs);
| Call { func = FnOpRegular fn_ptr; args; dest; _ } ->

(* For now, we take trait type arguments to be part of the code-gen *)
let type_args, const_generic_args, trait_refs =
match trait_and_method_generic_args with
| None ->
type_args, const_generic_args, trait_refs
| Some { types; const_generics; trait_refs; _ } ->
types @ type_args, const_generics @ const_generic_args, trait_refs @ trait_refs
in
L.log "Calls" "--> %d type_args, %d const_generics, %d trait_refs"
(List.length type_args) (List.length const_generic_args) (List.length trait_refs);
L.log "Calls" "--> trait_refs: %s\n"
(String.concat " ++ " (List.map (Charon.PrintTypes.trait_ref_to_string env.format_env) trait_refs));
L.log "Calls" "--> pattern: %s" (string_of_fn_ptr env fn_ptr);

let hd, is_assumed, output_t = expression_of_fn_ptr env fn_ptr in
let dest, _ = expression_of_place env dest in
let args = List.map (expression_of_operand env) args in
let original_type_args = type_args in
let type_args = List.map (typ_of_ty env) type_args in
let const_generic_args = List.map (expression_of_const_generic env) const_generic_args in
let { name; n_type_args = n_type_params; arg_types = inputs; ret_type = output; cg_types = cg_inputs; is_assumed } =
lookup_fun env fn_ptr
in
let args = if args = [] && not is_assumed then [ Krml.Helpers.eunit ] else args in
let inputs = if inputs = [] && not is_assumed then [ K.TUnit ] else inputs in
if not (n_type_params = List.length type_args) then
Krml.Warn.fatal_error "%a: n_type_params %d != type_args %d"
Krml.PrintAst.Ops.plid name
n_type_params (List.length type_args);
let poly_t_sans_cgs = Krml.Helpers.fold_arrow inputs output in
let poly_t = Krml.Helpers.fold_arrow cg_inputs poly_t_sans_cgs in
let output, t =
Krml.DeBruijn.(subst_ctn (List.length env.binders) const_generic_args (subst_tn type_args output)),
Krml.DeBruijn.(subst_ctn (List.length env.binders) const_generic_args (subst_tn type_args poly_t_sans_cgs))
in
let hd =
let hd = K.with_type poly_t (K.EQualified name) in
if type_args <> [] || const_generic_args <> [] then
K.with_type t (K.ETApp (hd, const_generic_args, type_args))
else
hd
in
let rhs = K.with_type output (K.EApp (hd, args)) in
let rhs = K.with_type output_t (K.EApp (hd, args)) in
(* This does something similar to maybe_addrof *)
let rhs =
match func, original_type_args with
(* TODO: determine whether extra_types is necessary *)
let extra_types = match fn_ptr.trait_and_method_generic_args with Some { types; _ } -> types | None -> [] in
match fn_ptr.func, fn_ptr.generics.types @ extra_types with
| FunId (FAssumed (SliceIndexShared | SliceIndexMut)), [ TAdt (TAssumed (TArray | TSlice), _) ] ->
(* Will decay. See comment above maybe_addrof *)
rhs
Expand Down Expand Up @@ -966,7 +987,6 @@ let rec expression_of_raw_statement (env: env) (ret_var: C.var_id) (s: C.raw_sta
K.(with_type TUnit (EWhile (Krml.Helpers.etrue,
expression_of_raw_statement env ret_var s.content)))


(** Top-level declarations: orchestration *)

let of_declaration_group (dg: 'id C.g_declaration_group) (f: 'id -> 'a): 'a list =
Expand Down Expand Up @@ -1047,6 +1067,7 @@ let decls_of_declarations (env: env) (d: C.declaration_group): K.decl list =
(Printexc.get_backtrace ());
None
end

| Some { arg_count; locals; body; _ } ->
if is_global_decl_body then
None
Expand All @@ -1061,22 +1082,27 @@ let decls_of_declarations (env: env) (d: C.declaration_group): K.decl list =
let args = List.tl args in

let return_type = typ_of_ty env return_var.var_ty in

(* Note: Rust allows zero-argument functions but the krml internal
representation wants a unit there. *)
let t_unit = C.(TAdt (TTuple, { types = []; const_generics = []; regions = []; trait_refs = [] })) in
let v_unit = { C.index = Charon.Expressions.VarId.of_int max_int; name = None; var_ty = t_unit } in
let args = if args = [] then [ v_unit ] else args in

let arg_binders = List.map (fun (arg: C.const_generic_var) ->
Krml.Helpers.fresh_binder ~mut:true arg.name (typ_of_literal_ty env arg.ty)
) signature.C.generics.const_generics @ List.map (fun (arg: C.var) ->
let name = Option.value ~default:"_" arg.name in
Krml.Helpers.fresh_binder ~mut:true name (typ_of_ty env arg.var_ty)
) args in
(* Note: Rust allows zero-argument functions but the krml internal
representation wants a unit there. *)
let arg_binders = if arg_binders = [] then [ Krml.Helpers.fresh_binder "dummy" K.TUnit ] else arg_binders in
let env = push_binders env args in
let body =
with_locals env return_type (return_var :: locals) (fun env ->
expression_of_raw_statement env return_var.index body.content)
in
Some (K.DFunction (None, [], List.length signature.C.generics.const_generics,
List.length signature.C.generics.types, return_type, name, arg_binders, body))

)
| GlobalGroup id ->
let global = env.get_nth_global id in
Expand Down
36 changes: 36 additions & 0 deletions lib/Cleanup2.ml
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,42 @@ let remove_array_repeats = object(self)
super#visit_ELet env b e1 e2
end

let remove_array_from_fn = object
inherit [_] map as super

val mutable defs = Hashtbl.create 41

method! visit_DFunction _ cc flags n_cgs n t name bs e =
assert (n_cgs = 0 && n = 0);
match bs with
| [{ typ = TInt SizeT; _ }] ->
Hashtbl.add defs name e
| _ ->
()
; ;
super#visit_DFunction () cc flags n_cgs n t name bs e

method! visit_EApp env e es =
match e.node with
| ETApp ({ node = EQualified (["core"; "array"], "from_fn"); _ },
[ len ],
[ t_elements; TArrow (t_index, t_elements') ]) ->
assert (t_elements' = t_elements);
assert (t_index = TInt SizeT);
assert (List.length es = 2);
let closure = Krml.Helpers.assert_elid (List.nth es 0).node in
assert (Hashtbl.mem defs closure);
let dst = List.nth es 1 in
EFor (Krml.Helpers.fresh_binder ~mut:true "i" H.usize, H.zero_usize (* i = 0 *),
H.mk_lt_usize (Krml.DeBruijn.lift 1 len) (* i < len *),
H.mk_incr_usize (* i++ *),
let i = with_type H.usize (EBound 0) in
Krml.Helpers.with_unit (EBufWrite (Krml.DeBruijn.lift 1 dst, i, Hashtbl.find defs closure)))
| _ ->
super#visit_EApp env e es
end


let rewrite_slice_to_array = object(_self)
inherit [_] map as super

Expand Down
Loading

0 comments on commit e065f59

Please sign in to comment.