Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpoint connector bugfixes #10647

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions nemo/lightning/io/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import pytorch_lightning as pl
from filelock import FileLock, Timeout
from pytorch_lightning.trainer.states import TrainerFn


# Dynamically inherit from the correct Path subclass based on the operating system.
if os.name == 'nt':
Expand Down Expand Up @@ -139,26 +141,25 @@
Args:
model (pl.LightningModule): The model to be set up.
trainer (Optional[pl.Trainer]): The trainer to be used, if not provided a new one will be created.

Returns
-------
pl.Trainer: The trainer configured with the model and strategy.
"""
from nemo.lightning import MegatronStrategy, Trainer
from nemo.lightning._strategy_lib import megatron_lazy_init_context
from nemo.lightning import MegatronStrategy, Trainer, _strategy_lib

_trainer = trainer or Trainer(
devices=1, accelerator="cpu", strategy=MegatronStrategy(store_optimizer_states=False)
devices=1,
accelerator="cpu",
strategy=MegatronStrategy(ckpt_save_optimizer=False, always_save_context=True),
)

_trainer.state.fn = TrainerFn.FITTING # needed for proper save.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jstjohn please add what do you mean here, what would be missing if fn was not set to fitting?

_trainer.strategy.connect(model)
_trainer.strategy.setup_environment()

if not model.state_dict():
_trainer.strategy.lazy_init = True
with _trainer.init_module(), megatron_lazy_init_context(model.config):
with _trainer.init_module(), _strategy_lib.megatron_lazy_init_context(model.config):
model.configure_model()

return _trainer

def nemo_save(self, output_path: Path, trainer: pl.Trainer, dump_io: bool = True) -> None:
Expand All @@ -170,16 +171,20 @@
trainer (pl.Trainer): The trainer with the strategy to save the model.
dump_io (bool): If True, the IO configuration will be saved to the output path.
"""
# Import here to avoid circular import
from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
nemo.lightning.pytorch.callbacks.model_checkpoint
begins an import cycle.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jstjohn any way to avoid this?


trainer.strategy._setup_optimizers = False
trainer.strategy._init_model_parallel = False
trainer.strategy.setup(trainer)
trainer.save_checkpoint(output_path)
output_path = Path(output_path)
trainer.save_checkpoint(output_path / ModelCheckpoint.WEIGHTS_PATH)

from nemo.lightning.io.pl import TrainerContext
from nemo.utils.get_rank import is_global_rank_zero

if is_global_rank_zero() and dump_io:
TrainerContext.from_trainer(trainer).io_dump(output_path)
TrainerContext.from_trainer(trainer).io_dump(output_path / ModelCheckpoint.CONTEXT_PATH)

def nemo_load(
self, path: Path, trainer: Optional[pl.Trainer] = None, cpu: bool = True
Expand Down
9 changes: 5 additions & 4 deletions nemo/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ class ModelCheckpoint(PTLModelCheckpoint):
async_save: Whether to enable asynchronous checkpointing.
"""

UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished"
WEIGHTS_PATH = "weights"
UNFINISHED_CHECKPOINT_SUFFIX: str = "-unfinished"
WEIGHTS_PATH: str = "weights"
CONTEXT_PATH: str = "context"

def __init__(
self,
Expand Down Expand Up @@ -277,7 +278,7 @@ def on_train_end(self, trainer, pl_module):
else:
super()._save_last_checkpoint(trainer, monitor_candidates)
if self.save_context_on_train_end and not self.always_save_context and is_global_rank_zero():
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(self.last_model_path) / "context")
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(self.last_model_path) / self.CONTEXT_PATH)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jstjohn ideally I'd like to avoid changing the checkpoint structure but if we have to do it, let's add a comment giving an example for the use-case and cherry-pick this PR to make it to the 24.09 release.

# Call parent on_train_end() to save the -last checkpoint
super().on_train_end(trainer, pl_module)

Expand Down Expand Up @@ -449,7 +450,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)
trainer.save_checkpoint(ckpt_filepath, save_weights_only, storage_options=storage_options)

if self.always_save_context and is_global_rank_zero():
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "context")
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / self.CONTEXT_PATH)

if self.async_save:
logging.info(f'Scheduled async checkpoint save for {filepath}')
Expand Down
Loading