Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix convert_element_type on large Python int inputs #6165

Merged
merged 3 commits into from
Mar 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,8 @@ def convert_element_type(operand: Array, new_dtype: DType = None,
if hasattr(operand, '__jax_array__'):
operand = operand.__jax_array__()

# Note: don't canonicalize old_dtype because x64 context might
# cause un-canonicalized operands to be passed in.
# Don't canonicalize old_dtype because x64 context might cause
# un-canonicalized operands to be passed in.
old_dtype = np.result_type(operand)
old_weak_type = dtypes.is_weakly_typed(operand)

Expand All @@ -441,6 +441,14 @@ def convert_element_type(operand: Array, new_dtype: DType = None,
msg = "Casting complex values to real discards the imaginary part"
warnings.warn(msg, np.ComplexWarning, stacklevel=2)

# Python has big integers, but convert_element_type(2 ** 100, np.float32) need
# not be an error since the target dtype fits the value. Handle this case by
# converting to a NumPy array before calling bind. Without this step, we'd
# first canonicalize the input to a value of dtype int32 or int64, leading to
# an overflow error.
if type(operand) is int:
operand = np.asarray(operand, new_dtype)

if ((old_dtype, old_weak_type) == (new_dtype, new_weak_type)
and isinstance(operand, (core.Tracer, xla.DeviceArray))):
return operand
Expand Down
5 changes: 2 additions & 3 deletions jax/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,8 @@ def _zeros_like_python_scalar(t, x):
return np.array(0, dtypes.python_scalar_dtypes[t])

def _make_concrete_python_scalar(t, x):
return ConcreteArray(
np.array(x, dtype=dtypes.python_scalar_dtypes[t]),
weak_type=True)
return ConcreteArray(np.array(x, dtype=dtypes.python_scalar_dtypes[t]),
weak_type=True)

for t in dtypes.python_scalar_dtypes:
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
Expand Down
6 changes: 6 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2393,6 +2393,12 @@ def f(_):
expected = jnp.arange(1) + 1
self.assertAllClose(ans, expected)

def test_large_python_int_to_float(self):
# https://github.com/google/jax/pull/6165
jnp.multiply(2 ** 100, 3.) # doesn't crash
out = lax.convert_element_type(2 ** 100, jnp.float32) # doesn't crash
self.assertArraysEqual(out, np.float32(2 ** 100))


class RematTest(jtu.JaxTestCase):

Expand Down
2 changes: 2 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,8 @@ def f(x):
grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2))

def testNoOpByOpUnderHash(self):
if not config.omnistaging_enabled:
raise SkipTest("test requires omnistaging")
def fail(*args, **kwargs): assert False
apply_primitive, xla.apply_primitive = xla.apply_primitive, fail
try:
Expand Down