Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: exporting alt-ergo FPA built-in primitives #876

Merged
merged 16 commits into from
Oct 30, 2023
231 changes: 168 additions & 63 deletions src/lib/frontend/d_cnf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ end
(** Builtins *)
type _ DStd.Builtin.t +=
| Float
| AERound of int * int
(** Equivalent of Float for the SMT2 format. *)
| Integer_round
| Abs_real
| Sqrt_real
Expand All @@ -193,6 +195,49 @@ type _ DStd.Builtin.t +=
(* Internal use for semantic triggers -- do not expose outside of theories *)
| Not_theory_constant | Is_theory_constant | Linear_dependency

let builtin_term t = Dl.Typer.T.builtin_term t

let builtin_ty t = Dl.Typer.T.builtin_ty t

let ty name ty =
Id.Map.add { name = DStd.Name.simple name; ns = Sort } @@
fun env s ->
builtin_ty @@
Dolmen_type.Base.app0 (module Dl.Typer.T) env s ty

let builtin_enum = function
| Ty.Tsum (name, cstrs) as ty_ ->
let ty_cst =
DStd.Expr.Id.mk ~builtin:B.Base
(DStd.Path.global (Hstring.view name))
DStd.Expr.{ arity = 0; alias = No_alias }
in
let cstrs =
List.map (fun c -> DStd.Path.global (Hstring.view c), []) cstrs
in
let _, cstrs = DStd.Expr.Term.define_adt ty_cst [] cstrs in
let dty = DT.apply ty_cst [] in
let add_cstrs map =
List.fold_left (fun map ((c : DE.term_cst), _) ->
let name = get_basename c.path in
Id.Map.add { name = DStd.Name.simple name; ns = Term } (fun env _ ->
builtin_term @@
Dolmen_type.Base.term_app_cst
(module Dl.Typer.T) env c) map)
map cstrs
in
Cache.store_ty (DE.Ty.Const.hash ty_cst) ty_;
dty,
cstrs,
fun map ->
map
|> ty (Hstring.view name) dty
|> add_cstrs
| _ -> assert false

let fpa_rounding_mode, rounding_modes, add_rounding_modes =
builtin_enum Fpa_rounding.fpa_rounding_mode

module Const = struct
open DE

Expand All @@ -207,6 +252,15 @@ module Const = struct
let name = "int2bv" in
Id.mk ~name ~builtin:(Int2BV n)
(DStd.Path.global name) Ty.(arrow [int] (bitv n)))

let smt_round =
with_cache (fun (n, m) ->
let name = "ae.round" in
Id.mk
~name
~builtin:(AERound (n, m))
(DStd.Path.global name)
Ty.(arrow [fpa_rounding_mode; real] real))
end

let bv2nat t =
Expand All @@ -220,6 +274,9 @@ let bv2nat t =
let int2bv n t =
DE.Term.apply_cst (Const.int2bv n) [] [t]

let smt_round n m rm t =
DE.Term.apply_cst (Const.smt_round (n, m)) [] [rm; t]

let bv_builtins env s =
let term_app1 f =
Dl.Typer.T.builtin_term @@
Expand All @@ -241,54 +298,49 @@ let bv_builtins env s =
end
| _ -> `Not_found

let fpa_builtins =
(** Takes a dolmen identifier [id] and injects it in Alt-Ergo's registered
identifiers.
It transforms "fpa_rounding_mode", the Alt-Ergo builtin type into the SMT2
rounding type "RoundingMode". Also injects each constructor into their SMT2
equivalent *)
let inject_ae_to_smt2 id =
match id with
| Id.{name = Simple n; _} ->
begin
if String.equal n Fpa_rounding.fpa_rounding_mode_ae_type_name then
(* Injecting the type name as the SMT2 Type name. *)
let name =
Dolmen_std.Name.simple Fpa_rounding.fpa_rounding_mode_type_name
in
{id with name}
else
match Fpa_rounding.rounding_mode_of_ae_hs (Hstring.make n) with
| rm ->
let name =
Dolmen_std.Name.simple (Fpa_rounding.string_of_rounding_mode rm)
in
{id with name}
| exception (Failure _) ->
id
end
| id -> id

let ae_fpa_builtins =
let (->.) args ret = (args, ret) in
let builtin_term t = Dl.Typer.T.builtin_term t in
let builtin_ty t = Dl.Typer.T.builtin_ty t in
let dterm name f =
Id.Map.add { name = DStd.Name.simple name; ns = Term } @@
fun env s ->
builtin_term @@
Dolmen_type.Base.term_app1 (module Dl.Typer.T) env s f
in
let ty name ty =
Id.Map.add { name = DStd.Name.simple name; ns = Sort } @@
fun env s ->
builtin_ty @@
Dolmen_type.Base.app0 (module Dl.Typer.T) env s ty
in
let builtin_enum = function
| Ty.Tsum (name, cstrs) as ty_ ->
let ty_cst =
DStd.Expr.Id.mk ~builtin:B.Base
(DStd.Path.global (Hstring.view name))
DStd.Expr.{ arity = 0; alias = No_alias }
in
let cstrs =
List.map (fun c -> DStd.Path.global (Hstring.view c), []) cstrs
in
let _, cstrs = DStd.Expr.Term.define_adt ty_cst [] cstrs in
let dty = DT.apply ty_cst [] in
let add_cstrs map =
List.fold_left (fun map ((c : DE.term_cst), _) ->
let name = get_basename c.path in
Id.Map.add { name = DStd.Name.simple name; ns = Term } (fun env _ ->
builtin_term @@
Dolmen_type.Base.term_app_cst
(module Dl.Typer.T) env c) map)
map cstrs
in
Cache.store_ty (DE.Ty.Const.hash ty_cst) ty_;
dty,
cstrs,
fun map ->
map
|> ty (Hstring.view name) dty
|> add_cstrs
| _ -> assert false
in
let fpa_rounding_mode, rounding_modes, add_rounding_modes =
builtin_enum Fpa_rounding.fpa_rounding_mode
let op ?(tyvars = []) name builtin (args, ret) =
let ty = DT.pi tyvars @@ DT.arrow args ret in
let cst = DE.Id.mk ~name ~builtin (DStd.Path.global name) ty in
Id.Map.add { name = DStd.Name.simple name; ns = Term } @@
fun env _ ->
builtin_term @@
Dolmen_type.Base.term_app_cst
(module Dl.Typer.T) env cst
in
let float_cst =
let ty = DT.(arrow [int; int; fpa_rounding_mode; real] real) in
Expand All @@ -311,15 +363,6 @@ let fpa_builtins =
let float32d x = float32 (mode "NearestTiesToEven") x in
let float64 = float (DE.Term.int "53") (DE.Term.int "1074") in
let float64d x = float64 (mode "NearestTiesToEven") x in
let op ?(tyvars = []) name builtin (args, ret) =
let ty = DT.pi tyvars @@ DT.arrow args ret in
let cst = DE.Id.mk ~name ~builtin (DStd.Path.global name) ty in
Id.Map.add { name = DStd.Name.simple name; ns = Term } @@
fun env _ ->
builtin_term @@
Dolmen_type.Base.term_app_cst
(module Dl.Typer.T) env cst
in
let partial1 name f =
Id.Map.add { name = DStd.Name.simple name; ns = Term } @@
fun env s ->
Expand All @@ -335,7 +378,11 @@ let fpa_builtins =
let is_theory_constant =
let open DT in
let a = Var.mk "alpha" in
op ~tyvars:[a] "is_theory_constant" Is_theory_constant ([of_var a] ->. prop)
op
~tyvars:[a]
"is_theory_constant"
Is_theory_constant
([of_var a] ->. prop)
in
let fpa_builtins =
let open DT in
Expand Down Expand Up @@ -409,24 +456,61 @@ let fpa_builtins =
|> op "not_theory_constant" Not_theory_constant ([real] ->. prop)
|> is_theory_constant
|> op "linear_dependency" Linear_dependency ([real; real] ->. prop)

in
fun env s ->
begin match s with
| Dl.Typer.T.Id id ->
begin
try
Id.Map.find_exn id fpa_builtins env s
with Not_found -> `Not_found
end
| Builtin _ -> `Not_found
end
let search_id id =
try
Id.Map.find_exn id fpa_builtins env s
with Not_found -> `Not_found
in
match s with
| Dl.Typer.T.Id id ->
let new_id = inject_ae_to_smt2 id in
search_id new_id
| Builtin _ -> `Not_found

let smt_fpa_builtins =
let term_app env s f =
Dl.Typer.T.builtin_term @@
Dolmen_type.Base.term_app2 (module Dl.Typer.T) env s f
in
let other_builtins =
Id.Map.empty
|> add_rounding_modes
in
fun env s ->
match s with
| Dl.Typer.T.Id {
ns = Term ;
name = Indexed {
basename = "ae.round" ;
indexes = [ i; j ] } } ->
begin match
int_of_string i,
int_of_string j
with
| n, m -> term_app env s (smt_round n m)
| exception Failure _ -> `Not_found
end
| Dl.Typer.T.Id id -> begin
match Id.Map.find_exn id other_builtins env s with
| e -> e
| exception Not_found -> `Not_found
end
| _ -> `Not_found

(** Concatenation of builtins handlers. *)
let (++) bt1 bt2 =
fun a b ->
match bt1 a b with
| `Not_found -> bt2 a b
| res -> res

let builtins =
fun _st (lang : Typer.lang) ->
match lang with
| `Logic Alt_ergo -> fpa_builtins
| `Logic (Smtlib2 _) -> bv_builtins
| `Logic Alt_ergo -> ae_fpa_builtins
| `Logic (Smtlib2 _) -> bv_builtins ++ smt_fpa_builtins
| _ -> fun _ _ -> `Not_found

(** Translates dolmen locs to Alt-Ergo's locs *)
Expand Down Expand Up @@ -929,6 +1013,13 @@ let mk_add translate sy ty l =
let args = aux_mk_add l in
E.mk_term sy args ty

let mk_rounding fpar =
let name = Fpa_rounding.string_of_rounding_mode fpar in
let ty = Fpa_rounding.fpa_rounding_mode in
let sy =
Sy.Op (Sy.Constr (Hstring.make name)) in
E.mk_term sy [] ty

(** [mk_expr ~loc ~name_base ~toplevel ~decl_kind term]

Builds an Alt-Ergo hashconsed expression from a dolmen term
Expand Down Expand Up @@ -1355,6 +1446,15 @@ let rec mk_expr
| _ -> unsupported "coercion: %a" DE.Term.print term
end
| Float, _ -> op Float
| AERound(i, j), _ ->
let args =
let i = E.Ints.of_int i in
let j = E.Ints.of_int j in
i :: j :: List.map (fun a -> aux_mk_expr a) args in
E.mk_term
(Sy.Op Float)
args
(dty_to_ty term_ty)
| Integer_round, _ -> op Integer_round
| Abs_real, _ -> op Abs_real
| Sqrt_real, _ -> op Sqrt_real
Expand All @@ -1373,6 +1473,11 @@ let rec mk_expr
| Not_theory_constant, _ -> op Not_theory_constant
| Is_theory_constant, _ -> op Is_theory_constant
| Linear_dependency, _ -> op Linear_dependency
| B.RoundNearestTiesToEven, _ -> mk_rounding NearestTiesToEven
| B.RoundNearestTiesToAway, _ -> mk_rounding NearestTiesToAway
| B.RoundTowardPositive, _ -> mk_rounding Up
| B.RoundTowardNegative, _ -> mk_rounding Down
| B.RoundTowardZero, _ -> mk_rounding ToZero
| _, _ -> unsupported "Application Term %a" DE.Term.print term
end

Expand Down
45 changes: 29 additions & 16 deletions src/lib/frontend/typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -259,18 +259,29 @@ module Env = struct
| [ x; y ] -> f x y
| _ -> assert false

let add_builtin_enum = function
| Ty.Tsum (_, cstrs) as ty ->
let enum m h =
let s = Hstring.view h in
MString.add s (`Term (
Symbols.Op (Constr h),
{ args = []; result = ty },
Other
)) m
in
fun m -> List.fold_left enum m cstrs
| _ -> assert false
let add_fpa_enum map =
let ty = Fpa_rounding.fpa_rounding_mode in
match ty with
| Ty.Tsum (_, cstrs) ->
List.fold_left
(fun m c ->
match Fpa_rounding.translate_smt_rounding_mode c with
| None ->
(* The constructors of the type are expected to be AE rounding
modes. *)
assert false
| Some hs ->
MString.add (Hstring.view hs) (`Term (
Symbols.Op (Constr c),
{ args = []; result = ty },
Other
))
m
)
map
cstrs
| _ -> (* Fpa_rounding.fpa_rounding_mode is a sum type. *)
assert false
Comment on lines +262 to +284
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: while this is correct, I feel like having the generic add_builtin_enum available is useful.

We could do add_builtin_enum ?(rename = Fun.id) map with an optional rename argument that is fun c -> Option.get (Fpa_rounding.translate_smt_rounding_mode c) for the rounding mode. On the other hand we can also do this when (if) we ever have more builtin enums to add before removing the legacy typechecker entirely, so eh.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC there was a more generic version of this function, which I specialized because it was used only once and I did not expect we would add new enums to the legacy syntax. I can revert it, but I already changed it to simplify the reading.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The generic version did not handle renaming so you were right to specialize it. This is a nit, feel free to ignore it :)


