Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance fla support for RWKV6 #44

Closed
wants to merge 24 commits into from
Closed

Conversation

uniartisan
Copy link
Contributor

This pull request aims at enhance fla support for RWKV6, both speed and perfermance on bf16. Also , enable fla on Intel cards.

FLA ChunkRWKV6 Optimized Implementation

This repository contains an optimized implementation of ChunkRWKV6 using FLA (Flash Attention) techniques. Our goal is to simultaneously improve both accuracy and speed compared to standard CUDA implementations.

Performance Comparison

We've conducted performance tests comparing our FLA BF16 implementation with the standard CUDA BF16 implementation. Here are some key results:

Test Case 1: B=32, T=4096, C=4096, HEAD_SIZE=64

Implementation Forward Time Backward Time
CUDA BF16 32.80 ms 148.05 ms
FLA BF16 50.17 ms 162.42 ms

Test Case 2: B=8, T=4096, C=4096, HEAD_SIZE=64

Implementation Forward Time Backward Time
CUDA BF16 9.69 ms 46.41 ms
FLA BF16 13.06 ms 40.79 ms

Where:

  • B: Batch size
  • T: Token length
  • C: Hidden layer dimension
  • HEAD_SIZE: Size of attention heads

Accuracy

We've measured the error ratios compared to FP32 CUDA implementations for various components. Our chunkRWKV6 FLA implementation achieves error levels consistent with CUDA implementations:

y:  0.0020138283862787135
gr: 0.00250389610197927
gk: 0.002499128980485113
gv: 0.0028262425242107
gw: 0.0027358097395330894
gu: 0.001821853127644057

@uniartisan
Copy link
Contributor Author

Please try to squash merge :)

@yzhangcs
Copy link
Member

@uniartisan Hello, many thanks for these great contributions!
I will make some checks soon.
However, could you restrict the revisions to the RWKV6 chunk only? You've defined many decorators for other purposes that are unrelated to this PR title. I think it would be better to create a separate PR for those changes.
Additionally, please note that there are some formatting errors that are not aligned with PEP8 standards.

@yzhangcs
Copy link
Member

Also it is not recommended to truncate the spaces at the end of each line in README file, as they are sometimes used as line breaks.

@uniartisan
Copy link
Contributor Author

uniartisan commented Aug 14, 2024

Your suggestion makes a lot of sense. Some of these changes were introduced by the edittor. I'll try to first limit the changes to chunkrwkv6 and fix the test

@uniartisan uniartisan force-pushed the enhance branch 5 times, most recently from fb728e7 to b99a7a1 Compare August 14, 2024 05:46
@uniartisan
Copy link
Contributor Author

checkrwkv6.tar.gz
Here are the codes that compare CUDA with FLA.

@uniartisan
Copy link
Contributor Author

Also, this pull request fixed #29
The problem was introduced by bfloat16 when calculating dq and dk. By converting to float32 when necessary and using tf32 as much as possible, and changing the group sequence, the pull request speeds up and achieves the same accuracy as the CUDA implementation (pure fp32 internal).

@yzhangcs
Copy link
Member

@uniartisan Hi, just make some reviews, could you have a check?

@uniartisan
Copy link
Contributor Author

@uniartisan Hi, just make some reviews, could you have a check?

hi. I can't see any comments, could you tell me where could I have a check?

@yzhangcs
Copy link
Member

@uniartisan Can you see msgs in your notice box

@uniartisan
Copy link
Contributor Author

image
Could you give me a review like this? https://github.com/sustcsonglin/flash-linear-attention/pull/44/files/4a3e2bb1d699c7e41ead7adc2f2403fb3e79ceb6

I can't see your msgs :(

@yzhangcs
Copy link
Member

@uniartisan sure, sorry for my late reply

@yzhangcs
Copy link
Member

@uniartisan Can you see my updated comments between the lines?

@uniartisan
Copy link
Contributor Author

@uniartisan Can you see my updated comments between the lines?

Sorry, I don't know what's going on. I still cannot see you review comments. Maybe you can directly post them here.😎

@yzhangcs yzhangcs self-requested a review August 25, 2024 17:22
@uniartisan uniartisan force-pushed the enhance branch 2 times, most recently from 9926634 to 49a8951 Compare August 26, 2024 06:05
@uniartisan
Copy link
Contributor Author

@yzhangcs Hello,
I hope finds you well. I have successfully synchronized all the latest changes to your project. Given your expertise and valuable insights, I was wondering if you could kindly take some time to review these updates at your earliest convenience.
Your feedback is crucial to ensure we're on the right track, and I greatly appreciate your assistance in this matter. :)

@yzhangcs
Copy link
Member

@uniartisan Thank you for the update. I'm running your code locally as there is no CI w/ GPUs. Will sync with you recently.

@yzhangcs
Copy link
Member

@uniartisan Hi, can you authorize this branch to me so that I can make some updates

@uniartisan
Copy link
Contributor Author

Hi, can you authorize this branch to me so that I can make some updates

Of course!!! Sorry for my late reply. I will try it :)

@yzhangcs
Copy link
Member

@uniartisan Hi, closing this PR as new features are too coupled. @sustcsonglin just pushed some new commits resolving the RWKV6 precision problems. Checkout those for more details. You can create new PRs if sth could be improved.

Again, thank you for your contributions and hard work!

@yzhangcs yzhangcs closed this Sep 23, 2024
@uniartisan uniartisan deleted the enhance branch September 27, 2024 08:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants