Skip to content

Commit

Permalink
unify configuration state handling
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Mar 18, 2021
1 parent 5e88ed2 commit 6930015
Show file tree
Hide file tree
Showing 19 changed files with 204 additions and 142 deletions.
3 changes: 2 additions & 1 deletion jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
del _cloud_tpu_init

# flake8: noqa: F401
from .config import config
from .config import (config, enable_checks, check_tracer_leaks, checking_leaks,
debug_nans, debug_infs, log_compiles)
from .api import (
ad, # TODO(phawkins): update users to avoid this.
argnums_partial, # TODO(phawkins): update Haiku to not use this.
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,7 @@ def _cond_typecheck(*avals, branches, linear):
f'called with operands of type {_avals_short(op_avals)}')

def cond_bind(*args, branches, linear):
if not core.skip_checks:
if config.jax_enable_checks:
avals = _map(core.get_aval, args)
_cond_typecheck(*avals, branches=branches, linear=linear)
for jaxpr in branches:
Expand Down Expand Up @@ -1876,7 +1876,7 @@ def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
f'called with sequence of type\n{_avals_short(x_avals)}')

def scan_bind(*args, **params):
if not core.skip_checks:
if config.jax_enable_checks:
avals = _map(core.get_aval, args)
_scan_typecheck(True, *avals, **params)
core.check_jaxpr(params['jaxpr'].jaxpr)
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import numpy as np

import jax
from jax.config import config

partial = functools.partial
Expand Down Expand Up @@ -192,7 +191,7 @@ def cached(_, *args, **kwargs):

