Skip to content

Commit

Permalink
[NeMo-UX] checkpointing improvements (NVIDIA#10241)
Browse files Browse the repository at this point in the history
* save model weights and artifacts to separate directories

Signed-off-by: ashors1 <ashors@nvidia.com>

* add save_artifacts_on_train_end

Signed-off-by: ashors1 <ashors@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: ashors1 <ashors1@users.noreply.github.com>

* do not save optimizer states in final checkpoint

Signed-off-by: ashors1 <ashors@nvidia.com>

* WIP support for saving only last k optimizer states

Signed-off-by: ashors1 <ashors@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: ashors1 <ashors1@users.noreply.github.com>

* minor cleanup

Signed-off-by: ashors1 <ashors@nvidia.com>

* Revert support for saving last k optimizer states. This will be addressed in a subsequent PR.

* use storage_options to determine when to skip saving optimizer states

Signed-off-by: ashors1 <ashors@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: ashors1 <ashors1@users.noreply.github.com>

* fix variable names, make checkpoint load work when optimizer states don't exist in the checkpoint

Signed-off-by: ashors1 <ashors@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: ashors1 <ashors1@users.noreply.github.com>

* FSDP updates, provide option to save optimizer states on_train_end

Signed-off-by: ashors1 <ashors@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: ashors1 <ashors1@users.noreply.github.com>

* simplify implementation, remove save_best_model option

Signed-off-by: ashors1 <ashors@nvidia.com>

* update default value of ckpt_include_optimizer for fsdp

Signed-off-by: ashors1 <ashors@nvidia.com>

* remove unused imports

Signed-off-by: ashors1 <ashors@nvidia.com>

* remove unused import

Signed-off-by: ashors1 <ashors@nvidia.com>

* cleanup

Signed-off-by: ashors1 <ashors@nvidia.com>

* make storage_options optional again

Signed-off-by: ashors1 <ashors@nvidia.com>

* fix failing tests

Signed-off-by: ashors1 <ashors@nvidia.com>

* address some comments

Signed-off-by: ashors1 <ashors@nvidia.com>

* use save_weights_only to determine whether to save optimizer states

Signed-off-by: ashors1 <ashors@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: ashors1 <ashors1@users.noreply.github.com>

* add some comments

Signed-off-by: ashors1 <ashors@nvidia.com>

* fix tests

Signed-off-by: ashors1 <ashors@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: ashors1 <ashors1@users.noreply.github.com>

* Apply isort and black reformatting

Signed-off-by: ashors1 <ashors1@users.noreply.github.com>

* fixes

Signed-off-by: ashors1 <ashors@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: ashors1 <ashors1@users.noreply.github.com>

* remove unnecessary line

Signed-off-by: ashors1 <ashors@nvidia.com>

---------

Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors1@users.noreply.github.com>
Co-authored-by: ashors1 <ashors1@users.noreply.github.com>
Signed-off-by: adityavavre <aditya.vavre@gmail.com>
  • Loading branch information
2 people authored and adityavavre committed Sep 15, 2024
1 parent f76944f commit 2bfcda1
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 60 deletions.
1 change: 0 additions & 1 deletion examples/llm/megatron_gpt_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def get_args():
strategy = nl.MegatronStrategy()
checkpoint_callback = ModelCheckpoint(
every_n_train_steps=5000,
enable_nemo_ckpt_io=False,
)
callbacks = [checkpoint_callback]

Expand Down
4 changes: 2 additions & 2 deletions nemo/lightning/io/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def io_dump(self, output: Path):
will be stored.
"""
output_path = Path(output)
local_artifacts_dir = "artifacts"
local_artifacts_dir = "."
artifacts_dir = output_path / local_artifacts_dir
artifacts_dir.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -518,7 +518,7 @@ def _io_path_elements_fn(x):
return x.__io__.__path_elements__()


def _artifact_transform_save(cfg: fdl.Config, output_path: Path, relative_dir: Path = "artifacts"):
def _artifact_transform_save(cfg: fdl.Config, output_path: Path, relative_dir: Path = "."):
for artifact in getattr(cfg.__fn_or_cls__, "__io_artifacts__", []):
current_val = getattr(cfg, artifact.attr)
if current_val is None:
Expand Down
83 changes: 40 additions & 43 deletions nemo/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,25 @@ class ModelCheckpoint(PTLModelCheckpoint):
verbose: Verbosity mode.
save_last: When ``True``, saves a `*-last` copy whenever a checkpoint file gets saved.
save_top_k: When ``True``, saves the top-k checkpoints according to ``monitor``.
save_weights_only: if ``True``, then only the model's weights will be saved.
save_weights_only: if ``True``, then only the model's weights will be saved. Optimizer states will
be omitted from all checkpoints.
mode: One of {min, max}. Whether the objective is to minimize or maximize the monitored quantity.
every_n_epochs: Number of epochs between checkpoints.
every_n_train_steps: Number of train steps between checkpoints.
train_time_interval: After each interval, monitor checkpoints. Not to be used with
``every_n_epochs`` or ``every_n_train_steps``.
save_best_model: When ``True``, reloads and saves the best checkpoint.
save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch
enable_nemo_ckpt_io: Whether to dump the current model model state, including the
config file, to allow for reproducibility of experiments.
save_optim_on_train_end: Whether to include the optimizer states in the final checkpoint
at the end of training. Only applicable when save_weights_only is ``True``.
always_save_context: Whether to dump the artifacts needed to reinintialize the current
model, trainer, and dataloader to allow for reproducibility of experiments.
save_context_on_train_end: Whether to dump the artifacts on_train_end regardless of whether
``always_save_context`` is ``True``.
async_save: Whether to enable asynchronous checkpointing.
try_restore_best_ckpt: Whether to restore the best model path.
"""

UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished"
WEIGHTS_PATH = "weights"

def __init__(
self,
Expand All @@ -67,21 +71,21 @@ def __init__(
every_n_epochs: int = None,
every_n_train_steps: Optional[int] = None,
train_time_interval: Optional[timedelta] = None,
save_best_model: bool = False,
save_on_train_epoch_end: Optional[bool] = False, # Save after training, not after validation
enable_nemo_ckpt_io: bool = True,
try_restore_best_ckpt: bool = True,
save_optim_on_train_end: Optional[bool] = False,
always_save_context: bool = False,
save_context_on_train_end: bool = True,
**kwargs,
):
self.save_best_model = save_best_model
self.previous_best_path = ""
self.enable_nemo_ckpt_io = enable_nemo_ckpt_io
self.always_save_context = always_save_context
self.save_context_on_train_end = save_context_on_train_end
self.save_optim_on_train_end = save_optim_on_train_end

# Checkpoints which removal is deferred until async save is done.
# Each element of `deferred_ckpts_to_remove` is a growing list
# that `self._remove_checkpoint` adds to. Once `self._save_checkpoint`
# is called, the last element is frozen and a new element is added.
self.deferred_ckpts_to_remove: List[List[str]] = []
self.try_restore_best_ckpt = try_restore_best_ckpt

# Call the parent class constructor with the remaining kwargs.
super().__init__(
Expand Down Expand Up @@ -251,11 +255,9 @@ def setup(self, trainer, *args, **kwargs) -> None:
self.async_save = getattr(trainer.strategy, "async_save", False)
super().setup(trainer, *args, **kwargs)

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
output = super().on_save_checkpoint(trainer, pl_module, checkpoint)
return output

def on_train_end(self, trainer, pl_module):
from nemo.utils.get_rank import is_global_rank_zero

if trainer.fast_dev_run:
return None

Expand All @@ -272,26 +274,11 @@ def on_train_end(self, trainer, pl_module):
logging.debug(f'Last checkpoint {self.last_model_path} already saved')
else:
super()._save_last_checkpoint(trainer, monitor_candidates)
if self.save_context_on_train_end and not self.always_save_context and is_global_rank_zero():
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(self.last_model_path) / "context")
# Call parent on_train_end() to save the -last checkpoint
super().on_train_end(trainer, pl_module)

# Load the best model and then re-save it
if self.save_best_model:
# wait for all processes
trainer.strategy.barrier("SaveBestCheckpointConnector.resume_end")
if self.best_model_path == "":
logging.warning(
f"{self} was told to save the best checkpoint at the end of training, but no saved checkpoints "
"were found. Saving latest model instead."
)

else:
if os.path.isdir(self.best_model_path.split('.ckpt')[0]):
self.best_model_path = self.best_model_path.split('.ckpt')[0]
if self.try_restore_best_ckpt:
self.best_model_path = trainer.strategy.broadcast(self.best_model_path)
trainer._checkpoint_connector.restore(self.best_model_path)

def _del_model_without_trainer(self, filepath: str) -> None:
from nemo.utils.get_rank import is_global_rank_zero

Expand Down Expand Up @@ -409,8 +396,11 @@ def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, torch.Tensor]:
return monitor_candidates

def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) -> None:
from nemo.utils.get_rank import is_global_rank_zero

# barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed.
# if anything goes wrong during checkpointing, we should be able to detect that data is incomplete.
ckpt_filepath = ckpt_to_dir(filepath) / ModelCheckpoint.WEIGHTS_PATH
self.set_checkpoint_unfinished_marker(filepath, barrier_after=True)
ema_callback = self._ema_callback(trainer)

Expand All @@ -420,17 +410,26 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)
if self.async_save:
raise ValueError('async_save with EMA not supported')
with ema_callback.save_original_optimizer_state(trainer):
super()._save_checkpoint(trainer, filepath)
super()._save_checkpoint(trainer, ckpt_filepath)

# save EMA copy of the model as well.
with ema_callback.save_ema_model(trainer):
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
filepath = self._ema_format_filepath(filepath)
rank_zero_info(f"Saving EMA weights to separate checkpoint {ckpt_filepath}")
ckpt_filepath = self._ema_format_filepath(ckpt_filepath)
if self.verbose:
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
super()._save_checkpoint(trainer, filepath)
rank_zero_info(f"Saving EMA weights to separate checkpoint {ckpt_filepath}")
super()._save_checkpoint(trainer, ckpt_filepath)
self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True)
else:
## Determine whether to include optimizer states in the checkpoint
## optimizer states are included when
## 1. save_weights_only is False and
## 2. either save_optim_on_train_end is True, or save_optim_on_train_end is False but the checkpoint
## is an intermediate checkpoint.
save_weights_only = self.save_weights_only or (
not self.save_optim_on_train_end and trainer.global_step == trainer.max_steps
)

