Skip to content

Commit

Permalink
Merge pull request #8552 from mattjj:elide-more-convert-element-types
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 411082070
  • Loading branch information
jax authors committed Nov 19, 2021
2 parents 75e063c + abbf78b commit f08a5a0
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 41 deletions.
35 changes: 19 additions & 16 deletions docs/jaxpr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,9 @@ For the example consider the function ``func11`` below
j:f32[] = mul h i
k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g
l:f32[] = add k j
m:f32[] = add l f
in (m, g) }
m:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
n:f32[] = add l m
in (n, g) }
length=16
linear=(False, False, False, False)
num_carry=1
Expand Down Expand Up @@ -417,18 +418,19 @@ which the computation should run. For example
backend=None
call_jaxpr={ lambda ; d:f32[] e:f32[]. let
f:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
g:f32[1] = mul d f
h:f32[] = convert_element_type[new_dtype=float32 weak_type=False] e
i:f32[1] = add h g
in (i,) }
g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
h:f32[1] = mul g f
i:f32[] = convert_element_type[new_dtype=float32 weak_type=False] e
j:f32[1] = add i h
in (j,) }
device=None
donated_invars=(False, False)
inline=False
name=inner
] a b
j:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
k:f32[1] = add j c
in (k,) }
k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
l:f32[1] = add k c
in (l,) }


XLA_pmap
Expand All @@ -452,12 +454,13 @@ captured using the ``xla_pmap`` primitive. Consider this example
axis_size=1
backend=None
call_jaxpr={ lambda ; d:f32[] e:f32[3]. let
f:f32[3] = add e d
g:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
h:f32[3] = add f g
i:f32[3] = psum[axes=('rows',) axis_index_groups=None] e
j:f32[3] = div h i
in (j,) }
f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
g:f32[3] = add e f
h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
i:f32[3] = add g h
j:f32[3] = psum[axes=('rows',) axis_index_groups=None] e
k:f32[3] = div i j
in (k,) }
devices=None
donated_invars=(False, False)
global_arg_shapes=(None,)
Expand All @@ -466,7 +469,7 @@ captured using the ``xla_pmap`` primitive. Consider this example
name=inner
out_axes=(0,)
] b a
in (c,) }
in (c,) }

The ``xla_pmap`` primitive specifies the name of the axis (parameter
``axis_name``) and the body of the function to be mapped as the ``call_jaxpr``
Expand Down
28 changes: 26 additions & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2927,7 +2927,22 @@ def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type):
return convert_element_type_p.bind(tangent, new_dtype=new_dtype,
weak_type=weak_type)

convert_element_type_p = core.convert_element_type_p
def _convert_elt_type_folding_rule(consts, eqn):
c, = consts
if type(c) in core.literalable_types and not np.shape(c):
return [np.array(c, eqn.params['new_dtype'])], None
else:
return [None], eqn

def _convert_elt_type_fwd_rule(eqn):
v, = eqn.invars
if (v.aval.dtype == eqn.params['new_dtype'] and
v.aval.weak_type == eqn.params['weak_type']):
return [v], None
else:
return [None], eqn

convert_element_type_p = Primitive('convert_element_type')
convert_element_type_p.def_impl(partial(xla.apply_primitive, convert_element_type_p))
convert_element_type_p.def_abstract_eval(
partial(standard_abstract_eval, convert_element_type_p,
Expand All @@ -2939,6 +2954,8 @@ def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type):
ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule
batching.defvectorized(convert_element_type_p)
masking.defvectorized(convert_element_type_p)
pe.const_fold_rules[convert_element_type_p] = _convert_elt_type_folding_rule
pe.forwarding_rules[convert_element_type_p] = _convert_elt_type_fwd_rule


def _bitcast_convert_type_shape_rule(operand, *, new_dtype):
Expand Down Expand Up @@ -3641,11 +3658,19 @@ def _broadcast_in_dim_batch_rule(batched_args, batch_dims, *, shape,
new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions))
return broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions), 0

def _broadcast_in_dim_fwd_rule(eqn):
v, = eqn.invars
if core.symbolic_equal_shape(eqn.params['shape'], v.aval.shape):
return [v], None
else:
return [None], eqn


