Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpoint connector bugfixes #10647

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jstjohn
Copy link
Collaborator

@jstjohn jstjohn commented Sep 26, 2024

What does this PR do ?

Get checkpoint connector working for bionemo (see https://github.com/NVIDIA/bionemo-fw-ea/pull/180)

Changelog

  • Use new nemo2 standard /weights and /context subdirectory scheme so checkpoint loaders work properly with checkpoints created by this method.
  • Other changes necessary to allow checkpoint transformation outside of training resumption.

@jstjohn jstjohn marked this pull request as draft September 26, 2024 23:00
@jstjohn jstjohn self-assigned this Sep 26, 2024
@jstjohn jstjohn marked this pull request as ready for review September 26, 2024 23:31
@jstjohn jstjohn force-pushed the jstjohn/nemo_checkpoint_connector_fixes branch 3 times, most recently from eb4bca5 to 6069101 Compare September 27, 2024 22:40
Signed-off-by: John St John <jstjohn@nvidia.com>
@jstjohn jstjohn force-pushed the jstjohn/nemo_checkpoint_connector_fixes branch from 6069101 to a4ad2d7 Compare September 27, 2024 22:42
@@ -170,16 +171,20 @@
trainer (pl.Trainer): The trainer with the strategy to save the model.
dump_io (bool): If True, the IO configuration will be saved to the output path.
"""
# Import here to avoid circular import
from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
nemo.lightning.pytorch.callbacks.model_checkpoint
begins an import cycle.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jstjohn any way to avoid this?

)

_trainer.state.fn = TrainerFn.FITTING # needed for proper save.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jstjohn please add what do you mean here, what would be missing if fn was not set to fitting?

@@ -277,7 +278,7 @@ def on_train_end(self, trainer, pl_module):
else:
super()._save_last_checkpoint(trainer, monitor_candidates)
if self.save_context_on_train_end and not self.always_save_context and is_global_rank_zero():
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(self.last_model_path) / "context")
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(self.last_model_path) / self.CONTEXT_PATH)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jstjohn ideally I'd like to avoid changing the checkpoint structure but if we have to do it, let's add a comment giving an example for the use-case and cherry-pick this PR to make it to the 24.09 release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants