Skip to content

Commit

Permalink
Improve handling of input specs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 463040220
  • Loading branch information
shaobohou authored and TF2JAXDev committed Jul 25, 2022
1 parent c631489 commit 3fa9361
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion tf2jax/_src/tf2jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,12 @@ def convert_from_restored(
return convert(tf_func, *args, **kwargs)


def _is_tensorspec_like(v: Any) -> bool:
return (isinstance(v, tf.TensorSpec) or (
type(v).__module__.endswith("tensorflow.python.framework.tensor_spec") and
type(v).__name__ == "VariableSpec"))


def _fix_tfhub_specs(
structured_specs,
flat_tensors,
Expand All @@ -382,7 +388,7 @@ def _fix_tfhub_specs(
flat_specs = tree.flatten(structured_specs)
tensor_count = 0
for idx, val in enumerate(flat_specs):
if isinstance(val, tf.TensorSpec):
if _is_tensorspec_like(val):
flat_specs[idx] = tf.TensorSpec(
val.shape, val.dtype, name=flat_tensors[tensor_count].op.name)
tensor_count += 1
Expand Down

0 comments on commit 3fa9361

Please sign in to comment.