From f91502cd469b60d979e0eb452859d80b6ae6d58e Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Mon, 25 Jan 2021 09:54:56 -0800 Subject: [PATCH] Consistent metric tracker (#4928) * Makes the metric tracker more consistent * Turns out we need best_epoch_metrics after all. * Backwards compatibility * Formatting --- allennlp/training/metric_tracker.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/allennlp/training/metric_tracker.py b/allennlp/training/metric_tracker.py index bef1424c977..c2ee482c8e7 100644 --- a/allennlp/training/metric_tracker.py +++ b/allennlp/training/metric_tracker.py @@ -31,13 +31,13 @@ def __init__( metric_name: Union[str, List[str]], patience: Optional[int] = None, ) -> None: - self._best_so_far: Optional[float] = None self._patience = patience + self._best_so_far: Optional[float] = None self._epochs_with_no_improvement = 0 self._is_best_so_far = True - self.best_epoch_metrics: Dict[str, float] = {} self._epoch_number = 0 self.best_epoch: Optional[int] = None + self.best_epoch_metrics: Dict[str, float] = {} if isinstance(metric_name, str): metric_name = [metric_name] @@ -59,6 +59,7 @@ def clear(self) -> None: self._is_best_so_far = True self._epoch_number = 0 self.best_epoch = None + self.best_epoch_metrics.clear() def state_dict(self) -> Dict[str, Any]: """ @@ -66,12 +67,11 @@ def state_dict(self) -> Dict[str, Any]: """ return { "best_so_far": self._best_so_far, - "patience": self._patience, "epochs_with_no_improvement": self._epochs_with_no_improvement, "is_best_so_far": self._is_best_so_far, - "best_epoch_metrics": self.best_epoch_metrics, "epoch_number": self._epoch_number, "best_epoch": self.best_epoch, + "best_epoch_metrics": self.best_epoch_metrics, } def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @@ -79,13 +79,15 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: A `Trainer` can use this to hydrate a metric tracker from a serialized state. """ self._best_so_far = state_dict["best_so_far"] - self._patience = state_dict["patience"] self._epochs_with_no_improvement = state_dict["epochs_with_no_improvement"] self._is_best_so_far = state_dict["is_best_so_far"] - self.best_epoch_metrics = state_dict["best_epoch_metrics"] self._epoch_number = state_dict["epoch_number"] self.best_epoch = state_dict["best_epoch"] + # Even though we don't promise backwards compatibility for the --recover flag, + # it's particularly easy and harmless to provide it here, so we do it. + self.best_epoch_metrics = state_dict.get("best_epoch_metrics", {}) + def add_metrics(self, metrics: Dict[str, float]) -> None: """ Record a new value of the metric and update the various things that depend on it. @@ -103,13 +105,13 @@ def add_metrics(self, metrics: Dict[str, float]) -> None: new_best = (self._best_so_far is None) or (combined_score > self._best_so_far) if new_best: - self.best_epoch = self._epoch_number - self._is_best_so_far = True self._best_so_far = combined_score self._epochs_with_no_improvement = 0 + self._is_best_so_far = True + self.best_epoch = self._epoch_number else: - self._is_best_so_far = False self._epochs_with_no_improvement += 1 + self._is_best_so_far = False self._epoch_number += 1 def is_best_so_far(self) -> bool: