Skip to content

Commit

Permalink
Update checkpoint connector nemo_save to match current folder heirarchy
Browse files Browse the repository at this point in the history
Signed-off-by: John St John <jstjohn@nvidia.com>
  • Loading branch information
jstjohn committed Sep 27, 2024
1 parent 23c7de1 commit a4ad2d7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
23 changes: 14 additions & 9 deletions nemo/lightning/io/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import pytorch_lightning as pl
from filelock import FileLock, Timeout
from pytorch_lightning.trainer.states import TrainerFn


# Dynamically inherit from the correct Path subclass based on the operating system.
if os.name == 'nt':
Expand Down Expand Up @@ -139,26 +141,25 @@ def nemo_setup(self, model: pl.LightningModule, trainer: Optional[pl.Trainer] =
Args:
model (pl.LightningModule): The model to be set up.
trainer (Optional[pl.Trainer]): The trainer to be used, if not provided a new one will be created.
Returns
-------
pl.Trainer: The trainer configured with the model and strategy.
"""
from nemo.lightning import MegatronStrategy, Trainer
from nemo.lightning._strategy_lib import megatron_lazy_init_context
from nemo.lightning import MegatronStrategy, Trainer, _strategy_lib

_trainer = trainer or Trainer(
devices=1, accelerator="cpu", strategy=MegatronStrategy(store_optimizer_states=False)
devices=1,
accelerator="cpu",
strategy=MegatronStrategy(ckpt_save_optimizer=False, always_save_context=True),
)

_trainer.state.fn = TrainerFn.FITTING # needed for proper save.
_trainer.strategy.connect(model)
_trainer.strategy.setup_environment()

if not model.state_dict():
_trainer.strategy.lazy_init = True
with _trainer.init_module(), megatron_lazy_init_context(model.config):
with _trainer.init_module(), _strategy_lib.megatron_lazy_init_context(model.config):
model.configure_model()

return _trainer

def nemo_save(self, output_path: Path, trainer: pl.Trainer, dump_io: bool = True) -> None:
Expand All @@ -170,16 +171,20 @@ def nemo_save(self, output_path: Path, trainer: pl.Trainer, dump_io: bool = True
trainer (pl.Trainer): The trainer with the strategy to save the model.
dump_io (bool): If True, the IO configuration will be saved to the output path.
"""
# Import here to avoid circular import
from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
nemo.lightning.pytorch.callbacks.model_checkpoint
begins an import cycle.

trainer.strategy._setup_optimizers = False
trainer.strategy._init_model_parallel = False
trainer.strategy.setup(trainer)
trainer.save_checkpoint(output_path)
output_path = Path(output_path)
trainer.save_checkpoint(output_path / ModelCheckpoint.WEIGHTS_PATH)

from nemo.lightning.io.pl import TrainerContext
from nemo.utils.get_rank import is_global_rank_zero

if is_global_rank_zero() and dump_io:
TrainerContext.from_trainer(trainer).io_dump(output_path)
TrainerContext.from_trainer(trainer).io_dump(output_path / ModelCheckpoint.CONTEXT_PATH)

def nemo_load(
self, path: Path, trainer: Optional[pl.Trainer] = None, cpu: bool = True
Expand Down
9 changes: 5 additions & 4 deletions nemo/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ class ModelCheckpoint(PTLModelCheckpoint):
async_save: Whether to enable asynchronous checkpointing.
"""

UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished"
WEIGHTS_PATH = "weights"
UNFINISHED_CHECKPOINT_SUFFIX: str = "-unfinished"
WEIGHTS_PATH: str = "weights"
CONTEXT_PATH: str = "context"

def __init__(
self,
Expand Down Expand Up @@ -277,7 +278,7 @@ def on_train_end(self, trainer, pl_module):
else:
super()._save_last_checkpoint(trainer, monitor_candidates)
if self.save_context_on_train_end and not self.always_save_context and is_global_rank_zero():
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(self.last_model_path) / "context")
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(self.last_model_path) / self.CONTEXT_PATH)
# Call parent on_train_end() to save the -last checkpoint
super().on_train_end(trainer, pl_module)

Expand Down Expand Up @@ -449,7 +450,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)
trainer.save_checkpoint(ckpt_filepath, save_weights_only, storage_options=storage_options)

if self.always_save_context and is_global_rank_zero():
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "context")
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / self.CONTEXT_PATH)

if self.async_save:
logging.info(f'Scheduled async checkpoint save for {filepath}')
Expand Down

0 comments on commit a4ad2d7

Please sign in to comment.