From c84cab45ba3af64a1e202d9ec42114ee6fb3c7fd Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Mon, 3 Jun 2024 15:27:15 +0200 Subject: [PATCH 1/3] fix sdpa musicgen --- .../models/musicgen/modeling_musicgen.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 810f34f7804716..075d465becae58 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -571,6 +571,20 @@ def forward( layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) + # Ignore copy + if attention_mask is not None and (attention_mask.mean(dim=[1,2,3]) <= torch.finfo(attention_mask.dtype).min).any(): + logger.warning_once( + '`torch.nn.functional.scaled_dot_product_attention` does not support having an empty attention mask. Falling back to the manual attention implementation. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + 'Note that this probably happens because `guidance_scale>1` or because you used `get_unconditional_inputs`. See https://github.com/huggingface/transformers/issues/31189 for more information.' + ) + return super().forward( + hidden_states, + key_value_states=key_value_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder From cc3bbd68413de61dc433de204228eb3202bb8af1 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Mon, 3 Jun 2024 16:58:48 +0200 Subject: [PATCH 2/3] make style --- src/transformers/models/musicgen/modeling_musicgen.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 075d465becae58..69d202de66f506 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -571,11 +571,15 @@ def forward( layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) + # Ignore copy - if attention_mask is not None and (attention_mask.mean(dim=[1,2,3]) <= torch.finfo(attention_mask.dtype).min).any(): + if ( + attention_mask is not None + and (attention_mask.mean(dim=[1, 2, 3]) <= torch.finfo(attention_mask.dtype).min).any() + ): logger.warning_once( '`torch.nn.functional.scaled_dot_product_attention` does not support having an empty attention mask. Falling back to the manual attention implementation. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - 'Note that this probably happens because `guidance_scale>1` or because you used `get_unconditional_inputs`. See https://github.com/huggingface/transformers/issues/31189 for more information.' + "Note that this probably happens because `guidance_scale>1` or because you used `get_unconditional_inputs`. See https://github.com/huggingface/transformers/issues/31189 for more information." ) return super().forward( hidden_states, @@ -584,7 +588,7 @@ def forward( attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, - ) + ) # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder From 605df12a77e12da9344db2d4d5710afc2cfc5214 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Fri, 14 Jun 2024 10:58:57 +0200 Subject: [PATCH 3/3] remove copied from statement from Musicgen SDPA --- src/transformers/models/musicgen/modeling_musicgen.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index fc6cc6c7e1c42e..8c126f5d809c9e 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -545,7 +545,6 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->Musicgen class MusicgenSdpaAttention(MusicgenAttention): def forward( self, @@ -572,7 +571,6 @@ def forward( output_attentions=output_attentions, ) - # Ignore copy if ( attention_mask is not None and (attention_mask.mean(dim=[1, 2, 3]) <= torch.finfo(attention_mask.dtype).min).any()