let find_builtin_cstr ty n =
match ty with
Expand Down Expand Up @@ -298,10 +309,12 @@ module Env = struct
let float prec exp mode x =
TTapp (Symbols.Op Float, [prec; exp; mode; x])
in
let nte = Fpa_rounding.string_of_rounding_mode NearestTiesToEven in
let tname = Fpa_rounding.fpa_rounding_mode_type_name in
let float32 = float (int "24") (int "149") in
let float32d = float32 (mode "NearestTiesToEven") in
let float32d = float32 (mode nte) in
let float64 = float (int "53") (int "1074") in
let float64d = float64 (mode "NearestTiesToEven") in
let float64d = float64 (mode nte) in
let op n op profile =
MString.add n @@ `Term (Symbols.Op op, profile, Other)
in
Expand All @@ -312,8 +325,8 @@ module Env = struct
let any = Ty.fresh_tvar in
let env = {
env with
types = Types.add_builtin env.types "fpa_rounding_mode" rm ;
builtins = add_builtin_enum Fpa_rounding.fpa_rounding_mode env.builtins;
types = Types.add_builtin env.types tname rm ;
builtins = add_fpa_enum env.builtins;
} in
let builtins =
env.builtins
Expand Down
2 changes: 1 addition & 1 deletion src/lib/structures/expr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2820,7 +2820,7 @@ let const_view t =
end
| { f = Op (Constr c); ty; _ }
when Ty.equal ty Fpa_rounding.fpa_rounding_mode ->
RoundingMode (Fpa_rounding.rounding_mode_of_hs c)
RoundingMode (Fpa_rounding.rounding_mode_of_smt_hs c)
| _ -> Fmt.failwith "unsupported constant: %a" print t

let int_view t =
Expand Down
Loading
Loading