Skip to content

Commit

Permalink
various small code improvements
Browse files Browse the repository at this point in the history
Signed-off-by: Rudi Horn <dyn-git@rudi-horn.de>
  • Loading branch information
Rudi Horn authored and Rudi Horn committed Nov 13, 2018
1 parent 2497757 commit 46361b2
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 47 deletions.
79 changes: 37 additions & 42 deletions src/memoization.ml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
open !Stdune
open Import
(* open Staged *)
open Fiber.O

(* type used for names of memoized functions *)
Expand Down Expand Up @@ -72,16 +71,11 @@ module Output = struct
type packed_output = Packed : _ output_cache -> packed_output

let unpack : type a. a Kind.t -> packed_output -> a run_state =
fun kind packed ->
match packed with
| Packed output ->
let Type_eq.T = Kind.eq (output.kind) kind |> Option.value_exn in
output.state

let pack kind v = Packed {
state = v;
kind = kind;
}
fun kind (Packed output) ->
let Type_eq.T = Kind.eq (output.kind) kind |> Option.value_exn in
output.state

let pack kind state = Packed { state; kind; }
end

type global_cache_info = {
Expand Down Expand Up @@ -125,25 +119,30 @@ module Cycle_error = struct

let filter ~(name:name) ex =
ex.cycle
|> List.filter ~f:(fun (f : dep_node) -> (Dag.get f).name = name)
|> List.map ~f:(fun (f : dep_node) -> (Dag.get f).input)
|> List.filter_map ~f:(fun f ->
let n = Dag.get f in
Option.some_if (n.name = name) n.input)

let string_of_dep_info di =
Format.sprintf "%s %s" di.name di.input

let serialize ex =
ex.cycle
|> List.map ~f:Dag.get
|> List.map ~f:(fun (f : dep_info) -> Format.sprintf "%s %s" f.name f.input)
|> List.map ~f:(fun n -> Dag.get n |> string_of_dep_info)

let serialize_stack ex =
ex.stack
|> List.map ~f:(fun v -> Stack_frame.di v)
|> List.map ~f:(fun (f : dep_info) -> Format.sprintf "%s %s" f.name f.input)
|> List.map ~f:(fun v -> Stack_frame.di v |> string_of_dep_info)
end

module Function_name = Interned.Make(struct
let initial_size = 1024
let resize_policy = Interned.Greedy
let order = Interned.Fast
end) ()

let global_dep_dag = Dag.create ()
let global_dep_table : (FnId.t * ser_input, dep_node) Hashtbl.t = Hashtbl.create 256

let global_id_table : (name, FnId.t) Hashtbl.t = Hashtbl.create 256
let global_dep_table : (Function_name.t * ser_input, dep_node) Hashtbl.t = Hashtbl.create 1024

module Fdecl = struct
type ('a, 'b) t = 'a -> 'b Fiber.t
Expand Down Expand Up @@ -172,12 +171,13 @@ end
module Memoize = struct
(* fiber context variable keys *)
let call_stack_key = Fiber.Var.create ()
let get_call_stack_int =
let get_call_stack =
Fiber.Var.get call_stack_key (* get call stack *)
>>| Option.value ~default:([], Id.Set.empty) (* default call stack is empty *)
>>| Option.value ~default:[] (* default call stack is empty *)

let get_call_stack =
get_call_stack_int >>| (fun (stack,_) -> stack)
let add_stack_entry frame stack = frame :: stack

let update_stack f stack = Fiber.Var.set call_stack_key stack f

let run_id = Fiber.Var.create ()
let run_memoize fiber =
Expand All @@ -188,20 +188,23 @@ module Memoize = struct

(* set up the fiber so that it both has an up to date call stack
as well as an empty dependency table *)
let wrap_fiber (stack_frame : Stack_frame.t) id (f : 'a Fiber.t) =
let wrap_fiber (stack_frame : Stack_frame.t) (f : 'a Fiber.t) =
(* set the context so that f has the call stack *)
get_call_stack_int
>>| (fun (stack, set) -> stack_frame :: stack, Id.Set.add set id) (* add top entry *)
>>= (fun stack -> Fiber.Var.set call_stack_key stack f) (* update *)
get_call_stack >>| add_stack_entry stack_frame >>= update_stack f

let dump_stack v =
get_call_stack
>>|
(Printf.printf "Memoized function stack:\n";
List.iter ~f:(fun st -> Printf.printf " %s %s\n" (Stack_frame.name st) (Stack_frame.input st)))
List.iter ~f:(
fun st -> Printf.printf " %s %s\n"
(Stack_frame.name st)
(Stack_frame.input st)
)
)
>>| (fun _ -> v)

let get_dependency_node (fn_id : FnId.t) (name : name) (inp : ser_input) : dep_node =
let get_dependency_node (fn_id : Function_name.t) (name : name) (inp : ser_input) : dep_node =
Hashtbl.find global_dep_table (fn_id, inp)
|> function
| None ->
Expand Down Expand Up @@ -236,11 +239,9 @@ module Memoize = struct
} |> raise

let get_name_id (name : name) =
match Hashtbl.find global_id_table name with
match Function_name.get name with
| None ->
let id = FnId.gen () in
Hashtbl.replace global_id_table ~key:name ~data:id;
id
Function_name.make name
| Some id -> id

let get_deps (name : name) (inp : ser_input) =
Expand All @@ -250,11 +251,6 @@ module Memoize = struct
let node = Dag.get n in
node.name,node.input)) (Dag.get dep_info).cache

let rec list_any l =
match l with
| [] -> false
| x :: xs -> if x then true else list_any xs

let dependencies_updated rinfo =
get_run_id
>>= fun run_id ->
Expand All @@ -280,7 +276,7 @@ module Memoize = struct
)
(* if any of the dependencies has changed we need to update *)
>>| (fun dat ->
let rerun = list_any dat in
let rerun = List.exists ~f:(fun x -> x) dat in
if not rerun then
rinfo.last_run <- run_id;
rerun
Expand Down Expand Up @@ -318,15 +314,14 @@ module Memoize = struct

(* the computation that force computes the fiber *)
let recompute inp dep_node comp updatefn =
let di : dep_info = Dag.get dep_node in
(* create an ivar so other threads can wait for the computation to finish *)
let ivar : 'b Fiber.Ivar.t = Fiber.Ivar.create () in
(* create an output cache entry with our ivar *)
set_running_output_cache kind dep_node ivar
(* define the function to update / double check intermediate result *)
(* set context of computation then run it *)
>>= (fun _ ->
comp inp |> wrap_fiber { dep_node } di.id)
comp inp |> wrap_fiber { dep_node })
(* update the output cache with the correct value *)
>>= update_caches kind dep_node out_spec updatefn
(* fill the ivar for any waiting threads *)
Expand Down
3 changes: 0 additions & 3 deletions src/stdune/list.ml
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ let rec find l ~f =
| [] -> None
| x :: l -> if f x then Some x else find l ~f

let exists l ~f =
find l ~f |> Option.is_some

let find_exn l ~f =
match find l ~f with
| Some x -> x
Expand Down
2 changes: 0 additions & 2 deletions src/stdune/list.mli
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ val find : 'a t -> f:('a -> bool ) -> 'a option
val find_exn : 'a t -> f:('a -> bool ) -> 'a
val find_map : 'a t -> f:('a -> 'b option) -> 'b option

val exists : 'a t -> f:('a -> bool ) -> bool

val last : 'a t -> 'a option

val sort : 'a t -> compare:('a -> 'a -> Ordering.t) -> 'a t
Expand Down

0 comments on commit 46361b2

Please sign in to comment.