@functools.wraps(f)
def wrapper(*args, **kwargs):
if jax.core.debug_state.check_leaks:
if config.jax_check_tracer_leaks:
return f(*args, **kwargs)
else:
return cached(bool(config.x64_enabled), *args, **kwargs)
Expand Down
8 changes: 3 additions & 5 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from . import linear_util as lu
from . import ad_util
from . import dtypes
from .core import eval_jaxpr, checking_leaks
from .core import eval_jaxpr
from .api_util import (flatten_fun, apply_flat_fun, flatten_fun_nokwargs,
flatten_fun_nokwargs2, argnums_partial,
argnums_partial_except, flatten_axes, donation_vector,
Expand Down Expand Up @@ -353,10 +353,8 @@ def get_device_info():
@wraps(fun)
@api_boundary
def f_jitted(*args, **kwargs):
# TODO(jblespiau): We can remove `config.x64_enabled` when jaxlib 0.1.62 is
# the minimal version.
context = (getattr(core.thread_local_state.trace_state.trace_stack,
"dynamic", None), config.x64_enabled)
context = getattr(core.thread_local_state.trace_state.trace_stack,
"dynamic", None)
# TODO(jblespiau): Move this to C++.
if (FLAGS.jax_debug_nans or FLAGS.jax_debug_infs) and not _jit_is_disabled():
device_arrays = cpp_jitted_f(context, *args, **kwargs)
Expand Down
134 changes: 118 additions & 16 deletions jax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import functools
import os
import sys
import threading

from jax import lib

def bool_env(varname: str, default: bool) -> bool:
Expand Down Expand Up @@ -47,6 +51,8 @@ def __init__(self):
self.meta = {}
self.FLAGS = NameSpace(self.read)
self.use_absl = False

# TODO(mattjj): delete these when only omnistaging is available
self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', True)
self._omnistaging_disablers = []

Expand All @@ -60,9 +66,9 @@ def update(self, name, val):
self.values[name] = val

if name == "jax_disable_jit":
lib.jax_jit.global_state().disable_jit = val
lib.jax_jit.global_state().disable_jit = bool(val)
elif name == "jax_enable_x64":
lib.jax_jit.global_state().enable_x64 = val
lib.jax_jit.global_state().enable_x64 = bool(val)

def read(self, name):
if self.use_absl:
Expand Down Expand Up @@ -143,14 +149,85 @@ def disable_omnistaging(self):
disabler()
self.omnistaging_enabled = False

# TODO(jakevdp, mattjj): unify this with `define_bool_state` stuff below
@property
def x64_enabled(self):
return lib.jax_jit.get_enable_x64()

# TODO(jakevdp): make this public when thread-local x64 is fully implemented.
def _set_x64_enabled(self, state):
lib.jax_jit.thread_local_state().enable_x64 = bool(state)

def define_bool_state(self, name: str, default: bool, help: str):
"""Set up thread-local state and return a contextmanager for managing it.
This function is a convenience wrapper. It defines a flag and corresponding
thread-local state, which can be managed via the contextmanager it returns.
The thread-local state value can be read via the ``config.<option_name>``
attribute, where ``config`` is the singleton ``Config`` instance.
Args:
name: string, converted to lowercase to define the name of the config
option (and absl flag). It is converted to uppercase to define the
corresponding shell environment variable.
default: boolean, a default value for the option.
help: string, used to populate the flag help information as well as the
docstring of the returned context manager.
Returns:
A contextmanager to control the thread-local state value.
Example:
enable_foo = config.define_bool_state(
name='jax_enable_foo',
default=False,
help='Enable foo.')
# Now the JAX_ENABLE_FOO shell environment variable and --jax_enable_foo
# command-line flag can be used to control the process-level value of
# the configuration option, in addition to using e.g.
# ``config.update("jax_enable_foo", True)`` directly. We can also use a
# context manager:
with enable_foo(True):
...
Accessing ``config.FLAGS.jax_enable_foo`` is different from accessing the
thread-local state value via ``config.jax_enable_foo``: the former reads the
flag value determined set by the environment variable or command-line flag
and does not read the thread-local state, whereas the latter reads the
thread-local state value managed by the contextmanager. Think of the
contextmanager state as a layer on top of the flag value: if no
contextmanager is in use then ``config.jax_enable_foo`` reflects the flag
value ``config.FLAGS.jax_enable_foo``, whereas if a contextmanager is in use
then only ``config.jax_enable_foo`` is updated. So in general using
``config.jax_enable_foo`` is best.
"""
name = name.lower()
self.DEFINE_bool(name, bool_env(name.upper(), default), help)

def get_state(self):
val = getattr(_thread_local_state, name, unset)
return val if val is not unset else self.read(name)
setattr(Config, name, property(get_state))

@contextlib.contextmanager
def set_state(new_val: bool):
prev_val = getattr(_thread_local_state, name, unset)
setattr(_thread_local_state, name, new_val)
try:
yield
finally:
if prev_val is unset:
delattr(_thread_local_state, name)
else:
setattr(_thread_local_state, name, prev_val)
set_state.__name__ = name.lstrip('jax_')
set_state.__doc__ = f"Context manager for `{name}` config option.\n\n{help}"
return set_state

_thread_local_state = threading.local()

class Unset: pass
unset = Unset()

class NameSpace(object):
def __init__(self, getter):
Expand All @@ -166,11 +243,6 @@ def __getattr__(self, name):

already_configured_with_absl = False

flags.DEFINE_bool(
'jax_enable_checks',
bool_env('JAX_ENABLE_CHECKS', False),
help='Turn on invariant checking (core.skip_checks = False)'
)

flags.DEFINE_bool(
'jax_omnistaging',
Expand All @@ -184,10 +256,40 @@ def __getattr__(self, name):
help='Set the number of stack frames in JAX tracer error messages.'
)

flags.DEFINE_bool(
'jax_check_tracer_leaks',
bool_env('JAX_CHECK_TRACER_LEAKS', False),

enable_checks = config.define_bool_state(
name='jax_enable_checks',
default=False,
help='Turn on invariant checking for JAX internals. Makes things slower.')

check_tracer_leaks = config.define_bool_state(
name='jax_check_tracer_leaks',
default=False,
help=('Turn on checking for leaked tracers as soon as a trace completes. '
'Enabling leak checking may have performance impacts: some caching '
'is disabled, and other overheads may be added.'),
)
'is disabled, and other overheads may be added.'))
checking_leaks = functools.partial(check_tracer_leaks, True)

debug_nans = config.define_bool_state(
name='jax_debug_nans',
default=False,
help=('Add nan checks to every operation. When a nan is detected on the '
'output of a jit-compiled computation, call into the un-compiled '
'version in an attempt to more precisely identify the operation '
'which produced the nan.'))

debug_infs = config.define_bool_state(
name='jax_debug_infs',
default=False,
help=('Add inf checks to every operation. When an inf is detected on the '
'output of a jit-compiled computation, call into the un-compiled '
'version in an attempt to more precisely identify the operation '
'which produced the inf.'))

