Skip to content

Commit

Permalink
Merge pull request #6165 from google:convert-element-type-impl
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 364239447
  • Loading branch information
jax authors committed Mar 22, 2021
2 parents af59542 + 214d273 commit 555aba8
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 5 deletions.
12 changes: 10 additions & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,8 @@ def convert_element_type(operand: Array, new_dtype: DType = None,
if hasattr(operand, '__jax_array__'):
operand = operand.__jax_array__()

# Note: don't canonicalize old_dtype because x64 context might
# cause un-canonicalized operands to be passed in.
# Don't canonicalize old_dtype because x64 context might cause
# un-canonicalized operands to be passed in.
old_dtype = np.result_type(operand)
old_weak_type = dtypes.is_weakly_typed(operand)

Expand All @@ -441,6 +441,14 @@ def convert_element_type(operand: Array, new_dtype: DType = None,
msg = "Casting complex values to real discards the imaginary part"
warnings.warn(msg, np.ComplexWarning, stacklevel=2)

# Python has big integers, but convert_element_type(2 ** 100, np.float32) need
# not be an error since the target dtype fits the value. Handle this case by
# converting to a NumPy array before calling bind. Without this step, we'd
# first canonicalize the input to a value of dtype int32 or int64, leading to
# an overflow error.
if type(operand) is int:
operand = np.asarray(operand, new_dtype)

if ((old_dtype, old_weak_type) == (new_dtype, new_weak_type)
and isinstance(operand, (core.Tracer, xla.DeviceArray))):
return operand
Expand Down
5 changes: 2 additions & 3 deletions jax/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,8 @@ def _zeros_like_python_scalar(t, x):
return np.array(0, dtypes.python_scalar_dtypes[t])

def _make_concrete_python_scalar(t, x):
return ConcreteArray(
np.array(x, dtype=dtypes.python_scalar_dtypes[t]),
weak_type=True)
return ConcreteArray(np.array(x, dtype=dtypes.python_scalar_dtypes[t]),
weak_type=True)

for t in dtypes.python_scalar_dtypes:
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
Expand Down
6 changes: 6 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2393,6 +2393,12 @@ def f(_):
expected = jnp.arange(1) + 1
self.assertAllClose(ans, expected)

def test_large_python_int_to_float(self):
# https://github.com/google/jax/pull/6165
jnp.multiply(2 ** 100, 3.) # doesn't crash
out = lax.convert_element_type(2 ** 100, jnp.float32) # doesn't crash
self.assertArraysEqual(out, np.float32(2 ** 100))


class RematTest(jtu.JaxTestCase):

Expand Down
2 changes: 2 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,8 @@ def f(x):
grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2))

def testNoOpByOpUnderHash(self):
if not config.omnistaging_enabled:
raise SkipTest("test requires omnistaging")
def fail(*args, **kwargs): assert False
apply_primitive, xla.apply_primitive = xla.apply_primitive, fail
try:
Expand Down

0 comments on commit 555aba8

Please sign in to comment.