Skip to content

Commit

Permalink
add save checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
armingh2000 committed Feb 2, 2024
1 parent b108f03 commit 285d365
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
5 changes: 4 additions & 1 deletion src/model/train/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import src.configs as configs
from tqdm import tqdm
from utils import save_checkpoint, load_checkpoint
from src.utils import mkpath

cyclic_loss_function = getattr(nn, configs.cyclic_loss)()
acyclic_loss_function = getattr(nn, configs.acyclic_loss)()
Expand All @@ -24,6 +25,8 @@ def get_loss(pred, target):


def train(model, train_loader, val_loader, logger, checkpoint=None):
mkpath(configs.model_checkpoint_dir_path)
mkpath(configs.training_state_checkpoint_dir_path)
optimizer = getattr(torch.optim, configs.optimizer)(
model.parameters(), lr=configs.learning_rate
)
Expand All @@ -36,7 +39,7 @@ def train(model, train_loader, val_loader, logger, checkpoint=None):
start_epoch = 0

else:
start_epoch = loaded_start_epoch
start_epoch = loaded_start_epoch + 1
model = loaded_model
optimizer = loaded_optimizer

Expand Down
6 changes: 3 additions & 3 deletions src/model/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ def load_checkpoint(checkpoint_path=None, model=None, optimizer=None, logger=Non
)

if not model_checkpoints or not training_state_checkpoints:
logger.error("No checkpoints found")
logger.info("No checkpoints found")
return None, None, None

# Sort the checkpoints by epoch to find the most recent one
latest_model_checkpoint = max(
model_checkpoints, key=lambda path: int(path.stem.split("_")[-3])
model_checkpoints, key=lambda path: int(path.stem.split("_")[-2])
)
latest_training_state_checkpoint = max(
training_state_checkpoints, key=lambda path: int(path.stem.split("_")[-3])
training_state_checkpoints, key=lambda path: int(path.stem.split("_")[-2])
)

logger.info(f"Loading the latest model checkpoint '{latest_model_checkpoint}'")
Expand Down

0 comments on commit 285d365

Please sign in to comment.