From 091b714eaa97adb31523e5760d2160cee7aaf5a0 Mon Sep 17 00:00:00 2001 From: James Mullenbach Date: Wed, 16 Aug 2023 11:02:36 -0700 Subject: [PATCH] Update expected attrs of VarHandleOp -- add debug_name. PiperOrigin-RevId: 557540794 --- tf2jax/_src/ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tf2jax/_src/ops.py b/tf2jax/_src/ops.py index a0b829a..2ea4186 100644 --- a/tf2jax/_src/ops.py +++ b/tf2jax/_src/ops.py @@ -2266,7 +2266,8 @@ def _func(x: jnp.ndarray) -> List[jnp.ndarray]: @register_operation("VarHandleOp") def _var_handle(proto): _check_attrs( - proto, {"shared_name", "container", "allowed_devices", "shape", "dtype"}) + proto, {"shared_name", "container", "allowed_devices", "shape", "dtype", + "debug_name"}) def _func(): raise ValueError(f"VarHandleOp `{proto.name}` cannot be evaluated.")