Skip to content

Commit

Permalink
Support bool in MatrixDiagV3.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 444487420
  • Loading branch information
shaobohou authored and TF2JAXDev committed Apr 26, 2022
1 parent a217cdb commit e322b8f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
20 changes: 16 additions & 4 deletions tf2jax/_src/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,17 +829,29 @@ def band_part(x):
self._test_convert(band_part, [inputs])

@chex.variants(with_jit=True, without_jit=True)
def test_matrix_diag(self):
@parameterized.parameters(np.float32, np.int32, np.bool_)
def test_matrix_diag(self, dtype):
np.random.seed(42)
inputs = np.random.normal(size=[10, 3, 4]).astype(np.float32)

if dtype == np.float32:
inputs = np.random.normal(size=[10, 3, 4]).astype(dtype)
padding = dtype(47)
elif dtype == np.int32:
inputs = np.random.randint(low=0, high=10, size=[10, 3, 4], dtype=dtype)
padding = dtype(47)
elif dtype == np.bool_:
inputs = np.random.normal(size=[10, 3, 4]) > 0.0
padding = np.bool_(False)
else:
raise ValueError(f"Unsupported dtype={dtype}")

def raw_func(x):
return tf.raw_ops.MatrixDiagV3(
diagonal=x, k=-2, num_rows=-1, num_cols=-1, padding_value=47)
diagonal=x, k=-2, num_rows=-1, num_cols=-1, padding_value=padding)
self._test_convert(raw_func, inputs)

def tf_func(x):
return tf.linalg.diag(x, k=-2, padding_value=42)
return tf.linalg.diag(x, k=-2, padding_value=padding)
self._test_convert(tf_func, inputs)

@chex.variants(with_jit=True, without_jit=True)
Expand Down
4 changes: 2 additions & 2 deletions tf2jax/_src/tf2jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,8 +879,8 @@ def _func(
diag_fn = jax.vmap(diag_fn)

outputs = diag_fn(diagonals)
paddings = jnp.ones_like(outputs) - diag_fn(jnp.ones_like(diagonals))
return outputs + paddings * padding_value
mask = diag_fn(jnp.ones_like(diagonals, dtype=jnp.bool_))
return jnp.where(mask, outputs, padding_value)

return _func

Expand Down

0 comments on commit e322b8f

Please sign in to comment.