Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: extend the unwrap_model function and save unwrapped model state dict instead of wrapped #29780

30 changes: 23 additions & 7 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4661,16 +4661,32 @@ def forward(

def unwrap_model(model: nn.Module) -> nn.Module:
"""
Recursively unwraps a model from potential containers (as used in distributed training).
Recursively unwraps a module and its child sublayers.

Args:
model (`torch.nn.Module`): The model to unwrap.
model (nn.Module): The model to unwrap.

Returns:
nn.Module: The unwrapped module.
"""
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return unwrap_model(model.module)
else:
return model

def recursive_unwrap(module):
if hasattr(module, "module"):
try:
unwrapped_module = recursive_unwrap(getattr(module, "module"))
except AttributeError:
unwrapped_module = module # Handle cases where wrapped module is inaccessible
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you give an example of when this happens? It seems weird we'd have hasattr(module, "module") evaluate as True but then we can't do getattr(module, "module")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right @amyeroberts. It does seem weird, I don't remember why I implemented like this, but thanks for pointing it out. I also can't think of an example.

I am fixing it.

return unwrapped_module

# Unwrap child sublayers recursively
for name, child in module.named_children():
setattr(module, name, recursive_unwrap(child))

return module

# Start with top-level unwrapping
unwrapped_model = recursive_unwrap(model)
return unwrapped_model


def expand_device_map(device_map, param_names, start_prefix):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3174,7 +3174,7 @@ def _save_tpu(self, output_dir: Optional[str] = None):
unwrap_model(model).save_pretrained(
output_dir,
is_main_process=self.args.should_save,
state_dict=model.state_dict(),
state_dict=unwrap_model(model).state_dict(),
save_function=xm.save,
safe_serialization=self.args.save_safetensors,
)
Expand Down
Loading