diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4df1f75011ae11..5e6705eaccafe2 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -812,7 +812,7 @@ def test_multigpu_data_parallel_forward(self): # Wrap model in nn.DataParallel model = torch.nn.DataParallel(model) with torch.no_grad(): - _ = model(**inputs_dict) + _ = model(**self._prepare_for_class(inputs_dict, model_class)) global_rng = random.Random() diff --git a/tests/test_modeling_longformer.py b/tests/test_modeling_longformer.py index 1819e1aa06057c..fed22060501c4f 100644 --- a/tests/test_modeling_longformer.py +++ b/tests/test_modeling_longformer.py @@ -286,7 +286,12 @@ def prepare_config_and_inputs_for_common(self): choice_labels, ) = config_and_inputs global_attention_mask = torch.zeros_like(input_ids) - inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask, "global_attention_mask": global_attention_mask} + inputs_dict = { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_mask": input_mask, + "global_attention_mask": global_attention_mask, + } return config, inputs_dict def prepare_config_and_inputs_for_question_answering(self):