Skip to content

Commit

Permalink
Update on "Improve build time by ~30%"
Browse files Browse the repository at this point in the history
... by reducing the number of ATen imports, and skipping them altogether when building the actual kernels

13mn -> 9mn on Sm61 (CI, does not build flash)

[ghstack-poisoned]
  • Loading branch information
danthe3rd committed Nov 28, 2022
2 parents 80bd22d + dfd494a commit c7f5164
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 59 deletions.
172 changes: 114 additions & 58 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import math
import random
from typing import Sequence, Type
from dataclasses import dataclass
from typing import Any, Sequence, Type

import pytest
import torch
Expand Down Expand Up @@ -167,7 +168,7 @@ def ref_attention(q, k, v, attn_bias=None, drop_mask=None, p=0.0, scale=None):
return attn @ v


def ref_attention_bmhk(q, k, v, attn_bias, scale=None):
def ref_attention_bmhk(q, k, v, attn_bias, scale=None) -> torch.Tensor:
assert q.ndim == 4

def T(t):
Expand Down Expand Up @@ -340,6 +341,108 @@ def test_forward(
)


@dataclass
class CuSeqlenInputs:
q: torch.Tensor
k: torch.Tensor
v: torch.Tensor
bias: Sequence[Any]

cu_seqlen_q: torch.Tensor
max_seqlen_q: int
cu_seqlen_k: torch.Tensor
max_seqlen_k: int

def ref_attention(self) -> torch.Tensor:
assert self.q.shape[0] == 1
cu_seqlen_q_cpu = self.cu_seqlen_q.tolist()
cu_seqlen_k_cpu = self.cu_seqlen_k.tolist()
outs = []
for attn_bias, q_start, q_end, k_start, k_end in zip(
self.bias,
cu_seqlen_q_cpu,
cu_seqlen_q_cpu[1:],
cu_seqlen_k_cpu,
cu_seqlen_k_cpu[1:],
):
outs.append(
ref_attention(
self.q[:, q_start:q_end],
self.k[:, k_start:k_end],
self.v[:, k_start:k_end],
attn_bias=attn_bias,
)
)
return torch.cat(outs, dim=1)

@staticmethod
def generate(
B_Mq_Mkv_H_K_Kv, attn_bias_type, dtype, device, op
) -> "CuSeqlenInputs":
batch_size, max_q_len, max_kv_len, num_heads, k, kv = B_Mq_Mkv_H_K_Kv
all_q = []
all_k = []
all_v = []
all_bias = []
cu_seqlen_q = [0]
cu_seqlen_k = [0]
scale = 3
# Reduce batch size to speedup tests
batch_size = min(batch_size, 20)
r = random.Random(max_q_len + k * kv)
torch.manual_seed(r.randint(0, 128))

for batch_id in range(batch_size):
q_len = r.randint(1, max_q_len)
kv_len = r.randint(1, max_kv_len)

all_q.append(
torch.randn((1, q_len, num_heads, k), device=device, dtype=dtype)
* scale
)
all_k.append(
torch.randn((1, kv_len, num_heads, k), device=device, dtype=dtype)
* scale
)
all_v.append(
torch.randn((1, kv_len, num_heads, kv), device=device, dtype=dtype)
* scale
)

if batch_id == 0:
if not op.supports(
xformers.ops.AttentionOpDispatch.from_arguments(
query=all_q[-1], key=all_k[-1], value=all_v[-1]
)
):
pytest.skip("unsupported configuration")

cu_seqlen_q += [cu_seqlen_q[-1] + q_len]
cu_seqlen_k += [cu_seqlen_k[-1] + kv_len]

