diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index 5562701cd10da8..bf0eb47a98fcb0 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -1552,6 +1552,9 @@ def forward( pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) + print(f"logits: {logits.shape}") + print(f"pooled_output: {pooled_output.shape}") + print(f"num_choices: {num_choices}") reshaped_logits = logits.view(-1, num_choices) outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here diff --git a/tests/test_modeling_longformer.py b/tests/test_modeling_longformer.py index 59b55866543a7d..1819e1aa06057c 100644 --- a/tests/test_modeling_longformer.py +++ b/tests/test_modeling_longformer.py @@ -285,7 +285,8 @@ def prepare_config_and_inputs_for_common(self): token_labels, choice_labels, ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} + global_attention_mask = torch.zeros_like(input_ids) + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask, "global_attention_mask": global_attention_mask} return config, inputs_dict def prepare_config_and_inputs_for_question_answering(self): @@ -319,11 +320,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase): ( LongformerModel, LongformerForMaskedLM, - # TODO: make tests pass for those models - # LongformerForSequenceClassification, - # LongformerForQuestionAnswering, - # LongformerForTokenClassification, - # LongformerForMultipleChoice, + LongformerForSequenceClassification, + LongformerForQuestionAnswering, + LongformerForTokenClassification, + LongformerForMultipleChoice, ) if is_torch_available() else ()