Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add Mamba2 #32027

Closed
wants to merge 93 commits into from
Closed

[WIP] Add Mamba2 #32027

wants to merge 93 commits into from

Conversation

vasqu
Copy link
Contributor

@vasqu vasqu commented Jul 17, 2024

What does this PR do?

As per title:
Paper: https://arxiv.org/abs/2405.21060
Repo: https://github.com/state-spaces/mamba

Mamba2 is a successor to Mamba which rethinks SSMs as a sort of special type of Attention (i.e. structured attention such as causal attention in decoder-only models). This implementation allows all architecture types, i.e. pure Mamba2, hybrid Mamba2-Attention, and pure Attention (we mostly follow the llama attention definition where possible). Maybe there's more interest after Mistral released their code model yesterday https://mistral.ai/news/codestral-mamba/ :)

There are still some TODOs left to do but the overall architecture and functionality should be there:

  • Caching with RoPE (unsure if it is even cached). Also if there are any transformations necessary to the weights for RoPE as done in Llama.
  • Additional warning about AMD compatibility (has been released after some time).
  • Update causal mask --> There has been something with static caches in llama; I doubt it affects us with the hybrid cache but just to be sure.
  • Hardware differences make it hard to gauge if some limits should be as high (see test_left_padding_compatibility).
  • Flash attention tests in general.
  • Integration tests in total.
  • Possibly allowing outputting the last SSM state of each block/layer, similarly to outputting attention weights.
  • Possibly allowing initial SSM states for the layers.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker @amyeroberts @gante @Adibvafa @pglorio

Adibvafa and others added 30 commits July 8, 2024 11:49
…v to flash_attn impl, make ssd_naive a class function, todos (e.g. rope caching)
… specific for the used shapes, other small nits on names/comments
@ArthurZucker ArthurZucker mentioned this pull request Jul 18, 2024
fix sequence classifier with copies and correct prefix for backbone model
@ruipeterpan
Copy link

ruipeterpan commented Jul 19, 2024

Thanks for the great work 🫡! Quick question about the causal convolution part in triton_kernels_forward. During prefill (generating the first token), cached_forward is False so causal_conv1d_fn() is invoked. During autoregressive decode, cached_forward is now True so causal_conv1d_update() is invoked. When doing so, xBC has shape (batch_size, 1, dim). However, causal_conv1d_update() requires x=xBC to have shape (batch_size, dim) or (batch, dim, seqlen). Are we missing a reshape operation on xBC ("b l d -> b d l"), like in https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba2.py#L293? A similar issue occurs when passing in past_key_values to model.generate(): cached_forward is True so causal_conv1d_update() is invoked whereas xBC has shape (batch_size, seq_len, dim).

Thanks in advance!

@vasqu
Copy link
Contributor Author

vasqu commented Jul 19, 2024

@ruipeterpan That's a great catch!

I think a simple transpose(1,2) should fix it on more recent versions (1.4=<) whereas a squeeze will mostly work on older versions too. Would likely need a shape check since the versions handle it differently outputting different shapes (i.e. (bsz, dim) vs (bsz, dim, 1)). Nvm, I missed a squeeze in the new code release.

For the generate issue. If you pass an initial cache, it should have the attribute has_previous_state set to False. So that the first pass should be a non cached forward call and then we re-enter the first scenario (which is bugged :D). Or is that not the case and even the first call has a false flag? (cant execute code atm)

@ruipeterpan
Copy link

Thanks for the clarification -- for using past_key_values, do you mean we need to manually set past_ket_values.has_previous_state to False before passing it in? My usage is as follows, not sure if I'm doing this correctly:

out = model.generate(input_ids, return_dict_in_generate=True)
past_key_values = out.past_key_values
# past_key_values.has_previous_state = False  # adding this LOC resolves the issue, thanks!
out = model.generate(other_input_ids, past_key_values=past_key_values)

@vasqu
Copy link
Contributor Author

vasqu commented Jul 20, 2024

Yup, that's how you would do it. The problem here is that mamba can only decode on a one-by-one basis so expecting it to have seq_len > 1 is incompatible on the first pass (with a cache). You basically have to reset the cache.

I do admit that it's rather unintuitive tho to reset such a flag. Should be cleaner with a separate method that handles it.

@Adibvafa Adibvafa mentioned this pull request Jul 21, 2024
11 tasks
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@vasqu vasqu closed this Aug 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants