Skip to content

Commit

Permalink
[fmha/decoding] Microbenchmarks
Browse files Browse the repository at this point in the history
Results on A100:
```
[------------------------------------------------------------------ attn_decodingfw ------------------------------------------------------------------]
                                             |  pytorch  |  optimized[flash-decoding]  |  optimized[triton_splitK]  |  optimized[flash-attention2.0.9]
1 threads: --------------------------------------------------------------------------------------------------------------------------------------------
      B=256 Mq=1 Mkv=256 Hq=16 Hkv=1 K=128   |   1883.4  |             40.4            |            50.3            |                394.8
      B=128 Mq=1 Mkv=512 Hq=16 Hkv=1 K=128   |   1955.9  |             44.2            |            48.4            |                366.5
      B=64 Mq=1 Mkv=1024 Hq=16 Hkv=1 K=128   |   2012.0  |             34.7            |            48.8            |                368.4
      B=32 Mq=1 Mkv=2048 Hq=16 Hkv=1 K=128   |   2101.4  |             33.7            |            47.3            |                351.4
      B=16 Mq=1 Mkv=4096 Hq=16 Hkv=1 K=128   |   2057.9  |             31.9            |            50.0            |                403.2
      B=8 Mq=1 Mkv=8192 Hq=16 Hkv=1 K=128    |   2078.3  |             34.2            |            51.6            |                527.5
      B=4 Mq=1 Mkv=16384 Hq=16 Hkv=1 K=128   |   2135.2  |             37.4            |            47.0            |                581.6
      B=2 Mq=1 Mkv=32768 Hq=16 Hkv=1 K=128   |   2163.6  |             45.0            |            57.4            |               1154.1
      B=1 Mq=1 Mkv=65536 Hq=16 Hkv=1 K=128   |    413.3  |             61.0            |            80.7            |               2299.2
      B=1 Mq=1 Mkv=131072 Hq=16 Hkv=1 K=128  |    803.5  |             81.7            |           144.3            |               4585.0
      B=256 Mq=1 Mkv=256 Hq=16 Hkv=2 K=128   |   3059.6  |            402.6            |            73.7            |                393.8
      B=128 Mq=1 Mkv=512 Hq=16 Hkv=2 K=128   |   3148.5  |            377.0            |            72.0            |                369.3
      B=64 Mq=1 Mkv=1024 Hq=16 Hkv=2 K=128   |   3161.7  |            375.0            |            70.3            |                368.0
      B=32 Mq=1 Mkv=2048 Hq=16 Hkv=2 K=128   |   3157.6  |            363.8            |            70.6            |                354.4
      B=16 Mq=1 Mkv=4096 Hq=16 Hkv=2 K=128   |   3154.0  |            417.2            |            72.3            |                405.0
      B=8 Mq=1 Mkv=8192 Hq=16 Hkv=2 K=128    |   3173.3  |            532.8            |            76.8            |                528.3
      B=4 Mq=1 Mkv=16384 Hq=16 Hkv=2 K=128   |   3222.5  |            195.5            |            78.8            |                582.8
      B=2 Mq=1 Mkv=32768 Hq=16 Hkv=2 K=128   |   3221.3  |            222.1            |            72.1            |               1154.3
      B=1 Mq=1 Mkv=65536 Hq=16 Hkv=2 K=128   |   1333.9  |            222.6            |            96.3            |               2298.2
      B=1 Mq=1 Mkv=131072 Hq=16 Hkv=2 K=128  |   2656.6  |            427.5            |           169.4            |               4583.7

Times are in microseconds (us).
```

ghstack-source-id: f3e0817f6e9be418eda7afbf72f1797d33acf60e
Pull Request resolved: https://github.com/fairinternal/xformers/pull/797

__original_commit__ = fairinternal/xformers@381ad8088345b4c61051b6c767597dc3e320076c
  • Loading branch information
danthe3rd authored and xFormers Bot committed Sep 22, 2023
1 parent 0b8c7b6 commit d3656d5
Showing 1 changed file with 159 additions and 0 deletions.
159 changes: 159 additions & 0 deletions xformers/benchmarks/benchmark_attn_decoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import itertools
from typing import Any

import torch
from torch.utils import benchmark
from utils import benchmark_main_helper

import xformers.ops as xops

