generalize jaxpr simplification machinery, fix convert_element_type simplification and add one for broadcast #8552
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This started as an attempt to simplify some jaxpr pretty-prints, by
convert_element_type
applications that I thought were unnecessary, andbroadcast_in_dim
s which were causing additional clutter.But it turned out that we were actually pruning more
convert_element_type
s than we should have done! In particular, seetest_weak_type_jit_invariance
; that test fails on the main branch even if we add the fixes inDynamicJaxprTrace.new_const
, because this logic from #6014 was not paying attention to weak types and hence clobbered them. Or here's a runnable example showing how we were losingjit
-invariance:In addition to fixing those bugs that turned up (the changes in
DynamicJaxprTrace
, and in what is now_convert_elt_type_fwd_rule
), this PR generalizes the jaxpr simplification machinery so as not to be special-cased onconvert_element_type_p
. Instead, we have tables of rules! How we love them.These rule signatures should let us add simplifications like forwarding variables through calls and other higher-order primitives. That's all future work though.
I had to skip a test in x64_context_tests.py because the test was already creating incorrect jaxprs and "succeeding by accident". The jaxprs were incorrect for the reasons we already understand. One way to see the issue in the test is to do this:
The jaxpr's input binder is incorrectly typed as an
i32[]
, even tough it's ani64[]
. That happens on main as well as on this branch. But on main we just happened to still perform the conversion, even though in the jaxpr it's trivial, because we (inconsistently) didn't elide it. I suspect I ran into this issue in #6014 and didn't take the time to understand it, instead only adding elision of trivial convert_element_types for constants instead. The reason constants worked is that we query the actual value for the dtype, rather than the incorrectly-dtyped variable.Anyway, since we don't plan to keep
enable_x64
around for long (as it's busted in other ways too), skipping this test seems best. An alternative would be to continue (for now) the policy of not eliding trivial convert_element_types except those applied to constants.