-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow custom softmax in memory efficient attention #530
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi @comaniac and thanks a lot for putting this pull request!
This change looks great - just a few things we need to modify: notably, it would be great to support any scale.
Regarding the cutlass kernel, it would require to modify the CUDA code indeed. I can guide you through this if you know C++ already and are interested :)
@danthe3rd thanks for the review. Yes I'm familiar with C++ (but not CUDA...) and could help support the CUTLASS one if the scale happens outside of its CUDA kernel. |
C++ knowledge is enough. CUDA is basically C++, and to do this modification you just need to pass another parameter all the way to this function and replace this value: xformers/xformers/components/attention/csrc/cuda/mem_eff_attention/kernel_forward.h Line 619 in fd21b40
But I can give more guidance once we are done with this PR :) |
Cool thanks for pointing out. I'll make an update later today. |
Hi @danthe3rd, I've done some implementations along with CUTLASS op (so that I could test it locally as I don't have A100 on hand) for your comments. Here are two issues I'm facing:
Any hints/suggestions would be appreciated. Thanks. |
I think I know what might happen. When you modify
|
Thanks! That's exactly the reason...now it works. |
yes that makes sense :) |
I've tested locally that CUTLASS ops work with custom scale. I'll let CI run for the Triton kernels to make sure I didn't break anything. Meanwhile, this PR should be ready for review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good! Great work!
Just a few more nits before we can get this merged
@danthe3rd all comments were addressed. PTAL. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great thanks!
I'll wait for the CI to become green, and test performance on A100 before merging just in case.
Looks like some tests are failing - you will need to modify |
Yeah I found that...fixed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @comaniac !
I ran some performance tests and it does not seem to affect performance (FW is very slightly faster, BW very slightly slower).
A100 fw
[------------------ attention (attn_bias=<class 'NoneType'>) -----------------]
| pr530_9b93469d | main | eager
1 threads: --------------------------------------------------------------------
f16 B=384, M=197, H=1, K=88 | 123.6 | 124.3 | 852.2
f16 B=384, M=197, H=1, K=80 | 115.2 | 115.8 | 740.9
f16 B=384, M=197, H=1, K=64 | 87.2 | 87.9 | 679.3
f16 B=1024, M=197, H=1, K=88 | 315.2 | 315.7 | 2219.5
f16 B=1024, M=197, H=1, K=80 | 293.5 | 294.9 | 1924.5
f16 B=1024, M=197, H=1, K=64 | 209.3 | 211.7 | 1768.7
f16 B=512, M=197, H=1, K=80 | 152.4 | 153.0 | 978.6
f16 B=32, M=197, H=16, K=80 | 153.2 | 153.9 | 1064.8
f16 B=32, M=197, H=16, K=64 | 112.0 | 113.1 | 981.8
f16 B=32, M=197, H=16, K=128 | 165.1 | 167.0 | 1719.5
f16 B=256, M=197, H=1, K=88 | 86.7 | 86.9 | 576.5
f16 B=16, M=197, H=16, K=88 | 87.3 | 87.6 | 629.6
f16 B=16, M=197, H=16, K=64 | 59.1 | 59.7 | 507.4
f16 B=16, M=197, H=16, K=128 | 88.1 | 88.6 | 879.3
f16 B=1, M=4096, H=160, K=128 | 15071.3 | 15053.1 | 20538.5
f16 B=2, M=4096, H=160, K=128 | 30077.7 | 30037.7 | 41703.8
f16 B=1, M=8192, H=160, K=128 | 60176.1 | 60096.2 |
f16 B=2, M=8192, H=160, K=128 | 120233.4 | 120091.0 |
f16 B=1024, M=82, H=8, K=64 | 447.2 | 450.9 | 1789.4
f16 B=150, M=256, H=16, K=64 | 503.9 | 511.2 | 1990.6
f16 B=64, M=256, H=12, K=64 | 170.3 | 172.0 | 671.8
f16 B=1, M=4096, H=16, K=40 | 873.4 | 882.4 | 1879.5
f16 B=1, M=16384, H=16, K=40 | 12397.6 | 12568.5 | 29433.2
f16 B=256, M=4096, H=16, K=64 | 183892.7 | 187014.7 |
f16 B=16, M=128, H=16, K=16 | 28.6 | 27.9 | 127.0
f16 B=16, M=128, H=16, K=32 | 28.3 | 27.8 | 128.2
f16 B=16, M=128, H=16, K=64 | 28.8 | 28.2 | 127.3
f16 B=16, M=128, H=16, K=128 | 38.9 | 39.1 | 147.5
f16 B=16, M=512, H=16, K=16 | 174.7 | 176.9 | 519.5
f16 B=16, M=512, H=16, K=32 | 180.8 | 182.6 | 568.3
f16 B=16, M=512, H=16, K=64 | 208.2 | 210.7 | 675.8
f16 B=16, M=512, H=16, K=128 | 360.4 | 362.2 | 860.2
f16 B=16, M=1024, H=16, K=16 | 666.0 | 674.1 | 1820.0
f16 B=16, M=1024, H=16, K=32 | 672.2 | 680.5 | 1913.4
f16 B=16, M=1024, H=16, K=64 | 768.9 | 780.0 | 2159.2
f16 B=16, M=1024, H=16, K=128 | 1348.2 | 1352.6 | 2592.8
f16 B=64, M=128, H=16, K=16 | 52.4 | 53.0 | 204.2
f16 B=64, M=128, H=16, K=32 | 58.1 | 58.3 | 250.9
f16 B=64, M=128, H=16, K=64 | 73.4 | 73.6 | 349.7
f16 B=64, M=128, H=16, K=128 | 130.6 | 131.0 | 538.1
f16 B=64, M=512, H=16, K=16 | 679.7 | 687.9 | 1883.3
f16 B=64, M=512, H=16, K=32 | 688.1 | 696.1 | 2071.7
f16 B=64, M=512, H=16, K=64 | 791.1 | 802.8 | 2472.3
f16 B=64, M=512, H=16, K=128 | 1415.9 | 1412.6 | 3258.6
f16 B=64, M=1024, H=16, K=16 | 2604.6 | 2637.9 | 7128.6
f16 B=64, M=1024, H=16, K=32 | 2625.8 | 2658.1 | 7520.6
f16 B=64, M=1024, H=16, K=64 | 3000.2 | 3046.3 | 8521.2
f16 B=64, M=1024, H=16, K=128 | 5358.9 | 5401.8 | 10204.7
Times are in microseconds (us).
A100 bw
[------------- attention backward (attn_bias=<class 'NoneType'>) -------------]
| pr530_9b93469d | main | vanilla
1 threads: --------------------------------------------------------------------
f16 B=384, M=197, H=1, K=88 | 609.7 | 602.9 | 2256.9
f16 B=384, M=197, H=1, K=80 | 577.2 | 571.9 | 1915.9
f16 B=384, M=197, H=1, K=64 | 386.9 | 385.4 | 1807.0
f16 B=1024, M=197, H=1, K=88 | 1548.1 | 1535.4 | 5942.5
f16 B=1024, M=197, H=1, K=80 | 1465.8 | 1456.2 | 5021.6
f16 B=1024, M=197, H=1, K=64 | 862.0 | 861.3 | 4737.1
f16 B=512, M=197, H=1, K=80 | 734.1 | 726.8 | 2539.7
f16 B=32, M=197, H=16, K=80 | 723.5 | 718.5 | 2578.6
f16 B=32, M=197, H=16, K=64 | 457.7 | 457.2 | 2432.8
f16 B=32, M=197, H=16, K=128 | 861.5 | 846.4 | 4497.1
f16 B=256, M=197, H=1, K=88 | 448.8 | 440.4 | 1526.1
f16 B=16, M=197, H=16, K=88 | 440.1 | 435.1 | 1538.6
f16 B=16, M=197, H=16, K=64 | 232.0 | 233.1 | 1245.4
f16 B=16, M=197, H=16, K=128 | 490.0 | 485.6 | 2266.9
f16 B=1, M=4096, H=160, K=128 | 63475.5 | 63526.0 | 46356.4
f16 B=2, M=4096, H=160, K=128 | 100184.6 | 100160.4 |
f16 B=1, M=8192, H=160, K=128 | 251758.1 | 251911.2 |
f16 B=2, M=8192, H=160, K=128 | 394557.1 | 394659.7 |
f16 B=1024, M=82, H=8, K=64 | 1866.4 | 1855.6 | 3826.4
f16 B=150, M=256, H=16, K=64 | 2105.3 | 2102.4 | 4559.9
f16 B=64, M=256, H=12, K=64 | 727.6 | 729.2 | 1499.7
f16 B=1, M=4096, H=16, K=40 | 23539.1 | 23522.2 | 4240.3
f16 B=1, M=16384, H=16, K=40 | 435278.5 | 437517.8 |
f16 B=256, M=4096, H=16, K=64 | 603767.9 | 602664.0 |
f16 B=16, M=128, H=16, K=16 | 146.0 | 141.7 | 492.7
f16 B=16, M=128, H=16, K=32 | 147.6 | 238.3 | 397.8
f16 B=16, M=128, H=16, K=64 | 142.3 | 138.9 | 303.9
f16 B=16, M=128, H=16, K=128 | 178.0 | 178.0 | 301.2
f16 B=16, M=512, H=16, K=16 | 552.9 | 553.3 | 1204.7
f16 B=16, M=512, H=16, K=32 | 651.5 | 650.8 | 1308.3
f16 B=16, M=512, H=16, K=64 | 849.4 | 849.4 | 1545.2
f16 B=16, M=512, H=16, K=128 | 1763.6 | 1762.4 | 1983.9
f16 B=16, M=1024, H=16, K=16 | 2229.1 | 2224.9 | 4262.5
f16 B=16, M=1024, H=16, K=32 | 2438.5 | 2438.5 | 4493.4
f16 B=16, M=1024, H=16, K=64 | 3031.7 | 3025.4 | 5001.1
f16 B=16, M=1024, H=16, K=128 | 6403.9 | 6398.4 | 5961.8
f16 B=64, M=128, H=16, K=16 | 161.7 | 161.6 | 439.2
f16 B=64, M=128, H=16, K=32 | 206.4 | 206.2 | 545.0
f16 B=64, M=128, H=16, K=64 | 327.3 | 326.4 | 766.2
f16 B=64, M=128, H=16, K=128 | 614.2 | 614.8 | 1231.2
f16 B=64, M=512, H=16, K=16 | 1975.7 | 1974.2 | 4487.7
f16 B=64, M=512, H=16, K=32 | 2355.5 | 2333.9 | 4979.9
f16 B=64, M=512, H=16, K=64 | 3075.9 | 3076.1 | 5888.7
f16 B=64, M=512, H=16, K=128 | 6148.0 | 6148.3 | 7706.6
f16 B=64, M=1024, H=16, K=16 | 7849.9 | 7839.7 | 16909.4
f16 B=64, M=1024, H=16, K=32 | 8856.1 | 8797.2 | 17904.8
f16 B=64, M=1024, H=16, K=64 | 11054.6 | 11058.0 | 19959.4
f16 B=64, M=1024, H=16, K=128 | 21944.5 | 21930.4 | 23716.0
Times are in microseconds (us).
What does this PR do?
Implement #522.
has_custom_scale
to memory efficient attentionforward
. When it is true, we assume the query state weights are scaled in advance so they won't be scaled again in kernels. This is required especially for T5 models.SUPPORTS_CUSTOM_SCALE
to indicate whether a memory efficient op supports custom scale. In this PR we only enableMemoryEfficientFlashAttentionOp
to align the API change and experiments. If everything goes well, I'll supportMemoryEfficientAttentionOp
in a follow-up PR (I'm not sure if the CUTLASS one can be supported without first changing CUTLASS kernel. It'd be great if someone could help confirm).Updated based on review comments:
has_custom_scale: bool = False
inAttentionOpDispatch
.SUPPORTS_CUSTOM_SCALE: bool = False
inAttentionOpBase
.scale: Optional[float] = None
inAttentionOpDispatch.from_argument
andmemory_efficient_attention
. WhenNone
, default scale value (1.0 / q.shape[-1] ** 0.5
) will be used.MemoryEfficientAttentionOp
.Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.