Skip to content

Commit

Permalink
Avoid reuse of polymorphic variable names in tf2jax internal.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 563039856
  • Loading branch information
shaobohou authored and TF2JAXDev committed Sep 7, 2023
1 parent 749b21b commit 0f0eb93
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
14 changes: 14 additions & 0 deletions tf2jax/_src/roundtrip_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,13 @@ def forward(x, w):
tf_outputs2 = concrete_tf_fn2(x, w)
self.assertAllClose(expected_outputs, tf_outputs2)

# JAX -> TF -> JAX -> TF -> SavedModel
module = tf.Module()
module.fn = tf_fn2
tmp_dir = self.create_tempdir()
options = tf.saved_model.SaveOptions(experimental_custom_gradients=True)
tf.saved_model.save(module, tmp_dir.full_path, options=options)

@chex.variants(with_jit=True)
@parameterized.named_parameters(
chex.params_product(
Expand Down Expand Up @@ -796,6 +803,13 @@ def forward(x, y):
tf_outputs2 = concrete_tf_fn2(x, y)
self.assertAllClose(expected_outputs, tf_outputs2)

# JAX -> TF -> JAX -> TF -> SavedModel
module = tf.Module()
module.fn = tf_fn2
tmp_dir = self.create_tempdir()
options = tf.saved_model.SaveOptions(experimental_custom_gradients=True)
tf.saved_model.save(module, tmp_dir.full_path, options=options)

@chex.variants(with_jit=True)
@parameterized.named_parameters(
chex.params_product(
Expand Down
18 changes: 11 additions & 7 deletions tf2jax/experimental/mhlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,23 +96,24 @@ def mhlo_apply_abstract_eval(
symtab = ir.SymbolTable(mhlo_module.operation)

# Check we are not reusing existing dimension vars.
dynamic_count = 0
has_polymorphic = False
for val in in_avals:
for dim in val.shape:
if not isinstance(dim, int):
has_polymorphic = True
if any(x.startswith(_UKNOWN_DIM_PREFIX) for x in dim.get_vars()):
raise ValueError(
"Polymorphic variable name that start with"
f" `{_UKNOWN_DIM_PREFIX}` are reserved for use by tf2jax"
f" internal for outputs: `{val.shape}`"
)
for dim in dim.get_vars():
if dim.startswith(_UKNOWN_DIM_PREFIX):
dynamic_count = max(
dynamic_count,
(int(dim.removeprefix(_UKNOWN_DIM_PREFIX + "_"))),
)

# Map each `dynamic`` dimension to a unique dimension variable because we
# do not have the information from the avals of the original JAX function.
# In practice, the output shapes may actually be much more constrained, but
# the information is not available here.
dynamic_count = 0
output_specs = []
for res in symtab["main"].type.results:
if any(dim == res.get_dynamic_size() for dim in res.shape):
Expand All @@ -124,7 +125,10 @@ def mhlo_apply_abstract_eval(
)

assert has_polymorphic, has_polymorphic
from jax.experimental.jax2tf import shape_poly # pylint: disable=g-import-not-at-top # pytype: disable=import-error
if jax.__version_info__ <= (0, 4, 14):
from jax.experimental.jax2tf import shape_poly # pylint: disable=g-import-not-at-top # pytype: disable=import-error
else:
from jax.experimental.export import shape_poly # pylint: disable=g-import-not-at-top # pytype: disable=import-error
out_shape = shape_poly._parse_spec(out_shape, res.shape) # pylint: disable=protected-access
else:
out_shape = res.shape
Expand Down

0 comments on commit 0f0eb93

Please sign in to comment.