From 2bfcda153ea949a8a1ea03f463debce1c6e212a8 Mon Sep 17 00:00:00 2001 From: Anna Shors <71393111+ashors1@users.noreply.github.com> Date: Fri, 6 Sep 2024 00:26:10 -0700 Subject: [PATCH] [NeMo-UX] checkpointing improvements (#10241) * save model weights and artifacts to separate directories Signed-off-by: ashors1 * add save_artifacts_on_train_end Signed-off-by: ashors1 * Apply isort and black reformatting Signed-off-by: ashors1 * do not save optimizer states in final checkpoint Signed-off-by: ashors1 * WIP support for saving only last k optimizer states Signed-off-by: ashors1 * Apply isort and black reformatting Signed-off-by: ashors1 * minor cleanup Signed-off-by: ashors1 * Revert support for saving last k optimizer states. This will be addressed in a subsequent PR. * use storage_options to determine when to skip saving optimizer states Signed-off-by: ashors1 * Apply isort and black reformatting Signed-off-by: ashors1 * fix variable names, make checkpoint load work when optimizer states don't exist in the checkpoint Signed-off-by: ashors1 * Apply isort and black reformatting Signed-off-by: ashors1 * FSDP updates, provide option to save optimizer states on_train_end Signed-off-by: ashors1 * Apply isort and black reformatting Signed-off-by: ashors1 * simplify implementation, remove save_best_model option Signed-off-by: ashors1 * update default value of ckpt_include_optimizer for fsdp Signed-off-by: ashors1 * remove unused imports Signed-off-by: ashors1 * remove unused import Signed-off-by: ashors1 * cleanup Signed-off-by: ashors1 * make storage_options optional again Signed-off-by: ashors1 * fix failing tests Signed-off-by: ashors1 * address some comments Signed-off-by: ashors1 * use save_weights_only to determine whether to save optimizer states Signed-off-by: ashors1 * Apply isort and black reformatting Signed-off-by: ashors1 * add some comments Signed-off-by: ashors1 * fix tests Signed-off-by: ashors1 * Apply isort and black reformatting Signed-off-by: ashors1 * Apply isort and black reformatting Signed-off-by: ashors1 * fixes Signed-off-by: ashors1 * Apply isort and black reformatting Signed-off-by: ashors1 * remove unnecessary line Signed-off-by: ashors1 --------- Signed-off-by: ashors1 Signed-off-by: ashors1 Co-authored-by: ashors1 Signed-off-by: adityavavre --- examples/llm/megatron_gpt_pretraining.py | 1 - nemo/lightning/io/mixin.py | 4 +- .../pytorch/callbacks/model_checkpoint.py | 83 +++++++++---------- .../pytorch/strategies/fsdp_strategy.py | 14 ++-- .../pytorch/strategies/megatron_strategy.py | 11 ++- nemo/lightning/resume.py | 16 ++++ .../collections/llm/test_mnist_model_nemo2.py | 6 +- .../llm/test_mnist_model_nemo2_fsdp.py | 6 +- tests/lightning/test_nemo_logger.py | 12 ++- 9 files changed, 93 insertions(+), 60 deletions(-) diff --git a/examples/llm/megatron_gpt_pretraining.py b/examples/llm/megatron_gpt_pretraining.py index cfdb6a6acb4b..bf36971d35d6 100644 --- a/examples/llm/megatron_gpt_pretraining.py +++ b/examples/llm/megatron_gpt_pretraining.py @@ -71,7 +71,6 @@ def get_args(): strategy = nl.MegatronStrategy() checkpoint_callback = ModelCheckpoint( every_n_train_steps=5000, - enable_nemo_ckpt_io=False, ) callbacks = [checkpoint_callback] diff --git a/nemo/lightning/io/mixin.py b/nemo/lightning/io/mixin.py index ff6c925a64bb..eee2d9ef751a 100644 --- a/nemo/lightning/io/mixin.py +++ b/nemo/lightning/io/mixin.py @@ -141,7 +141,7 @@ def io_dump(self, output: Path): will be stored. """ output_path = Path(output) - local_artifacts_dir = "artifacts" + local_artifacts_dir = "." artifacts_dir = output_path / local_artifacts_dir artifacts_dir.mkdir(parents=True, exist_ok=True) @@ -518,7 +518,7 @@ def _io_path_elements_fn(x): return x.__io__.__path_elements__() -def _artifact_transform_save(cfg: fdl.Config, output_path: Path, relative_dir: Path = "artifacts"): +def _artifact_transform_save(cfg: fdl.Config, output_path: Path, relative_dir: Path = "."): for artifact in getattr(cfg.__fn_or_cls__, "__io_artifacts__", []): current_val = getattr(cfg, artifact.attr) if current_val is None: diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index db48ded0d10d..7ebeed138d2c 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -40,21 +40,25 @@ class ModelCheckpoint(PTLModelCheckpoint): verbose: Verbosity mode. save_last: When ``True``, saves a `*-last` copy whenever a checkpoint file gets saved. save_top_k: When ``True``, saves the top-k checkpoints according to ``monitor``. - save_weights_only: if ``True``, then only the model's weights will be saved. + save_weights_only: if ``True``, then only the model's weights will be saved. Optimizer states will + be omitted from all checkpoints. mode: One of {min, max}. Whether the objective is to minimize or maximize the monitored quantity. every_n_epochs: Number of epochs between checkpoints. every_n_train_steps: Number of train steps between checkpoints. train_time_interval: After each interval, monitor checkpoints. Not to be used with ``every_n_epochs`` or ``every_n_train_steps``. - save_best_model: When ``True``, reloads and saves the best checkpoint. save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch - enable_nemo_ckpt_io: Whether to dump the current model model state, including the - config file, to allow for reproducibility of experiments. + save_optim_on_train_end: Whether to include the optimizer states in the final checkpoint + at the end of training. Only applicable when save_weights_only is ``True``. + always_save_context: Whether to dump the artifacts needed to reinintialize the current + model, trainer, and dataloader to allow for reproducibility of experiments. + save_context_on_train_end: Whether to dump the artifacts on_train_end regardless of whether + ``always_save_context`` is ``True``. async_save: Whether to enable asynchronous checkpointing. - try_restore_best_ckpt: Whether to restore the best model path. """ UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished" + WEIGHTS_PATH = "weights" def __init__( self, @@ -67,21 +71,21 @@ def __init__( every_n_epochs: int = None, every_n_train_steps: Optional[int] = None, train_time_interval: Optional[timedelta] = None, - save_best_model: bool = False, save_on_train_epoch_end: Optional[bool] = False, # Save after training, not after validation - enable_nemo_ckpt_io: bool = True, - try_restore_best_ckpt: bool = True, + save_optim_on_train_end: Optional[bool] = False, + always_save_context: bool = False, + save_context_on_train_end: bool = True, **kwargs, ): - self.save_best_model = save_best_model - self.previous_best_path = "" - self.enable_nemo_ckpt_io = enable_nemo_ckpt_io + self.always_save_context = always_save_context + self.save_context_on_train_end = save_context_on_train_end + self.save_optim_on_train_end = save_optim_on_train_end + # Checkpoints which removal is deferred until async save is done. # Each element of `deferred_ckpts_to_remove` is a growing list # that `self._remove_checkpoint` adds to. Once `self._save_checkpoint` # is called, the last element is frozen and a new element is added. self.deferred_ckpts_to_remove: List[List[str]] = [] - self.try_restore_best_ckpt = try_restore_best_ckpt # Call the parent class constructor with the remaining kwargs. super().__init__( @@ -251,11 +255,9 @@ def setup(self, trainer, *args, **kwargs) -> None: self.async_save = getattr(trainer.strategy, "async_save", False) super().setup(trainer, *args, **kwargs) - def on_save_checkpoint(self, trainer, pl_module, checkpoint): - output = super().on_save_checkpoint(trainer, pl_module, checkpoint) - return output - def on_train_end(self, trainer, pl_module): + from nemo.utils.get_rank import is_global_rank_zero + if trainer.fast_dev_run: return None @@ -272,26 +274,11 @@ def on_train_end(self, trainer, pl_module): logging.debug(f'Last checkpoint {self.last_model_path} already saved') else: super()._save_last_checkpoint(trainer, monitor_candidates) + if self.save_context_on_train_end and not self.always_save_context and is_global_rank_zero(): + TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(self.last_model_path) / "context") # Call parent on_train_end() to save the -last checkpoint super().on_train_end(trainer, pl_module) - # Load the best model and then re-save it - if self.save_best_model: - # wait for all processes - trainer.strategy.barrier("SaveBestCheckpointConnector.resume_end") - if self.best_model_path == "": - logging.warning( - f"{self} was told to save the best checkpoint at the end of training, but no saved checkpoints " - "were found. Saving latest model instead." - ) - - else: - if os.path.isdir(self.best_model_path.split('.ckpt')[0]): - self.best_model_path = self.best_model_path.split('.ckpt')[0] - if self.try_restore_best_ckpt: - self.best_model_path = trainer.strategy.broadcast(self.best_model_path) - trainer._checkpoint_connector.restore(self.best_model_path) - def _del_model_without_trainer(self, filepath: str) -> None: from nemo.utils.get_rank import is_global_rank_zero @@ -409,8 +396,11 @@ def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, torch.Tensor]: return monitor_candidates def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) -> None: + from nemo.utils.get_rank import is_global_rank_zero + # barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed. # if anything goes wrong during checkpointing, we should be able to detect that data is incomplete. + ckpt_filepath = ckpt_to_dir(filepath) / ModelCheckpoint.WEIGHTS_PATH self.set_checkpoint_unfinished_marker(filepath, barrier_after=True) ema_callback = self._ema_callback(trainer) @@ -420,17 +410,26 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) if self.async_save: raise ValueError('async_save with EMA not supported') with ema_callback.save_original_optimizer_state(trainer): - super()._save_checkpoint(trainer, filepath) + super()._save_checkpoint(trainer, ckpt_filepath) # save EMA copy of the model as well. with ema_callback.save_ema_model(trainer): - rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") - filepath = self._ema_format_filepath(filepath) + rank_zero_info(f"Saving EMA weights to separate checkpoint {ckpt_filepath}") + ckpt_filepath = self._ema_format_filepath(ckpt_filepath) if self.verbose: - rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") - super()._save_checkpoint(trainer, filepath) + rank_zero_info(f"Saving EMA weights to separate checkpoint {ckpt_filepath}") + super()._save_checkpoint(trainer, ckpt_filepath) self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) else: + ## Determine whether to include optimizer states in the checkpoint + ## optimizer states are included when + ## 1. save_weights_only is False and + ## 2. either save_optim_on_train_end is True, or save_optim_on_train_end is False but the checkpoint + ## is an intermediate checkpoint. + save_weights_only = self.save_weights_only or ( + not self.save_optim_on_train_end and trainer.global_step == trainer.max_steps + ) + # Async save passes the finalization function to checkpoint_io, # sync save calls the finalization function immediately after save. finalize_fn = self._get_finalize_save_checkpoint_callback(trainer, filepath, trainer.global_step) @@ -445,13 +444,11 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) self.deferred_ckpts_to_remove.append([]) else: storage_options = None - trainer.save_checkpoint(filepath, self.save_weights_only, storage_options=storage_options) + trainer.save_checkpoint(ckpt_filepath, save_weights_only, storage_options=storage_options) - ## NOTE: saving context happens synchronously always - from nemo.utils.get_rank import is_global_rank_zero + if self.always_save_context and is_global_rank_zero(): + TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "context") - if self.enable_nemo_ckpt_io and is_global_rank_zero(): - TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath)) if self.async_save: logging.info(f'Scheduled async checkpoint save for {filepath}') else: diff --git a/nemo/lightning/pytorch/strategies/fsdp_strategy.py b/nemo/lightning/pytorch/strategies/fsdp_strategy.py index 24087f80aae4..2a210c9bd7f0 100644 --- a/nemo/lightning/pytorch/strategies/fsdp_strategy.py +++ b/nemo/lightning/pytorch/strategies/fsdp_strategy.py @@ -208,11 +208,15 @@ def save_checkpoint( checkpoint["sharded_state_dict"] = pyt_to_mcore_state_dict(checkpoint.pop("state_dict")) checkpoint["state_dict"] = OrderedDict([]) - # TODO: do we still need to keep this? - for optim_state in checkpoint['optimizer_states']: - optim_state.pop("state") - - if self.trainer.state.fn == TrainerFn.FITTING and self.ckpt_save_optimizer: + ## replace unsharded optimizer_states with sharded dict. + ## note that if trainer.save_checkpoint(path, save_weights_only=True) is called, + ## the checkpoint will contain only model weights. Optimizer states will be omitted. + if ( + "optimizer_states" in checkpoint + and self.trainer.state.fn == TrainerFn.FITTING + and self.ckpt_save_optimizer + ): + del checkpoint["optimizer_states"] checkpoint['optimizer'] = get_optimizer_state_dict(self.model, self.optimizers) pyt_to_mcore_state_dict(checkpoint['optimizer']['state'], prefix="optimizer.state.") diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index fae6df5be207..4bf8c42ece02 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -625,7 +625,16 @@ def save_checkpoint( # retrieve `sharded_state_dict` if it has not already been configured in `on_save_checkpoint` if "sharded_state_dict" not in checkpoint: checkpoint["sharded_state_dict"] = self.megatron_parallel.sharded_state_dict() - if self.trainer.state.fn == TrainerFn.FITTING and self.ckpt_save_optimizer: + + ## replace unsharded optimizer_states with sharded dict. + ## note that if trainer.save_checkpoint(path, save_weights_only=True) is called, + ## the checkpoint will contain only model weights. Optimizer states will be omitted. + if ( + "optimizer_states" in checkpoint + and self.trainer.state.fn == TrainerFn.FITTING + and self.ckpt_save_optimizer + ): + del checkpoint["optimizer_states"] checkpoint["optimizer"] = [self.optimizer_sharded_state_dict()] self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options) diff --git a/nemo/lightning/resume.py b/nemo/lightning/resume.py index bce1964b6699..c8cefb4dd8d3 100644 --- a/nemo/lightning/resume.py +++ b/nemo/lightning/resume.py @@ -66,6 +66,11 @@ class AutoResume: resume_past_end: bool = False resume_ignore_no_checkpoint: bool = False + WEIGHTS_PATH = "weights" + + def get_model_weights_path(self, path): + return Path(path) / self.WEIGHTS_PATH + def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None): if isinstance(trainer, fl.Fabric): raise NotImplementedError("Fabric is not supported yet.") @@ -90,6 +95,7 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None): def _try_import_model( self, model: Optional[io.ConnectorMixin], path: str, adapter_path: Optional[str] = None ) -> BasePath: + if model is None: raise ValueError("Model is needed to import checkpoint from HF or other non-NeMo checkpoint format.") try: @@ -99,6 +105,11 @@ def _try_import_model( new_path = path if adapter_path: + + maybe_model_weights_path = self.get_model_weights_path(adapter_path) + if os.path.isdir(maybe_model_weights_path): + adapter_path = maybe_model_weights_path + new_path = AdapterPath(Path(adapter_path), base_model_path=new_path) if isinstance(new_path, str): @@ -211,6 +222,11 @@ def get_trainer_ckpt_path(self, model: Optional[io.ConnectorMixin] = None) -> Op if self.resume_if_exists: checkpoint = self._find_trainer_ckpt_path() + if checkpoint: + maybe_model_weights_path = self.get_model_weights_path(checkpoint) + if os.path.isdir(maybe_model_weights_path): + checkpoint = maybe_model_weights_path + if checkpoint: if self.adapter_path: return AdapterPath(Path(self.adapter_path), base_model_path=checkpoint) diff --git a/tests/collections/llm/test_mnist_model_nemo2.py b/tests/collections/llm/test_mnist_model_nemo2.py index c9507ab66bb3..616d845f590f 100644 --- a/tests/collections/llm/test_mnist_model_nemo2.py +++ b/tests/collections/llm/test_mnist_model_nemo2.py @@ -496,13 +496,12 @@ def run_train_mnist_litautoencoder_with_megatron_strategy_single_gpu(): # Configure our custom Checkpointer name = "test_experiment" checkpoint_callback = nl_callbacks.ModelCheckpoint( - save_best_model=True, save_last=True, monitor="val_loss", save_top_k=1, every_n_train_steps=5, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe - enable_nemo_ckpt_io=True, + always_save_context=True, ) root_dir = tmpdir save_dir = root_dir / name @@ -571,6 +570,9 @@ def run_train_mnist_litautoencoder_with_megatron_strategy_single_gpu(): ckpt_path = checkpoint_callback.last_model_path.replace( ".ckpt", "" ) # strip .ckpt off the end of the last path + ckpt_path = ( + Path(ckpt_path) / "weights" + ) ## weights are saved to the "weights" directory within the checkpoint assert Path( ckpt_path diff --git a/tests/collections/llm/test_mnist_model_nemo2_fsdp.py b/tests/collections/llm/test_mnist_model_nemo2_fsdp.py index 025f589e2f39..3ef0f14f10d8 100644 --- a/tests/collections/llm/test_mnist_model_nemo2_fsdp.py +++ b/tests/collections/llm/test_mnist_model_nemo2_fsdp.py @@ -519,13 +519,12 @@ def run_train_mnist_litautoencoder_with_fsdp_strategy_single_gpu(): # Configure our custom Checkpointer name = "test_experiment" checkpoint_callback = nl_callbacks.ModelCheckpoint( - save_best_model=True, save_last=True, monitor="val_loss", save_top_k=1, every_n_train_steps=5, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe - enable_nemo_ckpt_io=True, + always_save_context=True, ) root_dir = tmpdir save_dir = root_dir / name @@ -583,6 +582,9 @@ def run_train_mnist_litautoencoder_with_fsdp_strategy_single_gpu(): ckpt_path = checkpoint_callback.last_model_path.replace( ".ckpt", "" ) # strip .ckpt off the end of the last path + ckpt_path = ( + Path(ckpt_path) / "weights" + ) ## weights are saved to the "weights" directory within the checkpoint assert Path( ckpt_path diff --git a/tests/lightning/test_nemo_logger.py b/tests/lightning/test_nemo_logger.py index 54636f56472a..387d3540930f 100644 --- a/tests/lightning/test_nemo_logger.py +++ b/tests/lightning/test_nemo_logger.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import shutil import time from pathlib import Path from unittest.mock import patch @@ -159,7 +160,10 @@ def test_resume(self, trainer, tmp_path): ## if there are multiple "-last" checkpoints, choose the most recent one Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--last").mkdir() time.sleep(1) ## sleep for a second so the checkpoints are created at different times - Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last").mkdir() + ## make a "weights" dir within the checkpoint + Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last" / "weights").mkdir( + parents=True + ) time.sleep(1) # unfinished last, that should be ignored Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel3--last").mkdir() @@ -169,11 +173,11 @@ def test_resume(self, trainer, tmp_path): resume_from_directory=Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints"), resume_if_exists=True, ).setup(trainer) + ## if "weights" exists, we should restore from there assert str(trainer.ckpt_path) == str( - Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last") + Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last" / "weights") ) - # Finally succeed logger = nl.NeMoLogger( name="default", dir=str(tmp_path) + "/test_resume", @@ -181,7 +185,7 @@ def test_resume(self, trainer, tmp_path): use_datetime_version=False, ) logger.setup(trainer) - Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last").rmdir() + shutil.rmtree(Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last")) nl.AutoResume( resume_if_exists=True, ).setup(trainer)