min_run_time = 0.5
device = torch.device("cuda")


def product_dict(**kwargs):
keys = kwargs.keys()
vals = kwargs.values()
for instance in itertools.product(*vals):
yield dict(zip(keys, instance))


CASES = [
dict(B=max(1, 2 ** (16 - i)), Mq=1, Mkv=2**i, Hq=16, Hkv=1, K=128)
for i in range(8, 18)
] + [
dict(B=max(1, 2 ** (16 - i)), Mq=1, Mkv=2**i, Hq=16, Hkv=2, K=128)
for i in range(8, 18)
]


def _setup_test(
functions, fw: bool = False, bw: bool = False, cuda_graph: bool = True, **kwargs
):
for k, benchmark_cls in functions.items():
benchmark_object = benchmark_cls(**kwargs, bw=bw)
label = benchmark_object.label
label += "fw" if fw else ""
label += "bw" if bw else ""

def run_one():
if fw:
benchmark_object.fw()
if bw:
benchmark_object.bw()

if cuda_graph:
run_one()
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
run_one()

def run_one():
g.replay()

yield benchmark.Timer(
stmt="fn()",
globals={
"fn": run_one,
},
label=label,
description=k,
sub_label=benchmark_object.sub_label,
)


class AttentionDecodingFlashDecoding:
OP: Any = xops.fmha.flash.FwOp

def __init__(
self, B: int, Mq: int, Mkv: int, Hq: int, Hkv: int, K: int, bw: bool
) -> None:
dtype = torch.float16
self.sub_label = f"B={B} Mq={Mq} Mkv={Mkv} Hq={Hq} Hkv={Hkv} K={K}"
self.label = "attn_decoding"
self.shapes = (B, Mq, Mkv, Hq, Hkv, K)

assert Hkv <= Hq
assert Hq % Hkv == 0
self.q = torch.randn(
[B, Mq, Hkv, Hq // Hkv, K], device="cuda", dtype=dtype, requires_grad=bw
)
self.k = torch.randn(
[B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw
).expand(-1, -1, -1, Hq // Hkv, -1)
self.v = torch.randn(
[B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw
).expand(-1, -1, -1, Hq // Hkv, -1)

if Hq == Hkv:
self.q = self.q[:, :, :, 0]
self.k = self.k[:, :, :, 0]
self.v = self.v[:, :, :, 0]
if Hkv == 1:
self.q = self.q[:, :, 0]
self.k = self.k[:, :, 0]
self.v = self.v[:, :, 0]

def fw(self) -> None:
xops.memory_efficient_attention_forward(self.q, self.k, self.v, op=self.OP)


class AttentionDecodingSplitKV(AttentionDecodingFlashDecoding):
OP = xops.fmha.triton_splitk.FwOp


class AttentionDecodingPyTorchRepeat(AttentionDecodingFlashDecoding):
def fw(self) -> None:
B, Mq, Mkv, Hq, Hkv, K = self.shapes
scale = 1 / K**0.5
q = self.q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3)
k = self.k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3)
v = self.v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-1, -2)).softmax(-1) * scale
return attn @ v


BENCHMARKS = {
"pytorch": AttentionDecodingPyTorchRepeat,
"flash-decoding": AttentionDecodingFlashDecoding,
"triton_splitK": AttentionDecodingSplitKV,
}


try:
import flash_attn

class AttentionDecodingFlashAttention(AttentionDecodingFlashDecoding):
def fw(self) -> None:
q, k, v = self.q, self.k, self.v
if q.ndim == 5:
B, Mq, H1, H2, K = q.shape
B, Mkv, H1, H2, K = k.shape
q = q.reshape([B, Mq, H1 * H2, K])
k = k[:, :, :, 0]
v = v[:, :, :, 0]
return flash_attn.flash_attn_func(q, k, v)

BENCHMARKS[
f"flash-attention@{flash_attn.__version__}"
] = AttentionDecodingFlashAttention
except ImportError:
pass


def attn_decoding(**kwargs):
yield from _setup_test(
**kwargs,
fw=True,
cuda_graph=True,
functions=BENCHMARKS,
)


benchmark_main_helper(attn_decoding, CASES, min_run_time=min_run_time)

0 comments on commit d3656d5

Please sign in to comment.