Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unify configuration state handling #6112

Merged
merged 1 commit into from
Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
del _cloud_tpu_init

# flake8: noqa: F401
from .config import config
from .config import (config, enable_checks, check_tracer_leaks, checking_leaks,
debug_nans, debug_infs, log_compiles)
from .api import (
ad, # TODO(phawkins): update users to avoid this.
argnums_partial, # TODO(phawkins): update Haiku to not use this.
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,7 @@ def _cond_typecheck(*avals, branches, linear):
f'called with operands of type {_avals_short(op_avals)}')

def cond_bind(*args, branches, linear):
if not core.skip_checks:
if config.jax_enable_checks:
avals = _map(core.get_aval, args)
_cond_typecheck(*avals, branches=branches, linear=linear)
for jaxpr in branches:
Expand Down Expand Up @@ -1876,7 +1876,7 @@ def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
f'called with sequence of type\n{_avals_short(x_avals)}')

def scan_bind(*args, **params):
if not core.skip_checks:
if config.jax_enable_checks:
avals = _map(core.get_aval, args)
_scan_typecheck(True, *avals, **params)
core.check_jaxpr(params['jaxpr'].jaxpr)
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import numpy as np

import jax
from jax.config import config

partial = functools.partial
Expand Down Expand Up @@ -192,7 +191,7 @@ def cached(_, *args, **kwargs):

