Skip to content

Commit

Permalink
New algorithm for termination
Browse files Browse the repository at this point in the history
The previous algorithm wasn't sufficient to ensure the termination of
model generation for any mutually recursive ADT. This new algorithm
consists in ordering the nests of an ADT declaration using a topological
sorts.
  • Loading branch information
Halbaroth committed Apr 22, 2024
1 parent c5f5834 commit f583530
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 152 deletions.
146 changes: 73 additions & 73 deletions src/lib/frontend/d_cnf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -625,9 +625,7 @@ let mk_ty_decl (ty_c: DE.ty_cst) =
let ty = Ty.tsum name cstrs in
Cache.store_ty (DE.Ty.Const.hash ty_c) ty
else
let body = Some (List.rev rev_cs) in
let ty = Ty.t_adt ~body name tyvl in
Cache.store_ty (DE.Ty.Const.hash ty_c) ty
Ty.declare_adts [name, List.rev rev_cs, tyvl]

| None | Some Abstract ->
let name = get_basename ty_c.path in
Expand Down Expand Up @@ -662,12 +660,10 @@ let mk_term_decl ({ id_ty; path; tags; _ } as tcst: DE.term_cst) =
- If one of the types is an ADT, the ADTs that have only one case are
considered as ADTs as well and not as records. *)
let mk_mr_ty_decls (tdl: DE.ty_cst list) =
let handle_ty_decl (ty: Ty.t) (tdef: DE.Ty.def option) =
let handle_record_ty_decl (ty: Ty.t) (tdef: DE.Ty.def) =
match ty, tdef with
| Trecord { args; name; record_constr; _ },
Some (
Adt { cases = [| { dstrs; _ } |]; ty = ty_c; _ }
) ->
Adt { cases = [| { dstrs; _ } |]; ty = ty_c; _ } ->
let rev_lbs =
Array.fold_left (
fun acc c ->
Expand All @@ -690,31 +686,6 @@ let mk_mr_ty_decls (tdl: DE.ty_cst list) =
in
Cache.store_ty (DE.Ty.Const.hash ty_c) ty

| Tadt (hs, tyl),
Some (
Adt { cases; ty = ty_c; _ }
) ->
let rev_cs =
Array.fold_left (
fun accl DE.{ cstr = { path; _ }; dstrs; _ } ->
let rev_fields =
Array.fold_left (
fun acc tc_o ->
match tc_o with
| Some DE.{ id_ty; path; _ } ->
(get_basename path, dty_to_ty id_ty) :: acc
| None -> assert false
) [] dstrs
in
let name = get_basename path in
(name, List.rev rev_fields) :: accl
) [] cases
in
let body = Some (List.rev rev_cs) in
let args = tyl in
let ty = Ty.t_adt ~body (Hstring.view hs) args in
Cache.store_ty (DE.Ty.Const.hash ty_c) ty

| _ -> assert false
in
(* If there are adts in the list of type declarations then records are
Expand All @@ -730,51 +701,80 @@ let mk_mr_ty_decls (tdl: DE.ty_cst list) =
| df -> df :: acc, ca
) ([], false) tdl
in
let rev_l =
List.fold_left (
fun acc tdef ->
match tdef with
| Some (
(DE.Adt { cases; record; ty = ty_c; }) as adt
) ->
let tyvl = Cache.store_ty_vars_ret cases.(0).cstr.id_ty in
let name = get_basename ty_c.path in

let cns, is_enum =
Array.fold_right (
fun DE.{ dstrs; cstr = { path; _ }; _ } (nacc, is_enum) ->
get_basename path :: nacc,
Array.length dstrs = 0 && is_enum
) cases ([], true)
in
if is_enum && not contains_adts
then (
let ty = Ty.tsum name cns in

if contains_adts then (
let rev_l =
List.fold_left (
fun acc tdef ->
match tdef with
| Some (
(DE.Adt { cases; ty = ty_c; _ }) as adt
) ->
let tyvl = Cache.store_ty_vars_ret cases.(0).cstr.id_ty in
let name = get_basename ty_c.path in
let ty = Ty.t_adt name tyvl in
Cache.store_ty (DE.Ty.Const.hash ty_c) ty;
(* If it's an enum we don't need the second iteration. *)
acc
)
else (
(ty, adt) :: acc
| None
| Some Abstract ->
assert false (* unreachable in the second iteration *)
) [] (List.rev rev_tdefs)
in

let adts =
List.map (
fun (t, d) ->
match t, d with
| Ty.Tadt (hs, params), DE.Adt { cases; _ } ->
let cases =
Array.map (
fun DE.{ cstr = { path; _ }; dstrs; _ } ->
let fields =
Array.map (function
| Some DE.{ id_ty; path; _ } ->
get_basename path, dty_to_ty id_ty
| None -> assert false
) dstrs
|> Array.to_list
in
get_basename path, fields
) cases
|> Array.to_list
in
Hstring.view hs, cases, params
| _ -> assert false
) (List.rev rev_l)
in
Ty.declare_adts adts

) else (
let rev_l =
List.fold_left (
fun acc tdef ->
match tdef with
| Some (
(DE.Adt { cases; record; ty = ty_c; }) as adt
) ->
let tyvl = Cache.store_ty_vars_ret cases.(0).cstr.id_ty in
let name = get_basename ty_c.path in
assert (record || Array.length cases = 1);
let ty =
if (record || Array.length cases = 1) && not contains_adts
then
let record_constr =
Format.asprintf "%a" DStd.Path.print ty_c.path
in
Ty.trecord ~record_constr tyvl name []
else Ty.t_adt name tyvl
let record_constr =
Format.asprintf "%a" DStd.Path.print ty_c.path
in
Ty.trecord ~record_constr tyvl name []
in
Cache.store_ty (DE.Ty.Const.hash ty_c) ty;
(ty, Some adt) :: acc
)
| None
| Some Abstract ->
assert false (* unreachable in the second iteration *)
) [] (List.rev rev_tdefs)
in
List.iter (
fun (t, d) -> handle_ty_decl t d
) (List.rev rev_l)
(ty, adt) :: acc
| None
| Some Abstract ->
assert false (* unreachable in the second iteration *)
) [] (List.rev rev_tdefs)
in
List.iter (
fun (t, d) -> handle_record_ty_decl t d
) (List.rev rev_l)
)

(** Helper function hadle variables that are encoutered in patterns. *)
let handle_patt_var name (DE.{ term_descr; _ } as term) =
Expand Down
10 changes: 6 additions & 4 deletions src/lib/frontend/typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,13 @@ module Types = struct
List.map (fun (field, pp) -> field, ty_of_pp loc env None pp) l
) l
in
let body =
if l == [] then None (* in initialization step, no body *)
else Some l
let ty = Ty.t_adt id ty_vars in
let () =
match l with
| [] -> (* in initialization step, no body *) ()
| _ :: _ ->
Ty.declare_adts [id, l, ty_vars]
in
let ty = Ty.t_adt ~body id ty_vars in
ty, { env with to_ty = MString.add id ty env.to_ty }

