diff --git a/nemo/lightning/io/connector.py b/nemo/lightning/io/connector.py index 7d81e631d6f1..be9718797310 100644 --- a/nemo/lightning/io/connector.py +++ b/nemo/lightning/io/connector.py @@ -7,6 +7,8 @@ import pytorch_lightning as pl from filelock import FileLock, Timeout +from pytorch_lightning.trainer.states import TrainerFn + # Dynamically inherit from the correct Path subclass based on the operating system. if os.name == 'nt': @@ -139,26 +141,25 @@ def nemo_setup(self, model: pl.LightningModule, trainer: Optional[pl.Trainer] = Args: model (pl.LightningModule): The model to be set up. trainer (Optional[pl.Trainer]): The trainer to be used, if not provided a new one will be created. - Returns ------- pl.Trainer: The trainer configured with the model and strategy. """ - from nemo.lightning import MegatronStrategy, Trainer - from nemo.lightning._strategy_lib import megatron_lazy_init_context + from nemo.lightning import MegatronStrategy, Trainer, _strategy_lib _trainer = trainer or Trainer( - devices=1, accelerator="cpu", strategy=MegatronStrategy(store_optimizer_states=False) + devices=1, + accelerator="cpu", + strategy=MegatronStrategy(ckpt_save_optimizer=False, always_save_context=True), ) - + _trainer.state.fn = TrainerFn.FITTING # needed for proper save. _trainer.strategy.connect(model) _trainer.strategy.setup_environment() if not model.state_dict(): _trainer.strategy.lazy_init = True - with _trainer.init_module(), megatron_lazy_init_context(model.config): + with _trainer.init_module(), _strategy_lib.megatron_lazy_init_context(model.config): model.configure_model() - return _trainer def nemo_save(self, output_path: Path, trainer: pl.Trainer, dump_io: bool = True) -> None: @@ -170,16 +171,20 @@ def nemo_save(self, output_path: Path, trainer: pl.Trainer, dump_io: bool = True trainer (pl.Trainer): The trainer with the strategy to save the model. dump_io (bool): If True, the IO configuration will be saved to the output path. """ + # Import here to avoid circular import + from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint + trainer.strategy._setup_optimizers = False trainer.strategy._init_model_parallel = False trainer.strategy.setup(trainer) - trainer.save_checkpoint(output_path) + output_path = Path(output_path) + trainer.save_checkpoint(output_path / ModelCheckpoint.WEIGHTS_PATH) from nemo.lightning.io.pl import TrainerContext from nemo.utils.get_rank import is_global_rank_zero if is_global_rank_zero() and dump_io: - TrainerContext.from_trainer(trainer).io_dump(output_path) + TrainerContext.from_trainer(trainer).io_dump(output_path / ModelCheckpoint.CONTEXT_PATH) def nemo_load( self, path: Path, trainer: Optional[pl.Trainer] = None, cpu: bool = True diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index 6aee365a3f60..770ee44c43aa 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -57,8 +57,9 @@ class ModelCheckpoint(PTLModelCheckpoint): async_save: Whether to enable asynchronous checkpointing. """ - UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished" - WEIGHTS_PATH = "weights" + UNFINISHED_CHECKPOINT_SUFFIX: str = "-unfinished" + WEIGHTS_PATH: str = "weights" + CONTEXT_PATH: str = "context" def __init__( self, @@ -277,7 +278,7 @@ def on_train_end(self, trainer, pl_module): else: super()._save_last_checkpoint(trainer, monitor_candidates) if self.save_context_on_train_end and not self.always_save_context and is_global_rank_zero(): - TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(self.last_model_path) / "context") + TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(self.last_model_path) / self.CONTEXT_PATH) # Call parent on_train_end() to save the -last checkpoint super().on_train_end(trainer, pl_module) @@ -449,7 +450,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) trainer.save_checkpoint(ckpt_filepath, save_weights_only, storage_options=storage_options) if self.always_save_context and is_global_rank_zero(): - TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "context") + TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / self.CONTEXT_PATH) if self.async_save: logging.info(f'Scheduled async checkpoint save for {filepath}')