From b308728bf10f017656a300aa9af8d582af2dab62 Mon Sep 17 00:00:00 2001 From: Lasse Blaauwbroek Date: Mon, 23 Oct 2023 02:04:00 +0200 Subject: [PATCH] Improve error messages --- src/neural_learner.ml | 59 ++++++++++++++++++++++++++++--------------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/src/neural_learner.ml b/src/neural_learner.ml index dc5e17c..970e1db 100644 --- a/src/neural_learner.ml +++ b/src/neural_learner.ml @@ -236,6 +236,25 @@ let log_annotation () = Stm.(get_ast ~doc (get_current_state ~doc)) in Pp.string_of_ppcmds loc +let classify_response_message = + let module Response = Api.Reader.PredictionProtocol.Response in + function + | Response.Initialized -> Pp.str "initialized" + | Response.Prediction _ -> Pp.str "prediction" + | Response.TextPrediction _ -> Pp.str "textPrediction" + | Response.Synchronized _ -> Pp.str "synchronized" + | Response.Alignment _ -> Pp.str "alignment" + | Response.Undefined _ -> Pp.str "unknown" + +let protocol_error resp exp = + CErrors.user_err Pp.(str "Cap'n Proto protocol error while communicating with proving server. " ++ + str "Expected message of type " ++ quote (str exp) ++ str " but received message of type " ++ + quote (classify_response_message resp)) + +let protocol_early_terminate () = + CErrors.user_err Pp.(str "Cap'n Proto protocol error while communicating with proving server. " ++ + str "No response was received.") + let init_predict_text capnp_conn = let module Request = Api.Builder.PredictionProtocol.Request in let module Response = Api.Reader.PredictionProtocol.Response in @@ -244,12 +263,12 @@ let init_predict_text capnp_conn = Request.Initialize.log_annotation_set init @@ log_annotation (); ignore(Request.Initialize.data_version_set_reader init Api.Reader.current_version); match write_read_capnp_message_uninterrupted capnp_conn @@ Request.to_message request with - | None -> CErrors.anomaly Pp.(str "Capnp protocol error 1") + | None -> protocol_early_terminate () | Some response -> - let response = Response.of_message response in - match Response.get response with + let response = Response.get @@ Response.of_message response in + match response with | Response.Initialized -> () - | _ -> CErrors.anomaly Pp.(str "Capnp protocol error 2") + | _ -> protocol_error response "initialized" module SDCmap = Symmetric_diff.HMapMake( struct @@ -415,12 +434,12 @@ let sync_context_stack capnp_connection = if debug_option () then Feedback.msg_notice Pp.(str "writing message id " ++ int id); (match write_read_capnp_message_uninterrupted capnp_connection @@ Request.to_message request with - | None -> CErrors.anomaly Pp.(str "Capnp protocol error 1") + | None -> protocol_early_terminate () | Some response -> - let response = Response.of_message response in - match Response.get response with + let response = Response.get @@ Response.of_message response in + match response with | Response.Initialized -> () - | _ -> CErrors.anomaly Pp.(str "Capnp protocol error 2")); + | _ -> protocol_error response "initialized"); id::rrem | _, _ -> assert false in remote_state := loop (curtailed_remote_state, stack); @@ -469,10 +488,10 @@ let check_neural_alignment () = let request = Request.init_root () in Request.check_alignment_set request; match write_read_capnp_message_uninterrupted capnp_connection @@ Request.to_message request with - | None -> CErrors.anomaly Pp.(str "Capnp protocol error 1") + | None -> protocol_early_terminate () | Some response -> - let response = Response.of_message response in - match Response.get response with + let response = Response.get @@ Response.of_message response in + match response with | Response.Alignment alignment -> let find_global_argument = find_global_argument state in let unaligned_tacs = List.map (fun t -> fst @@ find_tactic tacs t) @@ @@ -507,7 +526,7 @@ let check_neural_alignment () = int def_count ++ str " unaligned definitions." ++ tacs_msg ++ defs_msg ) - | _ -> CErrors.anomaly Pp.(str "Capnp protocol error 2") + | _ -> protocol_error response "alignment" let push_cache () = if textmode_option () then () (* No caching needed for the text model at the moment *) else @@ -547,10 +566,10 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc let concl = term_repr @@ proof_state_goal ps in ProofState.text_set state @@ Graph_extractor.proof_state_to_string_safe (hyps, concl) env Evd.empty; match write_read_capnp_message_uninterrupted capnp_connection @@ Request.to_message request with - | None -> CErrors.anomaly Pp.(str "Capnp protocol error 3a") + | None -> protocol_early_terminate () | Some response -> - let response = Response.of_message response in - match Response.get response with + let response = Response.get @@ Response.of_message response in + match response with | Response.TextPrediction preds -> let preds = Capnp.Array.to_list preds in let preds = List.filter_map (fun p -> @@ -563,7 +582,7 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc None ) preds in preds - | _ -> CErrors.anomaly Pp.(str "Capnp protocol error 4") + | _ -> protocol_error response "textPrediction" let predict capnp_connection find_global_argument stack_size state tacs env ps = let module Tactic = Api.Reader.Tactic in @@ -605,10 +624,10 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc ; node_local_index = (fun n -> node_local_index (fst @@ G.lower n)) } state ps ~include_metadata:(include_metadata_option ()); match write_read_capnp_message_uninterrupted capnp_connection @@ Request.to_message request with - | None -> CErrors.anomaly Pp.(str "Capnp protocol error 3b") + | None -> protocol_early_terminate () | Some response -> - let response = Response.of_message response in - match Response.get response with + let response = Response.get @@ Response.of_message response in + match response with | Response.Prediction preds -> let preds = Capnp.Array.to_list preds in let preds = CList.filter_map (fun (i, p) -> @@ -647,7 +666,7 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc Option.map (fun tac -> tac, conf) @@ Tactic_one_variable.tactic_substitute args tac ) @@ CList.mapi (fun i x -> i, x) preds in preds - | _ -> CErrors.anomaly Pp.(str "Capnp protocol error 4") + | _ -> protocol_error response "prediction" type model = { tactics : (glob_tactic_expr * int) TacticMap.t }