Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NeMo-UX] checkpointing improvements #10241

Merged
merged 41 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
d98db0a
save model weights and artifacts to separate directories
ashors1 Aug 22, 2024
ca04f47
add save_artifacts_on_train_end
ashors1 Aug 24, 2024
9264eee
Apply isort and black reformatting
ashors1 Aug 24, 2024
99cff3a
Merge branch 'main' of github.com:NVIDIA/NeMo into ashors/nemo-ux-ckp…
ashors1 Aug 24, 2024
8e7f3f9
Merge branch 'ashors/nemo-ux-ckpt-dir' of github.com:NVIDIA/NeMo into…
ashors1 Aug 24, 2024
1a5a457
do not save optimizer states in final checkpoint
ashors1 Aug 28, 2024
c9302fa
WIP support for saving only last k optimizer states
ashors1 Aug 28, 2024
6da2371
Apply isort and black reformatting
ashors1 Aug 28, 2024
9b1f93c
minor cleanup
ashors1 Aug 28, 2024
402471d
Revert support for saving last k optimizer states. This will be addre…
ashors1 Aug 29, 2024
0a55953
use storage_options to determine when to skip saving optimizer states
ashors1 Aug 30, 2024
aa67b8a
Apply isort and black reformatting
ashors1 Aug 30, 2024
11001a0
fix variable names, make checkpoint load work when optimizer states d…
ashors1 Aug 30, 2024
483c942
Merge branch 'ashors/nemo-ux-ckpt-dir' of github.com:NVIDIA/NeMo into…
ashors1 Aug 30, 2024
3b3b779
Apply isort and black reformatting
ashors1 Aug 30, 2024
db7e0de
Merge branch 'main' of github.com:NVIDIA/NeMo into ashors/nemo-ux-ckp…
ashors1 Aug 30, 2024
1b9696f
FSDP updates, provide option to save optimizer states on_train_end
ashors1 Aug 30, 2024
21fcb40
Apply isort and black reformatting
ashors1 Aug 30, 2024
b9a0e9e
simplify implementation, remove save_best_model option
ashors1 Aug 30, 2024
a740bf9
Merge branch 'ashors/nemo-ux-ckpt-dir' of github.com:NVIDIA/NeMo into…
ashors1 Aug 30, 2024
3c74851
update default value of ckpt_include_optimizer for fsdp
ashors1 Aug 30, 2024
3fd76ad
remove unused imports
ashors1 Sep 3, 2024
2817d56
remove unused import
ashors1 Sep 3, 2024
8f17991
cleanup
ashors1 Sep 3, 2024
a4e6954
make storage_options optional again
ashors1 Sep 4, 2024
10e4f88
fix failing tests
ashors1 Sep 4, 2024
a78717d
address some comments
ashors1 Sep 4, 2024
5916ee3
use save_weights_only to determine whether to save optimizer states
ashors1 Sep 4, 2024
5745105
Apply isort and black reformatting
ashors1 Sep 4, 2024
b9be6e9
add some comments
ashors1 Sep 4, 2024
adf418e
Merge branch 'ashors/nemo-ux-ckpt-dir' of github.com:NVIDIA/NeMo into…
ashors1 Sep 4, 2024
6befd8a
fix tests
ashors1 Sep 5, 2024
65ebadb
Apply isort and black reformatting
ashors1 Sep 5, 2024
a95edf3
Merge branch 'main' of github.com:NVIDIA/NeMo into ashors/nemo-ux-ckp…
ashors1 Sep 5, 2024
9643ef5
Merge branch 'ashors/nemo-ux-ckpt-dir' of github.com:NVIDIA/NeMo into…
ashors1 Sep 5, 2024
09a38a5
Apply isort and black reformatting
ashors1 Sep 5, 2024
987e4f6
fixes
ashors1 Sep 5, 2024
6c5a1c6
Merge branch 'ashors/nemo-ux-ckpt-dir' of github.com:NVIDIA/NeMo into…
ashors1 Sep 5, 2024
54591a6
Apply isort and black reformatting
ashors1 Sep 5, 2024
1974e97
remove unnecessary line
ashors1 Sep 5, 2024
10d619e
Merge branch 'main' of github.com:NVIDIA/NeMo into ashors/nemo-ux-ckp…
ashors1 Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nemo/lightning/io/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def io_dump(self, output: Path):
will be stored.
"""
output_path = Path(output)
local_artifacts_dir = "artifacts"
local_artifacts_dir = "."
artifacts_dir = output_path / local_artifacts_dir
artifacts_dir.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -523,7 +523,7 @@ def _io_path_elements_fn(x):
return x.__io__.__path_elements__()


def _artifact_transform_save(cfg: fdl.Config, output_path: Path, relative_dir: Path = "artifacts"):
def _artifact_transform_save(cfg: fdl.Config, output_path: Path, relative_dir: Path = "."):
for artifact in getattr(cfg.__fn_or_cls__, "__io_artifacts__", []):
current_val = getattr(cfg, artifact.attr)
if current_val is None:
Expand Down
54 changes: 32 additions & 22 deletions nemo/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,16 @@ class ModelCheckpoint(PTLModelCheckpoint):
``every_n_epochs`` or ``every_n_train_steps``.
save_best_model: When ``True``, reloads and saves the best checkpoint.
save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch
enable_nemo_ckpt_io: Whether to dump the current model model state, including the
config file, to allow for reproducibility of experiments.
always_save_artifacts: Whether to dump the artifacts needed to reinintialize the current
ashors1 marked this conversation as resolved.
Show resolved Hide resolved
model, trainer, and dataloader to allow for reproducibility of experiments.
save_artifacts_on_train_end: Whether to dump the artifacts on_train_end regardless of whether
``always_save_artifacts`` is ``True``.
async_save: Whether to enable asynchronous checkpointing.
try_restore_best_ckpt: Whether to restore the best model path.
"""

UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished"
MODEL_WEIGHTS_PATH = "model_weights"
ashors1 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
Expand All @@ -69,13 +72,16 @@ def __init__(
train_time_interval: Optional[timedelta] = None,
save_best_model: bool = False,
save_on_train_epoch_end: Optional[bool] = False, # Save after training, not after validation
enable_nemo_ckpt_io: bool = True,
always_save_artifacts: bool = False,
save_artifacts_on_train_end: bool = True,
try_restore_best_ckpt: bool = True,
**kwargs,
):
self.save_best_model = save_best_model
self.previous_best_path = ""
self.enable_nemo_ckpt_io = enable_nemo_ckpt_io
self.always_save_artifacts = always_save_artifacts
self.save_artifacts_on_train_end = save_artifacts_on_train_end

# Checkpoints which removal is deferred until async save is done.
# Each element of `deferred_ckpts_to_remove` is a growing list
# that `self._remove_checkpoint` adds to. Once `self._save_checkpoint`
Expand Down Expand Up @@ -251,11 +257,9 @@ def setup(self, trainer, *args, **kwargs) -> None:
self.async_save = getattr(trainer.strategy, "async_save", False)
super().setup(trainer, *args, **kwargs)

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
output = super().on_save_checkpoint(trainer, pl_module, checkpoint)
return output

def on_train_end(self, trainer, pl_module):
from nemo.utils.get_rank import is_global_rank_zero

if trainer.fast_dev_run:
return None

Expand All @@ -272,6 +276,8 @@ def on_train_end(self, trainer, pl_module):
logging.debug(f'Last checkpoint {self.last_model_path} already saved')
else:
super()._save_last_checkpoint(trainer, monitor_candidates)
if self.save_artifacts_on_train_end and not self.always_save_artifacts and is_global_rank_zero():
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(self.last_model_path) / "artifacts")
ashors1 marked this conversation as resolved.
Show resolved Hide resolved
# Call parent on_train_end() to save the -last checkpoint
super().on_train_end(trainer, pl_module)

Expand All @@ -287,7 +293,9 @@ def on_train_end(self, trainer, pl_module):

else:
if os.path.isdir(self.best_model_path.split('.ckpt')[0]):
self.best_model_path = self.best_model_path.split('.ckpt')[0]
self.best_model_path = (
Path(self.best_model_path.split('.ckpt')[0]) / ModelCheckpoint.MODEL_WEIGHTS_PATH
)
if self.try_restore_best_ckpt:
self.best_model_path = trainer.strategy.broadcast(self.best_model_path)
trainer._checkpoint_connector.restore(self.best_model_path)
Expand Down Expand Up @@ -409,8 +417,11 @@ def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, torch.Tensor]:
return monitor_candidates

def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) -> None:
from nemo.utils.get_rank import is_global_rank_zero

# barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed.
# if anything goes wrong during checkpointing, we should be able to detect that data is incomplete.
ckpt_filepath = ckpt_to_dir(filepath) / ModelCheckpoint.MODEL_WEIGHTS_PATH
self.set_checkpoint_unfinished_marker(filepath, barrier_after=True)
ema_callback = self._ema_callback(trainer)

Expand All @@ -420,17 +431,20 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)
if self.async_save:
raise ValueError('async_save with EMA not supported')
with ema_callback.save_original_optimizer_state(trainer):
super()._save_checkpoint(trainer, filepath)
super()._save_checkpoint(trainer, ckpt_filepath)

# save EMA copy of the model as well.
with ema_callback.save_ema_model(trainer):
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
filepath = self._ema_format_filepath(filepath)
rank_zero_info(f"Saving EMA weights to separate checkpoint {ckpt_filepath}")
ckpt_filepath = self._ema_format_filepath(ckpt_filepath)
if self.verbose:
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
super()._save_checkpoint(trainer, filepath)
rank_zero_info(f"Saving EMA weights to separate checkpoint {ckpt_filepath}")
super()._save_checkpoint(trainer, ckpt_filepath)
self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True)
else:
## Do not include optimizer states in final checkpoint
ashors1 marked this conversation as resolved.
Show resolved Hide resolved
storage_options = dict(include_optimizer=(trainer.global_step == trainer.max_steps))

# Async save passes the finalization function to checkpoint_io,
# sync save calls the finalization function immediately after save.
finalize_fn = self._get_finalize_save_checkpoint_callback(trainer, filepath, trainer.global_step)
Expand All @@ -440,18 +454,14 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)

if not isinstance(checkpoint_io, AsyncFinalizableCheckpointIO):
raise ValueError('Async save requires async compatible CheckpointIO')
storage_options = dict(finalize_fn=finalize_fn)
storage_options["finalize_fn"] = finalize_fn
# Each upcoming ckpt removal request will be executed as part of this save finalization
self.deferred_ckpts_to_remove.append([])
else:
storage_options = None
trainer.save_checkpoint(filepath, self.save_weights_only, storage_options=storage_options)
trainer.save_checkpoint(ckpt_filepath, self.save_weights_only, storage_options=storage_options)

## NOTE: saving context happens synchronously always
from nemo.utils.get_rank import is_global_rank_zero
if self.always_save_artifacts and is_global_rank_zero():
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "artifacts")

if self.enable_nemo_ckpt_io and is_global_rank_zero():
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath))
if self.async_save:
logging.info(f'Scheduled async checkpoint save for {filepath}')
else:
Expand Down
10 changes: 9 additions & 1 deletion nemo/lightning/pytorch/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,11 +610,19 @@ def save_checkpoint(
# retrieve `sharded_state_dict` if it has not already been configured in `on_save_checkpoint`
if "sharded_state_dict" not in checkpoint:
checkpoint["sharded_state_dict"] = self.megatron_parallel.sharded_state_dict()
if self.trainer.state.fn == TrainerFn.FITTING and self.ckpt_include_optimizer:

## only save optimizer states if self.ckpt_include_optimizer and storage_options["include_optimizer"]
## are both True
include_optimizer = self.ckpt_include_optimizer
if "include_optimizer" in storage_options:
include_optimizer = not include_optimizer and storage_options["include_optimizer"]

if self.trainer.state.fn == TrainerFn.FITTING and include_optimizer:
checkpoint["optimizer"] = [self.optimizer_sharded_state_dict()]

self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)

## TODO: only load the optimizer states when they exist in the checkpoint?
@override
def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
"""PTL method which we override to integrate distributed checkpoints for model parallel models.
Expand Down
6 changes: 6 additions & 0 deletions nemo/lightning/resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class AutoResume(Resume, io.IOMixin):
checkpoints in NeMo.
"""

MAYBE_MODEL_WEIGHTS_PATH = "model_weights"
ashors1 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
path: Optional[str] = None, ## old resume_from_checkpoint
Expand Down Expand Up @@ -151,6 +153,10 @@ def nemo_path(self, model=None) -> Optional[Path]:
if checkpoint:
if self.adapter_path:
return AdapterPath(checkpoint, adapter_path=Path(self.adapter_path))

model_weights_path = Path(checkpoint) / AutoResume.MAYBE_MODEL_WEIGHTS_PATH
if os.path.isdir(model_weights_path):
ashors1 marked this conversation as resolved.
Show resolved Hide resolved
return model_weights_path
return Path(checkpoint)

return None
Expand Down
Loading