From c4ae27a1853d4c25fa5764b5e2fe401c01a760fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20B=C5=82a=C5=BC?= Date: Wed, 25 Sep 2024 20:06:31 +0200 Subject: [PATCH 1/3] Add unit test to test strict load --- tests/core/test_dist_ckpt.py | 71 ++++++++++++++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 3 deletions(-) diff --git a/tests/core/test_dist_ckpt.py b/tests/core/test_dist_ckpt.py index 8fe21a316854..64a22307f9ac 100644 --- a/tests/core/test_dist_ckpt.py +++ b/tests/core/test_dist_ckpt.py @@ -8,6 +8,8 @@ import torch from lightning_fabric.plugins import TorchCheckpointIO from pytorch_lightning.demos.boring_classes import BoringModel +from pytorch_lightning.trainer import call +from torch import Tensor from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy from nemo.utils.callbacks.dist_ckpt_io import ( @@ -33,7 +35,8 @@ def on_validation_epoch_end(self) -> None: class ExampleMCoreModel(ExampleModel): def sharded_state_dict(self): return { - 'a': ShardedTensor.from_rank_offsets('a', self.layer.weight, replica_id=torch.distributed.get_rank()), + 'layer.weight': ShardedTensor.from_rank_offsets('a', self.layer.weight, replica_id=torch.distributed.get_rank()), + 'layer.bias': ShardedTensor.from_rank_offsets('a.bias', self.layer.bias, replica_id=torch.distributed.get_rank()), 'const': 3, } @@ -68,7 +71,7 @@ def _get_nlp_strategy_without_optimizer_state(): strategy = NLPDDPStrategy() # this ensures optimizer sharded state creation is skipped strategy.optimizer_sharded_state_dict = types.MethodType( - lambda self, unsharded_optim_state: unsharded_optim_state, strategy + lambda self, unsharded_optim_state={}, is_loading=False: unsharded_optim_state, strategy ) return strategy @@ -176,4 +179,66 @@ def test_async_save_produces_same_checkpoints_as_sync(self, tmp_path): ) assert sync_state_dict['sharded_state_dict']['const'] == async_state_dict['sharded_state_dict']['const'] - assert torch.all(sync_state_dict['sharded_state_dict']['a'] == async_state_dict['sharded_state_dict']['a']) + assert torch.all(sync_state_dict['sharded_state_dict']['layer.weight'] == async_state_dict['sharded_state_dict']['layer.weight']) + + +class TestLoadStrictness: + class ExampleMCoreModelExtraHead(ExampleMCoreModel): + def __init__(self): + super().__init__() + self.extra_head = torch.nn.Linear(2, 4) + + def forward(self, x: Tensor) -> Tensor: + x = super().forward(x) + return self.extra_head(x) + + def sharded_state_dict(self): + sharded_sd = super().sharded_state_dict() + sharded_sd['extra_head.weight'] = ShardedTensor.from_rank_offsets('extra_head.weight', self.extra_head.weight, replica_id=torch.distributed.get_rank()) + sharded_sd['extra_head.bias'] = ShardedTensor.from_rank_offsets('extra_head.bias', self.extra_head.bias, replica_id=torch.distributed.get_rank()) + return sharded_sd + + def on_load_checkpoint(self, checkpoint): + self.load_state_dict(checkpoint['state_dict'], strict=False) + + @pytest.mark.run_only_on('GPU') + def test_load_strictness(self, tmp_path): + strategy = NLPDDPStrategy() + sync_checkpoint_io = DistributedCheckpointIO('torch_dist', load_strictness='log_all') + + model = ExampleMCoreModel() + + # dummy_trainer just to initialize NCCL + dummy_trainer = pl.Trainer( + enable_checkpointing=False, + logger=False, + max_epochs=1, + strategy=NLPDDPStrategy(), + plugins=[sync_checkpoint_io], + ) + dummy_trainer.fit(model) + tmp_path = strategy.broadcast(tmp_path) + + sync_ckpt_dir = tmp_path / 'sync_checkpoints' + + test_trainer = pl.Trainer( + enable_checkpointing=True, + logger=False, + max_epochs=1, + strategy=NLPDDPStrategy(), + plugins=[sync_checkpoint_io], + default_root_dir=sync_ckpt_dir, + ) + test_trainer.fit(model) + + # Simulate finetuning with an extra head + extra_head_model = TestLoadStrictness.ExampleMCoreModelExtraHead() + finetuning_trainer = pl.Trainer( + enable_checkpointing=True, + logger=False, + max_epochs=2, + strategy=NLPDDPStrategy(), + plugins=[sync_checkpoint_io], + default_root_dir=sync_ckpt_dir, + ) + finetuning_trainer.fit(extra_head_model, ckpt_path=_get_last_checkpoint_dir(sync_ckpt_dir, model)) From 1babc4bcb7416ec062de618288c7b7892e4b2ae6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20B=C5=82a=C5=BC?= Date: Fri, 27 Sep 2024 20:00:45 +0200 Subject: [PATCH 2/3] Add check in on_load_checkpoint for GPT --- .../nlp/models/language_modeling/megatron_gpt_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 2e842b5c6f7b..5ccd178f27cf 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1919,7 +1919,8 @@ def on_load_checkpoint(self, checkpoint) -> None: key.replace('model.', ''): checkpoint_state_dict.pop(key) for key in list(checkpoint_state_dict.keys()) } - module.load_state_dict(checkpoint_state_dict, strict=True) + dist_ckpt_strict = self.cfg.get('dist_ckpt_load_strictness', None) + module.load_state_dict(checkpoint_state_dict, strict=dist_ckpt_strict is None or dist_ckpt_strict != 'log_all') else: # when restoring a distributed checkpoint from a ptl checkpoint we need to defer loading the state_dict # see NLPModel.on_load_checkpoint From 1fcabbc4c3aad1a60d5d2892147db03ee22544ea Mon Sep 17 00:00:00 2001 From: mikolajblaz Date: Fri, 27 Sep 2024 18:04:32 +0000 Subject: [PATCH 3/3] Apply isort and black reformatting Signed-off-by: mikolajblaz --- .../language_modeling/megatron_gpt_model.py | 4 +++- tests/core/test_dist_ckpt.py | 21 ++++++++++++++----- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 5ccd178f27cf..95ea6cbd68b9 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1920,7 +1920,9 @@ def on_load_checkpoint(self, checkpoint) -> None: for key in list(checkpoint_state_dict.keys()) } dist_ckpt_strict = self.cfg.get('dist_ckpt_load_strictness', None) - module.load_state_dict(checkpoint_state_dict, strict=dist_ckpt_strict is None or dist_ckpt_strict != 'log_all') + module.load_state_dict( + checkpoint_state_dict, strict=dist_ckpt_strict is None or dist_ckpt_strict != 'log_all' + ) else: # when restoring a distributed checkpoint from a ptl checkpoint we need to defer loading the state_dict # see NLPModel.on_load_checkpoint diff --git a/tests/core/test_dist_ckpt.py b/tests/core/test_dist_ckpt.py index 64a22307f9ac..981ecf9ccf1d 100644 --- a/tests/core/test_dist_ckpt.py +++ b/tests/core/test_dist_ckpt.py @@ -35,8 +35,12 @@ def on_validation_epoch_end(self) -> None: class ExampleMCoreModel(ExampleModel): def sharded_state_dict(self): return { - 'layer.weight': ShardedTensor.from_rank_offsets('a', self.layer.weight, replica_id=torch.distributed.get_rank()), - 'layer.bias': ShardedTensor.from_rank_offsets('a.bias', self.layer.bias, replica_id=torch.distributed.get_rank()), + 'layer.weight': ShardedTensor.from_rank_offsets( + 'a', self.layer.weight, replica_id=torch.distributed.get_rank() + ), + 'layer.bias': ShardedTensor.from_rank_offsets( + 'a.bias', self.layer.bias, replica_id=torch.distributed.get_rank() + ), 'const': 3, } @@ -179,7 +183,10 @@ def test_async_save_produces_same_checkpoints_as_sync(self, tmp_path): ) assert sync_state_dict['sharded_state_dict']['const'] == async_state_dict['sharded_state_dict']['const'] - assert torch.all(sync_state_dict['sharded_state_dict']['layer.weight'] == async_state_dict['sharded_state_dict']['layer.weight']) + assert torch.all( + sync_state_dict['sharded_state_dict']['layer.weight'] + == async_state_dict['sharded_state_dict']['layer.weight'] + ) class TestLoadStrictness: @@ -194,8 +201,12 @@ def forward(self, x: Tensor) -> Tensor: def sharded_state_dict(self): sharded_sd = super().sharded_state_dict() - sharded_sd['extra_head.weight'] = ShardedTensor.from_rank_offsets('extra_head.weight', self.extra_head.weight, replica_id=torch.distributed.get_rank()) - sharded_sd['extra_head.bias'] = ShardedTensor.from_rank_offsets('extra_head.bias', self.extra_head.bias, replica_id=torch.distributed.get_rank()) + sharded_sd['extra_head.weight'] = ShardedTensor.from_rank_offsets( + 'extra_head.weight', self.extra_head.weight, replica_id=torch.distributed.get_rank() + ) + sharded_sd['extra_head.bias'] = ShardedTensor.from_rank_offsets( + 'extra_head.bias', self.extra_head.bias, replica_id=torch.distributed.get_rank() + ) return sharded_sd def on_load_checkpoint(self, checkpoint):