let add_builtin env id ty =
Expand Down
9 changes: 8 additions & 1 deletion src/lib/reasoners/adt_rel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,13 @@ let pick_delayed_destructor env =
None
with Found (r, d) -> Some (r, d)

let is_enum ty c =
match ty with
| Ty.Tadt (name, params) ->
let Adt cases = Ty.type_body name params in
Lists.is_empty @@ Ty.assoc_destrs c cases
| _ -> assert false

let pick_tightenable_domain ~for_model env uf =
Domains.fold
(fun r d best ->
Expand All @@ -627,7 +634,7 @@ let pick_tightenable_domain ~for_model env uf =
| _ ->
let cd = Domain.cardinal d in
let c = Domain.choose d in
if for_model || is_tightenable (X.type_info r) c then
if for_model || is_enum (X.type_info r) c then
match best with
| Some (n, _, _) when n <= cd -> best
| Some _ | None -> Some (cd, r, c)
Expand Down
130 changes: 60 additions & 70 deletions src/lib/structures/ty.ml
Original file line number Diff line number Diff line change
Expand Up @@ -525,88 +525,78 @@ let fresh_empty_text =

let tsum s lc = Tsum (Hstring.make s, List.map Hstring.make lc)

(* Count the number of occurences of the nest of [name] in the signature
of payload [l]. *)
let count_same_nest name l =
Lists.sum
(fun (_, ty) ->
match ty with
| Tadt (name', _) when Hstring.equal name name' -> 1
| _ -> 0
) l
let[@inline always] t_adt name params =
Tadt (Hstring.make name, params)

(* Count the number of occurences of the (mutually recursive) ADT type of
[name] in the signature of payload [l]. *)
let count_same_adt name l =
type case = string * (string * t) list
type adt = string * case list * t list

let find_map (type c) =
let exception Found of c
in fun (p : 'a -> 'b -> c option) (h : ('a, 'b) Hashtbl.t) ->
try
Hashtbl.iter
(fun k v ->
match p k v with Some c -> raise @@ Found c | None -> ()
) h;
raise Not_found
with Found c -> c

let count_in_set (set : (string, case list * t list) Hashtbl.t) c =
Lists.sum
(fun (_, ty) ->
match ty with
| Tadt (name', params) ->
(* TODO: this is a hackish way to check that `name'` and `name` are
two nests of the same mutually recursive ADT. We should store
ADT's nests in a data structure as it's done in Dolmen. *)
begin try
let Adt cases = Decls.body name' params in
List.exists
(fun { destrs; _ } ->
List.exists
(fun (_, ty') ->
match ty' with
| Tadt (name'', _) -> Hstring.equal name name''
| _ -> false
) destrs
) cases
|> Bool.to_int
with Not_found ->
(* If we haven't already register the nest [name'], it means that
[name] and [name'] are two nests of the same ADT. *)
1
end
| Tadt (name, _) when Hashtbl.mem set (Hstring.view name) -> 1
| _ -> 0
) l
) (snd c)

(* Comparison function used to ensure the termination of the model
generation for recursive ADT values. *)
let cons_weight name (_, l1) (_, l2) =
let c = count_same_nest name l1 - count_same_nest name l2 in
let cons_weight (set : (string, case list * t list) Hashtbl.t) c1 c2 =
let c = count_in_set set c1 - count_in_set set c2 in
if c <> 0 then c
else
let c = count_same_adt name l1 - count_same_adt name l2 in
if c <> 0 then c
else List.compare_lengths l1 l2

let t_adt ?(body=None) s ty_vars =
let hs = Hstring.make s in
let ty = Tadt (hs, ty_vars) in
begin match body with
| None -> ()
| Some [] -> assert false
| Some ([_] as cases) ->
if Options.get_debug_adt () then
Printer.print_dbg ~module_name:"Ty"
"should be registered as a record";
List.compare_lengths (snd c1) (snd c2)

let topological_sort (adts : adt list) =
let seens : adt list ref = ref [] in
let set : (string, case list * t list) Hashtbl.t = Hashtbl.create 17 in
List.iter
(fun (name, cases, params) -> Hashtbl.add set name (cases, params)) adts;
while Hashtbl.length set > 0 do
let adt, (cases, params) =
find_map (fun adt (cases, params) ->
let b =
List.exists (fun (_ , case) ->
List.for_all (fun (_, ty) ->
match ty with
| Tadt (n, _) ->
not @@ Hashtbl.mem set (Hstring.view n)
| _ -> true
) case
) cases
in
if b then Some (adt, (cases, params)) else None
) set
in
let cases = List.stable_sort (cons_weight set) cases in
Hashtbl.remove set adt;
seens := (adt, cases, params) :: !seens
done;
List.rev !seens

let declare_adts adts =
let adts = topological_sort adts in
List.iter (fun (name, cases, params) ->
let name = Hstring.make name in
let cases =
List.map (fun (s, l) ->
let l =
List.map (fun (d, e) -> Hstring.make d, e) l
List.map (fun (c, destrs) ->
let destrs =
List.map (fun (d, dty) -> Hstring.make d, dty) destrs
in
{constr = Hstring.make s ; destrs = l}
{ constr = Hstring.make c; destrs }
) cases
in
Decls.add hs ty_vars (Adt cases)
| Some cases ->
let cases =
List.stable_sort (cons_weight hs) cases |>
List.map (fun (c, l) ->
let l =
List.map (fun (d, e) -> Hstring.make d, e) l
in
{constr = Hstring.make c; destrs = l}
)
in
Decls.add hs ty_vars (Adt cases)
end;
ty
Decls.add name params (Adt cases)
) adts

let trecord ?(sort_fields = false) ~record_constr lv n lbs =
let lbs = List.map (fun (l,ty) -> Hstring.make l, ty) lbs in
Expand Down
Loading

0 comments on commit f583530

Please sign in to comment.