diff --git a/pytact/data_reader.pyx b/pytact/data_reader.pyx index c5067d0..a14df37 100644 --- a/pytact/data_reader.pyx +++ b/pytact/data_reader.pyx @@ -1360,6 +1360,10 @@ cdef class OnlineDefinitionsReader: """ return Definition._group_by_clusters(self.definitions(full)) + def node_by_id(self, nodeid: NodeId) -> Node: + """Lookup a node inside of this reader by it's local node-id. This is a low-level function.""" + return Node.init(self.graph_index.nodes.size() - 1, nodeid, &self.graph_index) + @contextmanager def online_definitions_initialize(OnlineDefinitionsReader stack, GlobalContextAddition_Reader init) -> Generator[OnlineDefinitionsReader, None, None]: diff --git a/pytact/fake_python_server.py b/pytact/fake_python_server.py index 95cdf85..9f768df 100644 --- a/pytact/fake_python_server.py +++ b/pytact/fake_python_server.py @@ -10,6 +10,7 @@ TacticPredictionGraph, TacticPredictionsGraph, TacticPredictionText, TacticPredictionsText, GlobalContextMessage, CheckAlignmentMessage, CheckAlignmentResponse) +from pytact.visualisation_webserver import wrap_visualization async def text_prediction_loop(context : GlobalContextMessage): tactics = [ 'idtac "is it working?"', 'idtac "yes it is working!"', 'auto' ] @@ -69,8 +70,11 @@ async def graph_prediction_loop(context : GlobalContextMessage, level): raise Exception(f"Capnp protocol error {msg}") + async def run_session(args, record_file, capnp_stream): messages_generator = capnp_message_generator(capnp_stream, args.rpc, record_file) + if args.with_visualization: + messages_generator = await wrap_visualization(messages_generator) if args.mode == 'text': print('Python server running in text mode') await text_prediction_loop(messages_generator) @@ -102,6 +106,8 @@ async def server(): 'replayed through "pytact-fake-coq"') parser.add_argument('--rpc', action='store_true', default = False, help='Communicate through Cap\'n Proto RPC.') + parser.add_argument('--with-visualization', action='store_true', default = False, + help='Launch a visualization webserver') args = parser.parse_args() if args.record_file is not None: diff --git a/pytact/graph_visualize_browse.py b/pytact/graph_visualize_browse.py index 8a4cd4a..6174f71 100644 --- a/pytact/graph_visualize_browse.py +++ b/pytact/graph_visualize_browse.py @@ -62,9 +62,14 @@ class GraphVisualizationData: graphid2path: List[Path] = field(init=False) def __post_init__(self): - self.trans_deps = transitive_closure({d.filename: list(d.dependencies) - for d in self.data.values()}) - self.graphid2path = [d.filename for d in sorted(self.data.values(), key=lambda d: d.graph)] + if len(self.data.values()) == 0: return + if hasattr(list(self.data.values())[0], "dependencies"): + self.trans_deps = transitive_closure({d.filename: list(d.dependencies) + for d in self.data.values()}) + self.graphid2path = [d.filename for d in sorted(self.data.values(), key=lambda d: d.graph)] + else: + self.trans_deps = { p : set() for p in self.data.keys()} + self.graphid2path = list(self.data.keys()) @dataclass class GraphVisualizationOutput: @@ -127,6 +132,12 @@ def render_proof_state_text(ps: ProofState): '
----------------------
' + ps.conclusion_text + '

Raw: ' + ps.text) +def mn(dataset): + if hasattr(dataset, "module_name"): + return dataset.module_name + else: + return "" + class GraphVisualizator: def __init__(self, data: GraphVisualizationData, url_maker: UrlMaker, settings: Settings = Settings()): self.data = data.data @@ -191,7 +202,7 @@ def global_context(self, fname: Path): dataset = self.data[fname] representative = dataset.representative - module_name = dataset.module_name + module_name = mn(dataset) def render_def(dot2, d: Definition): label = make_label(module_name, d.name) @@ -221,7 +232,7 @@ def render_def(dot2, d: Definition): dot.edge(id, id2, arrowtail="odot", dir="both", constraint="false", style="dashed") - for cluster in dataset.clustered_definitions(): + for cluster in dataset.clustered_definitions(full=False): start = str(cluster[0].node) ltail = None @@ -349,7 +360,7 @@ def definition(self, fname: Path, definition: int): proof = [("Proof", self.url_maker.proof(fname, definition))] ext_location = ( location + - [(make_label(self.data[fname].module_name, label), + [(make_label(mn(self.data[fname]), label), self.url_maker.definition(fname, definition))] + proof) return GraphVisualizationOutput(dot.source, ext_location, len(location), text) @@ -405,7 +416,7 @@ def proof(self, fname: Path, definition: int): dot.edge(before_id, qedid) location = (self.path2location(fname) + - [(make_label(self.data[fname].module_name, d.name), + [(make_label(mn(self.data[fname]), d.name), self.url_maker.definition(fname, definition)), ("Proof", self.url_maker.proof(fname, definition))]) return GraphVisualizationOutput(dot.source, location, len(location) - 1) @@ -508,7 +519,7 @@ def nlm(node: Node): dot2.edge('artificial-root', id) location = (self.path2location(fname) + - [(make_label(self.data[fname].module_name, d.name), + [(make_label(mn(self.data[fname]), d.name), self.url_maker.definition(fname, definition)), ("Proof", self.url_maker.proof(fname, definition)), (f"Step {stepi} outcome {outcomei}", diff --git a/pytact/visualisation_webserver.py b/pytact/visualisation_webserver.py index 1755a21..61fe73c 100644 --- a/pytact/visualisation_webserver.py +++ b/pytact/visualisation_webserver.py @@ -10,7 +10,7 @@ except ImportError: import importlib.resources as ilr -from pytact.data_reader import data_reader +from pytact.data_reader import data_reader, GlobalContextMessage, ProofState, CheckAlignmentMessage from pytact.graph_visualize_browse import ( GraphVisualizationData, GraphVisualizator, UrlMaker, Settings, GraphVisualizationOutput) @@ -36,7 +36,10 @@ def create_app(dataset_path: Path) -> Sanic: context_manager = ExitStack() template_path = ilr.files('pytact') / 'templates/' app.config.TEMPLATING_PATH_TO_TEMPLATES = context_manager.enter_context(ilr.as_file(template_path)) - app.ctx.gvd = GraphVisualizationData(context_manager.enter_context(data_reader(dataset_path))) + if isinstance(dataset_path, Path): + app.ctx.gvd = GraphVisualizationData(context_manager.enter_context(data_reader(dataset_path))) + else: + app.ctx.gvd = dataset_path @app.after_server_stop async def teardown(app): @@ -110,6 +113,48 @@ async def root_folder(request, query: Settings): return app + +async def wrap_visualization(context : GlobalContextMessage) -> GlobalContextMessage: + app = create_app(GraphVisualizationData(dict())) + + server = await app.create_server( + port=8000, host="0.0.0.0", return_asyncio_server=True + ) + + await server.startup() + await server.start_serving() + + async def wrapper(context, stack): + data = { Path(f"Slice{i}.bin") : d for i, d in enumerate(stack)} + app.ctx.gvd = GraphVisualizationData(data) + prediction_requests = context.prediction_requests + async for msg in prediction_requests: + # Redirect any exceptions to Coq. Additionally, deal with CancellationError + # thrown when a request from Coq is cancelled + async with context.redirect_exceptions(Exception): + if isinstance(msg, ProofState): + resp = yield msg + yield + await prediction_requests.asend(resp) + elif isinstance(msg, CheckAlignmentMessage): + resp = yield msg + yield + await prediction_requests.asend(resp) + elif isinstance(msg, GlobalContextMessage): + yield GlobalContextMessage(msg.definitions, + msg.tactics, + msg.log_annotation, + wrapper(msg, stack + [msg.definitions]), + msg.redirect_exceptions) + else: + raise Exception(f"Capnp protocol error {msg}") + + return GlobalContextMessage(context.definitions, + context.tactics, + context.log_annotation, + wrapper(context, []), + context.redirect_exceptions) + def main(): parser = argparse.ArgumentParser( diff --git a/src/neural_learner.ml b/src/neural_learner.ml index 001084a..6c712f7 100644 --- a/src/neural_learner.ml +++ b/src/neural_learner.ml @@ -50,8 +50,6 @@ let rpc_option = declare_bool_option ~name:"RPC" ~default:false open Stdint module TacticMap = Map.Make(struct type t = int64 let compare = Stdint.Int64.compare end) -let last_model = Summary.ref ~name:"neural-learner-lastmodel" TacticMap.empty - type location = int module Hashable = struct type t = location @@ -234,7 +232,7 @@ type context_stack = { stack : context_state list ; stack_size : int } -let update_context_stack id tacs env { stack_size; stack } = +let update_context_stack id tacs env_extra env { stack_size; stack } = let state, old_constants, old_inducives, old_section = match stack with | [] -> let (empty_state, ()), _ = CICGraphMonad.run_empty (CICGraphMonad.return ()) @@ -280,7 +278,6 @@ let update_context_stack id tacs env { stack_size; stack } = let open Monad.Make(CICGraphMonad) in let open GB in - let env_extra = Id.Map.empty, Cmap.empty in let updater = let* () = Cset.fold (fun c acc -> acc >> let+ _ = gen_const env env_extra c in ()) new_constants (return ()) in @@ -337,13 +334,13 @@ let sync_context_stack add_global_context = let id = ref 0 in let remote_state = ref [] in let remote_stack_size = ref 0 in - fun ?(keep_cache=true) tacs env -> + fun ?(keep_cache=true) tacs env_extra env -> if debug_option () then Feedback.msg_notice Pp.( str "old remote stack : " ++ prlist_with_sep (fun () -> str "-") int !remote_state ++ fnl () ++ str "old local stack : " ++ prlist_with_sep (fun () -> str "-") (fun { id; _ } -> int id) !context_stack.stack); - let state, ({ stack_size; stack } as cache) = update_context_stack !id tacs env !context_stack in + let state, ({ stack_size; stack } as cache) = update_context_stack !id tacs env_extra env !context_stack in if keep_cache then context_stack := cache; if debug_option () then @@ -414,12 +411,13 @@ let connect_socket my_socket = type connection = { capnp_connection : capnp_connection - ; sync_context_stack : ?keep_cache:bool -> (glob_tactic_expr * location) TacticMap.t -> Environ.env -> + ; sync_context_stack : ?keep_cache:bool -> (glob_tactic_expr * location) TacticMap.t -> + env_extra -> Environ.env -> CICGraphMonad.state * int } type communicator = { add_global_context : (Api.Builder.GlobalContextAddition.t -> unit) -> unit - ; sync_context_stack : ?keep_cache:bool -> (glob_tactic_expr * location) TacticMap.t -> Environ.env -> + ; sync_context_stack : ?keep_cache:bool -> (glob_tactic_expr * location) TacticMap.t -> env_extra -> Environ.env -> CICGraphMonad.state * int ; request_prediction : (Api.Builder.PredictionRequest.t -> unit) -> (Graph_api.ro, Api.Reader.Prediction.t, Api.Reader.array_t) Capnp.Array.t @@ -597,71 +595,15 @@ let get_communicator = | Some comm -> comm -let check_neural_alignment () = - let { sync_context_stack; check_alignment; _ } = get_communicator () in - let module Request = Api.Builder.PredictionProtocol.Request in - let module Response = Api.Reader.PredictionProtocol.Response in - let env = Global.env () in - let tacs = !last_model in - let state, stack_size = sync_context_stack ~keep_cache:false tacs env in - let request = Request.init_root () in - Request.check_alignment_set request; - let unaligned_tacs, unaligned_defs = check_alignment () in - let find_global_argument = find_global_argument state in - let unaligned_tacs = Capnp.Array.map_list ~f:(fun t -> fst @@ find_tactic tacs t) unaligned_tacs in - let unaligned_defs = Capnp.Array.map_list ~f:(fun node -> - let sid = stack_size - 1 - Api.Reader.Node.dep_index_get node in - let nid = Api.Reader.Node.node_index_get_int_exn node in - find_global_argument (sid, nid)) unaligned_defs in - let tacs_msg = if CList.is_empty unaligned_tacs then Pp.mt () else - Pp.(fnl () ++ str "Unaligned tactics: " ++ fnl () ++ - prlist_with_sep fnl (Pptactic.pr_glob_tactic env) unaligned_tacs) in - let defs_msg = - let open Tactic_one_variable in - if CList.is_empty unaligned_defs then Pp.mt () else - Pp.(fnl () ++ str "Unaligned definitions: " ++ fnl () ++ - prlist_with_sep fnl - (function - | TVar id -> Id.print id - | TRef r -> Libnames.pr_path @@ Nametab.path_of_global r - | TOther -> Pp.mt ()) - unaligned_defs) in - let def_count = Id.Map.cardinal state.section_nodes + - Cmap.cardinal state.definition_nodes.constants + - Indmap.cardinal state.definition_nodes.inductives + - Constrmap.cardinal state.definition_nodes.constructors + - ProjMap.cardinal state.definition_nodes.projections in - Feedback.msg_notice Pp.( - str "There are " ++ int (List.length unaligned_tacs) ++ str "/" ++ - int (TacticMap.cardinal tacs) ++ str " unaligned tactics and " ++ - int (List.length unaligned_defs) ++ str "/" ++ - int def_count ++ str " unaligned definitions." ++ - tacs_msg ++ defs_msg - ) - -let push_cache () = - if textmode_option () then () (* No caching needed for the text model at the moment *) else - let { sync_context_stack; _ } = get_communicator () in - (* We don't send the list of tactics, hence the empty list. Tactics are only sent right before - prediction requests are made. *) - let _, stack_size = sync_context_stack TacticMap.empty (Global.env ()) in - if debug_option () then - Feedback.msg_notice Pp.(str "Cache stack size: " ++ int stack_size) - -(* TODO: Hack: Options have the property that they are being read by Coq's stm (multiple times) on every - vernac command. Hence, we can use it to execute arbitrary code. We use to automatically cache. *) -let autocache_option = - let cache = ref false in - Goptions.{ optdepr = false - ; optname = "Tactician Neural Autocache" - ; optkey = ["Tactician"; "Neural"; "Autocache"] - ; optread = (fun () -> (if !cache then push_cache () else ()); !cache) - ; optwrite = (fun v -> cache := v) } -let () = Goptions.declare_bool_option autocache_option +let check_neural_alignment_ref = ref (fun () -> ()) +let check_neural_alignment () = !check_neural_alignment_ref () +let push_cache_ref = ref (fun () -> ()) +let push_cache () = !push_cache_ref () module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStructures) -> struct module LH = Learner_helper.L(TS) open TS + open Graph_generator_learner.ConvertStructures(TS) let predict_text request_text_prediction env ps = let module Tactic = Api.Reader.Tactic in @@ -767,10 +709,13 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc preds type model = - { tactics : (glob_tactic_expr * int) TacticMap.t } + { tactics : (glob_tactic_expr * int) TacticMap.t + ; proofs : (((proof_state * tactic_result) list * tactic option) list * data_status * KerName.t) list } + + let last_model = Summary.ref ~name:"neural-learner-lastmodel" { tactics = TacticMap.empty; proofs = [] } let empty () = - { tactics = TacticMap.empty } + { tactics = TacticMap.empty; proofs = [] } let add_tactic_info env map tac = let tac = Tactic_normalize.tactic_normalize @@ Tactic_normalize.tactic_strict tac in @@ -782,22 +727,64 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc TacticMap.add (Tactic_hash.tactic_hash env base_tactic) (base_tactic, params) map - let learn { tactics } _origin _outcomes tac = - match tac with - | None -> { tactics } + let learn { tactics; proofs } (kn, path, status) outcomes tac = + let tactics = match tac with + | None -> tactics | Some tac -> let tac = tactic_repr tac in let tactics = add_tactic_info (Global.env ()) tactics tac in - last_model := tactics; - { tactics } - - let predict { tactics } = + tactics in + let proofs = + let proof_states = List.map (fun x -> + x.before, x.result) outcomes in + match proofs with + | (ls, pstatus, pkn)::data when KerName.equal kn pkn -> + ((proof_states, tac)::ls, pstatus, pkn)::data + | _ -> ([proof_states, tac], status, kn)::proofs in + let db = { tactics; proofs } in + last_model := db; + db + + let env_extra proofs = + let globrefs = Environ.Globals.view (Global.env ()).env_globals in + let section_vars = Id.Set.of_list @@ + List.map Context.Named.Declaration.get_id @@ Environ.named_context @@ Global.env () in + let constants = globrefs.constants in + (* We are only interested in canonical constants *) + let constants = Cmap_env.fold (fun c _ m -> + let c = Constant.make1 @@ Constant.canonical c in + Cset.add c m) constants Cset.empty in + + + let proof_states = List.map (fun (prf, status, c) -> List.rev prf, status, c) @@ List.rev proofs in + + let env_extra_const = Cset.fold (fun c m -> + let path = Constant.canonical c in + let proof = List.find_opt (fun (p, _, path2) -> KerName.equal path path2) proof_states in + let proof = Option.map (fun (p, _, _) -> p) proof in + let proof = Option.map (List.map (fun (pss, tac) -> + List.map (fun (before, result) -> mk_outcome before result) pss, + Option.map tactic_repr tac)) proof in + Option.fold_left (fun m proof -> Cmap.add c proof m) m proof + ) constants Cmap.empty in + let env_extra_var = Id.Set.fold (fun id m -> + let proof = List.find_opt (fun (_, _, path) -> Id.equal id @@ Label.to_id @@ KerName.label path) proof_states in + let proof = Option.map (fun (p, _, _) -> p) proof in + let proof = Option.map (List.map (fun (pss, tac) -> + List.map (fun (before, result) -> mk_outcome before result) pss, + Option.map tactic_repr tac)) proof in + Option.fold_left (fun m proof -> Id.Map.add id proof m) m proof + ) section_vars Id.Map.empty in + env_extra_var, env_extra_const + + let predict { tactics; proofs } = let { add_global_context; sync_context_stack ; request_prediction; request_text_prediction; _ } = get_communicator () in let env = Global.env () in if not @@ textmode_option () then + let env_extra = env_extra proofs in let state, stack_size = - sync_context_stack ~keep_cache:false tactics env in + sync_context_stack ~keep_cache:false tactics env_extra env in let find_global_argument = find_global_argument state in fun f -> if f = [] then IStream.empty else @@ -815,6 +802,74 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc IStream.of_list preds let evaluate db _ _ = 0., db + + let check_neural_alignment () = + let { sync_context_stack; check_alignment; _ } = get_communicator () in + let module Request = Api.Builder.PredictionProtocol.Request in + let module Response = Api.Reader.PredictionProtocol.Response in + let env = Global.env () in + let { tactics; proofs } = !last_model in + let env_extra = env_extra (!last_model).proofs in + let state, stack_size = sync_context_stack ~keep_cache:false tactics env_extra env in + let request = Request.init_root () in + Request.check_alignment_set request; + let unaligned_tacs, unaligned_defs = check_alignment () in + let find_global_argument = find_global_argument state in + let unaligned_tacs = Capnp.Array.map_list ~f:(fun t -> fst @@ find_tactic tactics t) unaligned_tacs in + let unaligned_defs = Capnp.Array.map_list ~f:(fun node -> + let sid = stack_size - 1 - Api.Reader.Node.dep_index_get node in + let nid = Api.Reader.Node.node_index_get_int_exn node in + find_global_argument (sid, nid)) unaligned_defs in + let tacs_msg = if CList.is_empty unaligned_tacs then Pp.mt () else + Pp.(fnl () ++ str "Unaligned tactics: " ++ fnl () ++ + prlist_with_sep fnl (Pptactic.pr_glob_tactic env) unaligned_tacs) in + let defs_msg = + let open Tactic_one_variable in + if CList.is_empty unaligned_defs then Pp.mt () else + Pp.(fnl () ++ str "Unaligned definitions: " ++ fnl () ++ + prlist_with_sep fnl + (function + | TVar id -> Id.print id + | TRef r -> Libnames.pr_path @@ Nametab.path_of_global r + | TOther -> Pp.mt ()) + unaligned_defs) in + let def_count = Id.Map.cardinal state.section_nodes + + Cmap.cardinal state.definition_nodes.constants + + Indmap.cardinal state.definition_nodes.inductives + + Constrmap.cardinal state.definition_nodes.constructors + + ProjMap.cardinal state.definition_nodes.projections in + Feedback.msg_notice Pp.( + str "There are " ++ int (List.length unaligned_tacs) ++ str "/" ++ + int (TacticMap.cardinal tactics) ++ str " unaligned tactics and " ++ + int (List.length unaligned_defs) ++ str "/" ++ + int def_count ++ str " unaligned definitions." ++ + tacs_msg ++ defs_msg + ) + + let push_cache () = + if textmode_option () then () (* No caching needed for the text model at the moment *) else + let { sync_context_stack; _ } = get_communicator () in + (* We don't send the list of tactics, hence the empty list. Tactics are only sent right before + prediction requests are made. *) + let env_extra = env_extra (!last_model).proofs in + let _, stack_size = sync_context_stack TacticMap.empty env_extra (Global.env ()) in + if debug_option () then + Feedback.msg_notice Pp.(str "Cache stack size: " ++ int stack_size) + + let () = check_neural_alignment_ref := check_neural_alignment + let () = push_cache_ref := push_cache + + (* TODO: Hack: Options have the property that they are being read by Coq's stm (multiple times) on every + vernac command. Hence, we can use it to execute arbitrary code. We use to automatically cache. *) + let autocache_option = + let cache = ref false in + Goptions.{ optdepr = false + ; optname = "Tactician Neural Autocache" + ; optkey = ["Tactician"; "Neural"; "Autocache"] + ; optread = (fun () -> (if !cache then push_cache () else ()); !cache) + ; optwrite = (fun v -> cache := v) } + let () = Goptions.declare_bool_option autocache_option + end let () = register_online_learner "Neural Learner" (module NeuralLearner)