From 6e200ea011cdbd2d5fe362bc954032415b289654 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 29 Jun 2020 17:29:11 +0000 Subject: [PATCH] fix tests --- src/transformers/modeling_longformer.py | 3 +++ tests/test_modeling_longformer.py | 12 ++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index 305ca4b4924014..dbcbe70a042e57 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -1590,6 +1590,9 @@ def forward( pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) + print(f"logits: {logits.shape}") + print(f"pooled_output: {pooled_output.shape}") + print(f"num_choices: {num_choices}") reshaped_logits = logits.view(-1, num_choices) outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here diff --git a/tests/test_modeling_longformer.py b/tests/test_modeling_longformer.py index 59b55866543a7d..1819e1aa06057c 100644 --- a/tests/test_modeling_longformer.py +++ b/tests/test_modeling_longformer.py @@ -285,7 +285,8 @@ def prepare_config_and_inputs_for_common(self): token_labels, choice_labels, ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} + global_attention_mask = torch.zeros_like(input_ids) + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask, "global_attention_mask": global_attention_mask} return config, inputs_dict def prepare_config_and_inputs_for_question_answering(self): @@ -319,11 +320,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase): ( LongformerModel, LongformerForMaskedLM, - # TODO: make tests pass for those models - # LongformerForSequenceClassification, - # LongformerForQuestionAnswering, - # LongformerForTokenClassification, - # LongformerForMultipleChoice, + LongformerForSequenceClassification, + LongformerForQuestionAnswering, + LongformerForTokenClassification, + LongformerForMultipleChoice, ) if is_torch_available() else ()