Skip to content

Commit

Permalink
Flash Attention V2 w/ arbitrary attention bias (#1)
Browse files Browse the repository at this point in the history
* rel position + causal

* we are getting greener

* damn

* format the tests

* cleanup device handiling

* works but I want to benchmark myself

* write pytorch run

* why are teh timers so different

* add profile and apply black

* fix arg parsing

* flashv2

* I like mine more

* bingo
  • Loading branch information
drisspg authored Jul 18, 2023
1 parent 9f14785 commit 44ce5f3
Show file tree
Hide file tree
Showing 6 changed files with 767 additions and 1 deletion.
248 changes: 248 additions & 0 deletions benchmarks/flash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
import argparse
import csv
import enum
import itertools
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import List, Optional, Tuple

import torch
import triton
from torch.nn.functional import scaled_dot_product_attention
from tqdm import tqdm

from transformer_nuggets.flash import BiasMode, attention, build_alibi_mask
from transformer_nuggets.utils import benchmark_torch_function_in_microseconds
import transformer_nuggets.utils as utils

device = torch.device("cuda")


def build_mask(bias_choice, batch, num_heads, seq_len, causal, dtype):
if bias_choice == BiasMode.rel_pos:
attn_bias = build_alibi_mask(seq_len, seq_len, num_heads, scale=1, causal=causal)
attn_bias = attn_bias.expand(batch, num_heads, seq_len, seq_len).to(device).to(dtype)
elif bias_choice == BiasMode.alibi:
attn_bias = build_alibi_mask(seq_len, seq_len, num_heads, scale=None, causal=causal)
attn_bias = attn_bias.expand(batch, num_heads, seq_len, seq_len).to(device).to(dtype)
elif bias_choice == BiasMode.none:
attn_bias = None
return attn_bias


@dataclass
class ExperimentConfig:
bsz: int
num_heads: int
seqlen: int
head_dim: int
bias_choice: BiasMode
causal: bool
dtype: torch.dtype
direction: str


@dataclass
class ExperimentResult:
triton_time: float
pytorch_time: float


def get_input(
config: ExperimentConfig,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
q = torch.randn(
(config.bsz, config.num_heads, config.seqlen, config.head_dim),
dtype=config.dtype,
device=device,
requires_grad=True,
)
k = torch.randn(
(config.bsz, config.num_heads, config.seqlen, config.head_dim),
dtype=config.dtype,
device=device,
requires_grad=True,
)
v = torch.randn(
(config.bsz, config.num_heads, config.seqlen, config.head_dim),
dtype=config.dtype,
device=device,
requires_grad=True,
)
if config.bias_choice != BiasMode.none and config.seqlen < 8192:
mask = build_mask(
config.bias_choice,
config.bsz,
config.num_heads,
config.seqlen,
config.causal,
config.dtype,
)
return q, k, v, mask
else:
return q, k, v, None


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
q, k, v, mask = get_input(config)
causal = config.causal
sm_scale = 1
bias_choice = config.bias_choice
is_causal = causal if (bias_choice == BiasMode.none) else False
if config.direction == "fwd":
if config.seqlen >= 8192 and config.bias_choice != BiasMode.none:
# Skip PyTorch for large seq_len because of OOM
pytorch_time = float("nan")
else:
pytorch_time = benchmark_torch_function_in_microseconds(
scaled_dot_product_attention,
q,
k,
v,
is_causal=is_causal,
attn_mask=mask,
scale=sm_scale,
)
triton_time = benchmark_torch_function_in_microseconds(
attention, q, k, v, causal, sm_scale, bias_choice
)

elif config.direction == "bwd":
out_triton, _ = attention(q, k, v, causal, sm_scale, bias_choice)
dOut = torch.randn_like(out_triton)
triton_time = benchmark_torch_function_in_microseconds(
out_triton.backward, dOut, retain_graph=True
)
if config.seqlen >= 8192 and config.bias_choice != BiasMode.none:
# Skip PyTorch for large seq_len because of OOM
pytorch_time = float("nan")
else:
out_torch = scaled_dot_product_attention(q, k, v, is_causal=is_causal, attn_mask=mask, scale=sm_scale)
pytorch_time = benchmark_torch_function_in_microseconds(
out_torch.backward, dOut, retain_graph=True
)
else:
raise ValueError("Invalid direction")

return ExperimentResult(triton_time, pytorch_time)


class KernelChoice(enum.Enum):
triton = "triton"
torch = "torch"


def profile_experiment(
kernel, config: ExperimentConfig, profile_config: utils.ProfileConfig
) -> ExperimentResult:
q, k, v, mask = get_input(config)
sm_scale = 1
causal = config.causal
bias_choice = config.bias_choice
is_causal = causal if (bias_choice == BiasMode.none) else False
dOut = torch.randn_like(q)
fn = (
lambda: scaled_dot_product_attention(
q, k, v, mask, is_causal=is_causal, scale=sm_scale
).backward(dOut, retain_graph=True)
if kernel == KernelChoice.torch
else lambda: attention(q, k, v, causal, sm_scale, bias_choice).backward(
dOut, retain_graph=True
)
)
utils.profile_function(profile_config, fn)

def gen_configs() -> List[ExperimentConfig]:
seqlens = [512, 1024, 2048, 4096, 8192, 16384]
head_dim = [64, 128]
bias_choices = [BiasMode.none, BiasMode.rel_pos, BiasMode.alibi]
causal = [True, False]
dtypes = [torch.float16]
directions = ["fwd", "bwd"]
configs = []
def get_bsz_num_heads(hidden_dim, seq_len, head_dim, max_tokens=2**14):
# Default max_tokens = 2**14 = 16384
assert hidden_dim % head_dim == 0, "hidden_dim must be divisible by head_dim"
assert max_tokens % seq_len == 0, "max_tokens must be divisible by seq_len"
num_heads = hidden_dim / head_dim
batch_size = max_tokens / seq_len
return int(batch_size), int(num_heads)

for item in itertools.product(
seqlens, head_dim, bias_choices, causal, dtypes, directions
):
# 2048, chosen from FlashV2 Paper
bsz, num_heads = get_bsz_num_heads(2048, *item[:2])
configs.append(ExperimentConfig(bsz, num_heads, *item))
return configs

def main(output_file: Optional[Path], profile_path: Optional[Path]):
if output_file is not None:
configs = gen_configs()
results = []
for experiment_config in tqdm(configs, unit="Experiment"):
experiment_result = run_experiment(experiment_config)
merged = asdict(experiment_config) | asdict(experiment_result)
results.append(merged)

print(f"Writing results to {output_path}")
with open(output_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=results[0].keys())
writer.writeheader()
writer.writerows(results)
else:
print("No output file specified, skipping experiment!")

if profile_path is not None:
if not profile_path.suffix:
profile_path = profile_path.with_suffix(".json")
print(f"Writing profile to {profile_path}")
# Kernel Choice and Experiment Config
kernel_choice = KernelChoice.triton
experiment_config = ExperimentConfig(
4, 32, 4096, 64, BiasMode.none, False, torch.float16, "fwd"
)

profile_config = utils.ProfileConfig(
str(profile_path),
name=f"sdpa-{kernel_choice.value}",
iters=5,
warmup_iters=3,
sync=True,
)
profile_experiment(kernel_choice, experiment_config, profile_config)


if __name__ == "__main__":
"""Sample usage:
# Running sweep
python benchmarks/flash.py -o benchmarks/data/flash_attention_sweep.csv
# only works on post-Ampere GPUs right now
# bench_flash_attention.run(save_path=None, print_data=True)
"""
parser = argparse.ArgumentParser(description="Run experiments and output results to file")
parser.add_argument(
"-o",
"--output_file",
type=str,
help="Path to write out CSV file for experiment results.",
default=None,
)
parser.add_argument(
"-p",
"--profile_path",
type=str,
help="Path to write out json chrome trace file for an experiment.",
default=None,
)
args = parser.parse_args()
output_path = None
profile_path = None
if args.output_file is not None:
output_path = Path(args.output_file)
if args.profile_path is not None:
profile_path = Path(args.profile_path)
if output_path is None and profile_path is None:
raise ValueError("Must specify at least one of output_file or profile_path")
main(output_path, profile_path)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
[project.optional-dependencies]
dev = ["black", "bumpver", "isort", "pip-tools", "pytest"]
qlora = ['bitsandbytes']
flash = ['triton']

[tool.black]
line-length = 99
Expand Down
74 changes: 74 additions & 0 deletions test/test_flash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import torch
import pytest
from transformer_nuggets.flash import BiasMode, build_alibi_mask, attention


@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [(6, 8, 128, 16)])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("bias_choice", [BiasMode.rel_pos, BiasMode.none, BiasMode.alibi])
def test_op(Z, H, N_CTX, D_HEAD, causal, bias_choice, dtype=torch.float16):
torch.manual_seed(20)
q = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)
k = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)
v = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)

