Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NeMo-UX] checkpointing improvements #10241

Merged
merged 41 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
d98db0a
save model weights and artifacts to separate directories
ashors1 Aug 22, 2024
ca04f47
add save_artifacts_on_train_end
ashors1 Aug 24, 2024
9264eee
Apply isort and black reformatting
ashors1 Aug 24, 2024
99cff3a
Merge branch 'main' of github.com:NVIDIA/NeMo into ashors/nemo-ux-ckp…
ashors1 Aug 24, 2024
8e7f3f9
Merge branch 'ashors/nemo-ux-ckpt-dir' of github.com:NVIDIA/NeMo into…
ashors1 Aug 24, 2024
1a5a457
do not save optimizer states in final checkpoint
ashors1 Aug 28, 2024
c9302fa
WIP support for saving only last k optimizer states
ashors1 Aug 28, 2024
6da2371
Apply isort and black reformatting
ashors1 Aug 28, 2024
9b1f93c
minor cleanup
ashors1 Aug 28, 2024
402471d
Revert support for saving last k optimizer states. This will be addre…
ashors1 Aug 29, 2024
0a55953
use storage_options to determine when to skip saving optimizer states
ashors1 Aug 30, 2024
aa67b8a
Apply isort and black reformatting
ashors1 Aug 30, 2024
11001a0
fix variable names, make checkpoint load work when optimizer states d…
ashors1 Aug 30, 2024
483c942
Merge branch 'ashors/nemo-ux-ckpt-dir' of github.com:NVIDIA/NeMo into…
ashors1 Aug 30, 2024
3b3b779
Apply isort and black reformatting
ashors1 Aug 30, 2024
db7e0de
Merge branch 'main' of github.com:NVIDIA/NeMo into ashors/nemo-ux-ckp…
ashors1 Aug 30, 2024
1b9696f
FSDP updates, provide option to save optimizer states on_train_end
ashors1 Aug 30, 2024
21fcb40
Apply isort and black reformatting
ashors1 Aug 30, 2024
b9a0e9e
simplify implementation, remove save_best_model option
ashors1 Aug 30, 2024
a740bf9
Merge branch 'ashors/nemo-ux-ckpt-dir' of github.com:NVIDIA/NeMo into…
ashors1 Aug 30, 2024
3c74851
update default value of ckpt_include_optimizer for fsdp
ashors1 Aug 30, 2024
3fd76ad
remove unused imports
ashors1 Sep 3, 2024
2817d56
remove unused import
ashors1 Sep 3, 2024
8f17991
cleanup
ashors1 Sep 3, 2024
a4e6954
make storage_options optional again
ashors1 Sep 4, 2024
10e4f88
fix failing tests
ashors1 Sep 4, 2024
a78717d
address some comments
ashors1 Sep 4, 2024
5916ee3
use save_weights_only to determine whether to save optimizer states
ashors1 Sep 4, 2024
5745105
Apply isort and black reformatting
ashors1 Sep 4, 2024
b9be6e9
add some comments
ashors1 Sep 4, 2024
adf418e
Merge branch 'ashors/nemo-ux-ckpt-dir' of github.com:NVIDIA/NeMo into…
ashors1 Sep 4, 2024
6befd8a
fix tests
ashors1 Sep 5, 2024
65ebadb
Apply isort and black reformatting
ashors1 Sep 5, 2024
a95edf3
Merge branch 'main' of github.com:NVIDIA/NeMo into ashors/nemo-ux-ckp…
ashors1 Sep 5, 2024
9643ef5
Merge branch 'ashors/nemo-ux-ckpt-dir' of github.com:NVIDIA/NeMo into…
ashors1 Sep 5, 2024
09a38a5
Apply isort and black reformatting
ashors1 Sep 5, 2024
987e4f6
fixes
ashors1 Sep 5, 2024
6c5a1c6
Merge branch 'ashors/nemo-ux-ckpt-dir' of github.com:NVIDIA/NeMo into…
ashors1 Sep 5, 2024
54591a6
Apply isort and black reformatting
ashors1 Sep 5, 2024
1974e97
remove unnecessary line
ashors1 Sep 5, 2024
10d619e
Merge branch 'main' of github.com:NVIDIA/NeMo into ashors/nemo-ux-ckp…
ashors1 Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/llm/megatron_gpt_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def get_args():
strategy = nl.MegatronStrategy()
checkpoint_callback = ModelCheckpoint(
every_n_train_steps=5000,
enable_nemo_ckpt_io=False,
)
callbacks = [checkpoint_callback]

