Skip to content

Commit

Permalink
Blocksparse switch revisions (#342)
Browse files Browse the repository at this point in the history
* minor cleanup; updated changelog

* fixed mypy error

* added checking for blocksparse availability

Co-authored-by: Chris Yuan <christopheryuan@learnfair1490.h2.fair>
Co-authored-by: Chris Yuan <christopheryuan@devfair0278.h2.fair>
  • Loading branch information
3 people committed Jun 27, 2022
1 parent e3aa730 commit 12e8abc
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 39 deletions.
2 changes: 1 addition & 1 deletion BENCHMARKS.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ __Some results:__
_Note_: The estimated flops currently miss accounting for many operators, and are almost certainly an undercount. See issue [#154](https://github.com/fairinternal/xformers/issues/154)


## Casual Attention Blocksparse Optimization
## Causal Attention Blocksparse Optimization

FP16 | FP32
:-------------------------:|:-------------------------:
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Support several initialization options [#312]
- Conv2DFeedforward feedforward part [#321]
- VisualAttention [#329]
- Automatic blocksparse for causal attention [#334]


## [0.0.11] - 2022-05-30
Expand Down
21 changes: 20 additions & 1 deletion tests/test_core_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,18 @@
import torch
from torch import nn

from xformers import _is_triton_available
from xformers.components.attention._sputnik_sparse import SparseCS
from xformers.components.attention.attention_mask import AttentionMask
from xformers.components.attention.core import scaled_dot_product_attention

if _is_triton_available:
from xformers.triton.utils import gpu_capabilities_older_than_70

_is_blocksparse_available = (
_is_triton_available and not gpu_capabilities_older_than_70()
)

_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]


Expand Down Expand Up @@ -112,6 +120,9 @@ def test_amp_attention_sparsecs(device):
assert r.dtype == expected_device


@pytest.mark.skipif(
not _is_blocksparse_available, reason="Blocksparse is not available"
)
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("data_type", [torch.float16, torch.float32])
def test_switch_blocksparse(device, data_type):
Expand All @@ -138,9 +149,14 @@ def test_switch_blocksparse(device, data_type):
assert r_sparse.dtype == expected_device

if r_custom.dtype == r_att_mask.dtype:
assert torch.allclose(r_custom, r_att_mask, atol=1e-6, rtol=1e-3)
assert torch.allclose(r_custom, r_att_mask, atol=1e-6, rtol=1e-2)
else: # r_custom fp16, r_att_mask fp32
assert torch.allclose(r_custom, r_att_mask.half(), atol=1e-6, rtol=1e-2)


@pytest.mark.skipif(
not _is_blocksparse_available, reason="Blocksparse is not available"
)
@pytest.mark.parametrize("device", ["cuda"])
def test_switch_blocksparse_dims(device):
b, s, d, nh = 8, 128, 32, 8
Expand All @@ -159,6 +175,9 @@ def test_switch_blocksparse_dims(device):
assert r.dtype == expected_device


@pytest.mark.skipif(
not _is_blocksparse_available, reason="Blocksparse is not available"
)
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("training", [True, False])
@pytest.mark.parametrize("drop_prob", [0.0, 0.3])
Expand Down
2 changes: 2 additions & 0 deletions xformers/benchmarks/benchmark_causal_blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,14 @@ def sdp_attention():
title=f"Causal Blocksparse Runtime FW{bw.upper()} {datatype} Blocksize:{BS}",
units="runtime in ms",
dash_key="torch",
legend_loc="upper left",
)
pretty_plot(
results_mem,
title=f"Causal Blocksparse Memory FW{bw.upper()} {datatype} Blocksize:{BS}",
units="peak memory usage in MB",
dash_key="torch",
legend_loc="upper left",
)


Expand Down
6 changes: 4 additions & 2 deletions xformers/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def pretty_print(results, title, units):
print("")


def pretty_plot(results, title, units: str, filename=None, dash_key=""):
def pretty_plot(
results, title, units: str, filename=None, dash_key="", legend_loc="bottom_right"
):
"""Graph out the contents of a dict.
Dash key means that if the result label has this key, then it will be displayed with a dash"""

Expand Down Expand Up @@ -86,7 +88,7 @@ def pretty_plot(results, title, units: str, filename=None, dash_key=""):
plt.plot(list(results.keys()), v)

plt.title(title)
plt.legend(list(workloads.keys()), loc="lower right")
plt.legend(list(workloads.keys()), loc=legend_loc)
plt.ylabel(units)
plt.xticks(rotation=45)

Expand Down
59 changes: 34 additions & 25 deletions xformers/components/attention/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.


import logging
import math
from contextlib import nullcontext
from functools import lru_cache
Expand All @@ -13,14 +14,20 @@

from xformers import _is_sparse_available, _is_triton_available
from xformers.components.attention.attention_mask import AttentionMask
from xformers.components.attention.blocksparse import BlockSparseAttention
from xformers.components.attention.utils import reshape_heads

if _is_sparse_available:
from ._sputnik_sparse import SparseCS

if _is_triton_available:
from xformers.triton.softmax import softmax as triton_softmax
from xformers.triton.utils import gpu_capabilities_older_than_70

_is_blocksparse_available = (
_is_triton_available and not gpu_capabilities_older_than_70()
)

if _is_blocksparse_available:
from xformers.components.attention.blocksparse import BlockSparseAttention


def _create_random_sparsity(matrix, sparsity, divisible_by=4):
Expand Down Expand Up @@ -215,17 +222,19 @@ def scaled_query_key_softmax(
return att


# 128 is default maxsize
@lru_cache(maxsize=128)
def _retrieve_blocksparse(
num_heads: int, seq_len: int, block_size: int
) -> BlockSparseAttention:
# Checks if blocksparse object exists in cache
if _is_blocksparse_available:
# 128 is default maxsize
@lru_cache(maxsize=128)
def _retrieve_blocksparse(
num_heads: int, seq_len: int, block_size: int
) -> BlockSparseAttention:
# Checks if blocksparse object exists in cache

blocks = seq_len // block_size
print("Made uncached blocksparse")
layout_fill = torch.ones((num_heads, blocks, blocks), dtype=torch.long)
return BlockSparseAttention(layout=layout_fill, block_size=block_size, causal=True)
blocks = seq_len // block_size
layout_fill = torch.ones((num_heads, blocks, blocks), dtype=torch.long)
return BlockSparseAttention(
layout=layout_fill, block_size=block_size, causal=True
)


def blocksparse_attention(
Expand Down Expand Up @@ -266,7 +275,7 @@ def blocksparse_attention(

# Reshape attention (B, nh, S, hs) back to (N, S, hs)
if orig_dim == 3:
return reshape_heads(att, *att.size())
return att.flatten(0, 1)
return att


Expand All @@ -276,31 +285,31 @@ def scaled_dot_product_attention(
v: torch.Tensor,
att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]],
dropout: Optional[torch.nn.Module] = None,
block_size=128,
block_size: int = 128,
) -> torch.Tensor:
autocast_disabled = (
_is_sparse_available
and isinstance(att_mask, SparseCS)
or (att_mask is not None and att_mask.is_sparse)
)
seq_len = q.shape[-2]

# Check if causal is required but mask is not sparse; if fp16 or under amp context
# switch if:
# causal is required but mask is not sparse
# fp16 or under amp context
# sequence length is divisible by block size
# same seq len for K and Q
switch_to_blocksparse = (
_is_triton_available
_is_blocksparse_available
and (att_mask is not None and not att_mask.is_sparse)
and (isinstance(att_mask, AttentionMask) and att_mask.is_causal)
and (q.dtype == torch.float16 or torch.is_autocast_enabled())
)

# Switch only if sequence length is divisible by block size
# Blocksparse requires the same dimensions for K and Q for now
seq_len = q.shape[-2]
if (
switch_to_blocksparse
and not seq_len % block_size
and q.shape[-2] == k.shape[-2]
):
# print("switching to blocksparse...")
)

if switch_to_blocksparse:
logging.info("Switching causal attention to Triton blocksparse...")
return blocksparse_attention(q, k, v, dropout, block_size)

with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext():
Expand Down
10 changes: 0 additions & 10 deletions xformers/components/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,3 @@ def bool_mask_to_additive(
mask_ = torch.zeros_like(mask, dtype=dtype)
mask_[~mask] = float("-inf")
return mask_


# (B, S, D) to (B, S, nh, hs)
def split_heads(t: torch.Tensor, B: int, nH: int, S: int, Hs: int):
return t.view(B, nH, S, Hs)


# (B, nh, S, hs) back to (N, S, hs)
def reshape_heads(t: torch.Tensor, B: int, nH: int, S: int, Hs: int):
return t.view(B * nH, S, Hs)

0 comments on commit 12e8abc

Please sign in to comment.