diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 792cf937f32093..8c126f5d809c9e 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -545,7 +545,6 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->Musicgen class MusicgenSdpaAttention(MusicgenAttention): def forward( self, @@ -572,6 +571,23 @@ def forward( output_attentions=output_attentions, ) + if ( + attention_mask is not None + and (attention_mask.mean(dim=[1, 2, 3]) <= torch.finfo(attention_mask.dtype).min).any() + ): + logger.warning_once( + '`torch.nn.functional.scaled_dot_product_attention` does not support having an empty attention mask. Falling back to the manual attention implementation. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + "Note that this probably happens because `guidance_scale>1` or because you used `get_unconditional_inputs`. See https://github.com/huggingface/transformers/issues/31189 for more information." + ) + return super().forward( + hidden_states, + key_value_states=key_value_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None