attn_bias = None
if attn_bias_type is not None:
attn_bias = create_attn_bias(
attn_bias_type,
batch_size=num_heads,
q_len=q_len,
kv_len=kv_len,
dtype=dtype,
device=device,
)
all_bias.append(attn_bias)
return CuSeqlenInputs(
q=torch.cat(all_q, dim=1),
k=torch.cat(all_k, dim=1),
v=torch.cat(all_v, dim=1),
bias=all_bias,
cu_seqlen_q=torch.tensor(cu_seqlen_q, dtype=torch.int32, device=device),
max_seqlen_q=max_q_len,
cu_seqlen_k=torch.tensor(cu_seqlen_k, dtype=torch.int32, device=device),
max_seqlen_k=max_kv_len,
)


@cuda_only
@pytest.mark.parametrize("attn_bias_type", [None, xformers.ops.LowerTriangularMask])
@pytest.mark.parametrize(
Expand All @@ -359,70 +462,23 @@ def test_cu_seqlen_forward(
dtype,
):
device = "cuda"
batch_size, max_q_len, max_kv_len, num_heads, k, kv = B_Mq_Mkv_H_K_Kv
op = xformers.ops.MemoryEfficientAttentionCutlassOp
r = random.Random(max_q_len + k * kv)
torch.manual_seed(r.randint(0, 128))

all_q = []
all_k = []
all_v = []
all_o = []
cu_seqlen_q = [0]
cu_seqlen_k = [0]
scale = 3
# Reduce batch size to speedup tests
batch_size = min(batch_size, 20)

for batch_id in range(batch_size):
q_len = r.randint(1, max_q_len)
kv_len = r.randint(1, max_kv_len)

all_q.append(
torch.randn((1, q_len, num_heads, k), device=device, dtype=dtype) * scale
)
all_k.append(
torch.randn((1, kv_len, num_heads, k), device=device, dtype=dtype) * scale
)
all_v.append(
torch.randn((1, kv_len, num_heads, kv), device=device, dtype=dtype) * scale
)

if batch_id == 0:
if not op.supports(
xformers.ops.AttentionOpDispatch.from_arguments(
query=all_q[-1], key=all_k[-1], value=all_v[-1]
)
):
pytest.skip("unsupported configuration")

cu_seqlen_q += [cu_seqlen_q[-1] + q_len]
cu_seqlen_k += [cu_seqlen_k[-1] + kv_len]

attn_bias = None
if attn_bias_type is not None:
attn_bias = create_attn_bias(
attn_bias_type,
batch_size=num_heads,
q_len=q_len,
kv_len=kv_len,
dtype=dtype,
device=device,
)
all_o.append(ref_attention_bmhk(all_q[-1], all_k[-1], all_v[-1], attn_bias))
inputs = CuSeqlenInputs.generate(B_Mq_Mkv_H_K_Kv, attn_bias_type, dtype, device, op)

out, _ = op.FORWARD_OPERATOR(
torch.cat(all_q, dim=1),
torch.cat(all_k, dim=1),
torch.cat(all_v, dim=1),
max_seqlen_q=max_q_len,
cu_seqlens_q=torch.tensor(cu_seqlen_q, dtype=torch.int32, device=device),
cu_seqlens_k=torch.tensor(cu_seqlen_k, dtype=torch.int32, device=device),
inputs.q,
inputs.k,
inputs.v,
max_seqlen_q=inputs.max_seqlen_q,
cu_seqlens_q=inputs.cu_seqlen_q,
cu_seqlens_k=inputs.cu_seqlen_k,
compute_logsumexp=False,
causal=attn_bias_type is xformers.ops.LowerTriangularMask,
scale=None,
)
ref = torch.cat(all_o, dim=1)
ref = inputs.ref_attention()

assert_allclose(
out.float(),
ref,
Expand Down
2 changes: 1 addition & 1 deletion xformers/ops/unbind.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def forward(ctx, dim: int, *tensors: torch.Tensor):
@classmethod
# type: ignore
def backward(cls, ctx, grad: torch.Tensor):
return None, *(grad.unbind(dim=ctx.dim))
return (None, *grad.unbind(dim=ctx.dim))


def unbind(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, ...]:
Expand Down

0 comments on commit c7f5164

Please sign in to comment.