Expand Down
4 changes: 2 additions & 2 deletions nemo/lightning/io/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def io_dump(self, output: Path):
will be stored.
"""
output_path = Path(output)
local_artifacts_dir = "artifacts"
local_artifacts_dir = "."
artifacts_dir = output_path / local_artifacts_dir
artifacts_dir.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -518,7 +518,7 @@ def _io_path_elements_fn(x):
return x.__io__.__path_elements__()


def _artifact_transform_save(cfg: fdl.Config, output_path: Path, relative_dir: Path = "artifacts"):
def _artifact_transform_save(cfg: fdl.Config, output_path: Path, relative_dir: Path = "."):
for artifact in getattr(cfg.__fn_or_cls__, "__io_artifacts__", []):
current_val = getattr(cfg, artifact.attr)
if current_val is None:
Expand Down
76 changes: 34 additions & 42 deletions nemo/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,18 @@ class ModelCheckpoint(PTLModelCheckpoint):
every_n_train_steps: Number of train steps between checkpoints.
train_time_interval: After each interval, monitor checkpoints. Not to be used with
``every_n_epochs`` or ``every_n_train_steps``.
save_best_model: When ``True``, reloads and saves the best checkpoint.
save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch
enable_nemo_ckpt_io: Whether to dump the current model model state, including the
config file, to allow for reproducibility of experiments.
save_optim_on_train_end: Whether to include the optimizer states in the final checkpoint
at the end of training. Only applicable when save_weights_only is ``True``.
always_save_context: Whether to dump the artifacts needed to reinintialize the current
model, trainer, and dataloader to allow for reproducibility of experiments.
save_context_on_train_end: Whether to dump the artifacts on_train_end regardless of whether
``always_save_context`` is ``True``.
async_save: Whether to enable asynchronous checkpointing.
try_restore_best_ckpt: Whether to restore the best model path.
"""

UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished"
WEIGHTS_PATH = "weights"

def __init__(
self,
Expand All @@ -67,21 +70,21 @@ def __init__(
every_n_epochs: int = None,
every_n_train_steps: Optional[int] = None,
train_time_interval: Optional[timedelta] = None,
save_best_model: bool = False,
save_on_train_epoch_end: Optional[bool] = False, # Save after training, not after validation
enable_nemo_ckpt_io: bool = True,
try_restore_best_ckpt: bool = True,
save_optim_on_train_end: Optional[bool] = False,
always_save_context: bool = False,
save_context_on_train_end: bool = True,
**kwargs,
):
self.save_best_model = save_best_model
self.previous_best_path = ""
self.enable_nemo_ckpt_io = enable_nemo_ckpt_io
self.always_save_context = always_save_context
self.save_context_on_train_end = save_context_on_train_end
self.save_optim_on_train_end = save_optim_on_train_end

# Checkpoints which removal is deferred until async save is done.
# Each element of `deferred_ckpts_to_remove` is a growing list
# that `self._remove_checkpoint` adds to. Once `self._save_checkpoint`
# is called, the last element is frozen and a new element is added.
self.deferred_ckpts_to_remove: List[List[str]] = []
self.try_restore_best_ckpt = try_restore_best_ckpt

# Call the parent class constructor with the remaining kwargs.
super().__init__(
Expand Down Expand Up @@ -251,11 +254,9 @@ def setup(self, trainer, *args, **kwargs) -> None:
self.async_save = getattr(trainer.strategy, "async_save", False)
super().setup(trainer, *args, **kwargs)

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
output = super().on_save_checkpoint(trainer, pl_module, checkpoint)
return output

def on_train_end(self, trainer, pl_module):
from nemo.utils.get_rank import is_global_rank_zero

if trainer.fast_dev_run:
return None

Expand All @@ -272,26 +273,11 @@ def on_train_end(self, trainer, pl_module):
logging.debug(f'Last checkpoint {self.last_model_path} already saved')
else:
super()._save_last_checkpoint(trainer, monitor_candidates)
if self.save_context_on_train_end and not self.always_save_context and is_global_rank_zero():
akoumpa marked this conversation as resolved.
Show resolved Hide resolved
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(self.last_model_path) / "context")
# Call parent on_train_end() to save the -last checkpoint
super().on_train_end(trainer, pl_module)

# Load the best model and then re-save it
if self.save_best_model:
# wait for all processes
trainer.strategy.barrier("SaveBestCheckpointConnector.resume_end")
if self.best_model_path == "":
logging.warning(
f"{self} was told to save the best checkpoint at the end of training, but no saved checkpoints "
"were found. Saving latest model instead."
)

else:
if os.path.isdir(self.best_model_path.split('.ckpt')[0]):
self.best_model_path = self.best_model_path.split('.ckpt')[0]
if self.try_restore_best_ckpt:
self.best_model_path = trainer.strategy.broadcast(self.best_model_path)
trainer._checkpoint_connector.restore(self.best_model_path)

def _del_model_without_trainer(self, filepath: str) -> None:
from nemo.utils.get_rank import is_global_rank_zero

Expand Down Expand Up @@ -409,8 +395,11 @@ def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, torch.Tensor]:
return monitor_candidates

