Skip to content

Commit

Permalink
better naming
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Mar 15, 2022
1 parent 1dd5b9f commit 37b3720
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions tests/test_modeling_flax_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,43 +168,43 @@ def recursive_check(tuple_object, dict_object):
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})

def check_outputs(self, fxo, pto, model_class, names, context, results):
def check_outputs(self, fx_outputs, pt_outputs, model_class, names, context, results):
"""
Args:
model_class: The class of the model that is currently testing. For example, ..., etc.
Currently unused, but it could make debugging easier and faster.
names: A string, or a list of strings. These specify what fxo/pto represent in the model outputs.
names: A string, or a list of strings. These specify what fx_outputs/pt_outputs represent in the model outputs.
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
"""
if type(fxo) in [tuple, list]:
self.assertEqual(type(fxo), type(pto))
self.assertEqual(len(fxo), len(pto))
if type(fx_outputs) in [tuple, list]:
self.assertEqual(type(fx_outputs), type(pt_outputs))
self.assertEqual(len(fx_outputs), len(pt_outputs))
if type(names) == tuple:
for fo, po, name in zip(fxo, pto, names):
for fo, po, name in zip(fx_outputs, pt_outputs, names):
self.check_outputs(fo, po, model_class, names=name, context=context, results=results)
elif type(names) == str:
for idx, (fo, po) in enumerate(zip(fxo, pto)):
for idx, (fo, po) in enumerate(zip(fx_outputs, pt_outputs)):
self.check_outputs(fo, po, model_class, names=f"{names}_{idx}", context=context, results=results)
else:
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
elif isinstance(fxo, jnp.ndarray):
self.assertTrue(isinstance(pto, torch.Tensor))
elif isinstance(fx_outputs, jnp.ndarray):
self.assertTrue(isinstance(pt_outputs, torch.Tensor))

# Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fxo[fx_nans] = 0`.
fxo = np.array(fxo)
pto = pto.detach().to("cpu").numpy()
# Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
fx_outputs = np.array(fx_outputs)
pt_outputs = pt_outputs.detach().to("cpu").numpy()

fx_nans = np.isnan(fxo)
pt_nans = np.isnan(pto)
fx_nans = np.isnan(fx_outputs)
pt_nans = np.isnan(pt_outputs)

pto[fx_nans] = 0
fxo[fx_nans] = 0
pto[pt_nans] = 0
fxo[pt_nans] = 0
pt_outputs[fx_nans] = 0
fx_outputs[fx_nans] = 0
pt_outputs[pt_nans] = 0
fx_outputs[pt_nans] = 0

max_diff = np.amax(np.abs(fxo - pto))
max_diff = np.amax(np.abs(fx_outputs - pt_outputs))
self.assertLessEqual(max_diff, 1e-5)

if context not in results[model_class.__name__]:
Expand All @@ -214,7 +214,7 @@ def check_outputs(self, fxo, pto, model_class, names, context, results):
results[model_class.__name__][context][names].append(float(max_diff))
else:
raise ValueError(
f"`fxo` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fxo)} instead."
f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead."
)

@is_pt_flax_cross_test
Expand Down

0 comments on commit 37b3720

Please sign in to comment.