-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mamba / FalconMamba: Fix mamba left padding #32677
Mamba / FalconMamba: Fix mamba left padding #32677
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @younesbelkada for adding the states tuning-out! 😁 left a couple comments, mostly curious of some situations that were edge cases for mamba 2
Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
Can we propagate this to Jamba as well :D thx for this fix ❤️ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! pinging @ArthurZucker for merging 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding a test 🤗
# In case cache is not used, manually add a new column in the attention mask | ||
if not use_cache and attention_mask is not None and input_ids.shape != attention_mask.shape: | ||
pad_length = input_ids.shape[-1] - attention_mask.shape[-1] | ||
attention_mask = torch.cat([attention_mask, torch.ones_like(input_ids[:, :pad_length])], dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand why we are adding a [1] x batch_size
? ( past_length is usually gonna be 1 - current_generation_token
, so imagine 20 input ids, then -19 to slice the input_ids?
Unless the inpud_ids is 20, but then it always has the same shape as the mask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for users that run generation with use_cache=False
and makes sure to manually update the attention mask because this is done no where else except here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then this is more a problem with generate
as it should pass the correct attention mask 😓
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will include this in the patch 🤗
# In case cache is not used, manually update the attention mask | ||
if not use_cache and attention_mask is not None and input_ids.shape != attention_mask.shape: | ||
past_length = input_ids.shape[-1] - attention_mask.shape[-1] | ||
attention_mask = torch.cat([attention_mask, torch.ones_like(input_ids[:, :past_length])], dim=-1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's the only thing bothering me as generate with use_cache = False should not alter the attention mask being passed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes fixed it !
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -557,6 +574,7 @@ def set_input_embeddings(self, new_embeddings): | |||
def forward( | |||
self, | |||
input_ids: Optional[torch.LongTensor] = None, | |||
attention_mask: Optional[torch.LongTensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is breaking (having it as the second place)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes fixed it
if "attention_mask" in model_kwargs: | ||
attention_mask = model_kwargs["attention_mask"] | ||
model_kwargs["attention_mask"] = torch.cat( | ||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch !
* fix mamba left padding * Apply suggestions from code review Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * fix copies * test with `inputs_embeds` * Update src/transformers/models/falcon_mamba/modeling_falcon_mamba.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * copies * clairfy * fix last comments * remove --------- Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This reverts commit 91b799b.
What does this PR do?
As pointed out in #32080 (comment) - it is important to zero-out hidden states that corresponds to the padd tokens before and after the causal convolution so that the padd token will not have an impact on the calculated hidden states.
This can be empirically proven by generation quality before / after this fix (note by default FalconMamba uses left padding):
Before the fix:
After the fix:
Propagated the changes in Mamba1 as well
cc @ArthurZucker @molbap