-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MmaFromSmem[A100]: Accept transposed operand A #540
Conversation
[ghstack-poisoned]
ghstack-source-id: 66214435cf46227e4488190b7cf43dfd26de065e Pull Request resolved: #540
[ghstack-poisoned]
ghstack-source-id: 258bcebead02ef44dc87eaf4fbd09fd6d9e3d3f5 Pull Request resolved: #540
[ghstack-poisoned]
ghstack-source-id: d532fc410b24d7a27bdfeb0071a6ab1d55c712a8 Pull Request resolved: #540
[ghstack-poisoned]
ghstack-source-id: d654cf5fdfaccf2b4597b3a3867863b1eaf02afa Pull Request resolved: #540
[ghstack-poisoned]
ghstack-source-id: 7496be13903f8115557c2d0fe77b93bff69cafca Pull Request resolved: #540
Codecov ReportBase: 89.79% // Head: 89.79% // No change to project coverage 👍
Additional details and impacted files@@ Coverage Diff @@
## gh/danthe3rd/57/base #540 +/- ##
=====================================================
Coverage 89.79% 89.79%
=====================================================
Files 80 80
Lines 4839 4839
=====================================================
Hits 4345 4345
Misses 494 494
Flags with carried forward coverage won't be shown. Click here to find out more. Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great speed-up, thanks a lot Daniel!
I just have some high-level question about what was changed in one file vs the reference implementation.
Also, I suppose our tests stress-test all the necessary configurations that are needed to validate that the new iterator works well for differently sized dimensions (which are not nice multiples)?
@@ -0,0 +1,241 @@ | |||
#pragma once |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you describe briefly what was changed in the code compared to the baseline implementation?
**SUMMARY** Load `tmp.transpose()` directly from `tmp` in shared memory (transpose as we load). No longer need to store tmp+tmp.T in shared memory. Because we use less shared-memory, this means we can fit bigger block sizes. Going from 64x128 -> 128x128 gives ~15% perf improvement (for k>64). **PERF TEST (A100)** <details> <summary>BW A100 (f16)</summary> ``` [------------------------------------ attention backward (attn_bias=<class 'NoneType'>) -------------------------------------] | 57_tmpT_b516aec4[cutlass] | flash[flshatt] | vanilla | 56_base_02bf6b4e[cutlass] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 578.3 | | 2263.9 | 609.7 f16 B=384, M=197, H=1, K=80 | 547.6 | | 1921.9 | 572.6 f16 B=384, M=197, H=1, K=64 | 365.4 | 232.6 | 1808.2 | 386.0 f16 B=1024, M=197, H=1, K=88 | 1456.3 | | 5964.7 | 1539.8 f16 B=1024, M=197, H=1, K=80 | 1382.9 | | 5037.8 | 1464.9 f16 B=1024, M=197, H=1, K=64 | 822.1 | 576.4 | 4732.2 | 862.5 f16 B=512, M=197, H=1, K=80 | 695.4 | | 2543.6 | 730.0 f16 B=32, M=197, H=16, K=80 | 691.2 | | 2567.7 | 716.3 f16 B=32, M=197, H=16, K=64 | 427.5 | 296.1 | 2428.1 | 456.6 f16 B=32, M=197, H=16, K=128 | 858.1 | 682.6 | 4488.4 | 853.5 f16 B=256, M=197, H=1, K=88 | 422.0 | | 1528.5 | 442.7 f16 B=16, M=197, H=16, K=88 | 420.4 | | 1543.0 | 437.3 f16 B=16, M=197, H=16, K=64 | 217.5 | 165.2 | 1243.5 | 232.6 f16 B=16, M=197, H=16, K=128 | 479.8 | 385.5 | 2263.8 | 489.2 f16 B=1, M=4096, H=160, K=128 | 51009.8 | 54670.3 | 45924.2 | 63431.9 f16 B=2, M=4096, H=160, K=128 | 84491.6 | 84261.8 | | 100393.5 f16 B=1, M=8192, H=160, K=128 | 201456.7 | 215540.9 | | 251825.4 f16 B=2, M=8192, H=160, K=128 | 329735.0 | 330316.3 | | 395279.2 f16 B=1024, M=82, H=8, K=64 | 1764.0 | 1620.9 | 3822.6 | 1857.4 f16 B=150, M=256, H=16, K=64 | 2021.6 | 1626.3 | 4557.3 | 2103.9 f16 B=64, M=256, H=12, K=64 | 699.4 | 567.8 | 1498.4 | 730.6 f16 B=1, M=4096, H=16, K=40 | 22788.9 | | 4195.6 | 23624.7 f16 B=1, M=16384, H=16, K=40 | 408280.7 | | | 436163.8 f16 B=256, M=4096, H=16, K=64 | 565651.1 | 439946.4 | | 602642.6 f16 B=16, M=128, H=16, K=16 | 121.9 | 139.6 | 331.3 | 121.9 f16 B=16, M=128, H=16, K=32 | 121.4 | 139.2 | 331.5 | 122.5 f16 B=16, M=128, H=16, K=64 | 121.9 | 140.0 | 369.9 | 187.9 f16 B=16, M=128, H=16, K=128 | 186.7 | 170.3 | 332.9 | 177.8 f16 B=16, M=512, H=16, K=16 | 518.4 | 322.2 | 1204.6 | 556.4 f16 B=16, M=512, H=16, K=32 | 602.5 | 435.1 | 1306.5 | 652.2 f16 B=16, M=512, H=16, K=64 | 797.0 | 704.9 | 1547.1 | 850.2 f16 B=16, M=512, H=16, K=128 | 1544.8 | 1584.6 | 1985.3 | 1752.3 f16 B=16, M=1024, H=16, K=16 | 2049.1 | 1244.7 | 4261.7 | 2239.9 f16 B=16, M=1024, H=16, K=32 | 2229.0 | 1620.4 | 4492.3 | 2448.0 f16 B=16, M=1024, H=16, K=64 | 2817.6 | 2367.6 | 4998.2 | 3041.0 f16 B=16, M=1024, H=16, K=128 | 5433.3 | 5638.9 | 5958.5 | 6406.4 f16 B=64, M=128, H=16, K=16 | 158.2 | 145.5 | 439.7 | 161.9 f16 B=64, M=128, H=16, K=32 | 205.2 | 212.4 | 545.2 | 206.7 f16 B=64, M=128, H=16, K=64 | 314.6 | 311.5 | 767.7 | 326.0 f16 B=64, M=128, H=16, K=128 | 651.9 | 562.8 | 1227.5 | 613.3 f16 B=64, M=512, H=16, K=16 | 1872.3 | 1204.0 | 4488.6 | 1985.3 f16 B=64, M=512, H=16, K=32 | 2185.4 | 1543.8 | 4971.7 | 2340.3 f16 B=64, M=512, H=16, K=64 | 2940.4 | 2421.0 | 5885.5 | 3077.9 f16 B=64, M=512, H=16, K=128 | 5501.3 | 5446.7 | 7711.0 | 6153.0 f16 B=64, M=1024, H=16, K=16 | 7318.5 | 4711.4 | 16891.1 | 7890.2 f16 B=64, M=1024, H=16, K=32 | 8151.5 | 5697.1 | 17885.4 | 8849.9 f16 B=64, M=1024, H=16, K=64 | 10477.7 | 8155.9 | 19951.2 | 11059.9 f16 B=64, M=1024, H=16, K=128 | 19178.4 | 19198.4 | 23794.0 | 21939.1 Times are in microseconds (us). ``` </details> [ghstack-poisoned]
ghstack-source-id: 1364e7de510a82b3d21b044b9cc093be101ba510 Pull Request resolved: #540
Stack from ghstack (oldest at bottom):
SUMMARY
Load
tmp.transpose()
directly fromtmp
in shared memory (transpose as we load). No longer need to store tmp+tmp.T in shared memory.Because we use less shared-memory, this means we can fit bigger block sizes. Going from 64x128 -> 128x128 gives ~15% perf improvement (for k>64).
PERF TEST (A100)
BW A100 (f16)