Skip to content

Commit

Permalink
Output None as attention when layer is skipped (huggingface#30597)
Browse files Browse the repository at this point in the history
* Output `None` as attention when layer is skipped

* Add test for output_attentions
  • Loading branch information
jonghwanhyeon authored May 2, 2024
1 parent 39359e5 commit 4c94093
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/transformers/models/wavlm/modeling_wavlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def forward(
hidden_states, position_bias = layer_outputs[:2]

if skip_the_layer:
layer_outputs = (None, None)
layer_outputs = (None, None, None)

if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[2],)
Expand Down Expand Up @@ -810,7 +810,7 @@ def forward(
hidden_states, position_bias = layer_outputs[:2]

if skip_the_layer:
layer_outputs = (None, None)
layer_outputs = (None, None, None)

if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[2],)
Expand Down
13 changes: 13 additions & 0 deletions tests/models/wavlm/test_modeling_wavlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,15 @@ def check_seq_classifier_training(self, config, input_values, *args):

loss.backward()

def check_output_attentions(self, config, input_values, attention_mask):
model = WavLMModel(config=config)
model.config.layerdrop = 1.0
model.to(torch_device)
model.train()

outputs = model(input_values, attention_mask=attention_mask, output_attentions=True)
self.parent.assertTrue(len(outputs.attentions) > 0)

def check_labels_out_of_vocab(self, config, input_values, *args):
model = WavLMForCTC(config)
model.to(torch_device)
Expand Down Expand Up @@ -354,6 +363,10 @@ def test_seq_classifier_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_seq_classifier_training(*config_and_inputs)

def test_output_attentions(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_output_attentions(*config_and_inputs)

def test_labels_out_of_vocab(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
Expand Down

0 comments on commit 4c94093

Please sign in to comment.