log_compiles = config.define_bool_state(
name='jax_log_compiles',
default=False,
help=('Log a message each time every time `jit` or `pmap` compiles an XLA '
'computation. Logging is performed with `absl.logging`. When this '
'option is set, the log level is WARNING; otherwise the level is '
'DEBUG.'))
45 changes: 9 additions & 36 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,33 +43,6 @@
from ._src import traceback_util
traceback_util.register_exclusion(__file__)

# TODO(mattjj): move this into debug_state
skip_checks = not FLAGS.jax_enable_checks

@contextmanager
def skipping_checks():
"""Context manager for temporarily disabling internal checks."""
global skip_checks
old_value, skip_checks = skip_checks, True
try:
yield
finally:
skip_checks = old_value

@contextmanager
def checking_leaks():
"""Context manager for temporarily enabling tracer leak checks."""
old_value, debug_state.check_leaks = debug_state.check_leaks, True
try:
yield
finally:
debug_state.check_leaks = old_value

class DebugState(threading.local):
def __init__(self):
self.check_leaks = FLAGS.jax_check_tracer_leaks
debug_state = DebugState()

zip = safe_zip
map = safe_map

Expand Down Expand Up @@ -277,8 +250,8 @@ def __repr__(self):


def bind(self, *args, **params):
assert skip_checks or all(isinstance(arg, Tracer)
or valid_jaxtype(arg) for arg in args), args
assert (not config.jax_enable_checks or
all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
top_trace = find_top_trace(args)
tracers = map(top_trace.full_raise, args)
out = top_trace.process_primitive(self, tracers, params)
Expand Down Expand Up @@ -567,7 +540,7 @@ def __array_module__(self, types): return self.aval._array_module(self, types)

def __getattr__(self, name):
# if the aval property raises an AttributeError, gets caught here
assert skip_checks or name != "aval"
assert not config.jax_enable_checks or name != "aval"

try:
attr = getattr(self.aval, name)
Expand Down Expand Up @@ -754,7 +727,7 @@ def new_main(trace_type: Type[Trace],
if dynamic:
stack.dynamic = prev_dynamic

if debug_state.check_leaks:
if config.jax_check_tracer_leaks:
t = ref(main)
del main
if t() is not None:
Expand All @@ -773,7 +746,7 @@ def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
stack.dynamic = prev_dynamic
stack.stack[0] = prev_base

if debug_state.check_leaks:
if config.jax_check_tracer_leaks:
t = ref(main)
del main
if t() is not None:
Expand All @@ -794,7 +767,7 @@ def new_sublevel() -> Generator[None, None, None]:
thread_local_state.trace_state.substack.pop()

# TODO(mattjj): to check sublevel leaks, we need to make Sublevel weakref-able
# if debug_state.check_leaks:
# if config.jax_check_tracer_leaks:
# t = ref(sublevel)
# del sublevel
# if t() is not None:
Expand Down Expand Up @@ -866,7 +839,7 @@ class AbstractUnit(AbstractValue):
# _num_buffers = 0
def at_least_vspace(self): return self
def join(self, other):
if not skip_checks:
if config.jax_enable_checks:
assert other is abstract_unit, other
return self
def _eq(self, self_traced, other): return get_aval(other) is self
Expand Down Expand Up @@ -1894,7 +1867,7 @@ def new_main(trace_type: Type[Trace], bottom=False, **payload) -> Generator[Main
finally:
thread_local_state.trace_state.trace_stack.pop(bottom)

if debug_state.check_leaks:
if config.jax_check_tracer_leaks:
t = ref(main)
del main
if t() is not None:
Expand All @@ -1911,7 +1884,7 @@ def eval_context():
yield # dummy implementation for forward compatibility

def bind(self, *args, **kwargs):
assert skip_checks or all(isinstance(arg, Tracer)
assert not config.jax_enable_checks or all(isinstance(arg, Tracer)
or valid_jaxtype(arg) for arg in args), args
top_trace = find_top_trace(args)
if top_trace is None:
Expand Down
2 changes: 1 addition & 1 deletion jax/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ def rev(objective_fn, res, g):
"""
flat_args, in_tree = tree_flatten(example_args)
in_avals = tuple(map(abstractify, flat_args))
if core.debug_state.check_leaks:
if config.jax_check_tracer_leaks:
return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals)
else:
return _closure_convert_for_avals(fun, in_tree, in_avals)
Expand Down
Loading

0 comments on commit 6930015

Please sign in to comment.