unify configuration state handling #6112
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Reviewer: The main action is in config.py. The other changed files are downstream of those changes.
This change arose from working on adding a
jax_tpu_dot_precision
flag / context manager (see #6143).We have a few configurable bits of global state. These bits of configurable state are managed by flags and sometimes dynamically by context managers, but neither the APIs nor implementations are uniform. Moreover the state is spread out through a few different files.
The bits of global state we have in mind are:
core.skip_checks
(state in core.py, flag defined in config.py, not properly thread-local, has a context manager, does not affectjit
dispatch since it's all about trace-time errors)core.debug_state.check_leaks
(state in core.py, flag defined in config.py, thread-local, has a context manager, does not affectjit
dispatch since it's all about trace-time errors)jax_debug_nans
/jax_debug_infs
(flag defined in xla.py, not properly thread-local, no context manager, affectsjit
dispatch in that it adds checks to every execution on the Python side) (affectsjit
dispatch)jax_log_compiles
(flag defined in xla.py, not properly thread-local, no context manager, does not affectjit
dispatch since it's all about trace-time logging)jax_enable_x64
(work-in-progress, state in jax_jit.cc, context manager being developed injax.experimental
, thread-local, affectsjit
dispatch in that it's part of the compilation cache key and affects how input arguments are handled in the c++ code)jax_default_dot_precision
(work-in-progress, not present yet, affectsjit
dispatch analogously tojax_enable_x64
)disable_jit
(state in jax_jit.cc, affectsjit
dispatch in that it's part of the compilation cache key)jax_numpy_rank_promotion
(flag defined in lax_numpy.py, not thread-local and no context manager, does not affectjit
dispatch in that it's all about trace-time errors)This PR unifies all the boolean-valued instances of Python state via a single mechanism in config.py which sets up flags, thread-local state, and context manager APIs. This PR doesn't touch
jax_default_dot_precision
orjax_numpy_rank_promotion
because those are enums rather than booleans; it doesn't touchdisable_jit
orjax_enable_x64
because those are in C++.Another effect of this PR is introducing new context managers:
jax.enable_checks
,jax.check_tracer_leaks
,jax.debug_nans
,jax.debug_infs
, andjax.log_compiles
. Each takes a single boolean argument.Follow-up work might put more of these bits in C++ (i.e. in jax_jit.cc) for fast dispatch, and/or speed up dispatch times for Python state bits. That work should be easier once we collect all the state in one place as in this PR. It's also follow-up work
to unify the API withto add an enum version of this logic forjax_enable_x64
(cc @jakevdp), andjax_default_dot_precision
,jax_numpy_rank_promotion
, and perhaps the default device. After discussing with @jakevdp , we decided to unify with the implementation ofjax_enable_x64
in this PR, but leave the API endpoints for the x64 stuff unchanged.Benchmark results on
benchmarks/api_benchmark.py
show no real differences AIUI:This PR doesn't currently include tests, though the code is pretty thoroughly exercised by existing test coverage for
skip_checks
,check_leaks
,debug_nans
,disable_jit
, etc.