From d98db0a875777b8c8db8efb7846d1d94c4ac7b14 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 22 Aug 2024 11:17:16 -0700 Subject: [PATCH 01/31] save model weights and artifacts to separate directories Signed-off-by: ashors1 --- nemo/lightning/io/mixin.py | 2 +- .../pytorch/callbacks/model_checkpoint.py | 17 +++++++++-------- nemo/lightning/resume.py | 2 +- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/nemo/lightning/io/mixin.py b/nemo/lightning/io/mixin.py index d0d4d0243ff7..33874851af92 100644 --- a/nemo/lightning/io/mixin.py +++ b/nemo/lightning/io/mixin.py @@ -136,7 +136,7 @@ def io_dump(self, output: Path): will be stored. """ output_path = Path(output) - artifacts_dir = output_path / "artifacts" + artifacts_dir = output_path artifacts_dir.mkdir(parents=True, exist_ok=True) # Store artifacts directory in thread-local storage diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index db48ded0d10d..7009423e01f4 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -69,7 +69,7 @@ def __init__( train_time_interval: Optional[timedelta] = None, save_best_model: bool = False, save_on_train_epoch_end: Optional[bool] = False, # Save after training, not after validation - enable_nemo_ckpt_io: bool = True, + enable_nemo_ckpt_io: bool = False, try_restore_best_ckpt: bool = True, **kwargs, ): @@ -411,6 +411,7 @@ def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, torch.Tensor]: def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) -> None: # barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed. # if anything goes wrong during checkpointing, we should be able to detect that data is incomplete. + ckpt_filepath = ckpt_to_dir(filepath) / "model_weights" self.set_checkpoint_unfinished_marker(filepath, barrier_after=True) ema_callback = self._ema_callback(trainer) @@ -420,15 +421,15 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) if self.async_save: raise ValueError('async_save with EMA not supported') with ema_callback.save_original_optimizer_state(trainer): - super()._save_checkpoint(trainer, filepath) + super()._save_checkpoint(trainer, ckpt_filepath) # save EMA copy of the model as well. with ema_callback.save_ema_model(trainer): - rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") - filepath = self._ema_format_filepath(filepath) + rank_zero_info(f"Saving EMA weights to separate checkpoint {ckpt_filepath}") + ckpt_filepath = self._ema_format_filepath(ckpt_filepath) if self.verbose: - rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") - super()._save_checkpoint(trainer, filepath) + rank_zero_info(f"Saving EMA weights to separate checkpoint {ckpt_filepath}") + super()._save_checkpoint(trainer, ckpt_filepath) self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) else: # Async save passes the finalization function to checkpoint_io, @@ -445,13 +446,13 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) self.deferred_ckpts_to_remove.append([]) else: storage_options = None - trainer.save_checkpoint(filepath, self.save_weights_only, storage_options=storage_options) + trainer.save_checkpoint(ckpt_filepath, self.save_weights_only, storage_options=storage_options) ## NOTE: saving context happens synchronously always from nemo.utils.get_rank import is_global_rank_zero if self.enable_nemo_ckpt_io and is_global_rank_zero(): - TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath)) + TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "artifacts") if self.async_save: logging.info(f'Scheduled async checkpoint save for {filepath}') else: diff --git a/nemo/lightning/resume.py b/nemo/lightning/resume.py index ca87628d699e..4d48835ec439 100644 --- a/nemo/lightning/resume.py +++ b/nemo/lightning/resume.py @@ -151,7 +151,7 @@ def nemo_path(self, model=None) -> Optional[Path]: if checkpoint: if self.adapter_path: return AdapterPath(checkpoint, adapter_path=Path(self.adapter_path)) - return Path(checkpoint) + return Path(checkpoint) / "model_weights" return None From ca04f47d9806477fd400933c4f887c0c47c7ae6b Mon Sep 17 00:00:00 2001 From: ashors1 Date: Sat, 24 Aug 2024 15:51:53 -0700 Subject: [PATCH 02/31] add save_artifacts_on_train_end Signed-off-by: ashors1 --- .../pytorch/callbacks/model_checkpoint.py | 35 +++++++++++-------- nemo/lightning/resume.py | 8 ++++- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index 7009423e01f4..5bd554cf225b 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -31,7 +31,6 @@ from nemo.utils.app_state import AppState from nemo.utils.model_utils import ckpt_to_dir - class ModelCheckpoint(PTLModelCheckpoint): """Light wrapper around Lightning's ModelCheckpoint to force a saved checkpoint on train_end. Adds support for asyncronous checkpointing and provides some additional logic to clean up invalid checkpoints @@ -48,13 +47,16 @@ class ModelCheckpoint(PTLModelCheckpoint): ``every_n_epochs`` or ``every_n_train_steps``. save_best_model: When ``True``, reloads and saves the best checkpoint. save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch - enable_nemo_ckpt_io: Whether to dump the current model model state, including the - config file, to allow for reproducibility of experiments. + always_save_artifacts: Whether to dump the artifacts needed to reinintialize the current + model, trainer, and dataloader to allow for reproducibility of experiments. + save_artifacts_on_train_end: Whether to dump the artifacts on_train_end regardless of whether + ``always_save_artifacts`` is ``True``. async_save: Whether to enable asynchronous checkpointing. try_restore_best_ckpt: Whether to restore the best model path. """ UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished" + MODEL_WEIGHTS_PATH = "model_weights" def __init__( self, @@ -69,13 +71,16 @@ def __init__( train_time_interval: Optional[timedelta] = None, save_best_model: bool = False, save_on_train_epoch_end: Optional[bool] = False, # Save after training, not after validation - enable_nemo_ckpt_io: bool = False, + always_save_artifacts: bool = False, + save_artifacts_on_train_end: bool = True, try_restore_best_ckpt: bool = True, **kwargs, ): self.save_best_model = save_best_model self.previous_best_path = "" - self.enable_nemo_ckpt_io = enable_nemo_ckpt_io + self.always_save_artifacts = always_save_artifacts + self.save_artifacts_on_train_end = save_artifacts_on_train_end + # Checkpoints which removal is deferred until async save is done. # Each element of `deferred_ckpts_to_remove` is a growing list # that `self._remove_checkpoint` adds to. Once `self._save_checkpoint` @@ -251,11 +256,9 @@ def setup(self, trainer, *args, **kwargs) -> None: self.async_save = getattr(trainer.strategy, "async_save", False) super().setup(trainer, *args, **kwargs) - def on_save_checkpoint(self, trainer, pl_module, checkpoint): - output = super().on_save_checkpoint(trainer, pl_module, checkpoint) - return output - def on_train_end(self, trainer, pl_module): + from nemo.utils.get_rank import is_global_rank_zero + if trainer.fast_dev_run: return None @@ -272,6 +275,8 @@ def on_train_end(self, trainer, pl_module): logging.debug(f'Last checkpoint {self.last_model_path} already saved') else: super()._save_last_checkpoint(trainer, monitor_candidates) + if self.save_artifacts_on_train_end and not self.always_save_artifacts and is_global_rank_zero(): + TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(self.last_model_path) / "artifacts") # Call parent on_train_end() to save the -last checkpoint super().on_train_end(trainer, pl_module) @@ -287,7 +292,7 @@ def on_train_end(self, trainer, pl_module): else: if os.path.isdir(self.best_model_path.split('.ckpt')[0]): - self.best_model_path = self.best_model_path.split('.ckpt')[0] + self.best_model_path = Path(self.best_model_path.split('.ckpt')[0]) / ModelCheckpoint.MODEL_WEIGHTS_PATH if self.try_restore_best_ckpt: self.best_model_path = trainer.strategy.broadcast(self.best_model_path) trainer._checkpoint_connector.restore(self.best_model_path) @@ -409,9 +414,11 @@ def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, torch.Tensor]: return monitor_candidates def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) -> None: + from nemo.utils.get_rank import is_global_rank_zero + # barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed. # if anything goes wrong during checkpointing, we should be able to detect that data is incomplete. - ckpt_filepath = ckpt_to_dir(filepath) / "model_weights" + ckpt_filepath = ckpt_to_dir(filepath) / ModelCheckpoint.MODEL_WEIGHTS_PATH self.set_checkpoint_unfinished_marker(filepath, barrier_after=True) ema_callback = self._ema_callback(trainer) @@ -448,11 +455,9 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) storage_options = None trainer.save_checkpoint(ckpt_filepath, self.save_weights_only, storage_options=storage_options) - ## NOTE: saving context happens synchronously always - from nemo.utils.get_rank import is_global_rank_zero - - if self.enable_nemo_ckpt_io and is_global_rank_zero(): + if self.always_save_artifacts and is_global_rank_zero(): TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "artifacts") + if self.async_save: logging.info(f'Scheduled async checkpoint save for {filepath}') else: diff --git a/nemo/lightning/resume.py b/nemo/lightning/resume.py index 4d48835ec439..9afe7feb9c82 100644 --- a/nemo/lightning/resume.py +++ b/nemo/lightning/resume.py @@ -37,6 +37,8 @@ class AutoResume(Resume, io.IOMixin): checkpoints in NeMo. """ + MAYBE_MODEL_WEIGHTS_PATH = "model_weights" + def __init__( self, path: Optional[str] = None, ## old resume_from_checkpoint @@ -151,7 +153,11 @@ def nemo_path(self, model=None) -> Optional[Path]: if checkpoint: if self.adapter_path: return AdapterPath(checkpoint, adapter_path=Path(self.adapter_path)) - return Path(checkpoint) / "model_weights" + + model_weights_path = Path(checkpoint) / AutoResume.MAYBE_MODEL_WEIGHTS_PATH + if os.path.isdir(model_weights_path): + return model_weights_path + return Path(checkpoint) return None From 9264eee7caba849949eb56694d37b3ae1dedfe0e Mon Sep 17 00:00:00 2001 From: ashors1 Date: Sat, 24 Aug 2024 22:52:47 +0000 Subject: [PATCH 03/31] Apply isort and black reformatting Signed-off-by: ashors1 --- nemo/lightning/pytorch/callbacks/model_checkpoint.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index 5bd554cf225b..14c844f0dae8 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -31,6 +31,7 @@ from nemo.utils.app_state import AppState from nemo.utils.model_utils import ckpt_to_dir + class ModelCheckpoint(PTLModelCheckpoint): """Light wrapper around Lightning's ModelCheckpoint to force a saved checkpoint on train_end. Adds support for asyncronous checkpointing and provides some additional logic to clean up invalid checkpoints @@ -292,7 +293,9 @@ def on_train_end(self, trainer, pl_module): else: if os.path.isdir(self.best_model_path.split('.ckpt')[0]): - self.best_model_path = Path(self.best_model_path.split('.ckpt')[0]) / ModelCheckpoint.MODEL_WEIGHTS_PATH + self.best_model_path = ( + Path(self.best_model_path.split('.ckpt')[0]) / ModelCheckpoint.MODEL_WEIGHTS_PATH + ) if self.try_restore_best_ckpt: self.best_model_path = trainer.strategy.broadcast(self.best_model_path) trainer._checkpoint_connector.restore(self.best_model_path) From 1a5a4574362f658bd104dbb1e5125f985c262dec Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 27 Aug 2024 20:45:28 -0700 Subject: [PATCH 04/31] do not save optimizer states in final checkpoint Signed-off-by: ashors1 --- nemo/lightning/pytorch/callbacks/model_checkpoint.py | 7 +++++++ nemo/lightning/pytorch/strategies.py | 10 +++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index 14c844f0dae8..a0a7a646f5ff 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -263,6 +263,9 @@ def on_train_end(self, trainer, pl_module): if trainer.fast_dev_run: return None + ## Do not include optimizer states in final checkpoint + trainer.strategy.ckpt_include_optimizer = False + # check if we need to save a last checkpoint manually as validation isn't always run based on the interval if self.save_last and trainer.val_check_interval != 0: should_save_last_checkpoint = False @@ -427,6 +430,10 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) self._last_global_step_saved = trainer.global_step + ## Do not include optimizer states in final checkpoint + if trainer.global_step == trainer.max_steps: + trainer.strategy.ckpt_include_optimizer = False + if ema_callback is not None: if self.async_save: raise ValueError('async_save with EMA not supported') diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index d6ef18770fa4..f668d7298e69 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -177,7 +177,7 @@ def __init__( self.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size self.sequence_parallel = sequence_parallel self.lazy_init = lazy_init - self.ckpt_include_optimizer = ckpt_include_optimizer + self._ckpt_include_optimizer = ckpt_include_optimizer self.pipeline_dtype = pipeline_dtype self._setup_optimizers = setup_optimizers self._init_model_parallel = init_model_parallel @@ -602,6 +602,14 @@ def optimizer_sharded_state_dict(self, is_loading=False): self.megatron_parallel, optimizer, is_loading=is_loading, sharding_type=sharding_type ) + @property + def ckpt_include_optimizer(self): + return self._ckpt_include_optimizer + + @ckpt_include_optimizer.setter + def ckpt_include_optimizer(self, val): + self._ckpt_include_optimizer = val + @override def save_checkpoint( self, checkpoint: Dict[str, Any], filepath: Union[str, Path], storage_options: Optional[Any] = None From c9302fa36bc0634f59325708a639433622541b03 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 28 Aug 2024 16:38:53 -0700 Subject: [PATCH 05/31] WIP support for saving only last k optimizer states Signed-off-by: ashors1 --- .../pytorch/callbacks/model_checkpoint.py | 76 ++++++++++++++++++- 1 file changed, 72 insertions(+), 4 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index a0a7a646f5ff..41da35157c42 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -15,6 +15,7 @@ import os import re import shutil +from collections import OrderedDict from datetime import timedelta from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Union @@ -48,6 +49,7 @@ class ModelCheckpoint(PTLModelCheckpoint): ``every_n_epochs`` or ``every_n_train_steps``. save_best_model: When ``True``, reloads and saves the best checkpoint. save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch + save_optim_on_train_end: TODO always_save_artifacts: Whether to dump the artifacts needed to reinintialize the current model, trainer, and dataloader to allow for reproducibility of experiments. save_artifacts_on_train_end: Whether to dump the artifacts on_train_end regardless of whether @@ -72,6 +74,8 @@ def __init__( train_time_interval: Optional[timedelta] = None, save_best_model: bool = False, save_on_train_epoch_end: Optional[bool] = False, # Save after training, not after validation + save_last_n_optim_states: int = -1, + save_optim_on_train_end: Optional[bool] = False, always_save_artifacts: bool = False, save_artifacts_on_train_end: bool = True, try_restore_best_ckpt: bool = True, @@ -81,12 +85,15 @@ def __init__( self.previous_best_path = "" self.always_save_artifacts = always_save_artifacts self.save_artifacts_on_train_end = save_artifacts_on_train_end + self.save_last_n_optim_states = save_last_n_optim_states + self.save_optim_on_train_end = save_optim_on_train_end # Checkpoints which removal is deferred until async save is done. # Each element of `deferred_ckpts_to_remove` is a growing list # that `self._remove_checkpoint` adds to. Once `self._save_checkpoint` # is called, the last element is frozen and a new element is added. self.deferred_ckpts_to_remove: List[List[str]] = [] + self.optimizerless_checkpoints = set() self.try_restore_best_ckpt = try_restore_best_ckpt # Call the parent class constructor with the remaining kwargs. @@ -263,8 +270,9 @@ def on_train_end(self, trainer, pl_module): if trainer.fast_dev_run: return None - ## Do not include optimizer states in final checkpoint - trainer.strategy.ckpt_include_optimizer = False + if not self.save_optim_on_train_end: + ## Do not include optimizer states in final checkpoint + trainer.strategy.ckpt_include_optimizer = False # check if we need to save a last checkpoint manually as validation isn't always run based on the interval if self.save_last and trainer.val_check_interval != 0: @@ -431,7 +439,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) self._last_global_step_saved = trainer.global_step ## Do not include optimizer states in final checkpoint - if trainer.global_step == trainer.max_steps: + if trainer.global_step == trainer.max_steps and not self.save_optim_on_train_end: trainer.strategy.ckpt_include_optimizer = False if ema_callback is not None: @@ -465,6 +473,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) storage_options = None trainer.save_checkpoint(ckpt_filepath, self.save_weights_only, storage_options=storage_options) + print(f'{filepath=}') if self.always_save_artifacts and is_global_rank_zero(): TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "artifacts") @@ -472,7 +481,59 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) logging.info(f'Scheduled async checkpoint save for {filepath}') else: finalize_fn() - + + ## TODO: storage_options + def _drop_optimizer_states(self, trainer, filepath: Union[str, Path]) -> None: + # Get list of saved checkpoints + checkpoints = self._get_checkpoints_list(filepath) + + ## TODO: this doesn't work. ModelChck + checkpoint_index = len(checkpoints) - self.save_last_n_optim_states - 1 + if len(checkpoints) > self.save_last_n_optim_states: + checkpoint_path = checkpoints[checkpoint_index] + + logging.info(f"Loading '{checkpoint_path}' checkpoint to drop optimizer states...") + ## TODO: clean + checkpoint = trainer.strategy.load_checkpoint(Path(checkpoint_path) / ModelCheckpoint.MODEL_WEIGHTS_PATH) + + # Remove the checkpoint version with optimizer states + self._remove_checkpoint(trainer, checkpoint_path) + + checkpoint['optimizer'] = [None] + checkpoint['state_dict'] = OrderedDict([]) + checkpoint['sharded_state_dict'] = trainer.lightning_module.sharded_state_dict() + + ## TODO: debug -- only working for some steps right now + # Save the checkpoint without optimizer states + trainer.strategy.save_checkpoint(checkpoint, Path(checkpoint_path) / ModelCheckpoint.MODEL_WEIGHTS_PATH) + self.optimizerless_checkpoints.add(checkpoint_path) + + ## TODO + '''if HAVE_MODELOPT and hasattr(self.lightning_module, "get_model_module_list"): + save_sharded_modelopt_state( + self.lightning_module.get_model_module_list(), + checkpoint_path, + self.unwrapped_checkpoint_io.save_sharded_strategy, + prefix="model.", + )''' + + logging.info(f"Successfully dropped optimizer states for '{checkpoint_path}' checkpoint.") + + def _get_checkpoints_list(self, filepath: Union[str, Path]) -> List[str]: + # Get a list of saved checkpoints + checkpoint_dir = os.path.dirname(filepath) + ## note: we consider only the checkpoints whose optimizer states have not already been dropped + checkpoints = [ + d + for d in os.listdir(checkpoint_dir) + if (os.path.isdir(os.path.join(checkpoint_dir, d)) and '-last' not in d + and os.path.join(checkpoint_dir, d) not in self.optimizerless_checkpoints) + ] + checkpoints = sorted(checkpoints, key=lambda x: (checkpoint_dir / Path(x)).lstat().st_mtime) + checkpoints = [os.path.join(checkpoint_dir, checkpoint) for checkpoint in checkpoints] + + return checkpoints + def _get_finalize_save_checkpoint_callback( self, trainer: 'pytorch_lightning.Trainer', filepath: str, global_step: int ): @@ -492,6 +553,9 @@ def _cb(): self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) if not self.async_save: + ## TODO: clean up + if self.save_last_n_optim_states >= 0 and '-last' in filepath: + self._drop_optimizer_states(trainer=trainer, filepath=filepath) return logging.info(f'Async checkpoint save for step {global_step} ({filepath}) finalized successfully.') @@ -504,6 +568,10 @@ def _cb(): for ckpt_to_remove in ckpts_to_remove: self._remove_checkpoint(trainer, ckpt_to_remove, override_async=True) + if self.save_last_n_optim_states >= 0 and '-last' in filepath: + if is_global_rank_zero(): + self._drop_optimizer_states(trainer=trainer, filepath=filepath) + return _cb def _remove_checkpoint(self, trainer: "pytorch_lightning.Trainer", filepath: str, override_async=False) -> None: From 6da237179d11bd78b3f2d221f787ce0c048c49c0 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 28 Aug 2024 23:39:49 +0000 Subject: [PATCH 06/31] Apply isort and black reformatting Signed-off-by: ashors1 --- .../lightning/pytorch/callbacks/model_checkpoint.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index 41da35157c42..580724b3d7d6 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -481,7 +481,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) logging.info(f'Scheduled async checkpoint save for {filepath}') else: finalize_fn() - + ## TODO: storage_options def _drop_optimizer_states(self, trainer, filepath: Union[str, Path]) -> None: # Get list of saved checkpoints @@ -493,7 +493,7 @@ def _drop_optimizer_states(self, trainer, filepath: Union[str, Path]) -> None: checkpoint_path = checkpoints[checkpoint_index] logging.info(f"Loading '{checkpoint_path}' checkpoint to drop optimizer states...") - ## TODO: clean + ## TODO: clean checkpoint = trainer.strategy.load_checkpoint(Path(checkpoint_path) / ModelCheckpoint.MODEL_WEIGHTS_PATH) # Remove the checkpoint version with optimizer states @@ -526,14 +526,17 @@ def _get_checkpoints_list(self, filepath: Union[str, Path]) -> List[str]: checkpoints = [ d for d in os.listdir(checkpoint_dir) - if (os.path.isdir(os.path.join(checkpoint_dir, d)) and '-last' not in d - and os.path.join(checkpoint_dir, d) not in self.optimizerless_checkpoints) + if ( + os.path.isdir(os.path.join(checkpoint_dir, d)) + and '-last' not in d + and os.path.join(checkpoint_dir, d) not in self.optimizerless_checkpoints + ) ] checkpoints = sorted(checkpoints, key=lambda x: (checkpoint_dir / Path(x)).lstat().st_mtime) checkpoints = [os.path.join(checkpoint_dir, checkpoint) for checkpoint in checkpoints] return checkpoints - + def _get_finalize_save_checkpoint_callback( self, trainer: 'pytorch_lightning.Trainer', filepath: str, global_step: int ): From 9b1f93c3aeaca866539c0fca0aac1cf71c0ad125 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 28 Aug 2024 16:41:05 -0700 Subject: [PATCH 07/31] minor cleanup Signed-off-by: ashors1 --- nemo/lightning/pytorch/callbacks/model_checkpoint.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index 580724b3d7d6..781128663f5e 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -473,7 +473,6 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) storage_options = None trainer.save_checkpoint(ckpt_filepath, self.save_weights_only, storage_options=storage_options) - print(f'{filepath=}') if self.always_save_artifacts and is_global_rank_zero(): TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "artifacts") @@ -487,7 +486,6 @@ def _drop_optimizer_states(self, trainer, filepath: Union[str, Path]) -> None: # Get list of saved checkpoints checkpoints = self._get_checkpoints_list(filepath) - ## TODO: this doesn't work. ModelChck checkpoint_index = len(checkpoints) - self.save_last_n_optim_states - 1 if len(checkpoints) > self.save_last_n_optim_states: checkpoint_path = checkpoints[checkpoint_index] From 402471da24f8e91666b47721b1494d0e812ae927 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 29 Aug 2024 10:31:26 -0700 Subject: [PATCH 08/31] Revert support for saving last k optimizer states. This will be addressed in a subsequent PR. --- .../pytorch/callbacks/model_checkpoint.py | 75 +------------------ 1 file changed, 3 insertions(+), 72 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index 781128663f5e..a0a7a646f5ff 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -15,7 +15,6 @@ import os import re import shutil -from collections import OrderedDict from datetime import timedelta from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Union @@ -49,7 +48,6 @@ class ModelCheckpoint(PTLModelCheckpoint): ``every_n_epochs`` or ``every_n_train_steps``. save_best_model: When ``True``, reloads and saves the best checkpoint. save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch - save_optim_on_train_end: TODO always_save_artifacts: Whether to dump the artifacts needed to reinintialize the current model, trainer, and dataloader to allow for reproducibility of experiments. save_artifacts_on_train_end: Whether to dump the artifacts on_train_end regardless of whether @@ -74,8 +72,6 @@ def __init__( train_time_interval: Optional[timedelta] = None, save_best_model: bool = False, save_on_train_epoch_end: Optional[bool] = False, # Save after training, not after validation - save_last_n_optim_states: int = -1, - save_optim_on_train_end: Optional[bool] = False, always_save_artifacts: bool = False, save_artifacts_on_train_end: bool = True, try_restore_best_ckpt: bool = True, @@ -85,15 +81,12 @@ def __init__( self.previous_best_path = "" self.always_save_artifacts = always_save_artifacts self.save_artifacts_on_train_end = save_artifacts_on_train_end - self.save_last_n_optim_states = save_last_n_optim_states - self.save_optim_on_train_end = save_optim_on_train_end # Checkpoints which removal is deferred until async save is done. # Each element of `deferred_ckpts_to_remove` is a growing list # that `self._remove_checkpoint` adds to. Once `self._save_checkpoint` # is called, the last element is frozen and a new element is added. self.deferred_ckpts_to_remove: List[List[str]] = [] - self.optimizerless_checkpoints = set() self.try_restore_best_ckpt = try_restore_best_ckpt # Call the parent class constructor with the remaining kwargs. @@ -270,9 +263,8 @@ def on_train_end(self, trainer, pl_module): if trainer.fast_dev_run: return None - if not self.save_optim_on_train_end: - ## Do not include optimizer states in final checkpoint - trainer.strategy.ckpt_include_optimizer = False + ## Do not include optimizer states in final checkpoint + trainer.strategy.ckpt_include_optimizer = False # check if we need to save a last checkpoint manually as validation isn't always run based on the interval if self.save_last and trainer.val_check_interval != 0: @@ -439,7 +431,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) self._last_global_step_saved = trainer.global_step ## Do not include optimizer states in final checkpoint - if trainer.global_step == trainer.max_steps and not self.save_optim_on_train_end: + if trainer.global_step == trainer.max_steps: trainer.strategy.ckpt_include_optimizer = False if ema_callback is not None: @@ -481,60 +473,6 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) else: finalize_fn() - ## TODO: storage_options - def _drop_optimizer_states(self, trainer, filepath: Union[str, Path]) -> None: - # Get list of saved checkpoints - checkpoints = self._get_checkpoints_list(filepath) - - checkpoint_index = len(checkpoints) - self.save_last_n_optim_states - 1 - if len(checkpoints) > self.save_last_n_optim_states: - checkpoint_path = checkpoints[checkpoint_index] - - logging.info(f"Loading '{checkpoint_path}' checkpoint to drop optimizer states...") - ## TODO: clean - checkpoint = trainer.strategy.load_checkpoint(Path(checkpoint_path) / ModelCheckpoint.MODEL_WEIGHTS_PATH) - - # Remove the checkpoint version with optimizer states - self._remove_checkpoint(trainer, checkpoint_path) - - checkpoint['optimizer'] = [None] - checkpoint['state_dict'] = OrderedDict([]) - checkpoint['sharded_state_dict'] = trainer.lightning_module.sharded_state_dict() - - ## TODO: debug -- only working for some steps right now - # Save the checkpoint without optimizer states - trainer.strategy.save_checkpoint(checkpoint, Path(checkpoint_path) / ModelCheckpoint.MODEL_WEIGHTS_PATH) - self.optimizerless_checkpoints.add(checkpoint_path) - - ## TODO - '''if HAVE_MODELOPT and hasattr(self.lightning_module, "get_model_module_list"): - save_sharded_modelopt_state( - self.lightning_module.get_model_module_list(), - checkpoint_path, - self.unwrapped_checkpoint_io.save_sharded_strategy, - prefix="model.", - )''' - - logging.info(f"Successfully dropped optimizer states for '{checkpoint_path}' checkpoint.") - - def _get_checkpoints_list(self, filepath: Union[str, Path]) -> List[str]: - # Get a list of saved checkpoints - checkpoint_dir = os.path.dirname(filepath) - ## note: we consider only the checkpoints whose optimizer states have not already been dropped - checkpoints = [ - d - for d in os.listdir(checkpoint_dir) - if ( - os.path.isdir(os.path.join(checkpoint_dir, d)) - and '-last' not in d - and os.path.join(checkpoint_dir, d) not in self.optimizerless_checkpoints - ) - ] - checkpoints = sorted(checkpoints, key=lambda x: (checkpoint_dir / Path(x)).lstat().st_mtime) - checkpoints = [os.path.join(checkpoint_dir, checkpoint) for checkpoint in checkpoints] - - return checkpoints - def _get_finalize_save_checkpoint_callback( self, trainer: 'pytorch_lightning.Trainer', filepath: str, global_step: int ): @@ -554,9 +492,6 @@ def _cb(): self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) if not self.async_save: - ## TODO: clean up - if self.save_last_n_optim_states >= 0 and '-last' in filepath: - self._drop_optimizer_states(trainer=trainer, filepath=filepath) return logging.info(f'Async checkpoint save for step {global_step} ({filepath}) finalized successfully.') @@ -569,10 +504,6 @@ def _cb(): for ckpt_to_remove in ckpts_to_remove: self._remove_checkpoint(trainer, ckpt_to_remove, override_async=True) - if self.save_last_n_optim_states >= 0 and '-last' in filepath: - if is_global_rank_zero(): - self._drop_optimizer_states(trainer=trainer, filepath=filepath) - return _cb def _remove_checkpoint(self, trainer: "pytorch_lightning.Trainer", filepath: str, override_async=False) -> None: From 0a5595363954a79dd0ead688f45e81d1e9732086 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 29 Aug 2024 20:24:33 -0700 Subject: [PATCH 09/31] use storage_options to determine when to skip saving optimizer states Signed-off-by: ashors1 --- .../pytorch/callbacks/model_checkpoint.py | 14 ++++--------- nemo/lightning/pytorch/strategies.py | 20 +++++++++---------- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index a0a7a646f5ff..b34af7782e1c 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -263,9 +263,6 @@ def on_train_end(self, trainer, pl_module): if trainer.fast_dev_run: return None - ## Do not include optimizer states in final checkpoint - trainer.strategy.ckpt_include_optimizer = False - # check if we need to save a last checkpoint manually as validation isn't always run based on the interval if self.save_last and trainer.val_check_interval != 0: should_save_last_checkpoint = False @@ -430,10 +427,6 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) self._last_global_step_saved = trainer.global_step - ## Do not include optimizer states in final checkpoint - if trainer.global_step == trainer.max_steps: - trainer.strategy.ckpt_include_optimizer = False - if ema_callback is not None: if self.async_save: raise ValueError('async_save with EMA not supported') @@ -449,6 +442,9 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) super()._save_checkpoint(trainer, ckpt_filepath) self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) else: + ## Do not include optimizer states in final checkpoint + storage_options = dict(include_optimizer=(trainer.global_step == trainer.max_steps)) + # Async save passes the finalization function to checkpoint_io, # sync save calls the finalization function immediately after save. finalize_fn = self._get_finalize_save_checkpoint_callback(trainer, filepath, trainer.global_step) @@ -458,11 +454,9 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) if not isinstance(checkpoint_io, AsyncFinalizableCheckpointIO): raise ValueError('Async save requires async compatible CheckpointIO') - storage_options = dict(finalize_fn=finalize_fn) + storage_options["finalize_fn"]=finalize_fn # Each upcoming ckpt removal request will be executed as part of this save finalization self.deferred_ckpts_to_remove.append([]) - else: - storage_options = None trainer.save_checkpoint(ckpt_filepath, self.save_weights_only, storage_options=storage_options) if self.always_save_artifacts and is_global_rank_zero(): diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index f668d7298e69..180ddf8d6423 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -177,7 +177,7 @@ def __init__( self.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size self.sequence_parallel = sequence_parallel self.lazy_init = lazy_init - self._ckpt_include_optimizer = ckpt_include_optimizer + self.ckpt_include_optimizer = ckpt_include_optimizer self.pipeline_dtype = pipeline_dtype self._setup_optimizers = setup_optimizers self._init_model_parallel = init_model_parallel @@ -602,14 +602,6 @@ def optimizer_sharded_state_dict(self, is_loading=False): self.megatron_parallel, optimizer, is_loading=is_loading, sharding_type=sharding_type ) - @property - def ckpt_include_optimizer(self): - return self._ckpt_include_optimizer - - @ckpt_include_optimizer.setter - def ckpt_include_optimizer(self, val): - self._ckpt_include_optimizer = val - @override def save_checkpoint( self, checkpoint: Dict[str, Any], filepath: Union[str, Path], storage_options: Optional[Any] = None @@ -618,11 +610,19 @@ def save_checkpoint( # retrieve `sharded_state_dict` if it has not already been configured in `on_save_checkpoint` if "sharded_state_dict" not in checkpoint: checkpoint["sharded_state_dict"] = self.megatron_parallel.sharded_state_dict() - if self.trainer.state.fn == TrainerFn.FITTING and self.ckpt_include_optimizer: + + ## only save optimizer states if self.ckpt_include_optimizer and storage_options["include_optimizer"] + ## are both True + include_optimizer = self.ckpt_include_optimizer + if "include_optimizer" in storage_options: + include_optimizer = not include_optimizer and storage_options["include_optimizer"] + + if self.trainer.state.fn == TrainerFn.FITTING and include_optimizer: checkpoint["optimizer"] = [self.optimizer_sharded_state_dict()] self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options) + ## TODO: only load the optimizer states when they exist in the checkpoint? @override def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: """PTL method which we override to integrate distributed checkpoints for model parallel models. From aa67b8a191531dbea3e25477dd48d4458080913f Mon Sep 17 00:00:00 2001 From: ashors1 Date: Fri, 30 Aug 2024 03:25:28 +0000 Subject: [PATCH 10/31] Apply isort and black reformatting Signed-off-by: ashors1 --- nemo/lightning/pytorch/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index b34af7782e1c..1f04582f8fd0 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -454,7 +454,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) if not isinstance(checkpoint_io, AsyncFinalizableCheckpointIO): raise ValueError('Async save requires async compatible CheckpointIO') - storage_options["finalize_fn"]=finalize_fn + storage_options["finalize_fn"] = finalize_fn # Each upcoming ckpt removal request will be executed as part of this save finalization self.deferred_ckpts_to_remove.append([]) trainer.save_checkpoint(ckpt_filepath, self.save_weights_only, storage_options=storage_options) From 11001a0c7cd6fc0fa04ef7445b77fec02e7f489f Mon Sep 17 00:00:00 2001 From: ashors1 Date: Fri, 30 Aug 2024 09:22:20 -0700 Subject: [PATCH 11/31] fix variable names, make checkpoint load work when optimizer states don't exist in the checkpoint Signed-off-by: ashors1 --- .../pytorch/callbacks/model_checkpoint.py | 27 ++++++++++--------- nemo/lightning/pytorch/strategies.py | 13 ++++++--- nemo/lightning/resume.py | 4 +-- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index b34af7782e1c..7aeefa7c4d3c 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -48,16 +48,16 @@ class ModelCheckpoint(PTLModelCheckpoint): ``every_n_epochs`` or ``every_n_train_steps``. save_best_model: When ``True``, reloads and saves the best checkpoint. save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch - always_save_artifacts: Whether to dump the artifacts needed to reinintialize the current + always_save_context: Whether to dump the artifacts needed to reinintialize the current model, trainer, and dataloader to allow for reproducibility of experiments. - save_artifacts_on_train_end: Whether to dump the artifacts on_train_end regardless of whether - ``always_save_artifacts`` is ``True``. + save_context_on_train_end: Whether to dump the artifacts on_train_end regardless of whether + ``always_save_context`` is ``True``. async_save: Whether to enable asynchronous checkpointing. try_restore_best_ckpt: Whether to restore the best model path. """ UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished" - MODEL_WEIGHTS_PATH = "model_weights" + WEIGHTS_PATH = "weights" def __init__( self, @@ -72,15 +72,15 @@ def __init__( train_time_interval: Optional[timedelta] = None, save_best_model: bool = False, save_on_train_epoch_end: Optional[bool] = False, # Save after training, not after validation - always_save_artifacts: bool = False, - save_artifacts_on_train_end: bool = True, + always_save_context: bool = False, + save_context_on_train_end: bool = True, try_restore_best_ckpt: bool = True, **kwargs, ): self.save_best_model = save_best_model self.previous_best_path = "" - self.always_save_artifacts = always_save_artifacts - self.save_artifacts_on_train_end = save_artifacts_on_train_end + self.always_save_context = always_save_context + self.save_context_on_train_end = save_context_on_train_end # Checkpoints which removal is deferred until async save is done. # Each element of `deferred_ckpts_to_remove` is a growing list @@ -276,12 +276,13 @@ def on_train_end(self, trainer, pl_module): logging.debug(f'Last checkpoint {self.last_model_path} already saved') else: super()._save_last_checkpoint(trainer, monitor_candidates) - if self.save_artifacts_on_train_end and not self.always_save_artifacts and is_global_rank_zero(): + if self.save_context_on_train_end and not self.always_save_context and is_global_rank_zero(): TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(self.last_model_path) / "artifacts") # Call parent on_train_end() to save the -last checkpoint super().on_train_end(trainer, pl_module) # Load the best model and then re-save it + ## TODO: finish adding support if self.save_best_model: # wait for all processes trainer.strategy.barrier("SaveBestCheckpointConnector.resume_end") @@ -294,7 +295,7 @@ def on_train_end(self, trainer, pl_module): else: if os.path.isdir(self.best_model_path.split('.ckpt')[0]): self.best_model_path = ( - Path(self.best_model_path.split('.ckpt')[0]) / ModelCheckpoint.MODEL_WEIGHTS_PATH + Path(self.best_model_path.split('.ckpt')[0]) / ModelCheckpoint.WEIGHTS_PATH ) if self.try_restore_best_ckpt: self.best_model_path = trainer.strategy.broadcast(self.best_model_path) @@ -421,7 +422,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) # barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed. # if anything goes wrong during checkpointing, we should be able to detect that data is incomplete. - ckpt_filepath = ckpt_to_dir(filepath) / ModelCheckpoint.MODEL_WEIGHTS_PATH + ckpt_filepath = ckpt_to_dir(filepath) / ModelCheckpoint.WEIGHTS_PATH self.set_checkpoint_unfinished_marker(filepath, barrier_after=True) ema_callback = self._ema_callback(trainer) @@ -443,7 +444,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) else: ## Do not include optimizer states in final checkpoint - storage_options = dict(include_optimizer=(trainer.global_step == trainer.max_steps)) + storage_options = dict(include_optimizer=(trainer.global_step < trainer.max_steps)) # Async save passes the finalization function to checkpoint_io, # sync save calls the finalization function immediately after save. @@ -459,7 +460,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) self.deferred_ckpts_to_remove.append([]) trainer.save_checkpoint(ckpt_filepath, self.save_weights_only, storage_options=storage_options) - if self.always_save_artifacts and is_global_rank_zero(): + if self.always_save_context and is_global_rank_zero(): TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "artifacts") if self.async_save: diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index 180ddf8d6423..07a960a549f9 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -14,6 +14,7 @@ import torch.distributed from lightning_fabric.plugins import CheckpointIO, ClusterEnvironment from lightning_fabric.utilities.optimizer import _optimizer_to_device, _optimizers_to_device +from megatron.core import dist_checkpointing from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig from pytorch_lightning.accelerators import CPUAccelerator @@ -615,14 +616,14 @@ def save_checkpoint( ## are both True include_optimizer = self.ckpt_include_optimizer if "include_optimizer" in storage_options: - include_optimizer = not include_optimizer and storage_options["include_optimizer"] + include_optimizer = include_optimizer and storage_options["include_optimizer"] + del storage_options["include_optimizer"] if self.trainer.state.fn == TrainerFn.FITTING and include_optimizer: checkpoint["optimizer"] = [self.optimizer_sharded_state_dict()] self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options) - ## TODO: only load the optimizer states when they exist in the checkpoint? @override def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: """PTL method which we override to integrate distributed checkpoints for model parallel models. @@ -636,7 +637,11 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: sharded_state_dict = {} sharded_state_dict["state_dict"] = self.megatron_parallel.sharded_state_dict() - if self.ckpt_include_optimizer and self.trainer.state.fn == TrainerFn.FITTING: + common_state_dict = dist_checkpointing.load_common_state_dict(checkpoint_path) + has_optim = "optimizer" in common_state_dict + logging.info('Checkpoint is missing optimizer state. Restoring only the model weights.') + + if has_optim and self.ckpt_include_optimizer and self.trainer.state.fn == TrainerFn.FITTING: if self.lightning_module.optimizers(use_pl_optimizer=False): sharded_state_dict["optimizer"] = [self.optimizer_sharded_state_dict(is_loading=True)] @@ -646,7 +651,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: @override def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - if not self.ckpt_include_optimizer: + if not self.ckpt_include_optimizer or "optimizer" not in checkpoint: return optimizer_states = checkpoint["optimizer"] diff --git a/nemo/lightning/resume.py b/nemo/lightning/resume.py index 9afe7feb9c82..7161c2171ea1 100644 --- a/nemo/lightning/resume.py +++ b/nemo/lightning/resume.py @@ -37,7 +37,7 @@ class AutoResume(Resume, io.IOMixin): checkpoints in NeMo. """ - MAYBE_MODEL_WEIGHTS_PATH = "model_weights" + WEIGHTS_PATH = "weights" def __init__( self, @@ -154,7 +154,7 @@ def nemo_path(self, model=None) -> Optional[Path]: if self.adapter_path: return AdapterPath(checkpoint, adapter_path=Path(self.adapter_path)) - model_weights_path = Path(checkpoint) / AutoResume.MAYBE_MODEL_WEIGHTS_PATH + model_weights_path = Path(checkpoint) / AutoResume.WEIGHTS_PATH if os.path.isdir(model_weights_path): return model_weights_path return Path(checkpoint) From 3b3b779621d8d6aee457c7b914a4ecbb38738fda Mon Sep 17 00:00:00 2001 From: ashors1 Date: Fri, 30 Aug 2024 16:23:53 +0000 Subject: [PATCH 12/31] Apply isort and black reformatting Signed-off-by: ashors1 --- nemo/lightning/pytorch/callbacks/model_checkpoint.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index 42fde751d08b..b8578599e7d1 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -294,9 +294,7 @@ def on_train_end(self, trainer, pl_module): else: if os.path.isdir(self.best_model_path.split('.ckpt')[0]): - self.best_model_path = ( - Path(self.best_model_path.split('.ckpt')[0]) / ModelCheckpoint.WEIGHTS_PATH - ) + self.best_model_path = Path(self.best_model_path.split('.ckpt')[0]) / ModelCheckpoint.WEIGHTS_PATH if self.try_restore_best_ckpt: self.best_model_path = trainer.strategy.broadcast(self.best_model_path) trainer._checkpoint_connector.restore(self.best_model_path) From 1b9696f74d7bf5d45d0080effe07fa45ddf9f616 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Fri, 30 Aug 2024 10:10:02 -0700 Subject: [PATCH 13/31] FSDP updates, provide option to save optimizer states on_train_end Signed-off-by: ashors1 --- .../pytorch/callbacks/model_checkpoint.py | 6 +++++- .../pytorch/strategies/fsdp_strategy.py | 18 ++++++++++++++++-- .../pytorch/strategies/megatron_strategy.py | 3 ++- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index b8578599e7d1..a66be4d868d7 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -48,6 +48,8 @@ class ModelCheckpoint(PTLModelCheckpoint): ``every_n_epochs`` or ``every_n_train_steps``. save_best_model: When ``True``, reloads and saves the best checkpoint. save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch + save_optim_on_train_end: Whether to include the optimizer states in the final checkpoint + at the end of training. always_save_context: Whether to dump the artifacts needed to reinintialize the current model, trainer, and dataloader to allow for reproducibility of experiments. save_context_on_train_end: Whether to dump the artifacts on_train_end regardless of whether @@ -72,6 +74,7 @@ def __init__( train_time_interval: Optional[timedelta] = None, save_best_model: bool = False, save_on_train_epoch_end: Optional[bool] = False, # Save after training, not after validation + save_optim_on_train_end: Optional[bool] = False, always_save_context: bool = False, save_context_on_train_end: bool = True, try_restore_best_ckpt: bool = True, @@ -81,6 +84,7 @@ def __init__( self.previous_best_path = "" self.always_save_context = always_save_context self.save_context_on_train_end = save_context_on_train_end + self.save_optim_on_train_end = save_optim_on_train_end # Checkpoints which removal is deferred until async save is done. # Each element of `deferred_ckpts_to_remove` is a growing list @@ -442,7 +446,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) else: ## Do not include optimizer states in final checkpoint - storage_options = dict(include_optimizer=(trainer.global_step < trainer.max_steps)) + storage_options = dict(include_optimizer=(trainer.global_step < trainer.max_steps or self.save_optim_on_train_end)) # Async save passes the finalization function to checkpoint_io, # sync save calls the finalization function immediately after save. diff --git a/nemo/lightning/pytorch/strategies/fsdp_strategy.py b/nemo/lightning/pytorch/strategies/fsdp_strategy.py index 9bb08b3cbd7a..210f861429a9 100644 --- a/nemo/lightning/pytorch/strategies/fsdp_strategy.py +++ b/nemo/lightning/pytorch/strategies/fsdp_strategy.py @@ -1,3 +1,4 @@ +import logging import shutil from collections import OrderedDict from pathlib import Path @@ -7,6 +8,7 @@ import torch from lightning_fabric.plugins import CheckpointIO from lightning_fabric.strategies.fsdp import _get_sharded_state_dict_context +from megatron.core import dist_checkpointing from megatron.core.transformer.transformer_layer import TransformerLayer from pytorch_lightning.strategies.fsdp import FSDPStrategy as PLFSDPStrategy from pytorch_lightning.trainer.states import TrainerFn @@ -193,7 +195,14 @@ def save_checkpoint( for optim_state in checkpoint['optimizer_states']: optim_state.pop("state") - if self.trainer.state.fn == TrainerFn.FITTING and self.ckpt_include_optimizer: + ## only save optimizer states if self.ckpt_include_optimizer and storage_options["include_optimizer"] + ## are both True + include_optimizer = self.ckpt_include_optimizer + if "include_optimizer" in storage_options: + include_optimizer = include_optimizer and storage_options["include_optimizer"] + del storage_options["include_optimizer"] + + if self.trainer.state.fn == TrainerFn.FITTING and include_optimizer: checkpoint['optimizer'] = get_optimizer_state_dict(self.model, self.optimizers) pyt_to_mcore_state_dict(checkpoint['optimizer']['state'], prefix="optimizer.state.") @@ -224,7 +233,12 @@ def load_checkpoint(self, checkpoint_path: str | Path) -> Dict[str, Any]: pyt_to_mcore_state_dict(msd) sharded_state_dict["sharded_state_dict"] = msd - if self.ckpt_include_optimizer and self.trainer.state.fn == TrainerFn.FITTING: + common_state_dict = dist_checkpointing.load_common_state_dict(checkpoint_path) + has_optim = "optimizer" in common_state_dict + if not has_optim: + logging.warn('Checkpoint is missing optimizer state. Restoring only the model weights.') + + if has_optim and self.ckpt_include_optimizer and self.trainer.state.fn == TrainerFn.FITTING: osd = get_optimizer_state_dict(self.model, self.optimizers, options=StateDictOptions(cpu_offload=True)) pyt_to_mcore_state_dict(osd['state'], prefix="optimizer.state.") sharded_state_dict["optimizer"] = osd diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index 22223b457684..142a9f421e5f 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -631,7 +631,8 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: common_state_dict = dist_checkpointing.load_common_state_dict(checkpoint_path) has_optim = "optimizer" in common_state_dict - logging.info('Checkpoint is missing optimizer state. Restoring only the model weights.') + if not has_optim: + logging.warn('Checkpoint is missing optimizer state. Restoring only the model weights.') if has_optim and self.ckpt_include_optimizer and self.trainer.state.fn == TrainerFn.FITTING: if self.lightning_module.optimizers(use_pl_optimizer=False): From 21fcb4070b5a13f611ffbc1ff844dbd9b3edb2ca Mon Sep 17 00:00:00 2001 From: ashors1 Date: Fri, 30 Aug 2024 17:10:51 +0000 Subject: [PATCH 14/31] Apply isort and black reformatting Signed-off-by: ashors1 --- nemo/lightning/pytorch/callbacks/model_checkpoint.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index a66be4d868d7..5c6910434a29 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -446,7 +446,9 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) else: ## Do not include optimizer states in final checkpoint - storage_options = dict(include_optimizer=(trainer.global_step < trainer.max_steps or self.save_optim_on_train_end)) + storage_options = dict( + include_optimizer=(trainer.global_step < trainer.max_steps or self.save_optim_on_train_end) + ) # Async save passes the finalization function to checkpoint_io, # sync save calls the finalization function immediately after save. From b9a0e9e0f843bfd566578c51efa18ac0d0d830a5 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Fri, 30 Aug 2024 13:27:31 -0700 Subject: [PATCH 15/31] simplify implementation, remove save_best_model option Signed-off-by: ashors1 --- .../pytorch/callbacks/model_checkpoint.py | 25 ------------------- .../pytorch/strategies/fsdp_strategy.py | 7 +----- .../pytorch/strategies/megatron_strategy.py | 7 +----- 3 files changed, 2 insertions(+), 37 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index a66be4d868d7..d33c2b8e2d91 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -46,7 +46,6 @@ class ModelCheckpoint(PTLModelCheckpoint): every_n_train_steps: Number of train steps between checkpoints. train_time_interval: After each interval, monitor checkpoints. Not to be used with ``every_n_epochs`` or ``every_n_train_steps``. - save_best_model: When ``True``, reloads and saves the best checkpoint. save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch save_optim_on_train_end: Whether to include the optimizer states in the final checkpoint at the end of training. @@ -55,7 +54,6 @@ class ModelCheckpoint(PTLModelCheckpoint): save_context_on_train_end: Whether to dump the artifacts on_train_end regardless of whether ``always_save_context`` is ``True``. async_save: Whether to enable asynchronous checkpointing. - try_restore_best_ckpt: Whether to restore the best model path. """ UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished" @@ -72,16 +70,12 @@ def __init__( every_n_epochs: int = None, every_n_train_steps: Optional[int] = None, train_time_interval: Optional[timedelta] = None, - save_best_model: bool = False, save_on_train_epoch_end: Optional[bool] = False, # Save after training, not after validation save_optim_on_train_end: Optional[bool] = False, always_save_context: bool = False, save_context_on_train_end: bool = True, - try_restore_best_ckpt: bool = True, **kwargs, ): - self.save_best_model = save_best_model - self.previous_best_path = "" self.always_save_context = always_save_context self.save_context_on_train_end = save_context_on_train_end self.save_optim_on_train_end = save_optim_on_train_end @@ -91,7 +85,6 @@ def __init__( # that `self._remove_checkpoint` adds to. Once `self._save_checkpoint` # is called, the last element is frozen and a new element is added. self.deferred_ckpts_to_remove: List[List[str]] = [] - self.try_restore_best_ckpt = try_restore_best_ckpt # Call the parent class constructor with the remaining kwargs. super().__init__( @@ -285,24 +278,6 @@ def on_train_end(self, trainer, pl_module): # Call parent on_train_end() to save the -last checkpoint super().on_train_end(trainer, pl_module) - # Load the best model and then re-save it - ## TODO: finish adding support - if self.save_best_model: - # wait for all processes - trainer.strategy.barrier("SaveBestCheckpointConnector.resume_end") - if self.best_model_path == "": - logging.warning( - f"{self} was told to save the best checkpoint at the end of training, but no saved checkpoints " - "were found. Saving latest model instead." - ) - - else: - if os.path.isdir(self.best_model_path.split('.ckpt')[0]): - self.best_model_path = Path(self.best_model_path.split('.ckpt')[0]) / ModelCheckpoint.WEIGHTS_PATH - if self.try_restore_best_ckpt: - self.best_model_path = trainer.strategy.broadcast(self.best_model_path) - trainer._checkpoint_connector.restore(self.best_model_path) - def _del_model_without_trainer(self, filepath: str) -> None: from nemo.utils.get_rank import is_global_rank_zero diff --git a/nemo/lightning/pytorch/strategies/fsdp_strategy.py b/nemo/lightning/pytorch/strategies/fsdp_strategy.py index 210f861429a9..cb19c50d269c 100644 --- a/nemo/lightning/pytorch/strategies/fsdp_strategy.py +++ b/nemo/lightning/pytorch/strategies/fsdp_strategy.py @@ -233,12 +233,7 @@ def load_checkpoint(self, checkpoint_path: str | Path) -> Dict[str, Any]: pyt_to_mcore_state_dict(msd) sharded_state_dict["sharded_state_dict"] = msd - common_state_dict = dist_checkpointing.load_common_state_dict(checkpoint_path) - has_optim = "optimizer" in common_state_dict - if not has_optim: - logging.warn('Checkpoint is missing optimizer state. Restoring only the model weights.') - - if has_optim and self.ckpt_include_optimizer and self.trainer.state.fn == TrainerFn.FITTING: + if self.ckpt_include_optimizer and self.trainer.state.fn == TrainerFn.FITTING: osd = get_optimizer_state_dict(self.model, self.optimizers, options=StateDictOptions(cpu_offload=True)) pyt_to_mcore_state_dict(osd['state'], prefix="optimizer.state.") sharded_state_dict["optimizer"] = osd diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index 142a9f421e5f..cf622fa19ff2 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -629,12 +629,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: sharded_state_dict = {} sharded_state_dict["state_dict"] = self.megatron_parallel.sharded_state_dict() - common_state_dict = dist_checkpointing.load_common_state_dict(checkpoint_path) - has_optim = "optimizer" in common_state_dict - if not has_optim: - logging.warn('Checkpoint is missing optimizer state. Restoring only the model weights.') - - if has_optim and self.ckpt_include_optimizer and self.trainer.state.fn == TrainerFn.FITTING: + if self.ckpt_include_optimizer and self.trainer.state.fn == TrainerFn.FITTING: if self.lightning_module.optimizers(use_pl_optimizer=False): sharded_state_dict["optimizer"] = [self.optimizer_sharded_state_dict(is_loading=True)] From 3c748517cd229882001bff31faf27970412ffaaa Mon Sep 17 00:00:00 2001 From: ashors1 Date: Fri, 30 Aug 2024 13:33:18 -0700 Subject: [PATCH 16/31] update default value of ckpt_include_optimizer for fsdp Signed-off-by: ashors1 --- nemo/lightning/pytorch/strategies/fsdp_strategy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/lightning/pytorch/strategies/fsdp_strategy.py b/nemo/lightning/pytorch/strategies/fsdp_strategy.py index cb19c50d269c..4019a3e09862 100644 --- a/nemo/lightning/pytorch/strategies/fsdp_strategy.py +++ b/nemo/lightning/pytorch/strategies/fsdp_strategy.py @@ -58,7 +58,7 @@ def __init__( self, auto_wrap_policy={TransformerLayer}, state_dict_type="sharded", - ckpt_include_optimizer=False, + ckpt_include_optimizer=True, data_sampler=None, **kwargs, ): From 3fd76adf993d6abe24743d472b8bdab4f0413db6 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 3 Sep 2024 09:36:52 -0700 Subject: [PATCH 17/31] remove unused imports Signed-off-by: ashors1 --- nemo/lightning/pytorch/strategies/fsdp_strategy.py | 1 - nemo/lightning/pytorch/strategies/megatron_strategy.py | 1 - 2 files changed, 2 deletions(-) diff --git a/nemo/lightning/pytorch/strategies/fsdp_strategy.py b/nemo/lightning/pytorch/strategies/fsdp_strategy.py index 4019a3e09862..31e637534e15 100644 --- a/nemo/lightning/pytorch/strategies/fsdp_strategy.py +++ b/nemo/lightning/pytorch/strategies/fsdp_strategy.py @@ -8,7 +8,6 @@ import torch from lightning_fabric.plugins import CheckpointIO from lightning_fabric.strategies.fsdp import _get_sharded_state_dict_context -from megatron.core import dist_checkpointing from megatron.core.transformer.transformer_layer import TransformerLayer from pytorch_lightning.strategies.fsdp import FSDPStrategy as PLFSDPStrategy from pytorch_lightning.trainer.states import TrainerFn diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index cf622fa19ff2..009954cacfca 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -27,7 +27,6 @@ import torch.distributed from lightning_fabric.plugins import CheckpointIO, ClusterEnvironment from lightning_fabric.utilities.optimizer import _optimizer_to_device, _optimizers_to_device -from megatron.core import dist_checkpointing from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig from pytorch_lightning.accelerators import CPUAccelerator From 2817d56431e455f62fd87e58af2cbed32a600f6d Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 3 Sep 2024 11:30:18 -0700 Subject: [PATCH 18/31] remove unused import Signed-off-by: ashors1 --- nemo/lightning/pytorch/strategies/fsdp_strategy.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo/lightning/pytorch/strategies/fsdp_strategy.py b/nemo/lightning/pytorch/strategies/fsdp_strategy.py index 31e637534e15..ab87808d5643 100644 --- a/nemo/lightning/pytorch/strategies/fsdp_strategy.py +++ b/nemo/lightning/pytorch/strategies/fsdp_strategy.py @@ -1,4 +1,3 @@ -import logging import shutil from collections import OrderedDict from pathlib import Path From 8f17991a7c390aab0c905c373539616254d6c96d Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 3 Sep 2024 13:11:46 -0700 Subject: [PATCH 19/31] cleanup Signed-off-by: ashors1 --- nemo/lightning/pytorch/strategies/megatron_strategy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index 009954cacfca..7a492a562f59 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -638,7 +638,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: @override def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - if not self.ckpt_include_optimizer or "optimizer" not in checkpoint: + if not self.ckpt_include_optimizer: return optimizer_states = checkpoint["optimizer"] From a4e695491d63666ab896a4a25448cd8ed299fb0f Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 3 Sep 2024 21:14:29 -0700 Subject: [PATCH 20/31] make storage_options optional again Signed-off-by: ashors1 --- nemo/lightning/pytorch/strategies/fsdp_strategy.py | 2 +- nemo/lightning/pytorch/strategies/megatron_strategy.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/lightning/pytorch/strategies/fsdp_strategy.py b/nemo/lightning/pytorch/strategies/fsdp_strategy.py index ab87808d5643..12ac404a6672 100644 --- a/nemo/lightning/pytorch/strategies/fsdp_strategy.py +++ b/nemo/lightning/pytorch/strategies/fsdp_strategy.py @@ -196,7 +196,7 @@ def save_checkpoint( ## only save optimizer states if self.ckpt_include_optimizer and storage_options["include_optimizer"] ## are both True include_optimizer = self.ckpt_include_optimizer - if "include_optimizer" in storage_options: + if storage_options is not None and "include_optimizer" in storage_options: include_optimizer = include_optimizer and storage_options["include_optimizer"] del storage_options["include_optimizer"] diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index 7a492a562f59..85d5ad648d15 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -606,7 +606,7 @@ def save_checkpoint( ## only save optimizer states if self.ckpt_include_optimizer and storage_options["include_optimizer"] ## are both True include_optimizer = self.ckpt_include_optimizer - if "include_optimizer" in storage_options: + if storage_options is not None and "include_optimizer" in storage_options: include_optimizer = include_optimizer and storage_options["include_optimizer"] del storage_options["include_optimizer"] From 10e4f8823ca14d18f11fa391d9cd32b002c537bb Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 3 Sep 2024 22:17:25 -0700 Subject: [PATCH 21/31] fix failing tests Signed-off-by: ashors1 --- examples/llm/megatron_gpt_pretraining.py | 1 - tests/collections/llm/test_mnist_model_nemo2.py | 3 +-- tests/collections/llm/test_mnist_model_nemo2_fsdp.py | 3 +-- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/llm/megatron_gpt_pretraining.py b/examples/llm/megatron_gpt_pretraining.py index cfdb6a6acb4b..bf36971d35d6 100644 --- a/examples/llm/megatron_gpt_pretraining.py +++ b/examples/llm/megatron_gpt_pretraining.py @@ -71,7 +71,6 @@ def get_args(): strategy = nl.MegatronStrategy() checkpoint_callback = ModelCheckpoint( every_n_train_steps=5000, - enable_nemo_ckpt_io=False, ) callbacks = [checkpoint_callback] diff --git a/tests/collections/llm/test_mnist_model_nemo2.py b/tests/collections/llm/test_mnist_model_nemo2.py index c78306201751..219f3f647d93 100644 --- a/tests/collections/llm/test_mnist_model_nemo2.py +++ b/tests/collections/llm/test_mnist_model_nemo2.py @@ -496,13 +496,12 @@ def run_train_mnist_litautoencoder_with_megatron_strategy_single_gpu(): # Configure our custom Checkpointer name = "test_experiment" checkpoint_callback = nl_callbacks.ModelCheckpoint( - save_best_model=True, save_last=True, monitor="val_loss", save_top_k=1, every_n_train_steps=5, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe - enable_nemo_ckpt_io=True, + always_save_context=True, ) root_dir = tmpdir save_dir = root_dir / name diff --git a/tests/collections/llm/test_mnist_model_nemo2_fsdp.py b/tests/collections/llm/test_mnist_model_nemo2_fsdp.py index 32fde23bceb9..446b926f8612 100644 --- a/tests/collections/llm/test_mnist_model_nemo2_fsdp.py +++ b/tests/collections/llm/test_mnist_model_nemo2_fsdp.py @@ -519,13 +519,12 @@ def run_train_mnist_litautoencoder_with_fsdp_strategy_single_gpu(): # Configure our custom Checkpointer name = "test_experiment" checkpoint_callback = nl_callbacks.ModelCheckpoint( - save_best_model=True, save_last=True, monitor="val_loss", save_top_k=1, every_n_train_steps=5, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe - enable_nemo_ckpt_io=True, + always_save_context=True, ) root_dir = tmpdir save_dir = root_dir / name From a78717d52069b71d7ea22aed1e009ff95ee3637d Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 4 Sep 2024 08:44:22 -0700 Subject: [PATCH 22/31] address some comments Signed-off-by: ashors1 --- nemo/lightning/pytorch/callbacks/model_checkpoint.py | 4 ++-- nemo/lightning/resume.py | 5 +---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index 7fee3a8bc2ab..99bdda3d2540 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -274,7 +274,7 @@ def on_train_end(self, trainer, pl_module): else: super()._save_last_checkpoint(trainer, monitor_candidates) if self.save_context_on_train_end and not self.always_save_context and is_global_rank_zero(): - TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(self.last_model_path) / "artifacts") + TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(self.last_model_path) / "context") # Call parent on_train_end() to save the -last checkpoint super().on_train_end(trainer, pl_module) @@ -440,7 +440,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) trainer.save_checkpoint(ckpt_filepath, self.save_weights_only, storage_options=storage_options) if self.always_save_context and is_global_rank_zero(): - TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "artifacts") + TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "context") if self.async_save: logging.info(f'Scheduled async checkpoint save for {filepath}') diff --git a/nemo/lightning/resume.py b/nemo/lightning/resume.py index 0ba81c41e31b..2278ea9c0e5a 100644 --- a/nemo/lightning/resume.py +++ b/nemo/lightning/resume.py @@ -172,10 +172,7 @@ def nemo_path(self, model=None) -> Optional[Path]: if self.adapter_path: return AdapterPath(checkpoint, adapter_path=Path(self.adapter_path)) - model_weights_path = Path(checkpoint) / AutoResume.WEIGHTS_PATH - if os.path.isdir(model_weights_path): - return model_weights_path - return Path(checkpoint) + return Path(checkpoint) / AutoResume.WEIGHTS_PATH return None From 5916ee34572670c830fd6647029ea6e6cd33f342 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 4 Sep 2024 09:38:22 -0700 Subject: [PATCH 23/31] use save_weights_only to determine whether to save optimizer states Signed-off-by: ashors1 --- .../pytorch/callbacks/model_checkpoint.py | 14 +++++++------- .../pytorch/strategies/fsdp_strategy.py | 17 ++++------------- .../pytorch/strategies/megatron_strategy.py | 13 ++++--------- 3 files changed, 15 insertions(+), 29 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index 99bdda3d2540..d6abcfdcbccf 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -48,7 +48,7 @@ class ModelCheckpoint(PTLModelCheckpoint): ``every_n_epochs`` or ``every_n_train_steps``. save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch save_optim_on_train_end: Whether to include the optimizer states in the final checkpoint - at the end of training. + at the end of training. Only applicable when save_weights_only is ``True``. always_save_context: Whether to dump the artifacts needed to reinintialize the current model, trainer, and dataloader to allow for reproducibility of experiments. save_context_on_train_end: Whether to dump the artifacts on_train_end regardless of whether @@ -420,10 +420,8 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) super()._save_checkpoint(trainer, ckpt_filepath) self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) else: - ## Do not include optimizer states in final checkpoint - storage_options = dict( - include_optimizer=(trainer.global_step < trainer.max_steps or self.save_optim_on_train_end) - ) + ## Whether to include optimizer states + save_weights_only = self.save_weights_only or (not self.save_optim_on_train_end and trainer.global_step == trainer.max_steps) # Async save passes the finalization function to checkpoint_io, # sync save calls the finalization function immediately after save. @@ -434,10 +432,12 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) if not isinstance(checkpoint_io, AsyncFinalizableCheckpointIO): raise ValueError('Async save requires async compatible CheckpointIO') - storage_options["finalize_fn"] = finalize_fn + storage_options = dict(finalize_fn=finalize_fn) # Each upcoming ckpt removal request will be executed as part of this save finalization self.deferred_ckpts_to_remove.append([]) - trainer.save_checkpoint(ckpt_filepath, self.save_weights_only, storage_options=storage_options) + else: + storage_options = None + trainer.save_checkpoint(ckpt_filepath, save_weights_only, storage_options=storage_options) if self.always_save_context and is_global_rank_zero(): TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "context") diff --git a/nemo/lightning/pytorch/strategies/fsdp_strategy.py b/nemo/lightning/pytorch/strategies/fsdp_strategy.py index 12ac404a6672..c5127c6d90d2 100644 --- a/nemo/lightning/pytorch/strategies/fsdp_strategy.py +++ b/nemo/lightning/pytorch/strategies/fsdp_strategy.py @@ -189,18 +189,9 @@ def save_checkpoint( checkpoint["sharded_state_dict"] = pyt_to_mcore_state_dict(checkpoint.pop("state_dict")) checkpoint["state_dict"] = OrderedDict([]) - # TODO: do we still need to keep this? - for optim_state in checkpoint['optimizer_states']: - optim_state.pop("state") - - ## only save optimizer states if self.ckpt_include_optimizer and storage_options["include_optimizer"] - ## are both True - include_optimizer = self.ckpt_include_optimizer - if storage_options is not None and "include_optimizer" in storage_options: - include_optimizer = include_optimizer and storage_options["include_optimizer"] - del storage_options["include_optimizer"] - - if self.trainer.state.fn == TrainerFn.FITTING and include_optimizer: + ## replace unsharded optimizer_states with sharded dict + if "optimizer_states" in checkpoint: + del checkpoint["optimizer_states"] checkpoint['optimizer'] = get_optimizer_state_dict(self.model, self.optimizers) pyt_to_mcore_state_dict(checkpoint['optimizer']['state'], prefix="optimizer.state.") @@ -231,7 +222,7 @@ def load_checkpoint(self, checkpoint_path: str | Path) -> Dict[str, Any]: pyt_to_mcore_state_dict(msd) sharded_state_dict["sharded_state_dict"] = msd - if self.ckpt_include_optimizer and self.trainer.state.fn == TrainerFn.FITTING: + if self.ckpt_include_optimizer and self.trainer.state.fn == TrainerFn.FITTING: ## TODO: remove ckpt_include_optimizer osd = get_optimizer_state_dict(self.model, self.optimizers, options=StateDictOptions(cpu_offload=True)) pyt_to_mcore_state_dict(osd['state'], prefix="optimizer.state.") sharded_state_dict["optimizer"] = osd diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index 85d5ad648d15..30b4b7231769 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -603,14 +603,9 @@ def save_checkpoint( if "sharded_state_dict" not in checkpoint: checkpoint["sharded_state_dict"] = self.megatron_parallel.sharded_state_dict() - ## only save optimizer states if self.ckpt_include_optimizer and storage_options["include_optimizer"] - ## are both True - include_optimizer = self.ckpt_include_optimizer - if storage_options is not None and "include_optimizer" in storage_options: - include_optimizer = include_optimizer and storage_options["include_optimizer"] - del storage_options["include_optimizer"] - - if self.trainer.state.fn == TrainerFn.FITTING and include_optimizer: + ## replace unsharded optimizer_states with sharded dict + if "optimizer_states" in checkpoint: + del checkpoint["optimizer_states"] checkpoint["optimizer"] = [self.optimizer_sharded_state_dict()] self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options) @@ -638,7 +633,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: @override def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - if not self.ckpt_include_optimizer: + if not self.ckpt_include_optimizer: ## TODO: remove ckpt_include_optimizer return optimizer_states = checkpoint["optimizer"] From 5745105163c619c5b4ad0a62497b6c41011bf000 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 4 Sep 2024 16:39:34 +0000 Subject: [PATCH 24/31] Apply isort and black reformatting Signed-off-by: ashors1 --- nemo/lightning/pytorch/callbacks/model_checkpoint.py | 4 +++- nemo/lightning/pytorch/strategies/fsdp_strategy.py | 4 +++- nemo/lightning/pytorch/strategies/megatron_strategy.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index d6abcfdcbccf..946b0bd6ba81 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -421,7 +421,9 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) else: ## Whether to include optimizer states - save_weights_only = self.save_weights_only or (not self.save_optim_on_train_end and trainer.global_step == trainer.max_steps) + save_weights_only = self.save_weights_only or ( + not self.save_optim_on_train_end and trainer.global_step == trainer.max_steps + ) # Async save passes the finalization function to checkpoint_io, # sync save calls the finalization function immediately after save. diff --git a/nemo/lightning/pytorch/strategies/fsdp_strategy.py b/nemo/lightning/pytorch/strategies/fsdp_strategy.py index c5127c6d90d2..77efdd0de792 100644 --- a/nemo/lightning/pytorch/strategies/fsdp_strategy.py +++ b/nemo/lightning/pytorch/strategies/fsdp_strategy.py @@ -222,7 +222,9 @@ def load_checkpoint(self, checkpoint_path: str | Path) -> Dict[str, Any]: pyt_to_mcore_state_dict(msd) sharded_state_dict["sharded_state_dict"] = msd - if self.ckpt_include_optimizer and self.trainer.state.fn == TrainerFn.FITTING: ## TODO: remove ckpt_include_optimizer + if ( + self.ckpt_include_optimizer and self.trainer.state.fn == TrainerFn.FITTING + ): ## TODO: remove ckpt_include_optimizer osd = get_optimizer_state_dict(self.model, self.optimizers, options=StateDictOptions(cpu_offload=True)) pyt_to_mcore_state_dict(osd['state'], prefix="optimizer.state.") sharded_state_dict["optimizer"] = osd diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index 30b4b7231769..4cb4b062e9db 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -633,7 +633,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: @override def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - if not self.ckpt_include_optimizer: ## TODO: remove ckpt_include_optimizer + if not self.ckpt_include_optimizer: ## TODO: remove ckpt_include_optimizer return optimizer_states = checkpoint["optimizer"] From b9be6e9aa249036e7af41e4cc57c182dc033e43a Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 4 Sep 2024 14:00:56 -0700 Subject: [PATCH 25/31] add some comments Signed-off-by: ashors1 --- nemo/lightning/pytorch/callbacks/model_checkpoint.py | 9 +++++++-- nemo/lightning/pytorch/strategies/fsdp_strategy.py | 4 +++- nemo/lightning/pytorch/strategies/megatron_strategy.py | 4 +++- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index d6abcfdcbccf..b60dba3ef7e1 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -40,7 +40,8 @@ class ModelCheckpoint(PTLModelCheckpoint): verbose: Verbosity mode. save_last: When ``True``, saves a `*-last` copy whenever a checkpoint file gets saved. save_top_k: When ``True``, saves the top-k checkpoints according to ``monitor``. - save_weights_only: if ``True``, then only the model's weights will be saved. + save_weights_only: if ``True``, then only the model's weights will be saved. Optimizer states will + be omitted from all checkpoints. mode: One of {min, max}. Whether the objective is to minimize or maximize the monitored quantity. every_n_epochs: Number of epochs between checkpoints. every_n_train_steps: Number of train steps between checkpoints. @@ -420,7 +421,11 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) super()._save_checkpoint(trainer, ckpt_filepath) self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) else: - ## Whether to include optimizer states + ## Determine whether to include optimizer states in the checkpoint + ## optimizer states are included when + ## 1. save_weights_only is False and + ## 2. either save_optim_on_train_end is True, or save_optim_on_train_end is False but the checkpoint + ## is an intermediate checkpoint. save_weights_only = self.save_weights_only or (not self.save_optim_on_train_end and trainer.global_step == trainer.max_steps) # Async save passes the finalization function to checkpoint_io, diff --git a/nemo/lightning/pytorch/strategies/fsdp_strategy.py b/nemo/lightning/pytorch/strategies/fsdp_strategy.py index c5127c6d90d2..22a9bba25b8c 100644 --- a/nemo/lightning/pytorch/strategies/fsdp_strategy.py +++ b/nemo/lightning/pytorch/strategies/fsdp_strategy.py @@ -189,7 +189,9 @@ def save_checkpoint( checkpoint["sharded_state_dict"] = pyt_to_mcore_state_dict(checkpoint.pop("state_dict")) checkpoint["state_dict"] = OrderedDict([]) - ## replace unsharded optimizer_states with sharded dict + ## replace unsharded optimizer_states with sharded dict. + ## note that if trainer.save_checkpoint(path, save_weights_only=True) is called, + ## the checkpoint will contain only model weights. Optimizer states will be omitted. if "optimizer_states" in checkpoint: del checkpoint["optimizer_states"] checkpoint['optimizer'] = get_optimizer_state_dict(self.model, self.optimizers) diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index 30b4b7231769..dfb071e9cfc5 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -603,7 +603,9 @@ def save_checkpoint( if "sharded_state_dict" not in checkpoint: checkpoint["sharded_state_dict"] = self.megatron_parallel.sharded_state_dict() - ## replace unsharded optimizer_states with sharded dict + ## replace unsharded optimizer_states with sharded dict. + ## note that if trainer.save_checkpoint(path, save_weights_only=True) is called, + ## the checkpoint will contain only model weights. Optimizer states will be omitted. if "optimizer_states" in checkpoint: del checkpoint["optimizer_states"] checkpoint["optimizer"] = [self.optimizer_sharded_state_dict()] From 6befd8ab0dee3b2bb078cdc58b36aff4e6dd9a2d Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 5 Sep 2024 09:11:30 -0700 Subject: [PATCH 26/31] fix tests Signed-off-by: ashors1 --- tests/collections/llm/test_mnist_model_nemo2.py | 1 + tests/collections/llm/test_mnist_model_nemo2_fsdp.py | 1 + tests/lightning/test_nemo_logger.py | 6 +++--- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/collections/llm/test_mnist_model_nemo2.py b/tests/collections/llm/test_mnist_model_nemo2.py index 219f3f647d93..60707725c00a 100644 --- a/tests/collections/llm/test_mnist_model_nemo2.py +++ b/tests/collections/llm/test_mnist_model_nemo2.py @@ -571,6 +571,7 @@ def run_train_mnist_litautoencoder_with_megatron_strategy_single_gpu(): ckpt_path = checkpoint_callback.last_model_path.replace( ".ckpt", "" ) # strip .ckpt off the end of the last path + ckpt_path = Path(ckpt_path) / "weights" ## weights are saved to the "weights" directory within the checkpoint assert Path( ckpt_path diff --git a/tests/collections/llm/test_mnist_model_nemo2_fsdp.py b/tests/collections/llm/test_mnist_model_nemo2_fsdp.py index 446b926f8612..caa0e17e4e15 100644 --- a/tests/collections/llm/test_mnist_model_nemo2_fsdp.py +++ b/tests/collections/llm/test_mnist_model_nemo2_fsdp.py @@ -583,6 +583,7 @@ def run_train_mnist_litautoencoder_with_fsdp_strategy_single_gpu(): ckpt_path = checkpoint_callback.last_model_path.replace( ".ckpt", "" ) # strip .ckpt off the end of the last path + ckpt_path = Path(ckpt_path) / "weights" ## weights are saved to the "weights" directory within the checkpoint assert Path( ckpt_path diff --git a/tests/lightning/test_nemo_logger.py b/tests/lightning/test_nemo_logger.py index 3476f1361809..4e204a319464 100644 --- a/tests/lightning/test_nemo_logger.py +++ b/tests/lightning/test_nemo_logger.py @@ -157,7 +157,7 @@ def test_resume(self, trainer, tmp_path): resume_if_exists=True, ).setup(trainer) assert str(trainer.ckpt_path) == str( - Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last") + Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last" / "weights") ) # Finally succeed @@ -172,7 +172,7 @@ def test_resume(self, trainer, tmp_path): nl.AutoResume( resume_if_exists=True, ).setup(trainer) - checkpoint = Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--last") + checkpoint = Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--last" / "weights") assert Path(trainer.ckpt_path).resolve() == checkpoint.resolve() trainer = nl.Trainer(accelerator="cpu", logger=False) @@ -180,7 +180,7 @@ def test_resume(self, trainer, tmp_path): dirpath_log_dir = Path(tmp_path / "test_resume" / "dirpath_test" / "logs") dirpath_log_dir.mkdir(parents=True) dirpath_checkpoint_dir = Path(tmp_path / "test_resume" / "dirpath_test" / "ckpts") - dirpath_checkpoint = Path(dirpath_checkpoint_dir / "mymodel--last") + dirpath_checkpoint = Path(dirpath_checkpoint_dir / "mymodel--last" / "weights") dirpath_checkpoint.mkdir(parents=True) logger = nl.NeMoLogger( name="default", From 65ebadb3c6c4deb534d7b2d1abb762530e6d8693 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 5 Sep 2024 16:12:19 +0000 Subject: [PATCH 27/31] Apply isort and black reformatting Signed-off-by: ashors1 --- tests/collections/llm/test_mnist_model_nemo2.py | 4 +++- tests/collections/llm/test_mnist_model_nemo2_fsdp.py | 4 +++- tests/lightning/test_nemo_logger.py | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/collections/llm/test_mnist_model_nemo2.py b/tests/collections/llm/test_mnist_model_nemo2.py index 60707725c00a..c853ad3ecfb3 100644 --- a/tests/collections/llm/test_mnist_model_nemo2.py +++ b/tests/collections/llm/test_mnist_model_nemo2.py @@ -571,7 +571,9 @@ def run_train_mnist_litautoencoder_with_megatron_strategy_single_gpu(): ckpt_path = checkpoint_callback.last_model_path.replace( ".ckpt", "" ) # strip .ckpt off the end of the last path - ckpt_path = Path(ckpt_path) / "weights" ## weights are saved to the "weights" directory within the checkpoint + ckpt_path = ( + Path(ckpt_path) / "weights" + ) ## weights are saved to the "weights" directory within the checkpoint assert Path( ckpt_path diff --git a/tests/collections/llm/test_mnist_model_nemo2_fsdp.py b/tests/collections/llm/test_mnist_model_nemo2_fsdp.py index caa0e17e4e15..74c327ca345d 100644 --- a/tests/collections/llm/test_mnist_model_nemo2_fsdp.py +++ b/tests/collections/llm/test_mnist_model_nemo2_fsdp.py @@ -583,7 +583,9 @@ def run_train_mnist_litautoencoder_with_fsdp_strategy_single_gpu(): ckpt_path = checkpoint_callback.last_model_path.replace( ".ckpt", "" ) # strip .ckpt off the end of the last path - ckpt_path = Path(ckpt_path) / "weights" ## weights are saved to the "weights" directory within the checkpoint + ckpt_path = ( + Path(ckpt_path) / "weights" + ) ## weights are saved to the "weights" directory within the checkpoint assert Path( ckpt_path diff --git a/tests/lightning/test_nemo_logger.py b/tests/lightning/test_nemo_logger.py index 4e204a319464..7a999e771b09 100644 --- a/tests/lightning/test_nemo_logger.py +++ b/tests/lightning/test_nemo_logger.py @@ -172,7 +172,9 @@ def test_resume(self, trainer, tmp_path): nl.AutoResume( resume_if_exists=True, ).setup(trainer) - checkpoint = Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--last" / "weights") + checkpoint = Path( + tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--last" / "weights" + ) assert Path(trainer.ckpt_path).resolve() == checkpoint.resolve() trainer = nl.Trainer(accelerator="cpu", logger=False) From 09a38a5e441992accb58ddcc7e31f3acfac328b6 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 5 Sep 2024 16:41:43 +0000 Subject: [PATCH 28/31] Apply isort and black reformatting Signed-off-by: ashors1 --- nemo/lightning/pytorch/strategies/fsdp_strategy.py | 6 +++++- nemo/lightning/pytorch/strategies/megatron_strategy.py | 6 +++++- nemo/lightning/resume.py | 2 +- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/nemo/lightning/pytorch/strategies/fsdp_strategy.py b/nemo/lightning/pytorch/strategies/fsdp_strategy.py index a4d1dbd7275b..2a210c9bd7f0 100644 --- a/nemo/lightning/pytorch/strategies/fsdp_strategy.py +++ b/nemo/lightning/pytorch/strategies/fsdp_strategy.py @@ -211,7 +211,11 @@ def save_checkpoint( ## replace unsharded optimizer_states with sharded dict. ## note that if trainer.save_checkpoint(path, save_weights_only=True) is called, ## the checkpoint will contain only model weights. Optimizer states will be omitted. - if "optimizer_states" in checkpoint and self.trainer.state.fn == TrainerFn.FITTING and self.ckpt_save_optimizer: + if ( + "optimizer_states" in checkpoint + and self.trainer.state.fn == TrainerFn.FITTING + and self.ckpt_save_optimizer + ): del checkpoint["optimizer_states"] checkpoint['optimizer'] = get_optimizer_state_dict(self.model, self.optimizers) pyt_to_mcore_state_dict(checkpoint['optimizer']['state'], prefix="optimizer.state.") diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index 80b01774afa6..4bf8c42ece02 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -629,7 +629,11 @@ def save_checkpoint( ## replace unsharded optimizer_states with sharded dict. ## note that if trainer.save_checkpoint(path, save_weights_only=True) is called, ## the checkpoint will contain only model weights. Optimizer states will be omitted. - if "optimizer_states" in checkpoint and self.trainer.state.fn == TrainerFn.FITTING and self.ckpt_save_optimizer: + if ( + "optimizer_states" in checkpoint + and self.trainer.state.fn == TrainerFn.FITTING + and self.ckpt_save_optimizer + ): del checkpoint["optimizer_states"] checkpoint["optimizer"] = [self.optimizer_sharded_state_dict()] diff --git a/nemo/lightning/resume.py b/nemo/lightning/resume.py index d98a03888af4..d383cd50d7f7 100644 --- a/nemo/lightning/resume.py +++ b/nemo/lightning/resume.py @@ -109,7 +109,7 @@ def _try_import_model( new_path = path if adapter_path: - + new_path = AdapterPath(Path(adapter_path), base_model_path=new_path) if isinstance(new_path, str): From 987e4f669555de83d89f8fa36cc80bb5f0bfa2bc Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 5 Sep 2024 12:57:59 -0700 Subject: [PATCH 29/31] fixes Signed-off-by: ashors1 --- nemo/lightning/resume.py | 19 ++++++++++--------- tests/lightning/test_nemo_logger.py | 13 ++++++++----- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/nemo/lightning/resume.py b/nemo/lightning/resume.py index d98a03888af4..bf7977f4d480 100644 --- a/nemo/lightning/resume.py +++ b/nemo/lightning/resume.py @@ -69,7 +69,7 @@ class AutoResume: WEIGHTS_PATH = "weights" def get_model_weights_path(self, path): - return Path(path) / AutoResume.MODEL_WEIGHTS_PATH + return Path(path) / self.WEIGHTS_PATH def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None): if isinstance(trainer, fl.Fabric): @@ -96,10 +96,6 @@ def _try_import_model( self, model: Optional[io.ConnectorMixin], path: str, adapter_path: Optional[str] = None ) -> BasePath: - maybe_model_weights_path = self.get_model_weights_path(path) - if os.path.isdir(maybe_model_weights_path): - path = maybe_model_weights_path - if model is None: raise ValueError("Model is needed to import checkpoint from HF or other non-NeMo checkpoint format.") try: @@ -109,6 +105,10 @@ def _try_import_model( new_path = path if adapter_path: + + maybe_model_weights_path = self.get_model_weights_path(adapter_path) + if os.path.isdir(maybe_model_weights_path): + adapter_path = maybe_model_weights_path new_path = AdapterPath(Path(adapter_path), base_model_path=new_path) @@ -213,7 +213,7 @@ def _find_trainer_ckpt_path(self) -> Optional[Path]: else: checkpoint = last_checkpoints[0] - return Path(checkpoint) + return checkpoint def get_trainer_ckpt_path(self, model: Optional[io.ConnectorMixin] = None) -> Optional[Path]: checkpoint = None @@ -222,9 +222,10 @@ def get_trainer_ckpt_path(self, model: Optional[io.ConnectorMixin] = None) -> Op if self.resume_if_exists: checkpoint = self._find_trainer_ckpt_path() - maybe_model_weights_path = self.get_model_weights_path(checkpoint) - if os.path.isdir(maybe_model_weights_path): - checkpoint = maybe_model_weights_path + if checkpoint: + maybe_model_weights_path = self.get_model_weights_path(checkpoint) + if os.path.isdir(maybe_model_weights_path): + checkpoint = maybe_model_weights_path if checkpoint: if self.adapter_path: diff --git a/tests/lightning/test_nemo_logger.py b/tests/lightning/test_nemo_logger.py index fe592b6abb22..6814ea608ad1 100644 --- a/tests/lightning/test_nemo_logger.py +++ b/tests/lightning/test_nemo_logger.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import shutil import time from pathlib import Path from unittest.mock import patch @@ -159,21 +160,23 @@ def test_resume(self, trainer, tmp_path): ## if there are multiple "-last" checkpoints, choose the most recent one Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--last").mkdir() time.sleep(1) ## sleep for a second so the checkpoints are created at different times - Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last").mkdir() + ## make a "weights" dir within the checkpoint + Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last" / "weights").mkdir(parents=True) time.sleep(1) # unfinished last, that should be ignored Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel3--last").mkdir() + time.sleep(1) Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel3--last-unfinished").touch() nl.AutoResume( resume_from_directory=Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints"), resume_if_exists=True, ).setup(trainer) + ## if "weights" exists, we should restore from there assert str(trainer.ckpt_path) == str( Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last" / "weights") ) - # Finally succeed logger = nl.NeMoLogger( name="default", dir=str(tmp_path) + "/test_resume", @@ -181,12 +184,12 @@ def test_resume(self, trainer, tmp_path): use_datetime_version=False, ) logger.setup(trainer) - Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last").rmdir() + shutil.rmtree(Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last")) nl.AutoResume( resume_if_exists=True, ).setup(trainer) checkpoint = Path( - tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--last" / "weights" + tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--last" ) assert Path(trainer.ckpt_path).resolve() == checkpoint.resolve() @@ -195,7 +198,7 @@ def test_resume(self, trainer, tmp_path): dirpath_log_dir = Path(tmp_path / "test_resume" / "dirpath_test" / "logs") dirpath_log_dir.mkdir(parents=True) dirpath_checkpoint_dir = Path(tmp_path / "test_resume" / "dirpath_test" / "ckpts") - dirpath_checkpoint = Path(dirpath_checkpoint_dir / "mymodel--last" / "weights") + dirpath_checkpoint = Path(dirpath_checkpoint_dir / "mymodel--last") dirpath_checkpoint.mkdir(parents=True) logger = nl.NeMoLogger( name="default", From 54591a6493cc4d08e80b546056ec3f6c828d5a9c Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 5 Sep 2024 20:00:38 +0000 Subject: [PATCH 30/31] Apply isort and black reformatting Signed-off-by: ashors1 --- nemo/lightning/resume.py | 2 +- tests/lightning/test_nemo_logger.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/nemo/lightning/resume.py b/nemo/lightning/resume.py index bf7977f4d480..c8cefb4dd8d3 100644 --- a/nemo/lightning/resume.py +++ b/nemo/lightning/resume.py @@ -109,7 +109,7 @@ def _try_import_model( maybe_model_weights_path = self.get_model_weights_path(adapter_path) if os.path.isdir(maybe_model_weights_path): adapter_path = maybe_model_weights_path - + new_path = AdapterPath(Path(adapter_path), base_model_path=new_path) if isinstance(new_path, str): diff --git a/tests/lightning/test_nemo_logger.py b/tests/lightning/test_nemo_logger.py index 6814ea608ad1..536e8344eed3 100644 --- a/tests/lightning/test_nemo_logger.py +++ b/tests/lightning/test_nemo_logger.py @@ -161,7 +161,9 @@ def test_resume(self, trainer, tmp_path): Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--last").mkdir() time.sleep(1) ## sleep for a second so the checkpoints are created at different times ## make a "weights" dir within the checkpoint - Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last" / "weights").mkdir(parents=True) + Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last" / "weights").mkdir( + parents=True + ) time.sleep(1) # unfinished last, that should be ignored Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel3--last").mkdir() @@ -188,9 +190,7 @@ def test_resume(self, trainer, tmp_path): nl.AutoResume( resume_if_exists=True, ).setup(trainer) - checkpoint = Path( - tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--last" - ) + checkpoint = Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--last") assert Path(trainer.ckpt_path).resolve() == checkpoint.resolve() trainer = nl.Trainer(accelerator="cpu", logger=False) From 1974e97bf4d251effc55ff5fa5c31cb875bdfaaa Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 5 Sep 2024 13:02:37 -0700 Subject: [PATCH 31/31] remove unnecessary line Signed-off-by: ashors1 --- tests/lightning/test_nemo_logger.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/lightning/test_nemo_logger.py b/tests/lightning/test_nemo_logger.py index 536e8344eed3..387d3540930f 100644 --- a/tests/lightning/test_nemo_logger.py +++ b/tests/lightning/test_nemo_logger.py @@ -167,7 +167,6 @@ def test_resume(self, trainer, tmp_path): time.sleep(1) # unfinished last, that should be ignored Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel3--last").mkdir() - time.sleep(1) Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel3--last-unfinished").touch() nl.AutoResume(