Skip to content

Commit

Permalink
Checkpoint on final step of training even when it doesn't coincide wi…
Browse files Browse the repository at this point in the history
…th `save_freq`. (#284)
  • Loading branch information
alexander-soare authored Jun 20, 2024
1 parent 2abef3b commit 9aa4cdb
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
3 changes: 2 additions & 1 deletion lerobot/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ training:
# `online_env_seed` is used for environments for online training data rollouts.
online_env_seed: ???
eval_freq: ???
save_freq: ???
log_freq: 250
save_checkpoint: true
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
save_freq: ???
num_workers: 4
batch_size: ???
image_transforms:
Expand Down
5 changes: 4 additions & 1 deletion lerobot/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,10 @@ def evaluate_and_checkpoint_if_needed(step):
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
logging.info("Resume training")

if cfg.training.save_checkpoint and step % cfg.training.save_freq == 0:
if cfg.training.save_checkpoint and (
step % cfg.training.save_freq == 0
or step == cfg.training.offline_steps + cfg.training.online_steps
):
logging.info(f"Checkpoint policy after step {step}")
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
# needed (choose 6 as a minimum for consistency without being overkill).
Expand Down

0 comments on commit 9aa4cdb

Please sign in to comment.