def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) -> None:
from nemo.utils.get_rank import is_global_rank_zero

# barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed.
# if anything goes wrong during checkpointing, we should be able to detect that data is incomplete.
ckpt_filepath = ckpt_to_dir(filepath) / ModelCheckpoint.WEIGHTS_PATH
self.set_checkpoint_unfinished_marker(filepath, barrier_after=True)
ema_callback = self._ema_callback(trainer)

Expand All @@ -420,17 +409,22 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)
if self.async_save:
raise ValueError('async_save with EMA not supported')
with ema_callback.save_original_optimizer_state(trainer):
super()._save_checkpoint(trainer, filepath)
super()._save_checkpoint(trainer, ckpt_filepath)

# save EMA copy of the model as well.
with ema_callback.save_ema_model(trainer):
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
filepath = self._ema_format_filepath(filepath)
rank_zero_info(f"Saving EMA weights to separate checkpoint {ckpt_filepath}")
ckpt_filepath = self._ema_format_filepath(ckpt_filepath)
if self.verbose:
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
super()._save_checkpoint(trainer, filepath)
rank_zero_info(f"Saving EMA weights to separate checkpoint {ckpt_filepath}")
super()._save_checkpoint(trainer, ckpt_filepath)
self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True)
else:
## Whether to include optimizer states
save_weights_only = self.save_weights_only or (
not self.save_optim_on_train_end and trainer.global_step == trainer.max_steps
)

# Async save passes the finalization function to checkpoint_io,
# sync save calls the finalization function immediately after save.
finalize_fn = self._get_finalize_save_checkpoint_callback(trainer, filepath, trainer.global_step)
Expand All @@ -445,13 +439,11 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)
self.deferred_ckpts_to_remove.append([])
else:
storage_options = None
trainer.save_checkpoint(filepath, self.save_weights_only, storage_options=storage_options)
trainer.save_checkpoint(ckpt_filepath, save_weights_only, storage_options=storage_options)

## NOTE: saving context happens synchronously always
from nemo.utils.get_rank import is_global_rank_zero
if self.always_save_context and is_global_rank_zero():
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "context")

