Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Block mask #3

Merged
merged 4 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ jobs:
python -m pip install --upgrade pip
pip install -e .
pip install -e .'[dev]'
- name: Lint with flake8
- name: Lint with ruff
run: |
flake8
ruff check .
- name: Test with pytest
run: |
pytest
9 changes: 5 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ repos:
- ufmt == 2.1.0
- libcst == 1.0.1

- repo: https://github.com/pycqa/flake8
rev: 7.0.0
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.0
hooks:
- id: flake8
additional_dependencies: [flake8-pyproject]
# Run the linter.
- id: ruff
48 changes: 40 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ classifiers = [
]

dependencies = [
"torch >= 2.1.1",
"torch >= 2.2.1",
"scipy >= 1.9.1",
"tqdm >= 4.66",
"tabulate >= 0.8",
Expand All @@ -35,8 +35,7 @@ dev = [
"bumpver",
"pip-tools",
"pytest",
"flake8==6.1.0",
"flake8-pyproject",
"ruff==0.3.0",
"jsonargparse",
"docstring-parser"
]
Expand All @@ -52,16 +51,49 @@ llama = [
]

# ---------- TOOL CONFIGURATIONS ------------
[tool.flake8]
max-line-length = 99
ignore = ['E231', 'E241', 'E501', 'C408', 'E261', 'E731', 'G004', 'W503', 'E203']
per-file-ignores = [
'__init__.py:F401',

# ---------- RUFF ------------
[tool.ruff]
ignore = ['E231', 'E731']
# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".ipynb_checkpoints",
".mypy_cache",
".nox",
".pants.d",
".pyenv",
".pytest_cache",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
".vscode",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"site-packages",
"venv",
]

[tool.ruff.per-file-ignores]
"__init__.py" = ["F401", "F403"]

# ---------- UFMT ------------

[tool.usort]
first_party_detection = false

# ---------- Black ------------
[tool.black]
target-version = ["py38"]
line-length = 99
Expand Down
140 changes: 115 additions & 25 deletions test/test_flash.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,47 @@
import pytest
import torch

from torch.nn.attention import sdpa_kernel, SDPBackend
from transformer_nuggets.flash import attention, BiasMode, build_rel_mask
from transformer_nuggets.flash import attention, BiasMode, build_causal_mask, build_rel_mask


@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [(6, 8, 128, 16)])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("bias_choice", [BiasMode.rel_pos, BiasMode.none, BiasMode.alibi])
def clone_grad_and_reset(tensor):
cloned_grad = tensor.grad.clone()
tensor.grad = None
return cloned_grad


def clone_grad_and_reset_all(*tensors):
return (clone_grad_and_reset(tensor) for tensor in tensors)


def maybe_grab_upper_section(mask, N_CTX, causal):
BLOCK_M = 128
if N_CTX > BLOCK_M and causal:
# Since the kernel will not iterate over all seq_len_kv when causal
# We will only check the minimum rectangular block
mask = mask[:, :, :, :BLOCK_M]
return mask


def check_bias(bias_choice, causal, attn_bias, mask, N_CTX):
if bias_choice != BiasMode.none:
mask = maybe_grab_upper_section(mask, N_CTX, causal)
attn_bias = maybe_grab_upper_section(attn_bias, N_CTX, causal)
torch.testing.assert_close(attn_bias, mask.to(attn_bias.dtype), atol=4e-2, rtol=0)


@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [(1, 8, 128, 16)])
@pytest.mark.parametrize("is_causal", [False])
@pytest.mark.parametrize(
"bias_choice", [BiasMode.rel_pos, BiasMode.none, BiasMode.alibi, BiasMode.causal]
)
@pytest.mark.parametrize("sm_scale", [None, 1])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
def test_op(Z, H, N_CTX, D_HEAD, causal, bias_choice, sm_scale, dtype=torch.float16):
def test_flash_specific_masks(
Z, H, N_CTX, D_HEAD, is_causal, bias_choice, sm_scale, dtype=torch.float16
):
from torch.nn.attention import sdpa_kernel, SDPBackend

torch.manual_seed(20)
q = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
Expand All @@ -33,41 +64,100 @@ def test_op(Z, H, N_CTX, D_HEAD, causal, bias_choice, sm_scale, dtype=torch.floa
dout = torch.randn_like(q)

# reference implementation
if bias_choice == BiasMode.none:
attn_bias = None
else:
attn_bias = build_rel_mask(N_CTX, N_CTX, H, bias_choice, causal=causal)
is_causal = False
attn_bias = None
if bias_choice in {BiasMode.causal}:
attn_bias = (
build_causal_mask(N_CTX, N_CTX)
.to(device=q.device, dtype=q.dtype)
.expand(Z, H, N_CTX, N_CTX)
)
elif bias_choice in {BiasMode.rel_pos, BiasMode.alibi}:
attn_bias = build_rel_mask(N_CTX, N_CTX, H, bias_choice, causal=is_causal)
attn_bias = attn_bias.expand(Z, H, N_CTX, N_CTX).to(q.device).to(q.dtype)
elif bias_choice == BiasMode.none:
pass
else:
raise ValueError(f"Invalid bias_choice: {bias_choice}")

is_causal = causal if (bias_choice == BiasMode.none) else False
with sdpa_kernel(SDPBackend.MATH):
ref_out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, scale=sm_scale, is_causal=is_causal, attn_mask=attn_bias
)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
ref_dq, ref_dk, ref_dv = clone_grad_and_reset_all(q, k, v)
# triton implementation
tri_out, mask = attention(q, k, v, causal, sm_scale, bias_choice, True)
tri_out, mask = attention(q, k, v, is_causal, sm_scale, bias_choice, True)
tri_out.half()
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# Check attn_bias equivalence
if bias_choice != BiasMode.none:
torch.testing.assert_close(attn_bias, mask.half(), atol=4e-2, rtol=0)
tri_dq, tri_dk, tri_dv = clone_grad_and_reset_all(q, k, v)

# compare
torch.testing.assert_close(ref_out, tri_out, atol=5.5e-2, rtol=0)
check_bias(bias_choice, is_causal, attn_bias, mask, N_CTX)

torch.testing.assert_close(ref_out, tri_out, atol=5.8e-2, rtol=0)
if bias_choice != BiasMode.none:
fudge_factor = 6.1
else:
fudge_factor = 1
atol = 1e-2 * fudge_factor
if bias_choice == BiasMode.rel_pos and not causal:
atol *= 3
atol = 2e-2 * fudge_factor
if bias_choice == BiasMode.rel_pos and not is_causal:
atol *= 4.5
torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)


@pytest.mark.xfail(reason="This test is failing due to a bug in the implementation")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
def test_flash_masked_block(dtype=torch.float16):
from torch.nn.attention import sdpa_kernel, SDPBackend

torch.manual_seed(20)
Z, H, N_CTX, D_HEAD = (6, 8, 256, 16)
q = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)
k = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)
v = (
torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)

sm_scale = 1 / (D_HEAD**0.5)

temp_mask = torch.ones((Z, H, N_CTX, N_CTX)).tril_(-1).bool()
ref_mask = torch.zeros_like(temp_mask, dtype=torch.float32)
ref_mask.masked_fill_(temp_mask, float("-inf"))
ref_mask = ref_mask.to(q.device).to(q.dtype)
dout = torch.randn_like(q)
with sdpa_kernel(SDPBackend.MATH):
ref_out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, scale=sm_scale, is_causal=False, attn_mask=ref_mask
)

ref_out.backward(dout)
ref_dq, ref_dk, ref_dv = clone_grad_and_reset_all(q, k, v)

tri_out, mask = attention(q, k, v, False, sm_scale, BiasMode.inverse_causal, True) # type: ignore

tri_out.half()
tri_out.backward(dout)
tri_dq, tri_dk, tri_dv = clone_grad_and_reset_all(q, k, v)
# Check attn_bias equivalence
atol = 2e-2 * 6
# compare
check_bias(BiasMode.inverse_causal, False, ref_mask, mask, N_CTX)

torch.testing.assert_close(ref_out, tri_out, atol=5.8e-2, rtol=0)

torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)
Expand Down
3 changes: 2 additions & 1 deletion transformer_nuggets/flash/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from transformer_nuggets.flash.flash_attention import * # noqa: F403
from transformer_nuggets.flash.flash_attention import *
from transformer_nuggets.flash.masks import *
Loading
Loading