From 4c940934dad4cad6992d8470a71942db6ab6c0ac Mon Sep 17 00:00:00 2001 From: Jonghwan Hyeon Date: Fri, 3 May 2024 01:25:19 +0900 Subject: [PATCH] Output `None` as attention when layer is skipped (#30597) * Output `None` as attention when layer is skipped * Add test for output_attentions --- src/transformers/models/wavlm/modeling_wavlm.py | 4 ++-- tests/models/wavlm/test_modeling_wavlm.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 5d1a44c00a2302..1db656da60a538 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -727,7 +727,7 @@ def forward( hidden_states, position_bias = layer_outputs[:2] if skip_the_layer: - layer_outputs = (None, None) + layer_outputs = (None, None, None) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[2],) @@ -810,7 +810,7 @@ def forward( hidden_states, position_bias = layer_outputs[:2] if skip_the_layer: - layer_outputs = (None, None) + layer_outputs = (None, None, None) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[2],) diff --git a/tests/models/wavlm/test_modeling_wavlm.py b/tests/models/wavlm/test_modeling_wavlm.py index c0a8eed2096f35..3cf4348f6c83d7 100644 --- a/tests/models/wavlm/test_modeling_wavlm.py +++ b/tests/models/wavlm/test_modeling_wavlm.py @@ -288,6 +288,15 @@ def check_seq_classifier_training(self, config, input_values, *args): loss.backward() + def check_output_attentions(self, config, input_values, attention_mask): + model = WavLMModel(config=config) + model.config.layerdrop = 1.0 + model.to(torch_device) + model.train() + + outputs = model(input_values, attention_mask=attention_mask, output_attentions=True) + self.parent.assertTrue(len(outputs.attentions) > 0) + def check_labels_out_of_vocab(self, config, input_values, *args): model = WavLMForCTC(config) model.to(torch_device) @@ -354,6 +363,10 @@ def test_seq_classifier_train(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_seq_classifier_training(*config_and_inputs) + def test_output_attentions(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_output_attentions(*config_and_inputs) + def test_labels_out_of_vocab(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_labels_out_of_vocab(*config_and_inputs)