# Async save passes the finalization function to checkpoint_io,
# sync save calls the finalization function immediately after save.
finalize_fn = self._get_finalize_save_checkpoint_callback(trainer, filepath, trainer.global_step)
Expand All @@ -445,13 +444,11 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)
self.deferred_ckpts_to_remove.append([])
else:
storage_options = None
trainer.save_checkpoint(filepath, self.save_weights_only, storage_options=storage_options)
trainer.save_checkpoint(ckpt_filepath, save_weights_only, storage_options=storage_options)

## NOTE: saving context happens synchronously always
from nemo.utils.get_rank import is_global_rank_zero
if self.always_save_context and is_global_rank_zero():
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "context")

if self.enable_nemo_ckpt_io and is_global_rank_zero():
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath))
if self.async_save:
logging.info(f'Scheduled async checkpoint save for {filepath}')
else:
Expand Down
14 changes: 9 additions & 5 deletions nemo/lightning/pytorch/strategies/fsdp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,15 @@ def save_checkpoint(
checkpoint["sharded_state_dict"] = pyt_to_mcore_state_dict(checkpoint.pop("state_dict"))
checkpoint["state_dict"] = OrderedDict([])

# TODO: do we still need to keep this?
for optim_state in checkpoint['optimizer_states']:
optim_state.pop("state")

if self.trainer.state.fn == TrainerFn.FITTING and self.ckpt_save_optimizer:
## replace unsharded optimizer_states with sharded dict.
## note that if trainer.save_checkpoint(path, save_weights_only=True) is called,
## the checkpoint will contain only model weights. Optimizer states will be omitted.
if (
"optimizer_states" in checkpoint
and self.trainer.state.fn == TrainerFn.FITTING
and self.ckpt_save_optimizer
):
del checkpoint["optimizer_states"]
checkpoint['optimizer'] = get_optimizer_state_dict(self.model, self.optimizers)
pyt_to_mcore_state_dict(checkpoint['optimizer']['state'], prefix="optimizer.state.")

Expand Down
11 changes: 10 additions & 1 deletion nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,16 @@ def save_checkpoint(
# retrieve `sharded_state_dict` if it has not already been configured in `on_save_checkpoint`
if "sharded_state_dict" not in checkpoint:
checkpoint["sharded_state_dict"] = self.megatron_parallel.sharded_state_dict()
if self.trainer.state.fn == TrainerFn.FITTING and self.ckpt_save_optimizer:

## replace unsharded optimizer_states with sharded dict.
## note that if trainer.save_checkpoint(path, save_weights_only=True) is called,
## the checkpoint will contain only model weights. Optimizer states will be omitted.
if (
"optimizer_states" in checkpoint
and self.trainer.state.fn == TrainerFn.FITTING
and self.ckpt_save_optimizer
):
del checkpoint["optimizer_states"]
checkpoint["optimizer"] = [self.optimizer_sharded_state_dict()]

self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
Expand Down
16 changes: 16 additions & 0 deletions nemo/lightning/resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ class AutoResume:
resume_past_end: bool = False
resume_ignore_no_checkpoint: bool = False

WEIGHTS_PATH = "weights"

def get_model_weights_path(self, path):
return Path(path) / self.WEIGHTS_PATH

def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None):
if isinstance(trainer, fl.Fabric):
raise NotImplementedError("Fabric is not supported yet.")
Expand All @@ -90,6 +95,7 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None):
def _try_import_model(
self, model: Optional[io.ConnectorMixin], path: str, adapter_path: Optional[str] = None
) -> BasePath:

