From 2c1eaf6c0242badb86f3d626811ee3d7d220ec06 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 22 Aug 2023 21:53:09 +0900 Subject: [PATCH] Check output_attentions is False in BetterTransformer (#1306) add checks on output_attentions --- .../models/encoder_models.py | 53 +++++++++++++++---- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/optimum/bettertransformer/models/encoder_models.py b/optimum/bettertransformer/models/encoder_models.py index c56f20aae8d..20f7f4de50c 100644 --- a/optimum/bettertransformer/models/encoder_models.py +++ b/optimum/bettertransformer/models/encoder_models.py @@ -288,6 +288,7 @@ def __init__(self, bert_layer, config): self.validate_bettertransformer() def forward(self, hidden_states, attention_mask, *_): + # No check on output_attentions here as roformer relies on BertLayerBetterTransformer but does not pass output_attentions as keyword argument. if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): if hidden_states.is_nested: attention_mask = None @@ -463,7 +464,10 @@ def __init__(self, bart_layer, config): self.validate_bettertransformer() - def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): + def forward(self, hidden_states, attention_mask, output_attentions: bool, position_bias=None, *_, **__): + if output_attentions: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): if not hasattr(hidden_states, "original_shape"): original_shape = hidden_states.shape @@ -655,7 +659,10 @@ def __init__(self, mbart_layer, config): self.validate_bettertransformer() - def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): + def forward(self, hidden_states, attention_mask, output_attentions: bool, position_bias=None, *_, **__): + if output_attentions: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): if not hasattr(hidden_states, "original_shape"): original_shape = hidden_states.shape @@ -842,7 +849,10 @@ def __init__(self, bert_layer, config): self.validate_bettertransformer() - def forward(self, hidden_states, attn_mask, head_mask=None, output_attentions=None, *_): + def forward(self, hidden_states, attn_mask, output_attentions: bool, head_mask=None, *_): + if output_attentions: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): if hidden_states.is_nested: attn_mask = None @@ -1019,7 +1029,10 @@ def __init__(self, whisper_layer, config): self.validate_bettertransformer() - def forward(self, hidden_states, attention_mask, *_, **__): + def forward(self, hidden_states, attention_mask, output_attentions: bool, *_, **__): + if output_attentions: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): attention_mask = None # attention mask seems to be always None: https://github.com/huggingface/transformers/blob/94b3f544a1f5e04b78d87a2ae32a7ac252e22e31/src/transformers/models/whisper/modeling_whisper.py#L690 @@ -1139,7 +1152,10 @@ def __init__(self, vit_layer, config): self.validate_bettertransformer() - def forward(self, hidden_states, *_, **__): + def forward(self, hidden_states, output_attentions: bool, *_, **__): + if output_attentions: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): attention_mask = None @@ -1259,7 +1275,10 @@ def __init__(self, vilt_layer, config): self.validate_bettertransformer() - def forward(self, hidden_states, *_, **__): + def forward(self, hidden_states, layer_head_mask, output_attentions: bool, *_, **__): + if output_attentions: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): attention_mask = None @@ -1375,7 +1394,10 @@ def __init__(self, wav2vec2_layer, config): self.validate_bettertransformer() - def forward(self, hidden_states, attention_mask, **__): + def forward(self, hidden_states, attention_mask, output_attentions: bool, **__): + if output_attentions: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): if hidden_states.is_nested: attention_mask = None @@ -1497,7 +1519,10 @@ def __init__(self, fsmt_layer, config): self.validate_bettertransformer() - def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__): + def forward(self, hidden_states, attention_mask, output_attentions: bool, position_bias=None, *_, **__): + if output_attentions: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): if not hasattr(hidden_states, "original_shape"): original_shape = hidden_states.shape @@ -1638,7 +1663,10 @@ def __init__(self, prophetnet_layer, config): self.validate_bettertransformer() - def forward(self, hidden_states, attention_mask, *_, **__): + def forward(self, hidden_states, attention_mask, output_attentions: bool, *_, **__): + if output_attentions: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): if not hasattr(hidden_states, "original_shape"): original_shape = hidden_states.shape @@ -1772,10 +1800,13 @@ def __init__(self, layer, config): self.validate_bettertransformer() - def forward(self, hidden_states, attention_mask, *_, **__): + def forward(self, hidden_states, attention_mask, causal_attention_mask, output_attentions: bool, *_, **__): + if output_attentions: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") + if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled(): # we expect attention_mask to be None in the vision model - if attention_mask is not None: + if attention_mask is not None or causal_attention_mask is not None: raise ValueError( "Please do not use attention masks when using `BetterTransformer` converted vision models" )