Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chunk3: Add custom operator to avoid torch.cat in BW #458

Merged
merged 7 commits into from
Oct 6, 2022

Conversation

danthe3rd
Copy link
Contributor

@danthe3rd danthe3rd commented Oct 5, 2022

Stack from ghstack (oldest at bottom):

SUMMARY

Also:

  • updated benchmarks to reflect the real-world scenario where we chunk a packed qkv - for both fw/bw.
  • added coverage for chunking in tests

PERFORMANCE IMPACT

A100 bw (new benchmarks)
[---------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------]
                                     |  48_chunk3_31735f9  |  45_bwpacked_e53c5f3  |  vanilla  |  47_bwpackedgrad_9bacdf6
1 threads: --------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |          560.7      |           663.9       |   2265.7  |             710.3       
      f32 B=384, M=197, H=1, K=88    |         2445.1      |          2540.3       |   1843.3  |            2611.0       
      f16 B=384, M=197, H=1, K=80    |          530.4      |           619.9       |   1922.8  |             663.0       
      f32 B=384, M=197, H=1, K=80    |         2326.1      |          2425.2       |   1788.7  |            2476.4       
      f16 B=384, M=197, H=1, K=64    |          391.7      |           462.2       |   1812.7  |             492.8       
      f32 B=384, M=197, H=1, K=64    |         1275.0      |          1379.4       |   1675.4  |            1388.4       
      f16 B=1024, M=197, H=1, K=88   |         1399.5      |          1666.2       |   5965.2  |            1775.5       
      f32 B=1024, M=197, H=1, K=88   |         6332.5      |          6618.1       |   4559.6  |            6740.5       
      f16 B=1024, M=197, H=1, K=80   |         1326.2      |          1543.9       |   5041.4  |            1652.3       
      f32 B=1024, M=197, H=1, K=80   |         6057.1      |          6301.3       |   4411.6  |            6433.6       
      f16 B=1024, M=197, H=1, K=64   |          876.9      |          1063.1       |   4749.3  |            1133.2       
      f32 B=1024, M=197, H=1, K=64   |         3360.2      |          3629.0       |   4118.8  |            3652.0       
      f16 B=512, M=197, H=1, K=80    |          669.0      |           786.4       |   2544.9  |             842.2       
      f32 B=512, M=197, H=1, K=80    |         3032.3      |          3127.8       |   2287.4  |            3229.8       
      f16 B=32, M=197, H=16, K=80    |          663.0      |           789.7       |   2569.0  |             837.8       
      f32 B=32, M=197, H=16, K=80    |         3005.5      |          3166.3       |   2354.1  |            3225.9       
      f16 B=32, M=197, H=16, K=64    |          459.9      |           553.4       |   2436.3  |             591.9       
      f32 B=32, M=197, H=16, K=64    |         1814.1      |          1962.5       |   2197.3  |            1962.1       
      f16 B=32, M=197, H=16, K=128   |          792.5      |           981.9       |   4505.9  |            1056.5       
      f32 B=32, M=197, H=16, K=128   |         3734.8      |          3995.7       |   2805.8  |            4021.5       
      f16 B=256, M=197, H=1, K=88    |          413.4      |           482.6       |   1529.5  |             515.5       
      f32 B=256, M=197, H=1, K=88    |         1741.9      |          1818.3       |   1208.6  |            1852.4       
      f16 B=16, M=197, H=16, K=88    |          410.3      |           482.9       |   1545.7  |             512.5       
      f32 B=16, M=197, H=16, K=88    |         1734.9      |          1832.1       |   1250.6  |            1849.4       
      f16 B=16, M=197, H=16, K=64    |          235.4      |           286.0       |   1247.1  |             305.3       
      f32 B=16, M=197, H=16, K=64    |         1077.1      |          1143.7       |   1125.9  |            1154.0       
      f16 B=16, M=197, H=16, K=128   |          455.4      |           554.1       |   2273.1  |             596.0       
      f32 B=16, M=197, H=16, K=128   |         2028.9      |          2164.5       |   1446.7  |            2175.0       
      f16 B=1, M=4096, H=160, K=128  |        62454.4      |         63474.5       |  45930.5  |           64052.7       
      f32 B=1, M=4096, H=160, K=128  |       239035.4      |        232672.1       |           |          240073.9       
      f16 B=2, M=4096, H=160, K=128  |        98791.3      |        101006.4       |           |          101942.0       
      f32 B=2, M=4096, H=160, K=128  |       375914.9      |        368050.6       |           |          381280.4       
      f16 B=1, M=8192, H=160, K=128  |       248498.9      |        250066.9       |           |          251500.4       
      f32 B=1, M=8192, H=160, K=128  |       945102.2      |        922549.3       |           |          949256.4       
      f16 B=2, M=8192, H=160, K=128  |       389207.8      |        394486.6       |           |          396190.4       
      f32 B=2, M=8192, H=160, K=128  |      1496334.3      |       1449974.3       |           |         1502215.3       
      f16 B=1024, M=82, H=8, K=64    |         1872.4      |          2503.8       |   3819.8  |            2693.7       
      f32 B=1024, M=82, H=8, K=64    |         8734.3      |          9637.8       |   8732.9  |            9672.2       
      f16 B=150, M=256, H=16, K=64   |         2126.4      |          2713.4       |   4554.3  |            2880.8       
      f32 B=150, M=256, H=16, K=64   |         6214.3      |          7052.2       |  12943.2  |            7099.2       
      f16 B=64, M=256, H=12, K=64    |          741.2      |           930.1       |   1493.0  |             990.6       
      f32 B=64, M=256, H=12, K=64    |         2144.2      |          2408.5       |   4267.7  |            2433.8       
      f16 B=1, M=4096, H=16, K=40    |        24583.7      |         24224.8       |   4195.2  |           24500.2       
      f32 B=1, M=4096, H=16, K=40    |        72497.9      |         72070.8       |  17744.1  |           72393.0       
      f16 B=1, M=16384, H=16, K=40   |       451481.8      |        439027.7       |           |          451499.9       
      f32 B=1, M=16384, H=16, K=40   |      1169509.1      |       1164880.1       |           |         1169769.3       
      f16 B=256, M=4096, H=16, K=64  |       597391.6      |        625921.0       |           |          610433.2       
      f16 B=16, M=128, H=16, K=16    |           93.1      |           126.7       |    241.2  |             132.3       
      f32 B=16, M=128, H=16, K=16    |          184.1      |           176.5       |    373.8  |             180.7       
      f16 B=16, M=128, H=16, K=32    |          127.9      |           126.3       |    241.4  |             106.7       
      f32 B=16, M=128, H=16, K=32    |          194.1      |           216.6       |    412.7  |             225.8       
      f16 B=16, M=128, H=16, K=64    |          131.4      |           126.8       |    239.8  |             134.5       
      f32 B=16, M=128, H=16, K=64    |          280.4      |           326.0       |    500.0  |             334.0       
      f16 B=16, M=128, H=16, K=128   |          175.6      |           236.1       |    298.8  |             261.1       
      f32 B=16, M=128, H=16, K=128   |          531.8      |           615.8       |    677.2  |             638.0       
      f16 B=16, M=512, H=16, K=16    |          558.2      |           595.0       |   1201.9  |             607.8       
      f32 B=16, M=512, H=16, K=16    |         2146.7      |          2169.9       |   4416.1  |            2200.6       
      f16 B=16, M=512, H=16, K=32    |          653.5      |           732.3       |   1305.1  |             748.5       
      f32 B=16, M=512, H=16, K=32    |         2296.3      |          2373.9       |   4641.3  |            2400.1       
      f16 B=16, M=512, H=16, K=64    |          848.8      |           996.9       |   1544.6  |            1022.5       
      f32 B=16, M=512, H=16, K=64    |         2954.0      |          3117.1       |   5124.7  |            3157.6       
      f16 B=16, M=512, H=16, K=128   |         1735.4      |          1961.1       |   1982.7  |            2056.9       
      f32 B=16, M=512, H=16, K=128   |         6218.7      |          6396.4       |   6094.0  |            6600.3       
      f16 B=16, M=1024, H=16, K=16   |         2236.4      |          2319.4       |   4279.0  |            2331.6       
      f32 B=16, M=1024, H=16, K=16   |         8379.2      |          8363.9       |  16643.9  |            8503.6       
      f16 B=16, M=1024, H=16, K=32   |         2430.8      |          2649.6       |   4496.8  |            2608.7       
      f32 B=16, M=1024, H=16, K=32   |         8864.7      |          8907.8       |  17291.0  |            9074.0       
      f16 B=16, M=1024, H=16, K=64   |         3007.2      |          3351.3       |   4995.5  |            3351.0       
      f32 B=16, M=1024, H=16, K=64   |        11355.4      |         11627.1       |  18707.5  |           11694.3       
      f16 B=16, M=1024, H=16, K=128  |         6296.2      |          6748.7       |   5943.5  |            6967.0       
      f32 B=16, M=1024, H=16, K=128  |        23425.3      |         23360.0       |  21520.6  |           24169.7       
      f16 B=64, M=128, H=16, K=16    |          165.5      |           195.9       |    440.3  |             211.5       
      f32 B=64, M=128, H=16, K=16    |          497.4      |           540.7       |   1270.8  |             550.3       
      f16 B=64, M=128, H=16, K=32    |          210.4      |           274.9       |    544.8  |             298.5       
      f32 B=64, M=128, H=16, K=32    |          604.4      |           696.6       |   1428.3  |             710.9       
      f16 B=64, M=128, H=16, K=64    |          330.4      |           452.3       |    766.0  |             498.1       
      f32 B=64, M=128, H=16, K=64    |          883.4      |          1060.4       |   1745.2  |            1082.2       
      f16 B=64, M=128, H=16, K=128   |          605.5      |           847.8       |   1223.6  |             933.9       
      f32 B=64, M=128, H=16, K=128   |         1847.4      |          2169.7       |   2388.8  |            2236.0       
      f16 B=64, M=512, H=16, K=16    |         2004.7      |          2120.0       |   4487.0  |            2179.4       
      f32 B=64, M=512, H=16, K=16    |         6655.4      |          6818.8       |  16993.8  |            6872.1       
      f16 B=64, M=512, H=16, K=32    |         2379.3      |          2593.1       |   4957.2  |            2704.0       
      f32 B=64, M=512, H=16, K=32    |         7349.4      |          7644.6       |  17852.2  |            7736.2       
      f16 B=64, M=512, H=16, K=64    |         3129.6      |          3616.6       |   5888.8  |            3786.2       
      f32 B=64, M=512, H=16, K=64    |         9432.5      |         10123.9       |  19770.6  |           10178.5       
      f16 B=64, M=512, H=16, K=128   |         6054.1      |          7019.9       |   7712.6  |            7350.2       
      f32 B=64, M=512, H=16, K=128   |        21565.6      |         22281.9       |  23653.0  |           23084.4       
      f16 B=64, M=1024, H=16, K=16   |         7929.4      |          8199.1       |  16876.3  |            8242.5       
      f32 B=64, M=1024, H=16, K=16   |        26135.2      |         26347.9       |  66351.1  |           26639.0       
      f16 B=64, M=1024, H=16, K=32   |         8876.8      |          9450.0       |  17869.4  |            9473.5       
      f32 B=64, M=1024, H=16, K=32   |        27685.3      |         28104.6       |  69105.9  |           28428.7       
      f16 B=64, M=1024, H=16, K=64   |        11198.7      |         12180.5       |  19932.3  |           12543.4       
      f32 B=64, M=1024, H=16, K=64   |        34978.2      |         36239.4       |  74813.7  |           36482.4       
      f16 B=64, M=1024, H=16, K=128  |        21618.9      |         23439.6       |  23741.1  |           24160.1       
      f32 B=64, M=1024, H=16, K=128  |        80785.3      |         81080.8       |  86003.6  |           84132.9       

Times are in microseconds (us).
P100/V100 bw (new benchmarks)
[---------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------------]
                                                         |  48_chunk3_31735f94  |  45_bwpacked_e53c5f3a  |  vanilla   |  47_bwpackedgrad_9bacdf65
