From 4e9bafdd145042df3469b4d4a1ead592c39ed299 Mon Sep 17 00:00:00 2001 From: Jonghwan Hyeon Date: Thu, 2 May 2024 15:44:33 +0000 Subject: [PATCH] Add test for output_attentions --- tests/models/wavlm/test_modeling_wavlm.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/models/wavlm/test_modeling_wavlm.py b/tests/models/wavlm/test_modeling_wavlm.py index c0a8eed2096f35..3cf4348f6c83d7 100644 --- a/tests/models/wavlm/test_modeling_wavlm.py +++ b/tests/models/wavlm/test_modeling_wavlm.py @@ -288,6 +288,15 @@ def check_seq_classifier_training(self, config, input_values, *args): loss.backward() + def check_output_attentions(self, config, input_values, attention_mask): + model = WavLMModel(config=config) + model.config.layerdrop = 1.0 + model.to(torch_device) + model.train() + + outputs = model(input_values, attention_mask=attention_mask, output_attentions=True) + self.parent.assertTrue(len(outputs.attentions) > 0) + def check_labels_out_of_vocab(self, config, input_values, *args): model = WavLMForCTC(config) model.to(torch_device) @@ -354,6 +363,10 @@ def test_seq_classifier_train(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_seq_classifier_training(*config_and_inputs) + def test_output_attentions(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_output_attentions(*config_and_inputs) + def test_labels_out_of_vocab(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_labels_out_of_vocab(*config_and_inputs)