Skip to content

Commit

Permalink
don't device transfer in convert_element_type
Browse files Browse the repository at this point in the history
Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
  • Loading branch information
mattjj and zhangqiaorjc committed Mar 16, 2021
1 parent 3b7de31 commit e6f34c9
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 59 deletions.
17 changes: 9 additions & 8 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,10 +437,8 @@ def convert_element_type(operand: Array, new_dtype: DType = None,
msg = "Casting complex values to real discards the imaginary part"
warnings.warn(msg, np.ComplexWarning, stacklevel=2)

if not isinstance(operand, (core.Tracer, xla.DeviceArray)):
return _device_put_raw(np.asarray(operand, dtype=new_dtype),
weak_type=new_weak_type)
elif (old_dtype, old_weak_type) == (new_dtype, new_weak_type):
if ((old_dtype, old_weak_type) == (new_dtype, new_weak_type)
and isinstance(operand, (core.Tracer, xla.DeviceArray))):
return operand
else:
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
Expand Down Expand Up @@ -2687,10 +2685,13 @@ def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type):
else:
return convert_element_type_p.bind(tangent, new_dtype=new_dtype, weak_type=weak_type)

convert_element_type_p = standard_primitive(
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
'convert_element_type', _convert_element_type_translation_rule,
weak_type_rule=_convert_element_type_weak_type_rule)
convert_element_type_p = core.convert_element_type_p
convert_element_type_p.def_impl(partial(xla.apply_primitive, convert_element_type_p))
convert_element_type_p.def_abstract_eval(
partial(standard_abstract_eval, convert_element_type_p,
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
_convert_element_type_weak_type_rule, standard_named_shape_rule))
xla.translations[convert_element_type_p] = _convert_element_type_translation_rule
ad.defjvp(convert_element_type_p, _convert_element_type_jvp_rule)
ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule
batching.defvectorized(convert_element_type_p)
Expand Down
9 changes: 5 additions & 4 deletions jax/_src/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,14 @@ def entr(x):
@_wraps(osp_special.multigammaln, update_doc=False)
def multigammaln(a, d):
d = core.concrete_or_error(int, d, "d argument of multigammaln")
a, d = _promote_args_inexact("multigammaln", a, d)
a, d_ = _promote_args_inexact("multigammaln", a, d)

constant = lax.mul(lax.mul(lax.mul(_constant_like(a, 0.25), d),
lax.sub(d, _constant_like(a, 1))),
constant = lax.mul(lax.mul(lax.mul(_constant_like(a, 0.25), d_),
lax.sub(d_, _constant_like(a, 1))),
lax.log(_constant_like(a, np.pi)))
res = jnp.sum(gammaln(jnp.expand_dims(a, axis=-1) -
lax.div(jnp.arange(d), _constant_like(a, 2))),
lax.div(jnp.arange(d, dtype=d_.dtype),
_constant_like(a, 2))),
axis=-1)
return res + constant

Expand Down
2 changes: 2 additions & 0 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,8 @@ def concrete_or_error(force: Any, val: Any, context=""):
else:
return force(val)

convert_element_type_p = Primitive('convert_element_type')

class UnshapedArray(AbstractValue):
__slots__ = ['dtype', 'weak_type']
array_abstraction_level = 2
Expand Down
15 changes: 11 additions & 4 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,10 +1003,17 @@ def lit(var: core.Var) -> Optional[Any]:
new_constvars = [var[v] for v in jaxpr.constvars if not lit(v)]
new_constvals = [c for v, c in zip(jaxpr.constvars, constvals) if not lit(v)]
new_invars = [var[v] for v in jaxpr.invars]
new_eqns = [new_jaxpr_eqn([lit(v) or var[v] for v in eqn.invars],
[var[v] if v in used else dropvar for v in eqn.outvars],
eqn.primitive, eqn.params, eqn.source_info)
for eqn in jaxpr.eqns]
new_eqns = []
for eqn in jaxpr.eqns:
invars = [lit(v) or var[v] for v in eqn.invars]
if (eqn.primitive is core.convert_element_type_p and type(invars[0]) is Literal):
# constant-fold dtype conversion of literals to be inlined
consts[eqn.outvars[0]] = np.array(invars[0].val, eqn.params['new_dtype'])
else:
# might do DCE here, but we won't until we're more careful about effects
outvars = [var[v] if v in used else dropvar for v in eqn.outvars]
new_eqns.append(new_jaxpr_eqn(invars, outvars, eqn.primitive, eqn.params,
eqn.source_info))
new_outvars = [lit(v) or var[v] for v in jaxpr.outvars]
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns)
return new_jaxpr, new_constvals
Expand Down
58 changes: 24 additions & 34 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4635,48 +4635,38 @@ def t_(_, t): return id_(7.)
class InvertibleADTest(jtu.JaxTestCase):

def test_invertible_basic(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("Test requires omnistaging")

def f(x):
return (jnp.exp(x) * 4) * x

finv = jax.invertible(f)

x = jnp.ones((5,))

if config.omnistaging_enabled:
expected = """
{ lambda ; a b.
let c = exp a
d = mul c 4.0
e = mul d a
f = mul b a
g = div e a
h = mul b g
i = div g 4.0
j = mul f 4.0
_ = log i
k = mul j i
l = add_any h k
in (l,) }
"""
else:
expected = """
{ lambda ; a b.
let c = exp a
d = mul c 4.0
e = mul d a
f = div e a
g = mul b f
h = mul b a
i = mul h 4.0
j = div f 4.0
k = mul i j
l = add_any g k
in (l,) }
"""

jaxpr = jax.make_jaxpr(lambda p, ct: jax.vjp(finv, p)[1](ct))(x, x)
self.assertMultiLineStrippedEqual(expected, str(jaxpr))

# expected = """
# { lambda ; a b.
# let c = exp a
# d = mul c 4.0
# e = mul d a
# f = mul b a
# g = div e a
# h = mul b g
# i = mul f 4.0
# j = div g 4.0
# k = mul f j
# _ = reduce_sum[ axes=(0,) ] k
# _ = log j
# l = mul i j
# m = add_any h l
# in (m,) }
# """
# self.assertMultiLineStrippedEqual(expected, str(jaxpr)) # no jaxpr test

self.assertIn('div', str(jaxpr))
self.assertIn('log', str(jaxpr)) # assumes no DCE
self.assertAllClose(jax.value_and_grad(lambda x: np.sum(f(x)))(x),
jax.value_and_grad(lambda x: np.sum(finv(x)))(x),
check_dtypes=True)
Expand Down
12 changes: 3 additions & 9 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,18 +969,12 @@ def test_prng_errors(self):
api.jit(random.PRNGKey)(seed)

def test_random_split_doesnt_device_put_during_tracing(self):
raise SkipTest("broken test") # TODO(mattjj): fix

if not config.omnistaging_enabled:
raise SkipTest("test is omnistaging-specific")

key = random.PRNGKey(1)
raise SkipTest("test requires omnistaging")
key = random.PRNGKey(1).block_until_ready()
with jtu.count_device_put() as count:
api.jit(random.split)(key)
key, _ = random.split(key, 2)
self.assertEqual(count[0], 1) # 1 for the argument device_put call


self.assertEqual(count[0], 1) # 1 for the argument device_put


if __name__ == "__main__":
Expand Down

0 comments on commit e6f34c9

Please sign in to comment.