-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add jax_default_matmul_precision flag and context manager #6143
Conversation
047be0a
to
c2ef1fc
Compare
ad63cc3
to
f564880
Compare
jax/_src/lax/lax.py
Outdated
Tuple[PrecisionType, PrecisionType]] | ||
_precision_strings = { | ||
'bfloat16': Precision.DEFAULT, | ||
'tensorfloat32': Precision.HIGH, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering whether tensorfloat32
is an good description. On A100 I believe CUDA tensorfloat32
gives you 10 bits of mantissa on the input. And I'm not sure that HIGH
in XLA's naming actually gives you tensorfloat32
on GPU.
On TPU it means something different, namely a multipass algorithm on bfloat16 inputs. Is conflating the two is wise? I suspect we should just have two different names here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that it's probably misleading to use "tensorfloat32" if XLA doesn't actually support that on GPU (but clearly it should, at least in some form). This might be a good time to consult with the XLA GPU team to see how they feel about conflating Precision.HIGH
and tensorfloat32
.
See #2161 (comment) for notes on possible names. My favorite alternatives were bfloat24
and bfloat16_3x
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the nice thing about strings instead of enums is that we can support multiple redundant names -- they don't have to be unique, and we can extend it over time.
So perhaps we could support all of bfloat16
/bfloat16_1x
, bfloat16_3x
and float32
/bfloat16_6x
for now, and extend that list later when XLA adds true support for tensorfloat32
on GPUs, in whatever form that takes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest you want a couple of generic options:
- fastest
- most precise
and after that point you want to be completely specific about which algorithm you mean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, so for the current state of things, how about:
precision='highest'
->lax.Precision.HIGHEST
precision='float32'
->lax.Precision.HIGHEST
precision='bfloat16_3x'
->lax.Precision.HIGH
precision='bfloat16'
->lax.Precision.DEFAULT
precision='fastest'
->lax.Precision.DEFAULT
(We could add the bfloat16_1x
and bfloat16_6x
aliases as well, if desired)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @jekbradbury
Maybe we should discuss on chat for lower latency?
"matmul_precision" might be a more descriptive name than "dot_precision", as it only effects matrix-matrix multiplication and convolutions, which are arguably a special case of matrix/matrix multiplication. |
bc56413
to
8dd05f0
Compare
Ping for robots... |
Notice the "diffbase" on the flag-cleanup branch, i.e. #6112.See the comment on #6112 for more context.fixes #2161
This PR adds a configuration option
jax_default_matmul_precision
and a context managerjax.default_matmul_precision(...)
to control the default precision of internal computations used in matrix multiplies and convolutions on float32 inputs for supported backends (currently just TPU but likely soon A100 GPUs as well).For example, say we have a function
foo
which includes ajnp.dot
call:We can ensure that
dot
is computed at the highest (or lowest) precision using any of these methods:JAX_DEFAULT_MATMUL_PRECISION=float32
(orJAX_DEFAULT_MATMUL_PRECISION=bfloat16
);absl
, then we can use the command-line flag--jax_default_matmul_precision=float32
(or--jax_default_matmul_precision=bfloat16
);jax.config.update('jax_default_matmul_precision', 'float32')
(orjax.config.update('jax_default_matmul_precision', 'bfloat16')
);foo
, we can use thejax.default_matmul_precision
context manager:This configuration option controls the default precision in the sense that convolution operations like
lax.conv_general_dilated
and matrix multiply operations likelax.dot
take an optionalprecision
argument. This configuration option does not change the behaviors of such calls with explicit precision arguments; it only changes the behaviors of calls with no such argument provided.This PR does not change the default default dot precision; that remains
'bfloat16'
. It only adds new ways to control the default dot precision. We might change the default in follow-up work.In follow-up work we may add an analogous bit of enum state for controlling the default device. But I think we should land this part first.
TODO:
cc @rohan-anil @sharadmv @shoyer @SiegeLordEx @jonbarron