-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
generalize jaxpr simplification machinery
also: * fix jit invariance bug around weak types * elide trivial broadcasts This started as an attempt to simplify some jaxpr pretty-prints, by (1) eliding some convert_element_type applications that I thought were unnecessary and (2) eliding some trivial broadcasts. But it turned out that we were actually pruning more convert_element_types than we should! In particular, see test_weak_type_jit_invariance; that test fails on the main branch even if we add the fixes in DynamicJaxprTrace.new_const, because [this logic](https://github.com/google/jax/blob/b53a1740428a1b44d2b9f7694a00263918e6a309/jax/interpreters/partial_eval.py#L1225) was not paying attention to weak types and hence clobbered them. In addition to fixing those bugs that turned up (the changes in DynamicJaxprTrace, and in what is now _convert_elt_type_fwd_rule), this PR generalizes the jaxpr simplification machinery so as not to be a couple special cases on convert_element_type_p. Insetad, we have tables of rules! How we love them. These rule signatures should let us add simplifications like forwarding variables through calls and other higher-order primitives. That's all future work though.
- Loading branch information
Showing
7 changed files
with
121 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters