Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

benchmark script for simple_gla vs mamba2 kernel #50

Merged
merged 1 commit into from
Aug 18, 2024

Conversation

learning-chip
Copy link
Contributor

@learning-chip learning-chip commented Aug 18, 2024

Follow-up #49

Amazingly, it seems like chunk_simple_gla is much faster than mamba_chunk_scan_combined:

$ python ./benchmark_simple_gla_vs_mamba2.py

Performance:
         T  chunk_simple_gla  mamba2_ssd
0     64.0          0.084992    0.840208
1    128.0          0.100352    0.847920
2    256.0          0.100368    0.848896
3    512.0          0.174080    0.873472
4   1024.0          0.399360    0.880208
5   2048.0          0.776352    1.596416
6   4096.0          1.526784    3.160064
7   8192.0          3.067904    6.251520
8  16384.0          6.220800   12.452864

Performance

I left many TODO and NOTE in the benchmark scripts, including:

  • Testing more input shapes
  • Tuning block size
  • analyze impact of input memory layout

More importantly:

  • more detailed profiling to understand why exactly it is faster.

Maybe mamba-2 kernel incurs more memory IO (less "fused")? And why the short-sequence performance (T<256) differs by so much?

@yzhangcs
Copy link
Member

@learning-chip Great job! Appreciate your quick actions.

@yzhangcs yzhangcs merged commit c60ada3 into fla-org:main Aug 18, 2024
1 check passed
@sustcsonglin
Copy link
Collaborator

@learning-chip Mamba2’s official kernel involves three main steps: 1) computation of each chunk’s last hidden state, 2) recurrence at the chunk level, and 3) output computation.

For steps 1) and 2), it stores/loads the hidden state in FP32, which incurs significant I/O costs.

FLA’s implementation fuses steps 1) and 2), avoids materializing the FP32 hidden state after step 1) and stores only the BF16 hidden state after 2), thus reducing I/O costs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants