Skip to content

Commit

Permalink
clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Mar 17, 2022
1 parent 82b6ce0 commit 27909ca
Showing 1 changed file with 72 additions and 116 deletions.
188 changes: 72 additions & 116 deletions tests/test_modeling_flax_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def recursive_check(tuple_object, dict_object):
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})

def check_outputs(self, fx_outputs, pt_outputs, model_class, names, context, results):
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
"""
Args:
model_class: The class of the model that is currently testing. For example, ..., etc.
Expand All @@ -183,10 +183,10 @@ def check_outputs(self, fx_outputs, pt_outputs, model_class, names, context, res
self.assertEqual(len(fx_outputs), len(pt_outputs))
if type(names) == tuple:
for fo, po, name in zip(fx_outputs, pt_outputs, names):
self.check_outputs(fo, po, model_class, names=name, context=context, results=results)
self.check_outputs(fo, po, model_class, names=name)
elif type(names) == str:
for idx, (fo, po) in enumerate(zip(fx_outputs, pt_outputs)):
self.check_outputs(fo, po, model_class, names=f"{names}_{idx}", context=context, results=results)
self.check_outputs(fo, po, model_class, names=f"{names}_{idx}")
else:
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
elif isinstance(fx_outputs, jnp.ndarray):
Expand All @@ -206,12 +206,6 @@ def check_outputs(self, fx_outputs, pt_outputs, model_class, names, context, res

max_diff = np.amax(np.abs(fx_outputs - pt_outputs))
self.assertLessEqual(max_diff, 1e-5)

if context not in results[model_class.__name__]:
results[model_class.__name__][context] = {}
if names not in results[model_class.__name__][context]:
results[model_class.__name__][context][names] = []
results[model_class.__name__][context][names].append(float(max_diff))
else:
raise ValueError(
f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead."
Expand All @@ -223,150 +217,112 @@ def test_equivalence_pt_to_flax(self):
# But logically, it is fine.
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

results = {}

for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):

results[model_class.__name__] = {}

for _ in range(3):

# Output all for aggressive testing
config.output_hidden_states = True

# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()}
# Output all for aggressive testing
config.output_hidden_states = True

# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)

pt_model = pt_model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model = model_class(config, dtype=jnp.float32)

fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state

# send pytorch model to the correct device
pt_model.to(torch_device)
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()}

with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**prepared_inputs_dict)
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)

fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
pt_model = pt_model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model = model_class(config, dtype=jnp.float32)

self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys, context="load_via_memory", results=results)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state

with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
# send pytorch model to the correct device
pt_model.to(torch_device)

fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**prepared_inputs_dict)

fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])

self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs_loaded.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys, context="load_via_disk", results=results)
self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)

if len(self.all_model_classes) > 0:
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)

for model_class_name in results:
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict)

for context in results[model_class_name]:
for names in results[model_class_name][context]:
results[model_class_name][context][names] = float(
np.amax(np.array(results[model_class_name][context][names])))
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])

import json
with open(f"pt_to_flax_test_results_{type(self).__name__}.json", "w", encoding="UTF-8") as fp:
json.dump(results, fp, ensure_ascii=False, indent=4)
self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs_loaded.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)

@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

results = {}

for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):

results[model_class.__name__] = {}

for _ in range(3):

# Output all for aggressive testing
config.output_hidden_states = True
# Pure convolutional models have no attention

# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()}

# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
# Output all for aggressive testing
config.output_hidden_states = True
# Pure convolutional models have no attention

pt_model = pt_model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model = model_class(config, dtype=jnp.float32)

pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)

# make sure weights are tied in PyTorch
pt_model.tie_weights()
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()}

# send pytorch model to the correct device
pt_model.to(torch_device)
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)

with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**prepared_inputs_dict)
pt_model = pt_model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model = model_class(config, dtype=jnp.float32)

fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)

self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys, context="load_via_memory", results=results)
# make sure weights are tied in PyTorch
pt_model.tie_weights()

with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
# send pytorch model to the correct device
pt_model.to(torch_device)

# send pytorch model to the correct device
pt_model_loaded.to(torch_device)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**prepared_inputs_dict)

with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])

fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)

self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs.to_tuple(), pt_outputs_loaded.to_tuple(), model_class, names=fx_keys, context="load_via_disk", results=results)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)

if len(self.all_model_classes) > 0:
# send pytorch model to the correct device
pt_model_loaded.to(torch_device)

for model_class_name in results:
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs)

for context in results[model_class_name]:
for names in results[model_class_name][context]:
results[model_class_name][context][names] = float(
np.amax(np.array(results[model_class_name][context][names])))
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])

import json
with open(f"flax_to_pt_test_results_{type(self).__name__}.json", "w", encoding="UTF-8") as fp:
json.dump(results, fp, ensure_ascii=False, indent=4)
self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs.to_tuple(), pt_outputs_loaded.to_tuple(), model_class, names=fx_keys)

def test_from_pretrained_save_pretrained(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
Expand Down

0 comments on commit 27909ca

Please sign in to comment.