Skip to content

Commit

Permalink
fix convert_element_type on large inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Mar 21, 2021
1 parent af59542 commit 7a1d1e2
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
8 changes: 7 additions & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2658,6 +2658,12 @@ def _minmax_translation_rule(c, x, y, *, minmax=None, cmp=None):
ad.defjvp_zero(lt_p)


def _convert_element_type_impl(operand, *, new_dtype, weak_type):
if not isinstance(operand, xla.DeviceArray):
operand = np.asarray(operand, dtype=new_dtype)
return xla.apply_primitive(convert_element_type_p, operand,
new_dtype=new_dtype, weak_type=weak_type)

def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type):
return operand.shape

Expand Down Expand Up @@ -2693,7 +2699,7 @@ def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type):
return convert_element_type_p.bind(tangent, new_dtype=new_dtype, weak_type=weak_type)

convert_element_type_p = core.convert_element_type_p
convert_element_type_p.def_impl(partial(xla.apply_primitive, convert_element_type_p))
convert_element_type_p.def_impl(_convert_element_type_impl)
convert_element_type_p.def_abstract_eval(
partial(standard_abstract_eval, convert_element_type_p,
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
Expand Down
2 changes: 2 additions & 0 deletions jax/experimental/host_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,7 @@ def _outside_call_jvp_rule(primals, tangents, **params):
if not params["identity"]:
raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
tangent_instantiated = tuple(map(_instantiate_zeros, primals, tangents))
tangent_instantiated = tuple(map(ad.replace_float0s, primals, tangent_instantiated))

arg_treedef = params["arg_treedef"]
# The argument to the jvp tap is a pair of the tapped primals and tangents
Expand All @@ -946,6 +947,7 @@ def _outside_call_jvp_rule(primals, tangents, **params):
arg_treedef=jvp_arg_treedef,
))
out_primals_tapped, out_tangents_tapped = util.split_list(out_all, [len(primals)])
out_tangents_tapped = map(ad.recast_to_float0, out_primals_tapped, out_tangents_tapped)
return tuple(out_primals_tapped), tuple(out_tangents_tapped)


Expand Down
6 changes: 3 additions & 3 deletions tests/host_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,7 @@ def func(x, yint):
2 )
transforms: ['jvp', 'transpose'] what: pair
( 2.00
False )""", testing_stream.output)
0 )""", testing_stream.output)
testing_stream.reset()

def test_tap_vmap(self):
Expand Down Expand Up @@ -1590,8 +1590,8 @@ def padded_sum(x):
( 3 ) ) )
( ( [0. 0.1 0.2 0.3 0.4]
[0. 0.2 0.4 0.6 0.8] )
( ( False )
( False ) ) ) )""", testing_stream.output)
( ( 0 )
( 0 ) ) ) )""", testing_stream.output)
testing_stream.reset()

# Now with JIT
Expand Down

0 comments on commit 7a1d1e2

Please sign in to comment.