-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP
] Add Mamba2
#32027
[WIP
] Add Mamba2
#32027
Conversation
…al model (causal mask is missing todo)
…v to flash_attn impl, make ssd_naive a class function, todos (e.g. rope caching)
… specific for the used shapes, other small nits on names/comments
…update Refactor: Causal Mask Update and Prepare for Generate
Fix a lot of other remaining slow tests
…reshape instead of view; rope tests too
Fixes all remaining issues: Cache, attention, conversion, ...
Style and Quality
Add Mamba2ForSequenceClassficiation
Some fixes for the stuff introduced in PR #1
fix sequence classifier with copies and correct prefix for backbone model
Thanks for the great work 🫡! Quick question about the causal convolution part in Thanks in advance! |
@ruipeterpan That's a great catch! I think a simple transpose(1,2) should fix it on more recent versions (1.4=<) whereas a squeeze will mostly work on older versions too. For the generate issue. If you pass an initial cache, it should have the attribute |
Thanks for the clarification -- for using past_key_values, do you mean we need to manually set
|
Yup, that's how you would do it. The problem here is that mamba can only decode on a one-by-one basis so expecting it to have seq_len > 1 is incompatible on the first pass (with a cache). You basically have to reset the cache. I do admit that it's rather unintuitive tho to reset such a flag. Should be cleaner with a separate method that handles it. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
As per title:
Paper: https://arxiv.org/abs/2405.21060
Repo: https://github.com/state-spaces/mamba
Mamba2 is a successor to Mamba which rethinks SSMs as a sort of special type of Attention (i.e. structured attention such as causal attention in decoder-only models). This implementation allows all architecture types, i.e. pure Mamba2, hybrid Mamba2-Attention, and pure Attention (we mostly follow the llama attention definition where possible). Maybe there's more interest after Mistral released their code model yesterday https://mistral.ai/news/codestral-mamba/ :)
There are still some TODOs left to do but the overall architecture and functionality should be there:
test_left_padding_compatibility
).Before submitting
Pull Request section?
to it if that's the case. See Add Mamba2 #31204
documentation guidelines, and
here are tips on formatting docstrings. But see TODOs
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker @amyeroberts @gante @Adibvafa @pglorio