diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 50d6a468b23e..095d84d32690 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -421,14 +421,17 @@ def convert_element_type(operand: Array, new_dtype: DType = None, Returns: An array with the same shape as `operand`, cast elementwise to `new_dtype`. """ - new_dtype = dtypes.canonicalize_dtype(new_dtype or _dtype(operand)) if hasattr(operand, '__jax_array__'): operand = operand.__jax_array__() - new_weak_type = bool(weak_type) - old_dtype = dtypes.canonicalize_dtype(_dtype(operand)) + # Note: don't canonicalize old_dtype because x64 context might + # cause un-canonicalized operands to be passed in. + old_dtype = np.result_type(operand) old_weak_type = dtypes.is_weakly_typed(operand) + new_dtype = dtypes.canonicalize_dtype(new_dtype or old_dtype) + new_weak_type = bool(weak_type) + if (dtypes.issubdtype(old_dtype, np.complexfloating) and not dtypes.issubdtype(new_dtype, np.complexfloating)): msg = "Casting complex values to real discards the imaginary part" diff --git a/tests/lax_test.py b/tests/lax_test.py index 85a7e7c1083e..4b81c2c486e0 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -1535,8 +1535,8 @@ def testReduce(self, op, init_val, shape, dtype, dims): for init_weak_type in [True, False])) def testReduceWeakType(self, op_namespace, op, arr_weak_type, init_weak_type): op = getattr(op_namespace, op) - arr = lax.convert_element_type(np.arange(10), np.int32, weak_type=arr_weak_type) - init = lax.convert_element_type(1, np.int32, weak_type=init_weak_type) + arr = lax.convert_element_type(np.arange(10), int, weak_type=arr_weak_type) + init = lax.convert_element_type(1, int, weak_type=init_weak_type) fun = lambda arr, init: lax.reduce(arr, init, op, (0,)) out = fun(arr, init) self.assertEqual(dtypes.is_weakly_typed(out), arr_weak_type and init_weak_type) diff --git a/tests/x64_context_test.py b/tests/x64_context_test.py index 00e80abfd227..6b653e477b7f 100644 --- a/tests/x64_context_test.py +++ b/tests/x64_context_test.py @@ -157,5 +157,17 @@ def test_jit_cache(self): for _ in range(2): f() + def test_convert_element_type(self): + # Regression test for part of https://github.com/google/jax/issues/5982 + with enable_x64(): + x = jnp.int64(1) + self.assertEqual(x.dtype, jnp.int64) + + y = x.astype(jnp.int32) + self.assertEqual(y.dtype, jnp.int32) + + z = api.jit(lambda x: x.astype(jnp.int32))(x) + self.assertEqual(z.dtype, jnp.int32) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())