Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blocksparse switch revisions #342

Merged
merged 3 commits into from
Jun 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion BENCHMARKS.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ __Some results:__
_Note_: The estimated flops currently miss accounting for many operators, and are almost certainly an undercount. See issue [#154](https://github.com/fairinternal/xformers/issues/154)


## Casual Attention Blocksparse Optimization
## Causal Attention Blocksparse Optimization

FP16 | FP32
:-------------------------:|:-------------------------:
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Support several initialization options [#312]
- Conv2DFeedforward feedforward part [#321]
- VisualAttention [#329]
- Automatic blocksparse for causal attention [#334]


## [0.0.11] - 2022-05-30
Expand Down
21 changes: 20 additions & 1 deletion tests/test_core_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,18 @@
import torch
from torch import nn

from xformers import _is_triton_available
from xformers.components.attention._sputnik_sparse import SparseCS
from xformers.components.attention.attention_mask import AttentionMask
from xformers.components.attention.core import scaled_dot_product_attention

if _is_triton_available:
from xformers.triton.utils import gpu_capabilities_older_than_70

_is_blocksparse_available = (
_is_triton_available and not gpu_capabilities_older_than_70()
)

_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]


Expand Down Expand Up @@ -112,6 +120,9 @@ def test_amp_attention_sparsecs(device):
assert r.dtype == expected_device


@pytest.mark.skipif(
not _is_blocksparse_available, reason="Blocksparse is not available"
)
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("data_type", [torch.float16, torch.float32])
def test_switch_blocksparse(device, data_type):
Expand All @@ -138,9 +149,14 @@ def test_switch_blocksparse(device, data_type):
assert r_sparse.dtype == expected_device

if r_custom.dtype == r_att_mask.dtype:
assert torch.allclose(r_custom, r_att_mask, atol=1e-6, rtol=1e-3)
assert torch.allclose(r_custom, r_att_mask, atol=1e-6, rtol=1e-2)
else: # r_custom fp16, r_att_mask fp32
assert torch.allclose(r_custom, r_att_mask.half(), atol=1e-6, rtol=1e-2)


@pytest.mark.skipif(
not _is_blocksparse_available, reason="Blocksparse is not available"
)
@pytest.mark.parametrize("device", ["cuda"])
def test_switch_blocksparse_dims(device):
b, s, d, nh = 8, 128, 32, 8
Expand All @@ -159,6 +175,9 @@ def test_switch_blocksparse_dims(device):
assert r.dtype == expected_device


@pytest.mark.skipif(
not _is_blocksparse_available, reason="Blocksparse is not available"
)
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("training", [True, False])
@pytest.mark.parametrize("drop_prob", [0.0, 0.3])
Expand Down
2 changes: 2 additions & 0 deletions xformers/benchmarks/benchmark_causal_blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,14 @@ def sdp_attention():
title=f"Causal Blocksparse Runtime FW{bw.upper()} {datatype} Blocksize:{BS}",
units="runtime in ms",
dash_key="torch",
legend_loc="upper left",
)
pretty_plot(
results_mem,
title=f"Causal Blocksparse Memory FW{bw.upper()} {datatype} Blocksize:{BS}",
units="peak memory usage in MB",
dash_key="torch",
legend_loc="upper left",
)


Expand Down
6 changes: 4 additions & 2 deletions xformers/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def pretty_print(results, title, units):
print("")


def pretty_plot(results, title, units: str, filename=None, dash_key=""):
def pretty_plot(
results, title, units: str, filename=None, dash_key="", legend_loc="bottom_right"
):
"""Graph out the contents of a dict.
Dash key means that if the result label has this key, then it will be displayed with a dash"""

Expand Down Expand Up @@ -86,7 +88,7 @@ def pretty_plot(results, title, units: str, filename=None, dash_key=""):
plt.plot(list(results.keys()), v)

plt.title(title)
plt.legend(list(workloads.keys()), loc="lower right")
plt.legend(list(workloads.keys()), loc=legend_loc)
plt.ylabel(units)
plt.xticks(rotation=45)

Expand Down
59 changes: 34 additions & 25 deletions xformers/components/attention/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.


import logging
import math
from contextlib import nullcontext
from functools import lru_cache
Expand All @@ -13,14 +14,20 @@

from xformers import _is_sparse_available, _is_triton_available
from xformers.components.attention.attention_mask import AttentionMask
from xformers.components.attention.blocksparse import BlockSparseAttention
from xformers.components.attention.utils import reshape_heads

if _is_sparse_available:
from ._sputnik_sparse import SparseCS

if _is_triton_available:
from xformers.triton.softmax import softmax as triton_softmax
from xformers.triton.utils import gpu_capabilities_older_than_70

_is_blocksparse_available = (
_is_triton_available and not gpu_capabilities_older_than_70()
)

if _is_blocksparse_available:
from xformers.components.attention.blocksparse import BlockSparseAttention


def _create_random_sparsity(matrix, sparsity, divisible_by=4):
Expand Down Expand Up @@ -215,17 +222,19 @@ def scaled_query_key_softmax(
return att


# 128 is default maxsize
@lru_cache(maxsize=128)
def _retrieve_blocksparse(
num_heads: int, seq_len: int, block_size: int
) -> BlockSparseAttention:
# Checks if blocksparse object exists in cache
if _is_blocksparse_available:
# 128 is default maxsize
@lru_cache(maxsize=128)
def _retrieve_blocksparse(
num_heads: int, seq_len: int, block_size: int
) -> BlockSparseAttention:
# Checks if blocksparse object exists in cache

blocks = seq_len // block_size
print("Made uncached blocksparse")
layout_fill = torch.ones((num_heads, blocks, blocks), dtype=torch.long)
return BlockSparseAttention(layout=layout_fill, block_size=block_size, causal=True)
blocks = seq_len // block_size
layout_fill = torch.ones((num_heads, blocks, blocks), dtype=torch.long)
return BlockSparseAttention(
layout=layout_fill, block_size=block_size, causal=True
)


def blocksparse_attention(
Expand Down Expand Up @@ -266,7 +275,7 @@ def blocksparse_attention(

# Reshape attention (B, nh, S, hs) back to (N, S, hs)
if orig_dim == 3:
return reshape_heads(att, *att.size())
return att.flatten(0, 1)
return att


Expand All @@ -276,31 +285,31 @@ def scaled_dot_product_attention(
v: torch.Tensor,
att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]],
dropout: Optional[torch.nn.Module] = None,
block_size=128,
block_size: int = 128,
) -> torch.Tensor:
autocast_disabled = (
_is_sparse_available
and isinstance(att_mask, SparseCS)
or (att_mask is not None and att_mask.is_sparse)
)
seq_len = q.shape[-2]

# Check if causal is required but mask is not sparse; if fp16 or under amp context
# switch if:
# causal is required but mask is not sparse
# fp16 or under amp context
# sequence length is divisible by block size
# same seq len for K and Q
switch_to_blocksparse = (
_is_triton_available
_is_blocksparse_available
and (att_mask is not None and not att_mask.is_sparse)
and (isinstance(att_mask, AttentionMask) and att_mask.is_causal)
and (q.dtype == torch.float16 or torch.is_autocast_enabled())
)

# Switch only if sequence length is divisible by block size
# Blocksparse requires the same dimensions for K and Q for now
seq_len = q.shape[-2]
if (
switch_to_blocksparse
and not seq_len % block_size
and q.shape[-2] == k.shape[-2]
):
# print("switching to blocksparse...")
)

if switch_to_blocksparse:
logging.info("Switching causal attention to Triton blocksparse...")
return blocksparse_attention(q, k, v, dropout, block_size)

with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext():
Expand Down
10 changes: 0 additions & 10 deletions xformers/components/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,3 @@ def bool_mask_to_additive(
mask_ = torch.zeros_like(mask, dtype=dtype)
mask_[~mask] = float("-inf")
return mask_


# (B, S, D) to (B, S, nh, hs)
def split_heads(t: torch.Tensor, B: int, nH: int, S: int, Hs: int):
return t.view(B, nH, S, Hs)


# (B, nh, S, hs) back to (N, S, hs)
def reshape_heads(t: torch.Tensor, B: int, nH: int, S: int, Hs: int):
return t.view(B * nH, S, Hs)