diff --git a/tests/models/wavlm/test_modeling_wavlm.py b/tests/models/wavlm/test_modeling_wavlm.py index c0a8eed2096f35..d3670308bbd52f 100644 --- a/tests/models/wavlm/test_modeling_wavlm.py +++ b/tests/models/wavlm/test_modeling_wavlm.py @@ -288,6 +288,14 @@ def check_seq_classifier_training(self, config, input_values, *args): loss.backward() + def check_output_attentions(self, config, input_values, attention_mask): + model = WavLMModel(config=config) + model.to(torch_device) + model.train() + + outputs = model(input_values, attention_mask=attention_mask, output_attentions=True) + self.parent.assertTrue(len(outputs.attentions) > 0) + def check_labels_out_of_vocab(self, config, input_values, *args): model = WavLMForCTC(config) model.to(torch_device) @@ -354,6 +362,10 @@ def test_seq_classifier_train(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_seq_classifier_training(*config_and_inputs) + def test_output_attentions(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_output_attentions(*config_and_inputs) + def test_labels_out_of_vocab(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_labels_out_of_vocab(*config_and_inputs)