Skip to content

Commit

Permalink
Model generation for ADT theory
Browse files Browse the repository at this point in the history
This PR implements the model generation for ADT. The model generation is
done by the casesplit mechanism in `Adt_rel`.

- If `--enable-adts-cs` is present, we do casesplits by choosing a
  delayed destructor as before after asserting.

- If there is no more delayed constructors, we look for a domain which
  contains a tightenable constructor (that is a constructor without
  payload) and we choose the domain as small as possible as it does in
  the `Enum` theory.

- If we are in model generation, we can also choose constructors with
  payload (but we still choose first constructors without payload).

- The termination of the model generation is a bit tricky in the case of
  mutually recursive ADT. Please see the tests added for some
  complicated examples. I hope that I caught all the corner cases.
  To ensure the termination, the basic idea is to sort ADT's
  constructors in the module `Ty` during the parsing and to use the fact
  that the SMT-LIB standard only accepts well-founded ADT.

  We choose constructors in domains with the following order:
  - Constructor with the less destructors using the same nest;
  - Constructor with the less destructors using another nest of the same
    mutually recursive declaration;
  • Loading branch information
Halbaroth committed Apr 22, 2024
1 parent 05e98a4 commit c4b46cb
Show file tree
Hide file tree
Showing 18 changed files with 409 additions and 74 deletions.
21 changes: 13 additions & 8 deletions src/lib/reasoners/adt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -405,12 +405,17 @@ module Shostak (X : ALIEN) = struct


let assign_value _ _ _ =
Printer.print_err
"[ADTs.models] assign_value currently not implemented";
raise (Util.Not_implemented "Models for ADTs")

let to_model_term _r =
Printer.print_err
"[ADTs.models] to_model_term currently not implemented";
raise (Util.Not_implemented "Models for ADTs")
(* Model generation is performed by the casesplit mechanism
in [Adt_rel]. *)
None

let to_model_term r =
match embed r with
| Constr { c_name; c_ty; c_args } ->
let args = Lists.try_map (fun (_, arg) -> X.to_model_term arg) c_args in
Option.bind args @@ fun args ->
Some (E.mk_term Sy.(Op (Constr c_name)) args c_ty)

| Select _ -> None
| Alien a -> X.to_model_term a
end
98 changes: 80 additions & 18 deletions src/lib/reasoners/adt_rel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ module Domain = struct

exception Inconsistent of Ex.t

let[@inline always] cardinal { constrs; _ } = HSS.cardinal constrs

let[@inline always] choose { constrs; _ } = HSS.choose constrs

let[@inline always] as_singleton { constrs; ex } =
if HSS.cardinal constrs = 1 then
Some (HSS.choose constrs, ex)
Expand Down Expand Up @@ -205,6 +209,8 @@ module Domains = struct
) t.changed acc
in
acc, { t with changed = SX.empty }

let fold f t = MX.fold f t.domains
end

let calc_destructor d e uf =
Expand Down Expand Up @@ -599,38 +605,94 @@ let pick_delayed_destructor env =
Rel_utils.Delayed.iter_delayed
(fun r sy _e ->
match sy with
| Sy.Destruct d ->
raise_notrace @@ Found (r, d)
| Sy.Destruct destr ->
let d = Domains.get r env.domains in
if Domain.cardinal d > 1 then
raise_notrace @@ Found (r, destr)
else
()
| _ ->
()
) env.delayed;
None
with Found (r, d) -> Some (r, d)

let pick_tightenable_domain ~for_model env uf =
Domains.fold
(fun r d best ->
let rr, _ = Uf.find_r uf r in
match Th.embed rr with
| Constr _ ->
best
| _ ->
let cd = Domain.cardinal d in
let c = Domain.choose d in
if for_model || is_tightenable (X.type_info r) c then
match best with
| Some (n, _, _) when n <= cd -> best
| Some _ | None -> Some (cd, r, c)
else
best
) env.domains None

let can_split env n =
let m = Options.get_max_split () in
Numbers.Q.(compare (mult n env.size_splits) m) <= 0 || Numbers.Q.sign m < 0

let (let*) = Option.bind

(* Do a case-split by choosing a semantic value [r] and constructor [c]
for which there are delayed destructor applications and propagate the
literal [(not (_ is c) r)]. *)
let case_split env _uf ~for_model =
if Options.get_disable_adts () || not (Options.get_enable_adts_cs())
then
let case_split_destructor env =
if not @@ Options.get_enable_adts_cs () then
None
else
let* r, d = pick_delayed_destructor env in
let c = constr_of_destr (X.type_info r) d in
Some (LR.mkv_builtin false (Sy.IsConstr c) [r])