broadcast_in_dim_p = standard_primitive(
_broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim')
ad.deflinear2(broadcast_in_dim_p, _broadcast_in_dim_transpose_rule)
batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule
pe.forwarding_rules[broadcast_in_dim_p] = _broadcast_in_dim_fwd_rule


def _clamp_shape_rule(min, operand, max):
Expand Down Expand Up @@ -4623,7 +4648,6 @@ def _gather_fill(operand, indices, *, dimension_numbers, slice_sizes,
le(indices, expand_dims(upper_bound, tuple(range(num_batch_dims)))))
mask = _reduce_and(mask, [num_batch_dims])


# Computes the output shape and the positions of the batch dimensions in the
# output
output_ndims = num_batch_dims + len(dnums.offset_dims)
Expand Down
8 changes: 3 additions & 5 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
from ._src import traceback_util
traceback_util.register_exclusion(__file__)

zip = safe_zip
map = safe_map
zip, unsafe_zip = safe_zip, zip
map, unsafe_map = safe_map, map


# -------------------- jaxprs --------------------
Expand Down Expand Up @@ -1025,8 +1025,6 @@ def concrete_or_error(force: Any, val: Any, context=""):
else:
return force(val)

convert_element_type_p = Primitive('convert_element_type')


def _short_dtype_name(dtype):
return (dtype.name.replace('float', 'f').replace('uint', 'u')
Expand Down Expand Up @@ -1395,7 +1393,7 @@ def symbolic_equal_one_of_dim(d1: DimSize, dlist: Sequence[DimSize]) -> bool:

def symbolic_equal_shape(s1: Shape, s2: Shape) -> bool:
return (len(s1) == len(s2) and
all(map(symbolic_equal_dim, s1, s2)))
all(unsafe_map(symbolic_equal_dim, s1, s2)))

def greater_equal_dim(d1: DimSize, d2: DimSize) -> bool:
handler, ds = _dim_handler_and_canonical(d1, d2)
Expand Down
54 changes: 37 additions & 17 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,7 +1189,7 @@ def to_jaxpr(self, in_tracers, out_tracers):
outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
constvars, constvals = unzip2(self.constvar_to_val.items())
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns)
jaxpr, constvals = _prune_convert_element_types(jaxpr, constvals)
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
out_avals = [t.aval for t in out_tracers]
return jaxpr, out_avals, constvals
Expand All @@ -1212,25 +1212,45 @@ def find_progenitors(self, tracer):
const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars]
return invar_positions, const_eqns

def _prune_convert_element_types(jaxpr, constvals):
consts = dict(zip(jaxpr.constvars, constvals))
def _const_folding_and_forwarding(jaxpr, constvals):
consts: Dict[Var, Any] = dict(zip(jaxpr.constvars, constvals))
var_subs: Dict[Var, Var] = {} # not Dict[Var, Atom] b/c literals not inlined
new_eqns = []
for eqn in jaxpr.eqns:
if eqn.primitive is core.convert_element_type_p:
c = consts.get(eqn.invars[0])
if type(c) in core.literalable_types and not np.shape(c):
# constant-fold dtype conversion of literals to be inlined
consts[eqn.outvars[0]] = np.array(c, eqn.params['new_dtype'])
continue
if c is not None and dtypes.dtype(c) == eqn.params['new_dtype']:
# don't stage out no-op convert_element_type calls as clutter
consts[eqn.outvars[0]] = c
continue
# always apply invar substitutions
eqn = JaxprEqn([var_subs.get(v, v) for v in eqn.invars], eqn.outvars,
eqn.primitive, eqn.params, eqn.source_info)
# if any inputs are constants and we have a constant-folding rule, apply it
if eqn.primitive in const_fold_rules and any(v in consts for v in eqn.invars):
consts_in = [consts.get(v) for v in eqn.invars]
consts_out, new_eqn = const_fold_rules[eqn.primitive](consts_in, eqn)
assert (new_eqn is None) == all(c is not None for c in consts_out)
for v, c in zip(eqn.outvars, consts_out):
if c is not None: consts[v] = c
if new_eqn is None: continue
else: eqn = new_eqn
# if the application trivially maps some inputs to outputs, simplify
if eqn.primitive in forwarding_rules:
fwd_vars, new_eqn = forwarding_rules[eqn.primitive](eqn)
assert (new_eqn is None) == all(v is not None for v in fwd_vars)
for v_orig, v_new in zip(eqn.outvars, fwd_vars):
if v_new is not None: var_subs[v_orig] = v_new
if new_eqn is None: continue
else: eqn = new_eqn
new_eqns.append(eqn)
new_constvars, new_constvals = unzip2(consts.items())
new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, jaxpr.outvars, new_eqns)
new_outvars = [var_subs.get(v, v) for v in jaxpr.outvars]
new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, new_outvars, new_eqns)
return new_jaxpr, new_constvals

