From 5afc03b336da274ee86a3de43040bf6011cb0226 Mon Sep 17 00:00:00 2001 From: Chris Yuan Date: Mon, 27 Jun 2022 07:17:16 -0700 Subject: [PATCH 1/3] minor cleanup; updated changelog --- BENCHMARKS.md | 2 +- CHANGELOG.md | 1 + tests/test_core_attention.py | 4 ++- .../benchmark_causal_blocksparse.py | 2 ++ xformers/benchmarks/utils.py | 6 +++-- xformers/components/attention/core.py | 27 +++++++++---------- xformers/components/attention/utils.py | 10 ------- 7 files changed, 24 insertions(+), 28 deletions(-) diff --git a/BENCHMARKS.md b/BENCHMARKS.md index ef04add59d..65101d8459 100644 --- a/BENCHMARKS.md +++ b/BENCHMARKS.md @@ -129,7 +129,7 @@ __Some results:__ _Note_: The estimated flops currently miss accounting for many operators, and are almost certainly an undercount. See issue [#154](https://github.com/fairinternal/xformers/issues/154) -## Casual Attention Blocksparse Optimization +## Causal Attention Blocksparse Optimization FP16 | FP32 :-------------------------:|:-------------------------: diff --git a/CHANGELOG.md b/CHANGELOG.md index df328734bd..fe2c87b01d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support several initialization options [#312] - Conv2DFeedforward feedforward part [#321] - VisualAttention [#329] +- Automatic blocksparse for causal attention [#334] ## [0.0.11] - 2022-05-30 diff --git a/tests/test_core_attention.py b/tests/test_core_attention.py index b1cc306fed..9a6c579b99 100644 --- a/tests/test_core_attention.py +++ b/tests/test_core_attention.py @@ -138,7 +138,9 @@ def test_switch_blocksparse(device, data_type): assert r_sparse.dtype == expected_device if r_custom.dtype == r_att_mask.dtype: - assert torch.allclose(r_custom, r_att_mask, atol=1e-6, rtol=1e-3) + assert torch.allclose(r_custom, r_att_mask, atol=1e-6, rtol=1e-2) + else: # r_custom fp16, r_att_mask fp32 + assert torch.allclose(r_custom, r_att_mask.half(), atol=1e-6, rtol=1e-2) @pytest.mark.parametrize("device", ["cuda"]) diff --git a/xformers/benchmarks/benchmark_causal_blocksparse.py b/xformers/benchmarks/benchmark_causal_blocksparse.py index 05d630ba91..70a5e499b9 100644 --- a/xformers/benchmarks/benchmark_causal_blocksparse.py +++ b/xformers/benchmarks/benchmark_causal_blocksparse.py @@ -122,12 +122,14 @@ def sdp_attention(): title=f"Causal Blocksparse Runtime FW{bw.upper()} {datatype} Blocksize:{BS}", units="runtime in ms", dash_key="torch", + legend_loc="upper left", ) pretty_plot( results_mem, title=f"Causal Blocksparse Memory FW{bw.upper()} {datatype} Blocksize:{BS}", units="peak memory usage in MB", dash_key="torch", + legend_loc="upper left", ) diff --git a/xformers/benchmarks/utils.py b/xformers/benchmarks/utils.py index 8b5838fd44..37c3be02b2 100644 --- a/xformers/benchmarks/utils.py +++ b/xformers/benchmarks/utils.py @@ -55,7 +55,9 @@ def pretty_print(results, title, units): print("") -def pretty_plot(results, title, units: str, filename=None, dash_key=""): +def pretty_plot( + results, title, units: str, filename=None, dash_key="", legend_loc="bottom_right" +): """Graph out the contents of a dict. Dash key means that if the result label has this key, then it will be displayed with a dash""" @@ -86,7 +88,7 @@ def pretty_plot(results, title, units: str, filename=None, dash_key=""): plt.plot(list(results.keys()), v) plt.title(title) - plt.legend(list(workloads.keys()), loc="lower right") + plt.legend(list(workloads.keys()), loc=legend_loc) plt.ylabel(units) plt.xticks(rotation=45) diff --git a/xformers/components/attention/core.py b/xformers/components/attention/core.py index 0dd457466c..449e475de4 100644 --- a/xformers/components/attention/core.py +++ b/xformers/components/attention/core.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. +import logging import math from contextlib import nullcontext from functools import lru_cache @@ -14,7 +15,6 @@ from xformers import _is_sparse_available, _is_triton_available from xformers.components.attention.attention_mask import AttentionMask from xformers.components.attention.blocksparse import BlockSparseAttention -from xformers.components.attention.utils import reshape_heads if _is_sparse_available: from ._sputnik_sparse import SparseCS @@ -223,7 +223,6 @@ def _retrieve_blocksparse( # Checks if blocksparse object exists in cache blocks = seq_len // block_size - print("Made uncached blocksparse") layout_fill = torch.ones((num_heads, blocks, blocks), dtype=torch.long) return BlockSparseAttention(layout=layout_fill, block_size=block_size, causal=True) @@ -266,7 +265,7 @@ def blocksparse_attention( # Reshape attention (B, nh, S, hs) back to (N, S, hs) if orig_dim == 3: - return reshape_heads(att, *att.size()) + return att.flatten(0, 1) return att @@ -276,31 +275,31 @@ def scaled_dot_product_attention( v: torch.Tensor, att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]], dropout: Optional[torch.nn.Module] = None, - block_size=128, + block_size: int = 128, ) -> torch.Tensor: autocast_disabled = ( _is_sparse_available and isinstance(att_mask, SparseCS) or (att_mask is not None and att_mask.is_sparse) ) + seq_len = q.shape[-2] - # Check if causal is required but mask is not sparse; if fp16 or under amp context + # switch if: + # causal is required but mask is not sparse + # fp16 or under amp context + # sequence length is divisible by block size + # same seq len for K and Q switch_to_blocksparse = ( _is_triton_available and (att_mask is not None and not att_mask.is_sparse) and (isinstance(att_mask, AttentionMask) and att_mask.is_causal) and (q.dtype == torch.float16 or torch.is_autocast_enabled()) - ) - - # Switch only if sequence length is divisible by block size - # Blocksparse requires the same dimensions for K and Q for now - seq_len = q.shape[-2] - if ( - switch_to_blocksparse and not seq_len % block_size and q.shape[-2] == k.shape[-2] - ): - # print("switching to blocksparse...") + ) + + if switch_to_blocksparse: + logging.info("Switching causal attention to Triton blocksparse...") return blocksparse_attention(q, k, v, dropout, block_size) with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): diff --git a/xformers/components/attention/utils.py b/xformers/components/attention/utils.py index 012188e1a2..d6bb06a1ac 100644 --- a/xformers/components/attention/utils.py +++ b/xformers/components/attention/utils.py @@ -106,13 +106,3 @@ def bool_mask_to_additive( mask_ = torch.zeros_like(mask, dtype=dtype) mask_[~mask] = float("-inf") return mask_ - - -# (B, S, D) to (B, S, nh, hs) -def split_heads(t: torch.Tensor, B: int, nH: int, S: int, Hs: int): - return t.view(B, nH, S, Hs) - - -# (B, nh, S, hs) back to (N, S, hs) -def reshape_heads(t: torch.Tensor, B: int, nH: int, S: int, Hs: int): - return t.view(B * nH, S, Hs) From 3dc74d3866b942755fcc8819ffcd0b564c2e175c Mon Sep 17 00:00:00 2001 From: Chris Yuan Date: Mon, 27 Jun 2022 08:35:35 -0700 Subject: [PATCH 2/3] fixed mypy error --- xformers/components/attention/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xformers/components/attention/core.py b/xformers/components/attention/core.py index 449e475de4..92be261cc7 100644 --- a/xformers/components/attention/core.py +++ b/xformers/components/attention/core.py @@ -14,12 +14,12 @@ from xformers import _is_sparse_available, _is_triton_available from xformers.components.attention.attention_mask import AttentionMask -from xformers.components.attention.blocksparse import BlockSparseAttention if _is_sparse_available: from ._sputnik_sparse import SparseCS if _is_triton_available: + from xformers.components.attention.blocksparse import BlockSparseAttention from xformers.triton.softmax import softmax as triton_softmax From 1900e9dc05a5bfa060d518ee4fc7c1eee511f5fe Mon Sep 17 00:00:00 2001 From: Chris Yuan Date: Mon, 27 Jun 2022 09:14:39 -0700 Subject: [PATCH 3/3] added checking for blocksparse availability --- tests/test_core_attention.py | 17 ++++++++++++++ xformers/components/attention/core.py | 32 ++++++++++++++++++--------- 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/tests/test_core_attention.py b/tests/test_core_attention.py index 9a6c579b99..26c71bb05e 100644 --- a/tests/test_core_attention.py +++ b/tests/test_core_attention.py @@ -7,10 +7,18 @@ import torch from torch import nn +from xformers import _is_triton_available from xformers.components.attention._sputnik_sparse import SparseCS from xformers.components.attention.attention_mask import AttentionMask from xformers.components.attention.core import scaled_dot_product_attention +if _is_triton_available: + from xformers.triton.utils import gpu_capabilities_older_than_70 + +_is_blocksparse_available = ( + _is_triton_available and not gpu_capabilities_older_than_70() +) + _devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] @@ -112,6 +120,9 @@ def test_amp_attention_sparsecs(device): assert r.dtype == expected_device +@pytest.mark.skipif( + not _is_blocksparse_available, reason="Blocksparse is not available" +) @pytest.mark.parametrize("device", ["cuda"]) @pytest.mark.parametrize("data_type", [torch.float16, torch.float32]) def test_switch_blocksparse(device, data_type): @@ -143,6 +154,9 @@ def test_switch_blocksparse(device, data_type): assert torch.allclose(r_custom, r_att_mask.half(), atol=1e-6, rtol=1e-2) +@pytest.mark.skipif( + not _is_blocksparse_available, reason="Blocksparse is not available" +) @pytest.mark.parametrize("device", ["cuda"]) def test_switch_blocksparse_dims(device): b, s, d, nh = 8, 128, 32, 8 @@ -161,6 +175,9 @@ def test_switch_blocksparse_dims(device): assert r.dtype == expected_device +@pytest.mark.skipif( + not _is_blocksparse_available, reason="Blocksparse is not available" +) @pytest.mark.parametrize("device", ["cuda"]) @pytest.mark.parametrize("training", [True, False]) @pytest.mark.parametrize("drop_prob", [0.0, 0.3]) diff --git a/xformers/components/attention/core.py b/xformers/components/attention/core.py index 92be261cc7..9c8e47404b 100644 --- a/xformers/components/attention/core.py +++ b/xformers/components/attention/core.py @@ -19,8 +19,15 @@ from ._sputnik_sparse import SparseCS if _is_triton_available: - from xformers.components.attention.blocksparse import BlockSparseAttention from xformers.triton.softmax import softmax as triton_softmax + from xformers.triton.utils import gpu_capabilities_older_than_70 + +_is_blocksparse_available = ( + _is_triton_available and not gpu_capabilities_older_than_70() +) + +if _is_blocksparse_available: + from xformers.components.attention.blocksparse import BlockSparseAttention def _create_random_sparsity(matrix, sparsity, divisible_by=4): @@ -215,16 +222,19 @@ def scaled_query_key_softmax( return att -# 128 is default maxsize -@lru_cache(maxsize=128) -def _retrieve_blocksparse( - num_heads: int, seq_len: int, block_size: int -) -> BlockSparseAttention: - # Checks if blocksparse object exists in cache +if _is_blocksparse_available: + # 128 is default maxsize + @lru_cache(maxsize=128) + def _retrieve_blocksparse( + num_heads: int, seq_len: int, block_size: int + ) -> BlockSparseAttention: + # Checks if blocksparse object exists in cache - blocks = seq_len // block_size - layout_fill = torch.ones((num_heads, blocks, blocks), dtype=torch.long) - return BlockSparseAttention(layout=layout_fill, block_size=block_size, causal=True) + blocks = seq_len // block_size + layout_fill = torch.ones((num_heads, blocks, blocks), dtype=torch.long) + return BlockSparseAttention( + layout=layout_fill, block_size=block_size, causal=True + ) def blocksparse_attention( @@ -290,7 +300,7 @@ def scaled_dot_product_attention( # sequence length is divisible by block size # same seq len for K and Q switch_to_blocksparse = ( - _is_triton_available + _is_blocksparse_available and (att_mask is not None and not att_mask.is_sparse) and (isinstance(att_mask, AttentionMask) and att_mask.is_causal) and (q.dtype == torch.float16 or torch.is_autocast_enabled())