if self.enable_nemo_ckpt_io and is_global_rank_zero():
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath))
if self.async_save:
logging.info(f'Scheduled async checkpoint save for {filepath}')
else:
Expand Down
14 changes: 7 additions & 7 deletions nemo/lightning/pytorch/strategies/fsdp_strategy.py
ashors1 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
self,
auto_wrap_policy={TransformerLayer},
state_dict_type="sharded",
ckpt_include_optimizer=False,
ckpt_include_optimizer=True,
data_sampler=None,
**kwargs,
):
Expand Down Expand Up @@ -189,11 +189,9 @@ def save_checkpoint(
checkpoint["sharded_state_dict"] = pyt_to_mcore_state_dict(checkpoint.pop("state_dict"))
checkpoint["state_dict"] = OrderedDict([])

# TODO: do we still need to keep this?
for optim_state in checkpoint['optimizer_states']:
optim_state.pop("state")

if self.trainer.state.fn == TrainerFn.FITTING and self.ckpt_include_optimizer:
## replace unsharded optimizer_states with sharded dict
if "optimizer_states" in checkpoint:
del checkpoint["optimizer_states"]
checkpoint['optimizer'] = get_optimizer_state_dict(self.model, self.optimizers)
pyt_to_mcore_state_dict(checkpoint['optimizer']['state'], prefix="optimizer.state.")

Expand Down Expand Up @@ -224,7 +222,9 @@ def load_checkpoint(self, checkpoint_path: str | Path) -> Dict[str, Any]:
pyt_to_mcore_state_dict(msd)
sharded_state_dict["sharded_state_dict"] = msd

if self.ckpt_include_optimizer and self.trainer.state.fn == TrainerFn.FITTING:
if (
self.ckpt_include_optimizer and self.trainer.state.fn == TrainerFn.FITTING
): ## TODO: remove ckpt_include_optimizer
osd = get_optimizer_state_dict(self.model, self.optimizers, options=StateDictOptions(cpu_offload=True))
pyt_to_mcore_state_dict(osd['state'], prefix="optimizer.state.")
sharded_state_dict["optimizer"] = osd
Expand Down
7 changes: 5 additions & 2 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,10 @@ def save_checkpoint(
# retrieve `sharded_state_dict` if it has not already been configured in `on_save_checkpoint`
if "sharded_state_dict" not in checkpoint:
checkpoint["sharded_state_dict"] = self.megatron_parallel.sharded_state_dict()
if self.trainer.state.fn == TrainerFn.FITTING and self.ckpt_include_optimizer:

## replace unsharded optimizer_states with sharded dict
if "optimizer_states" in checkpoint:
del checkpoint["optimizer_states"]
checkpoint["optimizer"] = [self.optimizer_sharded_state_dict()]

self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
Expand Down Expand Up @@ -630,7 +633,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:

@override
def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
if not self.ckpt_include_optimizer:
if not self.ckpt_include_optimizer: ## TODO: remove ckpt_include_optimizer
return

optimizer_states = checkpoint["optimizer"]
Expand Down
5 changes: 4 additions & 1 deletion nemo/lightning/resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class AutoResume(Resume, io.IOMixin):
checkpoints in NeMo.
"""

WEIGHTS_PATH = "weights"

def __init__(
self,
path: Optional[str] = None, ## old resume_from_checkpoint
Expand Down Expand Up @@ -169,7 +171,8 @@ def nemo_path(self, model=None) -> Optional[Path]:
if checkpoint:
if self.adapter_path:
return AdapterPath(checkpoint, adapter_path=Path(self.adapter_path))
return Path(checkpoint)

return Path(checkpoint) / AutoResume.WEIGHTS_PATH

return None

Expand Down
3 changes: 1 addition & 2 deletions tests/collections/llm/test_mnist_model_nemo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,13 +496,12 @@ def run_train_mnist_litautoencoder_with_megatron_strategy_single_gpu():
# Configure our custom Checkpointer
name = "test_experiment"
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_best_model=True,
save_last=True,
monitor="val_loss",
save_top_k=1,
every_n_train_steps=5,
# Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
enable_nemo_ckpt_io=True,
always_save_context=True,
)
root_dir = tmpdir
save_dir = root_dir / name
Expand Down
3 changes: 1 addition & 2 deletions tests/collections/llm/test_mnist_model_nemo2_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,13 +519,12 @@ def run_train_mnist_litautoencoder_with_fsdp_strategy_single_gpu():
# Configure our custom Checkpointer
name = "test_experiment"
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_best_model=True,
save_last=True,
monitor="val_loss",
save_top_k=1,
every_n_train_steps=5,
# Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
enable_nemo_ckpt_io=True,
always_save_context=True,
)
root_dir = tmpdir
save_dir = root_dir / name
Expand Down
Loading