diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index b7e221b3595a99..32bbefd3a1d49f 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -168,43 +168,43 @@ def recursive_check(tuple_object, dict_object): dict_inputs = self._prepare_for_class(inputs_dict, model_class) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) - def check_outputs(self, fxo, pto, model_class, names, context, results): + def check_outputs(self, fx_outputs, pt_outputs, model_class, names, context, results): """ Args: model_class: The class of the model that is currently testing. For example, ..., etc. Currently unused, but it could make debugging easier and faster. - names: A string, or a list of strings. These specify what fxo/pto represent in the model outputs. + names: A string, or a list of strings. These specify what fx_outputs/pt_outputs represent in the model outputs. Currently unused, but in the future, we could use this information to make the error message clearer by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax. """ - if type(fxo) in [tuple, list]: - self.assertEqual(type(fxo), type(pto)) - self.assertEqual(len(fxo), len(pto)) + if type(fx_outputs) in [tuple, list]: + self.assertEqual(type(fx_outputs), type(pt_outputs)) + self.assertEqual(len(fx_outputs), len(pt_outputs)) if type(names) == tuple: - for fo, po, name in zip(fxo, pto, names): + for fo, po, name in zip(fx_outputs, pt_outputs, names): self.check_outputs(fo, po, model_class, names=name, context=context, results=results) elif type(names) == str: - for idx, (fo, po) in enumerate(zip(fxo, pto)): + for idx, (fo, po) in enumerate(zip(fx_outputs, pt_outputs)): self.check_outputs(fo, po, model_class, names=f"{names}_{idx}", context=context, results=results) else: raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.") - elif isinstance(fxo, jnp.ndarray): - self.assertTrue(isinstance(pto, torch.Tensor)) + elif isinstance(fx_outputs, jnp.ndarray): + self.assertTrue(isinstance(pt_outputs, torch.Tensor)) - # Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fxo[fx_nans] = 0`. - fxo = np.array(fxo) - pto = pto.detach().to("cpu").numpy() + # Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`. + fx_outputs = np.array(fx_outputs) + pt_outputs = pt_outputs.detach().to("cpu").numpy() - fx_nans = np.isnan(fxo) - pt_nans = np.isnan(pto) + fx_nans = np.isnan(fx_outputs) + pt_nans = np.isnan(pt_outputs) - pto[fx_nans] = 0 - fxo[fx_nans] = 0 - pto[pt_nans] = 0 - fxo[pt_nans] = 0 + pt_outputs[fx_nans] = 0 + fx_outputs[fx_nans] = 0 + pt_outputs[pt_nans] = 0 + fx_outputs[pt_nans] = 0 - max_diff = np.amax(np.abs(fxo - pto)) + max_diff = np.amax(np.abs(fx_outputs - pt_outputs)) self.assertLessEqual(max_diff, 1e-5) if context not in results[model_class.__name__]: @@ -214,7 +214,7 @@ def check_outputs(self, fxo, pto, model_class, names, context, results): results[model_class.__name__][context][names].append(float(max_diff)) else: raise ValueError( - f"`fxo` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fxo)} instead." + f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead." ) @is_pt_flax_cross_test