Skip to content

Commit

Permalink
Merge pull request #6018 from jakevdp:conv-elem-type
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 362956294
  • Loading branch information
jax authors committed Mar 15, 2021
2 parents 3fb6a11 + 04bf02a commit 80966fe
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
9 changes: 6 additions & 3 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,14 +421,17 @@ def convert_element_type(operand: Array, new_dtype: DType = None,
Returns:
An array with the same shape as `operand`, cast elementwise to `new_dtype`.
"""
new_dtype = dtypes.canonicalize_dtype(new_dtype or _dtype(operand))
if hasattr(operand, '__jax_array__'):
operand = operand.__jax_array__()
new_weak_type = bool(weak_type)

old_dtype = dtypes.canonicalize_dtype(_dtype(operand))
# Note: don't canonicalize old_dtype because x64 context might
# cause un-canonicalized operands to be passed in.
old_dtype = np.result_type(operand)
old_weak_type = dtypes.is_weakly_typed(operand)

new_dtype = dtypes.canonicalize_dtype(new_dtype or old_dtype)
new_weak_type = bool(weak_type)

if (dtypes.issubdtype(old_dtype, np.complexfloating) and
not dtypes.issubdtype(new_dtype, np.complexfloating)):
msg = "Casting complex values to real discards the imaginary part"
Expand Down
4 changes: 2 additions & 2 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,8 +1535,8 @@ def testReduce(self, op, init_val, shape, dtype, dims):
for init_weak_type in [True, False]))
def testReduceWeakType(self, op_namespace, op, arr_weak_type, init_weak_type):
op = getattr(op_namespace, op)
arr = lax.convert_element_type(np.arange(10), np.int32, weak_type=arr_weak_type)
init = lax.convert_element_type(1, np.int32, weak_type=init_weak_type)
arr = lax.convert_element_type(np.arange(10), int, weak_type=arr_weak_type)
init = lax.convert_element_type(1, int, weak_type=init_weak_type)
fun = lambda arr, init: lax.reduce(arr, init, op, (0,))
out = fun(arr, init)
self.assertEqual(dtypes.is_weakly_typed(out), arr_weak_type and init_weak_type)
Expand Down
12 changes: 12 additions & 0 deletions tests/x64_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,5 +157,17 @@ def test_jit_cache(self):
for _ in range(2):
f()

def test_convert_element_type(self):
# Regression test for part of https://github.com/google/jax/issues/5982
with enable_x64():
x = jnp.int64(1)
self.assertEqual(x.dtype, jnp.int64)

y = x.astype(jnp.int32)
self.assertEqual(y.dtype, jnp.int32)

z = api.jit(lambda x: x.astype(jnp.int32))(x)
self.assertEqual(z.dtype, jnp.int32)

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 80966fe

Please sign in to comment.