Skip to content

Commit

Permalink
convert_element_type: don't canonicalize old_dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Mar 10, 2021
1 parent cf9b77f commit 3714c76
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
9 changes: 6 additions & 3 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,14 +421,17 @@ def convert_element_type(operand: Array, new_dtype: DType = None,
Returns:
An array with the same shape as `operand`, cast elementwise to `new_dtype`.
"""
new_dtype = dtypes.canonicalize_dtype(new_dtype or _dtype(operand))
if hasattr(operand, '__jax_array__'):
operand = operand.__jax_array__()
new_weak_type = bool(weak_type)

old_dtype = dtypes.canonicalize_dtype(_dtype(operand))
# Note: don't canonicalize old_dtype because x64 context might
# cause un-canonicalized operands to be passed in.
old_dtype = np.result_type(operand)
old_weak_type = dtypes.is_weakly_typed(operand)

new_dtype = dtypes.canonicalize_dtype(new_dtype or old_dtype)
new_weak_type = bool(weak_type)

if (dtypes.issubdtype(old_dtype, np.complexfloating) and
not dtypes.issubdtype(new_dtype, np.complexfloating)):
msg = "Casting complex values to real discards the imaginary part"
Expand Down
9 changes: 9 additions & 0 deletions tests/x64_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,5 +157,14 @@ def test_jit_cache(self):
for _ in range(2):
f()

def test_convert_element_type(self):
# Regression test for part of https://github.com/google/jax/issues/5982
with enable_x64():
x = jnp.int64(1)
self.assertEqual(x.dtype, jnp.int64)

y = lax.convert_element_type(x, jnp.int32)
self.assertEqual(y.dtype, jnp.int32)

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 3714c76

Please sign in to comment.