From 0f0eb931cbcbaae70726653cfb32f1935e41ad50 Mon Sep 17 00:00:00 2001 From: Shaobo Hou Date: Wed, 6 Sep 2023 02:31:13 -0700 Subject: [PATCH] Avoid reuse of polymorphic variable names in tf2jax internal. PiperOrigin-RevId: 563039856 --- tf2jax/_src/roundtrip_test.py | 14 ++++++++++++++ tf2jax/experimental/mhlo.py | 18 +++++++++++------- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/tf2jax/_src/roundtrip_test.py b/tf2jax/_src/roundtrip_test.py index e90e4a9..e48c9ce 100644 --- a/tf2jax/_src/roundtrip_test.py +++ b/tf2jax/_src/roundtrip_test.py @@ -744,6 +744,13 @@ def forward(x, w): tf_outputs2 = concrete_tf_fn2(x, w) self.assertAllClose(expected_outputs, tf_outputs2) + # JAX -> TF -> JAX -> TF -> SavedModel + module = tf.Module() + module.fn = tf_fn2 + tmp_dir = self.create_tempdir() + options = tf.saved_model.SaveOptions(experimental_custom_gradients=True) + tf.saved_model.save(module, tmp_dir.full_path, options=options) + @chex.variants(with_jit=True) @parameterized.named_parameters( chex.params_product( @@ -796,6 +803,13 @@ def forward(x, y): tf_outputs2 = concrete_tf_fn2(x, y) self.assertAllClose(expected_outputs, tf_outputs2) + # JAX -> TF -> JAX -> TF -> SavedModel + module = tf.Module() + module.fn = tf_fn2 + tmp_dir = self.create_tempdir() + options = tf.saved_model.SaveOptions(experimental_custom_gradients=True) + tf.saved_model.save(module, tmp_dir.full_path, options=options) + @chex.variants(with_jit=True) @parameterized.named_parameters( chex.params_product( diff --git a/tf2jax/experimental/mhlo.py b/tf2jax/experimental/mhlo.py index f662d3e..c153156 100644 --- a/tf2jax/experimental/mhlo.py +++ b/tf2jax/experimental/mhlo.py @@ -96,23 +96,24 @@ def mhlo_apply_abstract_eval( symtab = ir.SymbolTable(mhlo_module.operation) # Check we are not reusing existing dimension vars. + dynamic_count = 0 has_polymorphic = False for val in in_avals: for dim in val.shape: if not isinstance(dim, int): has_polymorphic = True if any(x.startswith(_UKNOWN_DIM_PREFIX) for x in dim.get_vars()): - raise ValueError( - "Polymorphic variable name that start with" - f" `{_UKNOWN_DIM_PREFIX}` are reserved for use by tf2jax" - f" internal for outputs: `{val.shape}`" - ) + for dim in dim.get_vars(): + if dim.startswith(_UKNOWN_DIM_PREFIX): + dynamic_count = max( + dynamic_count, + (int(dim.removeprefix(_UKNOWN_DIM_PREFIX + "_"))), + ) # Map each `dynamic`` dimension to a unique dimension variable because we # do not have the information from the avals of the original JAX function. # In practice, the output shapes may actually be much more constrained, but # the information is not available here. - dynamic_count = 0 output_specs = [] for res in symtab["main"].type.results: if any(dim == res.get_dynamic_size() for dim in res.shape): @@ -124,7 +125,10 @@ def mhlo_apply_abstract_eval( ) assert has_polymorphic, has_polymorphic - from jax.experimental.jax2tf import shape_poly # pylint: disable=g-import-not-at-top # pytype: disable=import-error + if jax.__version_info__ <= (0, 4, 14): + from jax.experimental.jax2tf import shape_poly # pylint: disable=g-import-not-at-top # pytype: disable=import-error + else: + from jax.experimental.export import shape_poly # pylint: disable=g-import-not-at-top # pytype: disable=import-error out_shape = shape_poly._parse_spec(out_shape, res.shape) # pylint: disable=protected-access else: out_shape = res.shape