1 threads: --------------------------------------------------------------------------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |         6846.3       |          7583.8        |    3569.3  |            7599.5        
                          f32 B=384, M=197, H=1, K=88    |         9883.1       |         10107.2        |    4312.8  |           10486.3        
                          f16 B=384, M=197, H=1, K=80    |         6486.4       |          6997.7        |    3418.0  |            7037.3        
                          f32 B=384, M=197, H=1, K=80    |         9330.3       |          9550.6        |    4094.7  |            9893.4        
                          f16 B=384, M=197, H=1, K=64    |         3615.4       |          3930.4        |    2911.0  |            4074.2        
                          f32 B=384, M=197, H=1, K=64    |         6281.4       |          6554.5        |    3431.9  |            6738.1        
                          f16 B=1024, M=197, H=1, K=88   |        17226.8       |         18593.1        |    9733.2  |           18772.9        
                          f32 B=1024, M=197, H=1, K=88   |        26593.3       |         27136.2        |   12033.8  |           28184.2        
                          f16 B=1024, M=197, H=1, K=80   |        16330.1       |         17478.6        |    9270.2  |           17735.3        
                          f32 B=1024, M=197, H=1, K=80   |        25208.9       |         25680.1        |   11224.5  |           26636.1        
                          f16 B=1024, M=197, H=1, K=64   |         8889.1       |          9728.8        |    7646.1  |           10089.7        
                          f32 B=1024, M=197, H=1, K=64   |        16914.7       |         17743.4        |    9383.8  |           18068.4        
                          f16 B=512, M=197, H=1, K=80    |         8227.3       |          8878.4        |    4579.3  |            8953.6        
                          f32 B=512, M=197, H=1, K=80    |        13078.7       |         13346.0        |    5486.4  |           13817.6        
                          f16 B=32, M=197, H=16, K=80    |         8278.9       |          9002.9        |    4816.2  |            9025.6        
                          f32 B=32, M=197, H=16, K=80    |        12913.8       |         13371.2        |    5777.7  |           13667.6        
                          f16 B=32, M=197, H=16, K=64    |         4565.2       |          5000.0        |    4023.4  |            5146.3        
                          f32 B=32, M=197, H=16, K=64    |         8824.0       |          9257.7        |    4797.2  |            9400.5        
                          f16 B=32, M=197, H=16, K=128   |         9770.0       |         10849.7        |    5983.2  |           10932.0        
                          f32 B=32, M=197, H=16, K=128   |        15715.2       |         16559.9        |    7513.6  |           16839.9        
                          f16 B=256, M=197, H=1, K=88    |         5011.2       |          5363.8        |    2444.9  |            5426.0        
                          f32 B=256, M=197, H=1, K=88    |         6918.7       |          7040.8        |    2867.8  |            7303.2        
                          f16 B=16, M=197, H=16, K=88    |         4963.8       |          5343.9        |    2545.2  |            5398.9        
                          f32 B=16, M=197, H=16, K=88    |         6727.9       |          6981.7        |    3040.3  |            7121.2        
                          f16 B=16, M=197, H=16, K=64    |         2586.5       |          2777.1        |    2025.5  |            2905.6        
                          f32 B=16, M=197, H=16, K=64    |         4404.3       |          4607.2        |    2431.1  |            4691.8        
                          f16 B=16, M=197, H=16, K=128   |         5643.2       |          6194.1        |    3016.1  |            6216.3        
                          f32 B=16, M=197, H=16, K=128   |         7887.1       |          8308.3        |    3676.6  |            8456.2        
                          f16 B=1, M=4096, H=160, K=128  |      1087008.7       |       1115355.5        |            |         1091596.8        
                          f32 B=1, M=4096, H=160, K=128  |      1220066.8       |       1223422.8        |            |         1227912.2        
                          f16 B=2, M=4096, H=160, K=128  |      1734244.4       |       1794068.7        |            |         1756266.7        
                          f32 B=2, M=4096, H=160, K=128  |      2437675.5       |       2445780.4        |            |         2451957.5        
                          f16 B=1, M=8192, H=160, K=128  |      4367110.4       |       4466170.9        |            |         4383747.4        
                          f32 B=1, M=8192, H=160, K=128  |      4865732.9       |       4865708.9        |            |         4887066.5        
                          f16 B=2, M=8192, H=160, K=128  |      7002715.1       |       7146077.9        |            |         7033922.8        
                          f16 B=1024, M=82, H=8, K=64    |        23247.5       |         24929.5        |   18047.8  |           26928.2        
                          f32 B=1024, M=82, H=8, K=64    |        46463.2       |         48705.6        |   22797.5  |           50736.3        
                          f16 B=150, M=256, H=16, K=64   |        23467.9       |         25647.3        |   24569.2  |           26841.8        
                          f32 B=150, M=256, H=16, K=64   |        36887.7       |         39698.0        |   32050.2  |           40389.0        
                          f16 B=64, M=256, H=12, K=64    |         7723.7       |          8499.0        |    7702.1  |            8694.9        
                          f32 B=64, M=256, H=12, K=64    |        11992.1       |         12819.9        |    9874.5  |           13107.9        
                          f16 B=1, M=4096, H=16, K=40    |       142655.5       |        142899.7        |   28928.6  |          142922.7        
                          f32 B=1, M=4096, H=16, K=40    |       142626.8       |        142685.3        |   37303.2  |          142541.0        
                          f16 B=1, M=16384, H=16, K=40   |      2274095.0       |       2274882.0        |            |         2275019.9        
                          f32 B=1, M=16384, H=16, K=40   |      2284027.2       |       2279415.7        |            |         2277761.9        
                          f16 B=16, M=128, H=16, K=16    |          513.2       |           547.1        |     571.5  |             570.9        
                          f32 B=16, M=128, H=16, K=16    |          667.4       |           704.3        |     693.1  |             728.0        
                          f16 B=16, M=128, H=16, K=32    |          600.3       |           667.0        |     671.3  |             713.1        
                          f32 B=16, M=128, H=16, K=32    |          823.9       |           888.9        |     823.5  |             937.3        
                          f16 B=16, M=128, H=16, K=64    |          781.0       |           900.6        |     883.1  |             998.9        
                          f32 B=16, M=128, H=16, K=64    |         1173.7       |          1293.8        |    1077.0  |            1393.4        
                          f16 B=16, M=128, H=16, K=128   |         1649.2       |          1877.2        |    1323.2  |            2026.3        
                          f32 B=16, M=128, H=16, K=128   |         2250.5       |          2473.0        |    1654.7  |            2636.6        
                          f16 B=16, M=512, H=16, K=16    |         7709.3       |          7914.6        |    6945.1  |            7928.7        
                          f32 B=16, M=512, H=16, K=16    |         9797.2       |          9950.5        |    8499.4  |           10029.3        
                          f16 B=16, M=512, H=16, K=32    |         8956.9       |          9210.8        |    7517.1  |            9307.0        
                          f32 B=16, M=512, H=16, K=32    |        11480.7       |         11710.9        |    9249.4  |           11884.4        
                          f16 B=16, M=512, H=16, K=64    |        11324.0       |         11829.1        |    8849.5  |           12001.8        
                          f32 B=16, M=512, H=16, K=64    |        15744.1       |         16258.0        |   10954.6  |           16481.1        
                          f16 B=16, M=512, H=16, K=128   |        25320.2       |         26584.0        |   12412.3  |           26725.0        
                          f32 B=16, M=512, H=16, K=128   |        31187.1       |         32290.3        |   15167.5  |           32818.4        
                          f16 B=16, M=1024, H=16, K=16   |        31484.2       |         31601.4        |   26434.6  |           31894.6        
                          f32 B=16, M=1024, H=16, K=16   |        38754.1       |         38900.1        |   32320.0  |           39203.9        
                          f16 B=16, M=1024, H=16, K=32   |        36000.2       |         36672.6        |   28341.4  |           36579.5        
                          f32 B=16, M=1024, H=16, K=32   |        45070.7       |         45262.3        |   34914.2  |           45774.5        
                          f16 B=16, M=1024, H=16, K=64   |        45324.9       |         46540.4        |   32089.9  |           46784.2        
                          f32 B=16, M=1024, H=16, K=64   |        61320.3       |         62411.1        |   39565.0  |           63217.0        
                          f16 B=16, M=1024, H=16, K=128  |       104342.9       |        108469.4        |   43221.9  |          105620.6        
                          f32 B=16, M=1024, H=16, K=128  |       122688.4       |        125050.9        |   51205.7  |          126080.9        
                          f16 B=64, M=128, H=16, K=16    |         1707.9       |          1824.9        |    2106.4  |            1923.2        
                          f32 B=64, M=128, H=16, K=16    |         2487.4       |          2612.5        |    2565.1  |            2707.6        
                          f16 B=64, M=128, H=16, K=32    |         2016.8       |          2254.4        |    2485.4  |            2412.3        
                          f32 B=64, M=128, H=16, K=32    |         3135.8       |          3365.6        |    3063.2  |            3518.5        
                          f16 B=64, M=128, H=16, K=64    |         2700.2       |          3167.0        |    3306.0  |            3478.4        
                          f32 B=64, M=128, H=16, K=64    |         4435.1       |          4944.7        |    4227.6  |            5181.2        
                          f16 B=64, M=128, H=16, K=128   |         5769.1       |          6858.2        |    5299.8  |            7356.1        
                          f32 B=64, M=128, H=16, K=128   |         8577.9       |          9672.0        |    6916.3  |           10093.5        
                          f16 B=64, M=512, H=16, K=16    |        25994.0       |         26782.0        |   27240.9  |           26662.2        
                          f32 B=64, M=512, H=16, K=16    |        36864.9       |         37299.3        |   34159.3  |           37576.7        
                          f16 B=64, M=512, H=16, K=32    |        30680.4       |         32113.8        |   30109.0  |           32419.7        
                          f32 B=64, M=512, H=16, K=32    |        43638.5       |         44557.9        |   37358.5  |           45145.0        
                          f16 B=64, M=512, H=16, K=64    |        39417.5       |         41666.5        |   36004.2  |           42374.9        
                          f32 B=64, M=512, H=16, K=64    |        60049.2       |         63148.0        |   43412.6  |           63286.8        
                          f16 B=64, M=512, H=16, K=128   |        88951.1       |         93087.0        |   51730.1  |           94861.6        
                          f32 B=64, M=512, H=16, K=128   |       119728.7       |        124340.3        |   62413.7  |          126382.2        
                          f16 B=64, M=1024, H=16, K=16   |       108368.3       |        111081.8        |  106479.7  |          108716.1        
                          f32 B=64, M=1024, H=16, K=16   |       145612.0       |        147310.4        |            |          147380.7        
                          f16 B=64, M=1024, H=16, K=32   |       124296.1       |        127366.8        |  113905.0  |          126975.3        
                          f32 B=64, M=1024, H=16, K=32   |       171082.3       |        172539.0        |            |          173893.9        
                          f16 B=64, M=1024, H=16, K=64   |       155116.3       |        160429.2        |  130759.4  |          161834.0        
                          f32 B=64, M=1024, H=16, K=64   |       234356.0       |        239612.2        |            |          239948.3        
                          f16 B=64, M=1024, H=16, K=128  |       349728.3       |        360975.7        |  176158.7  |          371185.2        
                          f32 B=64, M=1024, H=16, K=128  |       468810.0       |        476415.4        |            |          481908.5        
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |         1700.3       |          1840.0        |    1375.3  |            1930.9        
                          f32 B=384, M=197, H=1, K=88    |         4456.4       |          4579.3        |    2235.5  |            4708.6        
                          f16 B=384, M=197, H=1, K=80    |         1623.3       |          1719.9        |    1279.5  |            1806.9        
                          f32 B=384, M=197, H=1, K=80    |         4031.2       |          4141.9        |    2149.8  |            4252.6        
                          f16 B=384, M=197, H=1, K=64    |         1092.8       |          1187.0        |    1048.5  |            1237.6        
                          f32 B=384, M=197, H=1, K=64    |         2717.5       |          2918.5        |    1738.5  |            2907.9        
                          f16 B=1024, M=197, H=1, K=88   |         4428.7       |          4906.2        |    3723.7  |            5178.2        
                          f32 B=1024, M=197, H=1, K=88   |        10947.5       |         11362.9        |    6052.5  |           11802.1        
                          f16 B=1024, M=197, H=1, K=80   |         4237.1       |          4491.4        |    3331.7  |            4725.6        
                          f32 B=1024, M=197, H=1, K=80   |         9842.6       |         10159.7        |    5682.4  |           10435.6        
                          f16 B=1024, M=197, H=1, K=64   |         2679.2       |          2927.4        |    2674.4  |            3033.0        
                          f32 B=1024, M=197, H=1, K=64   |         6597.6       |          7154.9        |    4489.7  |            7063.1        
                          f16 B=512, M=197, H=1, K=80    |         2239.5       |          2366.5        |    1684.2  |            2472.0        
                          f32 B=512, M=197, H=1, K=80    |         5362.4       |          5519.6        |    2857.9  |            5651.4        
                          f16 B=32, M=197, H=16, K=80    |         2208.1       |          2380.0        |    1803.4  |            2439.4        
                          f32 B=32, M=197, H=16, K=80    |         5503.6       |          5736.7        |    3017.5  |            5796.2        
                          f16 B=32, M=197, H=16, K=64    |         1493.4       |          1620.6        |    1457.2  |            1678.6        
                          f32 B=32, M=197, H=16, K=64    |         3672.6       |          3941.6        |    2415.0  |            3898.2        
                          f16 B=32, M=197, H=16, K=128   |         2634.3       |          2888.0        |    2215.1  |            2991.5        
                          f32 B=32, M=197, H=16, K=128   |         6811.5       |          7334.0        |    4049.3  |            7261.9        
                          f16 B=256, M=197, H=1, K=88    |         1290.3       |          1382.0        |     944.8  |            1449.4        
                          f32 B=256, M=197, H=1, K=88    |         2965.8       |          3043.2        |    1528.7  |            3137.7        
                          f16 B=16, M=197, H=16, K=88    |         1267.3       |          1357.0        |     970.8  |            1395.5        
                          f32 B=16, M=197, H=16, K=88    |         2879.9       |          3014.7        |    1626.5  |            3054.3        
                          f16 B=16, M=197, H=16, K=64    |          737.3       |           799.8        |     771.3  |             836.9        
                          f32 B=16, M=197, H=16, K=64    |         1879.2       |          2000.9        |    1282.5  |            1994.5        
                          f16 B=16, M=197, H=16, K=128   |         1443.9       |          1570.7        |    1142.2  |            1628.8        
                          f32 B=16, M=197, H=16, K=128   |         3480.5       |          3723.6        |    2027.2  |            3714.6        
                          f16 B=1, M=4096, H=160, K=128  |       150006.2       |        151877.5        |            |          152570.6        
                          f32 B=1, M=4096, H=160, K=128  |       582870.9       |        583519.8        |            |          585570.1        
                          f16 B=2, M=4096, H=160, K=128  |       301231.4       |        304511.7        |            |          305801.2        
                          f32 B=2, M=4096, H=160, K=128  |      1174724.1       |       1172498.4        |            |         1176814.0        
                          f16 B=1, M=8192, H=160, K=128  |       597461.6       |        600463.4        |            |          603066.6        
                          f32 B=1, M=8192, H=160, K=128  |      2333657.8       |       2329212.1        |            |         2339766.1        
                          f16 B=2, M=8192, H=160, K=128  |      1196837.5       |       1206932.4        |            |         1209012.2        
                          f16 B=1024, M=82, H=8, K=64    |         8926.8       |          9723.4        |    5799.4  |           10084.2        
                          f32 B=1024, M=82, H=8, K=64    |        15920.4       |         17434.4        |   11027.0  |           17492.8        
                          f16 B=150, M=256, H=16, K=64   |         5524.2       |          6363.9        |    7557.9  |            6586.2        
                          f32 B=150, M=256, H=16, K=64   |        17506.9       |         18843.5        |   16263.5  |           18988.6        
                          f16 B=64, M=256, H=12, K=64    |         1800.6       |          2050.3        |    2383.4  |            2139.0        
                          f32 B=64, M=256, H=12, K=64    |         5753.6       |          6196.3        |    4971.2  |            6200.0        
                          f16 B=1, M=4096, H=16, K=40    |        47649.5       |         47836.0        |    8368.4  |           47973.6        
                          f32 B=1, M=4096, H=16, K=40    |       111092.1       |        111027.3        |   19475.9  |          111257.8        
                          f16 B=1, M=16384, H=16, K=40   |       765320.2       |        765686.9        |            |          767337.2        
                          f32 B=1, M=16384, H=16, K=40   |      1769169.0       |       1769675.1        |            |         1769371.4        
                          f16 B=16, M=128, H=16, K=16    |          178.9       |           196.8        |     445.9  |             188.3        
                          f32 B=16, M=128, H=16, K=16    |          301.3       |           319.1        |     422.5  |             336.3        
                          f16 B=16, M=128, H=16, K=32    |          174.1       |           174.2        |     394.0  |             179.5        
                          f32 B=16, M=128, H=16, K=32    |          395.7       |           433.2        |     580.0  |             440.4        
                          f16 B=16, M=128, H=16, K=64    |          205.0       |           253.5        |     460.6  |             270.9        
                          f32 B=16, M=128, H=16, K=64    |          573.7       |           639.3        |     598.1  |             656.1        
                          f16 B=16, M=128, H=16, K=128   |          399.5       |           484.3        |     515.2  |             521.8        
                          f32 B=16, M=128, H=16, K=128   |         1126.3       |          1260.8        |    1008.1  |            1282.4        
                          f16 B=16, M=512, H=16, K=16    |         1597.6       |          1627.2        |    1901.1  |            1662.1        
                          f32 B=16, M=512, H=16, K=16    |         4458.5       |          4528.8        |    4232.0  |            4559.4        
                          f16 B=16, M=512, H=16, K=32    |         1819.1       |          1868.7        |    2097.2  |            1945.5        
                          f32 B=16, M=512, H=16, K=32    |         5604.2       |          5757.1        |    4566.4  |            5784.8        
                          f16 B=16, M=512, H=16, K=64    |         2345.5       |          2495.6        |    2558.0  |            2573.2        
                          f32 B=16, M=512, H=16, K=64    |         7778.3       |          8017.1        |    5488.2  |            8083.7        
                          f16 B=16, M=512, H=16, K=128   |         4516.6       |          4821.0        |    3386.7  |            4968.2        
                          f32 B=16, M=512, H=16, K=128   |        15412.7       |         15959.2        |    8865.9  |           16047.5        
                          f16 B=16, M=1024, H=16, K=16   |         6195.9       |          6217.6        |    6995.3  |            6326.4        
                          f32 B=16, M=1024, H=16, K=16   |        18136.2       |         18312.0        |   16088.2  |           18354.1        
                          f16 B=16, M=1024, H=16, K=32   |         7072.8       |          7122.3        |    7406.9  |            7297.7        
                          f32 B=16, M=1024, H=16, K=32   |        22108.2       |         22116.7        |   17112.5  |           22436.8        
                          f16 B=16, M=1024, H=16, K=64   |         8868.0       |          9104.6        |    8627.1  |            9311.8        
                          f32 B=16, M=1024, H=16, K=64   |        30710.5       |         31041.3        |   19860.8  |           31338.1        
                          f16 B=16, M=1024, H=16, K=128  |        17091.8       |         17655.5        |   10548.3  |           18083.8        
                          f32 B=16, M=1024, H=16, K=128  |        60317.8       |         61461.7        |   32919.2  |           61548.8        
                          f16 B=64, M=128, H=16, K=16    |          413.6       |           453.8        |     635.5  |             480.6        
                          f32 B=64, M=128, H=16, K=16    |         1033.8       |          1114.3        |    1238.9  |            1119.5        
                          f16 B=64, M=128, H=16, K=32    |          505.7       |           587.9        |     813.6  |             630.1        
                          f32 B=64, M=128, H=16, K=32    |         1423.0       |          1551.4        |    1533.4  |            1581.8        
                          f16 B=64, M=128, H=16, K=64    |          743.3       |           916.8        |    1187.7  |             976.5        
                          f32 B=64, M=128, H=16, K=64    |         2093.3       |          2384.6        |    2156.3  |            2405.4        
                          f16 B=64, M=128, H=16, K=128   |         1408.2       |          1734.3        |    1918.7  |            1859.6        
                          f32 B=64, M=128, H=16, K=128   |         4125.3       |          4671.4        |    3762.0  |            4717.0        
                          f16 B=64, M=512, H=16, K=16    |         5531.2       |          5643.3        |    7454.4  |            5770.8        
                          f32 B=64, M=512, H=16, K=16    |        16214.0       |         16531.2        |   16661.3  |           16540.8        
                          f16 B=64, M=512, H=16, K=32    |         6495.5       |          6725.2        |    8353.7  |            6941.8        
                          f32 B=64, M=512, H=16, K=32    |        20520.6       |         20941.9        |   18352.4  |           21116.8        
                          f16 B=64, M=512, H=16, K=64    |         8686.1       |          9278.6        |   10343.4  |            9593.2        
                          f32 B=64, M=512, H=16, K=64    |        28891.1       |         30003.0        |   22749.4  |           30139.1        
                          f16 B=64, M=512, H=16, K=128   |        15991.4       |         17412.3        |   14633.0  |           17848.2        
                          f32 B=64, M=512, H=16, K=128   |        57526.8       |         59970.8        |   40089.9  |           60016.9        
                          f16 B=64, M=1024, H=16, K=16   |        21552.8       |         21603.1        |   28447.1  |           22030.0        
                          f32 B=64, M=1024, H=16, K=16   |        65321.2       |         65736.8        |            |           65932.0        
                          f16 B=64, M=1024, H=16, K=32   |        25695.4       |         25905.9        |   30592.1  |           26644.8        
                          f32 B=64, M=1024, H=16, K=32   |        80213.4       |         80446.7        |            |           81363.1        
                          f16 B=64, M=1024, H=16, K=64   |        32465.6       |         33575.1        |   37233.4  |           34370.8        
                          f32 B=64, M=1024, H=16, K=64   |       112996.7       |        115632.0        |            |          115970.8        
                          f16 B=64, M=1024, H=16, K=128  |        60363.5       |         62800.2        |   48883.7  |           64505.1        
                          f32 B=64, M=1024, H=16, K=128  |       225023.4       |        230527.4        |            |          229851.8        

Times are in microseconds (us).

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 5, 2022
danthe3rd pushed a commit that referenced this pull request Oct 5, 2022
ghstack-source-id: d4449a3beb49285862c60f631849ca5e272f7578
Pull Request resolved: #458
@danthe3rd danthe3rd requested a review from fmassa October 5, 2022 12:12
danthe3rd pushed a commit that referenced this pull request Oct 5, 2022
ghstack-source-id: 49d85372692924617f2afc9d79aaa46a44325115
Pull Request resolved: #458
danthe3rd pushed a commit that referenced this pull request Oct 5, 2022
ghstack-source-id: b7304c42fbede4cf920541aa50a7554bdfbac3b5
Pull Request resolved: #458
@codecov-commenter
Copy link

Codecov Report

Base: 91.50% // Head: 91.51% // Increases project coverage by +0.01% 🎉

Coverage data is based on head (b426450) compared to base (33fbe65).
Patch coverage: 94.11% of modified lines in pull request are covered.

Additional details and impacted files
@@                   Coverage Diff                    @@
##           gh/danthe3rd/48/base     #458      +/-   ##
========================================================
+ Coverage                 91.50%   91.51%   +0.01%     
========================================================
  Files                        75       75              
  Lines                      4412     4429      +17     