(* Do a case-split by choosing a semantic value [r] whose the domain is as
small as possible.
If [for_model] is [false], we consider only constructors without payload. In
particular, if [r] has no constructors without payload in its domain, we
don't case-split on it. *)
let case_split_best ~for_model env uf =
let* n, r, c = pick_tightenable_domain ~for_model env uf in
let n = Numbers.Q.from_int n in
if for_model || can_split env n then
let _, cons = Option.get @@ build_constr_eq r c in
let nr, ctx = X.make cons in
assert (Lists.is_empty ctx);
Some (LR.mkv_eq r nr)
else
None

let case_split env uf ~for_model =
if Options.get_disable_adts () then
[]
else begin
assert (not for_model);
if Options.get_debug_adt () then Debug.pp_env "before cs" env;
match pick_delayed_destructor env with
| Some (r, d) ->
if Options.get_debug_adt () then
else
let cs =
match case_split_destructor env with
| Some _ as cs ->
assert (not (Options.get_enable_adts_cs ()) || not for_model);
cs
| None -> case_split_best ~for_model env uf
in
match cs with
| Some cs ->
if Options.get_debug_adt () then begin
Debug.pp_env "before cs" env;
Printer.print_dbg ~flushed:false
~module_name:"Adt_rel" ~function_name:"case_split"
"found r = %a and d = %a@ " X.print r Hstring.print d;
(* CS on negative version would be better in general. *)
let c = constr_of_destr (X.type_info r) d in
let cs = LR.mkv_builtin false (Sy.IsConstr c) [r] in
[ cs, true, Th_util.CS (Th_util.Th_adt, two) ]
"Assume %a" (Xliteral.print_view X.print) cs
end;
[ cs, true, Th_util.CS (Th_util.Th_adt, two)]
| None ->
Debug.no_case_split ();
if Options.get_debug_adt () then
Debug.no_case_split ();
[]
end

let optimizing_objective _env _uf _o = None

Expand Down
3 changes: 2 additions & 1 deletion src/lib/reasoners/shostak.ml
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,8 @@ struct
not (Options.get_restricted ()) &&
(RECORDS.is_mine_symb sb ty ||
BITV.is_mine_symb sb ty ||
ENUM.is_mine_symb sb ty)
ENUM.is_mine_symb sb ty ||
ADT.is_mine_symb sb ty)

let is_a_leaf r = match r.v with
| Term _ | Ac _ -> true
Expand Down
10 changes: 7 additions & 3 deletions src/lib/reasoners/uf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,10 @@ type cache = {
to ensure we don't generate twice an abstract value for a given symbol. *)
}

let is_destructor = function
| Sy.Op (Destruct _) -> true
| _ -> false

(* The environment of the union-find contains almost a first-order model.
There are two situations that require some computations to retrieve an
appropriate model value:
Expand All @@ -1084,9 +1088,9 @@ let compute_concrete_model_of_val cache =
and get_abstract_for = Cache.get_abstract_for cache.abstracts
in fun env t ((mdl, mrepr) as acc) ->
let { E.f; xs; ty; _ } = E.term_view t in
if X.is_solvable_theory_symbol f ty
|| Sy.is_internal f || E.is_internal_name t || E.is_internal_skolem t
|| E.equal t E.vrai || E.equal t E.faux
if X.is_solvable_theory_symbol f ty || is_destructor f
|| Sy.is_internal f || E.is_internal_name t || E.is_internal_skolem t
|| E.equal t E.vrai || E.equal t E.faux
then
(* These terms are built-in interpreted ones and we don't have
to produce a definition for them. *)
Expand Down
140 changes: 96 additions & 44 deletions src/lib/structures/ty.ml
Original file line number Diff line number Diff line change
Expand Up @@ -463,54 +463,55 @@ module Decls = struct
MH.add name {decl = (params, body); instances = MTY.empty} !decls

let body name args =
let {decl = (params, body); instances} = MH.find name !decls in
try
let {decl = (params, body); instances} = MH.find name !decls in
try
if compare_list params args = 0 then body
else MTY.find args instances
(* should I instantiate if not found ?? *)
with Not_found ->
let params, body = fresh_type params body in
(*if true || get_debug_adt () then*)
let sbt =
try
List.fold_left2
(fun sbt vty ty ->
let vty = shorten vty in
match vty with
| Tvar { value = Some _ ; _ } -> assert false
| Tvar {v ; value = None} ->
if equal vty ty then sbt else M.add v ty sbt
| _ ->
Printer.print_err "vty = %a and ty = %a"
print vty print ty;
assert false
)M.empty params args
with Invalid_argument _ -> assert false
in
let body = match body with
| Adt cases ->
Adt(
List.map
(fun {constr; destrs} ->
{constr;
destrs =
List.map (fun (d, ty) -> d, apply_subst sbt ty) destrs }
) cases
)
in
let params = List.map (fun ty -> apply_subst sbt ty) params in
add name params body;
body
if compare_list params args = 0 then body
else MTY.find args instances
(* should I instantiate if not found ?? *)
with Not_found ->
Printer.print_err "%a not found" Hstring.print name;
assert false
let params, body = fresh_type params body in
(*if true || get_debug_adt () then*)
let sbt =
try
List.fold_left2
(fun sbt vty ty ->
let vty = shorten vty in
match vty with
| Tvar { value = Some _ ; _ } -> assert false
| Tvar {v ; value = None} ->
if equal vty ty then sbt else M.add v ty sbt
| _ ->
Printer.print_err "vty = %a and ty = %a"
print vty print ty;
assert false
)M.empty params args
with Invalid_argument _ -> assert false
in
let body = match body with
| Adt cases ->
Adt(
List.map
(fun {constr; destrs} ->
{constr;
destrs =
List.map (fun (d, ty) -> d, apply_subst sbt ty) destrs }
) cases
)
in
let params = List.map (fun ty -> apply_subst sbt ty) params in
add name params body;
body

let reinit () = decls := MH.empty

end

let type_body name args = Decls.body name args
let type_body name args =
try
Decls.body name args
with Not_found ->
Printer.print_err "%a not found" Hstring.print name;
assert false


(* smart constructors *)
Expand All @@ -524,6 +525,56 @@ let fresh_empty_text =

let tsum s lc = Tsum (Hstring.make s, List.map Hstring.make lc)

(* Count the number of occurences of the nest of [name] in the signature
of payload [l]. *)
let count_same_nest name l =
Lists.sum
(fun (_, ty) ->
match ty with
| Tadt (name', _) when Hstring.equal name name' -> 1
| _ -> 0
) l

(* Count the number of occurences of the (mutually recursive) ADT type of
[name] in the signature of payload [l]. *)
let count_same_adt name l =
Lists.sum
(fun (_, ty) ->
match ty with
| Tadt (name', params) ->
(* TODO: this is a hackish way to check that `name'` and `name` are
two nests of the same mutually recursive ADT. We should store
ADT's nests in a data structure as it's done in Dolmen. *)
begin try
let Adt cases = Decls.body name' params in
List.exists
(fun { destrs; _ } ->
List.exists
(fun (_, ty') ->
match ty' with
| Tadt (name'', _) -> Hstring.equal name name''
| _ -> false
) destrs
) cases
|> Bool.to_int
with Not_found ->
(* If we haven't already register the nest [name'], it means that
[name] and [name'] are two nests of the same ADT. *)
1
end
| _ -> 0
) l

(* Comparison function used to ensure the termination of the model
generation for recursive ADT values. *)
let cons_weight name (_, l1) (_, l2) =
let c = count_same_nest name l1 - count_same_nest name l2 in
if c <> 0 then c
else
let c = count_same_adt name l1 - count_same_adt name l2 in
if c <> 0 then c
else List.compare_lengths l1 l2

let t_adt ?(body=None) s ty_vars =
let hs = Hstring.make s in
let ty = Tadt (hs, ty_vars) in
Expand All @@ -545,12 +596,13 @@ let t_adt ?(body=None) s ty_vars =
Decls.add hs ty_vars (Adt cases)
| Some cases ->
let cases =
List.map (fun (s, l) ->
List.stable_sort (cons_weight hs) cases |>
List.map (fun (c, l) ->
let l =
List.map (fun (d, e) -> Hstring.make d, e) l
in
{constr = Hstring.make s; destrs = l}
) cases
{constr = Hstring.make c; destrs = l}
)
in
Decls.add hs ty_vars (Adt cases)
end;
Expand Down
3 changes: 3 additions & 0 deletions src/lib/util/lists.ml
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,6 @@ let rec is_sorted cmp l =
match l with
| x :: y :: xs -> cmp x y <= 0 && is_sorted cmp (y :: xs)
| [_] | [] -> true

let sum f l =
List.fold_left (fun sum i -> sum + f i) 0 l
4 changes: 4 additions & 0 deletions src/lib/util/lists.mli
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,7 @@ val try_map : ('a -> 'b option) -> 'a list -> 'b list option
val is_sorted : ('a -> 'a -> int) -> 'a list -> bool
(** [is_sorted cmp l] checks that [l] is sorted for the comparison function
[cmp]. *)

val sum : ('a -> int) -> 'a list -> int
(** [sum f l] computes the sum [f a1 + f a2 + ...] where
[l = [a1; a2; ...]]. *)
Loading

0 comments on commit c4b46cb

Please sign in to comment.