Skip to content

Commit

Permalink
Merge pull request #342 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Add the attribute `best_epoch` to record the best epoch num
  • Loading branch information
WenjieDu authored Apr 8, 2024
2 parents 6e7982c + 93062a2 commit fb7ec06
Show file tree
Hide file tree
Showing 10 changed files with 39 additions and 10 deletions.
4 changes: 3 additions & 1 deletion pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,8 @@ class BaseNNModel(BaseModel):
The criteria to judge whether the model's performance is the best so far.
Usually the lower, the better.
best_epoch : int, default = -1,
The epoch number when the best loss is got.
Notes
-----
Expand Down Expand Up @@ -494,8 +496,8 @@ def __init__(
self.model = None
self.optimizer = None
self.best_model_dict = None
# WDU: may enable users to customize the criteria in the future
self.best_loss = float("inf")
self.best_epoch = -1

def _print_model_size(self) -> None:
"""Print the number of trainable parameters in the initialized NN model."""
Expand Down
5 changes: 4 additions & 1 deletion pypots/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def _train_model(
)

if mean_loss < self.best_loss:
self.best_epoch = epoch
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
self.patience = self.original_patience
Expand Down Expand Up @@ -376,7 +377,9 @@ def _train_model(
if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")

logger.info("Finished training.")
logger.info(
f"Finished training. The best model is from epoch#{self.best_epoch}."
)

@abstractmethod
def fit(
Expand Down
5 changes: 4 additions & 1 deletion pypots/clustering/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def _train_model(
)

if mean_loss < self.best_loss:
self.best_epoch = epoch
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
self.patience = self.original_patience
Expand Down Expand Up @@ -369,7 +370,9 @@ def _train_model(
if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")

logger.info("Finished training.")
logger.info(
f"Finished training. The best model is from epoch#{self.best_epoch}."
)

@abstractmethod
def fit(
Expand Down
5 changes: 4 additions & 1 deletion pypots/clustering/crli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def _train_model(
)

if mean_loss < self.best_loss:
self.best_epoch = epoch
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
self.patience = self.original_patience
Expand Down Expand Up @@ -335,7 +336,9 @@ def _train_model(
if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")

logger.info("Finished training.")
logger.info(
f"Finished training. The best model is from epoch#{self.best_epoch}."
)

def fit(
self,
Expand Down
5 changes: 4 additions & 1 deletion pypots/clustering/vader/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def _train_model(
)

if mean_loss < self.best_loss:
self.best_epoch = epoch
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
self.patience = self.original_patience
Expand Down Expand Up @@ -348,7 +349,9 @@ def _train_model(
if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")

logger.info("Finished training.")
logger.info(
f"Finished training. The best model is from epoch#{self.best_epoch}."
)

def fit(
self,
Expand Down
5 changes: 4 additions & 1 deletion pypots/forecasting/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def _train_model(
)

if mean_loss < self.best_loss:
self.best_epoch = epoch
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
self.patience = self.original_patience
Expand Down Expand Up @@ -370,7 +371,9 @@ def _train_model(
if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")

logger.info("Finished training.")
logger.info(
f"Finished training. The best model is from epoch#{self.best_epoch}."
)

@abstractmethod
def fit(
Expand Down
5 changes: 4 additions & 1 deletion pypots/imputation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def _train_model(
)

if mean_loss < self.best_loss:
self.best_epoch = epoch
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
self.patience = self.original_patience
Expand Down Expand Up @@ -373,7 +374,9 @@ def _train_model(
if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")

logger.info("Finished training.")
logger.info(
f"Finished training. The best model is from epoch#{self.best_epoch}."
)

@abstractmethod
def fit(
Expand Down
5 changes: 4 additions & 1 deletion pypots/imputation/csdi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def _train_model(
)

if mean_loss < self.best_loss:
self.best_epoch = epoch
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
self.patience = self.original_patience
Expand Down Expand Up @@ -322,7 +323,9 @@ def _train_model(
if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")

logger.info("Finished training.")
logger.info(
f"Finished training. The best model is from epoch#{self.best_epoch}."
)

def fit(
self,
Expand Down
5 changes: 4 additions & 1 deletion pypots/imputation/gpvae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def _train_model(
)

if mean_loss < self.best_loss:
self.best_epoch = epoch
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
self.patience = self.original_patience
Expand Down Expand Up @@ -352,7 +353,9 @@ def _train_model(
if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")

logger.info("Finished training.")
logger.info(
f"Finished training. The best model is from epoch#{self.best_epoch}."
)

def fit(
self,
Expand Down
5 changes: 4 additions & 1 deletion pypots/imputation/usgan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def _train_model(
)

if mean_loss < self.best_loss:
self.best_epoch = epoch
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
self.patience = self.original_patience
Expand Down Expand Up @@ -369,7 +370,9 @@ def _train_model(
if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")

logger.info("Finished training.")
logger.info(
f"Finished training. The best model is from epoch#{self.best_epoch}."
)

def fit(
self,
Expand Down

0 comments on commit fb7ec06

Please sign in to comment.