-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathopt_bb_eval.ml
137 lines (132 loc) · 4.43 KB
/
opt_bb_eval.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
open Interval
open Expr
open Opt_common
type expr' =
| Ref' of int
| Const' of interval
| Var' of int
| Pown of expr' * int
| U_op' of u_op_type * expr'
| Bin_op' of bin_op_type * expr' * expr'
| Gen_op' of gen_op_type * expr' list
| Ulp_op' of int * int * expr'
let expr'_of_expr var_index =
let rec of_expr = function
| Const c -> Const' (Const.to_interval c)
| Var v as ex when is_ref_var ex -> Ref' (index_of_ref_var ex)
| Var v -> Var' (var_index v)
| Rounding _ -> failwith "Rounding is not supported"
| U_op (op, arg) -> U_op' (op, of_expr arg)
| Bin_op (Op_mul, arg1, arg2) when eq_expr arg1 arg2 -> Pown (of_expr arg1, 2)
| Bin_op (Op_nat_pow, arg1, arg2) ->
let e = Eval.eval_interval_const_expr arg2 in
let n = truncate e.low in
if n < 0 || e.low <> e.high || float n <> e.low then
failwith "expr'_of_expr: Op_nat_pow"
else
Pown (of_expr arg1, n)
| Bin_op (op, arg1, arg2) -> Bin_op' (op, of_expr arg1, of_expr arg2)
| Gen_op (Op_ulp, [Const p; Const e; arg]) ->
Ulp_op' (Const.to_int p, Const.to_int e, of_expr arg)
| Gen_op (op, args) -> Gen_op' (op, List.map of_expr args)
in
of_expr
let rec eval_expr' refs arr = function
| Ref' i -> refs.(i)
| Const' c -> c
| Var' v -> arr.(v)
| Pown (arg, n) -> pow_I_i (eval_expr' refs arr arg) n
| Ulp_op' (p, e, arg) -> Func.goldberg_ulp_I (p, e) (eval_expr' refs arr arg)
| U_op' (op, arg) ->
begin
let x = eval_expr' refs arr arg in
match op with
| Op_neg -> ~-$ x
| Op_abs -> abs_I x
| Op_inv -> inv_I x
| Op_sqrt -> sqrt_I x
| Op_sin -> sin_I x
| Op_cos -> cos_I x
| Op_tan -> tan_I x
| Op_asin -> asin_I x
| Op_acos -> acos_I x
| Op_atan -> atan_I x
| Op_exp -> exp_I x
| Op_log -> log_I x
| Op_sinh -> sinh_I x
| Op_cosh -> cosh_I x
| Op_tanh -> tanh_I x
| Op_asinh -> Func.asinh_I x
| Op_acosh -> Func.acosh_I x
| Op_atanh -> Func.atanh_I x
| Op_floor_power2 -> Func.floor_power2_I x
end
| Bin_op' (op, arg1, arg2) ->
begin
let x1 = eval_expr' refs arr arg1 in
let x2 = eval_expr' refs arr arg2 in
match op with
| Op_add -> x1 +$ x2
| Op_sub -> x1 -$ x2
| Op_mul -> x1 *$ x2
| Op_div -> x1 /$ x2
| Op_max -> max_I_I x1 x2
| Op_min -> min_I_I x1 x2
| Op_nat_pow -> x1 **$. x2.high
| Op_sub2 -> Func.sub2_I (x1, x2)
| Op_abs_err -> Func.abs_err_I (x1, x2)
end
| Gen_op' (op, args) ->
begin
let xs = List.map (eval_expr' refs arr) args in
match (op, xs) with
| (Op_fma, [a;b;c]) -> (a *$ b) +$ c
| _ -> failwith ("eval_expr': Unsupported general operation: "
^ gen_op_name op)
end
let rec eval_expr'_list refs vars i = function
| [] -> failwith "eval_expr'_list: empty list"
| [ex] -> eval_expr' refs vars ex
| ex :: exs ->
refs.(i) <- eval_expr' refs vars ex;
eval_expr'_list refs vars (i + 1) exs
let min_max_expr (pars : Opt_common.opt_pars) max_only (cs : constraints) e =
(* ExprOut.(
Log.report `Main "Testing: %s" (Info.print_str e);
let es = expr_ref_list_of_expr e in
es |> List.iteri (fun i e -> Log.report `Main "%d: %s" i (Info.print_str e));
Log.report `Main "---"); *)
if Config.debug () then
Log.report `Debug "bb-eval_opt: x_abs_tol = %e, f_rel_tol = %e, f_abs_tol = %e, iters = %d"
pars.x_abs_tol pars.f_rel_tol pars.f_abs_tol pars.max_iters;
let var_names = vars_in_expr e in
let start_interval = var_names
|> List.map cs.var_interval
|> Array.of_list in
let x_tol = size_max_X start_interval *. pars.x_rel_tol +. pars.x_abs_tol in
let h_vars = Hashtbl.create 8 in
var_names |> List.iteri (fun i v -> Hashtbl.add h_vars v i);
let es' = e |> expr_ref_list_of_expr |> List.map (expr'_of_expr (Hashtbl.find h_vars)) in
let refs = Array.make (List.length es' - 1) Interval.zero_I in
let f arr = eval_expr'_list refs arr 0 es' in
let fmax, lower_max, iter_max =
Opt0.opt f start_interval x_tol pars.f_rel_tol pars.f_abs_tol pars.max_iters in
let fmin, lower_min, iter_min =
if max_only then 0., 0., 0
else
let f_min arr = ~-$ (eval_expr'_list refs arr 0 es') in
let fm, lm, i = Opt0.opt f_min start_interval x_tol pars.f_rel_tol pars.f_abs_tol pars.max_iters in
-.fm, -.lm, i in
let rmin = {
result = fmin;
lower_bound = lower_min;
iters = iter_min;
time = 0.;
} in
let rmax = {
result = fmax;
lower_bound = lower_max;
iters = iter_max;
time = 0.;
} in
rmin, rmax