diff --git a/src/choice_monad.ml b/src/choice_monad.ml index 62d8e2e21..512e6546f 100644 --- a/src/choice_monad.ml +++ b/src/choice_monad.ml @@ -6,9 +6,11 @@ open Choice_monad_intf type vbool = Symbolic_value.S.vbool +type thread = Thread.t + type vint32 = Symbolic_value.S.int32 -exception Assertion of assertion * Thread.t +exception Assertion of Encoding.Expression.t * Thread.t let check (sym_bool : vbool) (state : Thread.t) : bool = let (S (solver_module, solver)) = Thread.solver state in @@ -19,237 +21,173 @@ let check (sym_bool : vbool) (state : Thread.t) : bool = | Val (Bool no) -> not no | _ -> let check = no :: pc in - Format.printf "CHECK:@."; - List.iter (fun c -> print_endline (Encoding.Expression.to_string c)) check; + Format.printf "CHECK:@.%a" + (Format.pp_print_list ~pp_sep:Format.pp_print_newline + Encoding.Expression.pp ) + check; let module Solver = (val solver_module) in - let r = - Tracing.with_ev Tracing.check (fun () -> Solver.check solver check) - in + let r = Solver.check solver check in let msg = if r then "KO" else "OK" in Format.printf "/CHECK %s@." msg; not r -let eval_choice (sym_bool : vbool) (state : Thread.t) : (bool * Thread.t) list = - let (S (solver_module, s) as solver) = Thread.solver state in - let pc = Thread.pc state in - let memories = Thread.memories state in - let tables = Thread.tables state in - let globals = Thread.globals state in - let sym_bool = Encoding.Expression.simplify sym_bool in - match sym_bool with - | Val (Bool b) -> [ (b, state) ] - | Val (Num (I32 _)) -> assert false - | _ -> ( - let no = Symbolic_value.S.Bool.not sym_bool in - let module Solver = (val solver_module) in - let sat_true = - Tracing.with_ev Tracing.check_true (fun () -> - Solver.check s (sym_bool :: pc) ) - in - let sat_false = - Tracing.with_ev Tracing.check_false (fun () -> Solver.check s (no :: pc)) - in - match (sat_true, sat_false) with - | false, false -> [] - | true, false -> [ (true, state) ] - | false, true -> [ (false, state) ] - | true, true -> - Format.printf "CHOICE: %s@." (Encoding.Expression.to_string sym_bool); - let state1 : Thread.t = - { solver - ; pc = sym_bool :: pc - ; memories = Symbolic_memory.clone memories - ; tables = Symbolic_table.clone tables - ; globals = Symbolic_global.clone globals - } - in - let state2 : Thread.t = - { solver - ; pc = no :: pc - ; memories = Symbolic_memory.clone memories - ; tables = Symbolic_table.clone tables - ; globals = Symbolic_global.clone globals - } - in - [ (true, state1); (false, state2) ] ) - -let fix_symbol (e : Encoding.Expression.t) pc = - let open Encoding in - match e with - | Symbol sym -> (pc, sym) - | _ -> - let sym = Symbol.mk_symbol `I32Type "choice_i32" in - let assign = Expression.Relop (I32 Eq, Symbol sym, e) in - (assign :: pc, sym) - -let clone_if_needed ~orig_pc (cases : (int32 * Thread.t) list) : - (int32 * Thread.t) list = - match cases with - | [] -> [] - | [ (i, state) ] -> [ (i, { state with pc = orig_pc }) ] - | _ :: _ :: _ -> - List.map - (fun (i, state) -> - let solver = Thread.solver state in - let pc = Thread.pc state in - let memories = Thread.memories state in - let tables = Thread.tables state in - let globals = Thread.globals state in - let state : Thread.t = - { solver - ; pc - ; memories = Symbolic_memory.clone memories - ; tables = Symbolic_table.clone tables - ; globals = Symbolic_global.clone globals - } - in - (i, state) ) - cases - -let not_value sym value = - Encoding.Expression.Relop (I32 Ne, Symbol sym, Val (Num (I32 value))) - -let eval_choice_i32 (sym_int : vint32) (state : Thread.t) : - (int32 * Thread.t) list = - let (S (solver_module, solver)) = Thread.solver state in - let pc = Thread.pc state in - let sym_int = Encoding.Expression.simplify sym_int in - let orig_pc = pc in - let pc, sym = fix_symbol sym_int pc in - match sym_int with - | Val (Num (I32 i)) -> [ (i, state) ] - | _ -> - let module Solver = (val solver_module) in - let rec find_values values = - let additionnal = List.map (not_value sym) values in - if - not - (Tracing.with_ev Tracing.check (fun () -> - Solver.check solver (additionnal @ pc) ) ) - then [] - else begin - let model = - Tracing.with_ev Tracing.model (fun () -> - Solver.model ~symbols:[ sym ] solver ) - in - match model with - | None -> assert false (* ? *) - | Some model -> ( - let desc = Encoding.Model.to_string model in - Format.printf "Model:@.%s@." desc; - let v = Encoding.Model.evaluate model sym in - match v with +module Make (M : sig + type 'a t + + val return : 'a -> 'a t + + val empty : 'a t + + val length : 'a t -> int + + val hd : 'a t -> 'a + + val cons : 'a -> 'a t -> 'a t + + val map : ('a -> 'b) -> 'a t -> 'b t + + val flat_map : ('a -> 'b t) -> 'a t -> 'b t + + val to_list : 'a t -> 'a list + + val to_seq : 'a t -> 'a Seq.t +end) = +struct + include M + + type 'a t = thread -> ('a * thread) M.t + + let select b ({ Thread.pc; solver = S (solver_module, s); _ } as state) = + match Encoding.Expression.simplify b with + | Val (Bool b) -> M.return (b, state) + | Val (Num (I32 _)) -> assert false + | e -> ( + let module Solver = (val solver_module) in + let with_e = e :: pc in + let with_not_e = Symbolic_value.S.Bool.not e :: pc in + let sat_true = Solver.check s with_e in + let sat_false = Solver.check s with_not_e in + match (sat_true, sat_false) with + | false, false -> M.empty + | true, false | false, true -> M.return (sat_true, state) + | true, true -> + Format.printf "CHOICE: %a@." Encoding.Expression.pp e; + let state1 = Thread.clone { state with pc = with_e } in + let state2 = Thread.clone { state with pc = with_not_e } in + M.cons (true, state1) (M.return (false, state2)) ) + + let fix_symbol (e : Encoding.Expression.t) pc = + let open Encoding in + match e with + | Symbol sym -> (pc, sym) + | _ -> + let sym = Symbol.mk_symbol `I32Type "choice_i32" in + let assign = Expression.Relop (I32 Eq, Symbol sym, e) in + (assign :: pc, sym) + + let clone_if_needed ~orig_pc cases = + match M.length cases with + | 0 -> cases + | 1 -> + let i, state = M.hd cases in + M.return (i, { state with Thread.pc = orig_pc }) + | _n -> + M.map + (fun (i, state) -> + let state = Thread.clone state in + (i, state) ) + cases + + let not_value sym value = + Encoding.Expression.Relop (I32 Ne, Symbol sym, Val (Num (I32 value))) + + let select_i32 (sym_int : vint32) (state : Thread.t) = + let (S (solver_module, solver)) = Thread.solver state in + let pc = Thread.pc state in + let sym_int = Encoding.Expression.simplify sym_int in + let orig_pc = pc in + let pc, sym = fix_symbol sym_int pc in + match sym_int with + | Val (Num (I32 i)) -> M.return (i, state) + | _ -> + let module Solver = (val solver_module) in + let rec find_values values = + let additionnal = M.to_list @@ M.map (not_value sym) values in + if not (Solver.check solver (additionnal @ pc)) then M.empty + else begin + let model = Solver.model ~symbols:[ sym ] solver in + match model with | None -> assert false (* ? *) - | Some (Num (I32 i) as v) -> begin - let cond = Encoding.Expression.Relop (I32 Eq, Symbol sym, Val v) in - let case = (i, { state with pc = cond :: pc }) in - case :: find_values (i :: values) - end - | Some _ -> assert false ) - end - in - let cases = find_values [] in - clone_if_needed ~orig_pc cases - -module List = struct - type vbool = Symbolic_value.S.vbool - - type thread = Thread.t - - type 'a t = thread -> ('a * thread) list - - let return (v : 'a) : 'a t = fun t -> [ (v, t) ] + | Some model -> ( + Format.printf "Model:@.%a@." Encoding.Model.pp model; + let v = Encoding.Model.evaluate model sym in + match v with + | None -> assert false (* ? *) + | Some (Num (I32 i) as v) -> begin + let cond = + Encoding.Expression.Relop (I32 Eq, Symbol sym, Val v) + in + let pc = cond :: pc in + let case = (i, { state with pc }) in + M.cons case (find_values (M.cons i values)) + end + | Some _ -> assert false ) + end + in + let cases = find_values M.empty in + clone_if_needed ~orig_pc cases let bind (v : 'a t) (f : 'a -> 'b t) : 'b t = - fun t -> - let lst = v t in - match lst with - | [] -> [] - | [ (r, t) ] -> (f r) t - | _ -> List.concat_map (fun (r, t) -> (f r) t) lst - - let select (sym_bool : vbool) : bool t = eval_choice sym_bool + fun t -> M.flat_map (fun (r, t) -> (f r) t) (v t) - let select_i32 (i : Symbolic_value.S.int32) : int32 t = eval_choice_i32 i + let trap = function Trap.Unreachable -> fun _ -> M.empty | _ -> assert false - let trap : Trap.t -> 'a t = function - | Unreachable -> fun _ -> [] - | _ -> assert false + let with_thread f t = M.return (f t, t) - (* raise (Types.Trap "out of bounds memory access") *) - - let with_thread (f : thread -> 'b) : 'b t = fun t -> [ (f t, t) ] - - let add_pc (c : Symbolic_value.S.vbool) : unit t = - fun t -> + let add_pc c t = match c with - | Val (Bool b) -> if b then [ ((), t) ] else [] - | _ -> [ ((), { t with pc = c :: t.pc }) ] - - let assertion c t = - if check c t then [ ((), t) ] + | Encoding.Expression.Val (Bool b) -> + if b then M.return ((), t) else M.empty + | _ -> + let pc = c :: t.Thread.pc in + let t = { t with pc } in + M.return ((), t) + + let assertion c (t : Thread.t) = + if check c t then M.return ((), t) else let no = Symbolic_value.S.Bool.not c in let t = { t with pc = no :: t.pc } in - raise (Assertion (Encoding.Expression.to_string c, t)) - - let run (v : 'a t) (thread : thread) = List.to_seq (v thread) -end - -module Seq = struct - module List = Stdlib.List + raise (Assertion (c, t)) - type vbool = Symbolic_value.S.vbool + let run v t = v t |> M.to_seq - type thread = Thread.t + let return v t = M.return (v, t) +end - type 'a t = thread -> ('a * thread) Seq.t +module CList = Make (struct + include List - let return (v : 'a) : 'a t = fun t -> Seq.return (v, t) + let return v = [ v ] - let bind (v : 'a t) (f : 'a -> 'b t) : 'b t = - fun t -> - let seq = v t in - Seq.flat_map (fun (e, t) -> f e t) seq + let empty = [] - let select (sym_bool : vbool) : bool t = - fun state -> List.to_seq (eval_choice sym_bool state) + let flat_map = List.concat_map - let select_i32 (i : Symbolic_value.S.int32) : int32 t = - fun state -> List.to_seq (eval_choice_i32 i state) + let to_list = Fun.id +end) - let trap : Trap.t -> 'a t = function - | Unreachable -> fun _ -> Seq.empty - | _ -> assert false +module CSeq = Make (struct + include Seq - let assertion c t = - if check c t then Seq.return ((), t) - else - let no = Symbolic_value.S.Bool.not c in - let t = { t with pc = no :: t.pc } in - raise (Assertion (Encoding.Expression.to_string c, t)) + let hd s = Seq.uncons s |> Option.get |> fst - (* raise (Types.Trap "out of bounds memory access") *) + let to_list = List.of_seq - let with_thread (f : thread -> 'b) : 'b t = fun t -> Seq.return (f t, t) - - let add_pc (c : Symbolic_value.S.vbool) : unit t = - fun t -> - match c with - | Val (Bool b) -> if b then Seq.return ((), t) else Seq.empty - | _ -> Seq.return ((), { t with pc = c :: t.pc }) - - let run (v : 'a t) (thread : thread) = v thread -end + let to_seq = Fun.id +end) module Explicit = struct - module List = Stdlib.List - module Seq = Stdlib.Seq - - type vbool = Symbolic_value.S.vbool - - type thread = Thread.t - type 'a st = St of (thread -> 'a * thread) [@@unboxed] type 'a t = @@ -274,15 +212,6 @@ module Explicit = struct | Bind _ -> Bind (v, f) [@@inline] - (* let rec bind : type a b. a t -> (a -> b t) -> b t = - * fun v f -> - * match v with - * | Empty -> Empty - * | Retv v -> f v - * | Ret _ | Choice _ -> Bind (v, f) - * | Bind (v, f1) -> Bind (v, fun x -> bind (f1 x) f) - * [@@inline] *) - let select (cond : vbool) : bool t = match cond with Val (Bool b) -> Retv b | _ -> Choice cond [@@inline] @@ -316,9 +245,9 @@ module Explicit = struct else let no = Symbolic_value.S.Bool.not c in let t = { t with pc = no :: t.pc } in - raise (Assertion (Encoding.Expression.to_string c, t)) - | Choice cond -> List.to_seq (eval_choice cond t) - | Choice_i32 i -> List.to_seq (eval_choice_i32 i t) + raise (Assertion (c, t)) + | Choice cond -> CSeq.select cond t + | Choice_i32 i -> CSeq.select_i32 i t let rec run_and_trap : type a. a t -> thread -> (a eval * thread) Seq.t = fun v t -> @@ -342,11 +271,11 @@ module Explicit = struct else let no = Symbolic_value.S.Bool.not c in let t = { t with pc = no :: t.pc } in - Seq.return (EAssert (Encoding.Expression.to_string c), t) - | Choice cond -> - List.to_seq (List.map (fun (v, t) -> (EVal v, t)) (eval_choice cond t)) - | Choice_i32 i -> - List.to_seq (List.map (fun (v, t) -> (EVal v, t)) (eval_choice_i32 i t)) + Seq.return (EAssert c, t) + | Choice cond -> Seq.map (fun (v, t) -> (EVal v, t)) (CSeq.select cond t) + | Choice_i32 i -> Seq.map (fun (v, t) -> (EVal v, t)) (CSeq.select_i32 i t) + + let () = ignore run_and_trap let rec run_up_to : type a. depth:int -> a t -> thread -> (a * thread) Seq.t = fun ~depth v t -> @@ -368,9 +297,9 @@ module Explicit = struct else let no = Symbolic_value.S.Bool.not c in let t = { t with pc = no :: t.pc } in - raise (Assertion (Encoding.Expression.to_string c, t)) - | Choice cond -> List.to_seq (eval_choice cond t) - | Choice_i32 i -> List.to_seq (eval_choice_i32 i t) + raise (Assertion (c, t)) + | Choice cond -> CSeq.select cond t + | Choice_i32 i -> CSeq.select_i32 i t type 'a cont = { k : thread -> 'a -> unit } [@@unboxed] @@ -453,11 +382,10 @@ module Explicit = struct Condition.broadcast q.cond; Mutex.unlock q.mutex - let with_produce q f ~started = + let with_produce q f = Mutex.lock q.mutex; q.producers <- q.producers + 1; Mutex.unlock q.mutex; - started (); match f () with | () -> Mutex.lock q.mutex; @@ -478,35 +406,10 @@ module Explicit = struct } end - module Counter = struct - type t = - { mutable c : int - ; mutex : Mutex.t - ; cond : Condition.t - } - - let incr t = - Mutex.lock t.mutex; - t.c <- t.c + 1; - Condition.broadcast t.cond; - Mutex.unlock t.mutex - - let wait_n t n = - Mutex.lock t.mutex; - while t.c < n do - Condition.wait t.cond t.mutex - done; - Mutex.unlock t.mutex - - let create () = - { c = 0; mutex = Mutex.create (); cond = Condition.create () } - end - module MT = struct type 'a global_state = { w : hold WQ.t (* work *) ; r : ('a eval * thread) WQ.t (* results *) - ; start_counter : Counter.t } and 'a local_state = @@ -539,10 +442,9 @@ module Explicit = struct if check c { thread with solver = st.solver } then cont.k thread (E_st st) () else - let assertion = Encoding.Expression.to_string c in let no = Symbolic_value.S.Bool.not c in let thread = { thread with pc = no :: thread.pc } in - WQ.push (EAssert assertion, thread) st.global.r + WQ.push (EAssert c, thread) st.global.r | Bind (v, f) -> let k thread (E_st st) v = let r = f v in @@ -550,17 +452,16 @@ module Explicit = struct in step thread v { k } st | Choice cond -> - let cases = eval_choice cond { thread with solver = st.solver } in + let cases = CList.select cond { thread with solver = st.solver } in List.iter (fun (case, thread) -> cont.k thread (E_st st) case) cases | Choice_i32 i -> - let cases = eval_choice_i32 i { thread with solver = st.solver } in + let cases = CList.select_i32 i { thread with solver = st.solver } in List.iter (fun (case, thread) -> cont.k thread (E_st st) case) cases let init_global () = let w = WQ.init () in let r = WQ.init () in - let start_counter = Counter.create () in - { w; r; start_counter } + { w; r } let push_first_work g thread t = let k thread _st v = WQ.push (EVal v, thread) g.r in @@ -582,11 +483,9 @@ module Explicit = struct | None -> () in Domain.spawn (fun () -> - WQ.with_produce global.r - ~started:(fun () -> Counter.incr global.start_counter) - (fun () -> - WQ.produce global.w producer; - ignore i (* Format.printf "@.@.PRODUCER END %i@.@." i; *) ) ) + WQ.with_produce global.r (fun () -> + WQ.produce global.w producer; + ignore i (* Format.printf "@.@.PRODUCER END %i@.@." i; *) ) ) let worker_threads_count = 4 @@ -606,7 +505,6 @@ module Explicit = struct let producers = List.init worker_threads_count (fun i -> spawn_producer i global) in - Counter.wait_n global.start_counter worker_threads_count; loop_and_do (WQ.read_as_seq global.r) (fun () -> List.iter Domain.join producers ) end @@ -625,8 +523,7 @@ module Explicit = struct else let no = Symbolic_value.S.Bool.not c in let thread = { thread with pc = no :: thread.pc } in - let assertion = Encoding.Expression.to_string c in - Queue.push (EAssert assertion, thread) qt + Queue.push (EAssert c, thread) qt | Bind (v, f) -> let k thread v = let r = f v in @@ -634,10 +531,10 @@ module Explicit = struct in step thread v { k } q qt | Choice cond -> - let cases = eval_choice cond thread in + let cases = CList.select cond thread in List.iter (fun (case, thread) -> cont.k thread case) cases | Choice_i32 i -> - let cases = eval_choice_i32 i thread in + let cases = CList.select_i32 i thread in List.iter (fun (case, thread) -> cont.k thread case) cases let init thread t = @@ -660,17 +557,11 @@ module Explicit = struct let tail = sequify q qt in Cons (first, Seq.append head tail) - let () = ignore run_and_trap - - let run_and_trap : type a. a t -> thread -> (a eval * thread) Seq.t = + let _run_and_trap : type a. a t -> thread -> (a eval * thread) Seq.t = fun t thread -> let q, qt = init thread t in sequify q qt - let () = ignore MT.run_and_trap - - let () = ignore run_and_trap - let run_and_trap = MT.run_and_trap end @@ -684,9 +575,9 @@ module type T_trap = with type thread := Thread.t and module V := Symbolic_value.S -let list = (module List : T) +let list = (module CList : T) -let seq = (module Seq : T) +let seq = (module CSeq : T) let explicit = (module Explicit : T) diff --git a/src/choice_monad.mli b/src/choice_monad.mli index 297c31b55..240536dfb 100644 --- a/src/choice_monad.mli +++ b/src/choice_monad.mli @@ -1,4 +1,4 @@ -exception Assertion of Choice_monad_intf.assertion * Thread.t +exception Assertion of Encoding.Expression.expr * Thread.t module type T = Choice_monad_intf.Complete @@ -10,14 +10,14 @@ module type T_trap = with type thread := Thread.t and module V := Symbolic_value.S -module List : T +module CList : T -module Seq : T +module CSeq : T module Explicit : sig include T_trap - val run_up_to : depth:int -> 'a t -> Thread.t -> ('a * Thread.t) Stdlib.Seq.t + val run_up_to : depth:int -> 'a t -> Thread.t -> ('a * Thread.t) Seq.t end val choices : (module T) list diff --git a/src/choice_monad_intf.ml b/src/choice_monad_intf.ml index 9771171ab..a911d3a37 100644 --- a/src/choice_monad_intf.ml +++ b/src/choice_monad_intf.ml @@ -32,12 +32,10 @@ module type Complete = sig val run : 'a t -> thread -> ('a * thread) Seq.t end -type assertion = string - type 'a eval = | EVal of 'a | ETrap of Trap.t - | EAssert of assertion + | EAssert of Encoding.Expression.expr module type Complete_with_trap = sig include Complete diff --git a/src/cmd_sym.ml b/src/cmd_sym.ml index 74b43aa43..47115aa69 100644 --- a/src/cmd_sym.ml +++ b/src/cmd_sym.ml @@ -245,7 +245,7 @@ let cmd profiling debug unsafe optimize files = match result with | Choice_monad_intf.EVal (Ok ()) -> None | EAssert assertion -> - Format.printf "Assert failure: %s@." assertion; + Format.printf "Assert failure: %a@." Encoding.Expression.pp assertion; let model = get_model solver thread in Format.printf "Model:@.%s@." model; Some thread diff --git a/src/thread.ml b/src/thread.ml index 47106e3a5..98e041b22 100644 --- a/src/thread.ml +++ b/src/thread.ml @@ -36,3 +36,9 @@ let create () = ; tables = Symbolic_table.init () ; globals = Symbolic_global.init () } + +let clone { solver; pc; memories; tables; globals } = + let memories = Symbolic_memory.clone memories in + let tables = Symbolic_table.clone tables in + let globals = Symbolic_global.clone globals in + { solver; pc; memories; tables; globals }