Skip to content

Commit

Permalink
Fix infer_auto_device_map when tied weights share the same prefix name (
Browse files Browse the repository at this point in the history
#2324)

* fix auto device map with tied weights sharing a prefix name

Co-authored-by: Giuseppe Franco <giuseppefranco4@gmail.com>
Co-authored-by: Nick Fraser <icanlosh@gmail.com>

* precise comment

---------

Co-authored-by: Giuseppe Franco <giuseppefranco4@gmail.com>
Co-authored-by: Nick Fraser <icanlosh@gmail.com>
  • Loading branch information
3 people authored Jan 10, 2024
1 parent 456afd9 commit e3e9b87
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,15 +1061,22 @@ def infer_auto_device_map(

# We keep relevant tied parameters only: one of the tied parameters in the group is inside the current module
# and the other is not.
# Note: If we are currently processing the name `compute.weight`, an other parameter named e.g. `compute.weight_submodule.parameter`
# needs to be considered outside the current module, hence the check with additional dots.
tied_param_goups = [
tied_group
for tied_group in tied_parameters
if any(name in k for k in tied_group) and not all(name in k for k in tied_group)
if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group)
]

if verbose and len(tied_param_goups) > 0:
print(f" Found the relevant tied param groups {tied_param_goups}")

# Then we keep track of all the parameters that are tied to the current module, but not in the current module
tied_params = sum([[p for p in tied_group if name not in p] for tied_group in tied_param_goups], [])
tied_params = sum(
[[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_goups], []
)

if verbose and len(tied_params) > 0:
print(f" So those parameters need to be taken into account {tied_params}")

Expand Down
30 changes: 30 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,36 @@ def test_infer_auto_device_map_with_tied_weights(self):
expected = {"linear1": 0, "linear2": 1, "linear3": 0, "linear4": 1}
self.assertDictEqual(device_map, expected)

# With tied weights sharing a same prefix name (`compute.weight` vs `compute.weight_submodule.parameter`)
class SubModule(torch.nn.Module):
def __init__(self, ref_to_parameter):
super().__init__()
self.parameter = ref_to_parameter

def forward(self, x):
return self.x + torch.max(self.parameter)

class LinearModuleAndSubModule(torch.nn.Linear):
def __init__(self, in_features, out_features):
super().__init__(in_features, out_features)
self.weight_submodule = SubModule(self.weight)

def forward(self, x):
return torch.nn.functional.linear(self.weight_submodule(x), self.weight)

class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.compute = LinearModuleAndSubModule(3, 8)

def forward(self, x):
return self.compute(x)

model = Model()

device_memory = {0: 4, "cpu": 96000} # Low memory device, just to force splitting and trigger the error
infer_auto_device_map(model, device_memory)

@require_huggingface_suite
def test_infer_auto_device_map_on_t0pp(self):
from transformers import AutoConfig, AutoModelForSeq2SeqLM
Expand Down

0 comments on commit e3e9b87

Please sign in to comment.