From 4bfb18fba3fd7c235a19927594386d49237a11fa Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Fri, 26 Nov 2021 09:05:30 -0800 Subject: [PATCH 1/5] Drop just after the softmax --- tests/test_attentions.py | 9 +++++++-- xformers/components/attention/blocksparse.py | 5 +++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/test_attentions.py b/tests/test_attentions.py index a73c79a7da..e3ebb9b956 100644 --- a/tests/test_attentions.py +++ b/tests/test_attentions.py @@ -77,7 +77,7 @@ def noop(x): return multi_head -@pytest.mark.parametrize("attn_dropout", [0.0, 0.1]) +@pytest.mark.parametrize("attn_dropout", [0.0, 0.5]) @pytest.mark.parametrize("residual_dropout", [0.0, 0.1]) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("heads", [1, 4]) @@ -110,7 +110,12 @@ def test_order_invariance( torch.allclose(results[:, shuffle, :], results_shuffled) # Test the non-self-attention codepath - _ = multi_head(inputs, inputs_shuffled, inputs) + att = multi_head(inputs, inputs_shuffled, inputs) + + # Check that dropout actually drops some values + if attn_dropout > 0: + att_2 = multi_head(inputs, inputs_shuffled, inputs) + assert (att != att_2).any() @pytest.mark.parametrize("heads", [1, 4]) diff --git a/xformers/components/attention/blocksparse.py b/xformers/components/attention/blocksparse.py index 0867657d5e..130f56bc36 100644 --- a/xformers/components/attention/blocksparse.py +++ b/xformers/components/attention/blocksparse.py @@ -10,7 +10,6 @@ from typing import Optional import torch -from torch import nn from xformers import _is_triton_available from xformers.components.attention import Attention, AttentionConfig, register_attention @@ -80,7 +79,7 @@ def __init__( assert block_size >= 16, "Minimum block size is 16, for now at least" super().__init__() - self.attn_drop = nn.Dropout(dropout, inplace=False) + self.attn_drop = torch.nn.Dropout(dropout, inplace=False) # Pure blocksparse data self.layout = layout @@ -193,6 +192,8 @@ def forward( attn_mask_mode=MaskType.ADD, ) + sparse_att_mat = self.attn_drop(sparse_att_mat) + # - then (dense) attention is (sparse) attention matrix * dense (value) a = self.sparse_dot_dsd(sparse_att_mat, v) return a.to(q_dtype) From ef83b293f1fe791b9dcda12bb6f05e8e2e41cfc9 Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Sun, 28 Nov 2021 17:42:53 -0800 Subject: [PATCH 2/5] fixing a bunch of attentions, this is a good test --- tests/test_attentions.py | 2 +- xformers/components/attention/fourier_mix.py | 9 +++++++-- xformers/components/attention/global_tokens.py | 2 ++ xformers/components/attention/linformer.py | 4 ++++ 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/tests/test_attentions.py b/tests/test_attentions.py index e3ebb9b956..8530ac759a 100644 --- a/tests/test_attentions.py +++ b/tests/test_attentions.py @@ -77,7 +77,7 @@ def noop(x): return multi_head -@pytest.mark.parametrize("attn_dropout", [0.0, 0.5]) +@pytest.mark.parametrize("attn_dropout", [0.0, 0.9]) @pytest.mark.parametrize("residual_dropout", [0.0, 0.1]) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("heads", [1, 4]) diff --git a/xformers/components/attention/fourier_mix.py b/xformers/components/attention/fourier_mix.py index e7353faf39..f126121c74 100644 --- a/xformers/components/attention/fourier_mix.py +++ b/xformers/components/attention/fourier_mix.py @@ -10,7 +10,7 @@ @register_attention("fourier_mix", AttentionConfig) class FourierMix(Attention): - def __init__(self, *_, **__): + def __init__(self, dropout: float, *_, **__): """ FFT-based pseudo-attention mechanism, from " @@ -18,7 +18,12 @@ def __init__(self, *_, **__): Lee-Thorp et al., 2021, https://arxiv.org/pdf/2105.03824.pdf """ super().__init__() + self.attn_drop = torch.nn.Dropout(dropout, inplace=False) self.requires_input_projection = False def forward(self, q: torch.Tensor, *_, **__): - return torch.fft.fft2(q).real + att = torch.fft.fft2(q).real + + att = self.attn_drop(att) + + return att diff --git a/xformers/components/attention/global_tokens.py b/xformers/components/attention/global_tokens.py index 290f44c2be..9ef7606fbe 100644 --- a/xformers/components/attention/global_tokens.py +++ b/xformers/components/attention/global_tokens.py @@ -111,5 +111,7 @@ def forward( q=q_, k=k_, v=v_, att_mask=mask, dropout=self.attn_drop ) + att = self.attn_drop(att) + # Take into account an hypothetical padding return att[:, :seq_len, :] diff --git a/xformers/components/attention/linformer.py b/xformers/components/attention/linformer.py index b1aac85d94..b458d7c0ac 100644 --- a/xformers/components/attention/linformer.py +++ b/xformers/components/attention/linformer.py @@ -37,6 +37,7 @@ def __init__( if k is None: k = seq_len // 4 + print("dropout ", dropout) self.k = k self.E = nn.Linear(seq_len, k, bias=False) self.F = nn.Linear(seq_len, k, bias=False) @@ -61,4 +62,7 @@ def forward( y = scaled_dot_product_attention( q=q, k=k_projected, v=v_projected, att_mask=None, dropout=self.attn_drop ) + + y = self.attn_drop(y) + return y[:, :-padding, :] if padding > 0 else y From 2ff0bb3cc9b1482b71fd0051be15013f29df1287 Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Sun, 28 Nov 2021 17:53:31 -0800 Subject: [PATCH 3/5] Fixing SDP, this is not good --- tests/test_attentions.py | 5 ++++- xformers/components/attention/core.py | 5 +++++ xformers/components/attention/lambda_layer.py | 9 +++++++-- xformers/components/attention/scaled_dot_product.py | 3 ++- 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/tests/test_attentions.py b/tests/test_attentions.py index 8530ac759a..69169750d0 100644 --- a/tests/test_attentions.py +++ b/tests/test_attentions.py @@ -77,7 +77,7 @@ def noop(x): return multi_head -@pytest.mark.parametrize("attn_dropout", [0.0, 0.9]) +@pytest.mark.parametrize("attn_dropout", [0.0, 0.3]) @pytest.mark.parametrize("residual_dropout", [0.0, 0.1]) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("heads", [1, 4]) @@ -91,6 +91,9 @@ def test_order_invariance( causal: bool, device: torch.device, ): + + torch.manual_seed(42) + multi_head = _get_multihead( attention_name, attn_dropout, residual_dropout, causal, heads, device ) diff --git a/xformers/components/attention/core.py b/xformers/components/attention/core.py index 614ff31a05..79a236f615 100644 --- a/xformers/components/attention/core.py +++ b/xformers/components/attention/core.py @@ -154,6 +154,7 @@ def bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: def _apply_dropout(att, dropout): if dropout is None: return att + # Dropout chokes on sparse tensors if _is_sparse_available: if isinstance(att, SparseCS): @@ -172,10 +173,14 @@ def _apply_dropout(att, dropout): values = att.values().clone() # protect against in-place dropout values = dropout(values) att = torch.sparse_coo_tensor(att.indices(), values, att.shape) + else: + # Simple dense case + att = dropout(att) return att # Non optimized vanilla dropout + print(att - dropout(att)) att = dropout(att) return att diff --git a/xformers/components/attention/lambda_layer.py b/xformers/components/attention/lambda_layer.py index 999ba6c179..1d145d6ef6 100644 --- a/xformers/components/attention/lambda_layer.py +++ b/xformers/components/attention/lambda_layer.py @@ -27,7 +27,7 @@ class LambdaLayerConfig(AttentionConfig): @register_attention("lambda", LambdaLayerConfig) class LambdaLayer(Attention): - def __init__(self, seq_len: int, dim_head: int, *_, **__): + def __init__(self, dropout: float, seq_len: int, dim_head: int, *_, **__): """ Attention approximation using Lambda layers, from "Lambda networks: modeling long-range interactions without attention.", Bello, I. (2021). @@ -43,6 +43,7 @@ def __init__(self, seq_len: int, dim_head: int, *_, **__): torch.randn(2 * seq_len - 1, int(dim_head)) ) self.rel_pos = calc_rel_pos(seq_len) + self.attn_drop = torch.nn.Dropout(dropout, inplace=True) def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs @@ -65,4 +66,8 @@ def forward( "mnk,bnv->bnkv", rel_pos_emb, v ) # one lambda per position position_output = (q.unsqueeze(2) @ position_lambdas).squeeze() - return content_output + position_output + att = content_output + position_output + + att = self.attn_drop(att) + + return att diff --git a/xformers/components/attention/scaled_dot_product.py b/xformers/components/attention/scaled_dot_product.py index f84e674b52..a8560629cf 100644 --- a/xformers/components/attention/scaled_dot_product.py +++ b/xformers/components/attention/scaled_dot_product.py @@ -47,6 +47,7 @@ def __init__( **kwargs, ): super().__init__() + self.attn_drop = nn.Dropout(dropout, inplace=False) self.causal = causal self.seq_len = seq_len @@ -119,7 +120,7 @@ def forward( ) raise NotImplementedError - # Self-attend: (B x nh, S, hs) x (B x nh, hs, S) -> (B x nh, S, S) + # Attend: (B x nh, S, hs) x (B x nh, hs, S) -> (B x nh, S, S) y = scaled_dot_product_attention( q=q, k=k, v=v, att_mask=att_mask, dropout=self.attn_drop ) From 44ee4f3eeee53387540ed3e360166d11595ec586 Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Sun, 28 Nov 2021 18:30:15 -0800 Subject: [PATCH 4/5] code review, thanks Diana --- xformers/components/attention/core.py | 1 - xformers/components/attention/linformer.py | 1 - 2 files changed, 2 deletions(-) diff --git a/xformers/components/attention/core.py b/xformers/components/attention/core.py index 79a236f615..b33374e3b1 100644 --- a/xformers/components/attention/core.py +++ b/xformers/components/attention/core.py @@ -180,7 +180,6 @@ def _apply_dropout(att, dropout): return att # Non optimized vanilla dropout - print(att - dropout(att)) att = dropout(att) return att diff --git a/xformers/components/attention/linformer.py b/xformers/components/attention/linformer.py index b458d7c0ac..0458e9889b 100644 --- a/xformers/components/attention/linformer.py +++ b/xformers/components/attention/linformer.py @@ -37,7 +37,6 @@ def __init__( if k is None: k = seq_len // 4 - print("dropout ", dropout) self.k = k self.E = nn.Linear(seq_len, k, bias=False) self.F = nn.Linear(seq_len, k, bias=False) From 19b802ed737a6c91269640c3a7d68c20da2d0488 Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Sun, 28 Nov 2021 19:30:44 -0800 Subject: [PATCH 5/5] updating the changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 92d3da79b0..a1940516ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## TBD +### Fixed +- Dropout setting not properly passed in many attentions [#123] + ## [0.0.6] - 2021-11-24 ### Fixed - Fix self attention optimization not being triggered, broken residual path [#119]