Skip to content

Commit

Permalink
fix convert_element_type on large Py int inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Mar 22, 2021
1 parent af59542 commit 8c3125c
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 7 deletions.
8 changes: 7 additions & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2658,6 +2658,12 @@ def _minmax_translation_rule(c, x, y, *, minmax=None, cmp=None):
ad.defjvp_zero(lt_p)


def _convert_element_type_impl(operand, *, new_dtype, weak_type):
if dtypes.is_python_scalar(operand):
operand = np.asarray(operand, dtype=new_dtype)
return xla.apply_primitive(convert_element_type_p, operand,
new_dtype=new_dtype, weak_type=weak_type)

def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type):
return operand.shape

Expand Down Expand Up @@ -2693,7 +2699,7 @@ def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type):
return convert_element_type_p.bind(tangent, new_dtype=new_dtype, weak_type=weak_type)

convert_element_type_p = core.convert_element_type_p
convert_element_type_p.def_impl(partial(xla.apply_primitive, convert_element_type_p))
convert_element_type_p.def_impl(_convert_element_type_impl)
convert_element_type_p.def_abstract_eval(
partial(standard_abstract_eval, convert_element_type_p,
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
Expand Down
5 changes: 2 additions & 3 deletions jax/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,8 @@ def _zeros_like_python_scalar(t, x):
return np.array(0, dtypes.python_scalar_dtypes[t])

def _make_concrete_python_scalar(t, x):
return ConcreteArray(
np.array(x, dtype=dtypes.python_scalar_dtypes[t]),
weak_type=True)
return ConcreteArray(np.array(x, dtype=dtypes.python_scalar_dtypes[t]),
weak_type=True)

for t in dtypes.python_scalar_dtypes:
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
Expand Down
2 changes: 2 additions & 0 deletions jax/experimental/host_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,7 @@ def _outside_call_jvp_rule(primals, tangents, **params):
if not params["identity"]:
raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
tangent_instantiated = tuple(map(_instantiate_zeros, primals, tangents))
tangent_instantiated = tuple(map(ad.replace_float0s, primals, tangent_instantiated))

arg_treedef = params["arg_treedef"]
# The argument to the jvp tap is a pair of the tapped primals and tangents
Expand All @@ -946,6 +947,7 @@ def _outside_call_jvp_rule(primals, tangents, **params):
arg_treedef=jvp_arg_treedef,
))
out_primals_tapped, out_tangents_tapped = util.split_list(out_all, [len(primals)])
out_tangents_tapped = map(ad.recast_to_float0, out_primals_tapped, out_tangents_tapped)
return tuple(out_primals_tapped), tuple(out_tangents_tapped)


Expand Down
14 changes: 14 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2393,6 +2393,20 @@ def f(_):
expected = jnp.arange(1) + 1
self.assertAllClose(ans, expected)

def test_large_python_int_to_float(self):
# https://github.com/google/jax/pull/6165
# We skip checks because otherwise we end up calling valid_jaxtype(2**100),
# which tries to form a ConcreteArray with that value and thus leads to a
# NumPy OverflowError. It's true that 2**100 does not inhabit a jax type,
# but as an issue of Python embedding we can handle operations like
# lax.convert_element_type(2 ** 100, jnp.float32) as in the tests below.
# That is, lax.convert_element_type(2 ** 100, jnp.int32) is an error while
# lax.convert_element_type(2 ** 100, jnp.float32) is not.
with jax.core.skipping_checks():
jnp.multiply(2 ** 100, 3.) # doesn't crash
out = lax.convert_element_type(2 ** 100, jnp.float32) # doesn't crash
self.assertArraysEqual(out, np.float32(2 ** 100))


class RematTest(jtu.JaxTestCase):

Expand Down
6 changes: 3 additions & 3 deletions tests/host_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,7 @@ def func(x, yint):
2 )
transforms: ['jvp', 'transpose'] what: pair
( 2.00
False )""", testing_stream.output)
0 )""", testing_stream.output)
testing_stream.reset()

def test_tap_vmap(self):
Expand Down Expand Up @@ -1590,8 +1590,8 @@ def padded_sum(x):
( 3 ) ) )
( ( [0. 0.1 0.2 0.3 0.4]
[0. 0.2 0.4 0.6 0.8] )
( ( False )
( False ) ) ) )""", testing_stream.output)
( ( 0 )
( 0 ) ) ) )""", testing_stream.output)
testing_stream.reset()

# Now with JIT
Expand Down
2 changes: 2 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,8 @@ def f(x):
grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2))

def testNoOpByOpUnderHash(self):
if not config.omnistaging_enabled:
raise SkipTest("test requires omnistaging")
def fail(*args, **kwargs): assert False
apply_primitive, xla.apply_primitive = xla.apply_primitive, fail
try:
Expand Down

0 comments on commit 8c3125c

Please sign in to comment.