Skip to content

Commit

Permalink
patch_tied_tensors_bug: support malformed model definitions (#1014)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
  • Loading branch information
kylesayrs authored Jan 2, 2025
1 parent a358598 commit 1b8c7bf
Showing 1 changed file with 4 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@ def patch_tied_tensors_bug(model: torch.nn.Module):
input_embed = model.get_input_embeddings()
output_embed = model.get_output_embeddings()

if input_embed is None or output_embed is None:
# some models fail to properly override the abstract methods
return

if storage_ptr(input_embed.weight) == storage_ptr(output_embed.weight):
for module in (input_embed, output_embed):
if not is_module_offloaded(module):
Expand Down

0 comments on commit 1b8c7bf

Please sign in to comment.