Skip to content

Commit

Permalink
Added feature for saving checkpoint every certain epochs (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
peterli3819 authored Dec 20, 2023
1 parent 9b24b21 commit 179c45d
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions hippynn/experiment/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,17 @@ def setup_and_train(
training_modules: TrainingModules,
database,
setup_params: SetupParams,
store_all_better=False,
store_best=True,
store_every=0
):
"""
:param: training_modules: see :func:`setup_training`
:param: database: see :func:`train_model`
:param: setup_params: see :func:`setup_training`
:param: store_all_better: Save the state dict for each model doing better than a previous one
:param: store_best: Save a checkpoint for the best model
:param: store_every: Save a checkpoint for every certain epochs
:return: See :func:`train_model`
Shortcut for setup_training followed by train_model.
Expand All @@ -134,6 +140,9 @@ def setup_and_train(
metric_tracker=metric_tracker,
callbacks=None,
batch_callbacks=None,
store_all_better=store_all_better,
store_best=store_best,
store_every=store_every
)


Expand Down Expand Up @@ -212,6 +221,7 @@ def train_model(
batch_callbacks,
store_all_better=False,
store_best=True,
store_every=0,
store_structure_file=True,
store_metrics=True,
quiet=False,
Expand All @@ -228,6 +238,7 @@ def train_model(
:param batch_callbacks: callbacks to perform after every batch
:param store_best: Save a checkpoint for the best model
:param store_all_better: Save the state dict for each model doing better than a previous one
:param store_every: Save a checkpoint for every certain epochs
:param store_structure_file: Save the structure file for this experiment
:param store_metrics: Save the metric tracker for this experiment.
:param quiet: If True, disable printing during training (still prints testing results).
Expand Down Expand Up @@ -286,6 +297,7 @@ def train_model(
batch_callbacks=batch_callbacks,
store_best=store_best,
store_all_better=store_all_better,
store_every=store_every,
quiet=quiet,
)

Expand Down Expand Up @@ -364,6 +376,7 @@ def training_loop(
batch_callbacks,
store_all_better,
store_best,
store_every,
quiet,
):
"""
Expand All @@ -377,6 +390,7 @@ def training_loop(
:param batch_callbacks: list of callbacks for each batch
:param store_best: Save a checkpoint for the best model
:param store_all_better: Save the state dict for each model doing better than a previous one
:param store_every: Save a checkpoint for every certain epochs
:param quiet: whether to print information. Setting quiet to true won't prevent progress bars.
:return: metrics -- the state of the experiment after training
Expand Down Expand Up @@ -506,6 +520,17 @@ def training_loop(
# Write the checkpoint
with open("best_checkpoint.pt", "wb") as pfile:
torch.save(state, pfile)

if store_every and epoch != 0 and (epoch % store_every) == 0:
# Save a copy every "store_every" epoch
with open(f"model_epoch_{epoch}.pt", "wb") as pfile:
torch.save(model.state_dict(), pfile)

state = serialization.create_state(model, controller, metric_tracker)

# Write the checkpoint
with open(f"checkpoint_epoch_{epoch}.pt", "wb") as pfile:
torch.save(state, pfile)

epoch += 1

Expand Down

0 comments on commit 179c45d

Please sign in to comment.