@functools.wraps(f)
def wrapper(*args, **kwargs):
if jax.core.debug_state.check_leaks:
if config.jax_check_tracer_leaks:
return f(*args, **kwargs)
else:
return cached(bool(config.x64_enabled), *args, **kwargs)
Expand Down
10 changes: 5 additions & 5 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from . import linear_util as lu
from . import ad_util
from . import dtypes
from .core import eval_jaxpr, checking_leaks
from .core import eval_jaxpr
from .api_util import (flatten_fun, apply_flat_fun, flatten_fun_nokwargs,
flatten_fun_nokwargs2, argnums_partial,
argnums_partial_except, flatten_axes, donation_vector,
Expand Down Expand Up @@ -362,7 +362,7 @@ def f_jitted(*args, **kwargs):
context = (getattr(core.thread_local_state.trace_state.trace_stack,
"dynamic", None), config.x64_enabled)
# TODO(jblespiau): Move this to C++.
if (FLAGS.jax_debug_nans or FLAGS.jax_debug_infs) and not _jit_is_disabled():
if (config.jax_debug_nans or config.jax_debug_infs) and not _jit_is_disabled():
device_arrays = cpp_jitted_f(context, *args, **kwargs)
try:
xla.check_special(xla.xla_call_p, [
Expand All @@ -372,7 +372,7 @@ def f_jitted(*args, **kwargs):
])
return device_arrays
except FloatingPointError:
assert FLAGS.jax_debug_nans or FLAGS.jax_debug_infs # compiled_fun can only raise in this case
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
print("Invalid nan value encountered in the output of a C++-jit "
"function. Calling the de-optimized version.")
return cache_miss(*args, **kwargs)[0] # probably won't return
Expand All @@ -389,7 +389,7 @@ def f_jitted(*args, **kwargs):
@api_boundary
def f_jitted(*args, **kwargs):
# TODO(jblespiau): Move this to C++.
if (FLAGS.jax_debug_nans or FLAGS.jax_debug_infs) and not _jit_is_disabled():
if (config.jax_debug_nans or config.jax_debug_infs) and not _jit_is_disabled():
device_arrays = cpp_jitted_f(*args, **kwargs)
try:
xla.check_special(xla.xla_call_p, [
Expand All @@ -399,7 +399,7 @@ def f_jitted(*args, **kwargs):
])
return device_arrays
except FloatingPointError:
assert FLAGS.jax_debug_nans or FLAGS.jax_debug_infs # compiled_fun can only raise in this case
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
print("Invalid nan value encountered in the output of a C++-jit "
"function. Calling the de-optimized version.")
return cache_miss(*args, **kwargs)[0] # probably won't return
Expand Down
182 changes: 161 additions & 21 deletions jax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import functools
import os
import sys
import threading

from jax import lib

def bool_env(varname: str, default: bool) -> bool:
Expand Down Expand Up @@ -42,11 +46,16 @@ def int_env(varname: str, default: int) -> int:


class Config:
_HAS_DYNAMIC_ATTRIBUTES = True

def __init__(self):
self.values = {}
self.meta = {}
self.FLAGS = NameSpace(self.read)
self.use_absl = False
self._contextmanager_flags = set()

# TODO(mattjj): delete these when only omnistaging is available
self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', True)
self._omnistaging_disablers = []

Expand All @@ -65,6 +74,13 @@ def update(self, name, val):
lib.jax_jit.global_state().enable_x64 = val

def read(self, name):
if name in self._contextmanager_flags:
raise AttributeError(
"For flags with a corresponding contextmanager, read their value "
f"via e.g. `config.{name}` rather than `config.FLAGS.{name}`.")
return self._read(name)

def _read(self, name):
if self.use_absl:
return getattr(self.absl_flags.FLAGS, name)
else:
Expand Down Expand Up @@ -143,14 +159,82 @@ def disable_omnistaging(self):
disabler()
self.omnistaging_enabled = False

@property
def x64_enabled(self):
return lib.jax_jit.get_enable_x64()

# TODO(jakevdp): make this public when thread-local x64 is fully implemented.
def _set_x64_enabled(self, state):
lib.jax_jit.thread_local_state().enable_x64 = bool(state)

# # TODO(jakevdp, mattjj): unify this with `define_bool_state` stuff below
# @property
# def x64_enabled(self):
# return lib.jax_jit.get_enable_x64()

# def _set_x64_enabled(self, state):
# lib.jax_jit.thread_local_state().enable_x64 = bool(state)

def define_bool_state(self, name: str, default: bool, help: str):
"""Set up thread-local state and return a contextmanager for managing it.

This function is a convenience wrapper. It defines a flag and corresponding
thread-local state, which can be managed via the contextmanager it returns.

The thread-local state value can be read via the ``config.<option_name>``
attribute, where ``config`` is the singleton ``Config`` instance.

Args:
name: string, converted to lowercase to define the name of the config
option (and absl flag). It is converted to uppercase to define the
corresponding shell environment variable.
default: boolean, a default value for the option.
help: string, used to populate the flag help information as well as the
docstring of the returned context manager.

Returns:
A contextmanager to control the thread-local state value.

Example:

enable_foo = config.define_bool_state(
name='jax_enable_foo',
default=False,
help='Enable foo.')

# Now the JAX_ENABLE_FOO shell environment variable and --jax_enable_foo
# command-line flag can be used to control the process-level value of
# the configuration option, in addition to using e.g.
# ``config.update("jax_enable_foo", True)`` directly. We can also use a
# context manager:

with enable_foo(True):
...

The value of the thread-local state or flag can be accessed via
``config.jax_enable_foo``. Reading it via ``config.FLAGS.jax_enable_foo`` is
an error.
"""
name = name.lower()
self.DEFINE_bool(name, bool_env(name.upper(), default), help)
self._contextmanager_flags.add(name)

def get_state(self):
val = getattr(_thread_local_state, name, unset)
return val if val is not unset else self._read(name)
setattr(Config, name, property(get_state))

@contextlib.contextmanager
def set_state(new_val: bool):
prev_val = getattr(_thread_local_state, name, unset)
setattr(_thread_local_state, name, new_val)
try:
yield
finally:
if prev_val is unset:
delattr(_thread_local_state, name)
else:
setattr(_thread_local_state, name, prev_val)
set_state.__name__ = name[4:] if name.startswith('jax_') else name
set_state.__doc__ = f"Context manager for `{name}` config option.\n\n{help}"
return set_state

_thread_local_state = threading.local()

class Unset: pass
unset = Unset()

class NameSpace(object):
def __init__(self, getter):
Expand All @@ -166,11 +250,6 @@ def __getattr__(self, name):

already_configured_with_absl = False

flags.DEFINE_bool(
'jax_enable_checks',
bool_env('JAX_ENABLE_CHECKS', False),
help='Turn on invariant checking (core.skip_checks = False)'
)

flags.DEFINE_bool(
'jax_omnistaging',
Expand All @@ -184,14 +263,6 @@ def __getattr__(self, name):
help='Set the number of stack frames in JAX tracer error messages.'
)

flags.DEFINE_bool(
'jax_check_tracer_leaks',
bool_env('JAX_CHECK_TRACER_LEAKS', False),
help=('Turn on checking for leaked tracers as soon as a trace completes. '
'Enabling leak checking may have performance impacts: some caching '
'is disabled, and other overheads may be added.'),
)

flags.DEFINE_bool(
'jax_host_callback_inline',
bool_env('JAX_HOST_CALLBACK_INLINE', False),
Expand All @@ -206,3 +277,72 @@ def __getattr__(self, name):
'until the Python callback consume more outfeeds.'),
lower_bound=int(16 * 1e6)
)


