diff --git a/BENCHMARKS.md b/BENCHMARKS.md index ef04add59d..65101d8459 100644 --- a/BENCHMARKS.md +++ b/BENCHMARKS.md @@ -129,7 +129,7 @@ __Some results:__ _Note_: The estimated flops currently miss accounting for many operators, and are almost certainly an undercount. See issue [#154](https://github.com/fairinternal/xformers/issues/154) -## Casual Attention Blocksparse Optimization +## Causal Attention Blocksparse Optimization FP16 | FP32 :-------------------------:|:-------------------------: diff --git a/CHANGELOG.md b/CHANGELOG.md index df328734bd..fe2c87b01d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support several initialization options [#312] - Conv2DFeedforward feedforward part [#321] - VisualAttention [#329] +- Automatic blocksparse for causal attention [#334] ## [0.0.11] - 2022-05-30 diff --git a/tests/test_core_attention.py b/tests/test_core_attention.py index b1cc306fed..26c71bb05e 100644 --- a/tests/test_core_attention.py +++ b/tests/test_core_attention.py @@ -7,10 +7,18 @@ import torch from torch import nn +from xformers import _is_triton_available from xformers.components.attention._sputnik_sparse import SparseCS from xformers.components.attention.attention_mask import AttentionMask from xformers.components.attention.core import scaled_dot_product_attention +if _is_triton_available: + from xformers.triton.utils import gpu_capabilities_older_than_70 + +_is_blocksparse_available = ( + _is_triton_available and not gpu_capabilities_older_than_70() +) + _devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] @@ -112,6 +120,9 @@ def test_amp_attention_sparsecs(device): assert r.dtype == expected_device +@pytest.mark.skipif( + not _is_blocksparse_available, reason="Blocksparse is not available" +) @pytest.mark.parametrize("device", ["cuda"]) @pytest.mark.parametrize("data_type", [torch.float16, torch.float32]) def test_switch_blocksparse(device, data_type): @@ -138,9 +149,14 @@ def test_switch_blocksparse(device, data_type): assert r_sparse.dtype == expected_device if r_custom.dtype == r_att_mask.dtype: - assert torch.allclose(r_custom, r_att_mask, atol=1e-6, rtol=1e-3) + assert torch.allclose(r_custom, r_att_mask, atol=1e-6, rtol=1e-2) + else: # r_custom fp16, r_att_mask fp32 + assert torch.allclose(r_custom, r_att_mask.half(), atol=1e-6, rtol=1e-2) +@pytest.mark.skipif( + not _is_blocksparse_available, reason="Blocksparse is not available" +) @pytest.mark.parametrize("device", ["cuda"]) def test_switch_blocksparse_dims(device): b, s, d, nh = 8, 128, 32, 8 @@ -159,6 +175,9 @@ def test_switch_blocksparse_dims(device): assert r.dtype == expected_device +@pytest.mark.skipif( + not _is_blocksparse_available, reason="Blocksparse is not available" +) @pytest.mark.parametrize("device", ["cuda"]) @pytest.mark.parametrize("training", [True, False]) @pytest.mark.parametrize("drop_prob", [0.0, 0.3]) diff --git a/xformers/benchmarks/benchmark_causal_blocksparse.py b/xformers/benchmarks/benchmark_causal_blocksparse.py index 05d630ba91..70a5e499b9 100644 --- a/xformers/benchmarks/benchmark_causal_blocksparse.py +++ b/xformers/benchmarks/benchmark_causal_blocksparse.py @@ -122,12 +122,14 @@ def sdp_attention(): title=f"Causal Blocksparse Runtime FW{bw.upper()} {datatype} Blocksize:{BS}", units="runtime in ms", dash_key="torch", + legend_loc="upper left", ) pretty_plot( results_mem, title=f"Causal Blocksparse Memory FW{bw.upper()} {datatype} Blocksize:{BS}", units="peak memory usage in MB", dash_key="torch", + legend_loc="upper left", ) diff --git a/xformers/benchmarks/utils.py b/xformers/benchmarks/utils.py index 8b5838fd44..37c3be02b2 100644 --- a/xformers/benchmarks/utils.py +++ b/xformers/benchmarks/utils.py @@ -55,7 +55,9 @@ def pretty_print(results, title, units): print("") -def pretty_plot(results, title, units: str, filename=None, dash_key=""): +def pretty_plot( + results, title, units: str, filename=None, dash_key="", legend_loc="bottom_right" +): """Graph out the contents of a dict. Dash key means that if the result label has this key, then it will be displayed with a dash""" @@ -86,7 +88,7 @@ def pretty_plot(results, title, units: str, filename=None, dash_key=""): plt.plot(list(results.keys()), v) plt.title(title) - plt.legend(list(workloads.keys()), loc="lower right") + plt.legend(list(workloads.keys()), loc=legend_loc) plt.ylabel(units) plt.xticks(rotation=45) diff --git a/xformers/components/attention/core.py b/xformers/components/attention/core.py index 0dd457466c..9c8e47404b 100644 --- a/xformers/components/attention/core.py +++ b/xformers/components/attention/core.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. +import logging import math from contextlib import nullcontext from functools import lru_cache @@ -13,14 +14,20 @@ from xformers import _is_sparse_available, _is_triton_available from xformers.components.attention.attention_mask import AttentionMask -from xformers.components.attention.blocksparse import BlockSparseAttention -from xformers.components.attention.utils import reshape_heads if _is_sparse_available: from ._sputnik_sparse import SparseCS if _is_triton_available: from xformers.triton.softmax import softmax as triton_softmax + from xformers.triton.utils import gpu_capabilities_older_than_70 + +_is_blocksparse_available = ( + _is_triton_available and not gpu_capabilities_older_than_70() +) + +if _is_blocksparse_available: + from xformers.components.attention.blocksparse import BlockSparseAttention def _create_random_sparsity(matrix, sparsity, divisible_by=4): @@ -215,17 +222,19 @@ def scaled_query_key_softmax( return att -# 128 is default maxsize -@lru_cache(maxsize=128) -def _retrieve_blocksparse( - num_heads: int, seq_len: int, block_size: int -) -> BlockSparseAttention: - # Checks if blocksparse object exists in cache +if _is_blocksparse_available: + # 128 is default maxsize + @lru_cache(maxsize=128) + def _retrieve_blocksparse( + num_heads: int, seq_len: int, block_size: int + ) -> BlockSparseAttention: + # Checks if blocksparse object exists in cache - blocks = seq_len // block_size - print("Made uncached blocksparse") - layout_fill = torch.ones((num_heads, blocks, blocks), dtype=torch.long) - return BlockSparseAttention(layout=layout_fill, block_size=block_size, causal=True) + blocks = seq_len // block_size + layout_fill = torch.ones((num_heads, blocks, blocks), dtype=torch.long) + return BlockSparseAttention( + layout=layout_fill, block_size=block_size, causal=True + ) def blocksparse_attention( @@ -266,7 +275,7 @@ def blocksparse_attention( # Reshape attention (B, nh, S, hs) back to (N, S, hs) if orig_dim == 3: - return reshape_heads(att, *att.size()) + return att.flatten(0, 1) return att @@ -276,31 +285,31 @@ def scaled_dot_product_attention( v: torch.Tensor, att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]], dropout: Optional[torch.nn.Module] = None, - block_size=128, + block_size: int = 128, ) -> torch.Tensor: autocast_disabled = ( _is_sparse_available and isinstance(att_mask, SparseCS) or (att_mask is not None and att_mask.is_sparse) ) + seq_len = q.shape[-2] - # Check if causal is required but mask is not sparse; if fp16 or under amp context + # switch if: + # causal is required but mask is not sparse + # fp16 or under amp context + # sequence length is divisible by block size + # same seq len for K and Q switch_to_blocksparse = ( - _is_triton_available + _is_blocksparse_available and (att_mask is not None and not att_mask.is_sparse) and (isinstance(att_mask, AttentionMask) and att_mask.is_causal) and (q.dtype == torch.float16 or torch.is_autocast_enabled()) - ) - - # Switch only if sequence length is divisible by block size - # Blocksparse requires the same dimensions for K and Q for now - seq_len = q.shape[-2] - if ( - switch_to_blocksparse and not seq_len % block_size and q.shape[-2] == k.shape[-2] - ): - # print("switching to blocksparse...") + ) + + if switch_to_blocksparse: + logging.info("Switching causal attention to Triton blocksparse...") return blocksparse_attention(q, k, v, dropout, block_size) with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): diff --git a/xformers/components/attention/utils.py b/xformers/components/attention/utils.py index 012188e1a2..d6bb06a1ac 100644 --- a/xformers/components/attention/utils.py +++ b/xformers/components/attention/utils.py @@ -106,13 +106,3 @@ def bool_mask_to_additive( mask_ = torch.zeros_like(mask, dtype=dtype) mask_[~mask] = float("-inf") return mask_ - - -# (B, S, D) to (B, S, nh, hs) -def split_heads(t: torch.Tensor, B: int, nH: int, S: int, Hs: int): - return t.view(B, nH, S, Hs) - - -# (B, nh, S, hs) back to (N, S, hs) -def reshape_heads(t: torch.Tensor, B: int, nH: int, S: int, Hs: int): - return t.view(B * nH, S, Hs)