if model is None:
raise ValueError("Model is needed to import checkpoint from HF or other non-NeMo checkpoint format.")
try:
Expand All @@ -99,6 +105,11 @@ def _try_import_model(
new_path = path

if adapter_path:

maybe_model_weights_path = self.get_model_weights_path(adapter_path)
if os.path.isdir(maybe_model_weights_path):
adapter_path = maybe_model_weights_path

new_path = AdapterPath(Path(adapter_path), base_model_path=new_path)

if isinstance(new_path, str):
Expand Down Expand Up @@ -211,6 +222,11 @@ def get_trainer_ckpt_path(self, model: Optional[io.ConnectorMixin] = None) -> Op
if self.resume_if_exists:
checkpoint = self._find_trainer_ckpt_path()

if checkpoint:
maybe_model_weights_path = self.get_model_weights_path(checkpoint)
if os.path.isdir(maybe_model_weights_path):
checkpoint = maybe_model_weights_path

if checkpoint:
if self.adapter_path:
return AdapterPath(Path(self.adapter_path), base_model_path=checkpoint)
Expand Down
6 changes: 4 additions & 2 deletions tests/collections/llm/test_mnist_model_nemo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,13 +496,12 @@ def run_train_mnist_litautoencoder_with_megatron_strategy_single_gpu():
# Configure our custom Checkpointer
name = "test_experiment"
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_best_model=True,
save_last=True,
monitor="val_loss",
save_top_k=1,
every_n_train_steps=5,
# Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
enable_nemo_ckpt_io=True,
always_save_context=True,
)
root_dir = tmpdir
save_dir = root_dir / name
Expand Down Expand Up @@ -571,6 +570,9 @@ def run_train_mnist_litautoencoder_with_megatron_strategy_single_gpu():
ckpt_path = checkpoint_callback.last_model_path.replace(
".ckpt", ""
) # strip .ckpt off the end of the last path
ckpt_path = (
Path(ckpt_path) / "weights"
) ## weights are saved to the "weights" directory within the checkpoint

assert Path(
ckpt_path
Expand Down
6 changes: 4 additions & 2 deletions tests/collections/llm/test_mnist_model_nemo2_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,13 +519,12 @@ def run_train_mnist_litautoencoder_with_fsdp_strategy_single_gpu():
# Configure our custom Checkpointer
name = "test_experiment"
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_best_model=True,
save_last=True,
monitor="val_loss",
save_top_k=1,
every_n_train_steps=5,
# Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
enable_nemo_ckpt_io=True,
always_save_context=True,
)
root_dir = tmpdir
save_dir = root_dir / name
Expand Down Expand Up @@ -583,6 +582,9 @@ def run_train_mnist_litautoencoder_with_fsdp_strategy_single_gpu():
ckpt_path = checkpoint_callback.last_model_path.replace(
".ckpt", ""
) # strip .ckpt off the end of the last path
ckpt_path = (
Path(ckpt_path) / "weights"
) ## weights are saved to the "weights" directory within the checkpoint

assert Path(
ckpt_path
Expand Down
12 changes: 8 additions & 4 deletions tests/lightning/test_nemo_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
import shutil
import time
from pathlib import Path
from unittest.mock import patch
Expand Down Expand Up @@ -159,7 +160,10 @@ def test_resume(self, trainer, tmp_path):
## if there are multiple "-last" checkpoints, choose the most recent one
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--last").mkdir()
time.sleep(1) ## sleep for a second so the checkpoints are created at different times
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last").mkdir()
## make a "weights" dir within the checkpoint
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last" / "weights").mkdir(
parents=True
)
time.sleep(1)
# unfinished last, that should be ignored
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel3--last").mkdir()
Expand All @@ -169,19 +173,19 @@ def test_resume(self, trainer, tmp_path):
resume_from_directory=Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints"),
resume_if_exists=True,
).setup(trainer)
## if "weights" exists, we should restore from there
assert str(trainer.ckpt_path) == str(
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last")
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last" / "weights")
)

# Finally succeed
logger = nl.NeMoLogger(
name="default",
dir=str(tmp_path) + "/test_resume",
version="version_0",
use_datetime_version=False,
)
logger.setup(trainer)
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last").rmdir()
shutil.rmtree(Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last"))
nl.AutoResume(
resume_if_exists=True,
).setup(trainer)
Expand Down

0 comments on commit 2bfcda1

Please sign in to comment.