========================================================
+ Hits                       4037     4053      +16     
- Misses                      375      376       +1     
Flag Coverage Δ
Python 91.51% <94.11%> (+0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
xformers/ops.py 89.34% <94.11%> (+0.25%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty nice, thanks Daniel!

I think there is a missing condition in the backward to check that the gradients all come from the same storage.
There is also some opportunities to make the code even more generic and support an arbitrary number of input elements through torch.unbind, but up to you.

Also, can you add tests to Chunk3 for various cases? It would be good to stress-test this with some basic cases on its own test to make sure we are not missing anything else.

xformers/ops.py Outdated Show resolved Hide resolved
xformers/ops.py Outdated Show resolved Hide resolved
xformers/ops.py Outdated Show resolved Hide resolved
Comment on lines 103 to 129
attn_bias_type=[type(None), torch.Tensor, xformers.ops.LowerTriangularMask],
dtype=[torch.half, torch.bfloat16, torch.float],
attn_bias_type=[type(None)],
dtype=[torch.half],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: looks like something is missing here? :-)

Comment on lines +58 to +61
def T(t):
return t.permute((0, 2, 1, 3)).reshape(
[t.shape[0] * t.shape[2], t.shape[1], t.shape[3]]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I wonder if we should still keep the previous benchmark without the permute + contiguous in the hot path.

The improved benchmark you just added is very important as it's the main entry point for the users, but it might also hide some potential improvements to be done in the standard attention because now we are also measuring those extra overheads.

I would maybe this benchmark on top of the previous one, so that we have more numbers when measuring things together. But this is only a suggestion

@fmassa
Copy link
Contributor

fmassa commented Oct 5, 2022

BTW, if you generalize your function to support unbind, you could probably send this improvement to PyTorch proper.

The backward implementation currently lives in here in PyTorch

… avoid torch.cat in BW"


**SUMMARY**

Also:
- updated benchmarks to reflect the real-world scenario where we chunk a packed qkv - for both fw/bw.
- added coverage for chunking in tests

**PERFORMANCE IMPACT**

<details>
<summary>A100 bw (new benchmarks)</summary>

```
[---------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------]
                                     |  48_chunk3_31735f9  |  45_bwpacked_e53c5f3  |  vanilla  |  47_bwpackedgrad_9bacdf6
1 threads: --------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |          560.7      |           663.9       |   2265.7  |             710.3       
      f32 B=384, M=197, H=1, K=88    |         2445.1      |          2540.3       |   1843.3  |            2611.0       
      f16 B=384, M=197, H=1, K=80    |          530.4      |           619.9       |   1922.8  |             663.0       
      f32 B=384, M=197, H=1, K=80    |         2326.1      |          2425.2       |   1788.7  |            2476.4       
      f16 B=384, M=197, H=1, K=64    |          391.7      |           462.2       |   1812.7  |             492.8       
      f32 B=384, M=197, H=1, K=64    |         1275.0      |          1379.4       |   1675.4  |            1388.4       
      f16 B=1024, M=197, H=1, K=88   |         1399.5      |          1666.2       |   5965.2  |            1775.5       
      f32 B=1024, M=197, H=1, K=88   |         6332.5      |          6618.1       |   4559.6  |            6740.5       
      f16 B=1024, M=197, H=1, K=80   |         1326.2      |          1543.9       |   5041.4  |            1652.3       
      f32 B=1024, M=197, H=1, K=80   |         6057.1      |          6301.3       |   4411.6  |            6433.6       
      f16 B=1024, M=197, H=1, K=64   |          876.9      |          1063.1       |   4749.3  |            1133.2       
      f32 B=1024, M=197, H=1, K=64   |         3360.2      |          3629.0       |   4118.8  |            3652.0       
      f16 B=512, M=197, H=1, K=80    |          669.0      |           786.4       |   2544.9  |             842.2       
      f32 B=512, M=197, H=1, K=80    |         3032.3      |          3127.8       |   2287.4  |            3229.8       
      f16 B=32, M=197, H=16, K=80    |          663.0      |           789.7       |   2569.0  |             837.8       
      f32 B=32, M=197, H=16, K=80    |         3005.5      |          3166.3       |   2354.1  |            3225.9       
      f16 B=32, M=197, H=16, K=64    |          459.9      |           553.4       |   2436.3  |             591.9       
      f32 B=32, M=197, H=16, K=64    |         1814.1      |          1962.5       |   2197.3  |            1962.1       
      f16 B=32, M=197, H=16, K=128   |          792.5      |           981.9       |   4505.9  |            1056.5       
      f32 B=32, M=197, H=16, K=128   |         3734.8      |          3995.7       |   2805.8  |            4021.5       
      f16 B=256, M=197, H=1, K=88    |          413.4      |           482.6       |   1529.5  |             515.5       
      f32 B=256, M=197, H=1, K=88    |         1741.9      |          1818.3       |   1208.6  |            1852.4       
      f16 B=16, M=197, H=16, K=88    |          410.3      |           482.9       |   1545.7  |             512.5       
      f32 B=16, M=197, H=16, K=88    |         1734.9      |          1832.1       |   1250.6  |            1849.4       
      f16 B=16, M=197, H=16, K=64    |          235.4      |           286.0       |   1247.1  |             305.3       
      f32 B=16, M=197, H=16, K=64    |         1077.1      |          1143.7       |   1125.9  |            1154.0       
      f16 B=16, M=197, H=16, K=128   |          455.4      |           554.1       |   2273.1  |             596.0       
      f32 B=16, M=197, H=16, K=128   |         2028.9      |          2164.5       |   1446.7  |            2175.0       
      f16 B=1, M=4096, H=160, K=128  |        62454.4      |         63474.5       |  45930.5  |           64052.7       
      f32 B=1, M=4096, H=160, K=128  |       239035.4      |        232672.1       |           |          240073.9       
      f16 B=2, M=4096, H=160, K=128  |        98791.3      |        101006.4       |           |          101942.0       
      f32 B=2, M=4096, H=160, K=128  |       375914.9      |        368050.6       |           |          381280.4       
      f16 B=1, M=8192, H=160, K=128  |       248498.9      |        250066.9       |           |          251500.4       
      f32 B=1, M=8192, H=160, K=128  |       945102.2      |        922549.3       |           |          949256.4       
      f16 B=2, M=8192, H=160, K=128  |       389207.8      |        394486.6       |           |          396190.4       
      f32 B=2, M=8192, H=160, K=128  |      1496334.3      |       1449974.3       |           |         1502215.3       
      f16 B=1024, M=82, H=8, K=64    |         1872.4      |          2503.8       |   3819.8  |            2693.7       
      f32 B=1024, M=82, H=8, K=64    |         8734.3      |          9637.8       |   8732.9  |            9672.2       
      f16 B=150, M=256, H=16, K=64   |         2126.4      |          2713.4       |   4554.3  |            2880.8       
      f32 B=150, M=256, H=16, K=64   |         6214.3      |          7052.2       |  12943.2  |            7099.2       
      f16 B=64, M=256, H=12, K=64    |          741.2      |           930.1       |   1493.0  |             990.6       
      f32 B=64, M=256, H=12, K=64    |         2144.2      |          2408.5       |   4267.7  |            2433.8       
      f16 B=1, M=4096, H=16, K=40    |        24583.7      |         24224.8       |   4195.2  |           24500.2       
      f32 B=1, M=4096, H=16, K=40    |        72497.9      |         72070.8       |  17744.1  |           72393.0       
      f16 B=1, M=16384, H=16, K=40   |       451481.8      |        439027.7       |           |          451499.9       
      f32 B=1, M=16384, H=16, K=40   |      1169509.1      |       1164880.1       |           |         1169769.3       
      f16 B=256, M=4096, H=16, K=64  |       597391.6      |        625921.0       |           |          610433.2       
      f16 B=16, M=128, H=16, K=16    |           93.1      |           126.7       |    241.2  |             132.3       
      f32 B=16, M=128, H=16, K=16    |          184.1      |           176.5       |    373.8  |             180.7       
      f16 B=16, M=128, H=16, K=32    |          127.9      |           126.3       |    241.4  |             106.7       
      f32 B=16, M=128, H=16, K=32    |          194.1      |           216.6       |    412.7  |             225.8       
      f16 B=16, M=128, H=16, K=64    |          131.4      |           126.8       |    239.8  |             134.5       
      f32 B=16, M=128, H=16, K=64    |          280.4      |           326.0       |    500.0  |             334.0       
      f16 B=16, M=128, H=16, K=128   |          175.6      |           236.1       |    298.8  |             261.1       
      f32 B=16, M=128, H=16, K=128   |          531.8      |           615.8       |    677.2  |             638.0       
      f16 B=16, M=512, H=16, K=16    |          558.2      |           595.0       |   1201.9  |             607.8       
      f32 B=16, M=512, H=16, K=16    |         2146.7      |          2169.9       |   4416.1  |            2200.6       
      f16 B=16, M=512, H=16, K=32    |          653.5      |           732.3       |   1305.1  |             748.5       
      f32 B=16, M=512, H=16, K=32    |         2296.3      |          2373.9       |   4641.3  |            2400.1       
      f16 B=16, M=512, H=16, K=64    |          848.8      |           996.9       |   1544.6  |            1022.5       
      f32 B=16, M=512, H=16, K=64    |         2954.0      |          3117.1       |   5124.7  |            3157.6       
      f16 B=16, M=512, H=16, K=128   |         1735.4      |          1961.1       |   1982.7  |            2056.9       
      f32 B=16, M=512, H=16, K=128   |         6218.7      |          6396.4       |   6094.0  |            6600.3       
      f16 B=16, M=1024, H=16, K=16   |         2236.4      |          2319.4       |   4279.0  |            2331.6       
      f32 B=16, M=1024, H=16, K=16   |         8379.2      |          8363.9       |  16643.9  |            8503.6       
      f16 B=16, M=1024, H=16, K=32   |         2430.8      |          2649.6       |   4496.8  |            2608.7       
      f32 B=16, M=1024, H=16, K=32   |         8864.7      |          8907.8       |  17291.0  |            9074.0       
      f16 B=16, M=1024, H=16, K=64   |         3007.2      |          3351.3       |   4995.5  |            3351.0       
      f32 B=16, M=1024, H=16, K=64   |        11355.4      |         11627.1       |  18707.5  |           11694.3       
      f16 B=16, M=1024, H=16, K=128  |         6296.2      |          6748.7       |   5943.5  |            6967.0       
      f32 B=16, M=1024, H=16, K=128  |        23425.3      |         23360.0       |  21520.6  |           24169.7       
      f16 B=64, M=128, H=16, K=16    |          165.5      |           195.9       |    440.3  |             211.5       
      f32 B=64, M=128, H=16, K=16    |          497.4      |           540.7       |   1270.8  |             550.3       
      f16 B=64, M=128, H=16, K=32    |          210.4      |           274.9       |    544.8  |             298.5       
      f32 B=64, M=128, H=16, K=32    |          604.4      |           696.6       |   1428.3  |             710.9       
      f16 B=64, M=128, H=16, K=64    |          330.4      |           452.3       |    766.0  |             498.1       
      f32 B=64, M=128, H=16, K=64    |          883.4      |          1060.4       |   1745.2  |            1082.2       
      f16 B=64, M=128, H=16, K=128   |          605.5      |           847.8       |   1223.6  |             933.9       
      f32 B=64, M=128, H=16, K=128   |         1847.4      |          2169.7       |   2388.8  |            2236.0       
      f16 B=64, M=512, H=16, K=16    |         2004.7      |          2120.0       |   4487.0  |            2179.4       
      f32 B=64, M=512, H=16, K=16    |         6655.4      |          6818.8       |  16993.8  |            6872.1       
      f16 B=64, M=512, H=16, K=32    |         2379.3      |          2593.1       |   4957.2  |            2704.0       
      f32 B=64, M=512, H=16, K=32    |         7349.4      |          7644.6       |  17852.2  |            7736.2       
      f16 B=64, M=512, H=16, K=64    |         3129.6      |          3616.6       |   5888.8  |            3786.2       
      f32 B=64, M=512, H=16, K=64    |         9432.5      |         10123.9       |  19770.6  |           10178.5       
      f16 B=64, M=512, H=16, K=128   |         6054.1      |          7019.9       |   7712.6  |            7350.2       
      f32 B=64, M=512, H=16, K=128   |        21565.6      |         22281.9       |  23653.0  |           23084.4       
      f16 B=64, M=1024, H=16, K=16   |         7929.4      |          8199.1       |  16876.3  |            8242.5       
      f32 B=64, M=1024, H=16, K=16   |        26135.2      |         26347.9       |  66351.1  |           26639.0       
      f16 B=64, M=1024, H=16, K=32   |         8876.8      |          9450.0       |  17869.4  |            9473.5       
      f32 B=64, M=1024, H=16, K=32   |        27685.3      |         28104.6       |  69105.9  |           28428.7       
      f16 B=64, M=1024, H=16, K=64   |        11198.7      |         12180.5       |  19932.3  |           12543.4       
      f32 B=64, M=1024, H=16, K=64   |        34978.2      |         36239.4       |  74813.7  |           36482.4       
      f16 B=64, M=1024, H=16, K=128  |        21618.9      |         23439.6       |  23741.1  |           24160.1       
      f32 B=64, M=1024, H=16, K=128  |        80785.3      |         81080.8       |  86003.6  |           84132.9       

Times are in microseconds (us).
```
</details>

<details>
<summary>P100/V100 bw (new benchmarks)</summary>

```
[---------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------------]
                                                         |  48_chunk3_31735f94  |  45_bwpacked_e53c5f3a  |  vanilla   |  47_bwpackedgrad_9bacdf65
1 threads: --------------------------------------------------------------------------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |         6846.3       |          7583.8        |    3569.3  |            7599.5        
                          f32 B=384, M=197, H=1, K=88    |         9883.1       |         10107.2        |    4312.8  |           10486.3        
                          f16 B=384, M=197, H=1, K=80    |         6486.4       |          6997.7        |    3418.0  |            7037.3        
                          f32 B=384, M=197, H=1, K=80    |         9330.3       |          9550.6        |    4094.7  |            9893.4        
                          f16 B=384, M=197, H=1, K=64    |         3615.4       |          3930.4        |    2911.0  |            4074.2        
                          f32 B=384, M=197, H=1, K=64    |         6281.4       |          6554.5        |    3431.9  |            6738.1        
                          f16 B=1024, M=197, H=1, K=88   |        17226.8       |         18593.1        |    9733.2  |           18772.9        
                          f32 B=1024, M=197, H=1, K=88   |        26593.3       |         27136.2        |   12033.8  |           28184.2        
                          f16 B=1024, M=197, H=1, K=80   |        16330.1       |         17478.6        |    9270.2  |           17735.3        
                          f32 B=1024, M=197, H=1, K=80   |        25208.9       |         25680.1        |   11224.5  |           26636.1        
                          f16 B=1024, M=197, H=1, K=64   |         8889.1       |          9728.8        |    7646.1  |           10089.7        
                          f32 B=1024, M=197, H=1, K=64   |        16914.7       |         17743.4        |    9383.8  |           18068.4        
                          f16 B=512, M=197, H=1, K=80    |         8227.3       |          8878.4        |    4579.3  |            8953.6        
                          f32 B=512, M=197, H=1, K=80    |        13078.7       |         13346.0        |    5486.4  |           13817.6        
                          f16 B=32, M=197, H=16, K=80    |         8278.9       |          9002.9        |    4816.2  |            9025.6        
                          f32 B=32, M=197, H=16, K=80    |        12913.8       |         13371.2        |    5777.7  |           13667.6        
                          f16 B=32, M=197, H=16, K=64    |         4565.2       |          5000.0        |    4023.4  |            5146.3        
                          f32 B=32, M=197, H=16, K=64    |         8824.0       |          9257.7        |    4797.2  |            9400.5        
                          f16 B=32, M=197, H=16, K=128   |         9770.0       |         10849.7        |    5983.2  |           10932.0        
                          f32 B=32, M=197, H=16, K=128   |        15715.2       |         16559.9        |    7513.6  |           16839.9        
                          f16 B=256, M=197, H=1, K=88    |         5011.2       |          5363.8        |    2444.9  |            5426.0        
                          f32 B=256, M=197, H=1, K=88    |         6918.7       |          7040.8        |    2867.8  |            7303.2        
                          f16 B=16, M=197, H=16, K=88    |         4963.8       |          5343.9        |    2545.2  |            5398.9        
                          f32 B=16, M=197, H=16, K=88    |         6727.9       |          6981.7        |    3040.3  |            7121.2        
                          f16 B=16, M=197, H=16, K=64    |         2586.5       |          2777.1        |    2025.5  |            2905.6        
                          f32 B=16, M=197, H=16, K=64    |         4404.3       |          4607.2        |    2431.1  |            4691.8        
                          f16 B=16, M=197, H=16, K=128   |         5643.2       |          6194.1        |    3016.1  |            6216.3        
                          f32 B=16, M=197, H=16, K=128   |         7887.1       |          8308.3        |    3676.6  |            8456.2        
                          f16 B=1, M=4096, H=160, K=128  |      1087008.7       |       1115355.5        |            |         1091596.8        
                          f32 B=1, M=4096, H=160, K=128  |      1220066.8       |       1223422.8        |            |         1227912.2        
                          f16 B=2, M=4096, H=160, K=128  |      1734244.4       |       1794068.7        |            |         1756266.7        
                          f32 B=2, M=4096, H=160, K=128  |      2437675.5       |       2445780.4        |            |         2451957.5        
                          f16 B=1, M=8192, H=160, K=128  |      4367110.4       |       4466170.9        |            |         4383747.4        
                          f32 B=1, M=8192, H=160, K=128  |      4865732.9       |       4865708.9        |            |         4887066.5        
                          f16 B=2, M=8192, H=160, K=128  |      7002715.1       |       7146077.9        |            |         7033922.8        
                          f16 B=1024, M=82, H=8, K=64    |        23247.5       |         24929.5        |   18047.8  |           26928.2        
                          f32 B=1024, M=82, H=8, K=64    |        46463.2       |         48705.6        |   22797.5  |           50736.3        
                          f16 B=150, M=256, H=16, K=64   |        23467.9       |         25647.3        |   24569.2  |           26841.8        
                          f32 B=150, M=256, H=16, K=64   |        36887.7       |         39698.0        |   32050.2  |           40389.0        
                          f16 B=64, M=256, H=12, K=64    |         7723.7       |          8499.0        |    7702.1  |            8694.9        
                          f32 B=64, M=256, H=12, K=64    |        11992.1       |         12819.9        |    9874.5  |           13107.9        
                          f16 B=1, M=4096, H=16, K=40    |       142655.5       |        142899.7        |   28928.6  |          142922.7        
                          f32 B=1, M=4096, H=16, K=40    |       142626.8       |        142685.3        |   37303.2  |          142541.0        
                          f16 B=1, M=16384, H=16, K=40   |      2274095.0       |       2274882.0        |            |         2275019.9        
                          f32 B=1, M=16384, H=16, K=40   |      2284027.2       |       2279415.7        |            |         2277761.9        
                          f16 B=16, M=128, H=16, K=16    |          513.2       |           547.1        |     571.5  |             570.9        
                          f32 B=16, M=128, H=16, K=16    |          667.4       |           704.3        |     693.1  |             728.0        
                          f16 B=16, M=128, H=16, K=32    |          600.3       |           667.0        |     671.3  |             713.1        
                          f32 B=16, M=128, H=16, K=32    |          823.9       |           888.9        |     823.5  |             937.3        
                          f16 B=16, M=128, H=16, K=64    |          781.0       |           900.6        |     883.1  |             998.9        
                          f32 B=16, M=128, H=16, K=64    |         1173.7       |          1293.8        |    1077.0  |            1393.4        
                          f16 B=16, M=128, H=16, K=128   |         1649.2       |          1877.2        |    1323.2  |            2026.3        
                          f32 B=16, M=128, H=16, K=128   |         2250.5       |          2473.0        |    1654.7  |            2636.6        
                          f16 B=16, M=512, H=16, K=16    |         7709.3       |          7914.6        |    6945.1  |            7928.7        
                          f32 B=16, M=512, H=16, K=16    |         9797.2       |          9950.5        |    8499.4  |           10029.3        
                          f16 B=16, M=512, H=16, K=32    |         8956.9       |          9210.8        |    7517.1  |            9307.0        
                          f32 B=16, M=512, H=16, K=32    |        11480.7       |         11710.9        |    9249.4  |           11884.4        
                          f16 B=16, M=512, H=16, K=64    |        11324.0       |         11829.1        |    8849.5  |           12001.8        
                          f32 B=16, M=512, H=16, K=64    |        15744.1       |         16258.0        |   10954.6  |           16481.1        
                          f16 B=16, M=512, H=16, K=128   |        25320.2       |         26584.0        |   12412.3  |           26725.0        
                          f32 B=16, M=512, H=16, K=128   |        31187.1       |         32290.3        |   15167.5  |           32818.4        
                          f16 B=16, M=1024, H=16, K=16   |        31484.2       |         31601.4        |   26434.6  |           31894.6        
                          f32 B=16, M=1024, H=16, K=16   |        38754.1       |         38900.1        |   32320.0  |           39203.9        
                          f16 B=16, M=1024, H=16, K=32   |        36000.2       |         36672.6        |   28341.4  |           36579.5        
                          f32 B=16, M=1024, H=16, K=32   |        45070.7       |         45262.3        |   34914.2  |           45774.5        
                          f16 B=16, M=1024, H=16, K=64   |        45324.9       |         46540.4        |   32089.9  |           46784.2        
                          f32 B=16, M=1024, H=16, K=64   |        61320.3       |         62411.1        |   39565.0  |           63217.0        
                          f16 B=16, M=1024, H=16, K=128  |       104342.9       |        108469.4        |   43221.9  |          105620.6        
                          f32 B=16, M=1024, H=16, K=128  |       122688.4       |        125050.9        |   51205.7  |          126080.9        
                          f16 B=64, M=128, H=16, K=16    |         1707.9       |          1824.9        |    2106.4  |            1923.2        
                          f32 B=64, M=128, H=16, K=16    |         2487.4       |          2612.5        |    2565.1  |            2707.6        
                          f16 B=64, M=128, H=16, K=32    |         2016.8       |          2254.4        |    2485.4  |            2412.3        
                          f32 B=64, M=128, H=16, K=32    |         3135.8       |          3365.6        |    3063.2  |            3518.5        
                          f16 B=64, M=128, H=16, K=64    |         2700.2       |          3167.0        |    3306.0  |            3478.4        
                          f32 B=64, M=128, H=16, K=64    |         4435.1       |          4944.7        |    4227.6  |            5181.2        
                          f16 B=64, M=128, H=16, K=128   |         5769.1       |          6858.2        |    5299.8  |            7356.1        
                          f32 B=64, M=128, H=16, K=128   |         8577.9       |          9672.0        |    6916.3  |           10093.5        
                          f16 B=64, M=512, H=16, K=16    |        25994.0       |         26782.0        |   27240.9  |           26662.2        
                          f32 B=64, M=512, H=16, K=16    |        36864.9       |         37299.3        |   34159.3  |           37576.7        
                          f16 B=64, M=512, H=16, K=32    |        30680.4       |         32113.8        |   30109.0  |           32419.7        
                          f32 B=64, M=512, H=16, K=32    |        43638.5       |         44557.9        |   37358.5  |           45145.0        
                          f16 B=64, M=512, H=16, K=64    |        39417.5       |         41666.5        |   36004.2  |           42374.9        
                          f32 B=64, M=512, H=16, K=64    |        60049.2       |         63148.0        |   43412.6  |           63286.8        
                          f16 B=64, M=512, H=16, K=128   |        88951.1       |         93087.0        |   51730.1  |           94861.6        
                          f32 B=64, M=512, H=16, K=128   |       119728.7       |        124340.3        |   62413.7  |          126382.2        
                          f16 B=64, M=1024, H=16, K=16   |       108368.3       |        111081.8        |  106479.7  |          108716.1        
                          f32 B=64, M=1024, H=16, K=16   |       145612.0       |        147310.4        |            |          147380.7        
                          f16 B=64, M=1024, H=16, K=32   |       124296.1       |        127366.8        |  113905.0  |          126975.3        
                          f32 B=64, M=1024, H=16, K=32   |       171082.3       |        172539.0        |            |          173893.9        
                          f16 B=64, M=1024, H=16, K=64   |       155116.3       |        160429.2        |  130759.4  |          161834.0        
                          f32 B=64, M=1024, H=16, K=64   |       234356.0       |        239612.2        |            |          239948.3        
                          f16 B=64, M=1024, H=16, K=128  |       349728.3       |        360975.7        |  176158.7  |          371185.2        
                          f32 B=64, M=1024, H=16, K=128  |       468810.0       |        476415.4        |            |          481908.5        
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |         1700.3       |          1840.0        |    1375.3  |            1930.9        
                          f32 B=384, M=197, H=1, K=88    |         4456.4       |          4579.3        |    2235.5  |            4708.6        
                          f16 B=384, M=197, H=1, K=80    |         1623.3       |          1719.9        |    1279.5  |            1806.9        
                          f32 B=384, M=197, H=1, K=80    |         4031.2       |          4141.9        |    2149.8  |            4252.6        
                          f16 B=384, M=197, H=1, K=64    |         1092.8       |          1187.0        |    1048.5  |            1237.6        
                          f32 B=384, M=197, H=1, K=64    |         2717.5       |          2918.5        |    1738.5  |            2907.9        
                          f16 B=1024, M=197, H=1, K=88   |         4428.7       |          4906.2        |    3723.7  |            5178.2        
                          f32 B=1024, M=197, H=1, K=88   |        10947.5       |         11362.9        |    6052.5  |           11802.1        
                          f16 B=1024, M=197, H=1, K=80   |         4237.1       |          4491.4        |    3331.7  |            4725.6        
                          f32 B=1024, M=197, H=1, K=80   |         9842.6       |         10159.7        |    5682.4  |           10435.6        
                          f16 B=1024, M=197, H=1, K=64   |         2679.2       |          2927.4        |    2674.4  |            3033.0        
                          f32 B=1024, M=197, H=1, K=64   |         6597.6       |          7154.9        |    4489.7  |            7063.1        
                          f16 B=512, M=197, H=1, K=80    |         2239.5       |          2366.5        |    1684.2  |            2472.0        
                          f32 B=512, M=197, H=1, K=80    |         5362.4       |          5519.6        |    2857.9  |            5651.4        
                          f16 B=32, M=197, H=16, K=80    |         2208.1       |          2380.0        |    1803.4  |            2439.4        
                          f32 B=32, M=197, H=16, K=80    |         5503.6       |          5736.7        |    3017.5  |            5796.2        
                          f16 B=32, M=197, H=16, K=64    |         1493.4       |          1620.6        |    1457.2  |            1678.6        
                          f32 B=32, M=197, H=16, K=64    |         3672.6       |          3941.6        |    2415.0  |            3898.2        
                          f16 B=32, M=197, H=16, K=128   |         2634.3       |          2888.0        |    2215.1  |            2991.5        
                          f32 B=32, M=197, H=16, K=128   |         6811.5       |          7334.0        |    4049.3  |            7261.9        
                          f16 B=256, M=197, H=1, K=88    |         1290.3       |          1382.0        |     944.8  |            1449.4        
                          f32 B=256, M=197, H=1, K=88    |         2965.8       |          3043.2        |    1528.7  |            3137.7        
                          f16 B=16, M=197, H=16, K=88    |         1267.3       |          1357.0        |     970.8  |            1395.5        
                          f32 B=16, M=197, H=16, K=88    |         2879.9       |          3014.7        |    1626.5  |            3054.3        
                          f16 B=16, M=197, H=16, K=64    |          737.3       |           799.8        |     771.3  |             836.9        
                          f32 B=16, M=197, H=16, K=64    |         1879.2       |          2000.9        |    1282.5  |            1994.5        
                          f16 B=16, M=197, H=16, K=128   |         1443.9       |          1570.7        |    1142.2  |            1628.8        
                          f32 B=16, M=197, H=16, K=128   |         3480.5       |          3723.6        |    2027.2  |            3714.6        
                          f16 B=1, M=4096, H=160, K=128  |       150006.2       |        151877.5        |            |          152570.6        
                          f32 B=1, M=4096, H=160, K=128  |       582870.9       |        583519.8        |            |          585570.1        
                          f16 B=2, M=4096, H=160, K=128  |       301231.4       |        304511.7        |            |          305801.2        
                          f32 B=2, M=4096, H=160, K=128  |      1174724.1       |       1172498.4        |            |         1176814.0        
                          f16 B=1, M=8192, H=160, K=128  |       597461.6       |        600463.4        |            |          603066.6        
                          f32 B=1, M=8192, H=160, K=128  |      2333657.8       |       2329212.1        |            |         2339766.1        
                          f16 B=2, M=8192, H=160, K=128  |      1196837.5       |       1206932.4        |            |         1209012.2        
                          f16 B=1024, M=82, H=8, K=64    |         8926.8       |          9723.4        |    5799.4  |           10084.2        
                          f32 B=1024, M=82, H=8, K=64    |        15920.4       |         17434.4        |   11027.0  |           17492.8        
                          f16 B=150, M=256, H=16, K=64   |         5524.2       |          6363.9        |    7557.9  |            6586.2        
                          f32 B=150, M=256, H=16, K=64   |        17506.9       |         18843.5        |   16263.5  |           18988.6        
                          f16 B=64, M=256, H=12, K=64    |         1800.6       |          2050.3        |    2383.4  |            2139.0        
                          f32 B=64, M=256, H=12, K=64    |         5753.6       |          6196.3        |    4971.2  |            6200.0        
                          f16 B=1, M=4096, H=16, K=40    |        47649.5       |         47836.0        |    8368.4  |           47973.6        
                          f32 B=1, M=4096, H=16, K=40    |       111092.1       |        111027.3        |   19475.9  |          111257.8        
                          f16 B=1, M=16384, H=16, K=40   |       765320.2       |        765686.9        |            |          767337.2        
                          f32 B=1, M=16384, H=16, K=40   |      1769169.0       |       1769675.1        |            |         1769371.4        
                          f16 B=16, M=128, H=16, K=16    |          178.9       |           196.8        |     445.9  |             188.3        
                          f32 B=16, M=128, H=16, K=16    |          301.3       |           319.1        |     422.5  |             336.3        
                          f16 B=16, M=128, H=16, K=32    |          174.1       |           174.2        |     394.0  |             179.5        
                          f32 B=16, M=128, H=16, K=32    |          395.7       |           433.2        |     580.0  |             440.4        
                          f16 B=16, M=128, H=16, K=64    |          205.0       |           253.5        |     460.6  |             270.9        
                          f32 B=16, M=128, H=16, K=64    |          573.7       |           639.3        |     598.1  |             656.1        
                          f16 B=16, M=128, H=16, K=128   |          399.5       |           484.3        |     515.2  |             521.8        
                          f32 B=16, M=128, H=16, K=128   |         1126.3       |          1260.8        |    1008.1  |            1282.4        
                          f16 B=16, M=512, H=16, K=16    |         1597.6       |          1627.2        |    1901.1  |            1662.1        
                          f32 B=16, M=512, H=16, K=16    |         4458.5       |          4528.8        |    4232.0  |            4559.4        
                          f16 B=16, M=512, H=16, K=32    |         1819.1       |          1868.7        |    2097.2  |            1945.5        
                          f32 B=16, M=512, H=16, K=32    |         5604.2       |          5757.1        |    4566.4  |            5784.8        
                          f16 B=16, M=512, H=16, K=64    |         2345.5       |          2495.6        |    2558.0  |            2573.2        
                          f32 B=16, M=512, H=16, K=64    |         7778.3       |          8017.1        |    5488.2  |            8083.7        
                          f16 B=16, M=512, H=16, K=128   |         4516.6       |          4821.0        |    3386.7  |            4968.2        
                          f32 B=16, M=512, H=16, K=128   |        15412.7       |         15959.2        |    8865.9  |           16047.5        
                          f16 B=16, M=1024, H=16, K=16   |         6195.9       |          6217.6        |    6995.3  |            6326.4        
                          f32 B=16, M=1024, H=16, K=16   |        18136.2       |         18312.0        |   16088.2  |           18354.1        
                          f16 B=16, M=1024, H=16, K=32   |         7072.8       |          7122.3        |    7406.9  |            7297.7        
                          f32 B=16, M=1024, H=16, K=32   |        22108.2       |         22116.7        |   17112.5  |           22436.8        
                          f16 B=16, M=1024, H=16, K=64   |         8868.0       |          9104.6        |    8627.1  |            9311.8        
                          f32 B=16, M=1024, H=16, K=64   |        30710.5       |         31041.3        |   19860.8  |           31338.1        
                          f16 B=16, M=1024, H=16, K=128  |        17091.8       |         17655.5        |   10548.3  |           18083.8        
                          f32 B=16, M=1024, H=16, K=128  |        60317.8       |         61461.7        |   32919.2  |           61548.8        
                          f16 B=64, M=128, H=16, K=16    |          413.6       |           453.8        |     635.5  |             480.6        
                          f32 B=64, M=128, H=16, K=16    |         1033.8       |          1114.3        |    1238.9  |            1119.5        
                          f16 B=64, M=128, H=16, K=32    |          505.7       |           587.9        |     813.6  |             630.1        
                          f32 B=64, M=128, H=16, K=32    |         1423.0       |          1551.4        |    1533.4  |            1581.8        
                          f16 B=64, M=128, H=16, K=64    |          743.3       |           916.8        |    1187.7  |             976.5        
                          f32 B=64, M=128, H=16, K=64    |         2093.3       |          2384.6        |    2156.3  |            2405.4        
                          f16 B=64, M=128, H=16, K=128   |         1408.2       |          1734.3        |    1918.7  |            1859.6        
                          f32 B=64, M=128, H=16, K=128   |         4125.3       |          4671.4        |    3762.0  |            4717.0        
                          f16 B=64, M=512, H=16, K=16    |         5531.2       |          5643.3        |    7454.4  |            5770.8        
                          f32 B=64, M=512, H=16, K=16    |        16214.0       |         16531.2        |   16661.3  |           16540.8        
                          f16 B=64, M=512, H=16, K=32    |         6495.5       |          6725.2        |    8353.7  |            6941.8        
                          f32 B=64, M=512, H=16, K=32    |        20520.6       |         20941.9        |   18352.4  |           21116.8        
                          f16 B=64, M=512, H=16, K=64    |         8686.1       |          9278.6        |   10343.4  |            9593.2        
                          f32 B=64, M=512, H=16, K=64    |        28891.1       |         30003.0        |   22749.4  |           30139.1        
                          f16 B=64, M=512, H=16, K=128   |        15991.4       |         17412.3        |   14633.0  |           17848.2        
                          f32 B=64, M=512, H=16, K=128   |        57526.8       |         59970.8        |   40089.9  |           60016.9        
                          f16 B=64, M=1024, H=16, K=16   |        21552.8       |         21603.1        |   28447.1  |           22030.0        
                          f32 B=64, M=1024, H=16, K=16   |        65321.2       |         65736.8        |            |           65932.0        
                          f16 B=64, M=1024, H=16, K=32   |        25695.4       |         25905.9        |   30592.1  |           26644.8        
                          f32 B=64, M=1024, H=16, K=32   |        80213.4       |         80446.7        |            |           81363.1        
                          f16 B=64, M=1024, H=16, K=64   |        32465.6       |         33575.1        |   37233.4  |           34370.8        
                          f32 B=64, M=1024, H=16, K=64   |       112996.7       |        115632.0        |            |          115970.8        
                          f16 B=64, M=1024, H=16, K=128  |        60363.5       |         62800.2        |   48883.7  |           64505.1        
                          f32 B=64, M=1024, H=16, K=128  |       225023.4       |        230527.4        |            |          229851.8        

Times are in microseconds (us).
```
</details>

[ghstack-poisoned]
danthe3rd pushed a commit that referenced this pull request Oct 5, 2022
ghstack-source-id: b747ec522ded39beee04a77dbe70238877b0245b
Pull Request resolved: #458
**SUMMARY**

Also:
- updated benchmarks to reflect the real-world scenario where we chunk a packed qkv - for both fw/bw.
- added coverage for chunking in tests

**PERFORMANCE IMPACT**

<details>
<summary>A100 bw (new benchmarks)</summary>

```
[---------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------]
                                     |  48_chunk3_31735f9  |  45_bwpacked_e53c5f3  |  vanilla  |  47_bwpackedgrad_9bacdf6
1 threads: --------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |          560.7      |           663.9       |   2265.7  |             710.3       
      f32 B=384, M=197, H=1, K=88    |         2445.1      |          2540.3       |   1843.3  |            2611.0       
      f16 B=384, M=197, H=1, K=80    |          530.4      |           619.9       |   1922.8  |             663.0       
      f32 B=384, M=197, H=1, K=80    |         2326.1      |          2425.2       |   1788.7  |            2476.4       
      f16 B=384, M=197, H=1, K=64    |          391.7      |           462.2       |   1812.7  |             492.8       
      f32 B=384, M=197, H=1, K=64    |         1275.0      |          1379.4       |   1675.4  |            1388.4       
      f16 B=1024, M=197, H=1, K=88   |         1399.5      |          1666.2       |   5965.2  |            1775.5       
      f32 B=1024, M=197, H=1, K=88   |         6332.5      |          6618.1       |   4559.6  |            6740.5       
      f16 B=1024, M=197, H=1, K=80   |         1326.2      |          1543.9       |   5041.4  |            1652.3       
      f32 B=1024, M=197, H=1, K=80   |         6057.1      |          6301.3       |   4411.6  |            6433.6       
      f16 B=1024, M=197, H=1, K=64   |          876.9      |          1063.1       |   4749.3  |            1133.2       
      f32 B=1024, M=197, H=1, K=64   |         3360.2      |          3629.0       |   4118.8  |            3652.0       
      f16 B=512, M=197, H=1, K=80    |          669.0      |           786.4       |   2544.9  |             842.2       
      f32 B=512, M=197, H=1, K=80    |         3032.3      |          3127.8       |   2287.4  |            3229.8       
      f16 B=32, M=197, H=16, K=80    |          663.0      |           789.7       |   2569.0  |             837.8       
      f32 B=32, M=197, H=16, K=80    |         3005.5      |          3166.3       |   2354.1  |            3225.9       
      f16 B=32, M=197, H=16, K=64    |          459.9      |           553.4       |   2436.3  |             591.9       
      f32 B=32, M=197, H=16, K=64    |         1814.1      |          1962.5       |   2197.3  |            1962.1       
      f16 B=32, M=197, H=16, K=128   |          792.5      |           981.9       |   4505.9  |            1056.5       
      f32 B=32, M=197, H=16, K=128   |         3734.8      |          3995.7       |   2805.8  |            4021.5       
      f16 B=256, M=197, H=1, K=88    |          413.4      |           482.6       |   1529.5  |             515.5       
      f32 B=256, M=197, H=1, K=88    |         1741.9      |          1818.3       |   1208.6  |            1852.4       
      f16 B=16, M=197, H=16, K=88    |          410.3      |           482.9       |   1545.7  |             512.5       
      f32 B=16, M=197, H=16, K=88    |         1734.9      |          1832.1       |   1250.6  |            1849.4       
      f16 B=16, M=197, H=16, K=64    |          235.4      |           286.0       |   1247.1  |             305.3       
      f32 B=16, M=197, H=16, K=64    |         1077.1      |          1143.7       |   1125.9  |            1154.0       
      f16 B=16, M=197, H=16, K=128   |          455.4      |           554.1       |   2273.1  |             596.0       
      f32 B=16, M=197, H=16, K=128   |         2028.9      |          2164.5       |   1446.7  |            2175.0       
      f16 B=1, M=4096, H=160, K=128  |        62454.4      |         63474.5       |  45930.5  |           64052.7       
      f32 B=1, M=4096, H=160, K=128  |       239035.4      |        232672.1       |           |          240073.9       
      f16 B=2, M=4096, H=160, K=128  |        98791.3      |        101006.4       |           |          101942.0       
      f32 B=2, M=4096, H=160, K=128  |       375914.9      |        368050.6       |           |          381280.4       
      f16 B=1, M=8192, H=160, K=128  |       248498.9      |        250066.9       |           |          251500.4       
      f32 B=1, M=8192, H=160, K=128  |       945102.2      |        922549.3       |           |          949256.4       
      f16 B=2, M=8192, H=160, K=128  |       389207.8      |        394486.6       |           |          396190.4       
      f32 B=2, M=8192, H=160, K=128  |      1496334.3      |       1449974.3       |           |         1502215.3       
      f16 B=1024, M=82, H=8, K=64    |         1872.4      |          2503.8       |   3819.8  |            2693.7       
      f32 B=1024, M=82, H=8, K=64    |         8734.3      |          9637.8       |   8732.9  |            9672.2       
      f16 B=150, M=256, H=16, K=64   |         2126.4      |          2713.4       |   4554.3  |            2880.8       
      f32 B=150, M=256, H=16, K=64   |         6214.3      |          7052.2       |  12943.2  |            7099.2       
      f16 B=64, M=256, H=12, K=64    |          741.2      |           930.1       |   1493.0  |             990.6       
      f32 B=64, M=256, H=12, K=64    |         2144.2      |          2408.5       |   4267.7  |            2433.8       
      f16 B=1, M=4096, H=16, K=40    |        24583.7      |         24224.8       |   4195.2  |           24500.2       
      f32 B=1, M=4096, H=16, K=40    |        72497.9      |         72070.8       |  17744.1  |           72393.0       
      f16 B=1, M=16384, H=16, K=40   |       451481.8      |        439027.7       |           |          451499.9       
      f32 B=1, M=16384, H=16, K=40   |      1169509.1      |       1164880.1       |           |         1169769.3       
      f16 B=256, M=4096, H=16, K=64  |       597391.6      |        625921.0       |           |          610433.2       
      f16 B=16, M=128, H=16, K=16    |           93.1      |           126.7       |    241.2  |             132.3       
      f32 B=16, M=128, H=16, K=16    |          184.1      |           176.5       |    373.8  |             180.7       
      f16 B=16, M=128, H=16, K=32    |          127.9      |           126.3       |    241.4  |             106.7       
      f32 B=16, M=128, H=16, K=32    |          194.1      |           216.6       |    412.7  |             225.8       
      f16 B=16, M=128, H=16, K=64    |          131.4      |           126.8       |    239.8  |             134.5       
      f32 B=16, M=128, H=16, K=64    |          280.4      |           326.0       |    500.0  |             334.0       
      f16 B=16, M=128, H=16, K=128   |          175.6      |           236.1       |    298.8  |             261.1       
      f32 B=16, M=128, H=16, K=128   |          531.8      |           615.8       |    677.2  |             638.0       
      f16 B=16, M=512, H=16, K=16    |          558.2      |           595.0       |   1201.9  |             607.8       
      f32 B=16, M=512, H=16, K=16    |         2146.7      |          2169.9       |   4416.1  |            2200.6       
      f16 B=16, M=512, H=16, K=32    |          653.5      |           732.3       |   1305.1  |             748.5       
      f32 B=16, M=512, H=16, K=32    |         2296.3      |          2373.9       |   4641.3  |            2400.1       
      f16 B=16, M=512, H=16, K=64    |          848.8      |           996.9       |   1544.6  |            1022.5       
      f32 B=16, M=512, H=16, K=64    |         2954.0      |          3117.1       |   5124.7  |            3157.6       
      f16 B=16, M=512, H=16, K=128   |         1735.4      |          1961.1       |   1982.7  |            2056.9       
      f32 B=16, M=512, H=16, K=128   |         6218.7      |          6396.4       |   6094.0  |            6600.3       
      f16 B=16, M=1024, H=16, K=16   |         2236.4      |          2319.4       |   4279.0  |            2331.6       
      f32 B=16, M=1024, H=16, K=16   |         8379.2      |          8363.9       |  16643.9  |            8503.6       
      f16 B=16, M=1024, H=16, K=32   |         2430.8      |          2649.6       |   4496.8  |            2608.7       
      f32 B=16, M=1024, H=16, K=32   |         8864.7      |          8907.8       |  17291.0  |            9074.0       
      f16 B=16, M=1024, H=16, K=64   |         3007.2      |          3351.3       |   4995.5  |            3351.0       
      f32 B=16, M=1024, H=16, K=64   |        11355.4      |         11627.1       |  18707.5  |           11694.3       
      f16 B=16, M=1024, H=16, K=128  |         6296.2      |          6748.7       |   5943.5  |            6967.0       
      f32 B=16, M=1024, H=16, K=128  |        23425.3      |         23360.0       |  21520.6  |           24169.7       
      f16 B=64, M=128, H=16, K=16    |          165.5      |           195.9       |    440.3  |             211.5       
      f32 B=64, M=128, H=16, K=16    |          497.4      |           540.7       |   1270.8  |             550.3       
      f16 B=64, M=128, H=16, K=32    |          210.4      |           274.9       |    544.8  |             298.5       
      f32 B=64, M=128, H=16, K=32    |          604.4      |           696.6       |   1428.3  |             710.9       
      f16 B=64, M=128, H=16, K=64    |          330.4      |           452.3       |    766.0  |             498.1       
      f32 B=64, M=128, H=16, K=64    |          883.4      |          1060.4       |   1745.2  |            1082.2       
      f16 B=64, M=128, H=16, K=128   |          605.5      |           847.8       |   1223.6  |             933.9       
      f32 B=64, M=128, H=16, K=128   |         1847.4      |          2169.7       |   2388.8  |            2236.0       
      f16 B=64, M=512, H=16, K=16    |         2004.7      |          2120.0       |   4487.0  |            2179.4       
      f32 B=64, M=512, H=16, K=16    |         6655.4      |          6818.8       |  16993.8  |            6872.1       
      f16 B=64, M=512, H=16, K=32    |         2379.3      |          2593.1       |   4957.2  |            2704.0       
      f32 B=64, M=512, H=16, K=32    |         7349.4      |          7644.6       |  17852.2  |            7736.2       
      f16 B=64, M=512, H=16, K=64    |         3129.6      |          3616.6       |   5888.8  |            3786.2       
      f32 B=64, M=512, H=16, K=64    |         9432.5      |         10123.9       |  19770.6  |           10178.5       
      f16 B=64, M=512, H=16, K=128   |         6054.1      |          7019.9       |   7712.6  |            7350.2       
      f32 B=64, M=512, H=16, K=128   |        21565.6      |         22281.9       |  23653.0  |           23084.4       
      f16 B=64, M=1024, H=16, K=16   |         7929.4      |          8199.1       |  16876.3  |            8242.5       
      f32 B=64, M=1024, H=16, K=16   |        26135.2      |         26347.9       |  66351.1  |           26639.0       
      f16 B=64, M=1024, H=16, K=32   |         8876.8      |          9450.0       |  17869.4  |            9473.5       
      f32 B=64, M=1024, H=16, K=32   |        27685.3      |         28104.6       |  69105.9  |           28428.7       
      f16 B=64, M=1024, H=16, K=64   |        11198.7      |         12180.5       |  19932.3  |           12543.4       
      f32 B=64, M=1024, H=16, K=64   |        34978.2      |         36239.4       |  74813.7  |           36482.4       
      f16 B=64, M=1024, H=16, K=128  |        21618.9      |         23439.6       |  23741.1  |           24160.1       
      f32 B=64, M=1024, H=16, K=128  |        80785.3      |         81080.8       |  86003.6  |           84132.9       

Times are in microseconds (us).
```
</details>

<details>
<summary>P100/V100 bw (new benchmarks)</summary>

```
[---------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------------]
                                                         |  48_chunk3_31735f94  |  45_bwpacked_e53c5f3a  |  vanilla   |  47_bwpackedgrad_9bacdf65
1 threads: --------------------------------------------------------------------------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |         6846.3       |          7583.8        |    3569.3  |            7599.5        
                          f32 B=384, M=197, H=1, K=88    |         9883.1       |         10107.2        |    4312.8  |           10486.3        
                          f16 B=384, M=197, H=1, K=80    |         6486.4       |          6997.7        |    3418.0  |            7037.3        
                          f32 B=384, M=197, H=1, K=80    |         9330.3       |          9550.6        |    4094.7  |            9893.4        
                          f16 B=384, M=197, H=1, K=64    |         3615.4       |          3930.4        |    2911.0  |            4074.2        
                          f32 B=384, M=197, H=1, K=64    |         6281.4       |          6554.5        |    3431.9  |            6738.1        
                          f16 B=1024, M=197, H=1, K=88   |        17226.8       |         18593.1        |    9733.2  |           18772.9        
                          f32 B=1024, M=197, H=1, K=88   |        26593.3       |         27136.2        |   12033.8  |           28184.2        
                          f16 B=1024, M=197, H=1, K=80   |        16330.1       |         17478.6        |    9270.2  |           17735.3        
                          f32 B=1024, M=197, H=1, K=80   |        25208.9       |         25680.1        |   11224.5  |           26636.1        
                          f16 B=1024, M=197, H=1, K=64   |         8889.1       |          9728.8        |    7646.1  |           10089.7        
                          f32 B=1024, M=197, H=1, K=64   |        16914.7       |         17743.4        |    9383.8  |           18068.4        
                          f16 B=512, M=197, H=1, K=80    |         8227.3       |          8878.4        |    4579.3  |            8953.6        
                          f32 B=512, M=197, H=1, K=80    |        13078.7       |         13346.0        |    5486.4  |           13817.6        
                          f16 B=32, M=197, H=16, K=80    |         8278.9       |          9002.9        |    4816.2  |            9025.6        
                          f32 B=32, M=197, H=16, K=80    |        12913.8       |         13371.2        |    5777.7  |           13667.6        
                          f16 B=32, M=197, H=16, K=64    |         4565.2       |          5000.0        |    4023.4  |            5146.3        
                          f32 B=32, M=197, H=16, K=64    |         8824.0       |          9257.7        |    4797.2  |            9400.5        
                          f16 B=32, M=197, H=16, K=128   |         9770.0       |         10849.7        |    5983.2  |           10932.0        
                          f32 B=32, M=197, H=16, K=128   |        15715.2       |         16559.9        |    7513.6  |           16839.9        
                          f16 B=256, M=197, H=1, K=88    |         5011.2       |          5363.8        |    2444.9  |            5426.0        
                          f32 B=256, M=197, H=1, K=88    |         6918.7       |          7040.8        |    2867.8  |            7303.2        
                          f16 B=16, M=197, H=16, K=88    |         4963.8       |          5343.9        |    2545.2  |            5398.9        
                          f32 B=16, M=197, H=16, K=88    |         6727.9       |          6981.7        |    3040.3  |            7121.2        
                          f16 B=16, M=197, H=16, K=64    |         2586.5       |          2777.1        |    2025.5  |            2905.6        
                          f32 B=16, M=197, H=16, K=64    |         4404.3       |          4607.2        |    2431.1  |            4691.8        
                          f16 B=16, M=197, H=16, K=128   |         5643.2       |          6194.1        |    3016.1  |            6216.3        
                          f32 B=16, M=197, H=16, K=128   |         7887.1       |          8308.3        |    3676.6  |            8456.2        
                          f16 B=1, M=4096, H=160, K=128  |      1087008.7       |       1115355.5        |            |         1091596.8        
                          f32 B=1, M=4096, H=160, K=128  |      1220066.8       |       1223422.8        |            |         1227912.2        
                          f16 B=2, M=4096, H=160, K=128  |      1734244.4       |       1794068.7        |            |         1756266.7        
                          f32 B=2, M=4096, H=160, K=128  |      2437675.5       |       2445780.4        |            |         2451957.5        
                          f16 B=1, M=8192, H=160, K=128  |      4367110.4       |       4466170.9        |            |         4383747.4        
                          f32 B=1, M=8192, H=160, K=128  |      4865732.9       |       4865708.9        |            |         4887066.5        
                          f16 B=2, M=8192, H=160, K=128  |      7002715.1       |       7146077.9        |            |         7033922.8        
                          f16 B=1024, M=82, H=8, K=64    |        23247.5       |         24929.5        |   18047.8  |           26928.2        
                          f32 B=1024, M=82, H=8, K=64    |        46463.2       |         48705.6        |   22797.5  |           50736.3        
                          f16 B=150, M=256, H=16, K=64   |        23467.9       |         25647.3        |   24569.2  |           26841.8        
                          f32 B=150, M=256, H=16, K=64   |        36887.7       |         39698.0        |   32050.2  |           40389.0        
                          f16 B=64, M=256, H=12, K=64    |         7723.7       |          8499.0        |    7702.1  |            8694.9        
                          f32 B=64, M=256, H=12, K=64    |        11992.1       |         12819.9        |    9874.5  |           13107.9        
                          f16 B=1, M=4096, H=16, K=40    |       142655.5       |        142899.7        |   28928.6  |          142922.7        
                          f32 B=1, M=4096, H=16, K=40    |       142626.8       |        142685.3        |   37303.2  |          142541.0        
                          f16 B=1, M=16384, H=16, K=40   |      2274095.0       |       2274882.0        |            |         2275019.9        
                          f32 B=1, M=16384, H=16, K=40   |      2284027.2       |       2279415.7        |            |         2277761.9        
                          f16 B=16, M=128, H=16, K=16    |          513.2       |           547.1        |     571.5  |             570.9        
                          f32 B=16, M=128, H=16, K=16    |          667.4       |           704.3        |     693.1  |             728.0        
                          f16 B=16, M=128, H=16, K=32    |          600.3       |           667.0        |     671.3  |             713.1        
                          f32 B=16, M=128, H=16, K=32    |          823.9       |           888.9        |     823.5  |             937.3        
                          f16 B=16, M=128, H=16, K=64    |          781.0       |           900.6        |     883.1  |             998.9        
                          f32 B=16, M=128, H=16, K=64    |         1173.7       |          1293.8        |    1077.0  |            1393.4        
                          f16 B=16, M=128, H=16, K=128   |         1649.2       |          1877.2        |    1323.2  |            2026.3        
                          f32 B=16, M=128, H=16, K=128   |         2250.5       |          2473.0        |    1654.7  |            2636.6        
                          f16 B=16, M=512, H=16, K=16    |         7709.3       |          7914.6        |    6945.1  |            7928.7        
                          f32 B=16, M=512, H=16, K=16    |         9797.2       |          9950.5        |    8499.4  |           10029.3        
                          f16 B=16, M=512, H=16, K=32    |         8956.9       |          9210.8        |    7517.1  |            9307.0        
                          f32 B=16, M=512, H=16, K=32    |        11480.7       |         11710.9        |    9249.4  |           11884.4        
                          f16 B=16, M=512, H=16, K=64    |        11324.0       |         11829.1        |    8849.5  |           12001.8        
                          f32 B=16, M=512, H=16, K=64    |        15744.1       |         16258.0        |   10954.6  |           16481.1        
                          f16 B=16, M=512, H=16, K=128   |        25320.2       |         26584.0        |   12412.3  |           26725.0        
                          f32 B=16, M=512, H=16, K=128   |        31187.1       |         32290.3        |   15167.5  |           32818.4        
                          f16 B=16, M=1024, H=16, K=16   |        31484.2       |         31601.4        |   26434.6  |           31894.6        
                          f32 B=16, M=1024, H=16, K=16   |        38754.1       |         38900.1        |   32320.0  |           39203.9        
                          f16 B=16, M=1024, H=16, K=32   |        36000.2       |         36672.6        |   28341.4  |           36579.5        
                          f32 B=16, M=1024, H=16, K=32   |        45070.7       |         45262.3        |   34914.2  |           45774.5        
                          f16 B=16, M=1024, H=16, K=64   |        45324.9       |         46540.4        |   32089.9  |           46784.2        
                          f32 B=16, M=1024, H=16, K=64   |        61320.3       |         62411.1        |   39565.0  |           63217.0        
                          f16 B=16, M=1024, H=16, K=128  |       104342.9       |        108469.4        |   43221.9  |          105620.6        
                          f32 B=16, M=1024, H=16, K=128  |       122688.4       |        125050.9        |   51205.7  |          126080.9        
                          f16 B=64, M=128, H=16, K=16    |         1707.9       |          1824.9        |    2106.4  |            1923.2        
                          f32 B=64, M=128, H=16, K=16    |         2487.4       |          2612.5        |    2565.1  |            2707.6        
                          f16 B=64, M=128, H=16, K=32    |         2016.8       |          2254.4        |    2485.4  |            2412.3        
                          f32 B=64, M=128, H=16, K=32    |         3135.8       |          3365.6        |    3063.2  |            3518.5        
                          f16 B=64, M=128, H=16, K=64    |         2700.2       |          3167.0        |    3306.0  |            3478.4        
                          f32 B=64, M=128, H=16, K=64    |         4435.1       |          4944.7        |    4227.6  |            5181.2        
                          f16 B=64, M=128, H=16, K=128   |         5769.1       |          6858.2        |    5299.8  |            7356.1        
                          f32 B=64, M=128, H=16, K=128   |         8577.9       |          9672.0        |    6916.3  |           10093.5        
                          f16 B=64, M=512, H=16, K=16    |        25994.0       |         26782.0        |   27240.9  |           26662.2        
                          f32 B=64, M=512, H=16, K=16    |        36864.9       |         37299.3        |   34159.3  |           37576.7        
                          f16 B=64, M=512, H=16, K=32    |        30680.4       |         32113.8        |   30109.0  |           32419.7        
                          f32 B=64, M=512, H=16, K=32    |        43638.5       |         44557.9        |   37358.5  |           45145.0        
                          f16 B=64, M=512, H=16, K=64    |        39417.5       |         41666.5        |   36004.2  |           42374.9        
                          f32 B=64, M=512, H=16, K=64    |        60049.2       |         63148.0        |   43412.6  |           63286.8        
                          f16 B=64, M=512, H=16, K=128   |        88951.1       |         93087.0        |   51730.1  |           94861.6        
                          f32 B=64, M=512, H=16, K=128   |       119728.7       |        124340.3        |   62413.7  |          126382.2        
                          f16 B=64, M=1024, H=16, K=16   |       108368.3       |        111081.8        |  106479.7  |          108716.1        
                          f32 B=64, M=1024, H=16, K=16   |       145612.0       |        147310.4        |            |          147380.7        
                          f16 B=64, M=1024, H=16, K=32   |       124296.1       |        127366.8        |  113905.0  |          126975.3        
                          f32 B=64, M=1024, H=16, K=32   |       171082.3       |        172539.0        |            |          173893.9        
                          f16 B=64, M=1024, H=16, K=64   |       155116.3       |        160429.2        |  130759.4  |          161834.0        
                          f32 B=64, M=1024, H=16, K=64   |       234356.0       |        239612.2        |            |          239948.3        
                          f16 B=64, M=1024, H=16, K=128  |       349728.3       |        360975.7        |  176158.7  |          371185.2        
                          f32 B=64, M=1024, H=16, K=128  |       468810.0       |        476415.4        |            |          481908.5        
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |         1700.3       |          1840.0        |    1375.3  |            1930.9        
                          f32 B=384, M=197, H=1, K=88    |         4456.4       |          4579.3        |    2235.5  |            4708.6        
                          f16 B=384, M=197, H=1, K=80    |         1623.3       |          1719.9        |    1279.5  |            1806.9        
                          f32 B=384, M=197, H=1, K=80    |         4031.2       |          4141.9        |    2149.8  |            4252.6        
                          f16 B=384, M=197, H=1, K=64    |         1092.8       |          1187.0        |    1048.5  |            1237.6        
                          f32 B=384, M=197, H=1, K=64    |         2717.5       |          2918.5        |    1738.5  |            2907.9        
                          f16 B=1024, M=197, H=1, K=88   |         4428.7       |          4906.2        |    3723.7  |            5178.2        
                          f32 B=1024, M=197, H=1, K=88   |        10947.5       |         11362.9        |    6052.5  |           11802.1        
                          f16 B=1024, M=197, H=1, K=80   |         4237.1       |          4491.4        |    3331.7  |            4725.6        
                          f32 B=1024, M=197, H=1, K=80   |         9842.6       |         10159.7        |    5682.4  |           10435.6        
                          f16 B=1024, M=197, H=1, K=64   |         2679.2       |          2927.4        |    2674.4  |            3033.0        
                          f32 B=1024, M=197, H=1, K=64   |         6597.6       |          7154.9        |    4489.7  |            7063.1        
                          f16 B=512, M=197, H=1, K=80    |         2239.5       |          2366.5        |    1684.2  |            2472.0        
                          f32 B=512, M=197, H=1, K=80    |         5362.4       |          5519.6        |    2857.9  |            5651.4        
                          f16 B=32, M=197, H=16, K=80    |         2208.1       |          2380.0        |    1803.4  |            2439.4        
                          f32 B=32, M=197, H=16, K=80    |         5503.6       |          5736.7        |    3017.5  |            5796.2        
                          f16 B=32, M=197, H=16, K=64    |         1493.4       |          1620.6        |    1457.2  |            1678.6        
                          f32 B=32, M=197, H=16, K=64    |         3672.6       |          3941.6        |    2415.0  |            3898.2        
                          f16 B=32, M=197, H=16, K=128   |         2634.3       |          2888.0        |    2215.1  |            2991.5        
                          f32 B=32, M=197, H=16, K=128   |         6811.5       |          7334.0        |    4049.3  |            7261.9        
                          f16 B=256, M=197, H=1, K=88    |         1290.3       |          1382.0        |     944.8  |            1449.4        
                          f32 B=256, M=197, H=1, K=88    |         2965.8       |          3043.2        |    1528.7  |            3137.7        
                          f16 B=16, M=197, H=16, K=88    |         1267.3       |          1357.0        |     970.8  |            1395.5        
                          f32 B=16, M=197, H=16, K=88    |         2879.9       |          3014.7        |    1626.5  |            3054.3        
                          f16 B=16, M=197, H=16, K=64    |          737.3       |           799.8        |     771.3  |             836.9        
                          f32 B=16, M=197, H=16, K=64    |         1879.2       |          2000.9        |    1282.5  |            1994.5        
                          f16 B=16, M=197, H=16, K=128   |         1443.9       |          1570.7        |    1142.2  |            1628.8        
                          f32 B=16, M=197, H=16, K=128   |         3480.5       |          3723.6        |    2027.2  |            3714.6        
                          f16 B=1, M=4096, H=160, K=128  |       150006.2       |        151877.5        |            |          152570.6        
                          f32 B=1, M=4096, H=160, K=128  |       582870.9       |        583519.8        |            |          585570.1        
                          f16 B=2, M=4096, H=160, K=128  |       301231.4       |        304511.7        |            |          305801.2        
                          f32 B=2, M=4096, H=160, K=128  |      1174724.1       |       1172498.4        |            |         1176814.0        
                          f16 B=1, M=8192, H=160, K=128  |       597461.6       |        600463.4        |            |          603066.6        
                          f32 B=1, M=8192, H=160, K=128  |      2333657.8       |       2329212.1        |            |         2339766.1        
                          f16 B=2, M=8192, H=160, K=128  |      1196837.5       |       1206932.4        |            |         1209012.2        
                          f16 B=1024, M=82, H=8, K=64    |         8926.8       |          9723.4        |    5799.4  |           10084.2        
                          f32 B=1024, M=82, H=8, K=64    |        15920.4       |         17434.4        |   11027.0  |           17492.8        
                          f16 B=150, M=256, H=16, K=64   |         5524.2       |          6363.9        |    7557.9  |            6586.2        
                          f32 B=150, M=256, H=16, K=64   |        17506.9       |         18843.5        |   16263.5  |           18988.6        
                          f16 B=64, M=256, H=12, K=64    |         1800.6       |          2050.3        |    2383.4  |            2139.0        
                          f32 B=64, M=256, H=12, K=64    |         5753.6       |          6196.3        |    4971.2  |            6200.0        
                          f16 B=1, M=4096, H=16, K=40    |        47649.5       |         47836.0        |    8368.4  |           47973.6        
                          f32 B=1, M=4096, H=16, K=40    |       111092.1       |        111027.3        |   19475.9  |          111257.8        
                          f16 B=1, M=16384, H=16, K=40   |       765320.2       |        765686.9        |            |          767337.2        
                          f32 B=1, M=16384, H=16, K=40   |      1769169.0       |       1769675.1        |            |         1769371.4        
                          f16 B=16, M=128, H=16, K=16    |          178.9       |           196.8        |     445.9  |             188.3        
                          f32 B=16, M=128, H=16, K=16    |          301.3       |           319.1        |     422.5  |             336.3        
                          f16 B=16, M=128, H=16, K=32    |          174.1       |           174.2        |     394.0  |             179.5        
                          f32 B=16, M=128, H=16, K=32    |          395.7       |           433.2        |     580.0  |             440.4        
                          f16 B=16, M=128, H=16, K=64    |          205.0       |           253.5        |     460.6  |             270.9        
                          f32 B=16, M=128, H=16, K=64    |          573.7       |           639.3        |     598.1  |             656.1        
                          f16 B=16, M=128, H=16, K=128   |          399.5       |           484.3        |     515.2  |             521.8        
                          f32 B=16, M=128, H=16, K=128   |         1126.3       |          1260.8        |    1008.1  |            1282.4        
                          f16 B=16, M=512, H=16, K=16    |         1597.6       |          1627.2        |    1901.1  |            1662.1        
                          f32 B=16, M=512, H=16, K=16    |         4458.5       |          4528.8        |    4232.0  |            4559.4        
                          f16 B=16, M=512, H=16, K=32    |         1819.1       |          1868.7        |    2097.2  |            1945.5        
                          f32 B=16, M=512, H=16, K=32    |         5604.2       |          5757.1        |    4566.4  |            5784.8        
                          f16 B=16, M=512, H=16, K=64    |         2345.5       |          2495.6        |    2558.0  |            2573.2        
                          f32 B=16, M=512, H=16, K=64    |         7778.3       |          8017.1        |    5488.2  |            8083.7        
                          f16 B=16, M=512, H=16, K=128   |         4516.6       |          4821.0        |    3386.7  |            4968.2        
                          f32 B=16, M=512, H=16, K=128   |        15412.7       |         15959.2        |    8865.9  |           16047.5        
                          f16 B=16, M=1024, H=16, K=16   |         6195.9       |          6217.6        |    6995.3  |            6326.4        
                          f32 B=16, M=1024, H=16, K=16   |        18136.2       |         18312.0        |   16088.2  |           18354.1        
                          f16 B=16, M=1024, H=16, K=32   |         7072.8       |          7122.3        |    7406.9  |            7297.7        
                          f32 B=16, M=1024, H=16, K=32   |        22108.2       |         22116.7        |   17112.5  |           22436.8        
                          f16 B=16, M=1024, H=16, K=64   |         8868.0       |          9104.6        |    8627.1  |            9311.8        
                          f32 B=16, M=1024, H=16, K=64   |        30710.5       |         31041.3        |   19860.8  |           31338.1        
                          f16 B=16, M=1024, H=16, K=128  |        17091.8       |         17655.5        |   10548.3  |           18083.8        
                          f32 B=16, M=1024, H=16, K=128  |        60317.8       |         61461.7        |   32919.2  |           61548.8        
                          f16 B=64, M=128, H=16, K=16    |          413.6       |           453.8        |     635.5  |             480.6        
                          f32 B=64, M=128, H=16, K=16    |         1033.8       |          1114.3        |    1238.9  |            1119.5        
                          f16 B=64, M=128, H=16, K=32    |          505.7       |           587.9        |     813.6  |             630.1        
                          f32 B=64, M=128, H=16, K=32    |         1423.0       |          1551.4        |    1533.4  |            1581.8        
                          f16 B=64, M=128, H=16, K=64    |          743.3       |           916.8        |    1187.7  |             976.5        
                          f32 B=64, M=128, H=16, K=64    |         2093.3       |          2384.6        |    2156.3  |            2405.4        
                          f16 B=64, M=128, H=16, K=128   |         1408.2       |          1734.3        |    1918.7  |            1859.6        
                          f32 B=64, M=128, H=16, K=128   |         4125.3       |          4671.4        |    3762.0  |            4717.0        
                          f16 B=64, M=512, H=16, K=16    |         5531.2       |          5643.3        |    7454.4  |            5770.8        
                          f32 B=64, M=512, H=16, K=16    |        16214.0       |         16531.2        |   16661.3  |           16540.8        
                          f16 B=64, M=512, H=16, K=32    |         6495.5       |          6725.2        |    8353.7  |            6941.8        
                          f32 B=64, M=512, H=16, K=32    |        20520.6       |         20941.9        |   18352.4  |           21116.8        
                          f16 B=64, M=512, H=16, K=64    |         8686.1       |          9278.6        |   10343.4  |            9593.2        
                          f32 B=64, M=512, H=16, K=64    |        28891.1       |         30003.0        |   22749.4  |           30139.1        
                          f16 B=64, M=512, H=16, K=128   |        15991.4       |         17412.3        |   14633.0  |           17848.2        
                          f32 B=64, M=512, H=16, K=128   |        57526.8       |         59970.8        |   40089.9  |           60016.9        
                          f16 B=64, M=1024, H=16, K=16   |        21552.8       |         21603.1        |   28447.1  |           22030.0        
                          f32 B=64, M=1024, H=16, K=16   |        65321.2       |         65736.8        |            |           65932.0        
                          f16 B=64, M=1024, H=16, K=32   |        25695.4       |         25905.9        |   30592.1  |           26644.8        
                          f32 B=64, M=1024, H=16, K=32   |        80213.4       |         80446.7        |            |           81363.1        
                          f16 B=64, M=1024, H=16, K=64   |        32465.6       |         33575.1        |   37233.4  |           34370.8        
                          f32 B=64, M=1024, H=16, K=64   |       112996.7       |        115632.0        |            |          115970.8        
                          f16 B=64, M=1024, H=16, K=128  |        60363.5       |         62800.2        |   48883.7  |           64505.1        
                          f32 B=64, M=1024, H=16, K=128  |       225023.4       |        230527.4        |            |          229851.8        

Times are in microseconds (us).
```
</details>

[ghstack-poisoned]
danthe3rd pushed a commit that referenced this pull request Oct 6, 2022
ghstack-source-id: 370afb5983c34f74dca3a4d324240eed44e78add
Pull Request resolved: #458
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great, thanks!

I have some minor suggestions in the tests to ensure we cover a few more cases, but they can be done in a follow-up PR.

Also, if we are really sure about our implementation and want to enable it to all downstream users without asking them to change their code (which I'm not sure we should do now, might be better to be explicit), we could also use the new torch.library.Library functionality from PyTorch that allows overriding PyTorch functions directly from Python, so we could override the unbind_backward function.

An example is as follows (taken from pytorch/pytorch#75905):

def my_sum(*args, **kwargs):
    return args[0]
my_lib1 = torch.library.Library("aten", "IMPL")
my_lib1.impl('aten::sum', my_sum)
x = torch.tensor([1, 2])
assert torch.sum(x) == x
del my_lib1
assert torch.sum(x) == torch.tensor(3)

tests/test_mem_eff_attention.py Outdated Show resolved Hide resolved
tests/test_mem_eff_attention.py Outdated Show resolved Hide resolved
tests/test_mem_eff_attention.py Outdated Show resolved Hide resolved
**SUMMARY**

Also:
- updated benchmarks to reflect the real-world scenario where we chunk a packed qkv - for both fw/bw.
- added coverage for chunking in tests

**PERFORMANCE IMPACT**

<details>
<summary>A100 bw (new benchmarks)</summary>

```
[---------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------]
                                     |  48_chunk3_31735f9  |  45_bwpacked_e53c5f3  |  vanilla  |  47_bwpackedgrad_9bacdf6
1 threads: --------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |          560.7      |           663.9       |   2265.7  |             710.3       
      f32 B=384, M=197, H=1, K=88    |         2445.1      |          2540.3       |   1843.3  |            2611.0       
      f16 B=384, M=197, H=1, K=80    |          530.4      |           619.9       |   1922.8  |             663.0       
      f32 B=384, M=197, H=1, K=80    |         2326.1      |          2425.2       |   1788.7  |            2476.4       
      f16 B=384, M=197, H=1, K=64    |          391.7      |           462.2       |   1812.7  |             492.8       
      f32 B=384, M=197, H=1, K=64    |         1275.0      |          1379.4       |   1675.4  |            1388.4       
      f16 B=1024, M=197, H=1, K=88   |         1399.5      |          1666.2       |   5965.2  |            1775.5       
      f32 B=1024, M=197, H=1, K=88   |         6332.5      |          6618.1       |   4559.6  |            6740.5       
      f16 B=1024, M=197, H=1, K=80   |         1326.2      |          1543.9       |   5041.4  |            1652.3       
      f32 B=1024, M=197, H=1, K=80   |         6057.1      |          6301.3       |   4411.6  |            6433.6       
      f16 B=1024, M=197, H=1, K=64   |          876.9      |          1063.1       |   4749.3  |            1133.2       
      f32 B=1024, M=197, H=1, K=64   |         3360.2      |          3629.0       |   4118.8  |            3652.0       
      f16 B=512, M=197, H=1, K=80    |          669.0      |           786.4       |   2544.9  |             842.2       
      f32 B=512, M=197, H=1, K=80    |         3032.3      |          3127.8       |   2287.4  |            3229.8       
      f16 B=32, M=197, H=16, K=80    |          663.0      |           789.7       |   2569.0  |             837.8       
      f32 B=32, M=197, H=16, K=80    |         3005.5      |          3166.3       |   2354.1  |            3225.9       
      f16 B=32, M=197, H=16, K=64    |          459.9      |           553.4       |   2436.3  |             591.9       
      f32 B=32, M=197, H=16, K=64    |         1814.1      |          1962.5       |   2197.3  |            1962.1       
      f16 B=32, M=197, H=16, K=128   |          792.5      |           981.9       |   4505.9  |            1056.5       
      f32 B=32, M=197, H=16, K=128   |         3734.8      |          3995.7       |   2805.8  |            4021.5       
      f16 B=256, M=197, H=1, K=88    |          413.4      |           482.6       |   1529.5  |             515.5       
      f32 B=256, M=197, H=1, K=88    |         1741.9      |          1818.3       |   1208.6  |            1852.4       
      f16 B=16, M=197, H=16, K=88    |          410.3      |           482.9       |   1545.7  |             512.5       
      f32 B=16, M=197, H=16, K=88    |         1734.9      |          1832.1       |   1250.6  |            1849.4       
      f16 B=16, M=197, H=16, K=64    |          235.4      |           286.0       |   1247.1  |             305.3       
      f32 B=16, M=197, H=16, K=64    |         1077.1      |          1143.7       |   1125.9  |            1154.0       
      f16 B=16, M=197, H=16, K=128   |          455.4      |           554.1       |   2273.1  |             596.0       
      f32 B=16, M=197, H=16, K=128   |         2028.9      |          2164.5       |   1446.7  |            2175.0       
      f16 B=1, M=4096, H=160, K=128  |        62454.4      |         63474.5       |  45930.5  |           64052.7       
      f32 B=1, M=4096, H=160, K=128  |       239035.4      |        232672.1       |           |          240073.9       
      f16 B=2, M=4096, H=160, K=128  |        98791.3      |        101006.4       |           |          101942.0       
      f32 B=2, M=4096, H=160, K=128  |       375914.9      |        368050.6       |           |          381280.4       
      f16 B=1, M=8192, H=160, K=128  |       248498.9      |        250066.9       |           |          251500.4       
      f32 B=1, M=8192, H=160, K=128  |       945102.2      |        922549.3       |           |          949256.4       
      f16 B=2, M=8192, H=160, K=128  |       389207.8      |        394486.6       |           |          396190.4       
      f32 B=2, M=8192, H=160, K=128  |      1496334.3      |       1449974.3       |           |         1502215.3       
      f16 B=1024, M=82, H=8, K=64    |         1872.4      |          2503.8       |   3819.8  |            2693.7       
      f32 B=1024, M=82, H=8, K=64    |         8734.3      |          9637.8       |   8732.9  |            9672.2       
      f16 B=150, M=256, H=16, K=64   |         2126.4      |          2713.4       |   4554.3  |            2880.8       
      f32 B=150, M=256, H=16, K=64   |         6214.3      |          7052.2       |  12943.2  |            7099.2       
      f16 B=64, M=256, H=12, K=64    |          741.2      |           930.1       |   1493.0  |             990.6       
      f32 B=64, M=256, H=12, K=64    |         2144.2      |          2408.5       |   4267.7  |            2433.8       
      f16 B=1, M=4096, H=16, K=40    |        24583.7      |         24224.8       |   4195.2  |           24500.2       
      f32 B=1, M=4096, H=16, K=40    |        72497.9      |         72070.8       |  17744.1  |           72393.0       
      f16 B=1, M=16384, H=16, K=40   |       451481.8      |        439027.7       |           |          451499.9       
      f32 B=1, M=16384, H=16, K=40   |      1169509.1      |       1164880.1       |           |         1169769.3       
      f16 B=256, M=4096, H=16, K=64  |       597391.6      |        625921.0       |           |          610433.2       
      f16 B=16, M=128, H=16, K=16    |           93.1      |           126.7       |    241.2  |             132.3       
      f32 B=16, M=128, H=16, K=16    |          184.1      |           176.5       |    373.8  |             180.7       
      f16 B=16, M=128, H=16, K=32    |          127.9      |           126.3       |    241.4  |             106.7       
      f32 B=16, M=128, H=16, K=32    |          194.1      |           216.6       |    412.7  |             225.8       
      f16 B=16, M=128, H=16, K=64    |          131.4      |           126.8       |    239.8  |             134.5       
      f32 B=16, M=128, H=16, K=64    |          280.4      |           326.0       |    500.0  |             334.0       
      f16 B=16, M=128, H=16, K=128   |          175.6      |           236.1       |    298.8  |             261.1       
      f32 B=16, M=128, H=16, K=128   |          531.8      |           615.8       |    677.2  |             638.0       
      f16 B=16, M=512, H=16, K=16    |          558.2      |           595.0       |   1201.9  |             607.8       
      f32 B=16, M=512, H=16, K=16    |         2146.7      |          2169.9       |   4416.1  |            2200.6       
      f16 B=16, M=512, H=16, K=32    |          653.5      |           732.3       |   1305.1  |             748.5       
      f32 B=16, M=512, H=16, K=32    |         2296.3      |          2373.9       |   4641.3  |            2400.1       
      f16 B=16, M=512, H=16, K=64    |          848.8      |           996.9       |   1544.6  |            1022.5       
      f32 B=16, M=512, H=16, K=64    |         2954.0      |          3117.1       |   5124.7  |            3157.6       
      f16 B=16, M=512, H=16, K=128   |         1735.4      |          1961.1       |   1982.7  |            2056.9       
      f32 B=16, M=512, H=16, K=128   |         6218.7      |          6396.4       |   6094.0  |            6600.3       
      f16 B=16, M=1024, H=16, K=16   |         2236.4      |          2319.4       |   4279.0  |            2331.6       
      f32 B=16, M=1024, H=16, K=16   |         8379.2      |          8363.9       |  16643.9  |            8503.6       
      f16 B=16, M=1024, H=16, K=32   |         2430.8      |          2649.6       |   4496.8  |            2608.7       
      f32 B=16, M=1024, H=16, K=32   |         8864.7      |          8907.8       |  17291.0  |            9074.0       
      f16 B=16, M=1024, H=16, K=64   |         3007.2      |          3351.3       |   4995.5  |            3351.0       
      f32 B=16, M=1024, H=16, K=64   |        11355.4      |         11627.1       |  18707.5  |           11694.3       
      f16 B=16, M=1024, H=16, K=128  |         6296.2      |          6748.7       |   5943.5  |            6967.0       
      f32 B=16, M=1024, H=16, K=128  |        23425.3      |         23360.0       |  21520.6  |           24169.7       
      f16 B=64, M=128, H=16, K=16    |          165.5      |           195.9       |    440.3  |             211.5       
      f32 B=64, M=128, H=16, K=16    |          497.4      |           540.7       |   1270.8  |             550.3       
      f16 B=64, M=128, H=16, K=32    |          210.4      |           274.9       |    544.8  |             298.5       
      f32 B=64, M=128, H=16, K=32    |          604.4      |           696.6       |   1428.3  |             710.9       
      f16 B=64, M=128, H=16, K=64    |          330.4      |           452.3       |    766.0  |             498.1       
      f32 B=64, M=128, H=16, K=64    |          883.4      |          1060.4       |   1745.2  |            1082.2       
      f16 B=64, M=128, H=16, K=128   |          605.5      |           847.8       |   1223.6  |             933.9       
      f32 B=64, M=128, H=16, K=128   |         1847.4      |          2169.7       |   2388.8  |            2236.0       
      f16 B=64, M=512, H=16, K=16    |         2004.7      |          2120.0       |   4487.0  |            2179.4       
      f32 B=64, M=512, H=16, K=16    |         6655.4      |          6818.8       |  16993.8  |            6872.1       
      f16 B=64, M=512, H=16, K=32    |         2379.3      |          2593.1       |   4957.2  |            2704.0       
      f32 B=64, M=512, H=16, K=32    |         7349.4      |          7644.6       |  17852.2  |            7736.2       
      f16 B=64, M=512, H=16, K=64    |         3129.6      |          3616.6       |   5888.8  |            3786.2       
      f32 B=64, M=512, H=16, K=64    |         9432.5      |         10123.9       |  19770.6  |           10178.5       
      f16 B=64, M=512, H=16, K=128   |         6054.1      |          7019.9       |   7712.6  |            7350.2       
      f32 B=64, M=512, H=16, K=128   |        21565.6      |         22281.9       |  23653.0  |           23084.4       
      f16 B=64, M=1024, H=16, K=16   |         7929.4      |          8199.1       |  16876.3  |            8242.5       
      f32 B=64, M=1024, H=16, K=16   |        26135.2      |         26347.9       |  66351.1  |           26639.0       
      f16 B=64, M=1024, H=16, K=32   |         8876.8      |          9450.0       |  17869.4  |            9473.5       
      f32 B=64, M=1024, H=16, K=32   |        27685.3      |         28104.6       |  69105.9  |           28428.7       
      f16 B=64, M=1024, H=16, K=64   |        11198.7      |         12180.5       |  19932.3  |           12543.4       
      f32 B=64, M=1024, H=16, K=64   |        34978.2      |         36239.4       |  74813.7  |           36482.4       
      f16 B=64, M=1024, H=16, K=128  |        21618.9      |         23439.6       |  23741.1  |           24160.1       
      f32 B=64, M=1024, H=16, K=128  |        80785.3      |         81080.8       |  86003.6  |           84132.9       

Times are in microseconds (us).
```
</details>

<details>
<summary>P100/V100 bw (new benchmarks)</summary>

```
[---------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------------]
                                                         |  48_chunk3_31735f94  |  45_bwpacked_e53c5f3a  |  vanilla   |  47_bwpackedgrad_9bacdf65
1 threads: --------------------------------------------------------------------------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |         6846.3       |          7583.8        |    3569.3  |            7599.5        
                          f32 B=384, M=197, H=1, K=88    |         9883.1       |         10107.2        |    4312.8  |           10486.3        
                          f16 B=384, M=197, H=1, K=80    |         6486.4       |          6997.7        |    3418.0  |            7037.3        
                          f32 B=384, M=197, H=1, K=80    |         9330.3       |          9550.6        |    4094.7  |            9893.4        
                          f16 B=384, M=197, H=1, K=64    |         3615.4       |          3930.4        |    2911.0  |            4074.2        
                          f32 B=384, M=197, H=1, K=64    |         6281.4       |          6554.5        |    3431.9  |            6738.1        
                          f16 B=1024, M=197, H=1, K=88   |        17226.8       |         18593.1        |    9733.2  |           18772.9        
                          f32 B=1024, M=197, H=1, K=88   |        26593.3       |         27136.2        |   12033.8  |           28184.2        
                          f16 B=1024, M=197, H=1, K=80   |        16330.1       |         17478.6        |    9270.2  |           17735.3        
                          f32 B=1024, M=197, H=1, K=80   |        25208.9       |         25680.1        |   11224.5  |           26636.1        
                          f16 B=1024, M=197, H=1, K=64   |         8889.1       |          9728.8        |    7646.1  |           10089.7        
                          f32 B=1024, M=197, H=1, K=64   |        16914.7       |         17743.4        |    9383.8  |           18068.4        
                          f16 B=512, M=197, H=1, K=80    |         8227.3       |          8878.4        |    4579.3  |            8953.6        
                          f32 B=512, M=197, H=1, K=80    |        13078.7       |         13346.0        |    5486.4  |           13817.6        
                          f16 B=32, M=197, H=16, K=80    |         8278.9       |          9002.9        |    4816.2  |            9025.6        
                          f32 B=32, M=197, H=16, K=80    |        12913.8       |         13371.2        |    5777.7  |           13667.6        
                          f16 B=32, M=197, H=16, K=64    |         4565.2       |          5000.0        |    4023.4  |            5146.3        
                          f32 B=32, M=197, H=16, K=64    |         8824.0       |          9257.7        |    4797.2  |            9400.5        
                          f16 B=32, M=197, H=16, K=128   |         9770.0       |         10849.7        |    5983.2  |           10932.0        
                          f32 B=32, M=197, H=16, K=128   |        15715.2       |         16559.9        |    7513.6  |           16839.9        
                          f16 B=256, M=197, H=1, K=88    |         5011.2       |          5363.8        |    2444.9  |            5426.0        
                          f32 B=256, M=197, H=1, K=88    |         6918.7       |          7040.8        |    2867.8  |            7303.2        
                          f16 B=16, M=197, H=16, K=88    |         4963.8       |          5343.9        |    2545.2  |            5398.9        
                          f32 B=16, M=197, H=16, K=88    |         6727.9       |          6981.7        |    3040.3  |            7121.2        
                          f16 B=16, M=197, H=16, K=64    |         2586.5       |          2777.1        |    2025.5  |            2905.6        
                          f32 B=16, M=197, H=16, K=64    |         4404.3       |          4607.2        |    2431.1  |            4691.8        
                          f16 B=16, M=197, H=16, K=128   |         5643.2       |          6194.1        |    3016.1  |            6216.3        
                          f32 B=16, M=197, H=16, K=128   |         7887.1       |          8308.3        |    3676.6  |            8456.2        
                          f16 B=1, M=4096, H=160, K=128  |      1087008.7       |       1115355.5        |            |         1091596.8        
                          f32 B=1, M=4096, H=160, K=128  |      1220066.8       |       1223422.8        |            |         1227912.2        
                          f16 B=2, M=4096, H=160, K=128  |      1734244.4       |       1794068.7        |            |         1756266.7        
                          f32 B=2, M=4096, H=160, K=128  |      2437675.5       |       2445780.4        |            |         2451957.5        
                          f16 B=1, M=8192, H=160, K=128  |      4367110.4       |       4466170.9        |            |         4383747.4        
                          f32 B=1, M=8192, H=160, K=128  |      4865732.9       |       4865708.9        |            |         4887066.5        
                          f16 B=2, M=8192, H=160, K=128  |      7002715.1       |       7146077.9        |            |         7033922.8        
                          f16 B=1024, M=82, H=8, K=64    |        23247.5       |         24929.5        |   18047.8  |           26928.2        
                          f32 B=1024, M=82, H=8, K=64    |        46463.2       |         48705.6        |   22797.5  |           50736.3        
                          f16 B=150, M=256, H=16, K=64   |        23467.9       |         25647.3        |   24569.2  |           26841.8        
                          f32 B=150, M=256, H=16, K=64   |        36887.7       |         39698.0        |   32050.2  |           40389.0        
                          f16 B=64, M=256, H=12, K=64    |         7723.7       |          8499.0        |    7702.1  |            8694.9        
                          f32 B=64, M=256, H=12, K=64    |        11992.1       |         12819.9        |    9874.5  |           13107.9        
                          f16 B=1, M=4096, H=16, K=40    |       142655.5       |        142899.7        |   28928.6  |          142922.7        
                          f32 B=1, M=4096, H=16, K=40    |       142626.8       |        142685.3        |   37303.2  |          142541.0        
                          f16 B=1, M=16384, H=16, K=40   |      2274095.0       |       2274882.0        |            |         2275019.9        
                          f32 B=1, M=16384, H=16, K=40   |      2284027.2       |       2279415.7        |            |         2277761.9        
                          f16 B=16, M=128, H=16, K=16    |          513.2       |           547.1        |     571.5  |             570.9        
                          f32 B=16, M=128, H=16, K=16    |          667.4       |           704.3        |     693.1  |             728.0        
                          f16 B=16, M=128, H=16, K=32    |          600.3       |           667.0        |     671.3  |             713.1        
                          f32 B=16, M=128, H=16, K=32    |          823.9       |           888.9        |     823.5  |             937.3        
                          f16 B=16, M=128, H=16, K=64    |          781.0       |           900.6        |     883.1  |             998.9        
                          f32 B=16, M=128, H=16, K=64    |         1173.7       |          1293.8        |    1077.0  |            1393.4        
                          f16 B=16, M=128, H=16, K=128   |         1649.2       |          1877.2        |    1323.2  |            2026.3        
                          f32 B=16, M=128, H=16, K=128   |         2250.5       |          2473.0        |    1654.7  |            2636.6        
                          f16 B=16, M=512, H=16, K=16    |         7709.3       |          7914.6        |    6945.1  |            7928.7        
                          f32 B=16, M=512, H=16, K=16    |         9797.2       |          9950.5        |    8499.4  |           10029.3        
                          f16 B=16, M=512, H=16, K=32    |         8956.9       |          9210.8        |    7517.1  |            9307.0        
                          f32 B=16, M=512, H=16, K=32    |        11480.7       |         11710.9        |    9249.4  |           11884.4        
                          f16 B=16, M=512, H=16, K=64    |        11324.0       |         11829.1        |    8849.5  |           12001.8        
                          f32 B=16, M=512, H=16, K=64    |        15744.1       |         16258.0        |   10954.6  |           16481.1        
                          f16 B=16, M=512, H=16, K=128   |        25320.2       |         26584.0        |   12412.3  |           26725.0        
                          f32 B=16, M=512, H=16, K=128   |        31187.1       |         32290.3        |   15167.5  |           32818.4        
                          f16 B=16, M=1024, H=16, K=16   |        31484.2       |         31601.4        |   26434.6  |           31894.6        
                          f32 B=16, M=1024, H=16, K=16   |        38754.1       |         38900.1        |   32320.0  |           39203.9        
                          f16 B=16, M=1024, H=16, K=32   |        36000.2       |         36672.6        |   28341.4  |           36579.5        
                          f32 B=16, M=1024, H=16, K=32   |        45070.7       |         45262.3        |   34914.2  |           45774.5        
                          f16 B=16, M=1024, H=16, K=64   |        45324.9       |         46540.4        |   32089.9  |           46784.2        
                          f32 B=16, M=1024, H=16, K=64   |        61320.3       |         62411.1        |   39565.0  |           63217.0        
                          f16 B=16, M=1024, H=16, K=128  |       104342.9       |        108469.4        |   43221.9  |          105620.6        
                          f32 B=16, M=1024, H=16, K=128  |       122688.4       |        125050.9        |   51205.7  |          126080.9        
                          f16 B=64, M=128, H=16, K=16    |         1707.9       |          1824.9        |    2106.4  |            1923.2        
                          f32 B=64, M=128, H=16, K=16    |         2487.4       |          2612.5        |    2565.1  |            2707.6        
                          f16 B=64, M=128, H=16, K=32    |         2016.8       |          2254.4        |    2485.4  |            2412.3        
                          f32 B=64, M=128, H=16, K=32    |         3135.8       |          3365.6        |    3063.2  |            3518.5        
                          f16 B=64, M=128, H=16, K=64    |         2700.2       |          3167.0        |    3306.0  |            3478.4        
                          f32 B=64, M=128, H=16, K=64    |         4435.1       |          4944.7        |    4227.6  |            5181.2        
                          f16 B=64, M=128, H=16, K=128   |         5769.1       |          6858.2        |    5299.8  |            7356.1        
                          f32 B=64, M=128, H=16, K=128   |         8577.9       |          9672.0        |    6916.3  |           10093.5        
                          f16 B=64, M=512, H=16, K=16    |        25994.0       |         26782.0        |   27240.9  |           26662.2        
                          f32 B=64, M=512, H=16, K=16    |        36864.9       |         37299.3        |   34159.3  |           37576.7        
                          f16 B=64, M=512, H=16, K=32    |        30680.4       |         32113.8        |   30109.0  |           32419.7        
                          f32 B=64, M=512, H=16, K=32    |        43638.5       |         44557.9        |   37358.5  |           45145.0        
                          f16 B=64, M=512, H=16, K=64    |        39417.5       |         41666.5        |   36004.2  |           42374.9        
                          f32 B=64, M=512, H=16, K=64    |        60049.2       |         63148.0        |   43412.6  |           63286.8        
                          f16 B=64, M=512, H=16, K=128   |        88951.1       |         93087.0        |   51730.1  |           94861.6        
                          f32 B=64, M=512, H=16, K=128   |       119728.7       |        124340.3        |   62413.7  |          126382.2        
                          f16 B=64, M=1024, H=16, K=16   |       108368.3       |        111081.8        |  106479.7  |          108716.1        
                          f32 B=64, M=1024, H=16, K=16   |       145612.0       |        147310.4        |            |          147380.7        
                          f16 B=64, M=1024, H=16, K=32   |       124296.1       |        127366.8        |  113905.0  |          126975.3        
                          f32 B=64, M=1024, H=16, K=32   |       171082.3       |        172539.0        |            |          173893.9        
                          f16 B=64, M=1024, H=16, K=64   |       155116.3       |        160429.2        |  130759.4  |          161834.0        
                          f32 B=64, M=1024, H=16, K=64   |       234356.0       |        239612.2        |            |          239948.3        
                          f16 B=64, M=1024, H=16, K=128  |       349728.3       |        360975.7        |  176158.7  |          371185.2        
                          f32 B=64, M=1024, H=16, K=128  |       468810.0       |        476415.4        |            |          481908.5        
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |         1700.3       |          1840.0        |    1375.3  |            1930.9        
                          f32 B=384, M=197, H=1, K=88    |         4456.4       |          4579.3        |    2235.5  |            4708.6        
                          f16 B=384, M=197, H=1, K=80    |         1623.3       |          1719.9        |    1279.5  |            1806.9        
                          f32 B=384, M=197, H=1, K=80    |         4031.2       |          4141.9        |    2149.8  |            4252.6        
                          f16 B=384, M=197, H=1, K=64    |         1092.8       |          1187.0        |    1048.5  |            1237.6        
                          f32 B=384, M=197, H=1, K=64    |         2717.5       |          2918.5        |    1738.5  |            2907.9        
                          f16 B=1024, M=197, H=1, K=88   |         4428.7       |          4906.2        |    3723.7  |            5178.2        
                          f32 B=1024, M=197, H=1, K=88   |        10947.5       |         11362.9        |    6052.5  |           11802.1        
                          f16 B=1024, M=197, H=1, K=80   |         4237.1       |          4491.4        |    3331.7  |            4725.6        
                          f32 B=1024, M=197, H=1, K=80   |         9842.6       |         10159.7        |    5682.4  |           10435.6        
                          f16 B=1024, M=197, H=1, K=64   |         2679.2       |          2927.4        |    2674.4  |            3033.0        
                          f32 B=1024, M=197, H=1, K=64   |         6597.6       |          7154.9        |    4489.7  |            7063.1        
                          f16 B=512, M=197, H=1, K=80    |         2239.5       |          2366.5        |    1684.2  |            2472.0        
                          f32 B=512, M=197, H=1, K=80    |         5362.4       |          5519.6        |    2857.9  |            5651.4        
                          f16 B=32, M=197, H=16, K=80    |         2208.1       |          2380.0        |    1803.4  |            2439.4        
                          f32 B=32, M=197, H=16, K=80    |         5503.6       |          5736.7        |    3017.5  |            5796.2        
                          f16 B=32, M=197, H=16, K=64    |         1493.4       |          1620.6        |    1457.2  |            1678.6        
                          f32 B=32, M=197, H=16, K=64    |         3672.6       |          3941.6        |    2415.0  |            3898.2        
                          f16 B=32, M=197, H=16, K=128   |         2634.3       |          2888.0        |    2215.1  |            2991.5        
                          f32 B=32, M=197, H=16, K=128   |         6811.5       |          7334.0        |    4049.3  |            7261.9        
                          f16 B=256, M=197, H=1, K=88    |         1290.3       |          1382.0        |     944.8  |            1449.4        
                          f32 B=256, M=197, H=1, K=88    |         2965.8       |          3043.2        |    1528.7  |            3137.7        
                          f16 B=16, M=197, H=16, K=88    |         1267.3       |          1357.0        |     970.8  |            1395.5        
                          f32 B=16, M=197, H=16, K=88    |         2879.9       |          3014.7        |    1626.5  |            3054.3        
                          f16 B=16, M=197, H=16, K=64    |          737.3       |           799.8        |     771.3  |             836.9        
                          f32 B=16, M=197, H=16, K=64    |         1879.2       |          2000.9        |    1282.5  |            1994.5        
                          f16 B=16, M=197, H=16, K=128   |         1443.9       |          1570.7        |    1142.2  |            1628.8        
                          f32 B=16, M=197, H=16, K=128   |         3480.5       |          3723.6        |    2027.2  |            3714.6        
                          f16 B=1, M=4096, H=160, K=128  |       150006.2       |        151877.5        |            |          152570.6        
                          f32 B=1, M=4096, H=160, K=128  |       582870.9       |        583519.8        |            |          585570.1        
                          f16 B=2, M=4096, H=160, K=128  |       301231.4       |        304511.7        |            |          305801.2        
                          f32 B=2, M=4096, H=160, K=128  |      1174724.1       |       1172498.4        |            |         1176814.0        
                          f16 B=1, M=8192, H=160, K=128  |       597461.6       |        600463.4        |            |          603066.6        
                          f32 B=1, M=8192, H=160, K=128  |      2333657.8       |       2329212.1        |            |         2339766.1        
                          f16 B=2, M=8192, H=160, K=128  |      1196837.5       |       1206932.4        |            |         1209012.2        
                          f16 B=1024, M=82, H=8, K=64    |         8926.8       |          9723.4        |    5799.4  |           10084.2        
                          f32 B=1024, M=82, H=8, K=64    |        15920.4       |         17434.4        |   11027.0  |           17492.8        
                          f16 B=150, M=256, H=16, K=64   |         5524.2       |          6363.9        |    7557.9  |            6586.2        
                          f32 B=150, M=256, H=16, K=64   |        17506.9       |         18843.5        |   16263.5  |           18988.6        
                          f16 B=64, M=256, H=12, K=64    |         1800.6       |          2050.3        |    2383.4  |            2139.0        
                          f32 B=64, M=256, H=12, K=64    |         5753.6       |          6196.3        |    4971.2  |            6200.0        
                          f16 B=1, M=4096, H=16, K=40    |        47649.5       |         47836.0        |    8368.4  |           47973.6        
                          f32 B=1, M=4096, H=16, K=40    |       111092.1       |        111027.3        |   19475.9  |          111257.8        
                          f16 B=1, M=16384, H=16, K=40   |       765320.2       |        765686.9        |            |          767337.2        
                          f32 B=1, M=16384, H=16, K=40   |      1769169.0       |       1769675.1        |            |         1769371.4        
                          f16 B=16, M=128, H=16, K=16    |          178.9       |           196.8        |     445.9  |             188.3        
                          f32 B=16, M=128, H=16, K=16    |          301.3       |           319.1        |     422.5  |             336.3        
                          f16 B=16, M=128, H=16, K=32    |          174.1       |           174.2        |     394.0  |             179.5        
                          f32 B=16, M=128, H=16, K=32    |          395.7       |           433.2        |     580.0  |             440.4        
                          f16 B=16, M=128, H=16, K=64    |          205.0       |           253.5        |     460.6  |             270.9        
                          f32 B=16, M=128, H=16, K=64    |          573.7       |           639.3        |     598.1  |             656.1        
                          f16 B=16, M=128, H=16, K=128   |          399.5       |           484.3        |     515.2  |             521.8        
                          f32 B=16, M=128, H=16, K=128   |         1126.3       |          1260.8        |    1008.1  |            1282.4        
                          f16 B=16, M=512, H=16, K=16    |         1597.6       |          1627.2        |    1901.1  |            1662.1        
                          f32 B=16, M=512, H=16, K=16    |         4458.5       |          4528.8        |    4232.0  |            4559.4        
                          f16 B=16, M=512, H=16, K=32    |         1819.1       |          1868.7        |    2097.2  |            1945.5        
                          f32 B=16, M=512, H=16, K=32    |         5604.2       |          5757.1        |    4566.4  |            5784.8        
                          f16 B=16, M=512, H=16, K=64    |         2345.5       |          2495.6        |    2558.0  |            2573.2        
                          f32 B=16, M=512, H=16, K=64    |         7778.3       |          8017.1        |    5488.2  |            8083.7        
                          f16 B=16, M=512, H=16, K=128   |         4516.6       |          4821.0        |    3386.7  |            4968.2        
                          f32 B=16, M=512, H=16, K=128   |        15412.7       |         15959.2        |    8865.9  |           16047.5        
                          f16 B=16, M=1024, H=16, K=16   |         6195.9       |          6217.6        |    6995.3  |            6326.4        
                          f32 B=16, M=1024, H=16, K=16   |        18136.2       |         18312.0        |   16088.2  |           18354.1        
                          f16 B=16, M=1024, H=16, K=32   |         7072.8       |          7122.3        |    7406.9  |            7297.7        
                          f32 B=16, M=1024, H=16, K=32   |        22108.2       |         22116.7        |   17112.5  |           22436.8        
                          f16 B=16, M=1024, H=16, K=64   |         8868.0       |          9104.6        |    8627.1  |            9311.8        
                          f32 B=16, M=1024, H=16, K=64   |        30710.5       |         31041.3        |   19860.8  |           31338.1        
                          f16 B=16, M=1024, H=16, K=128  |        17091.8       |         17655.5        |   10548.3  |           18083.8        
                          f32 B=16, M=1024, H=16, K=128  |        60317.8       |         61461.7        |   32919.2  |           61548.8        
                          f16 B=64, M=128, H=16, K=16    |          413.6       |           453.8        |     635.5  |             480.6        
                          f32 B=64, M=128, H=16, K=16    |         1033.8       |          1114.3        |    1238.9  |            1119.5        
                          f16 B=64, M=128, H=16, K=32    |          505.7       |           587.9        |     813.6  |             630.1        
                          f32 B=64, M=128, H=16, K=32    |         1423.0       |          1551.4        |    1533.4  |            1581.8        
                          f16 B=64, M=128, H=16, K=64    |          743.3       |           916.8        |    1187.7  |             976.5        
                          f32 B=64, M=128, H=16, K=64    |         2093.3       |          2384.6        |    2156.3  |            2405.4        
                          f16 B=64, M=128, H=16, K=128   |         1408.2       |          1734.3        |    1918.7  |            1859.6        
                          f32 B=64, M=128, H=16, K=128   |         4125.3       |          4671.4        |    3762.0  |            4717.0        
                          f16 B=64, M=512, H=16, K=16    |         5531.2       |          5643.3        |    7454.4  |            5770.8        
                          f32 B=64, M=512, H=16, K=16    |        16214.0       |         16531.2        |   16661.3  |           16540.8        
                          f16 B=64, M=512, H=16, K=32    |         6495.5       |          6725.2        |    8353.7  |            6941.8        
                          f32 B=64, M=512, H=16, K=32    |        20520.6       |         20941.9        |   18352.4  |           21116.8        
                          f16 B=64, M=512, H=16, K=64    |         8686.1       |          9278.6        |   10343.4  |            9593.2        
                          f32 B=64, M=512, H=16, K=64    |        28891.1       |         30003.0        |   22749.4  |           30139.1        
                          f16 B=64, M=512, H=16, K=128   |        15991.4       |         17412.3        |   14633.0  |           17848.2        
                          f32 B=64, M=512, H=16, K=128   |        57526.8       |         59970.8        |   40089.9  |           60016.9        
                          f16 B=64, M=1024, H=16, K=16   |        21552.8       |         21603.1        |   28447.1  |           22030.0        
                          f32 B=64, M=1024, H=16, K=16   |        65321.2       |         65736.8        |            |           65932.0        
                          f16 B=64, M=1024, H=16, K=32   |        25695.4       |         25905.9        |   30592.1  |           26644.8        
                          f32 B=64, M=1024, H=16, K=32   |        80213.4       |         80446.7        |            |           81363.1        
                          f16 B=64, M=1024, H=16, K=64   |        32465.6       |         33575.1        |   37233.4  |           34370.8        
                          f32 B=64, M=1024, H=16, K=64   |       112996.7       |        115632.0        |            |          115970.8        
                          f16 B=64, M=1024, H=16, K=128  |        60363.5       |         62800.2        |   48883.7  |           64505.1        
                          f32 B=64, M=1024, H=16, K=128  |       225023.4       |        230527.4        |            |          229851.8        

Times are in microseconds (us).
```
</details>

[ghstack-poisoned]
**SUMMARY**

Also:
- updated benchmarks to reflect the real-world scenario where we chunk a packed qkv - for both fw/bw.
- added coverage for chunking in tests

**PERFORMANCE IMPACT**

<details>
<summary>A100 bw (new benchmarks)</summary>

```
[---------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------]
                                     |  48_chunk3_31735f9  |  45_bwpacked_e53c5f3  |  vanilla  |  47_bwpackedgrad_9bacdf6
1 threads: --------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |          560.7      |           663.9       |   2265.7  |             710.3       
      f32 B=384, M=197, H=1, K=88    |         2445.1      |          2540.3       |   1843.3  |            2611.0       
      f16 B=384, M=197, H=1, K=80    |          530.4      |           619.9       |   1922.8  |             663.0       
      f32 B=384, M=197, H=1, K=80    |         2326.1      |          2425.2       |   1788.7  |            2476.4       
      f16 B=384, M=197, H=1, K=64    |          391.7      |           462.2       |   1812.7  |             492.8       
      f32 B=384, M=197, H=1, K=64    |         1275.0      |          1379.4       |   1675.4  |            1388.4       
      f16 B=1024, M=197, H=1, K=88   |         1399.5      |          1666.2       |   5965.2  |            1775.5       
      f32 B=1024, M=197, H=1, K=88   |         6332.5      |          6618.1       |   4559.6  |            6740.5       
      f16 B=1024, M=197, H=1, K=80   |         1326.2      |          1543.9       |   5041.4  |            1652.3       
      f32 B=1024, M=197, H=1, K=80   |         6057.1      |          6301.3       |   4411.6  |            6433.6       
      f16 B=1024, M=197, H=1, K=64   |          876.9      |          1063.1       |   4749.3  |            1133.2       
      f32 B=1024, M=197, H=1, K=64   |         3360.2      |          3629.0       |   4118.8  |            3652.0       
      f16 B=512, M=197, H=1, K=80    |          669.0      |           786.4       |   2544.9  |             842.2       
      f32 B=512, M=197, H=1, K=80    |         3032.3      |          3127.8       |   2287.4  |            3229.8       
      f16 B=32, M=197, H=16, K=80    |          663.0      |           789.7       |   2569.0  |             837.8       
      f32 B=32, M=197, H=16, K=80    |         3005.5      |          3166.3       |   2354.1  |            3225.9       
      f16 B=32, M=197, H=16, K=64    |          459.9      |           553.4       |   2436.3  |             591.9       
      f32 B=32, M=197, H=16, K=64    |         1814.1      |          1962.5       |   2197.3  |            1962.1       
      f16 B=32, M=197, H=16, K=128   |          792.5      |           981.9       |   4505.9  |            1056.5       
      f32 B=32, M=197, H=16, K=128   |         3734.8      |          3995.7       |   2805.8  |            4021.5       
      f16 B=256, M=197, H=1, K=88    |          413.4      |           482.6       |   1529.5  |             515.5       
      f32 B=256, M=197, H=1, K=88    |         1741.9      |          1818.3       |   1208.6  |            1852.4       
      f16 B=16, M=197, H=16, K=88    |          410.3      |           482.9       |   1545.7  |             512.5       
      f32 B=16, M=197, H=16, K=88    |         1734.9      |          1832.1       |   1250.6  |            1849.4       
      f16 B=16, M=197, H=16, K=64    |          235.4      |           286.0       |   1247.1  |             305.3       
      f32 B=16, M=197, H=16, K=64    |         1077.1      |          1143.7       |   1125.9  |            1154.0       
      f16 B=16, M=197, H=16, K=128   |          455.4      |           554.1       |   2273.1  |             596.0       
      f32 B=16, M=197, H=16, K=128   |         2028.9      |          2164.5       |   1446.7  |            2175.0       
      f16 B=1, M=4096, H=160, K=128  |        62454.4      |         63474.5       |  45930.5  |           64052.7       
      f32 B=1, M=4096, H=160, K=128  |       239035.4      |        232672.1       |           |          240073.9       
      f16 B=2, M=4096, H=160, K=128  |        98791.3      |        101006.4       |           |          101942.0       
      f32 B=2, M=4096, H=160, K=128  |       375914.9      |        368050.6       |           |          381280.4       
      f16 B=1, M=8192, H=160, K=128  |       248498.9      |        250066.9       |           |          251500.4       
      f32 B=1, M=8192, H=160, K=128  |       945102.2      |        922549.3       |           |          949256.4       
      f16 B=2, M=8192, H=160, K=128  |       389207.8      |        394486.6       |           |          396190.4       
      f32 B=2, M=8192, H=160, K=128  |      1496334.3      |       1449974.3       |           |         1502215.3       
      f16 B=1024, M=82, H=8, K=64    |         1872.4      |          2503.8       |   3819.8  |            2693.7       
      f32 B=1024, M=82, H=8, K=64    |         8734.3      |          9637.8       |   8732.9  |            9672.2       
      f16 B=150, M=256, H=16, K=64   |         2126.4      |          2713.4       |   4554.3  |            2880.8       
      f32 B=150, M=256, H=16, K=64   |         6214.3      |          7052.2       |  12943.2  |            7099.2       
      f16 B=64, M=256, H=12, K=64    |          741.2      |           930.1       |   1493.0  |             990.6       
      f32 B=64, M=256, H=12, K=64    |         2144.2      |          2408.5       |   4267.7  |            2433.8       
      f16 B=1, M=4096, H=16, K=40    |        24583.7      |         24224.8       |   4195.2  |           24500.2       
      f32 B=1, M=4096, H=16, K=40    |        72497.9      |         72070.8       |  17744.1  |           72393.0       
      f16 B=1, M=16384, H=16, K=40   |       451481.8      |        439027.7       |           |          451499.9       
      f32 B=1, M=16384, H=16, K=40   |      1169509.1      |       1164880.1       |           |         1169769.3       
      f16 B=256, M=4096, H=16, K=64  |       597391.6      |        625921.0       |           |          610433.2       
      f16 B=16, M=128, H=16, K=16    |           93.1      |           126.7       |    241.2  |             132.3       
      f32 B=16, M=128, H=16, K=16    |          184.1      |           176.5       |    373.8  |             180.7       
      f16 B=16, M=128, H=16, K=32    |          127.9      |           126.3       |    241.4  |             106.7       
      f32 B=16, M=128, H=16, K=32    |          194.1      |           216.6       |    412.7  |             225.8       
      f16 B=16, M=128, H=16, K=64    |          131.4      |           126.8       |    239.8  |             134.5       
      f32 B=16, M=128, H=16, K=64    |          280.4      |           326.0       |    500.0  |             334.0       
      f16 B=16, M=128, H=16, K=128   |          175.6      |           236.1       |    298.8  |             261.1       
      f32 B=16, M=128, H=16, K=128   |          531.8      |           615.8       |    677.2  |             638.0       
      f16 B=16, M=512, H=16, K=16    |          558.2      |           595.0       |   1201.9  |             607.8       
      f32 B=16, M=512, H=16, K=16    |         2146.7      |          2169.9       |   4416.1  |            2200.6       
      f16 B=16, M=512, H=16, K=32    |          653.5      |           732.3       |   1305.1  |             748.5       
      f32 B=16, M=512, H=16, K=32    |         2296.3      |          2373.9       |   4641.3  |            2400.1       
      f16 B=16, M=512, H=16, K=64    |          848.8      |           996.9       |   1544.6  |            1022.5       
      f32 B=16, M=512, H=16, K=64    |         2954.0      |          3117.1       |   5124.7  |            3157.6       
      f16 B=16, M=512, H=16, K=128   |         1735.4      |          1961.1       |   1982.7  |            2056.9       
      f32 B=16, M=512, H=16, K=128   |         6218.7      |          6396.4       |   6094.0  |            6600.3       
      f16 B=16, M=1024, H=16, K=16   |         2236.4      |          2319.4       |   4279.0  |            2331.6       
      f32 B=16, M=1024, H=16, K=16   |         8379.2      |          8363.9       |  16643.9  |            8503.6       
      f16 B=16, M=1024, H=16, K=32   |         2430.8      |          2649.6       |   4496.8  |            2608.7       
      f32 B=16, M=1024, H=16, K=32   |         8864.7      |          8907.8       |  17291.0  |            9074.0       
      f16 B=16, M=1024, H=16, K=64   |         3007.2      |          3351.3       |   4995.5  |            3351.0       
      f32 B=16, M=1024, H=16, K=64   |        11355.4      |         11627.1       |  18707.5  |           11694.3       
      f16 B=16, M=1024, H=16, K=128  |         6296.2      |          6748.7       |   5943.5  |            6967.0       
      f32 B=16, M=1024, H=16, K=128  |        23425.3      |         23360.0       |  21520.6  |           24169.7       
      f16 B=64, M=128, H=16, K=16    |          165.5      |           195.9       |    440.3  |             211.5       
      f32 B=64, M=128, H=16, K=16    |          497.4      |           540.7       |   1270.8  |             550.3       
      f16 B=64, M=128, H=16, K=32    |          210.4      |           274.9       |    544.8  |             298.5       
      f32 B=64, M=128, H=16, K=32    |          604.4      |           696.6       |   1428.3  |             710.9       
      f16 B=64, M=128, H=16, K=64    |          330.4      |           452.3       |    766.0  |             498.1       
      f32 B=64, M=128, H=16, K=64    |          883.4      |          1060.4       |   1745.2  |            1082.2       
      f16 B=64, M=128, H=16, K=128   |          605.5      |           847.8       |   1223.6  |             933.9       
      f32 B=64, M=128, H=16, K=128   |         1847.4      |          2169.7       |   2388.8  |            2236.0       
      f16 B=64, M=512, H=16, K=16    |         2004.7      |          2120.0       |   4487.0  |            2179.4       
      f32 B=64, M=512, H=16, K=16    |         6655.4      |          6818.8       |  16993.8  |            6872.1       
      f16 B=64, M=512, H=16, K=32    |         2379.3      |          2593.1       |   4957.2  |            2704.0       
      f32 B=64, M=512, H=16, K=32    |         7349.4      |          7644.6       |  17852.2  |            7736.2       
      f16 B=64, M=512, H=16, K=64    |         3129.6      |          3616.6       |   5888.8  |            3786.2       
      f32 B=64, M=512, H=16, K=64    |         9432.5      |         10123.9       |  19770.6  |           10178.5       
      f16 B=64, M=512, H=16, K=128   |         6054.1      |          7019.9       |   7712.6  |            7350.2       
      f32 B=64, M=512, H=16, K=128   |        21565.6      |         22281.9       |  23653.0  |           23084.4       
      f16 B=64, M=1024, H=16, K=16   |         7929.4      |          8199.1       |  16876.3  |            8242.5       
      f32 B=64, M=1024, H=16, K=16   |        26135.2      |         26347.9       |  66351.1  |           26639.0       
      f16 B=64, M=1024, H=16, K=32   |         8876.8      |          9450.0       |  17869.4  |            9473.5       
      f32 B=64, M=1024, H=16, K=32   |        27685.3      |         28104.6       |  69105.9  |           28428.7       
      f16 B=64, M=1024, H=16, K=64   |        11198.7      |         12180.5       |  19932.3  |           12543.4       
      f32 B=64, M=1024, H=16, K=64   |        34978.2      |         36239.4       |  74813.7  |           36482.4       
      f16 B=64, M=1024, H=16, K=128  |        21618.9      |         23439.6       |  23741.1  |           24160.1       
      f32 B=64, M=1024, H=16, K=128  |        80785.3      |         81080.8       |  86003.6  |           84132.9       

Times are in microseconds (us).
```
</details>

<details>
<summary>P100/V100 bw (new benchmarks)</summary>

```
[---------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------------]
                                                         |  48_chunk3_31735f94  |  45_bwpacked_e53c5f3a  |  vanilla   |  47_bwpackedgrad_9bacdf65
1 threads: --------------------------------------------------------------------------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |         6846.3       |          7583.8        |    3569.3  |            7599.5        
                          f32 B=384, M=197, H=1, K=88    |         9883.1       |         10107.2        |    4312.8  |           10486.3        
                          f16 B=384, M=197, H=1, K=80    |         6486.4       |          6997.7        |    3418.0  |            7037.3        
                          f32 B=384, M=197, H=1, K=80    |         9330.3       |          9550.6        |    4094.7  |            9893.4        
                          f16 B=384, M=197, H=1, K=64    |         3615.4       |          3930.4        |    2911.0  |            4074.2        
                          f32 B=384, M=197, H=1, K=64    |         6281.4       |          6554.5        |    3431.9  |            6738.1        
                          f16 B=1024, M=197, H=1, K=88   |        17226.8       |         18593.1        |    9733.2  |           18772.9        
                          f32 B=1024, M=197, H=1, K=88   |        26593.3       |         27136.2        |   12033.8  |           28184.2        
                          f16 B=1024, M=197, H=1, K=80   |        16330.1       |         17478.6        |    9270.2  |           17735.3        
                          f32 B=1024, M=197, H=1, K=80   |        25208.9       |         25680.1        |   11224.5  |           26636.1        
                          f16 B=1024, M=197, H=1, K=64   |         8889.1       |          9728.8        |    7646.1  |           10089.7        
                          f32 B=1024, M=197, H=1, K=64   |        16914.7       |         17743.4        |    9383.8  |           18068.4        
                          f16 B=512, M=197, H=1, K=80    |         8227.3       |          8878.4        |    4579.3  |            8953.6        
                          f32 B=512, M=197, H=1, K=80    |        13078.7       |         13346.0        |    5486.4  |           13817.6        
                          f16 B=32, M=197, H=16, K=80    |         8278.9       |          9002.9        |    4816.2  |            9025.6        
                          f32 B=32, M=197, H=16, K=80    |        12913.8       |         13371.2        |    5777.7  |           13667.6        
                          f16 B=32, M=197, H=16, K=64    |         4565.2       |          5000.0        |    4023.4  |            5146.3        
                          f32 B=32, M=197, H=16, K=64    |         8824.0       |          9257.7        |    4797.2  |            9400.5        
                          f16 B=32, M=197, H=16, K=128   |         9770.0       |         10849.7        |    5983.2  |           10932.0        
                          f32 B=32, M=197, H=16, K=128   |        15715.2       |         16559.9        |    7513.6  |           16839.9        
                          f16 B=256, M=197, H=1, K=88    |         5011.2       |          5363.8        |    2444.9  |            5426.0        
                          f32 B=256, M=197, H=1, K=88    |         6918.7       |          7040.8        |    2867.8  |            7303.2        
                          f16 B=16, M=197, H=16, K=88    |         4963.8       |          5343.9        |    2545.2  |            5398.9        
                          f32 B=16, M=197, H=16, K=88    |         6727.9       |          6981.7        |    3040.3  |            7121.2        
                          f16 B=16, M=197, H=16, K=64    |         2586.5       |          2777.1        |    2025.5  |            2905.6        
                          f32 B=16, M=197, H=16, K=64    |         4404.3       |          4607.2        |    2431.1  |            4691.8        
                          f16 B=16, M=197, H=16, K=128   |         5643.2       |          6194.1        |    3016.1  |            6216.3        
                          f32 B=16, M=197, H=16, K=128   |         7887.1       |          8308.3        |    3676.6  |            8456.2        
                          f16 B=1, M=4096, H=160, K=128  |      1087008.7       |       1115355.5        |            |         1091596.8        
                          f32 B=1, M=4096, H=160, K=128  |      1220066.8       |       1223422.8        |            |         1227912.2        
                          f16 B=2, M=4096, H=160, K=128  |      1734244.4       |       1794068.7        |            |         1756266.7        
                          f32 B=2, M=4096, H=160, K=128  |      2437675.5       |       2445780.4        |            |         2451957.5        
                          f16 B=1, M=8192, H=160, K=128  |      4367110.4       |       4466170.9        |            |         4383747.4        
                          f32 B=1, M=8192, H=160, K=128  |      4865732.9       |       4865708.9        |            |         4887066.5        
                          f16 B=2, M=8192, H=160, K=128  |      7002715.1       |       7146077.9        |            |         7033922.8        
                          f16 B=1024, M=82, H=8, K=64    |        23247.5       |         24929.5        |   18047.8  |           26928.2        
                          f32 B=1024, M=82, H=8, K=64    |        46463.2       |         48705.6        |   22797.5  |           50736.3        
                          f16 B=150, M=256, H=16, K=64   |        23467.9       |         25647.3        |   24569.2  |           26841.8        
                          f32 B=150, M=256, H=16, K=64   |        36887.7       |         39698.0        |   32050.2  |           40389.0        
                          f16 B=64, M=256, H=12, K=64    |         7723.7       |          8499.0        |    7702.1  |            8694.9        
                          f32 B=64, M=256, H=12, K=64    |        11992.1       |         12819.9        |    9874.5  |           13107.9        
                          f16 B=1, M=4096, H=16, K=40    |       142655.5       |        142899.7        |   28928.6  |          142922.7        
                          f32 B=1, M=4096, H=16, K=40    |       142626.8       |        142685.3        |   37303.2  |          142541.0        
                          f16 B=1, M=16384, H=16, K=40   |      2274095.0       |       2274882.0        |            |         2275019.9        
                          f32 B=1, M=16384, H=16, K=40   |      2284027.2       |       2279415.7        |            |         2277761.9        
                          f16 B=16, M=128, H=16, K=16    |          513.2       |           547.1        |     571.5  |             570.9        
                          f32 B=16, M=128, H=16, K=16    |          667.4       |           704.3        |     693.1  |             728.0        
                          f16 B=16, M=128, H=16, K=32    |          600.3       |           667.0        |     671.3  |             713.1        
                          f32 B=16, M=128, H=16, K=32    |          823.9       |           888.9        |     823.5  |             937.3        
                          f16 B=16, M=128, H=16, K=64    |          781.0       |           900.6        |     883.1  |             998.9        
                          f32 B=16, M=128, H=16, K=64    |         1173.7       |          1293.8        |    1077.0  |            1393.4        
                          f16 B=16, M=128, H=16, K=128   |         1649.2       |          1877.2        |    1323.2  |            2026.3        
                          f32 B=16, M=128, H=16, K=128   |         2250.5       |          2473.0        |    1654.7  |            2636.6        
                          f16 B=16, M=512, H=16, K=16    |         7709.3       |          7914.6        |    6945.1  |            7928.7        
                          f32 B=16, M=512, H=16, K=16    |         9797.2       |          9950.5        |    8499.4  |           10029.3        
                          f16 B=16, M=512, H=16, K=32    |         8956.9       |          9210.8        |    7517.1  |            9307.0        
                          f32 B=16, M=512, H=16, K=32    |        11480.7       |         11710.9        |    9249.4  |           11884.4        
                          f16 B=16, M=512, H=16, K=64    |        11324.0       |         11829.1        |    8849.5  |           12001.8        
                          f32 B=16, M=512, H=16, K=64    |        15744.1       |         16258.0        |   10954.6  |           16481.1        
                          f16 B=16, M=512, H=16, K=128   |        25320.2       |         26584.0        |   12412.3  |           26725.0        
                          f32 B=16, M=512, H=16, K=128   |        31187.1       |         32290.3        |   15167.5  |           32818.4        
                          f16 B=16, M=1024, H=16, K=16   |        31484.2       |         31601.4        |   26434.6  |           31894.6        
                          f32 B=16, M=1024, H=16, K=16   |        38754.1       |         38900.1        |   32320.0  |           39203.9        
                          f16 B=16, M=1024, H=16, K=32   |        36000.2       |         36672.6        |   28341.4  |           36579.5        
                          f32 B=16, M=1024, H=16, K=32   |        45070.7       |         45262.3        |   34914.2  |           45774.5        
                          f16 B=16, M=1024, H=16, K=64   |        45324.9       |         46540.4        |   32089.9  |           46784.2        
                          f32 B=16, M=1024, H=16, K=64   |        61320.3       |         62411.1        |   39565.0  |           63217.0        
                          f16 B=16, M=1024, H=16, K=128  |       104342.9       |        108469.4        |   43221.9  |          105620.6        
                          f32 B=16, M=1024, H=16, K=128  |       122688.4       |        125050.9        |   51205.7  |          126080.9        
                          f16 B=64, M=128, H=16, K=16    |         1707.9       |          1824.9        |    2106.4  |            1923.2        
                          f32 B=64, M=128, H=16, K=16    |         2487.4       |          2612.5        |    2565.1  |            2707.6        
                          f16 B=64, M=128, H=16, K=32    |         2016.8       |          2254.4        |    2485.4  |            2412.3        
                          f32 B=64, M=128, H=16, K=32    |         3135.8       |          3365.6        |    3063.2  |            3518.5        
                          f16 B=64, M=128, H=16, K=64    |         2700.2       |          3167.0        |    3306.0  |            3478.4        
                          f32 B=64, M=128, H=16, K=64    |         4435.1       |          4944.7        |    4227.6  |            5181.2        
                          f16 B=64, M=128, H=16, K=128   |         5769.1       |          6858.2        |    5299.8  |            7356.1        
                          f32 B=64, M=128, H=16, K=128   |         8577.9       |          9672.0        |    6916.3  |           10093.5        
                          f16 B=64, M=512, H=16, K=16    |        25994.0       |         26782.0        |   27240.9  |           26662.2        
                          f32 B=64, M=512, H=16, K=16    |        36864.9       |         37299.3        |   34159.3  |           37576.7        
                          f16 B=64, M=512, H=16, K=32    |        30680.4       |         32113.8        |   30109.0  |           32419.7        
                          f32 B=64, M=512, H=16, K=32    |        43638.5       |         44557.9        |   37358.5  |           45145.0        
                          f16 B=64, M=512, H=16, K=64    |        39417.5       |         41666.5        |   36004.2  |           42374.9        
                          f32 B=64, M=512, H=16, K=64    |        60049.2       |         63148.0        |   43412.6  |           63286.8        
                          f16 B=64, M=512, H=16, K=128   |        88951.1       |         93087.0        |   51730.1  |           94861.6        
                          f32 B=64, M=512, H=16, K=128   |       119728.7       |        124340.3        |   62413.7  |          126382.2        
                          f16 B=64, M=1024, H=16, K=16   |       108368.3       |        111081.8        |  106479.7  |          108716.1        
                          f32 B=64, M=1024, H=16, K=16   |       145612.0       |        147310.4        |            |          147380.7        
                          f16 B=64, M=1024, H=16, K=32   |       124296.1       |        127366.8        |  113905.0  |          126975.3        
                          f32 B=64, M=1024, H=16, K=32   |       171082.3       |        172539.0        |            |          173893.9        
                          f16 B=64, M=1024, H=16, K=64   |       155116.3       |        160429.2        |  130759.4  |          161834.0        
                          f32 B=64, M=1024, H=16, K=64   |       234356.0       |        239612.2        |            |          239948.3        
                          f16 B=64, M=1024, H=16, K=128  |       349728.3       |        360975.7        |  176158.7  |          371185.2        
                          f32 B=64, M=1024, H=16, K=128  |       468810.0       |        476415.4        |            |          481908.5        
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |         1700.3       |          1840.0        |    1375.3  |            1930.9        
                          f32 B=384, M=197, H=1, K=88    |         4456.4       |          4579.3        |    2235.5  |            4708.6        
                          f16 B=384, M=197, H=1, K=80    |         1623.3       |          1719.9        |    1279.5  |            1806.9        
                          f32 B=384, M=197, H=1, K=80    |         4031.2       |          4141.9        |    2149.8  |            4252.6        
                          f16 B=384, M=197, H=1, K=64    |         1092.8       |          1187.0        |    1048.5  |            1237.6        
                          f32 B=384, M=197, H=1, K=64    |         2717.5       |          2918.5        |    1738.5  |            2907.9        
                          f16 B=1024, M=197, H=1, K=88   |         4428.7       |          4906.2        |    3723.7  |            5178.2        
                          f32 B=1024, M=197, H=1, K=88   |        10947.5       |         11362.9        |    6052.5  |           11802.1        
                          f16 B=1024, M=197, H=1, K=80   |         4237.1       |          4491.4        |    3331.7  |            4725.6        
                          f32 B=1024, M=197, H=1, K=80   |         9842.6       |         10159.7        |    5682.4  |           10435.6        
                          f16 B=1024, M=197, H=1, K=64   |         2679.2       |          2927.4        |    2674.4  |            3033.0        
                          f32 B=1024, M=197, H=1, K=64   |         6597.6       |          7154.9        |    4489.7  |            7063.1        
                          f16 B=512, M=197, H=1, K=80    |         2239.5       |          2366.5        |    1684.2  |            2472.0        
                          f32 B=512, M=197, H=1, K=80    |         5362.4       |          5519.6        |    2857.9  |            5651.4        
                          f16 B=32, M=197, H=16, K=80    |         2208.1       |          2380.0        |    1803.4  |            2439.4        
                          f32 B=32, M=197, H=16, K=80    |         5503.6       |          5736.7        |    3017.5  |            5796.2        
                          f16 B=32, M=197, H=16, K=64    |         1493.4       |          1620.6        |    1457.2  |            1678.6        
                          f32 B=32, M=197, H=16, K=64    |         3672.6       |          3941.6        |    2415.0  |            3898.2        
                          f16 B=32, M=197, H=16, K=128   |         2634.3       |          2888.0        |    2215.1  |            2991.5        
                          f32 B=32, M=197, H=16, K=128   |         6811.5       |          7334.0        |    4049.3  |            7261.9        
                          f16 B=256, M=197, H=1, K=88    |         1290.3       |          1382.0        |     944.8  |            1449.4        
                          f32 B=256, M=197, H=1, K=88    |         2965.8       |          3043.2        |    1528.7  |            3137.7        
                          f16 B=16, M=197, H=16, K=88    |         1267.3       |          1357.0        |     970.8  |            1395.5        
                          f32 B=16, M=197, H=16, K=88    |         2879.9       |          3014.7        |    1626.5  |            3054.3        
                          f16 B=16, M=197, H=16, K=64    |          737.3       |           799.8        |     771.3  |             836.9        
                          f32 B=16, M=197, H=16, K=64    |         1879.2       |          2000.9        |    1282.5  |            1994.5        
                          f16 B=16, M=197, H=16, K=128   |         1443.9       |          1570.7        |    1142.2  |            1628.8        
                          f32 B=16, M=197, H=16, K=128   |         3480.5       |          3723.6        |    2027.2  |            3714.6        
                          f16 B=1, M=4096, H=160, K=128  |       150006.2       |        151877.5        |            |          152570.6        
                          f32 B=1, M=4096, H=160, K=128  |       582870.9       |        583519.8        |            |          585570.1        
                          f16 B=2, M=4096, H=160, K=128  |       301231.4       |        304511.7        |            |          305801.2        
                          f32 B=2, M=4096, H=160, K=128  |      1174724.1       |       1172498.4        |            |         1176814.0        
                          f16 B=1, M=8192, H=160, K=128  |       597461.6       |        600463.4        |            |          603066.6        
                          f32 B=1, M=8192, H=160, K=128  |      2333657.8       |       2329212.1        |            |         2339766.1        
                          f16 B=2, M=8192, H=160, K=128  |      1196837.5       |       1206932.4        |            |         1209012.2        
                          f16 B=1024, M=82, H=8, K=64    |         8926.8       |          9723.4        |    5799.4  |           10084.2        
                          f32 B=1024, M=82, H=8, K=64    |        15920.4       |         17434.4        |   11027.0  |           17492.8        
                          f16 B=150, M=256, H=16, K=64   |         5524.2       |          6363.9        |    7557.9  |            6586.2        
                          f32 B=150, M=256, H=16, K=64   |        17506.9       |         18843.5        |   16263.5  |           18988.6        
                          f16 B=64, M=256, H=12, K=64    |         1800.6       |          2050.3        |    2383.4  |            2139.0        
                          f32 B=64, M=256, H=12, K=64    |         5753.6       |          6196.3        |    4971.2  |            6200.0        
                          f16 B=1, M=4096, H=16, K=40    |        47649.5       |         47836.0        |    8368.4  |           47973.6        
                          f32 B=1, M=4096, H=16, K=40    |       111092.1       |        111027.3        |   19475.9  |          111257.8        
                          f16 B=1, M=16384, H=16, K=40   |       765320.2       |        765686.9        |            |          767337.2        
                          f32 B=1, M=16384, H=16, K=40   |      1769169.0       |       1769675.1        |            |         1769371.4        
                          f16 B=16, M=128, H=16, K=16    |          178.9       |           196.8        |     445.9  |             188.3        
                          f32 B=16, M=128, H=16, K=16    |          301.3       |           319.1        |     422.5  |             336.3        
                          f16 B=16, M=128, H=16, K=32    |          174.1       |           174.2        |     394.0  |             179.5        
                          f32 B=16, M=128, H=16, K=32    |          395.7       |           433.2        |     580.0  |             440.4        
                          f16 B=16, M=128, H=16, K=64    |          205.0       |           253.5        |     460.6  |             270.9        
                          f32 B=16, M=128, H=16, K=64    |          573.7       |           639.3        |     598.1  |             656.1        
                          f16 B=16, M=128, H=16, K=128   |          399.5       |           484.3        |     515.2  |             521.8        
                          f32 B=16, M=128, H=16, K=128   |         1126.3       |          1260.8        |    1008.1  |            1282.4        
                          f16 B=16, M=512, H=16, K=16    |         1597.6       |          1627.2        |    1901.1  |            1662.1        
                          f32 B=16, M=512, H=16, K=16    |         4458.5       |          4528.8        |    4232.0  |            4559.4        
                          f16 B=16, M=512, H=16, K=32    |         1819.1       |          1868.7        |    2097.2  |            1945.5        
                          f32 B=16, M=512, H=16, K=32    |         5604.2       |          5757.1        |    4566.4  |            5784.8        
                          f16 B=16, M=512, H=16, K=64    |         2345.5       |          2495.6        |    2558.0  |            2573.2        
                          f32 B=16, M=512, H=16, K=64    |         7778.3       |          8017.1        |    5488.2  |            8083.7        
                          f16 B=16, M=512, H=16, K=128   |         4516.6       |          4821.0        |    3386.7  |            4968.2        
                          f32 B=16, M=512, H=16, K=128   |        15412.7       |         15959.2        |    8865.9  |           16047.5        
                          f16 B=16, M=1024, H=16, K=16   |         6195.9       |          6217.6        |    6995.3  |            6326.4        
                          f32 B=16, M=1024, H=16, K=16   |        18136.2       |         18312.0        |   16088.2  |           18354.1        
                          f16 B=16, M=1024, H=16, K=32   |         7072.8       |          7122.3        |    7406.9  |            7297.7        
                          f32 B=16, M=1024, H=16, K=32   |        22108.2       |         22116.7        |   17112.5  |           22436.8        
                          f16 B=16, M=1024, H=16, K=64   |         8868.0       |          9104.6        |    8627.1  |            9311.8        
                          f32 B=16, M=1024, H=16, K=64   |        30710.5       |         31041.3        |   19860.8  |           31338.1        
                          f16 B=16, M=1024, H=16, K=128  |        17091.8       |         17655.5        |   10548.3  |           18083.8        
                          f32 B=16, M=1024, H=16, K=128  |        60317.8       |         61461.7        |   32919.2  |           61548.8        
                          f16 B=64, M=128, H=16, K=16    |          413.6       |           453.8        |     635.5  |             480.6        
                          f32 B=64, M=128, H=16, K=16    |         1033.8       |          1114.3        |    1238.9  |            1119.5        
                          f16 B=64, M=128, H=16, K=32    |          505.7       |           587.9        |     813.6  |             630.1        
                          f32 B=64, M=128, H=16, K=32    |         1423.0       |          1551.4        |    1533.4  |            1581.8        
                          f16 B=64, M=128, H=16, K=64    |          743.3       |           916.8        |    1187.7  |             976.5        
                          f32 B=64, M=128, H=16, K=64    |         2093.3       |          2384.6        |    2156.3  |            2405.4        
                          f16 B=64, M=128, H=16, K=128   |         1408.2       |          1734.3        |    1918.7  |            1859.6        
                          f32 B=64, M=128, H=16, K=128   |         4125.3       |          4671.4        |    3762.0  |            4717.0        
                          f16 B=64, M=512, H=16, K=16    |         5531.2       |          5643.3        |    7454.4  |            5770.8        
                          f32 B=64, M=512, H=16, K=16    |        16214.0       |         16531.2        |   16661.3  |           16540.8        
                          f16 B=64, M=512, H=16, K=32    |         6495.5       |          6725.2        |    8353.7  |            6941.8        
                          f32 B=64, M=512, H=16, K=32    |        20520.6       |         20941.9        |   18352.4  |           21116.8        
                          f16 B=64, M=512, H=16, K=64    |         8686.1       |          9278.6        |   10343.4  |            9593.2        
                          f32 B=64, M=512, H=16, K=64    |        28891.1       |         30003.0        |   22749.4  |           30139.1        
                          f16 B=64, M=512, H=16, K=128   |        15991.4       |         17412.3        |   14633.0  |           17848.2        
                          f32 B=64, M=512, H=16, K=128   |        57526.8       |         59970.8        |   40089.9  |           60016.9        
                          f16 B=64, M=1024, H=16, K=16   |        21552.8       |         21603.1        |   28447.1  |           22030.0        
                          f32 B=64, M=1024, H=16, K=16   |        65321.2       |         65736.8        |            |           65932.0        
                          f16 B=64, M=1024, H=16, K=32   |        25695.4       |         25905.9        |   30592.1  |           26644.8        
                          f32 B=64, M=1024, H=16, K=32   |        80213.4       |         80446.7        |            |           81363.1        
                          f16 B=64, M=1024, H=16, K=64   |        32465.6       |         33575.1        |   37233.4  |           34370.8        
                          f32 B=64, M=1024, H=16, K=64   |       112996.7       |        115632.0        |            |          115970.8        
                          f16 B=64, M=1024, H=16, K=128  |        60363.5       |         62800.2        |   48883.7  |           64505.1        
                          f32 B=64, M=1024, H=16, K=128  |       225023.4       |        230527.4        |            |          229851.8        

Times are in microseconds (us).
```
</details>

[ghstack-poisoned]
@danthe3rd danthe3rd merged commit ed6a3c6 into gh/danthe3rd/48/base Oct 6, 2022
danthe3rd pushed a commit that referenced this pull request Oct 6, 2022
ghstack-source-id: e18e9b73589eac1e003c0b224bbff03c7fbb6445
Pull Request resolved: #458
@danthe3rd danthe3rd deleted the gh/danthe3rd/48/head branch October 6, 2022 12:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants