From f583530bf70cf08920ade3da137a1e83806f7d82 Mon Sep 17 00:00:00 2001 From: Pierre Villemot Date: Mon, 22 Apr 2024 13:09:30 +0200 Subject: [PATCH] New algorithm for termination The previous algorithm wasn't sufficient to ensure the termination of model generation for any mutually recursive ADT. This new algorithm consists in ordering the nests of an ADT declaration using a topological sorts. --- src/lib/frontend/d_cnf.ml | 146 ++++++++++++++++---------------- src/lib/frontend/typechecker.ml | 10 ++- src/lib/reasoners/adt_rel.ml | 9 +- src/lib/structures/ty.ml | 130 +++++++++++++--------------- src/lib/structures/ty.mli | 17 +++- 5 files changed, 160 insertions(+), 152 deletions(-) diff --git a/src/lib/frontend/d_cnf.ml b/src/lib/frontend/d_cnf.ml index 671eb5e486..9f963cfbae 100644 --- a/src/lib/frontend/d_cnf.ml +++ b/src/lib/frontend/d_cnf.ml @@ -625,9 +625,7 @@ let mk_ty_decl (ty_c: DE.ty_cst) = let ty = Ty.tsum name cstrs in Cache.store_ty (DE.Ty.Const.hash ty_c) ty else - let body = Some (List.rev rev_cs) in - let ty = Ty.t_adt ~body name tyvl in - Cache.store_ty (DE.Ty.Const.hash ty_c) ty + Ty.declare_adts [name, List.rev rev_cs, tyvl] | None | Some Abstract -> let name = get_basename ty_c.path in @@ -662,12 +660,10 @@ let mk_term_decl ({ id_ty; path; tags; _ } as tcst: DE.term_cst) = - If one of the types is an ADT, the ADTs that have only one case are considered as ADTs as well and not as records. *) let mk_mr_ty_decls (tdl: DE.ty_cst list) = - let handle_ty_decl (ty: Ty.t) (tdef: DE.Ty.def option) = + let handle_record_ty_decl (ty: Ty.t) (tdef: DE.Ty.def) = match ty, tdef with | Trecord { args; name; record_constr; _ }, - Some ( - Adt { cases = [| { dstrs; _ } |]; ty = ty_c; _ } - ) -> + Adt { cases = [| { dstrs; _ } |]; ty = ty_c; _ } -> let rev_lbs = Array.fold_left ( fun acc c -> @@ -690,31 +686,6 @@ let mk_mr_ty_decls (tdl: DE.ty_cst list) = in Cache.store_ty (DE.Ty.Const.hash ty_c) ty - | Tadt (hs, tyl), - Some ( - Adt { cases; ty = ty_c; _ } - ) -> - let rev_cs = - Array.fold_left ( - fun accl DE.{ cstr = { path; _ }; dstrs; _ } -> - let rev_fields = - Array.fold_left ( - fun acc tc_o -> - match tc_o with - | Some DE.{ id_ty; path; _ } -> - (get_basename path, dty_to_ty id_ty) :: acc - | None -> assert false - ) [] dstrs - in - let name = get_basename path in - (name, List.rev rev_fields) :: accl - ) [] cases - in - let body = Some (List.rev rev_cs) in - let args = tyl in - let ty = Ty.t_adt ~body (Hstring.view hs) args in - Cache.store_ty (DE.Ty.Const.hash ty_c) ty - | _ -> assert false in (* If there are adts in the list of type declarations then records are @@ -730,51 +701,80 @@ let mk_mr_ty_decls (tdl: DE.ty_cst list) = | df -> df :: acc, ca ) ([], false) tdl in - let rev_l = - List.fold_left ( - fun acc tdef -> - match tdef with - | Some ( - (DE.Adt { cases; record; ty = ty_c; }) as adt - ) -> - let tyvl = Cache.store_ty_vars_ret cases.(0).cstr.id_ty in - let name = get_basename ty_c.path in - - let cns, is_enum = - Array.fold_right ( - fun DE.{ dstrs; cstr = { path; _ }; _ } (nacc, is_enum) -> - get_basename path :: nacc, - Array.length dstrs = 0 && is_enum - ) cases ([], true) - in - if is_enum && not contains_adts - then ( - let ty = Ty.tsum name cns in + + if contains_adts then ( + let rev_l = + List.fold_left ( + fun acc tdef -> + match tdef with + | Some ( + (DE.Adt { cases; ty = ty_c; _ }) as adt + ) -> + let tyvl = Cache.store_ty_vars_ret cases.(0).cstr.id_ty in + let name = get_basename ty_c.path in + let ty = Ty.t_adt name tyvl in Cache.store_ty (DE.Ty.Const.hash ty_c) ty; - (* If it's an enum we don't need the second iteration. *) - acc - ) - else ( + (ty, adt) :: acc + | None + | Some Abstract -> + assert false (* unreachable in the second iteration *) + ) [] (List.rev rev_tdefs) + in + + let adts = + List.map ( + fun (t, d) -> + match t, d with + | Ty.Tadt (hs, params), DE.Adt { cases; _ } -> + let cases = + Array.map ( + fun DE.{ cstr = { path; _ }; dstrs; _ } -> + let fields = + Array.map (function + | Some DE.{ id_ty; path; _ } -> + get_basename path, dty_to_ty id_ty + | None -> assert false + ) dstrs + |> Array.to_list + in + get_basename path, fields + ) cases + |> Array.to_list + in + Hstring.view hs, cases, params + | _ -> assert false + ) (List.rev rev_l) + in + Ty.declare_adts adts + + ) else ( + let rev_l = + List.fold_left ( + fun acc tdef -> + match tdef with + | Some ( + (DE.Adt { cases; record; ty = ty_c; }) as adt + ) -> + let tyvl = Cache.store_ty_vars_ret cases.(0).cstr.id_ty in + let name = get_basename ty_c.path in + assert (record || Array.length cases = 1); let ty = - if (record || Array.length cases = 1) && not contains_adts - then - let record_constr = - Format.asprintf "%a" DStd.Path.print ty_c.path - in - Ty.trecord ~record_constr tyvl name [] - else Ty.t_adt name tyvl + let record_constr = + Format.asprintf "%a" DStd.Path.print ty_c.path + in + Ty.trecord ~record_constr tyvl name [] in Cache.store_ty (DE.Ty.Const.hash ty_c) ty; - (ty, Some adt) :: acc - ) - | None - | Some Abstract -> - assert false (* unreachable in the second iteration *) - ) [] (List.rev rev_tdefs) - in - List.iter ( - fun (t, d) -> handle_ty_decl t d - ) (List.rev rev_l) + (ty, adt) :: acc + | None + | Some Abstract -> + assert false (* unreachable in the second iteration *) + ) [] (List.rev rev_tdefs) + in + List.iter ( + fun (t, d) -> handle_record_ty_decl t d + ) (List.rev rev_l) + ) (** Helper function hadle variables that are encoutered in patterns. *) let handle_patt_var name (DE.{ term_descr; _ } as term) = diff --git a/src/lib/frontend/typechecker.ml b/src/lib/frontend/typechecker.ml index deea8cb931..8d9154eefa 100644 --- a/src/lib/frontend/typechecker.ml +++ b/src/lib/frontend/typechecker.ml @@ -170,11 +170,13 @@ module Types = struct List.map (fun (field, pp) -> field, ty_of_pp loc env None pp) l ) l in - let body = - if l == [] then None (* in initialization step, no body *) - else Some l + let ty = Ty.t_adt id ty_vars in + let () = + match l with + | [] -> (* in initialization step, no body *) () + | _ :: _ -> + Ty.declare_adts [id, l, ty_vars] in - let ty = Ty.t_adt ~body id ty_vars in ty, { env with to_ty = MString.add id ty env.to_ty } let add_builtin env id ty = diff --git a/src/lib/reasoners/adt_rel.ml b/src/lib/reasoners/adt_rel.ml index 222b7753a1..ae56eaccd9 100644 --- a/src/lib/reasoners/adt_rel.ml +++ b/src/lib/reasoners/adt_rel.ml @@ -617,6 +617,13 @@ let pick_delayed_destructor env = None with Found (r, d) -> Some (r, d) +let is_enum ty c = + match ty with + | Ty.Tadt (name, params) -> + let Adt cases = Ty.type_body name params in + Lists.is_empty @@ Ty.assoc_destrs c cases + | _ -> assert false + let pick_tightenable_domain ~for_model env uf = Domains.fold (fun r d best -> @@ -627,7 +634,7 @@ let pick_tightenable_domain ~for_model env uf = | _ -> let cd = Domain.cardinal d in let c = Domain.choose d in - if for_model || is_tightenable (X.type_info r) c then + if for_model || is_enum (X.type_info r) c then match best with | Some (n, _, _) when n <= cd -> best | Some _ | None -> Some (cd, r, c) diff --git a/src/lib/structures/ty.ml b/src/lib/structures/ty.ml index 7ac9e1fea8..ee0f7b97e5 100644 --- a/src/lib/structures/ty.ml +++ b/src/lib/structures/ty.ml @@ -525,88 +525,78 @@ let fresh_empty_text = let tsum s lc = Tsum (Hstring.make s, List.map Hstring.make lc) -(* Count the number of occurences of the nest of [name] in the signature - of payload [l]. *) -let count_same_nest name l = - Lists.sum - (fun (_, ty) -> - match ty with - | Tadt (name', _) when Hstring.equal name name' -> 1 - | _ -> 0 - ) l +let[@inline always] t_adt name params = + Tadt (Hstring.make name, params) -(* Count the number of occurences of the (mutually recursive) ADT type of - [name] in the signature of payload [l]. *) -let count_same_adt name l = +type case = string * (string * t) list +type adt = string * case list * t list + +let find_map (type c) = + let exception Found of c + in fun (p : 'a -> 'b -> c option) (h : ('a, 'b) Hashtbl.t) -> + try + Hashtbl.iter + (fun k v -> + match p k v with Some c -> raise @@ Found c | None -> () + ) h; + raise Not_found + with Found c -> c + +let count_in_set (set : (string, case list * t list) Hashtbl.t) c = Lists.sum (fun (_, ty) -> match ty with - | Tadt (name', params) -> - (* TODO: this is a hackish way to check that `name'` and `name` are - two nests of the same mutually recursive ADT. We should store - ADT's nests in a data structure as it's done in Dolmen. *) - begin try - let Adt cases = Decls.body name' params in - List.exists - (fun { destrs; _ } -> - List.exists - (fun (_, ty') -> - match ty' with - | Tadt (name'', _) -> Hstring.equal name name'' - | _ -> false - ) destrs - ) cases - |> Bool.to_int - with Not_found -> - (* If we haven't already register the nest [name'], it means that - [name] and [name'] are two nests of the same ADT. *) - 1 - end + | Tadt (name, _) when Hashtbl.mem set (Hstring.view name) -> 1 | _ -> 0 - ) l + ) (snd c) -(* Comparison function used to ensure the termination of the model - generation for recursive ADT values. *) -let cons_weight name (_, l1) (_, l2) = - let c = count_same_nest name l1 - count_same_nest name l2 in +let cons_weight (set : (string, case list * t list) Hashtbl.t) c1 c2 = + let c = count_in_set set c1 - count_in_set set c2 in if c <> 0 then c else - let c = count_same_adt name l1 - count_same_adt name l2 in - if c <> 0 then c - else List.compare_lengths l1 l2 - -let t_adt ?(body=None) s ty_vars = - let hs = Hstring.make s in - let ty = Tadt (hs, ty_vars) in - begin match body with - | None -> () - | Some [] -> assert false - | Some ([_] as cases) -> - if Options.get_debug_adt () then - Printer.print_dbg ~module_name:"Ty" - "should be registered as a record"; + List.compare_lengths (snd c1) (snd c2) + +let topological_sort (adts : adt list) = + let seens : adt list ref = ref [] in + let set : (string, case list * t list) Hashtbl.t = Hashtbl.create 17 in + List.iter + (fun (name, cases, params) -> Hashtbl.add set name (cases, params)) adts; + while Hashtbl.length set > 0 do + let adt, (cases, params) = + find_map (fun adt (cases, params) -> + let b = + List.exists (fun (_ , case) -> + List.for_all (fun (_, ty) -> + match ty with + | Tadt (n, _) -> + not @@ Hashtbl.mem set (Hstring.view n) + | _ -> true + ) case + ) cases + in + if b then Some (adt, (cases, params)) else None + ) set + in + let cases = List.stable_sort (cons_weight set) cases in + Hashtbl.remove set adt; + seens := (adt, cases, params) :: !seens + done; + List.rev !seens + +let declare_adts adts = + let adts = topological_sort adts in + List.iter (fun (name, cases, params) -> + let name = Hstring.make name in let cases = - List.map (fun (s, l) -> - let l = - List.map (fun (d, e) -> Hstring.make d, e) l + List.map (fun (c, destrs) -> + let destrs = + List.map (fun (d, dty) -> Hstring.make d, dty) destrs in - {constr = Hstring.make s ; destrs = l} + { constr = Hstring.make c; destrs } ) cases in - Decls.add hs ty_vars (Adt cases) - | Some cases -> - let cases = - List.stable_sort (cons_weight hs) cases |> - List.map (fun (c, l) -> - let l = - List.map (fun (d, e) -> Hstring.make d, e) l - in - {constr = Hstring.make c; destrs = l} - ) - in - Decls.add hs ty_vars (Adt cases) - end; - ty + Decls.add name params (Adt cases) + ) adts let trecord ?(sort_fields = false) ~record_constr lv n lbs = let lbs = List.map (fun (l,ty) -> Hstring.make l, ty) lbs in diff --git a/src/lib/structures/ty.mli b/src/lib/structures/ty.mli index 476b5c82c1..feba5d3156 100644 --- a/src/lib/structures/ty.mli +++ b/src/lib/structures/ty.mli @@ -176,10 +176,7 @@ val tsum : string -> string list -> t (** Create an enumeration type. [tsum name enums] creates an enumeration named [name], with constructors [enums]. *) -val t_adt : - ?body: ((string * (string * t) list) list) option -> - string -> t list -> - t +val t_adt : string -> t list -> t (** Create an algebraic datatype. The body is a list of constructors, where each constructor is associated with the list of its destructors with their respective types. If [body] is none, @@ -187,6 +184,18 @@ val t_adt : argument is the name of the type. The third one provides its list of arguments. *) +type case = string * (string * t) list + +val declare_adts : (string * case list * t list) list -> unit +(** Declare a mutually recursive algebraic datatype. + + The body is a list of + constructors, where each constructor is associated with the list of + its destructors with their respective types. If [body] is none, + then no definition will be registered for this type. The second + argument is the name of the type. The third one provides its list + of arguments. *) + val trecord : ?sort_fields:bool -> record_constr:string -> t list -> string -> (string * t) list -> t