From efbf731bff5298d95a5ba86ea139641e26abebee Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 8 Apr 2024 23:10:52 +0800 Subject: [PATCH 1/2] feat: add the attribute best_epoch to record the epoch num when the best loss is got; --- pypots/base.py | 4 +++- pypots/classification/base.py | 1 + pypots/clustering/base.py | 1 + pypots/clustering/crli/model.py | 1 + pypots/clustering/vader/model.py | 1 + pypots/forecasting/base.py | 1 + pypots/imputation/base.py | 1 + pypots/imputation/csdi/model.py | 1 + pypots/imputation/gpvae/model.py | 1 + pypots/imputation/usgan/model.py | 1 + 10 files changed, 12 insertions(+), 1 deletion(-) diff --git a/pypots/base.py b/pypots/base.py index 6fb7d1a6..ac3287c5 100644 --- a/pypots/base.py +++ b/pypots/base.py @@ -450,6 +450,8 @@ class BaseNNModel(BaseModel): The criteria to judge whether the model's performance is the best so far. Usually the lower, the better. + best_epoch : int, default = -1, + The epoch number when the best loss is got. Notes ----- @@ -494,8 +496,8 @@ def __init__( self.model = None self.optimizer = None self.best_model_dict = None - # WDU: may enable users to customize the criteria in the future self.best_loss = float("inf") + self.best_epoch = -1 def _print_model_size(self) -> None: """Print the number of trainable parameters in the initialized NN model.""" diff --git a/pypots/classification/base.py b/pypots/classification/base.py index 19f73d2b..7adee801 100644 --- a/pypots/classification/base.py +++ b/pypots/classification/base.py @@ -337,6 +337,7 @@ def _train_model( ) if mean_loss < self.best_loss: + self.best_epoch = epoch self.best_loss = mean_loss self.best_model_dict = self.model.state_dict() self.patience = self.original_patience diff --git a/pypots/clustering/base.py b/pypots/clustering/base.py index b0bb3336..62d93acd 100644 --- a/pypots/clustering/base.py +++ b/pypots/clustering/base.py @@ -336,6 +336,7 @@ def _train_model( ) if mean_loss < self.best_loss: + self.best_epoch = epoch self.best_loss = mean_loss self.best_model_dict = self.model.state_dict() self.patience = self.original_patience diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py index 2eff7647..3886fe15 100644 --- a/pypots/clustering/crli/model.py +++ b/pypots/clustering/crli/model.py @@ -296,6 +296,7 @@ def _train_model( ) if mean_loss < self.best_loss: + self.best_epoch = epoch self.best_loss = mean_loss self.best_model_dict = self.model.state_dict() self.patience = self.original_patience diff --git a/pypots/clustering/vader/model.py b/pypots/clustering/vader/model.py index eafbfddb..8436100f 100644 --- a/pypots/clustering/vader/model.py +++ b/pypots/clustering/vader/model.py @@ -309,6 +309,7 @@ def _train_model( ) if mean_loss < self.best_loss: + self.best_epoch = epoch self.best_loss = mean_loss self.best_model_dict = self.model.state_dict() self.patience = self.original_patience diff --git a/pypots/forecasting/base.py b/pypots/forecasting/base.py index 1ece900c..31a6fcc7 100644 --- a/pypots/forecasting/base.py +++ b/pypots/forecasting/base.py @@ -331,6 +331,7 @@ def _train_model( ) if mean_loss < self.best_loss: + self.best_epoch = epoch self.best_loss = mean_loss self.best_model_dict = self.model.state_dict() self.patience = self.original_patience diff --git a/pypots/imputation/base.py b/pypots/imputation/base.py index 30a87a42..c381ef4d 100644 --- a/pypots/imputation/base.py +++ b/pypots/imputation/base.py @@ -334,6 +334,7 @@ def _train_model( ) if mean_loss < self.best_loss: + self.best_epoch = epoch self.best_loss = mean_loss self.best_model_dict = self.model.state_dict() self.patience = self.original_patience diff --git a/pypots/imputation/csdi/model.py b/pypots/imputation/csdi/model.py index 62911931..38157a5a 100644 --- a/pypots/imputation/csdi/model.py +++ b/pypots/imputation/csdi/model.py @@ -283,6 +283,7 @@ def _train_model( ) if mean_loss < self.best_loss: + self.best_epoch = epoch self.best_loss = mean_loss self.best_model_dict = self.model.state_dict() self.patience = self.original_patience diff --git a/pypots/imputation/gpvae/model.py b/pypots/imputation/gpvae/model.py index e1bf5120..df38f3bb 100644 --- a/pypots/imputation/gpvae/model.py +++ b/pypots/imputation/gpvae/model.py @@ -313,6 +313,7 @@ def _train_model( ) if mean_loss < self.best_loss: + self.best_epoch = epoch self.best_loss = mean_loss self.best_model_dict = self.model.state_dict() self.patience = self.original_patience diff --git a/pypots/imputation/usgan/model.py b/pypots/imputation/usgan/model.py index cb90f092..4db3c2c5 100644 --- a/pypots/imputation/usgan/model.py +++ b/pypots/imputation/usgan/model.py @@ -330,6 +330,7 @@ def _train_model( ) if mean_loss < self.best_loss: + self.best_epoch = epoch self.best_loss = mean_loss self.best_model_dict = self.model.state_dict() self.patience = self.original_patience From 93062a244dbcee3e02776ab4611a4351fedbef82 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 8 Apr 2024 23:20:46 +0800 Subject: [PATCH 2/2] feat: remind the best epoch num after training is finished; --- pypots/classification/base.py | 4 +++- pypots/clustering/base.py | 4 +++- pypots/clustering/crli/model.py | 4 +++- pypots/clustering/vader/model.py | 4 +++- pypots/forecasting/base.py | 4 +++- pypots/imputation/base.py | 4 +++- pypots/imputation/csdi/model.py | 4 +++- pypots/imputation/gpvae/model.py | 4 +++- pypots/imputation/usgan/model.py | 4 +++- 9 files changed, 27 insertions(+), 9 deletions(-) diff --git a/pypots/classification/base.py b/pypots/classification/base.py index 7adee801..50bd5afd 100644 --- a/pypots/classification/base.py +++ b/pypots/classification/base.py @@ -377,7 +377,9 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info("Finished training.") + logger.info( + f"Finished training. The best model is from epoch#{self.best_epoch}." + ) @abstractmethod def fit( diff --git a/pypots/clustering/base.py b/pypots/clustering/base.py index 62d93acd..47f70a18 100644 --- a/pypots/clustering/base.py +++ b/pypots/clustering/base.py @@ -370,7 +370,9 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info("Finished training.") + logger.info( + f"Finished training. The best model is from epoch#{self.best_epoch}." + ) @abstractmethod def fit( diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py index 3886fe15..e6b8c23f 100644 --- a/pypots/clustering/crli/model.py +++ b/pypots/clustering/crli/model.py @@ -336,7 +336,9 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info("Finished training.") + logger.info( + f"Finished training. The best model is from epoch#{self.best_epoch}." + ) def fit( self, diff --git a/pypots/clustering/vader/model.py b/pypots/clustering/vader/model.py index 8436100f..1d3eaa73 100644 --- a/pypots/clustering/vader/model.py +++ b/pypots/clustering/vader/model.py @@ -349,7 +349,9 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info("Finished training.") + logger.info( + f"Finished training. The best model is from epoch#{self.best_epoch}." + ) def fit( self, diff --git a/pypots/forecasting/base.py b/pypots/forecasting/base.py index 31a6fcc7..2cdf641d 100644 --- a/pypots/forecasting/base.py +++ b/pypots/forecasting/base.py @@ -371,7 +371,9 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info("Finished training.") + logger.info( + f"Finished training. The best model is from epoch#{self.best_epoch}." + ) @abstractmethod def fit( diff --git a/pypots/imputation/base.py b/pypots/imputation/base.py index c381ef4d..284d1af2 100644 --- a/pypots/imputation/base.py +++ b/pypots/imputation/base.py @@ -374,7 +374,9 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info("Finished training.") + logger.info( + f"Finished training. The best model is from epoch#{self.best_epoch}." + ) @abstractmethod def fit( diff --git a/pypots/imputation/csdi/model.py b/pypots/imputation/csdi/model.py index 38157a5a..b30e8de9 100644 --- a/pypots/imputation/csdi/model.py +++ b/pypots/imputation/csdi/model.py @@ -323,7 +323,9 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info("Finished training.") + logger.info( + f"Finished training. The best model is from epoch#{self.best_epoch}." + ) def fit( self, diff --git a/pypots/imputation/gpvae/model.py b/pypots/imputation/gpvae/model.py index df38f3bb..9d33d275 100644 --- a/pypots/imputation/gpvae/model.py +++ b/pypots/imputation/gpvae/model.py @@ -353,7 +353,9 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info("Finished training.") + logger.info( + f"Finished training. The best model is from epoch#{self.best_epoch}." + ) def fit( self, diff --git a/pypots/imputation/usgan/model.py b/pypots/imputation/usgan/model.py index 4db3c2c5..89a674f3 100644 --- a/pypots/imputation/usgan/model.py +++ b/pypots/imputation/usgan/model.py @@ -370,7 +370,9 @@ def _train_model( if np.isnan(self.best_loss): raise ValueError("Something is wrong. best_loss is Nan after training.") - logger.info("Finished training.") + logger.info( + f"Finished training. The best model is from epoch#{self.best_epoch}." + ) def fit( self,