Skip to content

Commit

Permalink
Fix: Resolve #3060, preload_module_classes is lost for nested modul…
Browse files Browse the repository at this point in the history
…es (#3248)

* resolve 3060

* format

* add tests

* fix

* fix

* format
  • Loading branch information
wejoncy authored Dec 3, 2024
1 parent f8c77f0 commit 60461ff
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/accelerate/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,13 @@ def attach_execution_device_hook(
return

for child in module.children():
attach_execution_device_hook(child, execution_device, skip_keys=skip_keys, tied_params_map=tied_params_map)
attach_execution_device_hook(
child,
execution_device,
skip_keys=skip_keys,
preload_module_classes=preload_module_classes,
tied_params_map=tied_params_map,
)


def attach_align_device_hook(
Expand Down
63 changes: 63 additions & 0 deletions tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from accelerate.state import GradientState, PartialState
from accelerate.test_utils import (
require_bnb,
require_huggingface_suite,
require_multi_gpu,
require_non_cpu,
require_transformer_engine,
Expand Down Expand Up @@ -762,3 +763,65 @@ def test_save_model_with_stateful_dataloader(self, use_safetensors, tied_weights
assert torch.allclose(original_linear1, new_linear1)
assert torch.allclose(original_batchnorm, new_batchnorm)
assert torch.allclose(original_linear2, new_linear2)

@require_cuda
@require_huggingface_suite
def test_nested_hook(self, use_safetensors):
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel

class MyLinear(torch.nn.Module):
def __init__(self, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.centroid = torch.nn.Embedding(1, 2)
self.indices = torch.nn.parameter(torch.empty((1, 2, 2), **factory_kwargs))

def forward(self, x):
orig_shape = x.shape
x = torch.abs(x + self.indices).long()
x = x % 2
x = x.sum(-1)
x = (self.centroid.weight + x).reshape(orig_shape)
return x

class MySubModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = MyLinear()

def forward(self, x):
return self.layer(x)

class MyModel(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.layer = torch.nn.ModuleList([MySubModel() for i in range(4)])

def forward(self, x):
for layer in self.layer:
x = layer(x)
return x

with tempfile.TemporaryDirectory() as tmpdirname:
check_point = tmpdirname
offload_folder = check_point + "/offload"
os.makedirs(offload_folder, exist_ok=True)
config = PretrainedConfig()
m = MyModel(config)
m.save_pretrained(check_point)

with init_empty_weights():
my_model = MyModel(config)
my_model = load_checkpoint_and_dispatch(
my_model,
checkpoint=check_point,
max_memory={"cpu": 60, 0: 60},
device_map="auto",
no_split_module_classes=["MySubModel"],
offload_folder=offload_folder,
preload_module_classes=["MyLinear"],
)
# before fix, this would raise an error
# weight is on the meta device, we need a `value` to put in on 0
x = torch.randn(1, 2)
my_model(x)

0 comments on commit 60461ff

Please sign in to comment.