Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MMA_FROM_SMEM_IT_RES] Change how residuals are handled #393

Merged
merged 1 commit into from
Sep 15, 2022
Merged

Conversation

danthe3rd
Copy link
Contributor

@danthe3rd danthe3rd commented Sep 15, 2022

Changes how we deal with residuals for the right-hand operand of MM from shared memory. This is used on the fw and bw pass, and enables to support arbitrary sequence lengths.

image

P100/V100 bw (causal)
[-------------- attention backward (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ---------------]
                                                             |  a10cd6d8_it2  |  f527a89c_main  |  vanilla
1 threads: -----------------------------------------------------------------------------------------------
  (Quadro_GP100)          f32 fwd_gen B=256, M=128, K=16     |       491.5    |        492.0    |    617.9
                          f16 fwd_gen B=256, M=128, K=16     |       421.3    |        420.8    |    503.1
                          f32 fwd_gen B=256, M=128, K=32     |       601.2    |        601.9    |    696.0
                          f16 fwd_gen B=256, M=128, K=32     |       511.6    |        508.4    |    553.4
                          f32 fwd_gen B=256, M=128, K=64     |       841.3    |        850.2    |    850.9
                          f16 fwd_gen B=256, M=128, K=64     |       742.1    |        737.6    |    676.7
                          f32 fwd_gen B=256, M=128, K=128    |      1702.2    |       1711.5    |   1224.1
                          f16 fwd_gen B=256, M=128, K=128    |      1423.3    |       1426.1    |    964.2
                          f32 fwd_gen B=256, M=128, K=256    |      3639.1    |       3601.9    |   2298.3
                          f16 fwd_gen B=256, M=128, K=256    |      2816.3    |       2799.7    |   1754.8
                          f32 fwd_gen B=256, M=512, K=16     |      5445.3    |       5460.3    |   8198.0
                          f16 fwd_gen B=256, M=512, K=16     |      4839.9    |       4771.2    |   6605.2
                          f32 fwd_gen B=256, M=512, K=32     |      6428.9    |       6335.3    |   8687.9
                          f16 fwd_gen B=256, M=512, K=32     |      5554.2    |       5499.4    |   7055.3
                          f32 fwd_gen B=256, M=512, K=64     |      8840.8    |       8815.6    |   9725.9
                          f16 fwd_gen B=256, M=512, K=64     |      7330.8    |       7368.8    |   8046.3
                          f32 fwd_gen B=256, M=512, K=128    |     18242.3    |      18220.8    |  13041.1
                          f16 fwd_gen B=256, M=512, K=128    |     14627.2    |      14588.6    |  10608.1
                          f32 fwd_gen B=256, M=512, K=256    |     38261.6    |      38295.7    |  22055.6
                          f16 fwd_gen B=256, M=512, K=256    |     29581.4    |      29473.5    |  19099.3
                          f32 fwd_gen B=256, M=1024, K=16    |     20568.3    |      20742.2    |  31831.5
                          f16 fwd_gen B=256, M=1024, K=16    |     17906.7    |      17841.9    |  25815.6
                          f32 fwd_gen B=256, M=1024, K=32    |     23494.5    |      23356.2    |  34184.1
                          f16 fwd_gen B=256, M=1024, K=32    |     20496.6    |      20450.6    |  27308.5
                          f32 fwd_gen B=256, M=1024, K=64    |     31756.2    |      31767.4    |  37537.8
                          f16 fwd_gen B=256, M=1024, K=64    |     26557.4    |      26525.8    |  30106.0
                          f32 fwd_gen B=256, M=1024, K=128   |     68112.5    |      69320.7    |  46237.6
                          f16 fwd_gen B=256, M=1024, K=128   |     53740.2    |      53276.6    |  38945.4
                          f32 fwd_gen B=256, M=1024, K=256   |    141404.8    |     140386.0    |  82240.1
                          f16 fwd_gen B=256, M=1024, K=256   |    114221.2    |     111502.8    |  71721.4
                          f32 fwd_gen B=384, M=192, K=88     |      3898.1    |       3892.3    |   3088.5
                          f16 fwd_gen B=384, M=192, K=88     |      3113.6    |       3096.6    |   2585.9
                          f32 fwd_gen B=768, M=256, K=64     |      7524.9    |       7493.2    |   7958.8
                          f16 fwd_gen B=768, M=256, K=64     |      5551.7    |       5508.8    |   6479.4
                          f32 fwd_gen B=1024, M=128, K=16    |      1839.4    |       1845.4    |   2281.5
                          f16 fwd_gen B=1024, M=128, K=16    |      1458.4    |       1449.3    |   1841.8
                          f32 fwd_gen B=1024, M=128, K=32    |      2227.7    |       2229.4    |   2552.9
                          f16 fwd_gen B=1024, M=128, K=32    |      1803.0    |       1797.9    |   2053.4
                          f32 fwd_gen B=1024, M=128, K=64    |      3211.3    |       3209.2    |   3201.2
                          f16 fwd_gen B=1024, M=128, K=64    |      2519.9    |       2505.5    |   2509.3
                          f32 fwd_gen B=1024, M=128, K=128   |      6654.5    |       6687.9    |   4734.8
                          f16 fwd_gen B=1024, M=128, K=128   |      4974.6    |       4953.8    |   3633.1
                          f32 fwd_gen B=1024, M=128, K=256   |     14083.0    |      14076.6    |   9151.0
                          f16 fwd_gen B=1024, M=128, K=256   |     10069.9    |      10112.1    |   6837.4
                          f32 fwd_gen B=1024, M=512, K=16    |     20592.3    |      20653.7    |  32117.0
                          f16 fwd_gen B=1024, M=512, K=16    |     16158.9    |      16017.6    |  26148.7
                          f32 fwd_gen B=1024, M=512, K=32    |     23975.4    |      23959.9    |  34328.8
                          f16 fwd_gen B=1024, M=512, K=32    |     18907.0    |      18695.0    |  27836.2
                          f32 fwd_gen B=1024, M=512, K=64    |     33564.5    |      33562.3    |  39097.4
                          f16 fwd_gen B=1024, M=512, K=64    |     25553.2    |      25478.8    |  31379.5
                          f32 fwd_gen B=1024, M=512, K=128   |     72372.6    |      70788.2    |  50787.9
                          f16 fwd_gen B=1024, M=512, K=128   |     51198.4    |      51024.4    |  41799.5
                          f32 fwd_gen B=1024, M=512, K=256   |    148213.2    |     148020.5    |  90162.0
                          f16 fwd_gen B=1024, M=512, K=256   |    108821.8    |     105928.6    |  78079.1
                          f32 fwd_gen B=1024, M=1024, K=16   |     77081.4    |      77306.8    |         
                          f16 fwd_gen B=1024, M=1024, K=16   |     60747.4    |      59802.1    |         
                          f32 fwd_gen B=1024, M=1024, K=32   |     89675.4    |      88745.7    |         
                          f16 fwd_gen B=1024, M=1024, K=32   |     69476.2    |      69799.9    |         
                          f32 fwd_gen B=1024, M=1024, K=64   |    123574.5    |     123222.1    |         
                          f16 fwd_gen B=1024, M=1024, K=64   |     92629.0    |      91552.0    |         
                          f32 fwd_gen B=1024, M=1024, K=128  |    254993.6    |     256413.6    |         
                          f16 fwd_gen B=1024, M=1024, K=128  |    186495.9    |     185232.1    |         
                          f32 fwd_gen B=1024, M=1024, K=256  |    528065.3    |     528882.8    |         
                          f16 fwd_gen B=1024, M=1024, K=256  |    373521.3    |     372821.6    |         
                          f32 fwd_gen B=2400, M=256, K=64    |     22872.9    |      22901.1    |  24891.9
                          f16 fwd_gen B=2400, M=256, K=64    |     17271.5    |      17289.8    |  19963.9
                          f32 fwd_gen B=8192, M=82, K=64     |     20708.5    |      21386.1    |  16303.6
                          f16 fwd_gen B=8192, M=82, K=64     |     15182.5    |      15532.0    |  13613.3
  (Tesla_V100_SXM2_16GB)  f32 fwd_gen B=256, M=128, K=16     |       223.3    |        222.7    |    312.4
                          f16 fwd_gen B=256, M=128, K=16     |       190.0    |        200.1    |    214.5
                          f32 fwd_gen B=256, M=128, K=32     |       299.3    |        297.6    |    367.1
                          f16 fwd_gen B=256, M=128, K=32     |       184.9    |        197.2    |    215.2
                          f32 fwd_gen B=256, M=128, K=64     |       438.5    |        438.5    |    493.1
                          f16 fwd_gen B=256, M=128, K=64     |       248.5    |        252.7    |    248.8
                          f32 fwd_gen B=256, M=128, K=128    |       862.9    |        862.3    |    804.1
                          f16 fwd_gen B=256, M=128, K=128    |       470.4    |        474.2    |    375.0
                          f32 fwd_gen B=256, M=128, K=256    |      1766.3    |       1747.0    |   1423.4
                          f16 fwd_gen B=256, M=128, K=256    |       934.8    |        945.4    |    625.7
                          f32 fwd_gen B=256, M=512, K=16     |      2482.0    |       2479.2    |   4134.9
                          f16 fwd_gen B=256, M=512, K=16     |      1021.6    |       1050.9    |   1833.9
                          f32 fwd_gen B=256, M=512, K=32     |      3060.3    |       3064.4    |   4423.2
                          f16 fwd_gen B=256, M=512, K=32     |      1194.5    |       1226.1    |   1963.9
                          f32 fwd_gen B=256, M=512, K=64     |      4273.0    |       4273.3    |   5084.4
                          f16 fwd_gen B=256, M=512, K=64     |      1714.6    |       1733.5    |   2236.1
                          f32 fwd_gen B=256, M=512, K=128    |      8589.1    |       8710.6    |   7943.4
                          f16 fwd_gen B=256, M=512, K=128    |      3455.9    |       3500.4    |   2791.1
                          f32 fwd_gen B=256, M=512, K=256    |     17575.7    |      17718.5    |  13739.0
                          f16 fwd_gen B=256, M=512, K=256    |      7501.9    |       7472.3    |   4256.2
                          f32 fwd_gen B=256, M=1024, K=16    |      9279.7    |       9197.9    |  15967.0
                          f16 fwd_gen B=256, M=1024, K=16    |      3578.6    |       3712.8    |   7101.6
                          f32 fwd_gen B=256, M=1024, K=32    |     11143.4    |      11146.1    |  16831.0
                          f16 fwd_gen B=256, M=1024, K=32    |      3987.4    |       4114.0    |   7385.4
                          f32 fwd_gen B=256, M=1024, K=64    |     15383.0    |      15399.9    |  19120.3
                          f16 fwd_gen B=256, M=1024, K=64    |      5489.8    |       5555.1    |   8098.7
                          f32 fwd_gen B=256, M=1024, K=128   |     30973.3    |      31137.8    |  30084.6
                          f16 fwd_gen B=256, M=1024, K=128   |     11406.9    |      11528.1    |   9305.6
                          f32 fwd_gen B=256, M=1024, K=256   |     63638.0    |      63786.4    |  52073.9
                          f16 fwd_gen B=256, M=1024, K=256   |     24738.8    |      25077.1    |  13977.8
                          f32 fwd_gen B=384, M=192, K=88     |      2009.7    |       2011.2    |   1836.4
                          f16 fwd_gen B=384, M=192, K=88     |       925.9    |        937.0    |    791.3
                          f32 fwd_gen B=768, M=256, K=64     |      3465.7    |       3465.3    |   4282.2
                          f16 fwd_gen B=768, M=256, K=64     |      1636.1    |       1656.9    |   1941.1
                          f32 fwd_gen B=1024, M=128, K=16    |       785.4    |        784.2    |   1133.8
                          f16 fwd_gen B=1024, M=128, K=16    |       417.6    |        425.6    |    556.5
                          f32 fwd_gen B=1024, M=128, K=32    |      1048.1    |       1043.7    |   1330.6
                          f16 fwd_gen B=1024, M=128, K=32    |       546.9    |        557.3    |    664.8
                          f32 fwd_gen B=1024, M=128, K=64    |      1564.9    |       1567.2    |   1722.9
                          f16 fwd_gen B=1024, M=128, K=64    |       887.5    |        889.6    |    871.0
                          f32 fwd_gen B=1024, M=128, K=128   |      3146.0    |       3150.2    |   2829.7
                          f16 fwd_gen B=1024, M=128, K=128   |      1710.1    |       1724.7    |   1330.1
                          f32 fwd_gen B=1024, M=128, K=256   |      6434.9    |       6404.2    |   4975.8
                          f16 fwd_gen B=1024, M=128, K=256   |      3524.8    |       3545.7    |   2318.9
                          f32 fwd_gen B=1024, M=512, K=16    |      8690.6    |       8724.7    |  16184.6
                          f16 fwd_gen B=1024, M=512, K=16    |      3551.4    |       3666.2    |   7196.0
                          f32 fwd_gen B=1024, M=512, K=32    |     10998.2    |      10968.6    |  17474.0
                          f16 fwd_gen B=1024, M=512, K=32    |      4214.3    |       4340.3    |   7718.5
                          f32 fwd_gen B=1024, M=512, K=64    |     15607.6    |      15588.3    |  20109.1
                          f16 fwd_gen B=1024, M=512, K=64    |      6278.2    |       6306.9    |   8848.6
                          f32 fwd_gen B=1024, M=512, K=128   |     31843.1    |      31895.7    |  31883.1
                          f16 fwd_gen B=1024, M=512, K=128   |     12949.5    |      13040.6    |  10935.8
                          f32 fwd_gen B=1024, M=512, K=256   |     64232.1    |      64392.1    |  57454.4
                          f16 fwd_gen B=1024, M=512, K=256   |     28485.0    |      28649.6    |  16892.5
                          f32 fwd_gen B=1024, M=1024, K=16   |     32344.7    |      32473.0    |         
                          f16 fwd_gen B=1024, M=1024, K=16   |     12512.2    |      13015.0    |         
                          f32 fwd_gen B=1024, M=1024, K=32   |     39991.8    |      40016.3    |         
                          f16 fwd_gen B=1024, M=1024, K=32   |     14136.2    |      14540.4    |         
                          f32 fwd_gen B=1024, M=1024, K=64   |     56008.0    |      56205.1    |         
                          f16 fwd_gen B=1024, M=1024, K=64   |     20370.2    |      20674.4    |         
                          f32 fwd_gen B=1024, M=1024, K=128  |    115732.2    |     116045.7    |         
                          f16 fwd_gen B=1024, M=1024, K=128  |     42625.9    |      42965.9    |         
                          f32 fwd_gen B=1024, M=1024, K=256  |    229918.5    |     233529.9    |         
                          f16 fwd_gen B=1024, M=1024, K=256  |     95735.7    |      96150.7    |         
                          f32 fwd_gen B=2400, M=256, K=64    |     10758.2    |      10790.8    |  12962.1
                          f16 fwd_gen B=2400, M=256, K=64    |      4957.5    |       5006.3    |   5934.6
                          f32 fwd_gen B=8192, M=82, K=64     |      8913.0    |                 |   7703.6
                          f16 fwd_gen B=8192, M=82, K=64     |      5097.3    |                 |   3926.0

Times are in microseconds (us).
P100/V100 bw
[-------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------]
                                                             |  a10cd6d8_it2  |  f527a89c_main  |  vanilla
1 threads: -----------------------------------------------------------------------------------------------
  (Quadro_GP100)          f32 fwd_gen B=256, M=128, K=16     |       603.6    |        603.5    |    617.8
                          f16 fwd_gen B=256, M=128, K=16     |       562.9    |        558.8    |    504.6
                          f32 fwd_gen B=256, M=128, K=32     |       743.4    |        744.9    |    696.6
                          f16 fwd_gen B=256, M=128, K=32     |       668.8    |        670.4    |    559.7
                          f32 fwd_gen B=256, M=128, K=64     |      1035.6    |       1039.4    |    857.7
                          f16 fwd_gen B=256, M=128, K=64     |       910.1    |        912.4    |    685.7
                          f32 fwd_gen B=256, M=128, K=128    |      2106.5    |       2112.5    |   1230.2
                          f16 fwd_gen B=256, M=128, K=128    |      1727.9    |       1722.6    |    968.2
                          f32 fwd_gen B=256, M=128, K=256    |      4497.4    |       4448.3    |   2328.1
                          f16 fwd_gen B=256, M=128, K=256    |      3479.6    |       3472.7    |   1758.2
                          f32 fwd_gen B=256, M=512, K=16     |      9511.9    |       9686.7    |   8298.2
                          f16 fwd_gen B=256, M=512, K=16     |      8360.9    |       8365.1    |   6684.6
                          f32 fwd_gen B=256, M=512, K=32     |     10726.5    |      10873.5    |   8755.3
                          f16 fwd_gen B=256, M=512, K=32     |      9493.8    |       9437.7    |   7101.4
                          f32 fwd_gen B=256, M=512, K=64     |     14894.7    |      14909.4    |   9917.2
                          f16 fwd_gen B=256, M=512, K=64     |     12422.5    |      12288.5    |   8062.8
                          f32 fwd_gen B=256, M=512, K=128    |     30830.3    |      30693.2    |  13110.4
                          f16 fwd_gen B=256, M=512, K=128    |     24827.0    |      24799.8    |  10969.9
                          f32 fwd_gen B=256, M=512, K=256    |     65225.3    |      65654.4    |  22365.5
                          f16 fwd_gen B=256, M=512, K=256    |     50749.1    |      50223.7    |  19513.8
                          f32 fwd_gen B=256, M=1024, K=16    |     38175.6    |      38241.9    |  32308.1
                          f16 fwd_gen B=256, M=1024, K=16    |     33236.9    |      32765.8    |  25917.5
                          f32 fwd_gen B=256, M=1024, K=32    |     42755.6    |      42754.7    |  33714.9
                          f16 fwd_gen B=256, M=1024, K=32    |     37917.8    |      37597.5    |  27511.7
                          f32 fwd_gen B=256, M=1024, K=64    |     57917.0    |      57943.2    |  37281.9
                          f16 fwd_gen B=256, M=1024, K=64    |     49371.4    |      49014.8    |  30379.2
                          f32 fwd_gen B=256, M=1024, K=128   |    124602.9    |     123653.5    |  47015.5
                          f16 fwd_gen B=256, M=1024, K=128   |    104303.1    |     100350.2    |  39986.0
                          f32 fwd_gen B=256, M=1024, K=256   |    253401.1    |     253415.2    |  83675.5
                          f16 fwd_gen B=256, M=1024, K=256   |    203460.4    |     198098.3    |  73027.3
                          f32 fwd_gen B=384, M=192, K=88     |      5464.6    |       5469.8    |   3155.7
                          f16 fwd_gen B=384, M=192, K=88     |      4354.7    |       4334.4    |   2622.1
                          f32 fwd_gen B=768, M=256, K=64     |     11082.9    |      11042.0    |   8070.5
                          f16 fwd_gen B=768, M=256, K=64     |      8298.8    |       8176.9    |   6491.5
                          f32 fwd_gen B=1024, M=128, K=16    |      2266.9    |       2271.0    |   2309.1
                          f16 fwd_gen B=1024, M=128, K=16    |      1870.7    |       1856.8    |   1863.9
                          f32 fwd_gen B=1024, M=128, K=32    |      2775.9    |       2777.0    |   2582.0
                          f16 fwd_gen B=1024, M=128, K=32    |      2253.3    |       2239.1    |   2076.2
                          f32 fwd_gen B=1024, M=128, K=64    |      3933.5    |       3947.6    |   3227.8
                          f16 fwd_gen B=1024, M=128, K=64    |      3095.3    |       3080.0    |   2515.6
                          f32 fwd_gen B=1024, M=128, K=128   |      8156.7    |       8223.8    |   4789.1
                          f16 fwd_gen B=1024, M=128, K=128   |      6142.1    |       6104.6    |   3666.0
                          f32 fwd_gen B=1024, M=128, K=256   |     17248.9    |      17283.1    |   9183.2
                          f16 fwd_gen B=1024, M=128, K=256   |     12408.5    |      12672.6    |   6923.6
                          f32 fwd_gen B=1024, M=512, K=16    |     35961.7    |      36011.6    |  32782.6
                          f16 fwd_gen B=1024, M=512, K=16    |     28094.9    |      27821.0    |  26241.6
                          f32 fwd_gen B=1024, M=512, K=32    |     40607.6    |      40381.4    |  34627.1
                          f16 fwd_gen B=1024, M=512, K=32    |     32322.2    |      32117.2    |  28196.7
                          f32 fwd_gen B=1024, M=512, K=64    |     56485.5    |      56744.8    |  39378.8
                          f16 fwd_gen B=1024, M=512, K=64    |     43253.6    |      43340.8    |  31732.7
                          f32 fwd_gen B=1024, M=512, K=128   |    120895.1    |     120812.7    |  52358.4
                          f16 fwd_gen B=1024, M=512, K=128   |     91142.4    |      88084.4    |  42365.0
                          f32 fwd_gen B=1024, M=512, K=256   |    249426.5    |     247599.6    |  92642.0
                          f16 fwd_gen B=1024, M=512, K=256   |    179759.4    |     177818.9    |  78412.2
                          f32 fwd_gen B=1024, M=1024, K=16   |    144740.3    |     143821.9    |         
                          f16 fwd_gen B=1024, M=1024, K=16   |    112489.1    |     110988.6    |         
                          f32 fwd_gen B=1024, M=1024, K=32   |    163698.0    |     161774.9    |         
                          f16 fwd_gen B=1024, M=1024, K=32   |    129675.4    |     128456.5    |         
                          f32 fwd_gen B=1024, M=1024, K=64   |    225204.8    |     226110.0    |         
                          f16 fwd_gen B=1024, M=1024, K=64   |    170683.4    |     169544.6    |         
                          f32 fwd_gen B=1024, M=1024, K=128  |    463485.0    |     465248.8    |         
                          f16 fwd_gen B=1024, M=1024, K=128  |    338421.0    |     335817.0    |         
                          f32 fwd_gen B=1024, M=1024, K=256  |    965583.7    |     966059.9    |         
                          f16 fwd_gen B=1024, M=1024, K=256  |    680338.9    |     678913.0    |         
                          f32 fwd_gen B=2400, M=256, K=64    |     33993.4    |      33955.7    |  25264.5
                          f16 fwd_gen B=2400, M=256, K=64    |     25598.8    |      25575.3    |  20195.6
                          f32 fwd_gen B=8192, M=82, K=64     |     39214.4    |      42081.7    |  16501.0
                          f16 fwd_gen B=8192, M=82, K=64     |     26418.9    |      26565.4    |  13631.7
  (Tesla_V100_SXM2_16GB)  f32 fwd_gen B=256, M=128, K=16     |       280.4    |        280.5    |    313.0
                          f16 fwd_gen B=256, M=128, K=16     |       181.9    |        195.8    |    231.9
                          f32 fwd_gen B=256, M=128, K=32     |       366.3    |        364.4    |    367.6
                          f16 fwd_gen B=256, M=128, K=32     |       182.6    |        197.7    |    229.6
                          f32 fwd_gen B=256, M=128, K=64     |       525.3    |        524.0    |    494.2
                          f16 fwd_gen B=256, M=128, K=64     |       277.8    |        282.5    |    252.2
                          f32 fwd_gen B=256, M=128, K=128    |      1042.4    |       1039.9    |    808.4
                          f16 fwd_gen B=256, M=128, K=128    |       529.6    |        534.7    |    375.9
                          f32 fwd_gen B=256, M=128, K=256    |      2128.9    |       2113.6    |   1433.5
                          f16 fwd_gen B=256, M=128, K=256    |      1069.7    |       1084.7    |    623.3
                          f32 fwd_gen B=256, M=512, K=16     |      4282.3    |       4286.6    |   4113.2
                          f16 fwd_gen B=256, M=512, K=16     |      1674.8    |       1730.2    |   1814.9
                          f32 fwd_gen B=256, M=512, K=32     |      5185.2    |       5178.1    |   4375.3
                          f16 fwd_gen B=256, M=512, K=32     |      1893.3    |       1942.5    |   1942.3
                          f32 fwd_gen B=256, M=512, K=64     |      7139.8    |       7135.0    |   5169.7
                          f16 fwd_gen B=256, M=512, K=64     |      2604.7    |       2644.3    |   2240.0
                          f32 fwd_gen B=256, M=512, K=128    |     14396.9    |      14416.8    |   8109.3
                          f16 fwd_gen B=256, M=512, K=128    |      5323.5    |       5401.3    |   2809.1
                          f32 fwd_gen B=256, M=512, K=256    |     29436.6    |      29614.1    |  14103.9
                          f16 fwd_gen B=256, M=512, K=256    |     11986.5    |      11740.7    |   4310.4
                          f32 fwd_gen B=256, M=1024, K=16    |     17131.0    |      16988.4    |  15794.1
                          f16 fwd_gen B=256, M=1024, K=16    |      6436.6    |       6694.4    |   6839.1
                          f32 fwd_gen B=256, M=1024, K=32    |     20354.3    |      20399.3    |  16636.9
                          f16 fwd_gen B=256, M=1024, K=32    |      6977.2    |       7216.9    |   7118.1
                          f32 fwd_gen B=256, M=1024, K=64    |     27911.3    |      27929.5    |  19486.2
                          f16 fwd_gen B=256, M=1024, K=64    |      9351.1    |       9527.2    |   8035.3
                          f32 fwd_gen B=256, M=1024, K=128   |     56224.0    |      56473.3    |  30671.5
                          f16 fwd_gen B=256, M=1024, K=128   |     19651.4    |      19793.0    |   9381.2
                          f32 fwd_gen B=256, M=1024, K=256   |    117527.4    |     117923.9    |  53531.4
                          f16 fwd_gen B=256, M=1024, K=256   |     43035.6    |      44139.4    |  14264.3
                          f32 fwd_gen B=384, M=192, K=88     |      2756.0    |       2763.7    |   1849.5
                          f16 fwd_gen B=384, M=192, K=88     |      1163.0    |       1186.1    |    793.7
                          f32 fwd_gen B=768, M=256, K=64     |      5015.0    |       5014.3    |   4308.8
                          f16 fwd_gen B=768, M=256, K=64     |      2129.3    |       2168.8    |   1939.6
                          f32 fwd_gen B=1024, M=128, K=16    |       983.4    |        978.9    |   1133.9
                          f16 fwd_gen B=1024, M=128, K=16    |       487.2    |        500.8    |    557.2
                          f32 fwd_gen B=1024, M=128, K=32    |      1286.4    |       1282.8    |   1333.6
                          f16 fwd_gen B=1024, M=128, K=32    |       621.8    |        632.5    |    664.9
                          f32 fwd_gen B=1024, M=128, K=64    |      1884.7    |       1883.2    |   1732.0
                          f16 fwd_gen B=1024, M=128, K=64    |       990.3    |       1003.4    |    871.5
                          f32 fwd_gen B=1024, M=128, K=128   |      3800.9    |       3796.6    |   2863.8
                          f16 fwd_gen B=1024, M=128, K=128   |      1920.9    |       1940.0    |   1327.6
                          f32 fwd_gen B=1024, M=128, K=256   |      7725.6    |       7710.6    |   5009.7
                          f16 fwd_gen B=1024, M=128, K=256   |      4018.9    |       4044.9    |   2319.0
                          f32 fwd_gen B=1024, M=512, K=16    |     14987.6    |      15039.7    |  16240.3
                          f16 fwd_gen B=1024, M=512, K=16    |      5892.6    |       6099.3    |   7144.9
                          f32 fwd_gen B=1024, M=512, K=32    |     18639.1    |      18580.0    |  17424.9
                          f16 fwd_gen B=1024, M=512, K=32    |      6648.3    |       6854.6    |   7708.2
                          f32 fwd_gen B=1024, M=512, K=64    |     25932.1    |      25993.6    |  20367.7
                          f16 fwd_gen B=1024, M=512, K=64    |      9585.9    |       9665.9    |   8959.6
                          f32 fwd_gen B=1024, M=512, K=128   |     53173.8    |      53180.3    |  32376.8
                          f16 fwd_gen B=1024, M=512, K=128   |     19888.4    |      20102.3    |  10979.3
                          f32 fwd_gen B=1024, M=512, K=256   |    109147.5    |     108526.9    |  58786.1
                          f16 fwd_gen B=1024, M=512, K=256   |     44590.5    |      44873.8    |  17067.6
                          f32 fwd_gen B=1024, M=1024, K=16   |     59665.2    |      59822.9    |         
                          f16 fwd_gen B=1024, M=1024, K=16   |     22432.7    |      23356.0    |         
                          f32 fwd_gen B=1024, M=1024, K=32   |     73346.2    |      73317.7    |         
                          f16 fwd_gen B=1024, M=1024, K=32   |     24613.5    |      25418.6    |         
                          f32 fwd_gen B=1024, M=1024, K=64   |    102234.3    |     103742.4    |         
                          f16 fwd_gen B=1024, M=1024, K=64   |     34774.8    |      35188.8    |         
                          f32 fwd_gen B=1024, M=1024, K=128  |    209108.3    |     208818.7    |         
                          f16 fwd_gen B=1024, M=1024, K=128  |     73702.9    |      74141.9    |         
                          f32 fwd_gen B=1024, M=1024, K=256  |    418911.9    |     418241.4    |         
                          f16 fwd_gen B=1024, M=1024, K=256  |    165803.9    |     166748.5    |         
                          f32 fwd_gen B=2400, M=256, K=64    |     15618.7    |      15617.5    |  13054.3
                          f16 fwd_gen B=2400, M=256, K=64    |      6499.0    |       6548.9    |   5934.7
                          f32 fwd_gen B=8192, M=82, K=64     |     13620.2    |                 |   7738.1
                          f16 fwd_gen B=8192, M=82, K=64     |      9113.7    |                 |   3928.4

Times are in microseconds (us).
A100 bw
[----- attention backward (attn_bias=<class 'NoneType'>) ------]
                                 |  4ed5c36_it2  |  f527a89_main
1 threads: -----------------------------------------------------
      f16 B=256, M=128, K=16     |       136.6   |       139.0  
      f32 B=256, M=128, K=16     |       146.9   |       148.8  
      f16 B=256, M=128, K=32     |       129.2   |       208.4  
      f32 B=256, M=128, K=32     |       167.7   |       168.9  
      f16 B=256, M=128, K=64     |       140.5   |       164.7  
      f32 B=256, M=128, K=64     |       237.9   |       239.7  
      f16 B=256, M=128, K=128    |       267.4   |       267.7  
      f32 B=256, M=128, K=128    |       490.2   |       489.5  
      f16 B=256, M=128, K=256    |       584.7   |       584.0  
      f32 B=256, M=128, K=256    |      1043.4   |      1047.4  
      f16 B=256, M=512, K=16     |       733.5   |       731.0  
      f32 B=256, M=512, K=16     |      2004.5   |      2001.1  
      f16 B=256, M=512, K=32     |       917.7   |       906.5  
      f32 B=256, M=512, K=32     |      2319.0   |      2321.3  
      f16 B=256, M=512, K=64     |      1360.6   |      1365.4  
      f32 B=256, M=512, K=64     |      2998.1   |      3008.1  
      f16 B=256, M=512, K=128    |      2679.9   |      2684.4  
      f32 B=256, M=512, K=128    |      6313.8   |      6329.4  
      f16 B=256, M=512, K=256    |      5790.6   |      5791.0  
      f32 B=256, M=512, K=256    |     13367.1   |     13479.1  
      f16 B=256, M=1024, K=16    |      2982.1   |      2925.1  
      f32 B=256, M=1024, K=16    |      8512.0   |      8550.8  
      f16 B=256, M=1024, K=32    |      3687.7   |      3704.9  
      f32 B=256, M=1024, K=32    |      9047.1   |      9029.4  
      f16 B=256, M=1024, K=64    |      4871.2   |      4861.3  
      f32 B=256, M=1024, K=64    |     11889.1   |     11911.6  
      f16 B=256, M=1024, K=128   |      9613.1   |      9644.6  
      f32 B=256, M=1024, K=128   |     24137.4   |     24228.0  
      f16 B=256, M=1024, K=256   |     21097.3   |     21065.2  
      f32 B=256, M=1024, K=256   |     51657.7   |     53241.7  
      f16 B=384, M=192, K=88     |       640.5   |       641.6  
      f32 B=384, M=192, K=88     |      1056.0   |      1062.6  
      f16 B=768, M=256, K=64     |      1143.1   |      1143.7  
      f32 B=768, M=256, K=64     |      2120.6   |      2131.2  
      f16 B=1024, M=128, K=16    |       234.0   |       231.9  
      f32 B=1024, M=128, K=16    |       450.2   |       450.6  
      f16 B=1024, M=128, K=32    |       306.4   |       305.1  
      f32 B=1024, M=128, K=32    |       551.0   |       553.7  
      f16 B=1024, M=128, K=64    |       525.1   |       524.8  
      f32 B=1024, M=128, K=64    |       844.0   |       846.5  
      f16 B=1024, M=128, K=128   |      1049.8   |      1047.1  
      f32 B=1024, M=128, K=128   |      1675.4   |      1678.4  
      f16 B=1024, M=128, K=256   |      2199.7   |      2198.1  
      f32 B=1024, M=128, K=256   |      3494.5   |      3509.0  
      f16 B=1024, M=512, K=16    |      2514.5   |      2513.3  
      f32 B=1024, M=512, K=16    |      6424.5   |      6420.9  
      f16 B=1024, M=512, K=32    |      3510.9   |      3493.0  
      f32 B=1024, M=512, K=32    |      7480.0   |      7479.6  
      f16 B=1024, M=512, K=64    |      4834.0   |      4829.3  
      f32 B=1024, M=512, K=64    |      9678.2   |      9746.4  
      f16 B=1024, M=512, K=128   |     10210.0   |     10218.7  
      f32 B=1024, M=512, K=128   |     20086.5   |     20198.2  
      f16 B=1024, M=512, K=256   |     23313.8   |     23221.6  
      f32 B=1024, M=512, K=256   |     43316.7   |     43472.3  
      f16 B=1024, M=1024, K=16   |     11528.7   |     11618.7  
      f32 B=1024, M=1024, K=16   |     26997.9   |     27039.2  
      f16 B=1024, M=1024, K=32   |     13048.1   |     13018.2  
      f32 B=1024, M=1024, K=32   |     28699.4   |     28724.9  
      f16 B=1024, M=1024, K=64   |     17184.7   |     17148.3  
      f32 B=1024, M=1024, K=64   |     36583.4   |     36846.5  
      f16 B=1024, M=1024, K=128  |     37921.1   |     37935.5  
      f32 B=1024, M=1024, K=128  |     76306.8   |     76639.2  
      f16 B=1024, M=1024, K=256  |     85658.0   |     85501.1  
      f32 B=1024, M=1024, K=256  |    166120.6   |    166605.5  
      f16 B=2400, M=256, K=64    |      3341.1   |      3340.8  
      f32 B=2400, M=256, K=64    |      6194.8   |      6220.2  
      f16 B=8192, M=82, K=64     |      2865.4   |      2854.9  
      f32 B=8192, M=82, K=64     |      8268.4   |      8275.1  

Times are in microseconds (us).
P100/V100 fw (causal)
[------------------- attention (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) -------------------]
                                                             |  a10cd6d8_it2  |  f527a89c_main  |  vanilla
1 threads: -----------------------------------------------------------------------------------------------
  (Quadro_GP100)          f32 fwd_gen B=256, M=128, K=16     |       118.4    |        127.9    |    359.5
                          f16 fwd_gen B=256, M=128, K=16     |       129.9    |        140.6    |    293.4
                          f32 fwd_gen B=256, M=128, K=32     |       140.0    |        147.7    |    380.2
                          f16 fwd_gen B=256, M=128, K=32     |       152.9    |        165.1    |    308.0
                          f32 fwd_gen B=256, M=128, K=64     |       183.6    |        191.3    |    425.4
                          f16 fwd_gen B=256, M=128, K=64     |       197.2    |        206.9    |    347.0
                          f32 fwd_gen B=256, M=128, K=128    |       381.6    |        394.4    |    530.7
                          f16 fwd_gen B=256, M=128, K=128    |       381.5    |        390.3    |    432.4
                          f32 fwd_gen B=256, M=128, K=256    |       739.4    |        764.4    |   1006.5
                          f16 fwd_gen B=256, M=128, K=256    |       749.7    |        751.1    |    721.3
                          f32 fwd_gen B=256, M=512, K=16     |      1134.9    |       1232.2    |   4971.4
                          f16 fwd_gen B=256, M=512, K=16     |      1267.1    |       1370.9    |   3832.7
                          f32 fwd_gen B=256, M=512, K=32     |      1340.4    |       1434.2    |   5072.1
                          f16 fwd_gen B=256, M=512, K=32     |      1495.8    |       1600.0    |   4035.2
                          f32 fwd_gen B=256, M=512, K=64     |      1758.5    |       1882.7    |   5525.0
                          f16 fwd_gen B=256, M=512, K=64     |      1958.4    |       2080.1    |   4647.6
                          f32 fwd_gen B=256, M=512, K=128    |      3681.0    |       3924.7    |   6587.5
                          f16 fwd_gen B=256, M=512, K=128    |      3844.5    |       4055.2    |   5767.5
                          f32 fwd_gen B=256, M=512, K=256    |      7654.1    |       8149.3    |  10863.7
                          f16 fwd_gen B=256, M=512, K=256    |      8039.0    |       8283.1    |   9715.8
                          f32 fwd_gen B=256, M=1024, K=16    |      4090.8    |       4485.0    |  19589.8
                          f16 fwd_gen B=256, M=1024, K=16    |      4519.3    |       4971.4    |  15696.2
                          f32 fwd_gen B=256, M=1024, K=32    |      4858.4    |       5235.9    |  20022.6
                          f16 fwd_gen B=256, M=1024, K=32    |      5359.5    |       5835.0    |  16783.6
                          f32 fwd_gen B=256, M=1024, K=64    |      6327.3    |       6751.8    |  21592.8
                          f16 fwd_gen B=256, M=1024, K=64    |      7102.4    |       7526.0    |  18754.6
                          f32 fwd_gen B=256, M=1024, K=128   |     13047.6    |      14068.1    |  25912.1
                          f16 fwd_gen B=256, M=1024, K=128   |     13968.1    |      14685.3    |  23169.4
                          f32 fwd_gen B=256, M=1024, K=256   |     27895.1    |      29908.1    |  42132.2
                          f16 fwd_gen B=256, M=1024, K=256   |     28891.7    |      30121.5    |  38490.4
                          f32 fwd_gen B=384, M=192, K=88     |       920.3    |        986.5    |   1465.9
                          f16 fwd_gen B=384, M=192, K=88     |       927.0    |        979.9    |   1295.7
                          f32 fwd_gen B=768, M=256, K=64     |      1534.4    |       1658.3    |   4271.5
                          f16 fwd_gen B=768, M=256, K=64     |      1693.0    |       1792.6    |   3621.8
                          f32 fwd_gen B=1024, M=128, K=16    |       424.3    |        464.1    |   1333.6
                          f16 fwd_gen B=1024, M=128, K=16    |       462.5    |        502.1    |   1073.3
                          f32 fwd_gen B=1024, M=128, K=32    |       509.8    |        545.9    |   1426.5
                          f16 fwd_gen B=1024, M=128, K=32    |       551.2    |        594.4    |   1147.8
                          f32 fwd_gen B=1024, M=128, K=64    |       674.0    |        718.7    |   1619.1
                          f16 fwd_gen B=1024, M=128, K=64    |       731.2    |        770.9    |   1311.1
                          f32 fwd_gen B=1024, M=128, K=128   |      1508.2    |       1585.9    |   2026.1
                          f16 fwd_gen B=1024, M=128, K=128   |      1494.8    |       1558.3    |   1670.2
                          f32 fwd_gen B=1024, M=128, K=256   |      2883.7    |       3019.9    |   3848.9
                          f16 fwd_gen B=1024, M=128, K=256   |      2916.3    |       2966.1    |   2811.4
                          f32 fwd_gen B=1024, M=512, K=16    |      4443.1    |       4903.4    |  19594.2
                          f16 fwd_gen B=1024, M=512, K=16    |      4912.7    |       5381.8    |  15393.8
                          f32 fwd_gen B=1024, M=512, K=32    |      5294.2    |       5758.7    |  20025.6
                          f16 fwd_gen B=1024, M=512, K=32    |      5809.0    |       6279.6    |  16650.1
                          f32 fwd_gen B=1024, M=512, K=64    |      6997.3    |       7423.7    |  21856.9
                          f16 fwd_gen B=1024, M=512, K=64    |      7822.4    |       8271.2    |  18636.1
                          f32 fwd_gen B=1024, M=512, K=128   |     14821.7    |      15921.5    |  26405.8
                          f16 fwd_gen B=1024, M=512, K=128   |     15506.9    |      16307.0    |  23152.4
                          f32 fwd_gen B=1024, M=512, K=256   |     30528.1    |      32573.2    |  43769.8
                          f16 fwd_gen B=1024, M=512, K=256   |     31600.2    |      32400.5    |  39295.2
                          f32 fwd_gen B=1024, M=1024, K=16   |     16333.9    |      17867.6    |         
                          f16 fwd_gen B=1024, M=1024, K=16   |     18156.0    |      19744.0    |         
                          f32 fwd_gen B=1024, M=1024, K=32   |     19235.7    |      21009.7    |         
                          f16 fwd_gen B=1024, M=1024, K=32   |     21469.3    |      23110.3    |         
                          f32 fwd_gen B=1024, M=1024, K=64   |     25302.9    |      27179.2    |         
                          f16 fwd_gen B=1024, M=1024, K=64   |     28063.9    |      29931.9    |         
                          f32 fwd_gen B=1024, M=1024, K=128  |     51937.5    |      56392.7    |         
                          f16 fwd_gen B=1024, M=1024, K=128  |     55583.7    |      58485.1    |         
                          f32 fwd_gen B=1024, M=1024, K=256  |    112178.2    |     120699.9    |         
                          f16 fwd_gen B=1024, M=1024, K=256  |    118242.4    |     120288.6    |         
                          f32 fwd_gen B=2400, M=256, K=64    |      4804.2    |       5086.1    |  13322.3
                          f16 fwd_gen B=2400, M=256, K=64    |      5222.3    |       5506.6    |  11253.7
                          f32 fwd_gen B=8192, M=82, K=64     |      4565.3    |       4806.2    |   8356.0
                          f16 fwd_gen B=8192, M=82, K=64     |      4986.9    |       5185.4    |   7332.8
  (Tesla_V100_SXM2_16GB)  f32 fwd_gen B=256, M=128, K=16     |        73.7    |         75.2    |    189.5
                          f16 fwd_gen B=256, M=128, K=16     |        48.9    |         52.1    |    115.1
                          f32 fwd_gen B=256, M=128, K=32     |        86.1    |         87.8    |    208.4
                          f16 fwd_gen B=256, M=128, K=32     |        49.0    |         50.6    |    118.9
                          f32 fwd_gen B=256, M=128, K=64     |       113.0    |        115.0    |    260.1
                          f16 fwd_gen B=256, M=128, K=64     |        60.8    |         60.2    |    133.0
                          f32 fwd_gen B=256, M=128, K=128    |       210.6    |        213.8    |    387.0
                          f16 fwd_gen B=256, M=128, K=128    |        98.7    |         96.5    |    170.7
                          f32 fwd_gen B=256, M=128, K=256    |       411.4    |        419.5    |    668.2
                          f16 fwd_gen B=256, M=128, K=256    |       189.6    |        186.5    |    235.3
                          f32 fwd_gen B=256, M=512, K=16     |       656.2    |        676.9    |   2750.1
                          f16 fwd_gen B=256, M=512, K=16     |       332.0    |        326.0    |   1297.6
                          f32 fwd_gen B=256, M=512, K=32     |       778.2    |        820.8    |   2896.1
                          f16 fwd_gen B=256, M=512, K=32     |       346.9    |        340.4    |   1343.1
                          f32 fwd_gen B=256, M=512, K=64     |      1038.9    |       1068.6    |   3027.7
                          f16 fwd_gen B=256, M=512, K=64     |       417.4    |        413.5    |   1419.0
                          f32 fwd_gen B=256, M=512, K=128    |      2048.6    |       2097.0    |   4326.8
                          f16 fwd_gen B=256, M=512, K=128    |       739.6    |        716.4    |   1572.5
                          f32 fwd_gen B=256, M=512, K=256    |      4212.9    |       4332.7    |   7023.2
                          f16 fwd_gen B=256, M=512, K=256    |      1773.2    |       1758.6    |   2125.5
                          f32 fwd_gen B=256, M=1024, K=16    |      2316.0    |       2415.9    |  11002.9
                          f16 fwd_gen B=256, M=1024, K=16    |      1119.7    |       1097.5    |   5462.6
                          f32 fwd_gen B=256, M=1024, K=32    |      2771.9    |       2857.3    |  11312.0
                          f16 fwd_gen B=256, M=1024, K=32    |      1158.0    |       1127.4    |   5543.7
                          f32 fwd_gen B=256, M=1024, K=64    |      3712.1    |       3775.8    |  11875.1
                          f16 fwd_gen B=256, M=1024, K=64    |      1392.7    |       1370.6    |   5681.1
                          f32 fwd_gen B=256, M=1024, K=128   |      7344.6    |       7502.2    |  17256.8
                          f16 fwd_gen B=256, M=1024, K=128   |      2397.1    |       2396.9    |   6053.4
                          f32 fwd_gen B=256, M=1024, K=256   |     15262.3    |      15819.1    |  28129.4
                          f16 fwd_gen B=256, M=1024, K=256   |      6361.1    |       6167.2    |   8011.9
                          f32 fwd_gen B=384, M=192, K=88     |       491.0    |        504.6    |    937.8
                          f16 fwd_gen B=384, M=192, K=88     |       217.3    |        213.9    |    412.6
                          f32 fwd_gen B=768, M=256, K=64     |       901.0    |        926.4    |   2484.0
                          f16 fwd_gen B=768, M=256, K=64     |       394.8    |        393.7    |   1110.1
                          f32 fwd_gen B=1024, M=128, K=16    |       240.6    |        249.0    |    687.3
                          f16 fwd_gen B=1024, M=128, K=16    |       134.6    |        132.9    |    365.3
                          f32 fwd_gen B=1024, M=128, K=32    |       290.4    |        302.1    |    755.1
                          f16 fwd_gen B=1024, M=128, K=32    |       149.1    |        147.3    |    394.1
                          f32 fwd_gen B=1024, M=128, K=64    |       394.3    |        407.3    |    898.4
                          f16 fwd_gen B=1024, M=128, K=64    |       200.6    |        203.8    |    455.9
                          f32 fwd_gen B=1024, M=128, K=128   |       796.9    |        829.1    |   1308.9
                          f16 fwd_gen B=1024, M=128, K=128   |       356.6    |        351.9    |    583.2
                          f32 fwd_gen B=1024, M=128, K=256   |      1561.7    |       1596.0    |   2130.6
                          f16 fwd_gen B=1024, M=128, K=256   |       735.6    |        723.3    |    840.6
                          f32 fwd_gen B=1024, M=512, K=16    |      2476.3    |       2583.2    |  11175.0
                          f16 fwd_gen B=1024, M=512, K=16    |      1239.9    |       1216.7    |   5068.4
                          f32 fwd_gen B=1024, M=512, K=32    |      2978.2    |       3097.7    |  11774.8
                          f16 fwd_gen B=1024, M=512, K=32    |      1300.1    |       1272.7    |   5234.8
                          f32 fwd_gen B=1024, M=512, K=64    |      4045.2    |       4123.4    |  12369.5
                          f16 fwd_gen B=1024, M=512, K=64    |      1605.1    |       1586.2    |   5596.2
                          f32 fwd_gen B=1024, M=512, K=128   |      8145.2    |       8312.3    |  17490.6
                          f16 fwd_gen B=1024, M=512, K=128   |      2865.5    |       2828.0    |   6191.6
                          f32 fwd_gen B=1024, M=512, K=256   |     16675.5    |      17113.6    |  29348.0
                          f16 fwd_gen B=1024, M=512, K=256   |      7081.8    |       7234.6    |   8451.3
                          f32 fwd_gen B=1024, M=1024, K=16   |      9032.5    |       9440.5    |         
                          f16 fwd_gen B=1024, M=1024, K=16   |      4298.2    |       4210.2    |         
                          f32 fwd_gen B=1024, M=1024, K=32   |     10864.5    |      11261.8    |         
                          f16 fwd_gen B=1024, M=1024, K=32   |      4456.1    |       4410.3    |         
                          f32 fwd_gen B=1024, M=1024, K=64   |     14616.0    |      14917.2    |         
                          f16 fwd_gen B=1024, M=1024, K=64   |      5404.9    |       5348.1    |         
                          f32 fwd_gen B=1024, M=1024, K=128  |     29200.1    |      29810.9    |         
                          f16 fwd_gen B=1024, M=1024, K=128  |      9489.6    |       9421.1    |         
                          f32 fwd_gen B=1024, M=1024, K=256  |     61009.8    |      62968.1    |         
                          f16 fwd_gen B=1024, M=1024, K=256  |     24697.3    |      24588.1    |         
                          f32 fwd_gen B=2400, M=256, K=64    |      2761.8    |       2819.1    |   7438.0
                          f16 fwd_gen B=2400, M=256, K=64    |      1224.1    |       1202.6    |   3402.6
                          f32 fwd_gen B=8192, M=82, K=64     |      2555.5    |                 |   3956.6
                          f16 fwd_gen B=8192, M=82, K=64     |      1249.7    |                 |   2115.9

Times are in microseconds (us).
P100/V100 fw
[------------------------------- attention (attn_bias=<class 'NoneType'>) -------------------------------]
                                                             |  a10cd6d8_it2  |  f527a89c_main  |  vanilla
1 threads: -----------------------------------------------------------------------------------------------
  (Quadro_GP100)          f32 fwd_gen B=256, M=128, K=16     |       154.1    |        162.4    |    300.3
                          f16 fwd_gen B=256, M=128, K=16     |       153.8    |        167.4    |    222.1
                          f32 fwd_gen B=256, M=128, K=32     |       166.7    |        174.5    |    284.2
                          f16 fwd_gen B=256, M=128, K=32     |       182.9    |        195.3    |    238.7
                          f32 fwd_gen B=256, M=128, K=64     |       224.7    |        231.9    |    337.0
                          f16 fwd_gen B=256, M=128, K=64     |       237.0    |        247.8    |    287.8
                          f32 fwd_gen B=256, M=128, K=128    |       439.7    |        455.1    |    447.6
                          f16 fwd_gen B=256, M=128, K=128    |       443.1    |        455.7    |    373.4
                          f32 fwd_gen B=256, M=128, K=256    |       879.5    |        928.0    |    924.3
                          f16 fwd_gen B=256, M=128, K=256    |       893.8    |        900.6    |    663.0
                          f32 fwd_gen B=256, M=512, K=16     |      1918.3    |       2064.1    |   3433.2
                          f16 fwd_gen B=256, M=512, K=16     |      2118.7    |       2291.5    |   2951.1
                          f32 fwd_gen B=256, M=512, K=32     |      2252.4    |       2428.4    |   3692.2
                          f16 fwd_gen B=256, M=512, K=32     |      2483.0    |       2668.2    |   3278.6
                          f32 fwd_gen B=256, M=512, K=64     |      2965.7    |       3139.8    |   4342.7
                          f16 fwd_gen B=256, M=512, K=64     |      3284.1    |       3512.1    |   3815.5
                          f32 fwd_gen B=256, M=512, K=128    |      5908.7    |       6295.6    |   5549.1
                          f16 fwd_gen B=256, M=512, K=128    |      6345.7    |       6823.8    |   4981.2
                          f32 fwd_gen B=256, M=512, K=256    |     12703.5    |      13645.3    |   9825.2
                          f16 fwd_gen B=256, M=512, K=256    |     13460.9    |      13978.4    |   8906.3
                          f32 fwd_gen B=256, M=1024, K=16    |      7486.2    |       8453.9    |  13526.3
                          f16 fwd_gen B=256, M=1024, K=16    |      8309.8    |       9188.0    |  11930.1
                          f32 fwd_gen B=256, M=1024, K=32    |      8777.4    |       9527.4    |  14580.4
                          f16 fwd_gen B=256, M=1024, K=32    |      9764.7    |      10824.0    |  13080.3
                          f32 fwd_gen B=256, M=1024, K=64    |     11712.5    |      12658.1    |  16913.6
                          f16 fwd_gen B=256, M=1024, K=64    |     12891.4    |      13780.4    |  15238.5
                          f32 fwd_gen B=256, M=1024, K=128   |     22915.1    |      25009.8    |  21452.1
                          f16 fwd_gen B=256, M=1024, K=128   |     24934.4    |      26785.1    |  19434.1
                          f32 fwd_gen B=256, M=1024, K=256   |     51063.3    |      54789.4    |  37981.7
                          f16 fwd_gen B=256, M=1024, K=256   |     53383.9    |      54688.3    |  35144.6
                          f32 fwd_gen B=384, M=192, K=88     |      1332.1    |       1435.4    |   1228.4
                          f16 fwd_gen B=384, M=192, K=88     |      1382.4    |       1480.4    |   1119.5
                          f32 fwd_gen B=768, M=256, K=64     |      2320.6    |       2486.4    |   3421.9
                          f16 fwd_gen B=768, M=256, K=64     |      2548.4    |       2711.3    |   3028.1
                          f32 fwd_gen B=1024, M=128, K=16    |       529.5    |        574.9    |    971.7
                          f16 fwd_gen B=1024, M=128, K=16    |       569.3    |        632.8    |    841.5
                          f32 fwd_gen B=1024, M=128, K=32    |       650.2    |        696.0    |   1074.0
                          f16 fwd_gen B=1024, M=128, K=32    |       685.5    |        753.0    |    926.6
                          f32 fwd_gen B=1024, M=128, K=64    |       851.3    |        898.1    |   1286.2
                          f16 fwd_gen B=1024, M=128, K=64    |       916.1    |        963.2    |   1100.9
                          f32 fwd_gen B=1024, M=128, K=128   |      1735.3    |       1838.7    |   1718.4
                          f16 fwd_gen B=1024, M=128, K=128   |      1767.7    |       1857.8    |   1439.4
                          f32 fwd_gen B=1024, M=128, K=256   |      3443.7    |       3687.9    |   3511.7
                          f16 fwd_gen B=1024, M=128, K=256   |      3509.5    |       3625.0    |   2598.4
                          f32 fwd_gen B=1024, M=512, K=16    |      7621.4    |       8546.1    |  13752.0
                          f16 fwd_gen B=1024, M=512, K=16    |      8448.1    |       9391.3    |  11978.2
                          f32 fwd_gen B=1024, M=512, K=32    |      9062.4    |       9854.4    |  14948.0
                          f16 fwd_gen B=1024, M=512, K=32    |     10080.9    |      10926.1    |  13216.4
                          f32 fwd_gen B=1024, M=512, K=64    |     11908.9    |      12856.5    |  17414.6
                          f16 fwd_gen B=1024, M=512, K=64    |     13226.1    |      13974.8    |  15488.8
                          f32 fwd_gen B=1024, M=512, K=128   |     23919.8    |      25776.4    |  22168.6
                          f16 fwd_gen B=1024, M=512, K=128   |     25744.3    |      27217.2    |  19775.3
                          f32 fwd_gen B=1024, M=512, K=256   |     51508.8    |      55604.2    |  39519.8
                          f16 fwd_gen B=1024, M=512, K=256   |     53945.5    |      55808.6    |  36146.1
                          f32 fwd_gen B=1024, M=1024, K=16   |     29905.4    |      33163.2    |         
                          f16 fwd_gen B=1024, M=1024, K=16   |     33120.9    |      36215.5    |         
                          f32 fwd_gen B=1024, M=1024, K=32   |     35358.5    |      38556.3    |         
                          f16 fwd_gen B=1024, M=1024, K=32   |     39408.2    |      42834.3    |         
                          f32 fwd_gen B=1024, M=1024, K=64   |     46072.7    |      49090.7    |         
                          f16 fwd_gen B=1024, M=1024, K=64   |     51468.1    |      54710.6    |         
                          f32 fwd_gen B=1024, M=1024, K=128  |     93170.0    |     102208.5    |         
                          f16 fwd_gen B=1024, M=1024, K=128  |    103136.5    |     109579.5    |         
                          f32 fwd_gen B=1024, M=1024, K=256  |    206180.7    |     221698.4    |         
                          f16 fwd_gen B=1024, M=1024, K=256  |    214954.3    |     221447.5    |         
                          f32 fwd_gen B=2400, M=256, K=64    |      7280.2    |       7783.3    |  10638.5
                          f16 fwd_gen B=2400, M=256, K=64    |      8157.6    |       8566.3    |   9450.2
                          f32 fwd_gen B=8192, M=82, K=64     |      5551.0    |       5841.8    |   6891.2
                          f16 fwd_gen B=8192, M=82, K=64     |      6113.9    |       6300.3    |   6451.9
  (Tesla_V100_SXM2_16GB)  f32 fwd_gen B=256, M=128, K=16     |        93.2    |         91.9    |    173.3
                          f16 fwd_gen B=256, M=128, K=16     |        49.2    |         53.3    |    102.5
                          f32 fwd_gen B=256, M=128, K=32     |        99.6    |        102.4    |    142.7
                          f16 fwd_gen B=256, M=128, K=32     |        53.9    |         53.3    |    103.3
                          f32 fwd_gen B=256, M=128, K=64     |       132.5    |        135.1    |    199.3
                          f16 fwd_gen B=256, M=128, K=64     |        73.2    |         71.5    |    100.0
                          f32 fwd_gen B=256, M=128, K=128    |       252.3    |        255.6    |    327.9
                          f16 fwd_gen B=256, M=128, K=128    |       112.9    |        110.2    |    133.7
                          f32 fwd_gen B=256, M=128, K=256    |       493.0    |        507.0    |    617.4
                          f16 fwd_gen B=256, M=128, K=256    |       221.4    |        220.1    |    204.0
                          f32 fwd_gen B=256, M=512, K=16     |      1101.6    |       1144.2    |   1767.5
                          f16 fwd_gen B=256, M=512, K=16     |       516.6    |        506.2    |    791.2
                          f32 fwd_gen B=256, M=512, K=32     |      1299.5    |       1335.8    |   1915.2
                          f16 fwd_gen B=256, M=512, K=32     |       532.3    |        522.3    |    835.5
                          f32 fwd_gen B=256, M=512, K=64     |      1711.8    |       1759.6    |   2283.6
                          f16 fwd_gen B=256, M=512, K=64     |       654.1    |        681.1    |    974.5
                          f32 fwd_gen B=256, M=512, K=128    |      3407.9    |       3466.0    |   3676.6
                          f16 fwd_gen B=256, M=512, K=128    |      1150.7    |       1139.3    |   1155.0
                          f32 fwd_gen B=256, M=512, K=256    |      7134.9    |       7387.3    |   6426.8
                          f16 fwd_gen B=256, M=512, K=256    |      2939.5    |       2931.6    |   1737.3
                          f32 fwd_gen B=256, M=1024, K=16    |      4237.1    |       4432.5    |   7033.7
                          f16 fwd_gen B=256, M=1024, K=16    |      1969.0    |       1928.1    |   3318.6
                          f32 fwd_gen B=256, M=1024, K=32    |      5096.2    |       5289.1    |   7737.6
                          f16 fwd_gen B=256, M=1024, K=32    |      2006.9    |       1983.6    |   3427.6
                          f32 fwd_gen B=256, M=1024, K=64    |      6680.9    |       6854.9    |   9108.6
                          f16 fwd_gen B=256, M=1024, K=64    |      2454.8    |       2441.4    |   3803.7
                          f32 fwd_gen B=256, M=1024, K=128   |     13242.4    |      13630.7    |  14723.4
                          f16 fwd_gen B=256, M=1024, K=128   |      4214.2    |       4188.0    |   4278.1
                          f32 fwd_gen B=256, M=1024, K=256   |     27950.3    |      28930.1    |  25739.1
                          f16 fwd_gen B=256, M=1024, K=256   |     11583.1    |      11723.5    |   6389.4
                          f32 fwd_gen B=384, M=192, K=88     |       721.6    |        752.7    |    767.1
                          f16 fwd_gen B=384, M=192, K=88     |       286.6    |        285.6    |    316.5
                          f32 fwd_gen B=768, M=256, K=64     |      1347.3    |       1389.5    |   1894.9
                          f16 fwd_gen B=768, M=256, K=64     |       549.1    |        546.4    |    773.6
                          f32 fwd_gen B=1024, M=128, K=16    |       298.5    |        308.8    |    440.6
                          f16 fwd_gen B=1024, M=128, K=16    |       156.2    |        154.2    |    231.4
                          f32 fwd_gen B=1024, M=128, K=32    |       359.3    |        371.0    |    524.4
                          f16 fwd_gen B=1024, M=128, K=32    |       173.4    |        171.3    |    259.6
                          f32 fwd_gen B=1024, M=128, K=64    |       493.2    |        505.6    |    689.6
                          f16 fwd_gen B=1024, M=128, K=64    |       238.6    |        237.1    |    324.4
                          f32 fwd_gen B=1024, M=128, K=128   |       964.9    |        985.7    |   1139.8
                          f16 fwd_gen B=1024, M=128, K=128   |       405.1    |        398.6    |    452.1
                          f32 fwd_gen B=1024, M=128, K=256   |      1888.9    |       1948.3    |   1964.5
                          f16 fwd_gen B=1024, M=128, K=256   |       833.0    |        825.3    |    724.1
                          f32 fwd_gen B=1024, M=512, K=16    |      4246.8    |       4451.7    |   7166.3
                          f16 fwd_gen B=1024, M=512, K=16    |      1991.3    |       1951.1    |   3077.2
                          f32 fwd_gen B=1024, M=512, K=32    |      5121.6    |       5302.2    |   7961.7
                          f16 fwd_gen B=1024, M=512, K=32    |      2066.8    |       2047.9    |   3210.7
                          f32 fwd_gen B=1024, M=512, K=64    |      6822.8    |       7097.1    |   9351.6
                          f16 fwd_gen B=1024, M=512, K=64    |      2560.6    |       2541.9    |   3852.7
                          f32 fwd_gen B=1024, M=512, K=128   |     13559.8    |      14019.4    |  14925.2
                          f16 fwd_gen B=1024, M=512, K=128   |      4470.9    |       4446.2    |   4539.5
                          f32 fwd_gen B=1024, M=512, K=256   |     28195.7    |      29146.3    |  27001.1
                          f16 fwd_gen B=1024, M=512, K=256   |     11877.8    |      11962.9    |   6894.8
                          f32 fwd_gen B=1024, M=1024, K=16   |     16675.2    |      17541.8    |         
                          f16 fwd_gen B=1024, M=1024, K=16   |      7634.9    |       7427.3    |         
                          f32 fwd_gen B=1024, M=1024, K=32   |     20014.9    |      20726.0    |         
                          f16 fwd_gen B=1024, M=1024, K=32   |      7841.6    |       7760.9    |         
                          f32 fwd_gen B=1024, M=1024, K=64   |     26619.9    |      27418.8    |         
                          f16 fwd_gen B=1024, M=1024, K=64   |      9538.6    |       9435.3    |         
                          f32 fwd_gen B=1024, M=1024, K=128  |     52749.2    |      53916.0    |         
                          f16 fwd_gen B=1024, M=1024, K=128  |     16626.0    |      16490.8    |         
                          f32 fwd_gen B=1024, M=1024, K=256  |    111714.5    |     116257.0    |         
                          f16 fwd_gen B=1024, M=1024, K=256  |     45511.0    |      45515.1    |         
                          f32 fwd_gen B=2400, M=256, K=64    |      4195.4    |       4300.2    |   5727.8
                          f16 fwd_gen B=2400, M=256, K=64    |      1686.3    |       1688.7    |   2369.7
                          f32 fwd_gen B=8192, M=82, K=64     |      3124.6    |                 |   3194.6
                          f16 fwd_gen B=8192, M=82, K=64     |      1449.6    |                 |   1709.2

Times are in microseconds (us).
A100 fw
[---------------- attention (attn_bias=<class 'NoneType'>) ----------------]
                                 |  4ed5c36_it2  |  f527a89_main  |  vanilla
1 threads: -----------------------------------------------------------------
      f16 B=256, M=128, K=16     |       29.1    |       28.9     |     85.5
      f32 B=256, M=128, K=16     |       46.4    |       46.6     |    125.4
      f16 B=256, M=128, K=32     |       28.7    |       28.8     |     62.9
      f32 B=256, M=128, K=32     |       47.9    |       48.3     |    136.0
      f16 B=256, M=128, K=64     |       28.8    |       28.9     |     62.3
      f32 B=256, M=128, K=64     |       67.0    |       67.8     |    158.8
      f16 B=256, M=128, K=128    |       50.3    |       50.1     |     61.6
      f32 B=256, M=128, K=128    |      115.6    |      115.8     |    213.5
      f16 B=256, M=128, K=256    |      105.9    |      105.8     |    114.9
      f32 B=256, M=128, K=256    |      209.4    |      213.1     |    363.4
      f16 B=256, M=512, K=16     |      189.0    |      188.2     |    472.2
      f32 B=256, M=512, K=16     |      564.9    |      569.3     |   1571.6
      f16 B=256, M=512, K=32     |      196.3    |      195.4     |    488.0
      f32 B=256, M=512, K=32     |      572.9    |      578.2     |   1687.1
      f16 B=256, M=512, K=64     |      240.3    |      239.2     |    537.9
      f32 B=256, M=512, K=64     |      709.4    |      714.6     |   1933.0
      f16 B=256, M=512, K=128    |      404.5    |      398.4     |    610.7
      f32 B=256, M=512, K=128    |     1449.7    |     1432.5     |   2428.3
      f16 B=256, M=512, K=256    |      906.8    |      902.6     |    759.9
      f32 B=256, M=512, K=256    |     2777.1    |     2799.1     |   4268.8
      f16 B=256, M=1024, K=16    |      710.4    |      708.9     |   1788.8
      f32 B=256, M=1024, K=16    |     2171.5    |     2181.6     |   6002.8
      f16 B=256, M=1024, K=32    |      720.0    |      718.6     |   1819.5
      f32 B=256, M=1024, K=32    |     2191.7    |     2200.7     |   6424.0
      f16 B=256, M=1024, K=64    |      847.6    |      844.8     |   1937.6
      f32 B=256, M=1024, K=64    |     2691.2    |     2703.9     |   7370.8
      f16 B=256, M=1024, K=128   |     1434.5    |     1424.0     |   2143.6
      f32 B=256, M=1024, K=128   |     5516.5    |     5482.1     |   9232.3
      f16 B=256, M=1024, K=256   |     3303.6    |     3303.2     |   2527.7
      f32 B=256, M=1024, K=256   |    10773.6    |    10826.9     |  16598.7
      f16 B=384, M=192, K=88     |      125.6    |      122.9     |    189.3
      f32 B=384, M=192, K=88     |      372.0    |      367.5     |    692.4
      f16 B=768, M=256, K=64     |      214.3    |      213.0     |    473.9
      f32 B=768, M=256, K=64     |      571.7    |      575.9     |   1502.3
      f16 B=1024, M=128, K=16    |       60.6    |       60.4     |    159.3
      f32 B=1024, M=128, K=16    |      164.8    |      166.0     |    427.7
      f16 B=1024, M=128, K=32    |       69.3    |       69.0     |    169.1
      f32 B=1024, M=128, K=32    |      173.3    |      175.5     |    476.1
      f16 B=1024, M=128, K=64    |      105.6    |      105.0     |    214.7
      f32 B=1024, M=128, K=64    |      227.5    |      229.3     |    561.8
      f16 B=1024, M=128, K=128   |      188.2    |      187.6     |    290.7
      f32 B=1024, M=128, K=128   |      420.8    |      415.8     |    731.5
      f16 B=1024, M=128, K=256   |      364.8    |      365.5     |    435.8
      f32 B=1024, M=128, K=256   |      784.1    |      795.9     |   1281.2
      f16 B=1024, M=512, K=16    |      742.1    |      738.4     |   1761.6
      f32 B=1024, M=512, K=16    |     2208.9    |     2218.5     |   6089.0
      f16 B=1024, M=512, K=32    |      762.1    |      760.0     |   1834.9
      f32 B=1024, M=512, K=32    |     2238.1    |     2248.8     |   6488.8
      f16 B=1024, M=512, K=64    |      916.6    |      913.3     |   2004.1
      f32 B=1024, M=512, K=64    |     2765.5    |     2776.5     |   7446.7
      f16 B=1024, M=512, K=128   |     1586.6    |     1557.8     |   2329.0
      f32 B=1024, M=512, K=128   |     5756.7    |     5712.5     |   9438.5
      f16 B=1024, M=512, K=256   |     3557.5    |     3547.5     |   2915.8
      f32 B=1024, M=512, K=256   |    11002.9    |    11104.0     |  17044.3
      f16 B=1024, M=1024, K=16   |     2793.6    |     2782.9     |   7053.0
      f32 B=1024, M=1024, K=16   |     8604.6    |     8646.7     |  23928.0
      f16 B=1024, M=1024, K=32   |     2826.2    |     2819.4     |   7194.9
      f32 B=1024, M=1024, K=32   |     8673.6    |     8711.4     |  25606.2
      f16 B=1024, M=1024, K=64   |     3314.4    |     3302.4     |   7716.3
      f32 B=1024, M=1024, K=64   |    10666.6    |    10700.5     |  29398.0
      f16 B=1024, M=1024, K=128  |     5772.3    |     5681.6     |   8475.8
      f32 B=1024, M=1024, K=128  |    21922.2    |    21788.1     |  36698.4
      f16 B=1024, M=1024, K=256  |    13071.0    |    12901.7     |   9961.9
      f32 B=1024, M=1024, K=256  |    42936.8    |    43153.2     |  66201.8
      f16 B=2400, M=256, K=64    |      632.4    |      629.1     |   1397.6
      f32 B=2400, M=256, K=64    |     1728.8    |     1740.1     |   4531.2
      f16 B=8192, M=82, K=64     |      594.9    |      589.2     |   1168.6
      f32 B=8192, M=82, K=64     |     1457.6    |     1459.7     |   2917.2

Times are in microseconds (us).

Stacked PR Chain: MMA_FROM_SMEM_IT_RES

PR Title Merges Into
#391 [MMA_FROM_SMEM_IT_RES] Fix padding of LSE in forward N/A
#392 [MMA_FROM_SMEM_IT_RES] Copy over files from cutlass #391
#393 [MMA_FROM_SMEM_IT_RES] Change how residuals are handled #392

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 15, 2022
@codecov-commenter
Copy link

Codecov Report

Base: 91.14% // Head: 91.23% // Increases project coverage by +0.09% 🎉

Coverage data is based on head (4ed5c36) compared to base (1e01667).
Patch has no changes to coverable lines.

Additional details and impacted files
@@            Coverage Diff             @@
##              it1     #393      +/-   ##
==========================================
+ Coverage   91.14%   91.23%   +0.09%     
==========================================
  Files          75       75              
  Lines        4358     4346      -12     
==========================================
- Hits         3972     3965       -7     
+ Misses        386      381       -5     
Flag Coverage Δ
Python 91.23% <ø> (+0.09%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
xformers/ops.py 85.88% <ø> (+1.23%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome Daniel!

@@ -337,6 +339,11 @@ class PredicatedTileIterator<
address_iterator_.clear_mask(enable);
}

CUTLASS_HOST_DEVICE
void set_residual_tile(bool enable) {
address_iterator_.set_residual_tile(enable);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a pity that address_iterator is a private member, otherwise we could have just inherited from the base class and added this method to all the classes in this file.

params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1);
pointer_ += Shape::kStrided * tile_offset.strided();
}
if (!Gather) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here, this is the only thing that we changed (in addition to the set_residual_tile) but we need to copy the whole file. :-/

xformers/ops.py Show resolved Hide resolved
Base automatically changed from it1 to main September 15, 2022 12:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants