From 89768a3d2898ee7e08425f90ccf1435449c07ee6 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 23 Mar 2021 20:58:52 -0700 Subject: [PATCH] add jax_default_matmul_precision flag & context mngr --- jax/__init__.py | 3 +- jax/_src/lax/lax.py | 84 ++++++++++++++++++++++------------ jax/_src/numpy/lax_numpy.py | 17 ++----- jax/config.py | 91 +++++++++++++++++++++++++++++++++++-- tests/api_test.py | 53 +++++++++++++++++++++ tests/lax_numpy_test.py | 27 +++++++++-- 6 files changed, 224 insertions(+), 51 deletions(-) diff --git a/jax/__init__.py b/jax/__init__.py index 1cf4eb412cd8..2e10b66c74c5 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -31,7 +31,8 @@ # flake8: noqa: F401 from .config import (config, enable_checks, check_tracer_leaks, checking_leaks, - debug_nans, debug_infs, log_compiles) + debug_nans, debug_infs, log_compiles, + default_matmul_precision, numpy_rank_promotion) from .api import ( ad, # TODO(phawkins): update users to avoid this. argnums_partial, # TODO(phawkins): update Haiku to not use this. diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index b4a4568f08b9..403e1a1b8bba 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -506,8 +506,17 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array: Precision = xla_client.PrecisionConfig.Precision Precision.__str__ = lambda precision: precision.name PrecisionType = Any -PrecisionLike = Union[None, PrecisionType, Tuple[PrecisionType, PrecisionType]] - +PrecisionLike = Union[None, str, PrecisionType, Tuple[str, str], + Tuple[PrecisionType, PrecisionType]] +_precision_strings = { + 'highest': Precision.HIGHEST, + 'float32': Precision.HIGHEST, + 'bfloat16_3x': Precision.HIGH, + 'tensorfloat32': Precision.HIGH, + 'bfloat16': Precision.DEFAULT, + 'fastest': Precision.DEFAULT, + None: Precision.DEFAULT, +} class ConvDimensionNumbers(NamedTuple): """Describes batch, spatial, and feature dimensions of a convolution. @@ -555,23 +564,25 @@ def conv_general_dilated( rhs_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `rhs`. RHS dilation is also known as atrous convolution. - dimension_numbers: either `None`, a `ConvDimensionNumbers` object, or - a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string - of length `n+2`. + dimension_numbers: either `None`, a ``ConvDimensionNumbers`` object, or + a 3-tuple ``(lhs_spec, rhs_spec, out_spec)``, where each element is a + string of length `n+2`. feature_group_count: integer, default 1. See XLA HLO docs. batch_group_count: integer, default 1. See XLA HLO docs. precision: Optional. Either ``None``, which means the default precision for the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``, - ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two - ``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``. + ``Precision.HIGH`` or ``Precision.HIGHEST``), a string (e.g. 'highest' or + 'fastest', see the ``jax.default_matmul_precision`` context manager), or a + tuple of two ``lax.Precision`` enums or strings indicating precision of + ``lhs`` and ``rhs``. Returns: An array containing the convolution result. - In the string case of `dimension_numbers`, each character identifies by + In the string case of ``dimension_numbers``, each character identifies by position: - - the batch dimensions in `lhs`, `rhs`, and the output with the character + - the batch dimensions in ``lhs``, ``rhs``, and the output with the character 'N', - the feature dimensions in `lhs` and the output with the character 'C', - the input and output feature dimensions in rhs with the characters 'I' @@ -579,18 +590,18 @@ def conv_general_dilated( - spatial dimension correspondences between lhs, rhs, and the output using any distinct characters. - For example, to indicate dimension numbers consistent with the `conv` function - with two spatial dimensions, one could use `('NCHW', 'OIHW', 'NCHW')`. As - another example, to indicate dimension numbers consistent with the TensorFlow - Conv2D operation, one could use `('NHWC', 'HWIO', 'NHWC')`. When using the - latter form of convolution dimension specification, window strides are - associated with spatial dimension character labels according to the order in - which the labels appear in the `rhs_spec` string, so that `window_strides[0]` - is matched with the dimension corresponding to the first character - appearing in rhs_spec that is not `'I'` or `'O'`. - - If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')` - (for a 2D convolution). + For example, to indicate dimension numbers consistent with the ``conv`` + function with two spatial dimensions, one could use ``('NCHW', 'OIHW', + 'NCHW')``. As another example, to indicate dimension numbers consistent with + the TensorFlow Conv2D operation, one could use ``('NHWC', 'HWIO', 'NHWC')``. + When using the latter form of convolution dimension specification, window + strides are associated with spatial dimension character labels according to + the order in which the labels appear in the ``rhs_spec`` string, so that + ``window_strides[0]`` is matched with the dimension corresponding to the first + character appearing in rhs_spec that is not ``'I'`` or ``'O'``. + + If ``dimension_numbers`` is ``None``, the default is ``('NCHW', 'OIHW', + 'NCHW')`` (for a 2D convolution). """ dnums = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers) if lhs_dilation is None: @@ -6394,16 +6405,31 @@ def remaining(original, *removed_lists): def _canonicalize_precision(precision): if precision is None: - return None - if isinstance(precision, Precision) or ( - isinstance(precision, tuple) - and len(precision) == 2 - and all(isinstance(p, Precision) for p in precision) - ): + if config.jax_default_matmul_precision is None: + return None + try: + return _precision_strings[config.jax_default_matmul_precision] + except KeyError: + raise ValueError( + "jax_default_matmul_precision flag must be set to None or a value in " + f"{_precision_strings}, but got {config.jax_default_matmul_precision}" + ) from None + elif isinstance(precision, str) and precision in _precision_strings: + return _precision_strings.get(precision) + elif isinstance(precision, Precision): return precision + elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and + all(isinstance(p, Precision) for p in precision)): + return precision + elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and + all(isinstance(s, str) for s in precision)): + s1, s2 = precision + return (_canonicalize_precision(s1), _canonicalize_precision(s2)) else: - raise ValueError("Precision argument must be None, a lax.Precision value " - f"or a tuple of two lax.Precision values; got {precision}") + raise ValueError( + f"Precision argument must be None, a string in {_precision_strings}, " + "a lax.Precision value or a tuple of two lax.Precision values or " + f"strings; got {precision}.") def conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4c2fbe9c91cf..e2e916f7f2d8 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -28,7 +28,6 @@ import collections import collections.abc import operator -import os import types from typing import Any, Sequence, FrozenSet, Optional, Tuple, Union, cast from textwrap import dedent as _dedent @@ -45,7 +44,7 @@ from jax import dtypes from jax import errors from jax.core import UnshapedArray, ShapedArray, ConcreteArray, canonicalize_shape -from jax.config import flags, config +from jax.config import config from jax.interpreters.xla import DeviceArray, _DeviceArray, _CppDeviceArray from jax.interpreters.masking import Poly from jax import lax @@ -55,14 +54,6 @@ canonicalize_axis as _canonicalize_axis, maybe_named_axis) from jax.tree_util import tree_leaves, tree_flatten, tree_map -FLAGS = flags.FLAGS -flags.DEFINE_enum( - 'jax_numpy_rank_promotion', os.getenv('JAX_NUMPY_RANK_PROMOTION', 'allow'), - enum_values=['allow', 'warn', 'raise'], - help= - 'Control NumPy-style automatic rank promotion broadcasting ' - '("allow", "warn", or "raise").') - newaxis = None # Common docstring additions: @@ -247,20 +238,20 @@ def _promote_shapes(fun_name, *args): if not nonscalar_ranks or len(set(nonscalar_ranks)) == 1: return args else: - if FLAGS.jax_numpy_rank_promotion != "allow": + if config.jax_numpy_rank_promotion != "allow": _rank_promotion_warning_or_error(fun_name, shapes) result_rank = len(lax.broadcast_shapes(*shapes)) return [broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp) for arg, shp in zip(args, shapes)] def _rank_promotion_warning_or_error(fun_name, shapes): - if FLAGS.jax_numpy_rank_promotion == "warn": + if config.jax_numpy_rank_promotion == "warn": msg = ("Following NumPy automatic rank promotion for {} on shapes {}. " "Set the jax_numpy_rank_promotion config option to 'allow' to " "disable this warning; for more information, see " "https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.") warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes)))) - elif FLAGS.jax_numpy_rank_promotion == "raise": + elif config.jax_numpy_rank_promotion == "raise": msg = ("Operands could not be broadcast together for {} on shapes {} " "and with the config option jax_numpy_rank_promotion='raise'. " "For more information, see " diff --git a/jax/config.py b/jax/config.py index 1c4be855bf38..c2a4eebf7672 100644 --- a/jax/config.py +++ b/jax/config.py @@ -17,9 +17,9 @@ import os import sys import threading +from typing import List, Callable, Optional from jax import lib -from typing import Callable, Optional def bool_env(varname: str, default: bool) -> bool: """Read an environment variable and interpret it as a boolean. @@ -52,7 +52,7 @@ class Config: def __init__(self): self.values = {} self.meta = {} - self.FLAGS = NameSpace(self.read) + self.FLAGS = NameSpace(self.read, self.update) self.use_absl = False self._contextmanager_flags = set() @@ -255,18 +255,70 @@ def set_state(new_val: bool): set_state.__doc__ = f"Context manager for `{name}` config option.\n\n{help}" return set_state + def define_enum_state(self, name: str, enum_values: List[str], + default: Optional[str], help: str): + """Set up thread-local state and return a contextmanager for managing it. + Args: + name: string, converted to lowercase to define the name of the config + option (and absl flag). It is converted to uppercase to define the + corresponding shell environment variable. + enum_values: list of strings representing the possible values for the + option. + default: optional string, default value. + help: string, used to populate the flag help information as well as the + docstring of the returned context manager. + Returns: + A contextmanager to control the thread-local state value. + See docstring for ``define_bool_state``. + """ + name = name.lower() + self.DEFINE_enum(name, os.getenv(name.upper(), default), + enum_values=enum_values, help=help) + self._contextmanager_flags.add(name) + + def get_state(self): + val = getattr(_thread_local_state, name, unset) + return val if val is not unset else self._read(name) + setattr(Config, name, property(get_state)) + + @contextlib.contextmanager + def set_state(new_val: Optional[str]): + if (new_val is not None and + (type(new_val) is not str or new_val not in enum_values)): + raise ValueError(f"new enum value must be None or in {enum_values}, " + f"got {new_val} of type {type(new_val)}.") + prev_val = getattr(_thread_local_state, name, unset) + setattr(_thread_local_state, name, new_val) + try: + yield + finally: + if prev_val is unset: + delattr(_thread_local_state, name) + else: + setattr(_thread_local_state, name, prev_val) + set_state.__name__ = name[4:] if name.startswith('jax_') else name + set_state.__doc__ = f"Context manager for `{name}` config option.\n\n{help}" + return set_state + + _thread_local_state = threading.local() class Unset: pass unset = Unset() -class NameSpace(object): - def __init__(self, getter): - self._getter = getter +class NameSpace: + def __init__(self, getter, setter): + # must use super because we override this class's __setattr__, see + # https://docs.python.org/3/reference/datamodel.html#object.__setattr__ + super().__setattr__('_getter', getter) + super().__setattr__('_setter', setter) def __getattr__(self, name): return self._getter(name) + def __setattr__(self, name, val): + self._setter(name, val) + config = Config() flags = config @@ -357,3 +409,32 @@ def _update_x64_thread_local(val): config._contextmanager_flags.remove("jax_enable_x64") Config.x64_enabled = Config.jax_enable_x64 # type: ignore + + +numpy_rank_promotion = config.define_enum_state( + name='jax_numpy_rank_promotion', + enum_values=['allow', 'warn', 'raise'], + default='allow', + help=('Control NumPy-style automatic rank promotion broadcasting ' + '("allow", "warn", or "raise").')) + +default_matmul_precision = config.define_enum_state( + name='jax_default_matmul_precision', + enum_values=['bfloat16', 'tensorfloat32', 'float32'], + default=None, + help=('Control the default matmul and conv precision for 32bit inputs.\n\n' + + 'Some platforms, like TPU, offer configurable precision levels for ' + 'matrix multiplication and convolution computations, trading off ' + 'accuracy for speed. The precision can be controlled for each ' + 'operation; for example, see the :func:`jax.lax.conv_general_dilated` ' + 'and :func:`jax.lax.dot` docstrings. But it can be useful to control ' + 'the default behavior obtained when an operation is not given a ' + 'specific precision.\n\n' + + 'This option can be used to control the default precision ' + 'level for computations involved in matrix multiplication and ' + 'convolution on 32bit inputs. The levels roughly describe the ' + "precision at which scalar products are computed. The 'bfloat16' " + "option is the fastest and least precise; 'float32' is similar to " + "full float32 precision; 'tensorfloat32' is intermediate.\n\n")) diff --git a/tests/api_test.py b/tests/api_test.py index cad8776aaf4a..ad10636a2ab8 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -25,6 +25,7 @@ import weakref import functools import itertools as it +import operator as op from absl import logging from absl.testing import absltest, parameterized @@ -2399,6 +2400,58 @@ def test_large_python_int_to_float(self): out = lax.convert_element_type(2 ** 100, jnp.float32) # doesn't crash self.assertArraysEqual(out, np.float32(2 ** 100)) + def test_dot_precision_context_manager(self): + x = jnp.zeros((2, 2)) + + with jax.default_matmul_precision(None): + jnp.dot(x, x) # doesn't crash + jaxpr = jax.make_jaxpr(jnp.dot)(x, x) + self.assertIn('precision=None', str(jaxpr)) + + with jax.default_matmul_precision("bfloat16"): + x @ x # doesn't crash + jaxpr = jax.make_jaxpr(op.matmul)(x, x) + self.assertIn('precision=DEFAULT', str(jaxpr)) + + with jax.default_matmul_precision("tensorfloat32"): + jnp.dot(x, x) # doesn't crash + jaxpr = jax.make_jaxpr(jnp.dot)(x, x) + self.assertIn('precision=HIGH\n', str(jaxpr)) + + with jax.default_matmul_precision("float32"): + jnp.dot(x, x) # doesn't crash + jaxpr = jax.make_jaxpr(jnp.dot)(x, x) + self.assertIn('precision=HIGHEST', str(jaxpr)) + + dot = partial(jnp.dot, precision=lax.Precision.HIGHEST) + with jax.default_matmul_precision("tensorfloat32"): + dot(x, x) # doesn't crash + jaxpr = jax.make_jaxpr(dot)(x, x) + self.assertIn('precision=HIGHEST', str(jaxpr)) + + def test_dot_precision_flag(self): + x = jnp.zeros((2, 2)) + + prev_val = config._read("jax_default_matmul_precision") + try: + config.FLAGS.jax_default_matmul_precision = "tensorfloat32" + jnp.dot(x, x) # doesn't crash + jaxpr = jax.make_jaxpr(jnp.dot)(x, x) + finally: + config.FLAGS.jax_default_matmul_precision = prev_val + self.assertIn('precision=HIGH', str(jaxpr)) + self.assertEqual(prev_val, config._read("jax_default_matmul_precision")) + + prev_val = config._read("jax_default_matmul_precision") + try: + config.update('jax_default_matmul_precision','tensorfloat32') + jnp.dot(x, x) # doesn't crash + jaxpr = jax.make_jaxpr(jnp.dot)(x, x) + finally: + config.update('jax_default_matmul_precision', prev_val) + self.assertIn('precision=HIGH', str(jaxpr)) + self.assertEqual(prev_val, config._read("jax_default_matmul_precision")) + class RematTest(jtu.JaxTestCase): diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index d7b54e0664ea..aa03ea29b86a 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -4769,21 +4769,21 @@ def np_op(start, stop): def testDisableNumpyRankPromotionBroadcasting(self): try: - prev_flag = FLAGS.jax_numpy_rank_promotion + prev_flag = config.jax_numpy_rank_promotion FLAGS.jax_numpy_rank_promotion = "allow" jnp.ones(2) + jnp.ones((1, 2)) # works just fine finally: FLAGS.jax_numpy_rank_promotion = prev_flag try: - prev_flag = FLAGS.jax_numpy_rank_promotion + prev_flag = config.jax_numpy_rank_promotion FLAGS.jax_numpy_rank_promotion = "raise" self.assertRaises(ValueError, lambda: jnp.ones(2) + jnp.ones((1, 2))) finally: FLAGS.jax_numpy_rank_promotion = prev_flag try: - prev_flag = FLAGS.jax_numpy_rank_promotion + prev_flag = config.jax_numpy_rank_promotion FLAGS.jax_numpy_rank_promotion = "warn" with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") @@ -4800,6 +4800,27 @@ def testDisableNumpyRankPromotionBroadcasting(self): finally: FLAGS.jax_numpy_rank_promotion = prev_flag + def testDisableNumpyRankPromotionBroadcastingDecorator(self): + with jax.numpy_rank_promotion("allow"): + jnp.ones(2) + jnp.ones((1, 2)) # works just fine + + with jax.numpy_rank_promotion("raise"): + self.assertRaises(ValueError, lambda: jnp.ones(2) + jnp.ones((1, 2))) + + with jax.numpy_rank_promotion("warn"): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + jnp.ones(2) + jnp.ones((1, 2)) + assert len(w) > 0 + msg = str(w[-1].message) + expected_msg = ("Following NumPy automatic rank promotion for add on " + "shapes (2,) (1, 2).") + self.assertEqual(msg[:len(expected_msg)], expected_msg) + + prev_len = len(w) + jnp.ones(2) + 3 + self.assertEqual(len(w), prev_len) # don't want to warn for scalars + def testStackArrayArgument(self): # tests https://github.com/google/jax/issues/1271 @api.jit