enable_checks = config.define_bool_state(
name='jax_enable_checks',
default=False,
help='Turn on invariant checking for JAX internals. Makes things slower.')

check_tracer_leaks = config.define_bool_state(
name='jax_check_tracer_leaks',
default=False,
help=('Turn on checking for leaked tracers as soon as a trace completes. '
'Enabling leak checking may have performance impacts: some caching '
'is disabled, and other overheads may be added.'))
checking_leaks = functools.partial(check_tracer_leaks, True)

debug_nans = config.define_bool_state(
name='jax_debug_nans',
default=False,
help=('Add nan checks to every operation. When a nan is detected on the '
'output of a jit-compiled computation, call into the un-compiled '
'version in an attempt to more precisely identify the operation '
'which produced the nan.'))

debug_infs = config.define_bool_state(
name='jax_debug_infs',
default=False,
help=('Add inf checks to every operation. When an inf is detected on the '
'output of a jit-compiled computation, call into the un-compiled '
'version in an attempt to more precisely identify the operation '
'which produced the inf.'))

log_compiles = config.define_bool_state(
name='jax_log_compiles',
default=False,
help=('Log a message each time every time `jit` or `pmap` compiles an XLA '
'computation. Logging is performed with `absl.logging`. When this '
'option is set, the log level is WARNING; otherwise the level is '
'DEBUG.'))

# Because jax_enable_x64 is managed by C++ code, we don't reuse the
# config.define_bool_state mechanism, though conceptually it is the same.
config.DEFINE_bool('jax_enable_x64', bool_env('JAX_ENABLE_X64', False),
help='Enable 64-bit types to be used')
lib.jax_jit.global_state().enable_x64 = bool_env('JAX_ENABLE_X64', False)

@contextlib.contextmanager
def enable_x64(new_val: bool = True):
"""Experimental context manager to temporarily enable X64 mode.

Usage::

>>> import jax.numpy as jnp
>>> with enable_x64(True):
... print(jnp.arange(10.0).dtype)
...
float64
"""
prev_val = config.jax_enable_x64
lib.jax_jit.thread_local_state().enable_x64 = bool(new_val)
try:
yield
finally:
lib.jax_jit.thread_local_state().enable_x64 = prev_val
Config.jax_enable_x64 = property(lambda self: lib.jax_jit.get_enable_x64())
# config._contextmanager_flags.add('jax_enable_x64') # TODO(mattjj): remove footgun

# The `x64_enabled` property doesn't fit the naming scheme, but we use it for
# backward compatibility.
Config.x64_enabled = Config.jax_enable_x64
Loading