diff --git a/lib/bap_core_theory/bap_core_theory_manager.ml b/lib/bap_core_theory/bap_core_theory_manager.ml index 8d38a9883..b9ed8d9cd 100644 --- a/lib/bap_core_theory/bap_core_theory_manager.ml +++ b/lib/bap_core_theory/bap_core_theory_manager.ml @@ -80,52 +80,43 @@ let declare } let (++) x y = - x >>= fun x -> - y >>| fun y -> + let+ x = x and+ y = y in Value.merge x y [@@inline] + let join1 p q x = - x >>= fun x -> + let* x = x in p !!x ++ q !!x [@@inline] let join1s p q s x = - x >>= fun x -> + let* x = x in p s !!x ++ q s !!x [@@inline] let join2 p q x y = - x >>= fun x -> - y >>= fun y -> + let* x = x and+ y = y in p !!x !!y ++ q !!x !!y [@@inline] let join2s p q s x y = - x >>= fun x -> - y >>= fun y -> + let* x = x and+ y = y in p s !!x !!y ++ q s !!x !!y [@@inline] let join3 p q x y z = - x >>= fun x -> - y >>= fun y -> - z >>= fun z -> + let* x = x and* y = y and* z = z in p !!x !!y !!z ++ q !!x !!y !!z [@@inline] let join3s p q s x y z = - x >>= fun x -> - y >>= fun y -> - z >>= fun z -> + let* x = x and* y = y and* z = z in p s !!x !!y !!z ++ q s !!x !!y !!z [@@inline] let join4 p q r x y z = - r >>= fun r -> - x >>= fun x -> - y >>= fun y -> - z >>= fun z -> + let* r = r and* x = x and* y = y and* z = z in p !!r !!x !!y !!z ++ q !!r !!x !!y !!z [@@inline] diff --git a/lib/bap_primus/bap_primus_machine.ml b/lib/bap_primus/bap_primus_machine.ml index c9b94e905..3565abb45 100644 --- a/lib/bap_primus/bap_primus_machine.ml +++ b/lib/bap_primus/bap_primus_machine.ml @@ -53,30 +53,32 @@ let start,started = ~desc:"Occurs after the system start." module Make(M : Monad.S) = struct - module PE = struct - type t = (project, exn) Monad.Result.result - end - module SM = struct - include Monad.State.Multi.T2(M) - include Monad.State.Multi.Make2(M) - end + type id = Monad.State.Multi.id + module Id = Monad.State.Multi.Id + + type r = (exit_status * project) M.t - type 'a t = (('a,exn) Monad.Result.result,PE.t sm) Monad.Cont.t - and 'a sm = ('a,state) SM.t - and state = { + type state = { + self : id; args : string array; envp : string array; - curr : unit -> unit t; proj : project; - local : State.Bag.t; + conts : (unit -> unit machine) Map.M(Id).t; + local : State.Bag.t Map.M(Id).t; + parent : id Map.M(Id).t; global : State.Bag.t; deathrow : id list; - observations : unit t Observation.observations; + observations : unit machine Observation.observations; restricted : bool; } + and 'a machine = { + run : + reject:(exn -> state -> r) -> + accept:('a -> state -> r) -> + state -> r + } - type 'a machine = 'a t - type 'a c = 'a t + type 'a t = 'a machine type 'a m = 'a M.t type 'a e = ?boot:unit t -> @@ -84,71 +86,92 @@ module Make(M : Monad.S) = struct ?fini:unit t -> (exit_status * project) m effect - module C = Monad.Cont.Make(PE)(struct - type 'a t = 'a sm - include Monad.Make(struct - type 'a t = 'a sm - let return = SM.return - let bind m f = SM.bind m ~f - let map = `Custom SM.map - end) - end) + type _ error = exn - module CM = Monad.Result.Make(Exn)(struct - type 'a t = ('a, PE.t sm) Monad.Cont.t - include C + module Machine = Monad.Make(struct + type 'a t = 'a machine + let return x : 'a t = { + run = fun ~reject:_ ~accept s -> + accept x s + } [@@inline] + + let bind : 'a t -> ('a -> 'b t) -> 'b t = fun x f -> { + run = fun ~reject ~accept s -> + x.run s ~reject ~accept:(fun x s -> + (f x).run ~reject ~accept s) + } [@@inline] + + let map : 'a t -> f:('a -> 'b) -> 'b t = fun x ~f -> { + run = fun ~reject ~accept s -> + x.run s ~reject ~accept:(fun x s -> + accept (f x) s) + } [@@inline] + + let map = `Custom map end) + open Machine - type _ error = exn - open CM + open Machine.Syntax + open Machine.Let - type id = Monad.State.Multi.id - module Id = Monad.State.Multi.Id + let get () = { + run = fun ~reject:_ ~accept s -> accept s s + } [@@inline] + + let gets f = { + run = fun ~reject:_ ~accept s -> accept (f s) s + } [@@inline] + + let put s = { + run = fun ~reject:_ ~accept _ -> accept () s + } [@@inline] - (* lifts state monad to the outer monad *) - let lifts x = CM.lift (C.lift x) + let update f = { + run = fun ~reject:_ ~accept s -> accept () (f s) + } [@@inline] - let with_context cid (f : (unit -> 'a t)) = - lifts (SM.current ()) >>= fun id -> - lifts (SM.switch cid) >>= fun () -> - f () >>= fun r -> - lifts (SM.switch id) >>| fun () -> - r + let current () = gets (fun s -> s.self) + let switch self = update (fun s -> {s with self}) + + let with_context cid (f : (unit -> 'a t)) = { + run = fun ~reject ~accept s -> + (f ()).run {s with self=cid} + ~reject + ~accept:(fun r s'-> + accept r {s' with self=s.self}) + } - let with_global_context (f : (unit -> 'a t)) = - with_context SM.global f + let get_local () : _ t = gets (fun s -> + match Map.find s.local s.self with + | Some bag -> bag + | None -> State.Bag.empty) - let get_local () : _ t = lifts (SM.gets @@ fun s -> s.local) - let get_global () : _ t = with_global_context @@ fun () -> - lifts (SM.gets @@ fun s -> s.global) + let get_global () : _ t = gets (fun s -> s.global) - let set_local local = lifts @@ SM.update @@ fun s -> - {s with local} + let set_local local = update @@ fun s -> { + s with local = Map.set s.local s.self local + } - let set_global global = with_global_context @@ fun () -> - lifts (SM.update @@ fun s -> {s with global}) + let set_global global = update @@ fun s -> { + s with global + } - let is_restricted () : bool t = - with_global_context @@ fun () -> - lifts (SM.gets @@ fun s -> s.restricted) + let is_restricted () : bool t = gets @@ fun s -> s.restricted module Observation = struct type 'a m = 'a t type nonrec 'a observation = 'a observation type nonrec 'a statement = 'a statement - let observations () = lifts (SM.gets @@ fun s -> s.observations) - let unrestricted () = - with_global_context @@ fun () -> - lifts (SM.gets @@ fun s -> - if s.restricted then None - else Some s.observations) + let observations () = gets @@ fun s -> s.observations + let unrestricted () = gets @@ fun s -> + if s.restricted then None + else Some s.observations - let set_observations observations = with_global_context @@ fun () -> - lifts (SM.update @@ fun s -> {s with observations}) + let set_observations observations = + update @@ fun s -> {s with observations} let subscribe key observer = - with_global_context @@ fun () -> observations () >>= fun os -> let obs,sub = Observation.add_observer os key observer in set_observations obs >>| fun () -> sub @@ -157,13 +180,11 @@ module Make(M : Monad.S) = struct subscribe key observer >>| fun _ -> () let watch prov watcher = - with_global_context @@ fun () -> observations () >>= fun os -> let obs,_ = Observation.add_watcher os prov watcher in set_observations obs let cancel sub = - with_global_context @@ fun () -> observations () >>= fun os -> set_observations (Observation.cancel sub os) @@ -171,33 +192,31 @@ module Make(M : Monad.S) = struct module Observation = Observation.Make(struct type 'a t = 'a machine - include CM + include Machine end) - let make key obs = - with_global_context unrestricted >>= function + let make key obs = unrestricted () >>= function | None -> return () | Some os -> Observation.notify os key obs let make_even_if_restricted key obs = - with_global_context observations >>= fun os -> + observations () >>= fun os -> Observation.notify os key obs - let post key ~f = - with_global_context unrestricted >>= function + unrestricted () >>= function | None -> return () | Some os -> Observation.notify_if_observed os key @@ fun k -> f (fun x -> k x) end - let make_get get state = - get () >>= fun states -> + let make_get states state = + states () >>= fun states -> State.Bag.with_state states state ~ready:return ~create:(fun make -> - lifts (SM.get ()) >>= fun {proj} -> + get () >>= fun {proj} -> return (make proj)) let make_put get set state x = @@ -230,73 +249,110 @@ module Make(M : Monad.S) = struct let update s = make_update get put s end - let put proj = with_global_context @@ fun () -> - lifts @@ SM.update @@ fun s -> {s with proj} - let get () = with_global_context @@ fun () -> - lifts (SM.gets @@ fun s -> s.proj) - let project : project t = get () - let gets f = get () >>| f - let update f = get () >>= fun s -> put (f s) - let modify m f = m >>= fun x -> update f >>= fun () -> return x - - - let fork_state () = lifts (SM.fork ()) - let switch_state id : unit c = lifts (SM.switch id) - let store_curr k = - lifts (SM.update (fun s -> {s with curr = fun () -> k (Ok ())})) - - let lift x = lifts (SM.lift x) - let status x = lifts (SM.status x) - let forks () = lifts (SM.forks ()) - let ancestor x = lifts (SM.ancestor x) - let parent () = lifts (SM.parent ()) - let global = SM.global - let current () = lifts (SM.current ()) + let global = Id.zero + + let next_self s = + match Map.max_elt s.parent with + | None -> Id.succ global + | Some (id,_) -> Id.succ id + + let fork_state () = update @@ fun s -> + let child = next_self s in { + s with + local = Map.change s.local child ~f:(fun _ -> + Map.find s.local s.self); + self = child; + parent = Map.set s.parent child s.self; + } + + let switch_state self : unit t = update @@ fun s -> {s with self} + let store_curr k = update @@ fun s -> { + s with conts = Map.set s.conts s.self k + } + + let get_curr = gets @@ fun {conts; self} -> + match Map.find conts self with + | None -> return + | Some k -> k + + + let status x = gets @@ fun {self; conts} -> + if Id.equal self x then `Current else + if Map.mem conts x then `Live else `Dead + + + let forks () = gets @@ fun s -> + let fs = Map.to_sequence s.conts |> Sequence.map ~f:fst in + Sequence.shift_right fs global + + let ancestor xs = gets @@ fun {parent=ps; conts} -> + let is_alive = Map.mem conts in + let rec parent i = match Map.find ps i with + | None -> global + | Some p when is_alive p -> p + | Some p -> parent p in + let rec common xs = + match Base.List.map xs ~f:parent with + | [] -> global + | p :: ps when Base.List.for_all ps ~f:(Id.equal p) -> p + | ps -> common ps in + common xs + + let current () = gets @@ fun s -> s.self + + let parent () = + let* id = current () in + ancestor [id] let notify_fork pid = current () >>= fun cid -> Observation.make forked (pid,cid) - let restrict x = - with_global_context @@ fun () -> - lifts @@ SM.update (fun s -> { - s with restricted = x; - }) + let call : f:(cc:('a -> _ t) -> 'a t) -> 'a t = fun ~f -> { + run = fun ~reject ~accept s -> + let cc x = {run = fun ~reject:_ ~accept:_ s -> accept x s} in + (f ~cc).run ~reject ~accept s + } - let sentence_to_death id = - with_global_context (fun () -> - lifts @@ SM.update (fun s -> { - s with deathrow = id :: s.deathrow - })) + let restrict x = update @@ fun s -> { + s with restricted = x + } - let do_kill id = - lifts @@ SM.kill id + let sentence_to_death id = update @@ fun s -> { + s with deathrow = id :: s.deathrow + } - let execute_sentenced = - with_global_context (fun () -> - lifts @@ SM.get () >>= fun s -> - CM.List.iter s.deathrow ~f:do_kill >>= fun () -> - lifts @@ SM.put {s with deathrow = []}) + let drop id s = { + s with + conts = Map.remove s.conts id; + local = Map.remove s.local id; + } + + let do_kill id = update (drop id) + + let execute_sentenced = update @@ fun s -> + Base.List.fold s.deathrow ~init:{s with deathrow=[]} ~f:(fun s id -> + drop id s) - let switch id : unit c = + let switch id : unit t = is_restricted () >>= function | true -> failwith "switch in the restricted mode" | false -> - C.call ~f:(fun ~cc:k -> + call ~f:(fun ~cc:k -> current () >>= fun pid -> store_curr k >>= fun () -> switch_state id >>= fun () -> - lifts (SM.get ()) >>= fun s -> + get_curr >>= fun k -> execute_sentenced >>= fun () -> Observation.make switched (pid,id) >>= fun () -> - s.curr ()) + k ()) - let fork () : unit c = + let fork () : unit t = is_restricted () >>= function | true -> failwith "fork in the restricted mode" | false -> - C.call ~f:(fun ~cc:k -> + call ~f:(fun ~cc:k -> current () >>= fun pid -> store_curr k >>= fork_state >>= fun () -> @@ -332,34 +388,60 @@ module Make(M : Monad.S) = struct restrict true >>= fun () -> Observation.make_even_if_restricted stopped sys + let fail exn = { + run = fun ~reject ~accept:_ s -> reject exn s + } + + let catch x except = { + run = fun ~reject ~accept s -> + x.run s + ~accept + ~reject:(fun exn s -> + (except exn).run ~reject ~accept s) + } + + let lift : 'a m -> 'a t = fun x -> { + run = fun ~reject:_ ~accept s -> + M.bind x ~f:(fun x -> accept x s) + } + let raise exn = Observation.make raise_exn exn >>= fun () -> fail exn let catch = catch + let args = gets @@ fun s -> s.args + let envp = gets @@ fun s -> s.envp + + let put proj = update @@ fun s -> {s with proj} + let get () = gets @@ fun s -> s.proj + let project : project t = get () + let gets f = get () >>| f + let update f = get () >>= fun s -> put (f s) + let modify m f = m >>= fun x -> update f >>= fun () -> return x + let project = get () let program = project >>| Project.program let arch = project >>| Project.arch - let args = lifts (SM.gets @@ fun s -> s.args) - let envp = lifts (SM.gets @@ fun s -> s.envp) let init_state proj ~args ~envp = { args; envp; - curr = return; + self = global; + conts = Map.empty (module Id); + local = Map.empty (module Id); + parent = Map.empty (module Id); global = State.Bag.empty; - local = State.Bag.empty; observations = Bap_primus_observation.empty; deathrow = []; proj; restricted = true; } - let extract f = - let open SM.Syntax in - SM.switch SM.global >>= fun () -> - SM.gets f + let exec m s = m.run s + ~accept:(fun _ s -> M.return (Normal, s.proj)) + ~reject:(fun exn s -> M.return (Exn exn,s.proj)) let run : 'a t -> 'a e = fun user @@ -383,22 +465,14 @@ module Make(M : Monad.S) = struct fini >>= fun () -> stop sys >>= fun () -> return x in - M.bind - (SM.run - (C.run machine (function - | Ok _ -> extract @@ fun s -> Ok s.proj - | Error err -> extract @@ fun _ -> Error err)) - (init_state proj ~args ~envp)) - (fun (r,{proj}) -> match r with - | Ok _ -> M.return (Normal, proj) - | Error e -> M.return (Exn e, proj)) + exec machine (init_state proj ~args ~envp) module Syntax = struct - include CM.Syntax + include Machine.Syntax let (>>>) = Observation.observe end - include (CM : Monad.S with type 'a t := 'a t - and module Syntax := Syntax) + include (Machine : Monad.S with type 'a t := 'a t + and module Syntax := Syntax) end diff --git a/lib/knowledge/bap_knowledge.ml b/lib/knowledge/bap_knowledge.ml index b4af75207..60178844d 100644 --- a/lib/knowledge/bap_knowledge.ml +++ b/lib/knowledge/bap_knowledge.ml @@ -1248,7 +1248,7 @@ module Dict = struct type 'r visitor = { visit : 'a. 'a key -> 'a -> 'r -> 'r; - } [@@unboxed] + } let rec foreach x ~init f = match x with | T0 -> init @@ -1274,7 +1274,7 @@ module Dict = struct type ('b,'r) app = { app : 'a. 'a key -> 'a -> 'b -> 'r - } [@@unboxed] + } let cmp x y = Key.compare x y [@@inline] let eq x y = Key.compare x y = 0 [@@inline] @@ -1517,7 +1517,7 @@ module Dict = struct (* [merge k x y] *) type merge = { merge : 'a. 'a key -> 'a -> 'a -> 'a - } [@@unboxed] + } (* we could use a GADT, but it couldn't be unboxed, an since we could create merge functions a lot, it is better not to allocate them.*) @@ -2030,17 +2030,42 @@ module Knowledge = struct context = Dict.empty; } - module State = struct - include Monad.State.T1(Env)(Monad.Ident) - include Monad.State.Make(Env)(Monad.Ident) - end + + type 'a knowledge = { + run : 'r. reject:(conflict -> 'r) -> accept:('a -> state -> 'r) -> state -> 'r + } module Knowledge = struct - type 'a t = ('a,conflict) Result.t State.t - include Monad.Result.Make(Conflict)(State) - end + type 'a t = 'a knowledge + type _ error = conflict + let fail p : 'a t = {run = fun ~reject ~accept:_ _ -> reject p} + let catch x err = { + run = fun ~reject ~accept s -> x.run s + ~accept + ~reject:(fun p -> (err p).run ~reject ~accept s) + } - type 'a knowledge = 'a Knowledge.t + include Monad.Make(struct + type 'a t = 'a knowledge + let return x : 'a t = { + run = fun ~reject:_ ~accept s -> accept x s + } + + let bind : 'a t -> ('a -> 'b t) -> 'b t = fun x f -> { + run = fun ~reject ~accept s -> x.run s + ~reject + ~accept:(fun x s -> + (f x).run ~reject ~accept s) + } + + let map : 'a t -> f:('a -> 'b) -> 'b t = fun x ~f -> { + run = fun ~reject ~accept s -> x.run s + ~reject + ~accept:(fun x s -> accept (f x) s) + } + let map = `Custom map + end) + end open Knowledge.Syntax @@ -2206,10 +2231,21 @@ module Knowledge = struct end end - let get () = Knowledge.lift (State.get ()) - let put s = Knowledge.lift (State.put s) - let gets f = Knowledge.lift (State.gets f) - let update f = Knowledge.lift (State.update f) + let get () : state knowledge = { + run = fun ~reject:_ ~accept s -> accept s s + } [@@inline] + + let put s = { + run = fun ~reject:_ ~accept _ -> accept () s + } [@@inline] + + let gets f = { + run = fun ~reject:_ ~accept s -> accept (f s) s + } [@@inline] + + let update f = { + run = fun ~reject:_ ~accept s -> accept () (f s) + } [@@inline] let objects {Class.name} = get () >>| fun {classes} -> @@ -2823,19 +2859,42 @@ module Knowledge = struct | None -> Knowledge.fail (Empty p) | Some x -> !!x - let (>>=?) x f = - x >>= function - | None -> Knowledge.return None - | Some x -> f x - - let (>>|?) x f = - x >>| function - | None -> None - | Some x -> f x + let (>>=?) x f = { + run = fun ~reject ~accept s -> + x.run s + ~reject + ~accept:(fun x s -> + match x with + | None -> accept None s + | Some x -> + (f x).run ~accept ~reject s) + } [@@inline] + + let (>>|?) x f = { + run = fun ~reject ~accept s -> + x.run s + ~reject + ~accept:(fun x s -> match x with + | None -> accept None s + | Some x -> accept (f x) s) + } [@@inline] let (let*?) = (>>=?) let (let+?) = (>>|?) + let (and+) x y = { + run = fun ~reject ~accept s -> + x.run s + ~reject + ~accept:(fun x s -> + y.run s + ~reject + ~accept:(fun y s -> + accept (x,y) s)) + } [@@inline] + + let (and*) = (and+) + let (.$[]) v s = Value.get s v let (.$[]<-) v s x = Value.put s v x @@ -2876,11 +2935,9 @@ module Knowledge = struct | None -> Value.empty cls | Some x -> Value.create cls x - let run cls obj s = - match State.run (obj >>= get_value cls) s with - | Ok x,s -> Ok (x,s) - | Error err,_ -> Error err - + let run cls obj s = (obj >>= get_value cls).run s + ~reject:(fun err -> Error err) + ~accept:(fun x s -> Ok (x,s)) let pp_fullname ~package ppf {package=p; name} = if String.equal package p diff --git a/lib/monads/monads.mli b/lib/monads/monads.mli index 52088771c..a5ebc7504 100644 --- a/lib/monads/monads.mli +++ b/lib/monads/monads.mli @@ -1890,8 +1890,22 @@ module Std : sig type 'a contexts type id + + (** Fork Identifiers. *) module Id : sig include Identifiable with type t = id + + + (** [zero] is the initial identifier. + + @since 2.4.0 *) + val zero : t + + + (** [succ id] the successor of [id]. + + @since 2.4.0 *) + val succ : t -> t val pp : Format.formatter -> id -> unit end diff --git a/lib/monads/monads_monad.mli b/lib/monads/monads_monad.mli index d28b7a2fa..6cf4bf828 100644 --- a/lib/monads/monads_monad.mli +++ b/lib/monads/monads_monad.mli @@ -311,6 +311,8 @@ module State : sig type id module Id : sig include Identifiable.S with type t = id + val zero : t + val succ : t -> t val pp : Caml.Format.formatter -> id -> unit end