Skip to content

Commit

Permalink
Refactor -- lookup_fun was not taking into account trait clauses
Browse files Browse the repository at this point in the history
  • Loading branch information
protz committed May 15, 2024
1 parent 5323c3d commit 9caf19f
Showing 1 changed file with 113 additions and 98 deletions.
211 changes: 113 additions & 98 deletions lib/AstOfLlbc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,8 @@ type clause_binder = {

let push_clause_binder env { clause_id; item_name; t; sig_info; _ } =
{ env with
(* (1* important for diff calculations *1) *)
(* cg_binders = (C.ConstGenericVarId.of_int max_int, t) :: env.cg_binders; *)
(* Perhaps (?) important for diff calculations *)
cg_binders = (C.ConstGenericVarId.of_int max_int, t) :: env.cg_binders;
binders = (TraitClauseMethod (clause_id, item_name, sig_info), t) :: env.binders }

let push_clause_binders env bs =
Expand Down Expand Up @@ -679,6 +679,105 @@ let maybe_addrof (env: env) (ty: C.ty) (e: K.expr) =
| _ ->
K.(with_type (TBuf (e.typ, false)) (EAddrOf e))

(** Handling trait clauses as dictionaries *)

(* Using tests/where_clauses_simple as an example.
fn double<T: Ops + Copy, U: Ops+Copy> (...)
this gets desugared to fn double<T,U> where
T: Ops, <-- ClauseId 0 (required_methods: add, of_u32)
T: Copy, <-- ClauseId 1 (builtin, so neither required nor provided methods)
U: Ops, <-- ClauseId 2 (required_methods: add, of_u32)
U: Copy, <-- ClauseId 3 (builtin, so neither required nor provided methods)
the types we obtain by looking up the trait declaration have Self as 0
(DeBruijn).
*)
let build_trait_clause_mapping env (trait_clauses: C.trait_clause list) =
List.concat_map (fun tc ->
let { C.clause_id; trait_id; clause_generics; _ } = tc in
let trait_decl = env.get_nth_trait_decl trait_id in

(* FYI, some clauses like Copy have neither required nor provided methods. *)
Krml.KPrint.bprintf "clause id %d: FYI, clause_generic type is %s (required: %d) (provided: %d)\n"
(C.TraitClauseId.to_int clause_id)
(String.concat " ++ " (List.map C.show_ty (clause_generics.C.types)))
(List.length trait_decl.C.required_methods)
(List.length trait_decl.C.provided_methods);

(* The type variable from the function that the trait clause is applied to *)
let tvarid =
match clause_generics.C.types with
| [ TVar tvarid ] -> tvarid
| _ -> failwith "Clause above does not refer to a type var? confused"
in

List.map (fun (item_name, decl_id) ->
let decl = env.get_nth_function decl_id in
(clause_id, item_name), (tvarid, trait_decl.C.name, decl.C.signature)
) trait_decl.C.required_methods @
List.map (fun (item_name, decl_id) ->
match decl_id with
| Some decl_id ->
let decl = env.get_nth_function decl_id in
(clause_id, item_name), (tvarid, trait_decl.C.name, decl.C.signature)
| None ->
failwith ("TODO: handle provided trait methods, like " ^ item_name)
) trait_decl.C.provided_methods
) trait_clauses


(* Assumes type variables have been suitably bound in the environment *)
let rec mk_clause_binders_and_args env clause_mapping: clause_binder list =
List.map (fun ((clause_id, item_name), (tvarid, trait_name, (signature: C.fun_sig))) ->
let tvar = lookup_typ env tvarid in
let t = thd3 (typ_of_signature env signature) in
let t = Krml.DeBruijn.subst_t (TBound tvar) 0 t in
let pretty_name = string_of_name env trait_name ^ "_" ^ item_name in
let sig_info = {
n_type_args = List.length signature.generics.types - 1;
n_cg_args = List.length signature.generics.const_generics;
} in
{ pretty_name; t; clause_id; item_name; sig_info }
) clause_mapping


and typ_of_signature env signature =
let { C.generics = { types = type_params; const_generics; _ }; inputs; output ; _ } = signature in
let const_generics_ts = List.map (fun (c: C.const_generic_var) -> typ_of_literal_ty env c.ty) const_generics in
let env = push_cg_binders env const_generics in
let env = push_type_binders env type_params in

let clause_mapping = build_trait_clause_mapping env signature.C.generics.trait_clauses in
debug_trait_clause_mapping env clause_mapping;
let clause_binders = mk_clause_binders_and_args env clause_mapping in
let clause_ts = List.map (fun { t; _ } -> t) clause_binders in

let inputs = List.map (typ_of_ty env) inputs in
let output = typ_of_ty env output in
let adjusted_inputs =
if const_generics_ts = [] && inputs = [] then [ K.TUnit ] else const_generics_ts @ clause_ts @ inputs
in

let t = Krml.Helpers.fold_arrow adjusted_inputs output in
List.length const_generics_ts, List.length type_params, t


and debug_trait_clause_mapping env mapping =
let open Krml in
if mapping <> [] then
KPrint.bprintf "In this function, calls to trait bound methods are as follows:\n";
List.iter (fun ((clause_id, item_name), (_, trait_name, signature)) ->
let t = thd3 (typ_of_signature env signature) in
KPrint.bprintf "TraitClause %d (a.k.a. %s)::%s: %a\n"
(C.TraitClauseId.to_int clause_id) (string_of_name env trait_name) item_name PrintAst.Ops.ptyp t
) mapping


(* Compilation function instantiations *)


type lookup_result = {
f: K.expr';
n_type_args: int; (* just for a sanity check *)
Expand All @@ -701,18 +800,24 @@ let lookup_fun (env: env) depth (f: C.fn_ptr): lookup_result =
builtin b
| None ->
let lookup_result_of_signature f signature =
let { C.generics = { types = type_params; const_generics; _ }; inputs; output; _ } = signature in
let { C.generics = { types = type_params; const_generics; trait_clauses; _ }; inputs; output; _ } = signature in
L.log "Calls" "%s--> name: %a" depth Krml.PrintAst.Ops.pexpr (K.with_type TUnit f);
L.log "Calls" "%s--> args: %s, ret: %s" depth
(String.concat " ++ " (List.map (Charon.PrintTypes.ty_to_string env.format_env) inputs))
(Charon.PrintTypes.ty_to_string env.format_env output);
let env = push_cg_binders env const_generics in
let env = push_type_binders env type_params in

let clause_mapping = build_trait_clause_mapping env trait_clauses in
debug_trait_clause_mapping env clause_mapping;
let clause_binders = mk_clause_binders_and_args env clause_mapping in
let clause_ts = List.map (fun { t; _ } -> t) clause_binders in

{
f;
n_type_args = List.length type_params;
cg_types = List.map (fun (v: C.const_generic_var) -> typ_of_literal_ty env v.ty) const_generics;
arg_types = List.map (typ_of_ty env) inputs;
arg_types = clause_ts @ List.map (typ_of_ty env) inputs;
ret_type = typ_of_ty env output;
is_assumed = false
}
Expand Down Expand Up @@ -809,7 +914,7 @@ let rec expression_of_fn_ptr env depth (fn_ptr: C.fn_ptr) =
fst3 (expression_of_fn_ptr env (depth ^ " ") fn_ptr)
) (trait_impl.required_methods @ trait_impl.provided_methods)
) trait_refs in
L.log "Calls" "%s--> trait method impls: %d\n" depth (List.length fn_ptrs);
L.log "Calls" "%s--> trait method impls: %d" depth (List.length fn_ptrs);
let const_generic_args = const_generic_args @ fn_ptrs in

