diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 8f541e5703e6..162958737312 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1932,7 +1932,10 @@ def on_load_checkpoint(self, checkpoint) -> None: key.replace('model.', ''): checkpoint_state_dict.pop(key) for key in list(checkpoint_state_dict.keys()) } - module.load_state_dict(checkpoint_state_dict, strict=True) + dist_ckpt_strict = self.cfg.get('dist_ckpt_load_strictness', None) + module.load_state_dict( + checkpoint_state_dict, strict=dist_ckpt_strict is None or dist_ckpt_strict != 'log_all' + ) else: # when restoring a distributed checkpoint from a ptl checkpoint we need to defer loading the state_dict # see NLPModel.on_load_checkpoint diff --git a/tests/core/test_dist_ckpt.py b/tests/core/test_dist_ckpt.py index 8fe21a316854..4f8a64e22c21 100644 --- a/tests/core/test_dist_ckpt.py +++ b/tests/core/test_dist_ckpt.py @@ -8,6 +8,7 @@ import torch from lightning_fabric.plugins import TorchCheckpointIO from pytorch_lightning.demos.boring_classes import BoringModel +from torch import Tensor from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy from nemo.utils.callbacks.dist_ckpt_io import ( @@ -33,7 +34,12 @@ def on_validation_epoch_end(self) -> None: class ExampleMCoreModel(ExampleModel): def sharded_state_dict(self): return { - 'a': ShardedTensor.from_rank_offsets('a', self.layer.weight, replica_id=torch.distributed.get_rank()), + 'layer.weight': ShardedTensor.from_rank_offsets( + 'a', self.layer.weight, replica_id=torch.distributed.get_rank() + ), + 'layer.bias': ShardedTensor.from_rank_offsets( + 'a.bias', self.layer.bias, replica_id=torch.distributed.get_rank() + ), 'const': 3, } @@ -68,7 +74,7 @@ def _get_nlp_strategy_without_optimizer_state(): strategy = NLPDDPStrategy() # this ensures optimizer sharded state creation is skipped strategy.optimizer_sharded_state_dict = types.MethodType( - lambda self, unsharded_optim_state: unsharded_optim_state, strategy + lambda self, unsharded_optim_state={}, is_loading=False: unsharded_optim_state, strategy ) return strategy @@ -176,4 +182,73 @@ def test_async_save_produces_same_checkpoints_as_sync(self, tmp_path): ) assert sync_state_dict['sharded_state_dict']['const'] == async_state_dict['sharded_state_dict']['const'] - assert torch.all(sync_state_dict['sharded_state_dict']['a'] == async_state_dict['sharded_state_dict']['a']) + assert torch.all( + sync_state_dict['sharded_state_dict']['layer.weight'] + == async_state_dict['sharded_state_dict']['layer.weight'] + ) + + +class TestLoadStrictness: + class ExampleMCoreModelExtraHead(ExampleMCoreModel): + def __init__(self): + super().__init__() + self.extra_head = torch.nn.Linear(2, 4) + + def forward(self, x: Tensor) -> Tensor: + x = super().forward(x) + return self.extra_head(x) + + def sharded_state_dict(self): + sharded_sd = super().sharded_state_dict() + sharded_sd['extra_head.weight'] = ShardedTensor.from_rank_offsets( + 'extra_head.weight', self.extra_head.weight, replica_id=torch.distributed.get_rank() + ) + sharded_sd['extra_head.bias'] = ShardedTensor.from_rank_offsets( + 'extra_head.bias', self.extra_head.bias, replica_id=torch.distributed.get_rank() + ) + return sharded_sd + + def on_load_checkpoint(self, checkpoint): + self.load_state_dict(checkpoint['state_dict'], strict=False) + + @pytest.mark.run_only_on('GPU') + def test_load_strictness(self, tmp_path): + strategy = NLPDDPStrategy() + sync_checkpoint_io = DistributedCheckpointIO('torch_dist', load_strictness='log_all') + + model = ExampleMCoreModel() + + # dummy_trainer just to initialize NCCL + dummy_trainer = pl.Trainer( + enable_checkpointing=False, + logger=False, + max_epochs=1, + strategy=NLPDDPStrategy(), + plugins=[sync_checkpoint_io], + ) + dummy_trainer.fit(model) + tmp_path = strategy.broadcast(tmp_path) + + sync_ckpt_dir = tmp_path / 'sync_checkpoints' + + test_trainer = pl.Trainer( + enable_checkpointing=True, + logger=False, + max_epochs=1, + strategy=NLPDDPStrategy(), + plugins=[sync_checkpoint_io], + default_root_dir=sync_ckpt_dir, + ) + test_trainer.fit(model) + + # Simulate finetuning with an extra head + extra_head_model = TestLoadStrictness.ExampleMCoreModelExtraHead() + finetuning_trainer = pl.Trainer( + enable_checkpointing=True, + logger=False, + max_epochs=2, + strategy=NLPDDPStrategy(), + plugins=[sync_checkpoint_io], + default_root_dir=sync_ckpt_dir, + ) + finetuning_trainer.fit(extra_head_model, ckpt_path=_get_last_checkpoint_dir(sync_ckpt_dir, model))