diff --git a/src/neural_learner.ml b/src/neural_learner.ml index e1b657a..b9adc8c 100644 --- a/src/neural_learner.ml +++ b/src/neural_learner.ml @@ -699,6 +699,40 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc let tactics = add_tactic_info (Global.env ()) tactics tac in tactics in + (* TODO: Filtering out bad proof states: + Occasionally, proof states refer to section variables that have been filtered out by Coq during section + discharge. We filter out such bad proof states. Really, a better solution for this is needed. + To reproduce: + + Section S. Let K (x : Type) := x. + Lemma t : forall x:Type, forall _:x, x. + change (K (forall x:Type, forall _:x, x)). + change (forall x:Type, forall _:x, x). + intros; assumption. + Qed. + End S. + *) + let cache_type name = + let dirp = Global.current_dirpath () in + if Libnames.is_dirpath_prefix_of dirp (Libnames.dirpath name) then `File else `Dependency in + let is_malformed ps = + let ids = Context.Named.to_vars (TS.proof_state_hypotheses ps) in + let rec aux status c = match Constr.kind c with + | Constr.Var id -> + if not @@ Id.Set.mem id ids then + ((match cache_type path with + | `File -> Feedback.msg_warning (Pp.str "Malformed proof state detected") + | _ -> ()); true) + else false + | _ -> Constr.fold aux status c in + let status = aux false (TS.term_repr @@ TS.proof_state_goal ps) in + let status = Context.Named.fold_outside + (fun x status -> match x with + | Context.Named.Declaration.LocalAssum (_, t) -> aux status (TS.term_repr t) + | Context.Named.Declaration.LocalDef (_, c, t) -> aux (aux status (TS.term_repr t)) (TS.term_repr c)) + (TS.proof_state_hypotheses ps) ~init:status in + status in + (* TODO: Drop-in shadowing replacement for mk_outcome. For now, we don't need the proof term and after states. We butcher them to make the payload smaller and faster to compute. *) let mk_outcome before result = @@ -708,7 +742,8 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc mk_proof_state before, Constr.mkProp, Evd.empty, f, [] in let constant = Constant.make1 kn in - let proof_step = List.map (fun outcome -> mk_outcome outcome.before outcome.result) outcomes, + let proof_step = List.filter_map (fun outcome -> + if is_malformed outcome.before then None else Some (mk_outcome outcome.before outcome.result)) outcomes, Option.map tactic_repr tac in (* TODO: The list of proofs should be reversed at some point downstream *) let update proof = Some (Option.default [proof_step] @@ Option.map (fun prf -> proof_step :: prf) proof) in