Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generalize jaxpr simplification machinery, fix convert_element_type simplification and add one for broadcast #8552

Merged

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Nov 16, 2021

This started as an attempt to simplify some jaxpr pretty-prints, by

  1. eliding some convert_element_type applications that I thought were unnecessary, and
  2. eliding trivial broadcast_in_dims which were causing additional clutter.

But it turned out that we were actually pruning more convert_element_types than we should have done! In particular, see test_weak_type_jit_invariance; that test fails on the main branch even if we add the fixes in DynamicJaxprTrace.new_const, because this logic from #6014 was not paying attention to weak types and hence clobbered them. Or here's a runnable example showing how we were losing jit-invariance:

import jax
import jax.numpy as jnp
from jax import lax
import numpy as np

jax.config.update('jax_platform_name', 'cpu')

y = jnp.broadcast_to(3., (3,))
print(y.aval.weak_type)

def f():
  return lax.convert_element_type(y, 'float32')

print(f().aval.weak_type)
print(jax.jit(f)().aval.weak_type)

In addition to fixing those bugs that turned up (the changes in DynamicJaxprTrace, and in what is now _convert_elt_type_fwd_rule), this PR generalizes the jaxpr simplification machinery so as not to be special-cased on convert_element_type_p. Instead, we have tables of rules! How we love them.

These rule signatures should let us add simplifications like forwarding variables through calls and other higher-order primitives. That's all future work though.


I had to skip a test in x64_context_tests.py because the test was already creating incorrect jaxprs and "succeeding by accident". The jaxprs were incorrect for the reasons we already understand. One way to see the issue in the test is to do this:

import jax
import jax.numpy as jnp
from jax.experimental.x64_context import enable_x64

with enable_x64():
  x = jnp.int64(1)

jaxpr = jax.make_jaxpr(lambda x: x.astype('int32'))(x)
print(jaxpr)  # { lambda ; a:i32[]. let  in (a,) }

The jaxpr's input binder is incorrectly typed as an i32[], even tough it's an i64[]. That happens on main as well as on this branch. But on main we just happened to still perform the conversion, even though in the jaxpr it's trivial, because we (inconsistently) didn't elide it. I suspect I ran into this issue in #6014 and didn't take the time to understand it, instead only adding elision of trivial convert_element_types for constants instead. The reason constants worked is that we query the actual value for the dtype, rather than the incorrectly-dtyped variable.

Anyway, since we don't plan to keep enable_x64 around for long (as it's busted in other ways too), skipping this test seems best. An alternative would be to continue (for now) the policy of not eliding trivial convert_element_types except those applied to constants.

@google-cla google-cla bot added the cla: yes label Nov 16, 2021
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Nov 16, 2021
@mattjj mattjj force-pushed the elide-more-convert-element-types branch from 4d9f277 to c4e9bb0 Compare November 16, 2021 06:43
@mattjj mattjj force-pushed the elide-more-convert-element-types branch 2 times, most recently from cc1d20d to 7332600 Compare November 17, 2021 05:18
@mattjj mattjj changed the title elide trivial convert_element_types generalize jaxpr simplification machinery, fix convert_element_type simplification and add one for broadcast Nov 17, 2021
@mattjj mattjj requested a review from jekbradbury November 17, 2021 05:21
@@ -1025,8 +1025,6 @@ def concrete_or_error(force: Any, val: Any, context=""):
else:
return force(val)

convert_element_type_p = Primitive('convert_element_type')
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was only moved here in #6014 out of necessity, because we had special-cased some jaxpr simplifications on it in partial_eval.py.

@mattjj mattjj force-pushed the elide-more-convert-element-types branch from 7332600 to 0aab652 Compare November 17, 2021 05:50
@mattjj mattjj force-pushed the elide-more-convert-element-types branch 3 times, most recently from 275e106 to bd9cb53 Compare November 19, 2021 03:20
@mattjj mattjj added pull ready Ready for copybara import and testing and removed pull ready Ready for copybara import and testing labels Nov 19, 2021
also:
* fix jit invariance bug around weak types
* elide trivial broadcasts

This started as an attempt to simplify some jaxpr pretty-prints, by (1)
eliding some convert_element_type applications that I thought were
unnecessary and (2) eliding some trivial broadcasts.

But it turned out that we were actually pruning more
convert_element_types than we should! In particular, see
test_weak_type_jit_invariance; that test fails on the main branch even
if we add the fixes in DynamicJaxprTrace.new_const, because [this
logic](https://github.com/google/jax/blob/b53a1740428a1b44d2b9f7694a00263918e6a309/jax/interpreters/partial_eval.py#L1225)
was not paying attention to weak types and hence clobbered them.

In addition to fixing those bugs that turned up (the changes in
DynamicJaxprTrace, and in what is now _convert_elt_type_fwd_rule), this
PR generalizes the jaxpr simplification machinery so as not to be a
couple special cases on convert_element_type_p. Insetad, we have tables
of rules! How we love them.

These rule signatures should let us add simplifications like forwarding
variables through calls and other higher-order primitives. That's all
future work though.
@mattjj mattjj force-pushed the elide-more-convert-element-types branch from bd9cb53 to abbf78b Compare November 19, 2021 17:01
@copybara-service copybara-service bot merged commit f08a5a0 into jax-ml:main Nov 19, 2021
@mattjj mattjj deleted the elide-more-convert-element-types branch November 19, 2021 18:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants