Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NeMo-UX] checkpointing improvements #10241

Merged
merged 41 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
d98db0a
save model weights and artifacts to separate directories
ashors1 Aug 22, 2024
ca04f47
add save_artifacts_on_train_end
ashors1 Aug 24, 2024
9264eee
Apply isort and black reformatting
ashors1 Aug 24, 2024
99cff3a
Merge branch 'main' of github.com:NVIDIA/NeMo into ashors/nemo-ux-ckp…
ashors1 Aug 24, 2024
8e7f3f9
Merge branch 'ashors/nemo-ux-ckpt-dir' of github.com:NVIDIA/NeMo into…
ashors1 Aug 24, 2024
1a5a457
do not save optimizer states in final checkpoint
ashors1 Aug 28, 2024
c9302fa
WIP support for saving only last k optimizer states
ashors1 Aug 28, 2024
6da2371
Apply isort and black reformatting
ashors1 Aug 28, 2024
9b1f93c
minor cleanup
ashors1 Aug 28, 2024
402471d
Revert support for saving last k optimizer states. This will be addre…
ashors1 Aug 29, 2024
0a55953
use storage_options to determine when to skip saving optimizer states
ashors1 Aug 30, 2024
aa67b8a
Apply isort and black reformatting
ashors1 Aug 30, 2024
11001a0
fix variable names, make checkpoint load work when optimizer states d…
ashors1 Aug 30, 2024
483c942
Merge branch 'ashors/nemo-ux-ckpt-dir' of github.com:NVIDIA/NeMo into…
ashors1 Aug 30, 2024
3b3b779
Apply isort and black reformatting
ashors1 Aug 30, 2024
db7e0de
Merge branch 'main' of github.com:NVIDIA/NeMo into ashors/nemo-ux-ckp…
ashors1 Aug 30, 2024
1b9696f
FSDP updates, provide option to save optimizer states on_train_end
ashors1 Aug 30, 2024
21fcb40
Apply isort and black reformatting
ashors1 Aug 30, 2024
b9a0e9e
simplify implementation, remove save_best_model option
ashors1 Aug 30, 2024
a740bf9
Merge branch 'ashors/nemo-ux-ckpt-dir' of github.com:NVIDIA/NeMo into…
ashors1 Aug 30, 2024
3c74851
update default value of ckpt_include_optimizer for fsdp
ashors1 Aug 30, 2024
3fd76ad
remove unused imports
ashors1 Sep 3, 2024
2817d56
remove unused import
ashors1 Sep 3, 2024
8f17991
cleanup
ashors1 Sep 3, 2024
a4e6954
make storage_options optional again
ashors1 Sep 4, 2024
10e4f88
fix failing tests
ashors1 Sep 4, 2024
a78717d
address some comments
ashors1 Sep 4, 2024
5916ee3
use save_weights_only to determine whether to save optimizer states
ashors1 Sep 4, 2024
5745105
Apply isort and black reformatting
ashors1 Sep 4, 2024
b9be6e9
add some comments
ashors1 Sep 4, 2024
adf418e
Merge branch 'ashors/nemo-ux-ckpt-dir' of github.com:NVIDIA/NeMo into…
ashors1 Sep 4, 2024
6befd8a
fix tests
ashors1 Sep 5, 2024
65ebadb
Apply isort and black reformatting
ashors1 Sep 5, 2024
a95edf3
Merge branch 'main' of github.com:NVIDIA/NeMo into ashors/nemo-ux-ckp…
ashors1 Sep 5, 2024
9643ef5
Merge branch 'ashors/nemo-ux-ckpt-dir' of github.com:NVIDIA/NeMo into…
ashors1 Sep 5, 2024
09a38a5
Apply isort and black reformatting
ashors1 Sep 5, 2024
987e4f6
fixes
ashors1 Sep 5, 2024
6c5a1c6
Merge branch 'ashors/nemo-ux-ckpt-dir' of github.com:NVIDIA/NeMo into…
ashors1 Sep 5, 2024
54591a6
Apply isort and black reformatting
ashors1 Sep 5, 2024
1974e97
remove unnecessary line
ashors1 Sep 5, 2024
10d619e
Merge branch 'main' of github.com:NVIDIA/NeMo into ashors/nemo-ux-ckp…
ashors1 Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/llm/megatron_gpt_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def get_args():
strategy = nl.MegatronStrategy()
checkpoint_callback = ModelCheckpoint(
every_n_train_steps=5000,
enable_nemo_ckpt_io=False,
)
callbacks = [checkpoint_callback]