ConstFoldRule = Callable[[List[Optional[Any]], JaxprEqn],
Tuple[List[Optional[Any]], Optional[JaxprEqn]]]
const_fold_rules: Dict[Primitive, ConstFoldRule] = {}

ForwardingRule = Callable[[JaxprEqn],
Tuple[List[Optional[Var]], Optional[JaxprEqn]]]
forwarding_rules: Dict[Primitive, ForwardingRule] = {}

def _inline_literals(jaxpr, constvals):
# This function also ensures variables are labeled in a canonical ordering,
# prunes unused constants, and inserts `dropvar` symbols.
Expand Down Expand Up @@ -1280,7 +1300,7 @@ def new_const(self, val):
aval = raise_to_shaped(get_aval(val), weak_type=dtypes.is_weakly_typed(val))
tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
self.frame.tracers.append(tracer)
var = self.frame.tracer_to_var[id(tracer)] = self.getconstvar(val)
var = self.frame.tracer_to_var[id(tracer)] = self.getconstvar(aval, val)
self.frame.constvar_to_val[var] = val
return tracer

Expand All @@ -1299,10 +1319,10 @@ def makevar(self, tracer):
var = self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(tracer.aval)
return var

def getconstvar(self, c):
def getconstvar(self, aval, c):
var = self.frame.constid_to_var.get(id(c))
if var is None:
var = self.frame.constid_to_var[id(c)] = self.frame.newvar(get_aval(c))
var = self.frame.constid_to_var[id(c)] = self.frame.newvar(aval)
return var

def instantiate_const(self, val):
Expand Down
34 changes: 34 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4143,6 +4143,40 @@ def f(w, x):
self.assertEqual(shapes, expected)
self.assertIn('psum', str(jaxpr))

def test_weak_type_jit_invariance(self):
y = jnp.broadcast_to(3., (3,))
self.assertTrue(y.aval.weak_type)

def f():
return lax.convert_element_type(y, 'float32')

self.assertEqual(f().aval.weak_type, api.jit(f)().aval.weak_type)

def test_elide_trivial_convert_element_types(self):
# since we apply convert_element_type to a numpy.ndarray, the primitive is
# still bound and thus would appear in the jaxpr if we didn't clean it up
if config.x64_enabled:
x = np.arange(3, dtype='float64')
else:
x = np.arange(3, dtype='float32')

cet = partial(lax.convert_element_type, new_dtype=x.dtype)
jaxpr = api.make_jaxpr(lambda: cet(cet(cet(x))))()
self.assertLen(jaxpr.eqns, 0)

def test_elide_trivial_broadcasts(self):
# since we apply broadcast to a numpy.ndarray, the primitive is still bound
# and thus would appear in the jaxpr if we didn't clean it up
jaxpr = api.make_jaxpr(lambda: lax.broadcast(np.float32(3), ()))()
self.assertLen(jaxpr.jaxpr.eqns, 0)

def test_convert_element_type_literal_constant_folding(self):
# this convert_elemnt_type is nontrivial, but because it's on a scalar we
# constant-fold it
cet = partial(lax.convert_element_type, new_dtype='float16')
jaxpr = api.make_jaxpr(lambda: cet(3.))()
self.assertLen(jaxpr.eqns, 0)


class CustomJVPTest(jtu.JaxTestCase):

Expand Down
1 change: 0 additions & 1 deletion tests/host_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,6 @@ def func(x):
] b
_:f32[] = mul c 2.00
d:f32[] = mul 1.00 2.00
_:f32[] = broadcast_in_dim[broadcast_dimensions=() shape=()] 0.00
e:f32[] = outside_call[
arg_treedef={treedef}
callback=...
Expand Down
2 changes: 2 additions & 0 deletions tests/x64_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import concurrent.futures
from functools import partial
import time
import unittest

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -144,6 +145,7 @@ def test_jit_cache(self):
for _ in range(2):
f()

@unittest.skip("test fails, see #8552")
def test_convert_element_type(self):
# Regression test for part of https://github.com/google/jax/issues/5982
with enable_x64():
Expand Down

0 comments on commit f08a5a0

Please sign in to comment.