Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Jun 29, 2020
1 parent 0e4c802 commit 6e200ea
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
3 changes: 3 additions & 0 deletions src/transformers/modeling_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,6 +1590,9 @@ def forward(

pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
print(f"logits: {logits.shape}")
print(f"pooled_output: {pooled_output.shape}")
print(f"num_choices: {num_choices}")
reshaped_logits = logits.view(-1, num_choices)

outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
Expand Down
12 changes: 6 additions & 6 deletions tests/test_modeling_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ def prepare_config_and_inputs_for_common(self):
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
global_attention_mask = torch.zeros_like(input_ids)
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask, "global_attention_mask": global_attention_mask}
return config, inputs_dict

def prepare_config_and_inputs_for_question_answering(self):
Expand Down Expand Up @@ -319,11 +320,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
(
LongformerModel,
LongformerForMaskedLM,
# TODO: make tests pass for those models
# LongformerForSequenceClassification,
# LongformerForQuestionAnswering,
# LongformerForTokenClassification,
# LongformerForMultipleChoice,
LongformerForSequenceClassification,
LongformerForQuestionAnswering,
LongformerForTokenClassification,
LongformerForMultipleChoice,
)
if is_torch_available()
else ()
Expand Down

0 comments on commit 6e200ea

Please sign in to comment.