Expand Down
4 changes: 2 additions & 2 deletions nemo/lightning/io/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def io_dump(self, output: Path):
will be stored.
"""
output_path = Path(output)
local_artifacts_dir = "artifacts"
local_artifacts_dir = "."
artifacts_dir = output_path / local_artifacts_dir
artifacts_dir.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -518,7 +518,7 @@ def _io_path_elements_fn(x):
return x.__io__.__path_elements__()


def _artifact_transform_save(cfg: fdl.Config, output_path: Path, relative_dir: Path = "artifacts"):
def _artifact_transform_save(cfg: fdl.Config, output_path: Path, relative_dir: Path = "."):
for artifact in getattr(cfg.__fn_or_cls__, "__io_artifacts__", []):
current_val = getattr(cfg, artifact.attr)
if current_val is None:
Expand Down
83 changes: 40 additions & 43 deletions nemo/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,25 @@ class ModelCheckpoint(PTLModelCheckpoint):
verbose: Verbosity mode.
save_last: When ``True``, saves a `*-last` copy whenever a checkpoint file gets saved.
save_top_k: When ``True``, saves the top-k checkpoints according to ``monitor``.
save_weights_only: if ``True``, then only the model's weights will be saved.
save_weights_only: if ``True``, then only the model's weights will be saved. Optimizer states will
be omitted from all checkpoints.
mode: One of {min, max}. Whether the objective is to minimize or maximize the monitored quantity.
every_n_epochs: Number of epochs between checkpoints.
every_n_train_steps: Number of train steps between checkpoints.
train_time_interval: After each interval, monitor checkpoints. Not to be used with
``every_n_epochs`` or ``every_n_train_steps``.
save_best_model: When ``True``, reloads and saves the best checkpoint.
save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch
enable_nemo_ckpt_io: Whether to dump the current model model state, including the
config file, to allow for reproducibility of experiments.
save_optim_on_train_end: Whether to include the optimizer states in the final checkpoint
at the end of training. Only applicable when save_weights_only is ``True``.
always_save_context: Whether to dump the artifacts needed to reinintialize the current
model, trainer, and dataloader to allow for reproducibility of experiments.
save_context_on_train_end: Whether to dump the artifacts on_train_end regardless of whether
``always_save_context`` is ``True``.
async_save: Whether to enable asynchronous checkpointing.
try_restore_best_ckpt: Whether to restore the best model path.
"""

UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished"
WEIGHTS_PATH = "weights"

def __init__(
self,
Expand All @@ -67,21 +71,21 @@ def __init__(
every_n_epochs: int = None,
every_n_train_steps: Optional[int] = None,
train_time_interval: Optional[timedelta] = None,
save_best_model: bool = False,
save_on_train_epoch_end: Optional[bool] = False, # Save after training, not after validation
enable_nemo_ckpt_io: bool = True,
try_restore_best_ckpt: bool = True,
save_optim_on_train_end: Optional[bool] = False,
always_save_context: bool = False,
save_context_on_train_end: bool = True,
**kwargs,
):
self.save_best_model = save_best_model
self.previous_best_path = ""
self.enable_nemo_ckpt_io = enable_nemo_ckpt_io
self.always_save_context = always_save_context
self.save_context_on_train_end = save_context_on_train_end
self.save_optim_on_train_end = save_optim_on_train_end

# Checkpoints which removal is deferred until async save is done.
# Each element of `deferred_ckpts_to_remove` is a growing list
# that `self._remove_checkpoint` adds to. Once `self._save_checkpoint`
# is called, the last element is frozen and a new element is added.
self.deferred_ckpts_to_remove: List[List[str]] = []
self.try_restore_best_ckpt = try_restore_best_ckpt

# Call the parent class constructor with the remaining kwargs.
super().__init__(
Expand Down Expand Up @@ -251,11 +255,9 @@ def setup(self, trainer, *args, **kwargs) -> None:
self.async_save = getattr(trainer.strategy, "async_save", False)
super().setup(trainer, *args, **kwargs)

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
output = super().on_save_checkpoint(trainer, pl_module, checkpoint)
return output

def on_train_end(self, trainer, pl_module):
from nemo.utils.get_rank import is_global_rank_zero

if trainer.fast_dev_run:
return None

Expand All @@ -272,26 +274,11 @@ def on_train_end(self, trainer, pl_module):
logging.debug(f'Last checkpoint {self.last_model_path} already saved')
else:
super()._save_last_checkpoint(trainer, monitor_candidates)
if self.save_context_on_train_end and not self.always_save_context and is_global_rank_zero():
akoumpa marked this conversation as resolved.
Show resolved Hide resolved
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(self.last_model_path) / "context")
# Call parent on_train_end() to save the -last checkpoint
super().on_train_end(trainer, pl_module)

# Load the best model and then re-save it
if self.save_best_model:
# wait for all processes
trainer.strategy.barrier("SaveBestCheckpointConnector.resume_end")
if self.best_model_path == "":
logging.warning(
f"{self} was told to save the best checkpoint at the end of training, but no saved checkpoints "
"were found. Saving latest model instead."
)

else:
if os.path.isdir(self.best_model_path.split('.ckpt')[0]):
self.best_model_path = self.best_model_path.split('.ckpt')[0]
if self.try_restore_best_ckpt:
self.best_model_path = trainer.strategy.broadcast(self.best_model_path)
trainer._checkpoint_connector.restore(self.best_model_path)

def _del_model_without_trainer(self, filepath: str) -> None:
from nemo.utils.get_rank import is_global_rank_zero

Expand Down Expand Up @@ -409,8 +396,11 @@ def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, torch.Tensor]:
return monitor_candidates

def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) -> None:
from nemo.utils.get_rank import is_global_rank_zero

# barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed.
# if anything goes wrong during checkpointing, we should be able to detect that data is incomplete.
ckpt_filepath = ckpt_to_dir(filepath) / ModelCheckpoint.WEIGHTS_PATH
self.set_checkpoint_unfinished_marker(filepath, barrier_after=True)
ema_callback = self._ema_callback(trainer)

Expand All @@ -420,17 +410,26 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)
if self.async_save:
raise ValueError('async_save with EMA not supported')
with ema_callback.save_original_optimizer_state(trainer):
super()._save_checkpoint(trainer, filepath)
super()._save_checkpoint(trainer, ckpt_filepath)

# save EMA copy of the model as well.
with ema_callback.save_ema_model(trainer):
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
filepath = self._ema_format_filepath(filepath)
rank_zero_info(f"Saving EMA weights to separate checkpoint {ckpt_filepath}")
ckpt_filepath = self._ema_format_filepath(ckpt_filepath)
if self.verbose:
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
super()._save_checkpoint(trainer, filepath)
rank_zero_info(f"Saving EMA weights to separate checkpoint {ckpt_filepath}")
super()._save_checkpoint(trainer, ckpt_filepath)
self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True)
else:
## Determine whether to include optimizer states in the checkpoint
## optimizer states are included when
## 1. save_weights_only is False and
## 2. either save_optim_on_train_end is True, or save_optim_on_train_end is False but the checkpoint
## is an intermediate checkpoint.
save_weights_only = self.save_weights_only or (
not self.save_optim_on_train_end and trainer.global_step == trainer.max_steps
)

# Async save passes the finalization function to checkpoint_io,
# sync save calls the finalization function immediately after save.
finalize_fn = self._get_finalize_save_checkpoint_callback(trainer, filepath, trainer.global_step)
Expand All @@ -445,13 +444,11 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)
self.deferred_ckpts_to_remove.append([])
else:
storage_options = None
trainer.save_checkpoint(filepath, self.save_weights_only, storage_options=storage_options)
trainer.save_checkpoint(ckpt_filepath, save_weights_only, storage_options=storage_options)

## NOTE: saving context happens synchronously always
from nemo.utils.get_rank import is_global_rank_zero
if self.always_save_context and is_global_rank_zero():
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "context")

if self.enable_nemo_ckpt_io and is_global_rank_zero():
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath))
if self.async_save:
logging.info(f'Scheduled async checkpoint save for {filepath}')
else:
Expand Down
14 changes: 9 additions & 5 deletions nemo/lightning/pytorch/strategies/fsdp_strategy.py
ashors1 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,15 @@ def save_checkpoint(
checkpoint["sharded_state_dict"] = pyt_to_mcore_state_dict(checkpoint.pop("state_dict"))
checkpoint["state_dict"] = OrderedDict([])

# TODO: do we still need to keep this?
for optim_state in checkpoint['optimizer_states']:
optim_state.pop("state")

if self.trainer.state.fn == TrainerFn.FITTING and self.ckpt_save_optimizer:
## replace unsharded optimizer_states with sharded dict.
## note that if trainer.save_checkpoint(path, save_weights_only=True) is called,
## the checkpoint will contain only model weights. Optimizer states will be omitted.
if (
"optimizer_states" in checkpoint
and self.trainer.state.fn == TrainerFn.FITTING
and self.ckpt_save_optimizer
):
del checkpoint["optimizer_states"]
checkpoint['optimizer'] = get_optimizer_state_dict(self.model, self.optimizers)
pyt_to_mcore_state_dict(checkpoint['optimizer']['state'], prefix="optimizer.state.")

Expand Down
11 changes: 10 additions & 1 deletion nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,16 @@ def save_checkpoint(
# retrieve `sharded_state_dict` if it has not already been configured in `on_save_checkpoint`
if "sharded_state_dict" not in checkpoint:
checkpoint["sharded_state_dict"] = self.megatron_parallel.sharded_state_dict()
if self.trainer.state.fn == TrainerFn.FITTING and self.ckpt_save_optimizer:

## replace unsharded optimizer_states with sharded dict.
## note that if trainer.save_checkpoint(path, save_weights_only=True) is called,
## the checkpoint will contain only model weights. Optimizer states will be omitted.
if (
"optimizer_states" in checkpoint
and self.trainer.state.fn == TrainerFn.FITTING
and self.ckpt_save_optimizer
):
del checkpoint["optimizer_states"]
checkpoint["optimizer"] = [self.optimizer_sharded_state_dict()]

self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
Expand Down
16 changes: 16 additions & 0 deletions nemo/lightning/resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ class AutoResume:
resume_past_end: bool = False
resume_ignore_no_checkpoint: bool = False

WEIGHTS_PATH = "weights"

def get_model_weights_path(self, path):
return Path(path) / self.WEIGHTS_PATH

def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None):
if isinstance(trainer, fl.Fabric):
raise NotImplementedError("Fabric is not supported yet.")
Expand All @@ -90,6 +95,7 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None):
def _try_import_model(
self, model: Optional[io.ConnectorMixin], path: str, adapter_path: Optional[str] = None
) -> BasePath:

