-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
chunk3: Add custom operator to avoid torch.cat in BW #458
Conversation
[ghstack-poisoned]
ghstack-source-id: d4449a3beb49285862c60f631849ca5e272f7578 Pull Request resolved: #458
[ghstack-poisoned]
ghstack-source-id: 49d85372692924617f2afc9d79aaa46a44325115 Pull Request resolved: #458
[ghstack-poisoned]
ghstack-source-id: b7304c42fbede4cf920541aa50a7554bdfbac3b5 Pull Request resolved: #458
Codecov ReportBase: 91.50% // Head: 91.51% // Increases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## gh/danthe3rd/48/base #458 +/- ##
========================================================
+ Coverage 91.50% 91.51% +0.01%
========================================================
Files 75 75
Lines 4412 4429 +17
========================================================
+ Hits 4037 4053 +16
- Misses 375 376 +1
Flags with carried forward coverage won't be shown. Click here to find out more.
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty nice, thanks Daniel!
I think there is a missing condition in the backward to check that the gradients all come from the same storage.
There is also some opportunities to make the code even more generic and support an arbitrary number of input elements through torch.unbind
, but up to you.
Also, can you add tests to Chunk3
for various cases? It would be good to stress-test this with some basic cases on its own test to make sure we are not missing anything else.
attn_bias_type=[type(None), torch.Tensor, xformers.ops.LowerTriangularMask], | ||
dtype=[torch.half, torch.bfloat16, torch.float], | ||
attn_bias_type=[type(None)], | ||
dtype=[torch.half], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: looks like something is missing here? :-)
def T(t): | ||
return t.permute((0, 2, 1, 3)).reshape( | ||
[t.shape[0] * t.shape[2], t.shape[1], t.shape[3]] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I wonder if we should still keep the previous benchmark without the permute + contiguous in the hot path.
The improved benchmark you just added is very important as it's the main entry point for the users, but it might also hide some potential improvements to be done in the standard attention because now we are also measuring those extra overheads.
I would maybe this benchmark on top of the previous one, so that we have more numbers when measuring things together. But this is only a suggestion
BTW, if you generalize your function to support The backward implementation currently lives in here in PyTorch |
… avoid torch.cat in BW" **SUMMARY** Also: - updated benchmarks to reflect the real-world scenario where we chunk a packed qkv - for both fw/bw. - added coverage for chunking in tests **PERFORMANCE IMPACT** <details> <summary>A100 bw (new benchmarks)</summary> ``` [---------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------] | 48_chunk3_31735f9 | 45_bwpacked_e53c5f3 | vanilla | 47_bwpackedgrad_9bacdf6 1 threads: -------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 560.7 | 663.9 | 2265.7 | 710.3 f32 B=384, M=197, H=1, K=88 | 2445.1 | 2540.3 | 1843.3 | 2611.0 f16 B=384, M=197, H=1, K=80 | 530.4 | 619.9 | 1922.8 | 663.0 f32 B=384, M=197, H=1, K=80 | 2326.1 | 2425.2 | 1788.7 | 2476.4 f16 B=384, M=197, H=1, K=64 | 391.7 | 462.2 | 1812.7 | 492.8 f32 B=384, M=197, H=1, K=64 | 1275.0 | 1379.4 | 1675.4 | 1388.4 f16 B=1024, M=197, H=1, K=88 | 1399.5 | 1666.2 | 5965.2 | 1775.5 f32 B=1024, M=197, H=1, K=88 | 6332.5 | 6618.1 | 4559.6 | 6740.5 f16 B=1024, M=197, H=1, K=80 | 1326.2 | 1543.9 | 5041.4 | 1652.3 f32 B=1024, M=197, H=1, K=80 | 6057.1 | 6301.3 | 4411.6 | 6433.6 f16 B=1024, M=197, H=1, K=64 | 876.9 | 1063.1 | 4749.3 | 1133.2 f32 B=1024, M=197, H=1, K=64 | 3360.2 | 3629.0 | 4118.8 | 3652.0 f16 B=512, M=197, H=1, K=80 | 669.0 | 786.4 | 2544.9 | 842.2 f32 B=512, M=197, H=1, K=80 | 3032.3 | 3127.8 | 2287.4 | 3229.8 f16 B=32, M=197, H=16, K=80 | 663.0 | 789.7 | 2569.0 | 837.8 f32 B=32, M=197, H=16, K=80 | 3005.5 | 3166.3 | 2354.1 | 3225.9 f16 B=32, M=197, H=16, K=64 | 459.9 | 553.4 | 2436.3 | 591.9 f32 B=32, M=197, H=16, K=64 | 1814.1 | 1962.5 | 2197.3 | 1962.1 f16 B=32, M=197, H=16, K=128 | 792.5 | 981.9 | 4505.9 | 1056.5 f32 B=32, M=197, H=16, K=128 | 3734.8 | 3995.7 | 2805.8 | 4021.5 f16 B=256, M=197, H=1, K=88 | 413.4 | 482.6 | 1529.5 | 515.5 f32 B=256, M=197, H=1, K=88 | 1741.9 | 1818.3 | 1208.6 | 1852.4 f16 B=16, M=197, H=16, K=88 | 410.3 | 482.9 | 1545.7 | 512.5 f32 B=16, M=197, H=16, K=88 | 1734.9 | 1832.1 | 1250.6 | 1849.4 f16 B=16, M=197, H=16, K=64 | 235.4 | 286.0 | 1247.1 | 305.3 f32 B=16, M=197, H=16, K=64 | 1077.1 | 1143.7 | 1125.9 | 1154.0 f16 B=16, M=197, H=16, K=128 | 455.4 | 554.1 | 2273.1 | 596.0 f32 B=16, M=197, H=16, K=128 | 2028.9 | 2164.5 | 1446.7 | 2175.0 f16 B=1, M=4096, H=160, K=128 | 62454.4 | 63474.5 | 45930.5 | 64052.7 f32 B=1, M=4096, H=160, K=128 | 239035.4 | 232672.1 | | 240073.9 f16 B=2, M=4096, H=160, K=128 | 98791.3 | 101006.4 | | 101942.0 f32 B=2, M=4096, H=160, K=128 | 375914.9 | 368050.6 | | 381280.4 f16 B=1, M=8192, H=160, K=128 | 248498.9 | 250066.9 | | 251500.4 f32 B=1, M=8192, H=160, K=128 | 945102.2 | 922549.3 | | 949256.4 f16 B=2, M=8192, H=160, K=128 | 389207.8 | 394486.6 | | 396190.4 f32 B=2, M=8192, H=160, K=128 | 1496334.3 | 1449974.3 | | 1502215.3 f16 B=1024, M=82, H=8, K=64 | 1872.4 | 2503.8 | 3819.8 | 2693.7 f32 B=1024, M=82, H=8, K=64 | 8734.3 | 9637.8 | 8732.9 | 9672.2 f16 B=150, M=256, H=16, K=64 | 2126.4 | 2713.4 | 4554.3 | 2880.8 f32 B=150, M=256, H=16, K=64 | 6214.3 | 7052.2 | 12943.2 | 7099.2 f16 B=64, M=256, H=12, K=64 | 741.2 | 930.1 | 1493.0 | 990.6 f32 B=64, M=256, H=12, K=64 | 2144.2 | 2408.5 | 4267.7 | 2433.8 f16 B=1, M=4096, H=16, K=40 | 24583.7 | 24224.8 | 4195.2 | 24500.2 f32 B=1, M=4096, H=16, K=40 | 72497.9 | 72070.8 | 17744.1 | 72393.0 f16 B=1, M=16384, H=16, K=40 | 451481.8 | 439027.7 | | 451499.9 f32 B=1, M=16384, H=16, K=40 | 1169509.1 | 1164880.1 | | 1169769.3 f16 B=256, M=4096, H=16, K=64 | 597391.6 | 625921.0 | | 610433.2 f16 B=16, M=128, H=16, K=16 | 93.1 | 126.7 | 241.2 | 132.3 f32 B=16, M=128, H=16, K=16 | 184.1 | 176.5 | 373.8 | 180.7 f16 B=16, M=128, H=16, K=32 | 127.9 | 126.3 | 241.4 | 106.7 f32 B=16, M=128, H=16, K=32 | 194.1 | 216.6 | 412.7 | 225.8 f16 B=16, M=128, H=16, K=64 | 131.4 | 126.8 | 239.8 | 134.5 f32 B=16, M=128, H=16, K=64 | 280.4 | 326.0 | 500.0 | 334.0 f16 B=16, M=128, H=16, K=128 | 175.6 | 236.1 | 298.8 | 261.1 f32 B=16, M=128, H=16, K=128 | 531.8 | 615.8 | 677.2 | 638.0 f16 B=16, M=512, H=16, K=16 | 558.2 | 595.0 | 1201.9 | 607.8 f32 B=16, M=512, H=16, K=16 | 2146.7 | 2169.9 | 4416.1 | 2200.6 f16 B=16, M=512, H=16, K=32 | 653.5 | 732.3 | 1305.1 | 748.5 f32 B=16, M=512, H=16, K=32 | 2296.3 | 2373.9 | 4641.3 | 2400.1 f16 B=16, M=512, H=16, K=64 | 848.8 | 996.9 | 1544.6 | 1022.5 f32 B=16, M=512, H=16, K=64 | 2954.0 | 3117.1 | 5124.7 | 3157.6 f16 B=16, M=512, H=16, K=128 | 1735.4 | 1961.1 | 1982.7 | 2056.9 f32 B=16, M=512, H=16, K=128 | 6218.7 | 6396.4 | 6094.0 | 6600.3 f16 B=16, M=1024, H=16, K=16 | 2236.4 | 2319.4 | 4279.0 | 2331.6 f32 B=16, M=1024, H=16, K=16 | 8379.2 | 8363.9 | 16643.9 | 8503.6 f16 B=16, M=1024, H=16, K=32 | 2430.8 | 2649.6 | 4496.8 | 2608.7 f32 B=16, M=1024, H=16, K=32 | 8864.7 | 8907.8 | 17291.0 | 9074.0 f16 B=16, M=1024, H=16, K=64 | 3007.2 | 3351.3 | 4995.5 | 3351.0 f32 B=16, M=1024, H=16, K=64 | 11355.4 | 11627.1 | 18707.5 | 11694.3 f16 B=16, M=1024, H=16, K=128 | 6296.2 | 6748.7 | 5943.5 | 6967.0 f32 B=16, M=1024, H=16, K=128 | 23425.3 | 23360.0 | 21520.6 | 24169.7 f16 B=64, M=128, H=16, K=16 | 165.5 | 195.9 | 440.3 | 211.5 f32 B=64, M=128, H=16, K=16 | 497.4 | 540.7 | 1270.8 | 550.3 f16 B=64, M=128, H=16, K=32 | 210.4 | 274.9 | 544.8 | 298.5 f32 B=64, M=128, H=16, K=32 | 604.4 | 696.6 | 1428.3 | 710.9 f16 B=64, M=128, H=16, K=64 | 330.4 | 452.3 | 766.0 | 498.1 f32 B=64, M=128, H=16, K=64 | 883.4 | 1060.4 | 1745.2 | 1082.2 f16 B=64, M=128, H=16, K=128 | 605.5 | 847.8 | 1223.6 | 933.9 f32 B=64, M=128, H=16, K=128 | 1847.4 | 2169.7 | 2388.8 | 2236.0 f16 B=64, M=512, H=16, K=16 | 2004.7 | 2120.0 | 4487.0 | 2179.4 f32 B=64, M=512, H=16, K=16 | 6655.4 | 6818.8 | 16993.8 | 6872.1 f16 B=64, M=512, H=16, K=32 | 2379.3 | 2593.1 | 4957.2 | 2704.0 f32 B=64, M=512, H=16, K=32 | 7349.4 | 7644.6 | 17852.2 | 7736.2 f16 B=64, M=512, H=16, K=64 | 3129.6 | 3616.6 | 5888.8 | 3786.2 f32 B=64, M=512, H=16, K=64 | 9432.5 | 10123.9 | 19770.6 | 10178.5 f16 B=64, M=512, H=16, K=128 | 6054.1 | 7019.9 | 7712.6 | 7350.2 f32 B=64, M=512, H=16, K=128 | 21565.6 | 22281.9 | 23653.0 | 23084.4 f16 B=64, M=1024, H=16, K=16 | 7929.4 | 8199.1 | 16876.3 | 8242.5 f32 B=64, M=1024, H=16, K=16 | 26135.2 | 26347.9 | 66351.1 | 26639.0 f16 B=64, M=1024, H=16, K=32 | 8876.8 | 9450.0 | 17869.4 | 9473.5 f32 B=64, M=1024, H=16, K=32 | 27685.3 | 28104.6 | 69105.9 | 28428.7 f16 B=64, M=1024, H=16, K=64 | 11198.7 | 12180.5 | 19932.3 | 12543.4 f32 B=64, M=1024, H=16, K=64 | 34978.2 | 36239.4 | 74813.7 | 36482.4 f16 B=64, M=1024, H=16, K=128 | 21618.9 | 23439.6 | 23741.1 | 24160.1 f32 B=64, M=1024, H=16, K=128 | 80785.3 | 81080.8 | 86003.6 | 84132.9 Times are in microseconds (us). ``` </details> <details> <summary>P100/V100 bw (new benchmarks)</summary> ``` [---------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------------] | 48_chunk3_31735f94 | 45_bwpacked_e53c5f3a | vanilla | 47_bwpackedgrad_9bacdf65 1 threads: -------------------------------------------------------------------------------------------------------------------------------------- (Quadro_GP100) f16 B=384, M=197, H=1, K=88 | 6846.3 | 7583.8 | 3569.3 | 7599.5 f32 B=384, M=197, H=1, K=88 | 9883.1 | 10107.2 | 4312.8 | 10486.3 f16 B=384, M=197, H=1, K=80 | 6486.4 | 6997.7 | 3418.0 | 7037.3 f32 B=384, M=197, H=1, K=80 | 9330.3 | 9550.6 | 4094.7 | 9893.4 f16 B=384, M=197, H=1, K=64 | 3615.4 | 3930.4 | 2911.0 | 4074.2 f32 B=384, M=197, H=1, K=64 | 6281.4 | 6554.5 | 3431.9 | 6738.1 f16 B=1024, M=197, H=1, K=88 | 17226.8 | 18593.1 | 9733.2 | 18772.9 f32 B=1024, M=197, H=1, K=88 | 26593.3 | 27136.2 | 12033.8 | 28184.2 f16 B=1024, M=197, H=1, K=80 | 16330.1 | 17478.6 | 9270.2 | 17735.3 f32 B=1024, M=197, H=1, K=80 | 25208.9 | 25680.1 | 11224.5 | 26636.1 f16 B=1024, M=197, H=1, K=64 | 8889.1 | 9728.8 | 7646.1 | 10089.7 f32 B=1024, M=197, H=1, K=64 | 16914.7 | 17743.4 | 9383.8 | 18068.4 f16 B=512, M=197, H=1, K=80 | 8227.3 | 8878.4 | 4579.3 | 8953.6 f32 B=512, M=197, H=1, K=80 | 13078.7 | 13346.0 | 5486.4 | 13817.6 f16 B=32, M=197, H=16, K=80 | 8278.9 | 9002.9 | 4816.2 | 9025.6 f32 B=32, M=197, H=16, K=80 | 12913.8 | 13371.2 | 5777.7 | 13667.6 f16 B=32, M=197, H=16, K=64 | 4565.2 | 5000.0 | 4023.4 | 5146.3 f32 B=32, M=197, H=16, K=64 | 8824.0 | 9257.7 | 4797.2 | 9400.5 f16 B=32, M=197, H=16, K=128 | 9770.0 | 10849.7 | 5983.2 | 10932.0 f32 B=32, M=197, H=16, K=128 | 15715.2 | 16559.9 | 7513.6 | 16839.9 f16 B=256, M=197, H=1, K=88 | 5011.2 | 5363.8 | 2444.9 | 5426.0 f32 B=256, M=197, H=1, K=88 | 6918.7 | 7040.8 | 2867.8 | 7303.2 f16 B=16, M=197, H=16, K=88 | 4963.8 | 5343.9 | 2545.2 | 5398.9 f32 B=16, M=197, H=16, K=88 | 6727.9 | 6981.7 | 3040.3 | 7121.2 f16 B=16, M=197, H=16, K=64 | 2586.5 | 2777.1 | 2025.5 | 2905.6 f32 B=16, M=197, H=16, K=64 | 4404.3 | 4607.2 | 2431.1 | 4691.8 f16 B=16, M=197, H=16, K=128 | 5643.2 | 6194.1 | 3016.1 | 6216.3 f32 B=16, M=197, H=16, K=128 | 7887.1 | 8308.3 | 3676.6 | 8456.2 f16 B=1, M=4096, H=160, K=128 | 1087008.7 | 1115355.5 | | 1091596.8 f32 B=1, M=4096, H=160, K=128 | 1220066.8 | 1223422.8 | | 1227912.2 f16 B=2, M=4096, H=160, K=128 | 1734244.4 | 1794068.7 | | 1756266.7 f32 B=2, M=4096, H=160, K=128 | 2437675.5 | 2445780.4 | | 2451957.5 f16 B=1, M=8192, H=160, K=128 | 4367110.4 | 4466170.9 | | 4383747.4 f32 B=1, M=8192, H=160, K=128 | 4865732.9 | 4865708.9 | | 4887066.5 f16 B=2, M=8192, H=160, K=128 | 7002715.1 | 7146077.9 | | 7033922.8 f16 B=1024, M=82, H=8, K=64 | 23247.5 | 24929.5 | 18047.8 | 26928.2 f32 B=1024, M=82, H=8, K=64 | 46463.2 | 48705.6 | 22797.5 | 50736.3 f16 B=150, M=256, H=16, K=64 | 23467.9 | 25647.3 | 24569.2 | 26841.8 f32 B=150, M=256, H=16, K=64 | 36887.7 | 39698.0 | 32050.2 | 40389.0 f16 B=64, M=256, H=12, K=64 | 7723.7 | 8499.0 | 7702.1 | 8694.9 f32 B=64, M=256, H=12, K=64 | 11992.1 | 12819.9 | 9874.5 | 13107.9 f16 B=1, M=4096, H=16, K=40 | 142655.5 | 142899.7 | 28928.6 | 142922.7 f32 B=1, M=4096, H=16, K=40 | 142626.8 | 142685.3 | 37303.2 | 142541.0 f16 B=1, M=16384, H=16, K=40 | 2274095.0 | 2274882.0 | | 2275019.9 f32 B=1, M=16384, H=16, K=40 | 2284027.2 | 2279415.7 | | 2277761.9 f16 B=16, M=128, H=16, K=16 | 513.2 | 547.1 | 571.5 | 570.9 f32 B=16, M=128, H=16, K=16 | 667.4 | 704.3 | 693.1 | 728.0 f16 B=16, M=128, H=16, K=32 | 600.3 | 667.0 | 671.3 | 713.1 f32 B=16, M=128, H=16, K=32 | 823.9 | 888.9 | 823.5 | 937.3 f16 B=16, M=128, H=16, K=64 | 781.0 | 900.6 | 883.1 | 998.9 f32 B=16, M=128, H=16, K=64 | 1173.7 | 1293.8 | 1077.0 | 1393.4 f16 B=16, M=128, H=16, K=128 | 1649.2 | 1877.2 | 1323.2 | 2026.3 f32 B=16, M=128, H=16, K=128 | 2250.5 | 2473.0 | 1654.7 | 2636.6 f16 B=16, M=512, H=16, K=16 | 7709.3 | 7914.6 | 6945.1 | 7928.7 f32 B=16, M=512, H=16, K=16 | 9797.2 | 9950.5 | 8499.4 | 10029.3 f16 B=16, M=512, H=16, K=32 | 8956.9 | 9210.8 | 7517.1 | 9307.0 f32 B=16, M=512, H=16, K=32 | 11480.7 | 11710.9 | 9249.4 | 11884.4 f16 B=16, M=512, H=16, K=64 | 11324.0 | 11829.1 | 8849.5 | 12001.8 f32 B=16, M=512, H=16, K=64 | 15744.1 | 16258.0 | 10954.6 | 16481.1 f16 B=16, M=512, H=16, K=128 | 25320.2 | 26584.0 | 12412.3 | 26725.0 f32 B=16, M=512, H=16, K=128 | 31187.1 | 32290.3 | 15167.5 | 32818.4 f16 B=16, M=1024, H=16, K=16 | 31484.2 | 31601.4 | 26434.6 | 31894.6 f32 B=16, M=1024, H=16, K=16 | 38754.1 | 38900.1 | 32320.0 | 39203.9 f16 B=16, M=1024, H=16, K=32 | 36000.2 | 36672.6 | 28341.4 | 36579.5 f32 B=16, M=1024, H=16, K=32 | 45070.7 | 45262.3 | 34914.2 | 45774.5 f16 B=16, M=1024, H=16, K=64 | 45324.9 | 46540.4 | 32089.9 | 46784.2 f32 B=16, M=1024, H=16, K=64 | 61320.3 | 62411.1 | 39565.0 | 63217.0 f16 B=16, M=1024, H=16, K=128 | 104342.9 | 108469.4 | 43221.9 | 105620.6 f32 B=16, M=1024, H=16, K=128 | 122688.4 | 125050.9 | 51205.7 | 126080.9 f16 B=64, M=128, H=16, K=16 | 1707.9 | 1824.9 | 2106.4 | 1923.2 f32 B=64, M=128, H=16, K=16 | 2487.4 | 2612.5 | 2565.1 | 2707.6 f16 B=64, M=128, H=16, K=32 | 2016.8 | 2254.4 | 2485.4 | 2412.3 f32 B=64, M=128, H=16, K=32 | 3135.8 | 3365.6 | 3063.2 | 3518.5 f16 B=64, M=128, H=16, K=64 | 2700.2 | 3167.0 | 3306.0 | 3478.4 f32 B=64, M=128, H=16, K=64 | 4435.1 | 4944.7 | 4227.6 | 5181.2 f16 B=64, M=128, H=16, K=128 | 5769.1 | 6858.2 | 5299.8 | 7356.1 f32 B=64, M=128, H=16, K=128 | 8577.9 | 9672.0 | 6916.3 | 10093.5 f16 B=64, M=512, H=16, K=16 | 25994.0 | 26782.0 | 27240.9 | 26662.2 f32 B=64, M=512, H=16, K=16 | 36864.9 | 37299.3 | 34159.3 | 37576.7 f16 B=64, M=512, H=16, K=32 | 30680.4 | 32113.8 | 30109.0 | 32419.7 f32 B=64, M=512, H=16, K=32 | 43638.5 | 44557.9 | 37358.5 | 45145.0 f16 B=64, M=512, H=16, K=64 | 39417.5 | 41666.5 | 36004.2 | 42374.9 f32 B=64, M=512, H=16, K=64 | 60049.2 | 63148.0 | 43412.6 | 63286.8 f16 B=64, M=512, H=16, K=128 | 88951.1 | 93087.0 | 51730.1 | 94861.6 f32 B=64, M=512, H=16, K=128 | 119728.7 | 124340.3 | 62413.7 | 126382.2 f16 B=64, M=1024, H=16, K=16 | 108368.3 | 111081.8 | 106479.7 | 108716.1 f32 B=64, M=1024, H=16, K=16 | 145612.0 | 147310.4 | | 147380.7 f16 B=64, M=1024, H=16, K=32 | 124296.1 | 127366.8 | 113905.0 | 126975.3 f32 B=64, M=1024, H=16, K=32 | 171082.3 | 172539.0 | | 173893.9 f16 B=64, M=1024, H=16, K=64 | 155116.3 | 160429.2 | 130759.4 | 161834.0 f32 B=64, M=1024, H=16, K=64 | 234356.0 | 239612.2 | | 239948.3 f16 B=64, M=1024, H=16, K=128 | 349728.3 | 360975.7 | 176158.7 | 371185.2 f32 B=64, M=1024, H=16, K=128 | 468810.0 | 476415.4 | | 481908.5 (Tesla_V100_SXM2_16GB) f16 B=384, M=197, H=1, K=88 | 1700.3 | 1840.0 | 1375.3 | 1930.9 f32 B=384, M=197, H=1, K=88 | 4456.4 | 4579.3 | 2235.5 | 4708.6 f16 B=384, M=197, H=1, K=80 | 1623.3 | 1719.9 | 1279.5 | 1806.9 f32 B=384, M=197, H=1, K=80 | 4031.2 | 4141.9 | 2149.8 | 4252.6 f16 B=384, M=197, H=1, K=64 | 1092.8 | 1187.0 | 1048.5 | 1237.6 f32 B=384, M=197, H=1, K=64 | 2717.5 | 2918.5 | 1738.5 | 2907.9 f16 B=1024, M=197, H=1, K=88 | 4428.7 | 4906.2 | 3723.7 | 5178.2 f32 B=1024, M=197, H=1, K=88 | 10947.5 | 11362.9 | 6052.5 | 11802.1 f16 B=1024, M=197, H=1, K=80 | 4237.1 | 4491.4 | 3331.7 | 4725.6 f32 B=1024, M=197, H=1, K=80 | 9842.6 | 10159.7 | 5682.4 | 10435.6 f16 B=1024, M=197, H=1, K=64 | 2679.2 | 2927.4 | 2674.4 | 3033.0 f32 B=1024, M=197, H=1, K=64 | 6597.6 | 7154.9 | 4489.7 | 7063.1 f16 B=512, M=197, H=1, K=80 | 2239.5 | 2366.5 | 1684.2 | 2472.0 f32 B=512, M=197, H=1, K=80 | 5362.4 | 5519.6 | 2857.9 | 5651.4 f16 B=32, M=197, H=16, K=80 | 2208.1 | 2380.0 | 1803.4 | 2439.4 f32 B=32, M=197, H=16, K=80 | 5503.6 | 5736.7 | 3017.5 | 5796.2 f16 B=32, M=197, H=16, K=64 | 1493.4 | 1620.6 | 1457.2 | 1678.6 f32 B=32, M=197, H=16, K=64 | 3672.6 | 3941.6 | 2415.0 | 3898.2 f16 B=32, M=197, H=16, K=128 | 2634.3 | 2888.0 | 2215.1 | 2991.5 f32 B=32, M=197, H=16, K=128 | 6811.5 | 7334.0 | 4049.3 | 7261.9 f16 B=256, M=197, H=1, K=88 | 1290.3 | 1382.0 | 944.8 | 1449.4 f32 B=256, M=197, H=1, K=88 | 2965.8 | 3043.2 | 1528.7 | 3137.7 f16 B=16, M=197, H=16, K=88 | 1267.3 | 1357.0 | 970.8 | 1395.5 f32 B=16, M=197, H=16, K=88 | 2879.9 | 3014.7 | 1626.5 | 3054.3 f16 B=16, M=197, H=16, K=64 | 737.3 | 799.8 | 771.3 | 836.9 f32 B=16, M=197, H=16, K=64 | 1879.2 | 2000.9 | 1282.5 | 1994.5 f16 B=16, M=197, H=16, K=128 | 1443.9 | 1570.7 | 1142.2 | 1628.8 f32 B=16, M=197, H=16, K=128 | 3480.5 | 3723.6 | 2027.2 | 3714.6 f16 B=1, M=4096, H=160, K=128 | 150006.2 | 151877.5 | | 152570.6 f32 B=1, M=4096, H=160, K=128 | 582870.9 | 583519.8 | | 585570.1 f16 B=2, M=4096, H=160, K=128 | 301231.4 | 304511.7 | | 305801.2 f32 B=2, M=4096, H=160, K=128 | 1174724.1 | 1172498.4 | | 1176814.0 f16 B=1, M=8192, H=160, K=128 | 597461.6 | 600463.4 | | 603066.6 f32 B=1, M=8192, H=160, K=128 | 2333657.8 | 2329212.1 | | 2339766.1 f16 B=2, M=8192, H=160, K=128 | 1196837.5 | 1206932.4 | | 1209012.2 f16 B=1024, M=82, H=8, K=64 | 8926.8 | 9723.4 | 5799.4 | 10084.2 f32 B=1024, M=82, H=8, K=64 | 15920.4 | 17434.4 | 11027.0 | 17492.8 f16 B=150, M=256, H=16, K=64 | 5524.2 | 6363.9 | 7557.9 | 6586.2 f32 B=150, M=256, H=16, K=64 | 17506.9 | 18843.5 | 16263.5 | 18988.6 f16 B=64, M=256, H=12, K=64 | 1800.6 | 2050.3 | 2383.4 | 2139.0 f32 B=64, M=256, H=12, K=64 | 5753.6 | 6196.3 | 4971.2 | 6200.0 f16 B=1, M=4096, H=16, K=40 | 47649.5 | 47836.0 | 8368.4 | 47973.6 f32 B=1, M=4096, H=16, K=40 | 111092.1 | 111027.3 | 19475.9 | 111257.8 f16 B=1, M=16384, H=16, K=40 | 765320.2 | 765686.9 | | 767337.2 f32 B=1, M=16384, H=16, K=40 | 1769169.0 | 1769675.1 | | 1769371.4 f16 B=16, M=128, H=16, K=16 | 178.9 | 196.8 | 445.9 | 188.3 f32 B=16, M=128, H=16, K=16 | 301.3 | 319.1 | 422.5 | 336.3 f16 B=16, M=128, H=16, K=32 | 174.1 | 174.2 | 394.0 | 179.5 f32 B=16, M=128, H=16, K=32 | 395.7 | 433.2 | 580.0 | 440.4 f16 B=16, M=128, H=16, K=64 | 205.0 | 253.5 | 460.6 | 270.9 f32 B=16, M=128, H=16, K=64 | 573.7 | 639.3 | 598.1 | 656.1 f16 B=16, M=128, H=16, K=128 | 399.5 | 484.3 | 515.2 | 521.8 f32 B=16, M=128, H=16, K=128 | 1126.3 | 1260.8 | 1008.1 | 1282.4 f16 B=16, M=512, H=16, K=16 | 1597.6 | 1627.2 | 1901.1 | 1662.1 f32 B=16, M=512, H=16, K=16 | 4458.5 | 4528.8 | 4232.0 | 4559.4 f16 B=16, M=512, H=16, K=32 | 1819.1 | 1868.7 | 2097.2 | 1945.5 f32 B=16, M=512, H=16, K=32 | 5604.2 | 5757.1 | 4566.4 | 5784.8 f16 B=16, M=512, H=16, K=64 | 2345.5 | 2495.6 | 2558.0 | 2573.2 f32 B=16, M=512, H=16, K=64 | 7778.3 | 8017.1 | 5488.2 | 8083.7 f16 B=16, M=512, H=16, K=128 | 4516.6 | 4821.0 | 3386.7 | 4968.2 f32 B=16, M=512, H=16, K=128 | 15412.7 | 15959.2 | 8865.9 | 16047.5 f16 B=16, M=1024, H=16, K=16 | 6195.9 | 6217.6 | 6995.3 | 6326.4 f32 B=16, M=1024, H=16, K=16 | 18136.2 | 18312.0 | 16088.2 | 18354.1 f16 B=16, M=1024, H=16, K=32 | 7072.8 | 7122.3 | 7406.9 | 7297.7 f32 B=16, M=1024, H=16, K=32 | 22108.2 | 22116.7 | 17112.5 | 22436.8 f16 B=16, M=1024, H=16, K=64 | 8868.0 | 9104.6 | 8627.1 | 9311.8 f32 B=16, M=1024, H=16, K=64 | 30710.5 | 31041.3 | 19860.8 | 31338.1 f16 B=16, M=1024, H=16, K=128 | 17091.8 | 17655.5 | 10548.3 | 18083.8 f32 B=16, M=1024, H=16, K=128 | 60317.8 | 61461.7 | 32919.2 | 61548.8 f16 B=64, M=128, H=16, K=16 | 413.6 | 453.8 | 635.5 | 480.6 f32 B=64, M=128, H=16, K=16 | 1033.8 | 1114.3 | 1238.9 | 1119.5 f16 B=64, M=128, H=16, K=32 | 505.7 | 587.9 | 813.6 | 630.1 f32 B=64, M=128, H=16, K=32 | 1423.0 | 1551.4 | 1533.4 | 1581.8 f16 B=64, M=128, H=16, K=64 | 743.3 | 916.8 | 1187.7 | 976.5 f32 B=64, M=128, H=16, K=64 | 2093.3 | 2384.6 | 2156.3 | 2405.4 f16 B=64, M=128, H=16, K=128 | 1408.2 | 1734.3 | 1918.7 | 1859.6 f32 B=64, M=128, H=16, K=128 | 4125.3 | 4671.4 | 3762.0 | 4717.0 f16 B=64, M=512, H=16, K=16 | 5531.2 | 5643.3 | 7454.4 | 5770.8 f32 B=64, M=512, H=16, K=16 | 16214.0 | 16531.2 | 16661.3 | 16540.8 f16 B=64, M=512, H=16, K=32 | 6495.5 | 6725.2 | 8353.7 | 6941.8 f32 B=64, M=512, H=16, K=32 | 20520.6 | 20941.9 | 18352.4 | 21116.8 f16 B=64, M=512, H=16, K=64 | 8686.1 | 9278.6 | 10343.4 | 9593.2 f32 B=64, M=512, H=16, K=64 | 28891.1 | 30003.0 | 22749.4 | 30139.1 f16 B=64, M=512, H=16, K=128 | 15991.4 | 17412.3 | 14633.0 | 17848.2 f32 B=64, M=512, H=16, K=128 | 57526.8 | 59970.8 | 40089.9 | 60016.9 f16 B=64, M=1024, H=16, K=16 | 21552.8 | 21603.1 | 28447.1 | 22030.0 f32 B=64, M=1024, H=16, K=16 | 65321.2 | 65736.8 | | 65932.0 f16 B=64, M=1024, H=16, K=32 | 25695.4 | 25905.9 | 30592.1 | 26644.8 f32 B=64, M=1024, H=16, K=32 | 80213.4 | 80446.7 | | 81363.1 f16 B=64, M=1024, H=16, K=64 | 32465.6 | 33575.1 | 37233.4 | 34370.8 f32 B=64, M=1024, H=16, K=64 | 112996.7 | 115632.0 | | 115970.8 f16 B=64, M=1024, H=16, K=128 | 60363.5 | 62800.2 | 48883.7 | 64505.1 f32 B=64, M=1024, H=16, K=128 | 225023.4 | 230527.4 | | 229851.8 Times are in microseconds (us). ``` </details> [ghstack-poisoned]
ghstack-source-id: b747ec522ded39beee04a77dbe70238877b0245b Pull Request resolved: #458
**SUMMARY** Also: - updated benchmarks to reflect the real-world scenario where we chunk a packed qkv - for both fw/bw. - added coverage for chunking in tests **PERFORMANCE IMPACT** <details> <summary>A100 bw (new benchmarks)</summary> ``` [---------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------] | 48_chunk3_31735f9 | 45_bwpacked_e53c5f3 | vanilla | 47_bwpackedgrad_9bacdf6 1 threads: -------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 560.7 | 663.9 | 2265.7 | 710.3 f32 B=384, M=197, H=1, K=88 | 2445.1 | 2540.3 | 1843.3 | 2611.0 f16 B=384, M=197, H=1, K=80 | 530.4 | 619.9 | 1922.8 | 663.0 f32 B=384, M=197, H=1, K=80 | 2326.1 | 2425.2 | 1788.7 | 2476.4 f16 B=384, M=197, H=1, K=64 | 391.7 | 462.2 | 1812.7 | 492.8 f32 B=384, M=197, H=1, K=64 | 1275.0 | 1379.4 | 1675.4 | 1388.4 f16 B=1024, M=197, H=1, K=88 | 1399.5 | 1666.2 | 5965.2 | 1775.5 f32 B=1024, M=197, H=1, K=88 | 6332.5 | 6618.1 | 4559.6 | 6740.5 f16 B=1024, M=197, H=1, K=80 | 1326.2 | 1543.9 | 5041.4 | 1652.3 f32 B=1024, M=197, H=1, K=80 | 6057.1 | 6301.3 | 4411.6 | 6433.6 f16 B=1024, M=197, H=1, K=64 | 876.9 | 1063.1 | 4749.3 | 1133.2 f32 B=1024, M=197, H=1, K=64 | 3360.2 | 3629.0 | 4118.8 | 3652.0 f16 B=512, M=197, H=1, K=80 | 669.0 | 786.4 | 2544.9 | 842.2 f32 B=512, M=197, H=1, K=80 | 3032.3 | 3127.8 | 2287.4 | 3229.8 f16 B=32, M=197, H=16, K=80 | 663.0 | 789.7 | 2569.0 | 837.8 f32 B=32, M=197, H=16, K=80 | 3005.5 | 3166.3 | 2354.1 | 3225.9 f16 B=32, M=197, H=16, K=64 | 459.9 | 553.4 | 2436.3 | 591.9 f32 B=32, M=197, H=16, K=64 | 1814.1 | 1962.5 | 2197.3 | 1962.1 f16 B=32, M=197, H=16, K=128 | 792.5 | 981.9 | 4505.9 | 1056.5 f32 B=32, M=197, H=16, K=128 | 3734.8 | 3995.7 | 2805.8 | 4021.5 f16 B=256, M=197, H=1, K=88 | 413.4 | 482.6 | 1529.5 | 515.5 f32 B=256, M=197, H=1, K=88 | 1741.9 | 1818.3 | 1208.6 | 1852.4 f16 B=16, M=197, H=16, K=88 | 410.3 | 482.9 | 1545.7 | 512.5 f32 B=16, M=197, H=16, K=88 | 1734.9 | 1832.1 | 1250.6 | 1849.4 f16 B=16, M=197, H=16, K=64 | 235.4 | 286.0 | 1247.1 | 305.3 f32 B=16, M=197, H=16, K=64 | 1077.1 | 1143.7 | 1125.9 | 1154.0 f16 B=16, M=197, H=16, K=128 | 455.4 | 554.1 | 2273.1 | 596.0 f32 B=16, M=197, H=16, K=128 | 2028.9 | 2164.5 | 1446.7 | 2175.0 f16 B=1, M=4096, H=160, K=128 | 62454.4 | 63474.5 | 45930.5 | 64052.7 f32 B=1, M=4096, H=160, K=128 | 239035.4 | 232672.1 | | 240073.9 f16 B=2, M=4096, H=160, K=128 | 98791.3 | 101006.4 | | 101942.0 f32 B=2, M=4096, H=160, K=128 | 375914.9 | 368050.6 | | 381280.4 f16 B=1, M=8192, H=160, K=128 | 248498.9 | 250066.9 | | 251500.4 f32 B=1, M=8192, H=160, K=128 | 945102.2 | 922549.3 | | 949256.4 f16 B=2, M=8192, H=160, K=128 | 389207.8 | 394486.6 | | 396190.4 f32 B=2, M=8192, H=160, K=128 | 1496334.3 | 1449974.3 | | 1502215.3 f16 B=1024, M=82, H=8, K=64 | 1872.4 | 2503.8 | 3819.8 | 2693.7 f32 B=1024, M=82, H=8, K=64 | 8734.3 | 9637.8 | 8732.9 | 9672.2 f16 B=150, M=256, H=16, K=64 | 2126.4 | 2713.4 | 4554.3 | 2880.8 f32 B=150, M=256, H=16, K=64 | 6214.3 | 7052.2 | 12943.2 | 7099.2 f16 B=64, M=256, H=12, K=64 | 741.2 | 930.1 | 1493.0 | 990.6 f32 B=64, M=256, H=12, K=64 | 2144.2 | 2408.5 | 4267.7 | 2433.8 f16 B=1, M=4096, H=16, K=40 | 24583.7 | 24224.8 | 4195.2 | 24500.2 f32 B=1, M=4096, H=16, K=40 | 72497.9 | 72070.8 | 17744.1 | 72393.0 f16 B=1, M=16384, H=16, K=40 | 451481.8 | 439027.7 | | 451499.9 f32 B=1, M=16384, H=16, K=40 | 1169509.1 | 1164880.1 | | 1169769.3 f16 B=256, M=4096, H=16, K=64 | 597391.6 | 625921.0 | | 610433.2 f16 B=16, M=128, H=16, K=16 | 93.1 | 126.7 | 241.2 | 132.3 f32 B=16, M=128, H=16, K=16 | 184.1 | 176.5 | 373.8 | 180.7 f16 B=16, M=128, H=16, K=32 | 127.9 | 126.3 | 241.4 | 106.7 f32 B=16, M=128, H=16, K=32 | 194.1 | 216.6 | 412.7 | 225.8 f16 B=16, M=128, H=16, K=64 | 131.4 | 126.8 | 239.8 | 134.5 f32 B=16, M=128, H=16, K=64 | 280.4 | 326.0 | 500.0 | 334.0 f16 B=16, M=128, H=16, K=128 | 175.6 | 236.1 | 298.8 | 261.1 f32 B=16, M=128, H=16, K=128 | 531.8 | 615.8 | 677.2 | 638.0 f16 B=16, M=512, H=16, K=16 | 558.2 | 595.0 | 1201.9 | 607.8 f32 B=16, M=512, H=16, K=16 | 2146.7 | 2169.9 | 4416.1 | 2200.6 f16 B=16, M=512, H=16, K=32 | 653.5 | 732.3 | 1305.1 | 748.5 f32 B=16, M=512, H=16, K=32 | 2296.3 | 2373.9 | 4641.3 | 2400.1 f16 B=16, M=512, H=16, K=64 | 848.8 | 996.9 | 1544.6 | 1022.5 f32 B=16, M=512, H=16, K=64 | 2954.0 | 3117.1 | 5124.7 | 3157.6 f16 B=16, M=512, H=16, K=128 | 1735.4 | 1961.1 | 1982.7 | 2056.9 f32 B=16, M=512, H=16, K=128 | 6218.7 | 6396.4 | 6094.0 | 6600.3 f16 B=16, M=1024, H=16, K=16 | 2236.4 | 2319.4 | 4279.0 | 2331.6 f32 B=16, M=1024, H=16, K=16 | 8379.2 | 8363.9 | 16643.9 | 8503.6 f16 B=16, M=1024, H=16, K=32 | 2430.8 | 2649.6 | 4496.8 | 2608.7 f32 B=16, M=1024, H=16, K=32 | 8864.7 | 8907.8 | 17291.0 | 9074.0 f16 B=16, M=1024, H=16, K=64 | 3007.2 | 3351.3 | 4995.5 | 3351.0 f32 B=16, M=1024, H=16, K=64 | 11355.4 | 11627.1 | 18707.5 | 11694.3 f16 B=16, M=1024, H=16, K=128 | 6296.2 | 6748.7 | 5943.5 | 6967.0 f32 B=16, M=1024, H=16, K=128 | 23425.3 | 23360.0 | 21520.6 | 24169.7 f16 B=64, M=128, H=16, K=16 | 165.5 | 195.9 | 440.3 | 211.5 f32 B=64, M=128, H=16, K=16 | 497.4 | 540.7 | 1270.8 | 550.3 f16 B=64, M=128, H=16, K=32 | 210.4 | 274.9 | 544.8 | 298.5 f32 B=64, M=128, H=16, K=32 | 604.4 | 696.6 | 1428.3 | 710.9 f16 B=64, M=128, H=16, K=64 | 330.4 | 452.3 | 766.0 | 498.1 f32 B=64, M=128, H=16, K=64 | 883.4 | 1060.4 | 1745.2 | 1082.2 f16 B=64, M=128, H=16, K=128 | 605.5 | 847.8 | 1223.6 | 933.9 f32 B=64, M=128, H=16, K=128 | 1847.4 | 2169.7 | 2388.8 | 2236.0 f16 B=64, M=512, H=16, K=16 | 2004.7 | 2120.0 | 4487.0 | 2179.4 f32 B=64, M=512, H=16, K=16 | 6655.4 | 6818.8 | 16993.8 | 6872.1 f16 B=64, M=512, H=16, K=32 | 2379.3 | 2593.1 | 4957.2 | 2704.0 f32 B=64, M=512, H=16, K=32 | 7349.4 | 7644.6 | 17852.2 | 7736.2 f16 B=64, M=512, H=16, K=64 | 3129.6 | 3616.6 | 5888.8 | 3786.2 f32 B=64, M=512, H=16, K=64 | 9432.5 | 10123.9 | 19770.6 | 10178.5 f16 B=64, M=512, H=16, K=128 | 6054.1 | 7019.9 | 7712.6 | 7350.2 f32 B=64, M=512, H=16, K=128 | 21565.6 | 22281.9 | 23653.0 | 23084.4 f16 B=64, M=1024, H=16, K=16 | 7929.4 | 8199.1 | 16876.3 | 8242.5 f32 B=64, M=1024, H=16, K=16 | 26135.2 | 26347.9 | 66351.1 | 26639.0 f16 B=64, M=1024, H=16, K=32 | 8876.8 | 9450.0 | 17869.4 | 9473.5 f32 B=64, M=1024, H=16, K=32 | 27685.3 | 28104.6 | 69105.9 | 28428.7 f16 B=64, M=1024, H=16, K=64 | 11198.7 | 12180.5 | 19932.3 | 12543.4 f32 B=64, M=1024, H=16, K=64 | 34978.2 | 36239.4 | 74813.7 | 36482.4 f16 B=64, M=1024, H=16, K=128 | 21618.9 | 23439.6 | 23741.1 | 24160.1 f32 B=64, M=1024, H=16, K=128 | 80785.3 | 81080.8 | 86003.6 | 84132.9 Times are in microseconds (us). ``` </details> <details> <summary>P100/V100 bw (new benchmarks)</summary> ``` [---------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------------] | 48_chunk3_31735f94 | 45_bwpacked_e53c5f3a | vanilla | 47_bwpackedgrad_9bacdf65 1 threads: -------------------------------------------------------------------------------------------------------------------------------------- (Quadro_GP100) f16 B=384, M=197, H=1, K=88 | 6846.3 | 7583.8 | 3569.3 | 7599.5 f32 B=384, M=197, H=1, K=88 | 9883.1 | 10107.2 | 4312.8 | 10486.3 f16 B=384, M=197, H=1, K=80 | 6486.4 | 6997.7 | 3418.0 | 7037.3 f32 B=384, M=197, H=1, K=80 | 9330.3 | 9550.6 | 4094.7 | 9893.4 f16 B=384, M=197, H=1, K=64 | 3615.4 | 3930.4 | 2911.0 | 4074.2 f32 B=384, M=197, H=1, K=64 | 6281.4 | 6554.5 | 3431.9 | 6738.1 f16 B=1024, M=197, H=1, K=88 | 17226.8 | 18593.1 | 9733.2 | 18772.9 f32 B=1024, M=197, H=1, K=88 | 26593.3 | 27136.2 | 12033.8 | 28184.2 f16 B=1024, M=197, H=1, K=80 | 16330.1 | 17478.6 | 9270.2 | 17735.3 f32 B=1024, M=197, H=1, K=80 | 25208.9 | 25680.1 | 11224.5 | 26636.1 f16 B=1024, M=197, H=1, K=64 | 8889.1 | 9728.8 | 7646.1 | 10089.7 f32 B=1024, M=197, H=1, K=64 | 16914.7 | 17743.4 | 9383.8 | 18068.4 f16 B=512, M=197, H=1, K=80 | 8227.3 | 8878.4 | 4579.3 | 8953.6 f32 B=512, M=197, H=1, K=80 | 13078.7 | 13346.0 | 5486.4 | 13817.6 f16 B=32, M=197, H=16, K=80 | 8278.9 | 9002.9 | 4816.2 | 9025.6 f32 B=32, M=197, H=16, K=80 | 12913.8 | 13371.2 | 5777.7 | 13667.6 f16 B=32, M=197, H=16, K=64 | 4565.2 | 5000.0 | 4023.4 | 5146.3 f32 B=32, M=197, H=16, K=64 | 8824.0 | 9257.7 | 4797.2 | 9400.5 f16 B=32, M=197, H=16, K=128 | 9770.0 | 10849.7 | 5983.2 | 10932.0 f32 B=32, M=197, H=16, K=128 | 15715.2 | 16559.9 | 7513.6 | 16839.9 f16 B=256, M=197, H=1, K=88 | 5011.2 | 5363.8 | 2444.9 | 5426.0 f32 B=256, M=197, H=1, K=88 | 6918.7 | 7040.8 | 2867.8 | 7303.2 f16 B=16, M=197, H=16, K=88 | 4963.8 | 5343.9 | 2545.2 | 5398.9 f32 B=16, M=197, H=16, K=88 | 6727.9 | 6981.7 | 3040.3 | 7121.2 f16 B=16, M=197, H=16, K=64 | 2586.5 | 2777.1 | 2025.5 | 2905.6 f32 B=16, M=197, H=16, K=64 | 4404.3 | 4607.2 | 2431.1 | 4691.8 f16 B=16, M=197, H=16, K=128 | 5643.2 | 6194.1 | 3016.1 | 6216.3 f32 B=16, M=197, H=16, K=128 | 7887.1 | 8308.3 | 3676.6 | 8456.2 f16 B=1, M=4096, H=160, K=128 | 1087008.7 | 1115355.5 | | 1091596.8 f32 B=1, M=4096, H=160, K=128 | 1220066.8 | 1223422.8 | | 1227912.2 f16 B=2, M=4096, H=160, K=128 | 1734244.4 | 1794068.7 | | 1756266.7 f32 B=2, M=4096, H=160, K=128 | 2437675.5 | 2445780.4 | | 2451957.5 f16 B=1, M=8192, H=160, K=128 | 4367110.4 | 4466170.9 | | 4383747.4 f32 B=1, M=8192, H=160, K=128 | 4865732.9 | 4865708.9 | | 4887066.5 f16 B=2, M=8192, H=160, K=128 | 7002715.1 | 7146077.9 | | 7033922.8 f16 B=1024, M=82, H=8, K=64 | 23247.5 | 24929.5 | 18047.8 | 26928.2 f32 B=1024, M=82, H=8, K=64 | 46463.2 | 48705.6 | 22797.5 | 50736.3 f16 B=150, M=256, H=16, K=64 | 23467.9 | 25647.3 | 24569.2 | 26841.8 f32 B=150, M=256, H=16, K=64 | 36887.7 | 39698.0 | 32050.2 | 40389.0 f16 B=64, M=256, H=12, K=64 | 7723.7 | 8499.0 | 7702.1 | 8694.9 f32 B=64, M=256, H=12, K=64 | 11992.1 | 12819.9 | 9874.5 | 13107.9 f16 B=1, M=4096, H=16, K=40 | 142655.5 | 142899.7 | 28928.6 | 142922.7 f32 B=1, M=4096, H=16, K=40 | 142626.8 | 142685.3 | 37303.2 | 142541.0 f16 B=1, M=16384, H=16, K=40 | 2274095.0 | 2274882.0 | | 2275019.9 f32 B=1, M=16384, H=16, K=40 | 2284027.2 | 2279415.7 | | 2277761.9 f16 B=16, M=128, H=16, K=16 | 513.2 | 547.1 | 571.5 | 570.9 f32 B=16, M=128, H=16, K=16 | 667.4 | 704.3 | 693.1 | 728.0 f16 B=16, M=128, H=16, K=32 | 600.3 | 667.0 | 671.3 | 713.1 f32 B=16, M=128, H=16, K=32 | 823.9 | 888.9 | 823.5 | 937.3 f16 B=16, M=128, H=16, K=64 | 781.0 | 900.6 | 883.1 | 998.9 f32 B=16, M=128, H=16, K=64 | 1173.7 | 1293.8 | 1077.0 | 1393.4 f16 B=16, M=128, H=16, K=128 | 1649.2 | 1877.2 | 1323.2 | 2026.3 f32 B=16, M=128, H=16, K=128 | 2250.5 | 2473.0 | 1654.7 | 2636.6 f16 B=16, M=512, H=16, K=16 | 7709.3 | 7914.6 | 6945.1 | 7928.7 f32 B=16, M=512, H=16, K=16 | 9797.2 | 9950.5 | 8499.4 | 10029.3 f16 B=16, M=512, H=16, K=32 | 8956.9 | 9210.8 | 7517.1 | 9307.0 f32 B=16, M=512, H=16, K=32 | 11480.7 | 11710.9 | 9249.4 | 11884.4 f16 B=16, M=512, H=16, K=64 | 11324.0 | 11829.1 | 8849.5 | 12001.8 f32 B=16, M=512, H=16, K=64 | 15744.1 | 16258.0 | 10954.6 | 16481.1 f16 B=16, M=512, H=16, K=128 | 25320.2 | 26584.0 | 12412.3 | 26725.0 f32 B=16, M=512, H=16, K=128 | 31187.1 | 32290.3 | 15167.5 | 32818.4 f16 B=16, M=1024, H=16, K=16 | 31484.2 | 31601.4 | 26434.6 | 31894.6 f32 B=16, M=1024, H=16, K=16 | 38754.1 | 38900.1 | 32320.0 | 39203.9 f16 B=16, M=1024, H=16, K=32 | 36000.2 | 36672.6 | 28341.4 | 36579.5 f32 B=16, M=1024, H=16, K=32 | 45070.7 | 45262.3 | 34914.2 | 45774.5 f16 B=16, M=1024, H=16, K=64 | 45324.9 | 46540.4 | 32089.9 | 46784.2 f32 B=16, M=1024, H=16, K=64 | 61320.3 | 62411.1 | 39565.0 | 63217.0 f16 B=16, M=1024, H=16, K=128 | 104342.9 | 108469.4 | 43221.9 | 105620.6 f32 B=16, M=1024, H=16, K=128 | 122688.4 | 125050.9 | 51205.7 | 126080.9 f16 B=64, M=128, H=16, K=16 | 1707.9 | 1824.9 | 2106.4 | 1923.2 f32 B=64, M=128, H=16, K=16 | 2487.4 | 2612.5 | 2565.1 | 2707.6 f16 B=64, M=128, H=16, K=32 | 2016.8 | 2254.4 | 2485.4 | 2412.3 f32 B=64, M=128, H=16, K=32 | 3135.8 | 3365.6 | 3063.2 | 3518.5 f16 B=64, M=128, H=16, K=64 | 2700.2 | 3167.0 | 3306.0 | 3478.4 f32 B=64, M=128, H=16, K=64 | 4435.1 | 4944.7 | 4227.6 | 5181.2 f16 B=64, M=128, H=16, K=128 | 5769.1 | 6858.2 | 5299.8 | 7356.1 f32 B=64, M=128, H=16, K=128 | 8577.9 | 9672.0 | 6916.3 | 10093.5 f16 B=64, M=512, H=16, K=16 | 25994.0 | 26782.0 | 27240.9 | 26662.2 f32 B=64, M=512, H=16, K=16 | 36864.9 | 37299.3 | 34159.3 | 37576.7 f16 B=64, M=512, H=16, K=32 | 30680.4 | 32113.8 | 30109.0 | 32419.7 f32 B=64, M=512, H=16, K=32 | 43638.5 | 44557.9 | 37358.5 | 45145.0 f16 B=64, M=512, H=16, K=64 | 39417.5 | 41666.5 | 36004.2 | 42374.9 f32 B=64, M=512, H=16, K=64 | 60049.2 | 63148.0 | 43412.6 | 63286.8 f16 B=64, M=512, H=16, K=128 | 88951.1 | 93087.0 | 51730.1 | 94861.6 f32 B=64, M=512, H=16, K=128 | 119728.7 | 124340.3 | 62413.7 | 126382.2 f16 B=64, M=1024, H=16, K=16 | 108368.3 | 111081.8 | 106479.7 | 108716.1 f32 B=64, M=1024, H=16, K=16 | 145612.0 | 147310.4 | | 147380.7 f16 B=64, M=1024, H=16, K=32 | 124296.1 | 127366.8 | 113905.0 | 126975.3 f32 B=64, M=1024, H=16, K=32 | 171082.3 | 172539.0 | | 173893.9 f16 B=64, M=1024, H=16, K=64 | 155116.3 | 160429.2 | 130759.4 | 161834.0 f32 B=64, M=1024, H=16, K=64 | 234356.0 | 239612.2 | | 239948.3 f16 B=64, M=1024, H=16, K=128 | 349728.3 | 360975.7 | 176158.7 | 371185.2 f32 B=64, M=1024, H=16, K=128 | 468810.0 | 476415.4 | | 481908.5 (Tesla_V100_SXM2_16GB) f16 B=384, M=197, H=1, K=88 | 1700.3 | 1840.0 | 1375.3 | 1930.9 f32 B=384, M=197, H=1, K=88 | 4456.4 | 4579.3 | 2235.5 | 4708.6 f16 B=384, M=197, H=1, K=80 | 1623.3 | 1719.9 | 1279.5 | 1806.9 f32 B=384, M=197, H=1, K=80 | 4031.2 | 4141.9 | 2149.8 | 4252.6 f16 B=384, M=197, H=1, K=64 | 1092.8 | 1187.0 | 1048.5 | 1237.6 f32 B=384, M=197, H=1, K=64 | 2717.5 | 2918.5 | 1738.5 | 2907.9 f16 B=1024, M=197, H=1, K=88 | 4428.7 | 4906.2 | 3723.7 | 5178.2 f32 B=1024, M=197, H=1, K=88 | 10947.5 | 11362.9 | 6052.5 | 11802.1 f16 B=1024, M=197, H=1, K=80 | 4237.1 | 4491.4 | 3331.7 | 4725.6 f32 B=1024, M=197, H=1, K=80 | 9842.6 | 10159.7 | 5682.4 | 10435.6 f16 B=1024, M=197, H=1, K=64 | 2679.2 | 2927.4 | 2674.4 | 3033.0 f32 B=1024, M=197, H=1, K=64 | 6597.6 | 7154.9 | 4489.7 | 7063.1 f16 B=512, M=197, H=1, K=80 | 2239.5 | 2366.5 | 1684.2 | 2472.0 f32 B=512, M=197, H=1, K=80 | 5362.4 | 5519.6 | 2857.9 | 5651.4 f16 B=32, M=197, H=16, K=80 | 2208.1 | 2380.0 | 1803.4 | 2439.4 f32 B=32, M=197, H=16, K=80 | 5503.6 | 5736.7 | 3017.5 | 5796.2 f16 B=32, M=197, H=16, K=64 | 1493.4 | 1620.6 | 1457.2 | 1678.6 f32 B=32, M=197, H=16, K=64 | 3672.6 | 3941.6 | 2415.0 | 3898.2 f16 B=32, M=197, H=16, K=128 | 2634.3 | 2888.0 | 2215.1 | 2991.5 f32 B=32, M=197, H=16, K=128 | 6811.5 | 7334.0 | 4049.3 | 7261.9 f16 B=256, M=197, H=1, K=88 | 1290.3 | 1382.0 | 944.8 | 1449.4 f32 B=256, M=197, H=1, K=88 | 2965.8 | 3043.2 | 1528.7 | 3137.7 f16 B=16, M=197, H=16, K=88 | 1267.3 | 1357.0 | 970.8 | 1395.5 f32 B=16, M=197, H=16, K=88 | 2879.9 | 3014.7 | 1626.5 | 3054.3 f16 B=16, M=197, H=16, K=64 | 737.3 | 799.8 | 771.3 | 836.9 f32 B=16, M=197, H=16, K=64 | 1879.2 | 2000.9 | 1282.5 | 1994.5 f16 B=16, M=197, H=16, K=128 | 1443.9 | 1570.7 | 1142.2 | 1628.8 f32 B=16, M=197, H=16, K=128 | 3480.5 | 3723.6 | 2027.2 | 3714.6 f16 B=1, M=4096, H=160, K=128 | 150006.2 | 151877.5 | | 152570.6 f32 B=1, M=4096, H=160, K=128 | 582870.9 | 583519.8 | | 585570.1 f16 B=2, M=4096, H=160, K=128 | 301231.4 | 304511.7 | | 305801.2 f32 B=2, M=4096, H=160, K=128 | 1174724.1 | 1172498.4 | | 1176814.0 f16 B=1, M=8192, H=160, K=128 | 597461.6 | 600463.4 | | 603066.6 f32 B=1, M=8192, H=160, K=128 | 2333657.8 | 2329212.1 | | 2339766.1 f16 B=2, M=8192, H=160, K=128 | 1196837.5 | 1206932.4 | | 1209012.2 f16 B=1024, M=82, H=8, K=64 | 8926.8 | 9723.4 | 5799.4 | 10084.2 f32 B=1024, M=82, H=8, K=64 | 15920.4 | 17434.4 | 11027.0 | 17492.8 f16 B=150, M=256, H=16, K=64 | 5524.2 | 6363.9 | 7557.9 | 6586.2 f32 B=150, M=256, H=16, K=64 | 17506.9 | 18843.5 | 16263.5 | 18988.6 f16 B=64, M=256, H=12, K=64 | 1800.6 | 2050.3 | 2383.4 | 2139.0 f32 B=64, M=256, H=12, K=64 | 5753.6 | 6196.3 | 4971.2 | 6200.0 f16 B=1, M=4096, H=16, K=40 | 47649.5 | 47836.0 | 8368.4 | 47973.6 f32 B=1, M=4096, H=16, K=40 | 111092.1 | 111027.3 | 19475.9 | 111257.8 f16 B=1, M=16384, H=16, K=40 | 765320.2 | 765686.9 | | 767337.2 f32 B=1, M=16384, H=16, K=40 | 1769169.0 | 1769675.1 | | 1769371.4 f16 B=16, M=128, H=16, K=16 | 178.9 | 196.8 | 445.9 | 188.3 f32 B=16, M=128, H=16, K=16 | 301.3 | 319.1 | 422.5 | 336.3 f16 B=16, M=128, H=16, K=32 | 174.1 | 174.2 | 394.0 | 179.5 f32 B=16, M=128, H=16, K=32 | 395.7 | 433.2 | 580.0 | 440.4 f16 B=16, M=128, H=16, K=64 | 205.0 | 253.5 | 460.6 | 270.9 f32 B=16, M=128, H=16, K=64 | 573.7 | 639.3 | 598.1 | 656.1 f16 B=16, M=128, H=16, K=128 | 399.5 | 484.3 | 515.2 | 521.8 f32 B=16, M=128, H=16, K=128 | 1126.3 | 1260.8 | 1008.1 | 1282.4 f16 B=16, M=512, H=16, K=16 | 1597.6 | 1627.2 | 1901.1 | 1662.1 f32 B=16, M=512, H=16, K=16 | 4458.5 | 4528.8 | 4232.0 | 4559.4 f16 B=16, M=512, H=16, K=32 | 1819.1 | 1868.7 | 2097.2 | 1945.5 f32 B=16, M=512, H=16, K=32 | 5604.2 | 5757.1 | 4566.4 | 5784.8 f16 B=16, M=512, H=16, K=64 | 2345.5 | 2495.6 | 2558.0 | 2573.2 f32 B=16, M=512, H=16, K=64 | 7778.3 | 8017.1 | 5488.2 | 8083.7 f16 B=16, M=512, H=16, K=128 | 4516.6 | 4821.0 | 3386.7 | 4968.2 f32 B=16, M=512, H=16, K=128 | 15412.7 | 15959.2 | 8865.9 | 16047.5 f16 B=16, M=1024, H=16, K=16 | 6195.9 | 6217.6 | 6995.3 | 6326.4 f32 B=16, M=1024, H=16, K=16 | 18136.2 | 18312.0 | 16088.2 | 18354.1 f16 B=16, M=1024, H=16, K=32 | 7072.8 | 7122.3 | 7406.9 | 7297.7 f32 B=16, M=1024, H=16, K=32 | 22108.2 | 22116.7 | 17112.5 | 22436.8 f16 B=16, M=1024, H=16, K=64 | 8868.0 | 9104.6 | 8627.1 | 9311.8 f32 B=16, M=1024, H=16, K=64 | 30710.5 | 31041.3 | 19860.8 | 31338.1 f16 B=16, M=1024, H=16, K=128 | 17091.8 | 17655.5 | 10548.3 | 18083.8 f32 B=16, M=1024, H=16, K=128 | 60317.8 | 61461.7 | 32919.2 | 61548.8 f16 B=64, M=128, H=16, K=16 | 413.6 | 453.8 | 635.5 | 480.6 f32 B=64, M=128, H=16, K=16 | 1033.8 | 1114.3 | 1238.9 | 1119.5 f16 B=64, M=128, H=16, K=32 | 505.7 | 587.9 | 813.6 | 630.1 f32 B=64, M=128, H=16, K=32 | 1423.0 | 1551.4 | 1533.4 | 1581.8 f16 B=64, M=128, H=16, K=64 | 743.3 | 916.8 | 1187.7 | 976.5 f32 B=64, M=128, H=16, K=64 | 2093.3 | 2384.6 | 2156.3 | 2405.4 f16 B=64, M=128, H=16, K=128 | 1408.2 | 1734.3 | 1918.7 | 1859.6 f32 B=64, M=128, H=16, K=128 | 4125.3 | 4671.4 | 3762.0 | 4717.0 f16 B=64, M=512, H=16, K=16 | 5531.2 | 5643.3 | 7454.4 | 5770.8 f32 B=64, M=512, H=16, K=16 | 16214.0 | 16531.2 | 16661.3 | 16540.8 f16 B=64, M=512, H=16, K=32 | 6495.5 | 6725.2 | 8353.7 | 6941.8 f32 B=64, M=512, H=16, K=32 | 20520.6 | 20941.9 | 18352.4 | 21116.8 f16 B=64, M=512, H=16, K=64 | 8686.1 | 9278.6 | 10343.4 | 9593.2 f32 B=64, M=512, H=16, K=64 | 28891.1 | 30003.0 | 22749.4 | 30139.1 f16 B=64, M=512, H=16, K=128 | 15991.4 | 17412.3 | 14633.0 | 17848.2 f32 B=64, M=512, H=16, K=128 | 57526.8 | 59970.8 | 40089.9 | 60016.9 f16 B=64, M=1024, H=16, K=16 | 21552.8 | 21603.1 | 28447.1 | 22030.0 f32 B=64, M=1024, H=16, K=16 | 65321.2 | 65736.8 | | 65932.0 f16 B=64, M=1024, H=16, K=32 | 25695.4 | 25905.9 | 30592.1 | 26644.8 f32 B=64, M=1024, H=16, K=32 | 80213.4 | 80446.7 | | 81363.1 f16 B=64, M=1024, H=16, K=64 | 32465.6 | 33575.1 | 37233.4 | 34370.8 f32 B=64, M=1024, H=16, K=64 | 112996.7 | 115632.0 | | 115970.8 f16 B=64, M=1024, H=16, K=128 | 60363.5 | 62800.2 | 48883.7 | 64505.1 f32 B=64, M=1024, H=16, K=128 | 225023.4 | 230527.4 | | 229851.8 Times are in microseconds (us). ``` </details> [ghstack-poisoned]
ghstack-source-id: 370afb5983c34f74dca3a4d324240eed44e78add Pull Request resolved: #458
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great, thanks!
I have some minor suggestions in the tests to ensure we cover a few more cases, but they can be done in a follow-up PR.
Also, if we are really sure about our implementation and want to enable it to all downstream users without asking them to change their code (which I'm not sure we should do now, might be better to be explicit), we could also use the new torch.library.Library
functionality from PyTorch that allows overriding PyTorch functions directly from Python, so we could override the unbind_backward
function.
An example is as follows (taken from pytorch/pytorch#75905):
def my_sum(*args, **kwargs):
return args[0]
my_lib1 = torch.library.Library("aten", "IMPL")
my_lib1.impl('aten::sum', my_sum)
x = torch.tensor([1, 2])
assert torch.sum(x) == x
del my_lib1
assert torch.sum(x) == torch.tensor(3)
**SUMMARY** Also: - updated benchmarks to reflect the real-world scenario where we chunk a packed qkv - for both fw/bw. - added coverage for chunking in tests **PERFORMANCE IMPACT** <details> <summary>A100 bw (new benchmarks)</summary> ``` [---------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------] | 48_chunk3_31735f9 | 45_bwpacked_e53c5f3 | vanilla | 47_bwpackedgrad_9bacdf6 1 threads: -------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 560.7 | 663.9 | 2265.7 | 710.3 f32 B=384, M=197, H=1, K=88 | 2445.1 | 2540.3 | 1843.3 | 2611.0 f16 B=384, M=197, H=1, K=80 | 530.4 | 619.9 | 1922.8 | 663.0 f32 B=384, M=197, H=1, K=80 | 2326.1 | 2425.2 | 1788.7 | 2476.4 f16 B=384, M=197, H=1, K=64 | 391.7 | 462.2 | 1812.7 | 492.8 f32 B=384, M=197, H=1, K=64 | 1275.0 | 1379.4 | 1675.4 | 1388.4 f16 B=1024, M=197, H=1, K=88 | 1399.5 | 1666.2 | 5965.2 | 1775.5 f32 B=1024, M=197, H=1, K=88 | 6332.5 | 6618.1 | 4559.6 | 6740.5 f16 B=1024, M=197, H=1, K=80 | 1326.2 | 1543.9 | 5041.4 | 1652.3 f32 B=1024, M=197, H=1, K=80 | 6057.1 | 6301.3 | 4411.6 | 6433.6 f16 B=1024, M=197, H=1, K=64 | 876.9 | 1063.1 | 4749.3 | 1133.2 f32 B=1024, M=197, H=1, K=64 | 3360.2 | 3629.0 | 4118.8 | 3652.0 f16 B=512, M=197, H=1, K=80 | 669.0 | 786.4 | 2544.9 | 842.2 f32 B=512, M=197, H=1, K=80 | 3032.3 | 3127.8 | 2287.4 | 3229.8 f16 B=32, M=197, H=16, K=80 | 663.0 | 789.7 | 2569.0 | 837.8 f32 B=32, M=197, H=16, K=80 | 3005.5 | 3166.3 | 2354.1 | 3225.9 f16 B=32, M=197, H=16, K=64 | 459.9 | 553.4 | 2436.3 | 591.9 f32 B=32, M=197, H=16, K=64 | 1814.1 | 1962.5 | 2197.3 | 1962.1 f16 B=32, M=197, H=16, K=128 | 792.5 | 981.9 | 4505.9 | 1056.5 f32 B=32, M=197, H=16, K=128 | 3734.8 | 3995.7 | 2805.8 | 4021.5 f16 B=256, M=197, H=1, K=88 | 413.4 | 482.6 | 1529.5 | 515.5 f32 B=256, M=197, H=1, K=88 | 1741.9 | 1818.3 | 1208.6 | 1852.4 f16 B=16, M=197, H=16, K=88 | 410.3 | 482.9 | 1545.7 | 512.5 f32 B=16, M=197, H=16, K=88 | 1734.9 | 1832.1 | 1250.6 | 1849.4 f16 B=16, M=197, H=16, K=64 | 235.4 | 286.0 | 1247.1 | 305.3 f32 B=16, M=197, H=16, K=64 | 1077.1 | 1143.7 | 1125.9 | 1154.0 f16 B=16, M=197, H=16, K=128 | 455.4 | 554.1 | 2273.1 | 596.0 f32 B=16, M=197, H=16, K=128 | 2028.9 | 2164.5 | 1446.7 | 2175.0 f16 B=1, M=4096, H=160, K=128 | 62454.4 | 63474.5 | 45930.5 | 64052.7 f32 B=1, M=4096, H=160, K=128 | 239035.4 | 232672.1 | | 240073.9 f16 B=2, M=4096, H=160, K=128 | 98791.3 | 101006.4 | | 101942.0 f32 B=2, M=4096, H=160, K=128 | 375914.9 | 368050.6 | | 381280.4 f16 B=1, M=8192, H=160, K=128 | 248498.9 | 250066.9 | | 251500.4 f32 B=1, M=8192, H=160, K=128 | 945102.2 | 922549.3 | | 949256.4 f16 B=2, M=8192, H=160, K=128 | 389207.8 | 394486.6 | | 396190.4 f32 B=2, M=8192, H=160, K=128 | 1496334.3 | 1449974.3 | | 1502215.3 f16 B=1024, M=82, H=8, K=64 | 1872.4 | 2503.8 | 3819.8 | 2693.7 f32 B=1024, M=82, H=8, K=64 | 8734.3 | 9637.8 | 8732.9 | 9672.2 f16 B=150, M=256, H=16, K=64 | 2126.4 | 2713.4 | 4554.3 | 2880.8 f32 B=150, M=256, H=16, K=64 | 6214.3 | 7052.2 | 12943.2 | 7099.2 f16 B=64, M=256, H=12, K=64 | 741.2 | 930.1 | 1493.0 | 990.6 f32 B=64, M=256, H=12, K=64 | 2144.2 | 2408.5 | 4267.7 | 2433.8 f16 B=1, M=4096, H=16, K=40 | 24583.7 | 24224.8 | 4195.2 | 24500.2 f32 B=1, M=4096, H=16, K=40 | 72497.9 | 72070.8 | 17744.1 | 72393.0 f16 B=1, M=16384, H=16, K=40 | 451481.8 | 439027.7 | | 451499.9 f32 B=1, M=16384, H=16, K=40 | 1169509.1 | 1164880.1 | | 1169769.3 f16 B=256, M=4096, H=16, K=64 | 597391.6 | 625921.0 | | 610433.2 f16 B=16, M=128, H=16, K=16 | 93.1 | 126.7 | 241.2 | 132.3 f32 B=16, M=128, H=16, K=16 | 184.1 | 176.5 | 373.8 | 180.7 f16 B=16, M=128, H=16, K=32 | 127.9 | 126.3 | 241.4 | 106.7 f32 B=16, M=128, H=16, K=32 | 194.1 | 216.6 | 412.7 | 225.8 f16 B=16, M=128, H=16, K=64 | 131.4 | 126.8 | 239.8 | 134.5 f32 B=16, M=128, H=16, K=64 | 280.4 | 326.0 | 500.0 | 334.0 f16 B=16, M=128, H=16, K=128 | 175.6 | 236.1 | 298.8 | 261.1 f32 B=16, M=128, H=16, K=128 | 531.8 | 615.8 | 677.2 | 638.0 f16 B=16, M=512, H=16, K=16 | 558.2 | 595.0 | 1201.9 | 607.8 f32 B=16, M=512, H=16, K=16 | 2146.7 | 2169.9 | 4416.1 | 2200.6 f16 B=16, M=512, H=16, K=32 | 653.5 | 732.3 | 1305.1 | 748.5 f32 B=16, M=512, H=16, K=32 | 2296.3 | 2373.9 | 4641.3 | 2400.1 f16 B=16, M=512, H=16, K=64 | 848.8 | 996.9 | 1544.6 | 1022.5 f32 B=16, M=512, H=16, K=64 | 2954.0 | 3117.1 | 5124.7 | 3157.6 f16 B=16, M=512, H=16, K=128 | 1735.4 | 1961.1 | 1982.7 | 2056.9 f32 B=16, M=512, H=16, K=128 | 6218.7 | 6396.4 | 6094.0 | 6600.3 f16 B=16, M=1024, H=16, K=16 | 2236.4 | 2319.4 | 4279.0 | 2331.6 f32 B=16, M=1024, H=16, K=16 | 8379.2 | 8363.9 | 16643.9 | 8503.6 f16 B=16, M=1024, H=16, K=32 | 2430.8 | 2649.6 | 4496.8 | 2608.7 f32 B=16, M=1024, H=16, K=32 | 8864.7 | 8907.8 | 17291.0 | 9074.0 f16 B=16, M=1024, H=16, K=64 | 3007.2 | 3351.3 | 4995.5 | 3351.0 f32 B=16, M=1024, H=16, K=64 | 11355.4 | 11627.1 | 18707.5 | 11694.3 f16 B=16, M=1024, H=16, K=128 | 6296.2 | 6748.7 | 5943.5 | 6967.0 f32 B=16, M=1024, H=16, K=128 | 23425.3 | 23360.0 | 21520.6 | 24169.7 f16 B=64, M=128, H=16, K=16 | 165.5 | 195.9 | 440.3 | 211.5 f32 B=64, M=128, H=16, K=16 | 497.4 | 540.7 | 1270.8 | 550.3 f16 B=64, M=128, H=16, K=32 | 210.4 | 274.9 | 544.8 | 298.5 f32 B=64, M=128, H=16, K=32 | 604.4 | 696.6 | 1428.3 | 710.9 f16 B=64, M=128, H=16, K=64 | 330.4 | 452.3 | 766.0 | 498.1 f32 B=64, M=128, H=16, K=64 | 883.4 | 1060.4 | 1745.2 | 1082.2 f16 B=64, M=128, H=16, K=128 | 605.5 | 847.8 | 1223.6 | 933.9 f32 B=64, M=128, H=16, K=128 | 1847.4 | 2169.7 | 2388.8 | 2236.0 f16 B=64, M=512, H=16, K=16 | 2004.7 | 2120.0 | 4487.0 | 2179.4 f32 B=64, M=512, H=16, K=16 | 6655.4 | 6818.8 | 16993.8 | 6872.1 f16 B=64, M=512, H=16, K=32 | 2379.3 | 2593.1 | 4957.2 | 2704.0 f32 B=64, M=512, H=16, K=32 | 7349.4 | 7644.6 | 17852.2 | 7736.2 f16 B=64, M=512, H=16, K=64 | 3129.6 | 3616.6 | 5888.8 | 3786.2 f32 B=64, M=512, H=16, K=64 | 9432.5 | 10123.9 | 19770.6 | 10178.5 f16 B=64, M=512, H=16, K=128 | 6054.1 | 7019.9 | 7712.6 | 7350.2 f32 B=64, M=512, H=16, K=128 | 21565.6 | 22281.9 | 23653.0 | 23084.4 f16 B=64, M=1024, H=16, K=16 | 7929.4 | 8199.1 | 16876.3 | 8242.5 f32 B=64, M=1024, H=16, K=16 | 26135.2 | 26347.9 | 66351.1 | 26639.0 f16 B=64, M=1024, H=16, K=32 | 8876.8 | 9450.0 | 17869.4 | 9473.5 f32 B=64, M=1024, H=16, K=32 | 27685.3 | 28104.6 | 69105.9 | 28428.7 f16 B=64, M=1024, H=16, K=64 | 11198.7 | 12180.5 | 19932.3 | 12543.4 f32 B=64, M=1024, H=16, K=64 | 34978.2 | 36239.4 | 74813.7 | 36482.4 f16 B=64, M=1024, H=16, K=128 | 21618.9 | 23439.6 | 23741.1 | 24160.1 f32 B=64, M=1024, H=16, K=128 | 80785.3 | 81080.8 | 86003.6 | 84132.9 Times are in microseconds (us). ``` </details> <details> <summary>P100/V100 bw (new benchmarks)</summary> ``` [---------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------------] | 48_chunk3_31735f94 | 45_bwpacked_e53c5f3a | vanilla | 47_bwpackedgrad_9bacdf65 1 threads: -------------------------------------------------------------------------------------------------------------------------------------- (Quadro_GP100) f16 B=384, M=197, H=1, K=88 | 6846.3 | 7583.8 | 3569.3 | 7599.5 f32 B=384, M=197, H=1, K=88 | 9883.1 | 10107.2 | 4312.8 | 10486.3 f16 B=384, M=197, H=1, K=80 | 6486.4 | 6997.7 | 3418.0 | 7037.3 f32 B=384, M=197, H=1, K=80 | 9330.3 | 9550.6 | 4094.7 | 9893.4 f16 B=384, M=197, H=1, K=64 | 3615.4 | 3930.4 | 2911.0 | 4074.2 f32 B=384, M=197, H=1, K=64 | 6281.4 | 6554.5 | 3431.9 | 6738.1 f16 B=1024, M=197, H=1, K=88 | 17226.8 | 18593.1 | 9733.2 | 18772.9 f32 B=1024, M=197, H=1, K=88 | 26593.3 | 27136.2 | 12033.8 | 28184.2 f16 B=1024, M=197, H=1, K=80 | 16330.1 | 17478.6 | 9270.2 | 17735.3 f32 B=1024, M=197, H=1, K=80 | 25208.9 | 25680.1 | 11224.5 | 26636.1 f16 B=1024, M=197, H=1, K=64 | 8889.1 | 9728.8 | 7646.1 | 10089.7 f32 B=1024, M=197, H=1, K=64 | 16914.7 | 17743.4 | 9383.8 | 18068.4 f16 B=512, M=197, H=1, K=80 | 8227.3 | 8878.4 | 4579.3 | 8953.6 f32 B=512, M=197, H=1, K=80 | 13078.7 | 13346.0 | 5486.4 | 13817.6 f16 B=32, M=197, H=16, K=80 | 8278.9 | 9002.9 | 4816.2 | 9025.6 f32 B=32, M=197, H=16, K=80 | 12913.8 | 13371.2 | 5777.7 | 13667.6 f16 B=32, M=197, H=16, K=64 | 4565.2 | 5000.0 | 4023.4 | 5146.3 f32 B=32, M=197, H=16, K=64 | 8824.0 | 9257.7 | 4797.2 | 9400.5 f16 B=32, M=197, H=16, K=128 | 9770.0 | 10849.7 | 5983.2 | 10932.0 f32 B=32, M=197, H=16, K=128 | 15715.2 | 16559.9 | 7513.6 | 16839.9 f16 B=256, M=197, H=1, K=88 | 5011.2 | 5363.8 | 2444.9 | 5426.0 f32 B=256, M=197, H=1, K=88 | 6918.7 | 7040.8 | 2867.8 | 7303.2 f16 B=16, M=197, H=16, K=88 | 4963.8 | 5343.9 | 2545.2 | 5398.9 f32 B=16, M=197, H=16, K=88 | 6727.9 | 6981.7 | 3040.3 | 7121.2 f16 B=16, M=197, H=16, K=64 | 2586.5 | 2777.1 | 2025.5 | 2905.6 f32 B=16, M=197, H=16, K=64 | 4404.3 | 4607.2 | 2431.1 | 4691.8 f16 B=16, M=197, H=16, K=128 | 5643.2 | 6194.1 | 3016.1 | 6216.3 f32 B=16, M=197, H=16, K=128 | 7887.1 | 8308.3 | 3676.6 | 8456.2 f16 B=1, M=4096, H=160, K=128 | 1087008.7 | 1115355.5 | | 1091596.8 f32 B=1, M=4096, H=160, K=128 | 1220066.8 | 1223422.8 | | 1227912.2 f16 B=2, M=4096, H=160, K=128 | 1734244.4 | 1794068.7 | | 1756266.7 f32 B=2, M=4096, H=160, K=128 | 2437675.5 | 2445780.4 | | 2451957.5 f16 B=1, M=8192, H=160, K=128 | 4367110.4 | 4466170.9 | | 4383747.4 f32 B=1, M=8192, H=160, K=128 | 4865732.9 | 4865708.9 | | 4887066.5 f16 B=2, M=8192, H=160, K=128 | 7002715.1 | 7146077.9 | | 7033922.8 f16 B=1024, M=82, H=8, K=64 | 23247.5 | 24929.5 | 18047.8 | 26928.2 f32 B=1024, M=82, H=8, K=64 | 46463.2 | 48705.6 | 22797.5 | 50736.3 f16 B=150, M=256, H=16, K=64 | 23467.9 | 25647.3 | 24569.2 | 26841.8 f32 B=150, M=256, H=16, K=64 | 36887.7 | 39698.0 | 32050.2 | 40389.0 f16 B=64, M=256, H=12, K=64 | 7723.7 | 8499.0 | 7702.1 | 8694.9 f32 B=64, M=256, H=12, K=64 | 11992.1 | 12819.9 | 9874.5 | 13107.9 f16 B=1, M=4096, H=16, K=40 | 142655.5 | 142899.7 | 28928.6 | 142922.7 f32 B=1, M=4096, H=16, K=40 | 142626.8 | 142685.3 | 37303.2 | 142541.0 f16 B=1, M=16384, H=16, K=40 | 2274095.0 | 2274882.0 | | 2275019.9 f32 B=1, M=16384, H=16, K=40 | 2284027.2 | 2279415.7 | | 2277761.9 f16 B=16, M=128, H=16, K=16 | 513.2 | 547.1 | 571.5 | 570.9 f32 B=16, M=128, H=16, K=16 | 667.4 | 704.3 | 693.1 | 728.0 f16 B=16, M=128, H=16, K=32 | 600.3 | 667.0 | 671.3 | 713.1 f32 B=16, M=128, H=16, K=32 | 823.9 | 888.9 | 823.5 | 937.3 f16 B=16, M=128, H=16, K=64 | 781.0 | 900.6 | 883.1 | 998.9 f32 B=16, M=128, H=16, K=64 | 1173.7 | 1293.8 | 1077.0 | 1393.4 f16 B=16, M=128, H=16, K=128 | 1649.2 | 1877.2 | 1323.2 | 2026.3 f32 B=16, M=128, H=16, K=128 | 2250.5 | 2473.0 | 1654.7 | 2636.6 f16 B=16, M=512, H=16, K=16 | 7709.3 | 7914.6 | 6945.1 | 7928.7 f32 B=16, M=512, H=16, K=16 | 9797.2 | 9950.5 | 8499.4 | 10029.3 f16 B=16, M=512, H=16, K=32 | 8956.9 | 9210.8 | 7517.1 | 9307.0 f32 B=16, M=512, H=16, K=32 | 11480.7 | 11710.9 | 9249.4 | 11884.4 f16 B=16, M=512, H=16, K=64 | 11324.0 | 11829.1 | 8849.5 | 12001.8 f32 B=16, M=512, H=16, K=64 | 15744.1 | 16258.0 | 10954.6 | 16481.1 f16 B=16, M=512, H=16, K=128 | 25320.2 | 26584.0 | 12412.3 | 26725.0 f32 B=16, M=512, H=16, K=128 | 31187.1 | 32290.3 | 15167.5 | 32818.4 f16 B=16, M=1024, H=16, K=16 | 31484.2 | 31601.4 | 26434.6 | 31894.6 f32 B=16, M=1024, H=16, K=16 | 38754.1 | 38900.1 | 32320.0 | 39203.9 f16 B=16, M=1024, H=16, K=32 | 36000.2 | 36672.6 | 28341.4 | 36579.5 f32 B=16, M=1024, H=16, K=32 | 45070.7 | 45262.3 | 34914.2 | 45774.5 f16 B=16, M=1024, H=16, K=64 | 45324.9 | 46540.4 | 32089.9 | 46784.2 f32 B=16, M=1024, H=16, K=64 | 61320.3 | 62411.1 | 39565.0 | 63217.0 f16 B=16, M=1024, H=16, K=128 | 104342.9 | 108469.4 | 43221.9 | 105620.6 f32 B=16, M=1024, H=16, K=128 | 122688.4 | 125050.9 | 51205.7 | 126080.9 f16 B=64, M=128, H=16, K=16 | 1707.9 | 1824.9 | 2106.4 | 1923.2 f32 B=64, M=128, H=16, K=16 | 2487.4 | 2612.5 | 2565.1 | 2707.6 f16 B=64, M=128, H=16, K=32 | 2016.8 | 2254.4 | 2485.4 | 2412.3 f32 B=64, M=128, H=16, K=32 | 3135.8 | 3365.6 | 3063.2 | 3518.5 f16 B=64, M=128, H=16, K=64 | 2700.2 | 3167.0 | 3306.0 | 3478.4 f32 B=64, M=128, H=16, K=64 | 4435.1 | 4944.7 | 4227.6 | 5181.2 f16 B=64, M=128, H=16, K=128 | 5769.1 | 6858.2 | 5299.8 | 7356.1 f32 B=64, M=128, H=16, K=128 | 8577.9 | 9672.0 | 6916.3 | 10093.5 f16 B=64, M=512, H=16, K=16 | 25994.0 | 26782.0 | 27240.9 | 26662.2 f32 B=64, M=512, H=16, K=16 | 36864.9 | 37299.3 | 34159.3 | 37576.7 f16 B=64, M=512, H=16, K=32 | 30680.4 | 32113.8 | 30109.0 | 32419.7 f32 B=64, M=512, H=16, K=32 | 43638.5 | 44557.9 | 37358.5 | 45145.0 f16 B=64, M=512, H=16, K=64 | 39417.5 | 41666.5 | 36004.2 | 42374.9 f32 B=64, M=512, H=16, K=64 | 60049.2 | 63148.0 | 43412.6 | 63286.8 f16 B=64, M=512, H=16, K=128 | 88951.1 | 93087.0 | 51730.1 | 94861.6 f32 B=64, M=512, H=16, K=128 | 119728.7 | 124340.3 | 62413.7 | 126382.2 f16 B=64, M=1024, H=16, K=16 | 108368.3 | 111081.8 | 106479.7 | 108716.1 f32 B=64, M=1024, H=16, K=16 | 145612.0 | 147310.4 | | 147380.7 f16 B=64, M=1024, H=16, K=32 | 124296.1 | 127366.8 | 113905.0 | 126975.3 f32 B=64, M=1024, H=16, K=32 | 171082.3 | 172539.0 | | 173893.9 f16 B=64, M=1024, H=16, K=64 | 155116.3 | 160429.2 | 130759.4 | 161834.0 f32 B=64, M=1024, H=16, K=64 | 234356.0 | 239612.2 | | 239948.3 f16 B=64, M=1024, H=16, K=128 | 349728.3 | 360975.7 | 176158.7 | 371185.2 f32 B=64, M=1024, H=16, K=128 | 468810.0 | 476415.4 | | 481908.5 (Tesla_V100_SXM2_16GB) f16 B=384, M=197, H=1, K=88 | 1700.3 | 1840.0 | 1375.3 | 1930.9 f32 B=384, M=197, H=1, K=88 | 4456.4 | 4579.3 | 2235.5 | 4708.6 f16 B=384, M=197, H=1, K=80 | 1623.3 | 1719.9 | 1279.5 | 1806.9 f32 B=384, M=197, H=1, K=80 | 4031.2 | 4141.9 | 2149.8 | 4252.6 f16 B=384, M=197, H=1, K=64 | 1092.8 | 1187.0 | 1048.5 | 1237.6 f32 B=384, M=197, H=1, K=64 | 2717.5 | 2918.5 | 1738.5 | 2907.9 f16 B=1024, M=197, H=1, K=88 | 4428.7 | 4906.2 | 3723.7 | 5178.2 f32 B=1024, M=197, H=1, K=88 | 10947.5 | 11362.9 | 6052.5 | 11802.1 f16 B=1024, M=197, H=1, K=80 | 4237.1 | 4491.4 | 3331.7 | 4725.6 f32 B=1024, M=197, H=1, K=80 | 9842.6 | 10159.7 | 5682.4 | 10435.6 f16 B=1024, M=197, H=1, K=64 | 2679.2 | 2927.4 | 2674.4 | 3033.0 f32 B=1024, M=197, H=1, K=64 | 6597.6 | 7154.9 | 4489.7 | 7063.1 f16 B=512, M=197, H=1, K=80 | 2239.5 | 2366.5 | 1684.2 | 2472.0 f32 B=512, M=197, H=1, K=80 | 5362.4 | 5519.6 | 2857.9 | 5651.4 f16 B=32, M=197, H=16, K=80 | 2208.1 | 2380.0 | 1803.4 | 2439.4 f32 B=32, M=197, H=16, K=80 | 5503.6 | 5736.7 | 3017.5 | 5796.2 f16 B=32, M=197, H=16, K=64 | 1493.4 | 1620.6 | 1457.2 | 1678.6 f32 B=32, M=197, H=16, K=64 | 3672.6 | 3941.6 | 2415.0 | 3898.2 f16 B=32, M=197, H=16, K=128 | 2634.3 | 2888.0 | 2215.1 | 2991.5 f32 B=32, M=197, H=16, K=128 | 6811.5 | 7334.0 | 4049.3 | 7261.9 f16 B=256, M=197, H=1, K=88 | 1290.3 | 1382.0 | 944.8 | 1449.4 f32 B=256, M=197, H=1, K=88 | 2965.8 | 3043.2 | 1528.7 | 3137.7 f16 B=16, M=197, H=16, K=88 | 1267.3 | 1357.0 | 970.8 | 1395.5 f32 B=16, M=197, H=16, K=88 | 2879.9 | 3014.7 | 1626.5 | 3054.3 f16 B=16, M=197, H=16, K=64 | 737.3 | 799.8 | 771.3 | 836.9 f32 B=16, M=197, H=16, K=64 | 1879.2 | 2000.9 | 1282.5 | 1994.5 f16 B=16, M=197, H=16, K=128 | 1443.9 | 1570.7 | 1142.2 | 1628.8 f32 B=16, M=197, H=16, K=128 | 3480.5 | 3723.6 | 2027.2 | 3714.6 f16 B=1, M=4096, H=160, K=128 | 150006.2 | 151877.5 | | 152570.6 f32 B=1, M=4096, H=160, K=128 | 582870.9 | 583519.8 | | 585570.1 f16 B=2, M=4096, H=160, K=128 | 301231.4 | 304511.7 | | 305801.2 f32 B=2, M=4096, H=160, K=128 | 1174724.1 | 1172498.4 | | 1176814.0 f16 B=1, M=8192, H=160, K=128 | 597461.6 | 600463.4 | | 603066.6 f32 B=1, M=8192, H=160, K=128 | 2333657.8 | 2329212.1 | | 2339766.1 f16 B=2, M=8192, H=160, K=128 | 1196837.5 | 1206932.4 | | 1209012.2 f16 B=1024, M=82, H=8, K=64 | 8926.8 | 9723.4 | 5799.4 | 10084.2 f32 B=1024, M=82, H=8, K=64 | 15920.4 | 17434.4 | 11027.0 | 17492.8 f16 B=150, M=256, H=16, K=64 | 5524.2 | 6363.9 | 7557.9 | 6586.2 f32 B=150, M=256, H=16, K=64 | 17506.9 | 18843.5 | 16263.5 | 18988.6 f16 B=64, M=256, H=12, K=64 | 1800.6 | 2050.3 | 2383.4 | 2139.0 f32 B=64, M=256, H=12, K=64 | 5753.6 | 6196.3 | 4971.2 | 6200.0 f16 B=1, M=4096, H=16, K=40 | 47649.5 | 47836.0 | 8368.4 | 47973.6 f32 B=1, M=4096, H=16, K=40 | 111092.1 | 111027.3 | 19475.9 | 111257.8 f16 B=1, M=16384, H=16, K=40 | 765320.2 | 765686.9 | | 767337.2 f32 B=1, M=16384, H=16, K=40 | 1769169.0 | 1769675.1 | | 1769371.4 f16 B=16, M=128, H=16, K=16 | 178.9 | 196.8 | 445.9 | 188.3 f32 B=16, M=128, H=16, K=16 | 301.3 | 319.1 | 422.5 | 336.3 f16 B=16, M=128, H=16, K=32 | 174.1 | 174.2 | 394.0 | 179.5 f32 B=16, M=128, H=16, K=32 | 395.7 | 433.2 | 580.0 | 440.4 f16 B=16, M=128, H=16, K=64 | 205.0 | 253.5 | 460.6 | 270.9 f32 B=16, M=128, H=16, K=64 | 573.7 | 639.3 | 598.1 | 656.1 f16 B=16, M=128, H=16, K=128 | 399.5 | 484.3 | 515.2 | 521.8 f32 B=16, M=128, H=16, K=128 | 1126.3 | 1260.8 | 1008.1 | 1282.4 f16 B=16, M=512, H=16, K=16 | 1597.6 | 1627.2 | 1901.1 | 1662.1 f32 B=16, M=512, H=16, K=16 | 4458.5 | 4528.8 | 4232.0 | 4559.4 f16 B=16, M=512, H=16, K=32 | 1819.1 | 1868.7 | 2097.2 | 1945.5 f32 B=16, M=512, H=16, K=32 | 5604.2 | 5757.1 | 4566.4 | 5784.8 f16 B=16, M=512, H=16, K=64 | 2345.5 | 2495.6 | 2558.0 | 2573.2 f32 B=16, M=512, H=16, K=64 | 7778.3 | 8017.1 | 5488.2 | 8083.7 f16 B=16, M=512, H=16, K=128 | 4516.6 | 4821.0 | 3386.7 | 4968.2 f32 B=16, M=512, H=16, K=128 | 15412.7 | 15959.2 | 8865.9 | 16047.5 f16 B=16, M=1024, H=16, K=16 | 6195.9 | 6217.6 | 6995.3 | 6326.4 f32 B=16, M=1024, H=16, K=16 | 18136.2 | 18312.0 | 16088.2 | 18354.1 f16 B=16, M=1024, H=16, K=32 | 7072.8 | 7122.3 | 7406.9 | 7297.7 f32 B=16, M=1024, H=16, K=32 | 22108.2 | 22116.7 | 17112.5 | 22436.8 f16 B=16, M=1024, H=16, K=64 | 8868.0 | 9104.6 | 8627.1 | 9311.8 f32 B=16, M=1024, H=16, K=64 | 30710.5 | 31041.3 | 19860.8 | 31338.1 f16 B=16, M=1024, H=16, K=128 | 17091.8 | 17655.5 | 10548.3 | 18083.8 f32 B=16, M=1024, H=16, K=128 | 60317.8 | 61461.7 | 32919.2 | 61548.8 f16 B=64, M=128, H=16, K=16 | 413.6 | 453.8 | 635.5 | 480.6 f32 B=64, M=128, H=16, K=16 | 1033.8 | 1114.3 | 1238.9 | 1119.5 f16 B=64, M=128, H=16, K=32 | 505.7 | 587.9 | 813.6 | 630.1 f32 B=64, M=128, H=16, K=32 | 1423.0 | 1551.4 | 1533.4 | 1581.8 f16 B=64, M=128, H=16, K=64 | 743.3 | 916.8 | 1187.7 | 976.5 f32 B=64, M=128, H=16, K=64 | 2093.3 | 2384.6 | 2156.3 | 2405.4 f16 B=64, M=128, H=16, K=128 | 1408.2 | 1734.3 | 1918.7 | 1859.6 f32 B=64, M=128, H=16, K=128 | 4125.3 | 4671.4 | 3762.0 | 4717.0 f16 B=64, M=512, H=16, K=16 | 5531.2 | 5643.3 | 7454.4 | 5770.8 f32 B=64, M=512, H=16, K=16 | 16214.0 | 16531.2 | 16661.3 | 16540.8 f16 B=64, M=512, H=16, K=32 | 6495.5 | 6725.2 | 8353.7 | 6941.8 f32 B=64, M=512, H=16, K=32 | 20520.6 | 20941.9 | 18352.4 | 21116.8 f16 B=64, M=512, H=16, K=64 | 8686.1 | 9278.6 | 10343.4 | 9593.2 f32 B=64, M=512, H=16, K=64 | 28891.1 | 30003.0 | 22749.4 | 30139.1 f16 B=64, M=512, H=16, K=128 | 15991.4 | 17412.3 | 14633.0 | 17848.2 f32 B=64, M=512, H=16, K=128 | 57526.8 | 59970.8 | 40089.9 | 60016.9 f16 B=64, M=1024, H=16, K=16 | 21552.8 | 21603.1 | 28447.1 | 22030.0 f32 B=64, M=1024, H=16, K=16 | 65321.2 | 65736.8 | | 65932.0 f16 B=64, M=1024, H=16, K=32 | 25695.4 | 25905.9 | 30592.1 | 26644.8 f32 B=64, M=1024, H=16, K=32 | 80213.4 | 80446.7 | | 81363.1 f16 B=64, M=1024, H=16, K=64 | 32465.6 | 33575.1 | 37233.4 | 34370.8 f32 B=64, M=1024, H=16, K=64 | 112996.7 | 115632.0 | | 115970.8 f16 B=64, M=1024, H=16, K=128 | 60363.5 | 62800.2 | 48883.7 | 64505.1 f32 B=64, M=1024, H=16, K=128 | 225023.4 | 230527.4 | | 229851.8 Times are in microseconds (us). ``` </details> [ghstack-poisoned]
**SUMMARY** Also: - updated benchmarks to reflect the real-world scenario where we chunk a packed qkv - for both fw/bw. - added coverage for chunking in tests **PERFORMANCE IMPACT** <details> <summary>A100 bw (new benchmarks)</summary> ``` [---------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------] | 48_chunk3_31735f9 | 45_bwpacked_e53c5f3 | vanilla | 47_bwpackedgrad_9bacdf6 1 threads: -------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 560.7 | 663.9 | 2265.7 | 710.3 f32 B=384, M=197, H=1, K=88 | 2445.1 | 2540.3 | 1843.3 | 2611.0 f16 B=384, M=197, H=1, K=80 | 530.4 | 619.9 | 1922.8 | 663.0 f32 B=384, M=197, H=1, K=80 | 2326.1 | 2425.2 | 1788.7 | 2476.4 f16 B=384, M=197, H=1, K=64 | 391.7 | 462.2 | 1812.7 | 492.8 f32 B=384, M=197, H=1, K=64 | 1275.0 | 1379.4 | 1675.4 | 1388.4 f16 B=1024, M=197, H=1, K=88 | 1399.5 | 1666.2 | 5965.2 | 1775.5 f32 B=1024, M=197, H=1, K=88 | 6332.5 | 6618.1 | 4559.6 | 6740.5 f16 B=1024, M=197, H=1, K=80 | 1326.2 | 1543.9 | 5041.4 | 1652.3 f32 B=1024, M=197, H=1, K=80 | 6057.1 | 6301.3 | 4411.6 | 6433.6 f16 B=1024, M=197, H=1, K=64 | 876.9 | 1063.1 | 4749.3 | 1133.2 f32 B=1024, M=197, H=1, K=64 | 3360.2 | 3629.0 | 4118.8 | 3652.0 f16 B=512, M=197, H=1, K=80 | 669.0 | 786.4 | 2544.9 | 842.2 f32 B=512, M=197, H=1, K=80 | 3032.3 | 3127.8 | 2287.4 | 3229.8 f16 B=32, M=197, H=16, K=80 | 663.0 | 789.7 | 2569.0 | 837.8 f32 B=32, M=197, H=16, K=80 | 3005.5 | 3166.3 | 2354.1 | 3225.9 f16 B=32, M=197, H=16, K=64 | 459.9 | 553.4 | 2436.3 | 591.9 f32 B=32, M=197, H=16, K=64 | 1814.1 | 1962.5 | 2197.3 | 1962.1 f16 B=32, M=197, H=16, K=128 | 792.5 | 981.9 | 4505.9 | 1056.5 f32 B=32, M=197, H=16, K=128 | 3734.8 | 3995.7 | 2805.8 | 4021.5 f16 B=256, M=197, H=1, K=88 | 413.4 | 482.6 | 1529.5 | 515.5 f32 B=256, M=197, H=1, K=88 | 1741.9 | 1818.3 | 1208.6 | 1852.4 f16 B=16, M=197, H=16, K=88 | 410.3 | 482.9 | 1545.7 | 512.5 f32 B=16, M=197, H=16, K=88 | 1734.9 | 1832.1 | 1250.6 | 1849.4 f16 B=16, M=197, H=16, K=64 | 235.4 | 286.0 | 1247.1 | 305.3 f32 B=16, M=197, H=16, K=64 | 1077.1 | 1143.7 | 1125.9 | 1154.0 f16 B=16, M=197, H=16, K=128 | 455.4 | 554.1 | 2273.1 | 596.0 f32 B=16, M=197, H=16, K=128 | 2028.9 | 2164.5 | 1446.7 | 2175.0 f16 B=1, M=4096, H=160, K=128 | 62454.4 | 63474.5 | 45930.5 | 64052.7 f32 B=1, M=4096, H=160, K=128 | 239035.4 | 232672.1 | | 240073.9 f16 B=2, M=4096, H=160, K=128 | 98791.3 | 101006.4 | | 101942.0 f32 B=2, M=4096, H=160, K=128 | 375914.9 | 368050.6 | | 381280.4 f16 B=1, M=8192, H=160, K=128 | 248498.9 | 250066.9 | | 251500.4 f32 B=1, M=8192, H=160, K=128 | 945102.2 | 922549.3 | | 949256.4 f16 B=2, M=8192, H=160, K=128 | 389207.8 | 394486.6 | | 396190.4 f32 B=2, M=8192, H=160, K=128 | 1496334.3 | 1449974.3 | | 1502215.3 f16 B=1024, M=82, H=8, K=64 | 1872.4 | 2503.8 | 3819.8 | 2693.7 f32 B=1024, M=82, H=8, K=64 | 8734.3 | 9637.8 | 8732.9 | 9672.2 f16 B=150, M=256, H=16, K=64 | 2126.4 | 2713.4 | 4554.3 | 2880.8 f32 B=150, M=256, H=16, K=64 | 6214.3 | 7052.2 | 12943.2 | 7099.2 f16 B=64, M=256, H=12, K=64 | 741.2 | 930.1 | 1493.0 | 990.6 f32 B=64, M=256, H=12, K=64 | 2144.2 | 2408.5 | 4267.7 | 2433.8 f16 B=1, M=4096, H=16, K=40 | 24583.7 | 24224.8 | 4195.2 | 24500.2 f32 B=1, M=4096, H=16, K=40 | 72497.9 | 72070.8 | 17744.1 | 72393.0 f16 B=1, M=16384, H=16, K=40 | 451481.8 | 439027.7 | | 451499.9 f32 B=1, M=16384, H=16, K=40 | 1169509.1 | 1164880.1 | | 1169769.3 f16 B=256, M=4096, H=16, K=64 | 597391.6 | 625921.0 | | 610433.2 f16 B=16, M=128, H=16, K=16 | 93.1 | 126.7 | 241.2 | 132.3 f32 B=16, M=128, H=16, K=16 | 184.1 | 176.5 | 373.8 | 180.7 f16 B=16, M=128, H=16, K=32 | 127.9 | 126.3 | 241.4 | 106.7 f32 B=16, M=128, H=16, K=32 | 194.1 | 216.6 | 412.7 | 225.8 f16 B=16, M=128, H=16, K=64 | 131.4 | 126.8 | 239.8 | 134.5 f32 B=16, M=128, H=16, K=64 | 280.4 | 326.0 | 500.0 | 334.0 f16 B=16, M=128, H=16, K=128 | 175.6 | 236.1 | 298.8 | 261.1 f32 B=16, M=128, H=16, K=128 | 531.8 | 615.8 | 677.2 | 638.0 f16 B=16, M=512, H=16, K=16 | 558.2 | 595.0 | 1201.9 | 607.8 f32 B=16, M=512, H=16, K=16 | 2146.7 | 2169.9 | 4416.1 | 2200.6 f16 B=16, M=512, H=16, K=32 | 653.5 | 732.3 | 1305.1 | 748.5 f32 B=16, M=512, H=16, K=32 | 2296.3 | 2373.9 | 4641.3 | 2400.1 f16 B=16, M=512, H=16, K=64 | 848.8 | 996.9 | 1544.6 | 1022.5 f32 B=16, M=512, H=16, K=64 | 2954.0 | 3117.1 | 5124.7 | 3157.6 f16 B=16, M=512, H=16, K=128 | 1735.4 | 1961.1 | 1982.7 | 2056.9 f32 B=16, M=512, H=16, K=128 | 6218.7 | 6396.4 | 6094.0 | 6600.3 f16 B=16, M=1024, H=16, K=16 | 2236.4 | 2319.4 | 4279.0 | 2331.6 f32 B=16, M=1024, H=16, K=16 | 8379.2 | 8363.9 | 16643.9 | 8503.6 f16 B=16, M=1024, H=16, K=32 | 2430.8 | 2649.6 | 4496.8 | 2608.7 f32 B=16, M=1024, H=16, K=32 | 8864.7 | 8907.8 | 17291.0 | 9074.0 f16 B=16, M=1024, H=16, K=64 | 3007.2 | 3351.3 | 4995.5 | 3351.0 f32 B=16, M=1024, H=16, K=64 | 11355.4 | 11627.1 | 18707.5 | 11694.3 f16 B=16, M=1024, H=16, K=128 | 6296.2 | 6748.7 | 5943.5 | 6967.0 f32 B=16, M=1024, H=16, K=128 | 23425.3 | 23360.0 | 21520.6 | 24169.7 f16 B=64, M=128, H=16, K=16 | 165.5 | 195.9 | 440.3 | 211.5 f32 B=64, M=128, H=16, K=16 | 497.4 | 540.7 | 1270.8 | 550.3 f16 B=64, M=128, H=16, K=32 | 210.4 | 274.9 | 544.8 | 298.5 f32 B=64, M=128, H=16, K=32 | 604.4 | 696.6 | 1428.3 | 710.9 f16 B=64, M=128, H=16, K=64 | 330.4 | 452.3 | 766.0 | 498.1 f32 B=64, M=128, H=16, K=64 | 883.4 | 1060.4 | 1745.2 | 1082.2 f16 B=64, M=128, H=16, K=128 | 605.5 | 847.8 | 1223.6 | 933.9 f32 B=64, M=128, H=16, K=128 | 1847.4 | 2169.7 | 2388.8 | 2236.0 f16 B=64, M=512, H=16, K=16 | 2004.7 | 2120.0 | 4487.0 | 2179.4 f32 B=64, M=512, H=16, K=16 | 6655.4 | 6818.8 | 16993.8 | 6872.1 f16 B=64, M=512, H=16, K=32 | 2379.3 | 2593.1 | 4957.2 | 2704.0 f32 B=64, M=512, H=16, K=32 | 7349.4 | 7644.6 | 17852.2 | 7736.2 f16 B=64, M=512, H=16, K=64 | 3129.6 | 3616.6 | 5888.8 | 3786.2 f32 B=64, M=512, H=16, K=64 | 9432.5 | 10123.9 | 19770.6 | 10178.5 f16 B=64, M=512, H=16, K=128 | 6054.1 | 7019.9 | 7712.6 | 7350.2 f32 B=64, M=512, H=16, K=128 | 21565.6 | 22281.9 | 23653.0 | 23084.4 f16 B=64, M=1024, H=16, K=16 | 7929.4 | 8199.1 | 16876.3 | 8242.5 f32 B=64, M=1024, H=16, K=16 | 26135.2 | 26347.9 | 66351.1 | 26639.0 f16 B=64, M=1024, H=16, K=32 | 8876.8 | 9450.0 | 17869.4 | 9473.5 f32 B=64, M=1024, H=16, K=32 | 27685.3 | 28104.6 | 69105.9 | 28428.7 f16 B=64, M=1024, H=16, K=64 | 11198.7 | 12180.5 | 19932.3 | 12543.4 f32 B=64, M=1024, H=16, K=64 | 34978.2 | 36239.4 | 74813.7 | 36482.4 f16 B=64, M=1024, H=16, K=128 | 21618.9 | 23439.6 | 23741.1 | 24160.1 f32 B=64, M=1024, H=16, K=128 | 80785.3 | 81080.8 | 86003.6 | 84132.9 Times are in microseconds (us). ``` </details> <details> <summary>P100/V100 bw (new benchmarks)</summary> ``` [---------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------------] | 48_chunk3_31735f94 | 45_bwpacked_e53c5f3a | vanilla | 47_bwpackedgrad_9bacdf65 1 threads: -------------------------------------------------------------------------------------------------------------------------------------- (Quadro_GP100) f16 B=384, M=197, H=1, K=88 | 6846.3 | 7583.8 | 3569.3 | 7599.5 f32 B=384, M=197, H=1, K=88 | 9883.1 | 10107.2 | 4312.8 | 10486.3 f16 B=384, M=197, H=1, K=80 | 6486.4 | 6997.7 | 3418.0 | 7037.3 f32 B=384, M=197, H=1, K=80 | 9330.3 | 9550.6 | 4094.7 | 9893.4 f16 B=384, M=197, H=1, K=64 | 3615.4 | 3930.4 | 2911.0 | 4074.2 f32 B=384, M=197, H=1, K=64 | 6281.4 | 6554.5 | 3431.9 | 6738.1 f16 B=1024, M=197, H=1, K=88 | 17226.8 | 18593.1 | 9733.2 | 18772.9 f32 B=1024, M=197, H=1, K=88 | 26593.3 | 27136.2 | 12033.8 | 28184.2 f16 B=1024, M=197, H=1, K=80 | 16330.1 | 17478.6 | 9270.2 | 17735.3 f32 B=1024, M=197, H=1, K=80 | 25208.9 | 25680.1 | 11224.5 | 26636.1 f16 B=1024, M=197, H=1, K=64 | 8889.1 | 9728.8 | 7646.1 | 10089.7 f32 B=1024, M=197, H=1, K=64 | 16914.7 | 17743.4 | 9383.8 | 18068.4 f16 B=512, M=197, H=1, K=80 | 8227.3 | 8878.4 | 4579.3 | 8953.6 f32 B=512, M=197, H=1, K=80 | 13078.7 | 13346.0 | 5486.4 | 13817.6 f16 B=32, M=197, H=16, K=80 | 8278.9 | 9002.9 | 4816.2 | 9025.6 f32 B=32, M=197, H=16, K=80 | 12913.8 | 13371.2 | 5777.7 | 13667.6 f16 B=32, M=197, H=16, K=64 | 4565.2 | 5000.0 | 4023.4 | 5146.3 f32 B=32, M=197, H=16, K=64 | 8824.0 | 9257.7 | 4797.2 | 9400.5 f16 B=32, M=197, H=16, K=128 | 9770.0 | 10849.7 | 5983.2 | 10932.0 f32 B=32, M=197, H=16, K=128 | 15715.2 | 16559.9 | 7513.6 | 16839.9 f16 B=256, M=197, H=1, K=88 | 5011.2 | 5363.8 | 2444.9 | 5426.0 f32 B=256, M=197, H=1, K=88 | 6918.7 | 7040.8 | 2867.8 | 7303.2 f16 B=16, M=197, H=16, K=88 | 4963.8 | 5343.9 | 2545.2 | 5398.9 f32 B=16, M=197, H=16, K=88 | 6727.9 | 6981.7 | 3040.3 | 7121.2 f16 B=16, M=197, H=16, K=64 | 2586.5 | 2777.1 | 2025.5 | 2905.6 f32 B=16, M=197, H=16, K=64 | 4404.3 | 4607.2 | 2431.1 | 4691.8 f16 B=16, M=197, H=16, K=128 | 5643.2 | 6194.1 | 3016.1 | 6216.3 f32 B=16, M=197, H=16, K=128 | 7887.1 | 8308.3 | 3676.6 | 8456.2 f16 B=1, M=4096, H=160, K=128 | 1087008.7 | 1115355.5 | | 1091596.8 f32 B=1, M=4096, H=160, K=128 | 1220066.8 | 1223422.8 | | 1227912.2 f16 B=2, M=4096, H=160, K=128 | 1734244.4 | 1794068.7 | | 1756266.7 f32 B=2, M=4096, H=160, K=128 | 2437675.5 | 2445780.4 | | 2451957.5 f16 B=1, M=8192, H=160, K=128 | 4367110.4 | 4466170.9 | | 4383747.4 f32 B=1, M=8192, H=160, K=128 | 4865732.9 | 4865708.9 | | 4887066.5 f16 B=2, M=8192, H=160, K=128 | 7002715.1 | 7146077.9 | | 7033922.8 f16 B=1024, M=82, H=8, K=64 | 23247.5 | 24929.5 | 18047.8 | 26928.2 f32 B=1024, M=82, H=8, K=64 | 46463.2 | 48705.6 | 22797.5 | 50736.3 f16 B=150, M=256, H=16, K=64 | 23467.9 | 25647.3 | 24569.2 | 26841.8 f32 B=150, M=256, H=16, K=64 | 36887.7 | 39698.0 | 32050.2 | 40389.0 f16 B=64, M=256, H=12, K=64 | 7723.7 | 8499.0 | 7702.1 | 8694.9 f32 B=64, M=256, H=12, K=64 | 11992.1 | 12819.9 | 9874.5 | 13107.9 f16 B=1, M=4096, H=16, K=40 | 142655.5 | 142899.7 | 28928.6 | 142922.7 f32 B=1, M=4096, H=16, K=40 | 142626.8 | 142685.3 | 37303.2 | 142541.0 f16 B=1, M=16384, H=16, K=40 | 2274095.0 | 2274882.0 | | 2275019.9 f32 B=1, M=16384, H=16, K=40 | 2284027.2 | 2279415.7 | | 2277761.9 f16 B=16, M=128, H=16, K=16 | 513.2 | 547.1 | 571.5 | 570.9 f32 B=16, M=128, H=16, K=16 | 667.4 | 704.3 | 693.1 | 728.0 f16 B=16, M=128, H=16, K=32 | 600.3 | 667.0 | 671.3 | 713.1 f32 B=16, M=128, H=16, K=32 | 823.9 | 888.9 | 823.5 | 937.3 f16 B=16, M=128, H=16, K=64 | 781.0 | 900.6 | 883.1 | 998.9 f32 B=16, M=128, H=16, K=64 | 1173.7 | 1293.8 | 1077.0 | 1393.4 f16 B=16, M=128, H=16, K=128 | 1649.2 | 1877.2 | 1323.2 | 2026.3 f32 B=16, M=128, H=16, K=128 | 2250.5 | 2473.0 | 1654.7 | 2636.6 f16 B=16, M=512, H=16, K=16 | 7709.3 | 7914.6 | 6945.1 | 7928.7 f32 B=16, M=512, H=16, K=16 | 9797.2 | 9950.5 | 8499.4 | 10029.3 f16 B=16, M=512, H=16, K=32 | 8956.9 | 9210.8 | 7517.1 | 9307.0 f32 B=16, M=512, H=16, K=32 | 11480.7 | 11710.9 | 9249.4 | 11884.4 f16 B=16, M=512, H=16, K=64 | 11324.0 | 11829.1 | 8849.5 | 12001.8 f32 B=16, M=512, H=16, K=64 | 15744.1 | 16258.0 | 10954.6 | 16481.1 f16 B=16, M=512, H=16, K=128 | 25320.2 | 26584.0 | 12412.3 | 26725.0 f32 B=16, M=512, H=16, K=128 | 31187.1 | 32290.3 | 15167.5 | 32818.4 f16 B=16, M=1024, H=16, K=16 | 31484.2 | 31601.4 | 26434.6 | 31894.6 f32 B=16, M=1024, H=16, K=16 | 38754.1 | 38900.1 | 32320.0 | 39203.9 f16 B=16, M=1024, H=16, K=32 | 36000.2 | 36672.6 | 28341.4 | 36579.5 f32 B=16, M=1024, H=16, K=32 | 45070.7 | 45262.3 | 34914.2 | 45774.5 f16 B=16, M=1024, H=16, K=64 | 45324.9 | 46540.4 | 32089.9 | 46784.2 f32 B=16, M=1024, H=16, K=64 | 61320.3 | 62411.1 | 39565.0 | 63217.0 f16 B=16, M=1024, H=16, K=128 | 104342.9 | 108469.4 | 43221.9 | 105620.6 f32 B=16, M=1024, H=16, K=128 | 122688.4 | 125050.9 | 51205.7 | 126080.9 f16 B=64, M=128, H=16, K=16 | 1707.9 | 1824.9 | 2106.4 | 1923.2 f32 B=64, M=128, H=16, K=16 | 2487.4 | 2612.5 | 2565.1 | 2707.6 f16 B=64, M=128, H=16, K=32 | 2016.8 | 2254.4 | 2485.4 | 2412.3 f32 B=64, M=128, H=16, K=32 | 3135.8 | 3365.6 | 3063.2 | 3518.5 f16 B=64, M=128, H=16, K=64 | 2700.2 | 3167.0 | 3306.0 | 3478.4 f32 B=64, M=128, H=16, K=64 | 4435.1 | 4944.7 | 4227.6 | 5181.2 f16 B=64, M=128, H=16, K=128 | 5769.1 | 6858.2 | 5299.8 | 7356.1 f32 B=64, M=128, H=16, K=128 | 8577.9 | 9672.0 | 6916.3 | 10093.5 f16 B=64, M=512, H=16, K=16 | 25994.0 | 26782.0 | 27240.9 | 26662.2 f32 B=64, M=512, H=16, K=16 | 36864.9 | 37299.3 | 34159.3 | 37576.7 f16 B=64, M=512, H=16, K=32 | 30680.4 | 32113.8 | 30109.0 | 32419.7 f32 B=64, M=512, H=16, K=32 | 43638.5 | 44557.9 | 37358.5 | 45145.0 f16 B=64, M=512, H=16, K=64 | 39417.5 | 41666.5 | 36004.2 | 42374.9 f32 B=64, M=512, H=16, K=64 | 60049.2 | 63148.0 | 43412.6 | 63286.8 f16 B=64, M=512, H=16, K=128 | 88951.1 | 93087.0 | 51730.1 | 94861.6 f32 B=64, M=512, H=16, K=128 | 119728.7 | 124340.3 | 62413.7 | 126382.2 f16 B=64, M=1024, H=16, K=16 | 108368.3 | 111081.8 | 106479.7 | 108716.1 f32 B=64, M=1024, H=16, K=16 | 145612.0 | 147310.4 | | 147380.7 f16 B=64, M=1024, H=16, K=32 | 124296.1 | 127366.8 | 113905.0 | 126975.3 f32 B=64, M=1024, H=16, K=32 | 171082.3 | 172539.0 | | 173893.9 f16 B=64, M=1024, H=16, K=64 | 155116.3 | 160429.2 | 130759.4 | 161834.0 f32 B=64, M=1024, H=16, K=64 | 234356.0 | 239612.2 | | 239948.3 f16 B=64, M=1024, H=16, K=128 | 349728.3 | 360975.7 | 176158.7 | 371185.2 f32 B=64, M=1024, H=16, K=128 | 468810.0 | 476415.4 | | 481908.5 (Tesla_V100_SXM2_16GB) f16 B=384, M=197, H=1, K=88 | 1700.3 | 1840.0 | 1375.3 | 1930.9 f32 B=384, M=197, H=1, K=88 | 4456.4 | 4579.3 | 2235.5 | 4708.6 f16 B=384, M=197, H=1, K=80 | 1623.3 | 1719.9 | 1279.5 | 1806.9 f32 B=384, M=197, H=1, K=80 | 4031.2 | 4141.9 | 2149.8 | 4252.6 f16 B=384, M=197, H=1, K=64 | 1092.8 | 1187.0 | 1048.5 | 1237.6 f32 B=384, M=197, H=1, K=64 | 2717.5 | 2918.5 | 1738.5 | 2907.9 f16 B=1024, M=197, H=1, K=88 | 4428.7 | 4906.2 | 3723.7 | 5178.2 f32 B=1024, M=197, H=1, K=88 | 10947.5 | 11362.9 | 6052.5 | 11802.1 f16 B=1024, M=197, H=1, K=80 | 4237.1 | 4491.4 | 3331.7 | 4725.6 f32 B=1024, M=197, H=1, K=80 | 9842.6 | 10159.7 | 5682.4 | 10435.6 f16 B=1024, M=197, H=1, K=64 | 2679.2 | 2927.4 | 2674.4 | 3033.0 f32 B=1024, M=197, H=1, K=64 | 6597.6 | 7154.9 | 4489.7 | 7063.1 f16 B=512, M=197, H=1, K=80 | 2239.5 | 2366.5 | 1684.2 | 2472.0 f32 B=512, M=197, H=1, K=80 | 5362.4 | 5519.6 | 2857.9 | 5651.4 f16 B=32, M=197, H=16, K=80 | 2208.1 | 2380.0 | 1803.4 | 2439.4 f32 B=32, M=197, H=16, K=80 | 5503.6 | 5736.7 | 3017.5 | 5796.2 f16 B=32, M=197, H=16, K=64 | 1493.4 | 1620.6 | 1457.2 | 1678.6 f32 B=32, M=197, H=16, K=64 | 3672.6 | 3941.6 | 2415.0 | 3898.2 f16 B=32, M=197, H=16, K=128 | 2634.3 | 2888.0 | 2215.1 | 2991.5 f32 B=32, M=197, H=16, K=128 | 6811.5 | 7334.0 | 4049.3 | 7261.9 f16 B=256, M=197, H=1, K=88 | 1290.3 | 1382.0 | 944.8 | 1449.4 f32 B=256, M=197, H=1, K=88 | 2965.8 | 3043.2 | 1528.7 | 3137.7 f16 B=16, M=197, H=16, K=88 | 1267.3 | 1357.0 | 970.8 | 1395.5 f32 B=16, M=197, H=16, K=88 | 2879.9 | 3014.7 | 1626.5 | 3054.3 f16 B=16, M=197, H=16, K=64 | 737.3 | 799.8 | 771.3 | 836.9 f32 B=16, M=197, H=16, K=64 | 1879.2 | 2000.9 | 1282.5 | 1994.5 f16 B=16, M=197, H=16, K=128 | 1443.9 | 1570.7 | 1142.2 | 1628.8 f32 B=16, M=197, H=16, K=128 | 3480.5 | 3723.6 | 2027.2 | 3714.6 f16 B=1, M=4096, H=160, K=128 | 150006.2 | 151877.5 | | 152570.6 f32 B=1, M=4096, H=160, K=128 | 582870.9 | 583519.8 | | 585570.1 f16 B=2, M=4096, H=160, K=128 | 301231.4 | 304511.7 | | 305801.2 f32 B=2, M=4096, H=160, K=128 | 1174724.1 | 1172498.4 | | 1176814.0 f16 B=1, M=8192, H=160, K=128 | 597461.6 | 600463.4 | | 603066.6 f32 B=1, M=8192, H=160, K=128 | 2333657.8 | 2329212.1 | | 2339766.1 f16 B=2, M=8192, H=160, K=128 | 1196837.5 | 1206932.4 | | 1209012.2 f16 B=1024, M=82, H=8, K=64 | 8926.8 | 9723.4 | 5799.4 | 10084.2 f32 B=1024, M=82, H=8, K=64 | 15920.4 | 17434.4 | 11027.0 | 17492.8 f16 B=150, M=256, H=16, K=64 | 5524.2 | 6363.9 | 7557.9 | 6586.2 f32 B=150, M=256, H=16, K=64 | 17506.9 | 18843.5 | 16263.5 | 18988.6 f16 B=64, M=256, H=12, K=64 | 1800.6 | 2050.3 | 2383.4 | 2139.0 f32 B=64, M=256, H=12, K=64 | 5753.6 | 6196.3 | 4971.2 | 6200.0 f16 B=1, M=4096, H=16, K=40 | 47649.5 | 47836.0 | 8368.4 | 47973.6 f32 B=1, M=4096, H=16, K=40 | 111092.1 | 111027.3 | 19475.9 | 111257.8 f16 B=1, M=16384, H=16, K=40 | 765320.2 | 765686.9 | | 767337.2 f32 B=1, M=16384, H=16, K=40 | 1769169.0 | 1769675.1 | | 1769371.4 f16 B=16, M=128, H=16, K=16 | 178.9 | 196.8 | 445.9 | 188.3 f32 B=16, M=128, H=16, K=16 | 301.3 | 319.1 | 422.5 | 336.3 f16 B=16, M=128, H=16, K=32 | 174.1 | 174.2 | 394.0 | 179.5 f32 B=16, M=128, H=16, K=32 | 395.7 | 433.2 | 580.0 | 440.4 f16 B=16, M=128, H=16, K=64 | 205.0 | 253.5 | 460.6 | 270.9 f32 B=16, M=128, H=16, K=64 | 573.7 | 639.3 | 598.1 | 656.1 f16 B=16, M=128, H=16, K=128 | 399.5 | 484.3 | 515.2 | 521.8 f32 B=16, M=128, H=16, K=128 | 1126.3 | 1260.8 | 1008.1 | 1282.4 f16 B=16, M=512, H=16, K=16 | 1597.6 | 1627.2 | 1901.1 | 1662.1 f32 B=16, M=512, H=16, K=16 | 4458.5 | 4528.8 | 4232.0 | 4559.4 f16 B=16, M=512, H=16, K=32 | 1819.1 | 1868.7 | 2097.2 | 1945.5 f32 B=16, M=512, H=16, K=32 | 5604.2 | 5757.1 | 4566.4 | 5784.8 f16 B=16, M=512, H=16, K=64 | 2345.5 | 2495.6 | 2558.0 | 2573.2 f32 B=16, M=512, H=16, K=64 | 7778.3 | 8017.1 | 5488.2 | 8083.7 f16 B=16, M=512, H=16, K=128 | 4516.6 | 4821.0 | 3386.7 | 4968.2 f32 B=16, M=512, H=16, K=128 | 15412.7 | 15959.2 | 8865.9 | 16047.5 f16 B=16, M=1024, H=16, K=16 | 6195.9 | 6217.6 | 6995.3 | 6326.4 f32 B=16, M=1024, H=16, K=16 | 18136.2 | 18312.0 | 16088.2 | 18354.1 f16 B=16, M=1024, H=16, K=32 | 7072.8 | 7122.3 | 7406.9 | 7297.7 f32 B=16, M=1024, H=16, K=32 | 22108.2 | 22116.7 | 17112.5 | 22436.8 f16 B=16, M=1024, H=16, K=64 | 8868.0 | 9104.6 | 8627.1 | 9311.8 f32 B=16, M=1024, H=16, K=64 | 30710.5 | 31041.3 | 19860.8 | 31338.1 f16 B=16, M=1024, H=16, K=128 | 17091.8 | 17655.5 | 10548.3 | 18083.8 f32 B=16, M=1024, H=16, K=128 | 60317.8 | 61461.7 | 32919.2 | 61548.8 f16 B=64, M=128, H=16, K=16 | 413.6 | 453.8 | 635.5 | 480.6 f32 B=64, M=128, H=16, K=16 | 1033.8 | 1114.3 | 1238.9 | 1119.5 f16 B=64, M=128, H=16, K=32 | 505.7 | 587.9 | 813.6 | 630.1 f32 B=64, M=128, H=16, K=32 | 1423.0 | 1551.4 | 1533.4 | 1581.8 f16 B=64, M=128, H=16, K=64 | 743.3 | 916.8 | 1187.7 | 976.5 f32 B=64, M=128, H=16, K=64 | 2093.3 | 2384.6 | 2156.3 | 2405.4 f16 B=64, M=128, H=16, K=128 | 1408.2 | 1734.3 | 1918.7 | 1859.6 f32 B=64, M=128, H=16, K=128 | 4125.3 | 4671.4 | 3762.0 | 4717.0 f16 B=64, M=512, H=16, K=16 | 5531.2 | 5643.3 | 7454.4 | 5770.8 f32 B=64, M=512, H=16, K=16 | 16214.0 | 16531.2 | 16661.3 | 16540.8 f16 B=64, M=512, H=16, K=32 | 6495.5 | 6725.2 | 8353.7 | 6941.8 f32 B=64, M=512, H=16, K=32 | 20520.6 | 20941.9 | 18352.4 | 21116.8 f16 B=64, M=512, H=16, K=64 | 8686.1 | 9278.6 | 10343.4 | 9593.2 f32 B=64, M=512, H=16, K=64 | 28891.1 | 30003.0 | 22749.4 | 30139.1 f16 B=64, M=512, H=16, K=128 | 15991.4 | 17412.3 | 14633.0 | 17848.2 f32 B=64, M=512, H=16, K=128 | 57526.8 | 59970.8 | 40089.9 | 60016.9 f16 B=64, M=1024, H=16, K=16 | 21552.8 | 21603.1 | 28447.1 | 22030.0 f32 B=64, M=1024, H=16, K=16 | 65321.2 | 65736.8 | | 65932.0 f16 B=64, M=1024, H=16, K=32 | 25695.4 | 25905.9 | 30592.1 | 26644.8 f32 B=64, M=1024, H=16, K=32 | 80213.4 | 80446.7 | | 81363.1 f16 B=64, M=1024, H=16, K=64 | 32465.6 | 33575.1 | 37233.4 | 34370.8 f32 B=64, M=1024, H=16, K=64 | 112996.7 | 115632.0 | | 115970.8 f16 B=64, M=1024, H=16, K=128 | 60363.5 | 62800.2 | 48883.7 | 64505.1 f32 B=64, M=1024, H=16, K=128 | 225023.4 | 230527.4 | | 229851.8 Times are in microseconds (us). ``` </details> [ghstack-poisoned]
ghstack-source-id: e18e9b73589eac1e003c0b224bbff03c7fbb6445 Pull Request resolved: #458
Stack from ghstack (oldest at bottom):
SUMMARY
Also:
PERFORMANCE IMPACT
A100 bw (new benchmarks)
P100/V100 bw (new benchmarks)