Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] attn_dropout #123

Merged
merged 5 commits into from
Nov 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## TBD
### Fixed
- Dropout setting not properly passed in many attentions [#123]

## [0.0.6] - 2021-11-24
### Fixed
- Fix self attention optimization not being triggered, broken residual path [#119]
Expand Down
12 changes: 10 additions & 2 deletions tests/test_attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def noop(x):
return multi_head


@pytest.mark.parametrize("attn_dropout", [0.0, 0.1])
@pytest.mark.parametrize("attn_dropout", [0.0, 0.3])
@pytest.mark.parametrize("residual_dropout", [0.0, 0.1])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("heads", [1, 4])
Expand All @@ -91,6 +91,9 @@ def test_order_invariance(
causal: bool,
device: torch.device,
):

torch.manual_seed(42)

multi_head = _get_multihead(
attention_name, attn_dropout, residual_dropout, causal, heads, device
)
Expand All @@ -110,7 +113,12 @@ def test_order_invariance(
torch.allclose(results[:, shuffle, :], results_shuffled)

# Test the non-self-attention codepath
_ = multi_head(inputs, inputs_shuffled, inputs)
att = multi_head(inputs, inputs_shuffled, inputs)

# Check that dropout actually drops some values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great to have this test now!

if attn_dropout > 0:
att_2 = multi_head(inputs, inputs_shuffled, inputs)
assert (att != att_2).any()


@pytest.mark.parametrize("heads", [1, 4])
Expand Down
5 changes: 3 additions & 2 deletions xformers/components/attention/blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Optional

import torch
from torch import nn

from xformers import _is_triton_available
from xformers.components.attention import Attention, AttentionConfig, register_attention
Expand Down Expand Up @@ -80,7 +79,7 @@ def __init__(
assert block_size >= 16, "Minimum block size is 16, for now at least"

super().__init__()
self.attn_drop = nn.Dropout(dropout, inplace=False)
self.attn_drop = torch.nn.Dropout(dropout, inplace=False)

# Pure blocksparse data
self.layout = layout
Expand Down Expand Up @@ -193,6 +192,8 @@ def forward(
attn_mask_mode=MaskType.ADD,
)

sparse_att_mat = self.attn_drop(sparse_att_mat)

# - then (dense) attention is (sparse) attention matrix * dense (value)
a = self.sparse_dot_dsd(sparse_att_mat, v)
return a.to(q_dtype)
4 changes: 4 additions & 0 deletions xformers/components/attention/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
def _apply_dropout(att, dropout):
if dropout is None:
return att

# Dropout chokes on sparse tensors
if _is_sparse_available:
if isinstance(att, SparseCS):
Expand All @@ -172,6 +173,9 @@ def _apply_dropout(att, dropout):
values = att.values().clone() # protect against in-place dropout
values = dropout(values)
att = torch.sparse_coo_tensor(att.indices(), values, att.shape)
else:
# Simple dense case
att = dropout(att)

return att

Expand Down
9 changes: 7 additions & 2 deletions xformers/components/attention/fourier_mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,20 @@

@register_attention("fourier_mix", AttentionConfig)
class FourierMix(Attention):
def __init__(self, *_, **__):
def __init__(self, dropout: float, *_, **__):
"""
FFT-based pseudo-attention mechanism, from
"
"FNet: Mixing Tokens with Fourier Transforms"
Lee-Thorp et al., 2021, https://arxiv.org/pdf/2105.03824.pdf
"""
super().__init__()
self.attn_drop = torch.nn.Dropout(dropout, inplace=False)
self.requires_input_projection = False

def forward(self, q: torch.Tensor, *_, **__):
return torch.fft.fft2(q).real
att = torch.fft.fft2(q).real

att = self.attn_drop(att)

return att
2 changes: 2 additions & 0 deletions xformers/components/attention/global_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,7 @@ def forward(
q=q_, k=k_, v=v_, att_mask=mask, dropout=self.attn_drop
)

att = self.attn_drop(att)

# Take into account an hypothetical padding
return att[:, :seq_len, :]
9 changes: 7 additions & 2 deletions xformers/components/attention/lambda_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class LambdaLayerConfig(AttentionConfig):

@register_attention("lambda", LambdaLayerConfig)
class LambdaLayer(Attention):
def __init__(self, seq_len: int, dim_head: int, *_, **__):
def __init__(self, dropout: float, seq_len: int, dim_head: int, *_, **__):
"""
Attention approximation using Lambda layers, from
"Lambda networks: modeling long-range interactions without attention.", Bello, I. (2021).
Expand All @@ -43,6 +43,7 @@ def __init__(self, seq_len: int, dim_head: int, *_, **__):
torch.randn(2 * seq_len - 1, int(dim_head))
)
self.rel_pos = calc_rel_pos(seq_len)
self.attn_drop = torch.nn.Dropout(dropout, inplace=True)

def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
Expand All @@ -65,4 +66,8 @@ def forward(
"mnk,bnv->bnkv", rel_pos_emb, v
) # one lambda per position
position_output = (q.unsqueeze(2) @ position_lambdas).squeeze()
return content_output + position_output
att = content_output + position_output

att = self.attn_drop(att)

return att
3 changes: 3 additions & 0 deletions xformers/components/attention/linformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,7 @@ def forward(
y = scaled_dot_product_attention(
q=q, k=k_projected, v=v_projected, att_mask=None, dropout=self.attn_drop
)

y = self.attn_drop(y)

return y[:, :-padding, :] if padding > 0 else y
3 changes: 2 additions & 1 deletion xformers/components/attention/scaled_dot_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
**kwargs,
):
super().__init__()

self.attn_drop = nn.Dropout(dropout, inplace=False)
self.causal = causal
self.seq_len = seq_len
Expand Down Expand Up @@ -119,7 +120,7 @@ def forward(
)
raise NotImplementedError

# Self-attend: (B x nh, S, hs) x (B x nh, hs, S) -> (B x nh, S, S)
# Attend: (B x nh, S, hs) x (B x nh, hs, S) -> (B x nh, S, S)
y = scaled_dot_product_attention(
q=q, k=k, v=v, att_mask=att_mask, dropout=self.attn_drop
)
Expand Down