From d4ddbbe197cbe4c7b3044b602173eb1ad234827c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 12 Jul 2024 16:56:43 +0200 Subject: [PATCH] Allows non-strict load with distributed checkpoints (#9613) (#9715) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Allow non-strict load * Point to non-stric load MCore branch * Avoid module level StrictHandling * Use MCore fork * Update to MCore fix * Restore ackward compatibility * Update flag defaults * Update MCore tag * Update PyT Dist interface * Update to latest core_r0.8.0 --------- Signed-off-by: Mikołaj Błaż Co-authored-by: mikolajblaz Signed-off-by: Malay Nagda --- Dockerfile.ci | 6 +-- .../conf/megatron_gpt_config.yaml | 1 + nemo/utils/callbacks/dist_ckpt_io.py | 50 ++++++++----------- 3 files changed, 26 insertions(+), 31 deletions(-) diff --git a/Dockerfile.ci b/Dockerfile.ci index 12e0a3af7cd2..2a7006c057f1 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -34,7 +34,7 @@ WORKDIR /workspace # Install NeMo requirements ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea ARG MODELOPT_VERSION=0.13.0 -ARG MCORE_TAG=de1b7c223303f6ba21e0540f27361334116efcbc +ARG MCORE_TAG=c0164bcfd4f8213a10a6b1e47ef80721a68b4fb6 ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c RUN \ --mount=type=bind,source=requirements,target=requirements \ @@ -69,14 +69,14 @@ git clone https://github.com/state-spaces/mamba.git && \ git checkout v2.0.3 && \ python setup.py install && \ cd .. && \ - rm -rf mamba + rm -rf mamba git clone https://github.com/Dao-AILab/causal-conv1d && \ cd causal-conv1d && \ git checkout v1.2.2.post1 && \ python setup.py install && \ cd .. && \ - rm -rf causal-conv1d + rm -rf causal-conv1d EOF diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index 1599f38cbfa8..809ca30ca5ed 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -182,6 +182,7 @@ model: dist_ckpt_torch_dist_multiproc: 2 # number of extra processes per rank used during ckpt save with PyTorch distributed format dist_ckpt_assume_constant_structure: False # set to True only if the state dict structure doesn't change within a single job. Allows caching some computation across checkpoint saves. dist_ckpt_parallel_dist_opt: True # parallel save/load of a DistributedOptimizer. 'True' allows performant save and reshardable checkpoints. Set to 'False' only in order to minimize the number of checkpoint files. + dist_ckpt_load_strictness: null # defines checkpoint keys mismatch behavior (only during dist-ckpt load). Choices: assume_ok_unexpected (default - try loading without any check), log_all (log mismatches), raise_all (raise mismatches) ## Activation Checkpointing # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. diff --git a/nemo/utils/callbacks/dist_ckpt_io.py b/nemo/utils/callbacks/dist_ckpt_io.py index ad2ad1eebec0..9348779051bb 100644 --- a/nemo/utils/callbacks/dist_ckpt_io.py +++ b/nemo/utils/callbacks/dist_ckpt_io.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from contextlib import contextmanager from time import time -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union import pytorch_lightning as pl from lightning_fabric.plugins import CheckpointIO @@ -44,6 +44,7 @@ FullyParallelSaveStrategyWrapper, ) from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy + from megatron.core.dist_checkpointing.validation import StrictHandling from megatron.core.parallel_state import get_data_parallel_group HAVE_MEGATRON_CORE = True @@ -188,6 +189,9 @@ class DistributedCheckpointIO(AsyncCompatibleCheckpointIO): load_directly_on_device (bool, optional): if True, loads the weights directly on GPU. Has effect only for `zarr` based checkpoints (PyT Distributed always loads on device). Defaults to True. + load_strictness (StrictHandling, optional): defines loading strictness. + If not None, overwrites the `strict` flag passed to `load_checkpoint`. + Defaults to None. async_save (bool): whether to save asynchronously. Should be set to True if this class will be wrapped with AsyncFinalizableCheckpointIO. torch_dist_multiproc (int, optional): number of extra processes per rank @@ -202,6 +206,7 @@ def __init__( self, save_ckpt_format: str, load_directly_on_device: bool = True, + load_strictness: Optional['StrictHandling'] = None, async_save: bool = False, torch_dist_multiproc: Optional[int] = None, assume_constant_structure: bool = False, @@ -215,6 +220,7 @@ def __init__( self.save_ckpt_format = save_ckpt_format self.load_directly_on_device = load_directly_on_device + self.load_strictness = load_strictness self.async_save = async_save self.torch_dist_multiproc = torch_dist_multiproc self.assume_constant_structure = assume_constant_structure @@ -238,6 +244,7 @@ def from_config(cls, model_cfg: dict, async_save: bool = False): return cls( save_ckpt_format=model_cfg.get('dist_ckpt_format', 'zarr'), load_directly_on_device=model_cfg.get('dist_ckpt_load_on_device', True), + load_strictness=model_cfg.get('dist_ckpt_load_strictness', None), async_save=async_save, torch_dist_multiproc=model_cfg.get('dist_ckpt_torch_dist_multiproc', None), parallel_save=model_cfg.get('dist_ckpt_parallel_save', False), @@ -275,7 +282,7 @@ def load_checkpoint( path: _PATH, map_location: Optional[Any] = None, sharded_state_dict: Dict[str, Any] = None, - strict: Optional[bool] = True, + strict: Union[None, bool, 'StrictHandling'] = None, validate_access_integrity: Optional[bool] = True, ) -> Dict[str, Any]: """Loads a distributed checkpoint. @@ -287,6 +294,10 @@ def load_checkpoint( defines the loading procedure for the distributed checkpoint. Defaults to None to comply with the CheckpointIO interface, but it's a required argument. + strict (bool, StrictHandling, optional): adjust load strictness. bool value + is translated to StrictHandling instance. Gets overwritten by + `self.load_strictness`. Defaults to None. If `self.load_strictness` + is also None, strict becomes StrictHandling.ASSUME_OK_UNEXPECTED. Returns: Dist[str, Any]: loaded checkpoint. @@ -311,40 +322,23 @@ def load_checkpoint( if sharded_strategy is not None: logging.info(f'Using {sharded_strategy} dist-ckpt load strategy.') - if not strict: - sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict) + if isinstance(strict, bool): + strict = StrictHandling.ASSUME_OK_UNEXPECTED if strict else StrictHandling.LOG_ALL + if self.load_strictness is not None: + # Overwrites function argument + strict = self.load_strictness + if strict is None: + # Default behavior + strict = StrictHandling.ASSUME_OK_UNEXPECTED return dist_checkpointing.load( sharded_state_dict=sharded_state_dict, checkpoint_dir=path, sharded_strategy=sharded_strategy, validate_access_integrity=validate_access_integrity, + strict=strict, ) - def adjust_non_strict_load(self, path: _PATH, sharded_state_dict: Dict[str, Any]): - ckpt_sharded_metadata = dist_checkpointing.load_tensors_metadata(path) - loaded_keys = [] - missing_keys = [] - unexpected_keys = [] - - def should_remove_missing_sharded_base(x: Any): - if isinstance(x, ShardedBase): - if x.key in ckpt_sharded_metadata: - loaded_keys.append(x.key) - return False - else: - unexpected_keys.append(x.key) - return True - return False - - _, sharded_state_dict = extract_matching_values(sharded_state_dict, should_remove_missing_sharded_base) - logging.info(f'The following keys are not in the checkpoint and will not be loaded: {unexpected_keys}') - - # TODO: compute missing_keys by: - # 1. all_gather_object of loaded_keys - # 2. missing_keys = ckpt_sharded_metadata.keys() - loaded_keys - return sharded_state_dict - @_debug_time('DistributedCheckpointIO.remove_checkpoint') def remove_checkpoint(self, path: _PATH) -> None: """Remove a distributed checkpoint.