sm_scale = 1
dout = torch.randn_like(q)

# reference implementation
if bias_choice == BiasMode.rel_pos:
attn_bias = build_alibi_mask(N_CTX, N_CTX, H, scale=1, causal=causal)
attn_bias = attn_bias.expand(Z, H, N_CTX, N_CTX).to(q.device).to(q.dtype)
elif bias_choice == BiasMode.alibi:
attn_bias = build_alibi_mask(N_CTX, N_CTX, H, scale=None, causal=causal)
attn_bias = attn_bias.expand(Z, H, N_CTX, N_CTX).to(q.device).to(q.dtype)
elif bias_choice == BiasMode.none:
attn_bias = None
is_causal = causal if (bias_choice == BiasMode.none) else False
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False):
ref_out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, scale=sm_scale, is_causal=is_causal, attn_mask=attn_bias
)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# triton implementation
tri_out, mask = attention(q, k, v, causal, sm_scale, bias_choice, True)
tri_out.half()
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# Check attn_bias equivalence
if bias_choice != BiasMode.none:
torch.testing.assert_close(attn_bias, mask.half(), atol=4e-2, rtol=0)

# compare
torch.testing.assert_close(ref_out, tri_out, atol=4e-2, rtol=0)
if bias_choice != BiasMode.none:
fudge_factor = 5
else:
fudge_factor = 1
atol = 1e-2 * fudge_factor
if bias_choice == BiasMode.rel_pos and not causal:
atol *= 3
torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)


if __name__ == "__main__":
pytest.main([__file__])
1 change: 1 addition & 0 deletions transformer_nuggets/flash/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from transformer_nuggets.flash.flash_attention import *
Loading

0 comments on commit 44ce5f3

Please sign in to comment.