From be72b26005fbc66ca0dd3631f72041edfff023ff Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Wed, 6 Apr 2022 19:56:37 -0700 Subject: [PATCH] Making it explicit whether the attention mechanism supports an attention mask or not check the assert --- tests/test_block_factory.py | 19 +++++++++++++---- tests/test_model_factory.py | 6 +++--- xformers/components/attention/base.py | 4 ++++ xformers/components/attention/blocksparse.py | 5 ++++- .../components/attention/compositional.py | 4 ++++ xformers/components/attention/favor.py | 4 ++++ xformers/components/attention/fourier_mix.py | 3 +++ .../components/attention/global_tokens.py | 3 +++ xformers/components/attention/lambda_layer.py | 4 ++++ xformers/components/attention/linformer.py | 6 ++++++ xformers/components/attention/local.py | 4 ++++ xformers/components/attention/nystrom.py | 4 ++++ xformers/components/attention/ortho.py | 4 ++++ xformers/components/attention/random.py | 5 +++++ .../attention/scaled_dot_product.py | 4 ++++ xformers/components/multi_head_dispatch.py | 21 +++++++++++++------ 16 files changed, 86 insertions(+), 14 deletions(-) diff --git a/tests/test_block_factory.py b/tests/test_block_factory.py index 7b5450faaa..42750b7a8d 100644 --- a/tests/test_block_factory.py +++ b/tests/test_block_factory.py @@ -50,6 +50,7 @@ def test_xformer_encoder_block( device: torch.device, reversible: bool, ): + block_size = 16 attention_config = { @@ -112,7 +113,13 @@ def test_xformer_encoder_block( # Check that we support attention masking, at least interface wise (do not check correctness yet) att_mask = torch.ones(SEQ, SEQ, dtype=torch.bool, device=device) - _ = block(inputs, att_mask=att_mask) + if block.mha.attention.supports_attention_mask: + _ = block(inputs, att_mask=att_mask) + else: + with pytest.raises(AssertionError): + # Check that passing an attention mask to a mechanism which does not support it raises + # an exception + _ = block(inputs, att_mask=att_mask) # Check that we support input masking, at least interface wise (do not check correctness yet) input_mask = torch.randn(SEQ, dtype=torch.float, device=device) @@ -223,7 +230,10 @@ def test_xformer_decoder_block( input_mask[input_mask < 0.0] = -float("inf") encoded = encoder_block(inputs) - _ = decoder_block(inputs, encoded, encoder_att_mask=att_mask, input_mask=input_mask) + if decoder_block.mha.attention.supports_attention_mask: + _ = decoder_block( + inputs, encoded, encoder_att_mask=att_mask, input_mask=input_mask + ) # Test different sequence lengths when encoding and decoding if not decoder_block.mha.attention.requires_same_k_q_dimensions: @@ -303,8 +313,9 @@ def test_embedding_projection(): _ = block(inputs) # Check that we support attention masking, at least interface wise (do not check correctness yet) - att_mask = torch.ones(SEQ, SEQ, dtype=torch.bool, device=device) - _ = block(inputs, att_mask=att_mask) + if block.mha.attention.supports_attention_mask: + att_mask = torch.ones(SEQ, SEQ, dtype=torch.bool, device=device) + _ = block(inputs, att_mask=att_mask) # Check that we support input masking, at least interface wise (do not check correctness yet) input_mask = torch.randn(SEQ, dtype=torch.float, device=device) diff --git a/tests/test_model_factory.py b/tests/test_model_factory.py index 4a4a87242e..7cda9d90a2 100644 --- a/tests/test_model_factory.py +++ b/tests/test_model_factory.py @@ -39,7 +39,7 @@ "num_heads": 4, "residual_dropout": 0, "attention": { - "name": "linformer", + "name": "scaled_dot_product", "dropout": 0, "causal": True, "seq_len": SEQ, @@ -73,7 +73,7 @@ "residual_dropout": 0, "dim_model": EMB, "attention": { - "name": "linformer", + "name": "scaled_dot_product", "dropout": 0, "causal": True, "seq_len": SEQ, @@ -84,7 +84,7 @@ "residual_dropout": 0, "dim_model": EMB, "attention": { - "name": "linformer", + "name": "scaled_dot_product", "dropout": 0, "causal": True, "seq_len": SEQ, diff --git a/xformers/components/attention/base.py b/xformers/components/attention/base.py index 511dcd04f2..52d72f7e4b 100644 --- a/xformers/components/attention/base.py +++ b/xformers/components/attention/base.py @@ -53,6 +53,10 @@ def __init__(self, dropout: Optional[float] = None, *args, **kwargs): # so that the MHA wrapper should skip it self.requires_skip_multi_head = False + # Whether this attention mechanism supports attention masks + self.supports_attention_mask = True + self.supports_key_padding_mask = False + @classmethod def from_config(cls: Type[Self], config: AttentionConfig) -> Self: # Generate the class inputs from the config diff --git a/xformers/components/attention/blocksparse.py b/xformers/components/attention/blocksparse.py index e3b691784c..e0f28e9446 100644 --- a/xformers/components/attention/blocksparse.py +++ b/xformers/components/attention/blocksparse.py @@ -114,9 +114,12 @@ def __init__( # key padding mask and attention mask must be passed in separately self.requires_separate_masks = True - self.requires_same_k_q_dimensions = True + # Properties specific to this attention mechanism + self.supports_attention_mask = True + self.supports_key_padding_mask = True + def update_mask_type(self, mask: torch.Tensor): global _mask_type_warning if _mask_type_warning: diff --git a/xformers/components/attention/compositional.py b/xformers/components/attention/compositional.py index 8d9276f4e1..5c862365c1 100644 --- a/xformers/components/attention/compositional.py +++ b/xformers/components/attention/compositional.py @@ -189,6 +189,10 @@ def __init__( self.causal = causal + # Properties specific to this attention mechanism + self.supports_attention_mask = True + self.supports_key_padding_mask = False + self._reset_parameters() def _reset_parameters(self): diff --git a/xformers/components/attention/favor.py b/xformers/components/attention/favor.py index 340da110e8..426bfffc38 100644 --- a/xformers/components/attention/favor.py +++ b/xformers/components/attention/favor.py @@ -104,6 +104,10 @@ def __init__( self.feature_map: FeatureMap = feature_map_constructor(**feature_settings) # type: ignore + # Properties specific to this attention mechanism + self.supports_attention_mask = False + self.supports_key_padding_mask = False + @staticmethod def _maybe_promote(x: torch.Tensor) -> torch.Tensor: # Only promote fp16 buffers, bfloat16 would be fine for instance diff --git a/xformers/components/attention/fourier_mix.py b/xformers/components/attention/fourier_mix.py index 0079ed4a4d..8ceaa01731 100644 --- a/xformers/components/attention/fourier_mix.py +++ b/xformers/components/attention/fourier_mix.py @@ -20,6 +20,9 @@ def __init__(self, dropout: float, *_, **__): """ super().__init__() self.attn_drop = torch.nn.Dropout(dropout, inplace=False) + + # Properties specific to this attention mechanism + self.supports_attention_mask = False self.requires_input_projection = False def forward(self, q: torch.Tensor, *_, **__): diff --git a/xformers/components/attention/global_tokens.py b/xformers/components/attention/global_tokens.py index 0d188fde5d..fdb60576dc 100644 --- a/xformers/components/attention/global_tokens.py +++ b/xformers/components/attention/global_tokens.py @@ -78,7 +78,10 @@ def __init__( else maybe_sparsify(self.attention_mask) ) + # Properties specific to this attention mechanism self.requires_same_k_q_dimensions = True + self.supports_attention_mask = False + self.supports_key_padding_mask = False def forward( self, diff --git a/xformers/components/attention/lambda_layer.py b/xformers/components/attention/lambda_layer.py index dc8130c902..0002a20cbc 100644 --- a/xformers/components/attention/lambda_layer.py +++ b/xformers/components/attention/lambda_layer.py @@ -44,7 +44,11 @@ def __init__(self, dropout: float, seq_len: int, dim_head: int, *_, **__): ) self.rel_pos = calc_rel_pos(seq_len) self.attn_drop = torch.nn.Dropout(dropout, inplace=True) + + # Properties specific to this attention mechanism self.requires_same_k_q_dimensions = True + self.supports_attention_mask = False + self.supports_key_padding_mask = False def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs diff --git a/xformers/components/attention/linformer.py b/xformers/components/attention/linformer.py index 8f6b181c24..6b23d38d0b 100644 --- a/xformers/components/attention/linformer.py +++ b/xformers/components/attention/linformer.py @@ -42,8 +42,14 @@ def __init__( self.F = nn.Linear(seq_len, k, bias=False) self.attn_drop = nn.Dropout(dropout, inplace=False) self.seq_len = seq_len + + # MHA related flags: + # kq need to have the same dimension self.requires_same_k_q_dimensions = True + # Properties specific to this attention mechanism + self.supports_attention_mask = False + def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs ): diff --git a/xformers/components/attention/local.py b/xformers/components/attention/local.py index d3699becdf..68df4bca3d 100644 --- a/xformers/components/attention/local.py +++ b/xformers/components/attention/local.py @@ -77,6 +77,10 @@ def __init__( self.attention_mask: Optional[torch.Tensor] = None self.requires_same_k_q_dimensions = True + # Properties specific to this attention mechanism + self.supports_attention_mask = True + self.supports_key_padding_mask = False + def _get_local_mask(self, shape: torch.Size) -> torch.Tensor: window_size = self.window_size * 2 + 1 if self.causal else self.window_size mask = local_1d_pattern(shape[1], window_size) diff --git a/xformers/components/attention/nystrom.py b/xformers/components/attention/nystrom.py index 002ef84898..f34b788ae0 100644 --- a/xformers/components/attention/nystrom.py +++ b/xformers/components/attention/nystrom.py @@ -154,6 +154,10 @@ def __init__( self.causal_mask_2: Optional[torch.Tensor] = None self.causal_mask_3: Optional[torch.Tensor] = None + # This attention does not support attention masks + self.supports_attention_mask = False + self.supports_key_padding_mask = True + def forward( self, q: torch.Tensor, diff --git a/xformers/components/attention/ortho.py b/xformers/components/attention/ortho.py index e426a65312..94a370e6ba 100644 --- a/xformers/components/attention/ortho.py +++ b/xformers/components/attention/ortho.py @@ -72,6 +72,10 @@ def __init__( self.subsample_fraction = subsample_fraction self.landmark_selection = landmark_selection + # Properties specific to this attention mechanism + self.supports_attention_mask = True + self.supports_key_padding_mask = False + def forward( self, q: torch.Tensor, diff --git a/xformers/components/attention/random.py b/xformers/components/attention/random.py index 55ad307737..5e3ee08e69 100644 --- a/xformers/components/attention/random.py +++ b/xformers/components/attention/random.py @@ -68,6 +68,11 @@ def __init__( self.rand_attention_mask: Optional[torch.Tensor] = None self.constant_masking = constant_masking self.force_sparsity = force_sparsity + + # Properties specific to this attention mechanism + self.supports_attention_mask = True + self.supports_key_padding_mask = False + self.requires_same_k_q_dimensions = True def _get_rand_mask(self, shape: torch.Size) -> torch.Tensor: diff --git a/xformers/components/attention/scaled_dot_product.py b/xformers/components/attention/scaled_dot_product.py index a8560629cf..3db7b0ada7 100644 --- a/xformers/components/attention/scaled_dot_product.py +++ b/xformers/components/attention/scaled_dot_product.py @@ -57,6 +57,10 @@ def __init__( else: self.mask = None + # Properties specific to this attention mechanism + self.supports_attention_mask = True + self.supports_key_padding_mask = False + def forward( self, q: torch.Tensor, diff --git a/xformers/components/multi_head_dispatch.py b/xformers/components/multi_head_dispatch.py index f754324b3e..a7eb91f459 100644 --- a/xformers/components/multi_head_dispatch.py +++ b/xformers/components/multi_head_dispatch.py @@ -165,10 +165,21 @@ def forward( + "In that case causality is ill-determined. Please pad your sequences accordingly" ) + kw_mask_args = {} + if att_mask is not None: + assert ( + self.attention.supports_attention_mask + ), "This attention does not support attention masks" + kw_mask_args["att_mask"] = att_mask + + if key_padding_mask is not None: + assert ( + self.attention.supports_key_padding_mask + ), "This attention does not support key padding masks" + kw_mask_args["key_padding_mask"] = key_padding_mask + if self.attention.requires_skip_multi_head: - return self.attention( - query, key, value, att_mask=att_mask, key_padding_mask=key_padding_mask - ) + return self.attention(query, key, value, **kw_mask_args) # Calculate query, key, values for all heads in batch if self.attention.requires_input_projection: @@ -199,9 +210,7 @@ def forward( v = reshape_fn(v, B, S_K, self.num_heads, self.dim_k) # Self-attend - y = self.attention( - q=q, k=k, v=v, att_mask=att_mask, key_padding_mask=key_padding_mask - ) + y = self.attention(q=q, k=k, v=v, **kw_mask_args) # Re-assemble all head outputs side by side y = (