if model is None:
raise ValueError("Model is needed to import checkpoint from HF or other non-NeMo checkpoint format.")
try:
Expand All @@ -99,6 +105,11 @@ def _try_import_model(
new_path = path

if adapter_path:

maybe_model_weights_path = self.get_model_weights_path(adapter_path)
if os.path.isdir(maybe_model_weights_path):
adapter_path = maybe_model_weights_path

new_path = AdapterPath(Path(adapter_path), base_model_path=new_path)

if isinstance(new_path, str):
Expand Down Expand Up @@ -211,6 +222,11 @@ def get_trainer_ckpt_path(self, model: Optional[io.ConnectorMixin] = None) -> Op
if self.resume_if_exists:
checkpoint = self._find_trainer_ckpt_path()

if checkpoint:
maybe_model_weights_path = self.get_model_weights_path(checkpoint)
if os.path.isdir(maybe_model_weights_path):
checkpoint = maybe_model_weights_path

if checkpoint:
if self.adapter_path:
return AdapterPath(Path(self.adapter_path), base_model_path=checkpoint)
Expand Down
6 changes: 4 additions & 2 deletions tests/collections/llm/test_mnist_model_nemo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,13 +496,12 @@ def run_train_mnist_litautoencoder_with_megatron_strategy_single_gpu():
# Configure our custom Checkpointer
name = "test_experiment"
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_best_model=True,
save_last=True,
monitor="val_loss",
save_top_k=1,
every_n_train_steps=5,
# Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
enable_nemo_ckpt_io=True,
always_save_context=True,
)
root_dir = tmpdir
save_dir = root_dir / name
Expand Down Expand Up @@ -571,6 +570,9 @@ def run_train_mnist_litautoencoder_with_megatron_strategy_single_gpu():
ckpt_path = checkpoint_callback.last_model_path.replace(
".ckpt", ""
) # strip .ckpt off the end of the last path
ckpt_path = (
Path(ckpt_path) / "weights"
) ## weights are saved to the "weights" directory within the checkpoint

assert Path(
ckpt_path
Expand Down
6 changes: 4 additions & 2 deletions tests/collections/llm/test_mnist_model_nemo2_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,13 +519,12 @@ def run_train_mnist_litautoencoder_with_fsdp_strategy_single_gpu():
# Configure our custom Checkpointer
name = "test_experiment"
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_best_model=True,
save_last=True,
monitor="val_loss",
save_top_k=1,
every_n_train_steps=5,
# Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
enable_nemo_ckpt_io=True,
always_save_context=True,
)
root_dir = tmpdir
save_dir = root_dir / name
Expand Down Expand Up @@ -583,6 +582,9 @@ def run_train_mnist_litautoencoder_with_fsdp_strategy_single_gpu():
ckpt_path = checkpoint_callback.last_model_path.replace(
".ckpt", ""
) # strip .ckpt off the end of the last path
ckpt_path = (
Path(ckpt_path) / "weights"
) ## weights are saved to the "weights" directory within the checkpoint

assert Path(
ckpt_path
Expand Down
12 changes: 8 additions & 4 deletions tests/lightning/test_nemo_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
import shutil
import time
from pathlib import Path
from unittest.mock import patch
Expand Down Expand Up @@ -159,7 +160,10 @@ def test_resume(self, trainer, tmp_path):
## if there are multiple "-last" checkpoints, choose the most recent one
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel--last").mkdir()
time.sleep(1) ## sleep for a second so the checkpoints are created at different times
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last").mkdir()
## make a "weights" dir within the checkpoint
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last" / "weights").mkdir(
parents=True
)
time.sleep(1)
# unfinished last, that should be ignored
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel3--last").mkdir()
Expand All @@ -169,19 +173,19 @@ def test_resume(self, trainer, tmp_path):
resume_from_directory=Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints"),
resume_if_exists=True,
).setup(trainer)
## if "weights" exists, we should restore from there
assert str(trainer.ckpt_path) == str(
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last")
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last" / "weights")
)

# Finally succeed
logger = nl.NeMoLogger(
name="default",
dir=str(tmp_path) + "/test_resume",
version="version_0",
use_datetime_version=False,
)
logger.setup(trainer)
Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last").rmdir()
shutil.rmtree(Path(tmp_path / "test_resume" / "default" / "version_0" / "checkpoints" / "mymodel2--last"))
nl.AutoResume(
resume_if_exists=True,
).setup(trainer)
Expand Down
Loading