Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Respect load strictness when calling load_state_dict #10665

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1919,7 +1919,10 @@ def on_load_checkpoint(self, checkpoint) -> None:
key.replace('model.', ''): checkpoint_state_dict.pop(key)
for key in list(checkpoint_state_dict.keys())
}
module.load_state_dict(checkpoint_state_dict, strict=True)
dist_ckpt_strict = self.cfg.get('dist_ckpt_load_strictness', None)
module.load_state_dict(
checkpoint_state_dict, strict=dist_ckpt_strict is None or dist_ckpt_strict != 'log_all'
)
else:
# when restoring a distributed checkpoint from a ptl checkpoint we need to defer loading the state_dict
# see NLPModel.on_load_checkpoint
Expand Down
82 changes: 79 additions & 3 deletions tests/core/test_dist_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import torch
from lightning_fabric.plugins import TorchCheckpointIO
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.trainer import call

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'call' is not used.
from torch import Tensor

from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
from nemo.utils.callbacks.dist_ckpt_io import (
Expand All @@ -33,7 +35,12 @@
class ExampleMCoreModel(ExampleModel):
def sharded_state_dict(self):
return {
'a': ShardedTensor.from_rank_offsets('a', self.layer.weight, replica_id=torch.distributed.get_rank()),
'layer.weight': ShardedTensor.from_rank_offsets(
'a', self.layer.weight, replica_id=torch.distributed.get_rank()
),
'layer.bias': ShardedTensor.from_rank_offsets(
'a.bias', self.layer.bias, replica_id=torch.distributed.get_rank()
),
'const': 3,
}

Expand Down Expand Up @@ -68,7 +75,7 @@
strategy = NLPDDPStrategy()
# this ensures optimizer sharded state creation is skipped
strategy.optimizer_sharded_state_dict = types.MethodType(
lambda self, unsharded_optim_state: unsharded_optim_state, strategy
lambda self, unsharded_optim_state={}, is_loading=False: unsharded_optim_state, strategy
)
return strategy

Expand Down Expand Up @@ -176,4 +183,73 @@
)

assert sync_state_dict['sharded_state_dict']['const'] == async_state_dict['sharded_state_dict']['const']
assert torch.all(sync_state_dict['sharded_state_dict']['a'] == async_state_dict['sharded_state_dict']['a'])
assert torch.all(
sync_state_dict['sharded_state_dict']['layer.weight']
== async_state_dict['sharded_state_dict']['layer.weight']
)


class TestLoadStrictness:
class ExampleMCoreModelExtraHead(ExampleMCoreModel):
def __init__(self):
super().__init__()
self.extra_head = torch.nn.Linear(2, 4)

def forward(self, x: Tensor) -> Tensor:
x = super().forward(x)
return self.extra_head(x)

def sharded_state_dict(self):
sharded_sd = super().sharded_state_dict()
sharded_sd['extra_head.weight'] = ShardedTensor.from_rank_offsets(
'extra_head.weight', self.extra_head.weight, replica_id=torch.distributed.get_rank()
)
sharded_sd['extra_head.bias'] = ShardedTensor.from_rank_offsets(
'extra_head.bias', self.extra_head.bias, replica_id=torch.distributed.get_rank()
)
return sharded_sd

def on_load_checkpoint(self, checkpoint):
self.load_state_dict(checkpoint['state_dict'], strict=False)

@pytest.mark.run_only_on('GPU')
def test_load_strictness(self, tmp_path):
strategy = NLPDDPStrategy()
sync_checkpoint_io = DistributedCheckpointIO('torch_dist', load_strictness='log_all')

model = ExampleMCoreModel()

# dummy_trainer just to initialize NCCL
dummy_trainer = pl.Trainer(
enable_checkpointing=False,
logger=False,
max_epochs=1,
strategy=NLPDDPStrategy(),
plugins=[sync_checkpoint_io],
)
dummy_trainer.fit(model)
tmp_path = strategy.broadcast(tmp_path)

sync_ckpt_dir = tmp_path / 'sync_checkpoints'

test_trainer = pl.Trainer(
enable_checkpointing=True,
logger=False,
max_epochs=1,
strategy=NLPDDPStrategy(),
plugins=[sync_checkpoint_io],
default_root_dir=sync_ckpt_dir,
)
test_trainer.fit(model)

# Simulate finetuning with an extra head
extra_head_model = TestLoadStrictness.ExampleMCoreModelExtraHead()
finetuning_trainer = pl.Trainer(
enable_checkpointing=True,
logger=False,
max_epochs=2,
strategy=NLPDDPStrategy(),
plugins=[sync_checkpoint_io],
default_root_dir=sync_ckpt_dir,
)
finetuning_trainer.fit(extra_head_model, ckpt_path=_get_last_checkpoint_dir(sync_ckpt_dir, model))
Loading