Skip to content

Commit

Permalink
make config.FLAGS.jax_enable_foo an error
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Mar 19, 2021
1 parent 8acad26 commit 7f3bb63
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 14 deletions.
24 changes: 13 additions & 11 deletions jax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self):
self.meta = {}
self.FLAGS = NameSpace(self.read)
self.use_absl = False
self._contextmanager_flags = set()

# TODO(mattjj): delete these when only omnistaging is available
self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', True)
Expand All @@ -71,6 +72,13 @@ def update(self, name, val):
lib.jax_jit.global_state().enable_x64 = val

def read(self, name):
if name in self._contextmanager_flags:
raise AttributeError(
"For flags with a corresponding contextmanager, read their value "
f"via e.g. `config.{name}` rather than `config.FLAGS.{name}`.")
return self._read(name)

def _read(self, name):
if self.use_absl:
return getattr(self.absl_flags.FLAGS, name)
else:
Expand Down Expand Up @@ -193,23 +201,17 @@ def define_bool_state(self, name: str, default: bool, help: str):
with enable_foo(True):
...
Accessing ``config.FLAGS.jax_enable_foo`` is different from accessing the
thread-local state value via ``config.jax_enable_foo``: the former reads the
flag value determined set by the environment variable or command-line flag
and does not read the thread-local state, whereas the latter reads the
thread-local state value managed by the contextmanager. Think of the
contextmanager state as a layer on top of the flag value: if no
contextmanager is in use then ``config.jax_enable_foo`` reflects the flag
value ``config.FLAGS.jax_enable_foo``, whereas if a contextmanager is in use
then only ``config.jax_enable_foo`` is updated. So in general using
``config.jax_enable_foo`` is best.
The value of the thread-local state or flag can be accessed via
``config.jax_enable_foo``. Reading it via ``config.FLAGS.jax_enable_foo`` is
an error.
"""
name = name.lower()
self.DEFINE_bool(name, bool_env(name.upper(), default), help)
self._contextmanager_flags.add(name)

def get_state(self):
val = getattr(_thread_local_state, name, unset)
return val if val is not unset else self.read(name)
return val if val is not unset else self._read(name)
setattr(Config, name, property(get_state))

@contextlib.contextmanager
Expand Down
2 changes: 1 addition & 1 deletion tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2334,7 +2334,7 @@ def test_leak_checker_catches_a_sublevel_leak(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")

with core.checking_leaks():
with jax.checking_leaks():
@jit
def f(x):
lst = []
Expand Down
4 changes: 2 additions & 2 deletions tests/debug_nans_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
class DebugNaNsTest(jtu.JaxTestCase):

def setUp(self):
self.cfg = config.read("jax_debug_nans")
self.cfg = config._read("jax_debug_nans")
config.update("jax_debug_nans", True)

def tearDown(self):
Expand Down Expand Up @@ -144,7 +144,7 @@ def testPjit(self):
class DebugInfsTest(jtu.JaxTestCase):

def setUp(self):
self.cfg = config.read("jax_debug_infs")
self.cfg = config._read("jax_debug_infs")
config.update("jax_debug_infs", True)

def tearDown(self):
Expand Down

0 comments on commit 7f3bb63

Please sign in to comment.