diff --git a/tf2jax/_src/ops_test.py b/tf2jax/_src/ops_test.py index 4e72955..2136c61 100644 --- a/tf2jax/_src/ops_test.py +++ b/tf2jax/_src/ops_test.py @@ -829,17 +829,29 @@ def band_part(x): self._test_convert(band_part, [inputs]) @chex.variants(with_jit=True, without_jit=True) - def test_matrix_diag(self): + @parameterized.parameters(np.float32, np.int32, np.bool_) + def test_matrix_diag(self, dtype): np.random.seed(42) - inputs = np.random.normal(size=[10, 3, 4]).astype(np.float32) + + if dtype == np.float32: + inputs = np.random.normal(size=[10, 3, 4]).astype(dtype) + padding = dtype(47) + elif dtype == np.int32: + inputs = np.random.randint(low=0, high=10, size=[10, 3, 4], dtype=dtype) + padding = dtype(47) + elif dtype == np.bool_: + inputs = np.random.normal(size=[10, 3, 4]) > 0.0 + padding = np.bool_(False) + else: + raise ValueError(f"Unsupported dtype={dtype}") def raw_func(x): return tf.raw_ops.MatrixDiagV3( - diagonal=x, k=-2, num_rows=-1, num_cols=-1, padding_value=47) + diagonal=x, k=-2, num_rows=-1, num_cols=-1, padding_value=padding) self._test_convert(raw_func, inputs) def tf_func(x): - return tf.linalg.diag(x, k=-2, padding_value=42) + return tf.linalg.diag(x, k=-2, padding_value=padding) self._test_convert(tf_func, inputs) @chex.variants(with_jit=True, without_jit=True) diff --git a/tf2jax/_src/tf2jax.py b/tf2jax/_src/tf2jax.py index fe4d51c..55e8f76 100644 --- a/tf2jax/_src/tf2jax.py +++ b/tf2jax/_src/tf2jax.py @@ -879,8 +879,8 @@ def _func( diag_fn = jax.vmap(diag_fn) outputs = diag_fn(diagonals) - paddings = jnp.ones_like(outputs) - diag_fn(jnp.ones_like(diagonals)) - return outputs + paddings * padding_value + mask = diag_fn(jnp.ones_like(diagonals, dtype=jnp.bool_)) + return jnp.where(mask, outputs, padding_value) return _func