From f28353a24115d6b3742f44084bf239d362e94adf Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 19 Sep 2024 08:18:25 +0800 Subject: [PATCH] fix: model_saving_strategy=best does not work; --- pypots/classification/base.py | 4 ++-- pypots/classification/brits/model.py | 2 +- pypots/classification/grud/model.py | 2 +- pypots/classification/raindrop/model.py | 2 +- pypots/clustering/crli/model.py | 4 ++-- pypots/clustering/vader/model.py | 6 +++--- pypots/forecasting/base.py | 4 ++-- pypots/forecasting/csdi/model.py | 6 +++--- pypots/imputation/autoformer/model.py | 2 +- pypots/imputation/base.py | 4 ++-- pypots/imputation/brits/model.py | 2 +- pypots/imputation/crossformer/model.py | 2 +- pypots/imputation/csdi/model.py | 6 +++--- pypots/imputation/dlinear/model.py | 2 +- pypots/imputation/etsformer/model.py | 2 +- pypots/imputation/fedformer/model.py | 2 +- pypots/imputation/film/model.py | 2 +- pypots/imputation/frets/model.py | 2 +- pypots/imputation/gpvae/model.py | 6 +++--- pypots/imputation/grud/model.py | 2 +- pypots/imputation/imputeformer/model.py | 2 +- pypots/imputation/informer/model.py | 2 +- pypots/imputation/itransformer/model.py | 2 +- pypots/imputation/koopa/model.py | 2 +- pypots/imputation/micn/model.py | 2 +- pypots/imputation/moderntcn/model.py | 2 +- pypots/imputation/mrnn/model.py | 2 +- pypots/imputation/nonstationary_transformer/model.py | 2 +- pypots/imputation/patchtst/model.py | 2 +- pypots/imputation/pyraformer/model.py | 2 +- pypots/imputation/reformer/model.py | 2 +- pypots/imputation/revinscinet/model.py | 2 +- pypots/imputation/saits/model.py | 2 +- pypots/imputation/scinet/model.py | 2 +- pypots/imputation/stemgnn/model.py | 2 +- pypots/imputation/tcn/model.py | 2 +- pypots/imputation/tefn/model.py | 2 +- pypots/imputation/tide/model.py | 2 +- pypots/imputation/timemixer/model.py | 2 +- pypots/imputation/timesnet/model.py | 2 +- pypots/imputation/transformer/model.py | 2 +- pypots/imputation/usgan/model.py | 6 +++--- 42 files changed, 56 insertions(+), 56 deletions(-) diff --git a/pypots/classification/base.py b/pypots/classification/base.py index e1848602..75a3a3bb 100644 --- a/pypots/classification/base.py +++ b/pypots/classification/base.py @@ -347,8 +347,8 @@ def _train_model( # save the model if necessary self._auto_save_model_if_necessary( - confirm_saving=self.best_epoch == epoch, - saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}", + confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better", + saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}", ) if os.getenv("enable_tuning", False): diff --git a/pypots/classification/brits/model.py b/pypots/classification/brits/model.py index 85ddd798..e4719f05 100644 --- a/pypots/classification/brits/model.py +++ b/pypots/classification/brits/model.py @@ -228,7 +228,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/classification/grud/model.py b/pypots/classification/grud/model.py index 5fb84671..a8b1ed50 100644 --- a/pypots/classification/grud/model.py +++ b/pypots/classification/grud/model.py @@ -204,7 +204,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/classification/raindrop/model.py b/pypots/classification/raindrop/model.py index f599b204..aafac455 100644 --- a/pypots/classification/raindrop/model.py +++ b/pypots/classification/raindrop/model.py @@ -248,7 +248,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py index f1838af3..39a18bdc 100644 --- a/pypots/clustering/crli/model.py +++ b/pypots/clustering/crli/model.py @@ -295,7 +295,7 @@ def _train_model( # save the model if necessary self._auto_save_model_if_necessary( - confirm_saving=self.best_epoch == epoch, + confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better", saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}", ) @@ -354,7 +354,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/clustering/vader/model.py b/pypots/clustering/vader/model.py index 0a6e6418..8e14b93f 100644 --- a/pypots/clustering/vader/model.py +++ b/pypots/clustering/vader/model.py @@ -303,8 +303,8 @@ def _train_model( # save the model if necessary self._auto_save_model_if_necessary( - confirm_saving=self.best_epoch == epoch, - saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}", + confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better", + saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}", ) if os.getenv("enable_tuning", False): @@ -367,7 +367,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/forecasting/base.py b/pypots/forecasting/base.py index 5113876d..0931c791 100644 --- a/pypots/forecasting/base.py +++ b/pypots/forecasting/base.py @@ -346,8 +346,8 @@ def _train_model( # save the model if necessary self._auto_save_model_if_necessary( - confirm_saving=self.best_epoch == epoch, - saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}", + confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better", + saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}", ) if os.getenv("enable_tuning", False): diff --git a/pypots/forecasting/csdi/model.py b/pypots/forecasting/csdi/model.py index 8492f87b..ea7d5856 100644 --- a/pypots/forecasting/csdi/model.py +++ b/pypots/forecasting/csdi/model.py @@ -307,8 +307,8 @@ def _train_model( # save the model if necessary self._auto_save_model_if_necessary( - confirm_saving=self.best_epoch == epoch, - saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}", + confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better", + saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}", ) if os.getenv("enable_tuning", False): @@ -379,7 +379,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/autoformer/model.py b/pypots/imputation/autoformer/model.py index 38a044e5..01213d63 100644 --- a/pypots/imputation/autoformer/model.py +++ b/pypots/imputation/autoformer/model.py @@ -234,7 +234,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/base.py b/pypots/imputation/base.py index 1a20dc72..0f43bc25 100644 --- a/pypots/imputation/base.py +++ b/pypots/imputation/base.py @@ -346,8 +346,8 @@ def _train_model( # save the model if necessary self._auto_save_model_if_necessary( - confirm_saving=self.best_epoch == epoch, - saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}", + confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better", + saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}", ) if os.getenv("enable_tuning", False): diff --git a/pypots/imputation/brits/model.py b/pypots/imputation/brits/model.py index 06ec6f4e..6311e321 100644 --- a/pypots/imputation/brits/model.py +++ b/pypots/imputation/brits/model.py @@ -216,7 +216,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/crossformer/model.py b/pypots/imputation/crossformer/model.py index 5e8c3016..3b37c849 100644 --- a/pypots/imputation/crossformer/model.py +++ b/pypots/imputation/crossformer/model.py @@ -240,7 +240,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/csdi/model.py b/pypots/imputation/csdi/model.py index 19c3ecfd..b6d22c7f 100644 --- a/pypots/imputation/csdi/model.py +++ b/pypots/imputation/csdi/model.py @@ -287,8 +287,8 @@ def _train_model( # save the model if necessary self._auto_save_model_if_necessary( - confirm_saving=self.best_epoch == epoch, - saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}", + confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better", + saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}", ) if os.getenv("enable_tuning", False): @@ -363,7 +363,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/dlinear/model.py b/pypots/imputation/dlinear/model.py index ea65df87..28809888 100644 --- a/pypots/imputation/dlinear/model.py +++ b/pypots/imputation/dlinear/model.py @@ -211,7 +211,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/etsformer/model.py b/pypots/imputation/etsformer/model.py index 7ecb0c03..3851c741 100644 --- a/pypots/imputation/etsformer/model.py +++ b/pypots/imputation/etsformer/model.py @@ -234,7 +234,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/fedformer/model.py b/pypots/imputation/fedformer/model.py index 05d8e7cd..4a07104e 100644 --- a/pypots/imputation/fedformer/model.py +++ b/pypots/imputation/fedformer/model.py @@ -248,7 +248,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/film/model.py b/pypots/imputation/film/model.py index 1f505e64..389f527c 100644 --- a/pypots/imputation/film/model.py +++ b/pypots/imputation/film/model.py @@ -228,7 +228,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/frets/model.py b/pypots/imputation/frets/model.py index 0fc730b7..5ff772b8 100644 --- a/pypots/imputation/frets/model.py +++ b/pypots/imputation/frets/model.py @@ -210,7 +210,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/gpvae/model.py b/pypots/imputation/gpvae/model.py index f8ff2193..28e73e61 100644 --- a/pypots/imputation/gpvae/model.py +++ b/pypots/imputation/gpvae/model.py @@ -312,8 +312,8 @@ def _train_model( # save the model if necessary self._auto_save_model_if_necessary( - confirm_saving=self.best_epoch == epoch, - saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}", + confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better", + saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}", ) if os.getenv("enable_tuning", False): @@ -377,7 +377,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/grud/model.py b/pypots/imputation/grud/model.py index 269888d0..391ad7b3 100644 --- a/pypots/imputation/grud/model.py +++ b/pypots/imputation/grud/model.py @@ -201,7 +201,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/imputeformer/model.py b/pypots/imputation/imputeformer/model.py index 92daf873..1b313686 100644 --- a/pypots/imputation/imputeformer/model.py +++ b/pypots/imputation/imputeformer/model.py @@ -253,7 +253,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/informer/model.py b/pypots/imputation/informer/model.py index 07788534..26a23c50 100644 --- a/pypots/imputation/informer/model.py +++ b/pypots/imputation/informer/model.py @@ -228,7 +228,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/itransformer/model.py b/pypots/imputation/itransformer/model.py index 46774670..a3c179ec 100644 --- a/pypots/imputation/itransformer/model.py +++ b/pypots/imputation/itransformer/model.py @@ -255,7 +255,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/koopa/model.py b/pypots/imputation/koopa/model.py index 60cbc482..6e5285a8 100644 --- a/pypots/imputation/koopa/model.py +++ b/pypots/imputation/koopa/model.py @@ -241,7 +241,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/micn/model.py b/pypots/imputation/micn/model.py index edfa8d3d..cfa925f4 100644 --- a/pypots/imputation/micn/model.py +++ b/pypots/imputation/micn/model.py @@ -222,7 +222,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/moderntcn/model.py b/pypots/imputation/moderntcn/model.py index e408f5eb..5ba2790b 100644 --- a/pypots/imputation/moderntcn/model.py +++ b/pypots/imputation/moderntcn/model.py @@ -252,7 +252,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/mrnn/model.py b/pypots/imputation/mrnn/model.py index 40f8dcac..bcd63861 100644 --- a/pypots/imputation/mrnn/model.py +++ b/pypots/imputation/mrnn/model.py @@ -218,7 +218,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/nonstationary_transformer/model.py b/pypots/imputation/nonstationary_transformer/model.py index 814cff3d..1866a96b 100644 --- a/pypots/imputation/nonstationary_transformer/model.py +++ b/pypots/imputation/nonstationary_transformer/model.py @@ -244,7 +244,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/patchtst/model.py b/pypots/imputation/patchtst/model.py index 81d09fc7..f61d5328 100644 --- a/pypots/imputation/patchtst/model.py +++ b/pypots/imputation/patchtst/model.py @@ -264,7 +264,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/pyraformer/model.py b/pypots/imputation/pyraformer/model.py index 576e7c87..48dbb26f 100644 --- a/pypots/imputation/pyraformer/model.py +++ b/pypots/imputation/pyraformer/model.py @@ -240,7 +240,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/reformer/model.py b/pypots/imputation/reformer/model.py index 76b23cb4..5b795d26 100644 --- a/pypots/imputation/reformer/model.py +++ b/pypots/imputation/reformer/model.py @@ -241,7 +241,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/revinscinet/model.py b/pypots/imputation/revinscinet/model.py index 20a78807..4612bb73 100644 --- a/pypots/imputation/revinscinet/model.py +++ b/pypots/imputation/revinscinet/model.py @@ -246,7 +246,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/saits/model.py b/pypots/imputation/saits/model.py index cecb3cbe..930a4f84 100644 --- a/pypots/imputation/saits/model.py +++ b/pypots/imputation/saits/model.py @@ -268,7 +268,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/scinet/model.py b/pypots/imputation/scinet/model.py index 86caceb8..515dfc2e 100644 --- a/pypots/imputation/scinet/model.py +++ b/pypots/imputation/scinet/model.py @@ -248,7 +248,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/stemgnn/model.py b/pypots/imputation/stemgnn/model.py index 743ed3d5..1bc2ef5f 100644 --- a/pypots/imputation/stemgnn/model.py +++ b/pypots/imputation/stemgnn/model.py @@ -222,7 +222,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/tcn/model.py b/pypots/imputation/tcn/model.py index 8c01981f..9c83a37d 100644 --- a/pypots/imputation/tcn/model.py +++ b/pypots/imputation/tcn/model.py @@ -216,7 +216,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/tefn/model.py b/pypots/imputation/tefn/model.py index ff30eca5..b4c082e4 100644 --- a/pypots/imputation/tefn/model.py +++ b/pypots/imputation/tefn/model.py @@ -188,7 +188,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/tide/model.py b/pypots/imputation/tide/model.py index 949b15fe..e5d644c3 100644 --- a/pypots/imputation/tide/model.py +++ b/pypots/imputation/tide/model.py @@ -228,7 +228,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/timemixer/model.py b/pypots/imputation/timemixer/model.py index 5e274d7f..09a3f74e 100644 --- a/pypots/imputation/timemixer/model.py +++ b/pypots/imputation/timemixer/model.py @@ -253,7 +253,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/timesnet/model.py b/pypots/imputation/timesnet/model.py index e3029e93..5ac1aec6 100644 --- a/pypots/imputation/timesnet/model.py +++ b/pypots/imputation/timesnet/model.py @@ -224,7 +224,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/transformer/model.py b/pypots/imputation/transformer/model.py index 33eefee1..ccc7cf27 100644 --- a/pypots/imputation/transformer/model.py +++ b/pypots/imputation/transformer/model.py @@ -256,7 +256,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self, diff --git a/pypots/imputation/usgan/model.py b/pypots/imputation/usgan/model.py index e329fdf0..aadb3703 100644 --- a/pypots/imputation/usgan/model.py +++ b/pypots/imputation/usgan/model.py @@ -324,8 +324,8 @@ def _train_model( # save the model if necessary self._auto_save_model_if_necessary( - confirm_saving=self.best_epoch == epoch, - saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}", + confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better", + saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}", ) if os.getenv("enable_tuning", False): @@ -389,7 +389,7 @@ def fit( self.model.eval() # set the model as eval status to freeze it. # Step 3: save the model if necessary - self._auto_save_model_if_necessary(confirm_saving=True) + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( self,