Skip to content

Commit

Permalink
Improve error messages
Browse files Browse the repository at this point in the history
  • Loading branch information
LasseBlaauwbroek committed Oct 23, 2023
1 parent c5eabe8 commit b308728
Showing 1 changed file with 39 additions and 20 deletions.
59 changes: 39 additions & 20 deletions src/neural_learner.ml
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,25 @@ let log_annotation () =
Stm.(get_ast ~doc (get_current_state ~doc)) in
Pp.string_of_ppcmds loc

let classify_response_message =
let module Response = Api.Reader.PredictionProtocol.Response in
function
| Response.Initialized -> Pp.str "initialized"
| Response.Prediction _ -> Pp.str "prediction"
| Response.TextPrediction _ -> Pp.str "textPrediction"
| Response.Synchronized _ -> Pp.str "synchronized"
| Response.Alignment _ -> Pp.str "alignment"
| Response.Undefined _ -> Pp.str "unknown"

let protocol_error resp exp =
CErrors.user_err Pp.(str "Cap'n Proto protocol error while communicating with proving server. " ++
str "Expected message of type " ++ quote (str exp) ++ str " but received message of type " ++
quote (classify_response_message resp))

let protocol_early_terminate () =
CErrors.user_err Pp.(str "Cap'n Proto protocol error while communicating with proving server. " ++
str "No response was received.")

let init_predict_text capnp_conn =
let module Request = Api.Builder.PredictionProtocol.Request in
let module Response = Api.Reader.PredictionProtocol.Response in
Expand All @@ -244,12 +263,12 @@ let init_predict_text capnp_conn =
Request.Initialize.log_annotation_set init @@ log_annotation ();
ignore(Request.Initialize.data_version_set_reader init Api.Reader.current_version);
match write_read_capnp_message_uninterrupted capnp_conn @@ Request.to_message request with
| None -> CErrors.anomaly Pp.(str "Capnp protocol error 1")
| None -> protocol_early_terminate ()
| Some response ->
let response = Response.of_message response in
match Response.get response with
let response = Response.get @@ Response.of_message response in
match response with
| Response.Initialized -> ()
| _ -> CErrors.anomaly Pp.(str "Capnp protocol error 2")
| _ -> protocol_error response "initialized"

module SDCmap = Symmetric_diff.HMapMake(
struct
Expand Down Expand Up @@ -415,12 +434,12 @@ let sync_context_stack capnp_connection =
if debug_option () then
Feedback.msg_notice Pp.(str "writing message id " ++ int id);
(match write_read_capnp_message_uninterrupted capnp_connection @@ Request.to_message request with
| None -> CErrors.anomaly Pp.(str "Capnp protocol error 1")
| None -> protocol_early_terminate ()
| Some response ->
let response = Response.of_message response in
match Response.get response with
let response = Response.get @@ Response.of_message response in
match response with
| Response.Initialized -> ()
| _ -> CErrors.anomaly Pp.(str "Capnp protocol error 2"));
| _ -> protocol_error response "initialized");
id::rrem
| _, _ -> assert false in
remote_state := loop (curtailed_remote_state, stack);
Expand Down Expand Up @@ -469,10 +488,10 @@ let check_neural_alignment () =
let request = Request.init_root () in
Request.check_alignment_set request;
match write_read_capnp_message_uninterrupted capnp_connection @@ Request.to_message request with
| None -> CErrors.anomaly Pp.(str "Capnp protocol error 1")
| None -> protocol_early_terminate ()
| Some response ->
let response = Response.of_message response in
match Response.get response with
let response = Response.get @@ Response.of_message response in
match response with
| Response.Alignment alignment ->
let find_global_argument = find_global_argument state in
let unaligned_tacs = List.map (fun t -> fst @@ find_tactic tacs t) @@
Expand Down Expand Up @@ -507,7 +526,7 @@ let check_neural_alignment () =
int def_count ++ str " unaligned definitions." ++
tacs_msg ++ defs_msg
)
| _ -> CErrors.anomaly Pp.(str "Capnp protocol error 2")
| _ -> protocol_error response "alignment"

let push_cache () =
if textmode_option () then () (* No caching needed for the text model at the moment *) else
Expand Down Expand Up @@ -547,10 +566,10 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc
let concl = term_repr @@ proof_state_goal ps in
ProofState.text_set state @@ Graph_extractor.proof_state_to_string_safe (hyps, concl) env Evd.empty;
match write_read_capnp_message_uninterrupted capnp_connection @@ Request.to_message request with
| None -> CErrors.anomaly Pp.(str "Capnp protocol error 3a")
| None -> protocol_early_terminate ()
| Some response ->
let response = Response.of_message response in
match Response.get response with
let response = Response.get @@ Response.of_message response in
match response with
| Response.TextPrediction preds ->
let preds = Capnp.Array.to_list preds in
let preds = List.filter_map (fun p ->
Expand All @@ -563,7 +582,7 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc
None
) preds in
preds
| _ -> CErrors.anomaly Pp.(str "Capnp protocol error 4")
| _ -> protocol_error response "textPrediction"

let predict capnp_connection find_global_argument stack_size state tacs env ps =
let module Tactic = Api.Reader.Tactic in
Expand Down Expand Up @@ -605,10 +624,10 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc
; node_local_index = (fun n -> node_local_index (fst @@ G.lower n)) } state ps
~include_metadata:(include_metadata_option ());
match write_read_capnp_message_uninterrupted capnp_connection @@ Request.to_message request with
| None -> CErrors.anomaly Pp.(str "Capnp protocol error 3b")
| None -> protocol_early_terminate ()
| Some response ->
let response = Response.of_message response in
match Response.get response with
let response = Response.get @@ Response.of_message response in
match response with
| Response.Prediction preds ->
let preds = Capnp.Array.to_list preds in
let preds = CList.filter_map (fun (i, p) ->
Expand Down Expand Up @@ -647,7 +666,7 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc
Option.map (fun tac -> tac, conf) @@ Tactic_one_variable.tactic_substitute args tac
) @@ CList.mapi (fun i x -> i, x) preds in
preds
| _ -> CErrors.anomaly Pp.(str "Capnp protocol error 4")
| _ -> protocol_error response "prediction"

type model =
{ tactics : (glob_tactic_expr * int) TacticMap.t }
Expand Down

0 comments on commit b308728

Please sign in to comment.