Skip to content

Commit

Permalink
Merge branch 'r2.0.0rc1' of github.com:NVIDIA/NeMo into ashors/nemo-u…
Browse files Browse the repository at this point in the history
…x-load-on-device
  • Loading branch information
ashors1 committed Jul 10, 2024
2 parents f0f1ca9 + 5e91378 commit 67e05c1
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 3 deletions.
3 changes: 2 additions & 1 deletion examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ model:
# Distributed checkpoint setup
dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format.
dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU
dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint
dist_ckpt_parallel_save: True # if true, each worker will write its own part of the dist checkpoint
dist_ckpt_parallel_save_within_dp: False # if true, save will be parallelized only within a DP group (whole world otherwise), which might slightly reduce the save overhead
dist_ckpt_parallel_load: False # if true, each worker will load part of the dist checkpoint and exchange with NCCL. Might use some extra GPU memory
dist_ckpt_torch_dist_multiproc: 2 # number of extra processes per rank used during ckpt save with PyTorch distributed format
dist_ckpt_assume_constant_structure: False # set to True only if the state dict structure doesn't change within a single job. Allows caching some computation across checkpoint saves.
Expand Down
7 changes: 6 additions & 1 deletion nemo/lightning/io/pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
torch_dist_multiproc: Optional[int] = None,
assume_constant_structure: bool = False,
parallel_save: bool = True,
parallel_save_within_dp: bool = False,
parallel_load: bool = False,
):
self.save_ckpt_format = save_ckpt_format
Expand All @@ -85,6 +86,7 @@ def __init__(
self.torch_dist_multiproc = torch_dist_multiproc
self.assume_constant_structure = assume_constant_structure
self.parallel_save = parallel_save
self.parallel_save_within_dp = parallel_save_within_dp
self.parallel_load = parallel_load

self._save_sharded_strategy = None
Expand Down Expand Up @@ -216,8 +218,11 @@ def _determine_dist_ckpt_save_strategy(self):
save_strategy.use_cached_ckpt_structure = self.assume_constant_structure

if self.parallel_save:
parallelization_group = (
get_data_parallel_group(with_context_parallel=True) if self.parallel_save_within_dp else None
)
save_strategy = FullyParallelSaveStrategyWrapper(
save_strategy, get_data_parallel_group(with_context_parallel=True), self.assume_constant_structure
save_strategy, parallelization_group, self.assume_constant_structure
)

logging.info(f'Using {save_strategy} dist-ckpt save strategy.')
Expand Down
3 changes: 3 additions & 0 deletions nemo/lightning/pytorch/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
ckpt_torch_dist_multiproc=None, ## TODO(ashors): put elsewhere?
ckpt_assume_constant_structure=False,
ckpt_parallel_save=True,
ckpt_parallel_save_within_dp=False,
ckpt_parallel_load=False,
ckpt_parallel_save_optim=True,
ckpt_load_directly_on_device=True,
Expand Down Expand Up @@ -144,6 +145,7 @@ def __init__(
self.torch_dist_multiproc = ckpt_torch_dist_multiproc
self.assume_constant_structure = ckpt_assume_constant_structure
self.parallel_save = ckpt_parallel_save
self.parallel_save_within_dp = ckpt_parallel_save_within_dp
self.parallel_load = ckpt_parallel_load
self.parallel_save_optim = ckpt_parallel_save_optim
self.load_directly_on_device = ckpt_load_directly_on_device
Expand Down Expand Up @@ -580,6 +582,7 @@ def checkpoint_io(self) -> CheckpointIO:
torch_dist_multiproc=self.torch_dist_multiproc,
assume_constant_structure=self.assume_constant_structure,
parallel_save=self.parallel_save,
parallel_save_within_dp=self.parallel_save_within_dp,
parallel_load=self.parallel_load,
load_directly_on_device=self.load_directly_on_device,
)
Expand Down
8 changes: 7 additions & 1 deletion nemo/utils/callbacks/dist_ckpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def __init__(
torch_dist_multiproc: Optional[int] = None,
assume_constant_structure: bool = False,
parallel_save: bool = False,
parallel_save_within_dp: bool = False,
parallel_load: bool = False,
):
super().__init__()
Expand All @@ -218,6 +219,7 @@ def __init__(
self.torch_dist_multiproc = torch_dist_multiproc
self.assume_constant_structure = assume_constant_structure
self.parallel_save = parallel_save
self.parallel_save_within_dp = parallel_save_within_dp
self.parallel_load = parallel_load

self._save_sharded_strategy = None
Expand All @@ -239,6 +241,7 @@ def from_config(cls, model_cfg: dict, async_save: bool = False):
async_save=async_save,
torch_dist_multiproc=model_cfg.get('dist_ckpt_torch_dist_multiproc', None),
parallel_save=model_cfg.get('dist_ckpt_parallel_save', False),
parallel_save_within_dp=model_cfg.get('dist_ckpt_parallel_save_within_dp', False),
parallel_load=model_cfg.get('dist_ckpt_parallel_load', False),
)

Expand Down Expand Up @@ -377,8 +380,11 @@ def _determine_dist_ckpt_save_strategy(self):
save_strategy.use_cached_ckpt_structure = self.assume_constant_structure

if self.parallel_save:
parallelization_group = (
get_data_parallel_group(with_context_parallel=True) if self.parallel_save_within_dp else None
)
save_strategy = FullyParallelSaveStrategyWrapper(
save_strategy, get_data_parallel_group(with_context_parallel=True), self.assume_constant_structure
save_strategy, parallelization_group, self.assume_constant_structure
)

logging.info(f'Using {save_strategy} dist-ckpt save strategy.')
Expand Down

0 comments on commit 67e05c1

Please sign in to comment.