(* This needs to match what is done in the FunGroup case (i.e. when we extract
Expand All @@ -828,6 +933,8 @@ let rec expression_of_fn_ptr env depth (fn_ptr: C.fn_ptr) =
n_type_params (List.length type_args);
let poly_t_sans_cgs = Krml.Helpers.fold_arrow inputs output in
let poly_t = Krml.Helpers.fold_arrow cg_inputs poly_t_sans_cgs in
L.log "Calls" "%s--> poly_t_sans_cgs: %a" depth Krml.PrintAst.Ops.ptyp poly_t_sans_cgs;
L.log "Calls" "%s--> poly_t: %a" depth Krml.PrintAst.Ops.ptyp poly_t;
let output, t =
let offset = List.length env.binders - List.length env.cg_binders in
Krml.DeBruijn.(subst_ctn offset const_generic_args (subst_tn type_args output)),
Expand All @@ -840,6 +947,7 @@ let rec expression_of_fn_ptr env depth (fn_ptr: C.fn_ptr) =
else
hd
in
L.log "Calls" "%s--> hd: %a" depth Krml.PrintAst.Ops.pexpr hd;
hd, is_assumed, output

let expression_of_fn_ptr env (fn_ptr: C.fn_ptr) =
Expand Down Expand Up @@ -1185,99 +1293,6 @@ let rec expression_of_raw_statement (env: env) (ret_var: C.var_id) (s: C.raw_sta

(** Top-level declarations: orchestration *)

(* Using tests/where_clauses_simple as an example.
fn double<T: Ops + Copy, U: Ops+Copy> (...)
this gets desugared to fn double<T,U> where
T: Ops, <-- ClauseId 0 (required_methods: add, of_u32)
T: Copy, <-- ClauseId 1 (builtin, so neither required nor provided methods)
U: Ops, <-- ClauseId 2 (required_methods: add, of_u32)
U: Copy, <-- ClauseId 3 (builtin, so neither required nor provided methods)
the types we obtain by looking up the trait declaration have Self as 0
(DeBruijn).
*)
let build_trait_clause_mapping env (trait_clauses: C.trait_clause list) =
List.concat_map (fun tc ->
let { C.clause_id; trait_id; clause_generics; _ } = tc in
let trait_decl = env.get_nth_trait_decl trait_id in

(* FYI, some clauses like Copy have neither required nor provided methods. *)
Krml.KPrint.bprintf "clause id %d: FYI, clause_generic type is %s (required: %d) (provided: %d)\n"
(C.TraitClauseId.to_int clause_id)
(String.concat " ++ " (List.map C.show_ty (clause_generics.C.types)))
(List.length trait_decl.C.required_methods)
(List.length trait_decl.C.provided_methods);

(* The type variable from the function that the trait clause is applied to *)
let tvarid =
match clause_generics.C.types with
| [ TVar tvarid ] -> tvarid
| _ -> failwith "Clause above does not refer to a type var? confused"
in

List.map (fun (item_name, decl_id) ->
let decl = env.get_nth_function decl_id in
(clause_id, item_name), (tvarid, trait_decl.C.name, decl.C.signature)
) trait_decl.C.required_methods @
List.map (fun (item_name, decl_id) ->
match decl_id with
| Some decl_id ->
let decl = env.get_nth_function decl_id in
(clause_id, item_name), (tvarid, trait_decl.C.name, decl.C.signature)
| None ->
failwith ("TODO: handle provided trait methods, like " ^ item_name)
) trait_decl.C.provided_methods
) trait_clauses


(* Assumes type variables have been suitably bound in the environment *)
let rec mk_clause_binders_and_args env clause_mapping: clause_binder list =
List.map (fun ((clause_id, item_name), (tvarid, trait_name, (signature: C.fun_sig))) ->
let tvar = lookup_typ env tvarid in
let t = thd3 (typ_of_signature env signature) in
let t = Krml.DeBruijn.subst_t (TBound tvar) 0 t in
let pretty_name = string_of_name env trait_name ^ "_" ^ item_name in
let sig_info = {
n_type_args = List.length signature.generics.types - 1;
n_cg_args = List.length signature.generics.const_generics;
} in
{ pretty_name; t; clause_id; item_name; sig_info }
) clause_mapping


and typ_of_signature env signature =
let { C.generics = { types = type_params; const_generics; _ }; inputs; output ; _ } = signature in
let const_generics_ts = List.map (fun (c: C.const_generic_var) -> typ_of_literal_ty env c.ty) const_generics in
let env = push_cg_binders env const_generics in
let env = push_type_binders env type_params in

let clause_mapping = build_trait_clause_mapping env signature.C.generics.trait_clauses in
debug_trait_clause_mapping env clause_mapping;
let clause_binders = mk_clause_binders_and_args env clause_mapping in
let clause_ts = List.map (fun { t; _ } -> t) clause_binders in

let inputs = List.map (typ_of_ty env) inputs in
let output = typ_of_ty env output in
let adjusted_inputs =
if const_generics_ts = [] && inputs = [] then [ K.TUnit ] else const_generics_ts @ clause_ts @ inputs
in

let t = Krml.Helpers.fold_arrow adjusted_inputs output in
List.length const_generics_ts, List.length type_params, t


and debug_trait_clause_mapping env mapping =
let open Krml in
if mapping <> [] then
KPrint.bprintf "In this function, calls to trait bound methods are as follows:\n";
List.iter (fun ((clause_id, item_name), (_, trait_name, signature)) ->
let t = thd3 (typ_of_signature env signature) in
KPrint.bprintf "TraitClause %d (a.k.a. %s)::%s: %a\n"
(C.TraitClauseId.to_int clause_id) (string_of_name env trait_name) item_name PrintAst.Ops.ptyp t
) mapping



let of_declaration_group (dg: 'id C.g_declaration_group) (f: 'id -> 'a): 'a list =
Expand Down

0 comments on commit 9caf19f

Please sign in to comment.