Skip to content

Commit

Permalink
Merge pull request #514 from WenjieDu/(fix)model_saving
Browse files Browse the repository at this point in the history
`model_saving_strategy=best` does not work
  • Loading branch information
WenjieDu authored Sep 22, 2024
2 parents 9c18cdc + f28353a commit d8ade79
Show file tree
Hide file tree
Showing 42 changed files with 56 additions and 56 deletions.
4 changes: 2 additions & 2 deletions pypots/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,8 @@ def _train_model(

# save the model if necessary
self._auto_save_model_if_necessary(
confirm_saving=self.best_epoch == epoch,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better",
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}",
)

if os.getenv("enable_tuning", False):
Expand Down
2 changes: 1 addition & 1 deletion pypots/classification/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/classification/grud/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/classification/raindrop/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
4 changes: 2 additions & 2 deletions pypots/clustering/crli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def _train_model(

# save the model if necessary
self._auto_save_model_if_necessary(
confirm_saving=self.best_epoch == epoch,
confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better",
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}",
)

Expand Down Expand Up @@ -354,7 +354,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
6 changes: 3 additions & 3 deletions pypots/clustering/vader/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@ def _train_model(

# save the model if necessary
self._auto_save_model_if_necessary(
confirm_saving=self.best_epoch == epoch,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better",
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}",
)

if os.getenv("enable_tuning", False):
Expand Down Expand Up @@ -367,7 +367,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
4 changes: 2 additions & 2 deletions pypots/forecasting/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,8 @@ def _train_model(

# save the model if necessary
self._auto_save_model_if_necessary(
confirm_saving=self.best_epoch == epoch,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better",
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}",
)

if os.getenv("enable_tuning", False):
Expand Down
6 changes: 3 additions & 3 deletions pypots/forecasting/csdi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ def _train_model(

# save the model if necessary
self._auto_save_model_if_necessary(
confirm_saving=self.best_epoch == epoch,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better",
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}",
)

if os.getenv("enable_tuning", False):
Expand Down Expand Up @@ -379,7 +379,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/autoformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
4 changes: 2 additions & 2 deletions pypots/imputation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,8 @@ def _train_model(

# save the model if necessary
self._auto_save_model_if_necessary(
confirm_saving=self.best_epoch == epoch,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better",
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}",
)

if os.getenv("enable_tuning", False):
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/crossformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
6 changes: 3 additions & 3 deletions pypots/imputation/csdi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,8 @@ def _train_model(

# save the model if necessary
self._auto_save_model_if_necessary(
confirm_saving=self.best_epoch == epoch,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better",
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}",
)

if os.getenv("enable_tuning", False):
Expand Down Expand Up @@ -363,7 +363,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/dlinear/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/etsformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/fedformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/film/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/frets/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
6 changes: 3 additions & 3 deletions pypots/imputation/gpvae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,8 @@ def _train_model(

# save the model if necessary
self._auto_save_model_if_necessary(
confirm_saving=self.best_epoch == epoch,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
confirm_saving=self.best_epoch == epoch and self.model_saving_strategy == "better",
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}",
)

if os.getenv("enable_tuning", False):
Expand Down Expand Up @@ -377,7 +377,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/grud/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/imputeformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/informer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/itransformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/koopa/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/micn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/moderntcn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/mrnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/nonstationary_transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/patchtst/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/pyraformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/reformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/revinscinet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/saits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/scinet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/stemgnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/tcn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/tefn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def fit(
self.model.eval() # set the model as eval status to freeze it.

# Step 3: save the model if necessary
self._auto_save_model_if_necessary(confirm_saving=True)
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")

def predict(
self,
Expand Down
Loading

0 comments on commit d8ade79

Please sign in to comment.