-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Include bf16 support for TPUs and CPUs, and a better check for if a CUDA device supports BF16 #462
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the fixes in this PR, plus the new BF16 support!
elif self.mixed_precision == "bf16" and is_bf16_available(): | ||
if self.distributed_type in [DistributedType.NO, DistributedType.MULTI_CPU, DistributedType.MULTI_GPU]: | ||
device_type = "cpu" if not torch.cuda.is_available() else "cuda" | ||
autocast_context = torch.autocast(dtype=torch.bfloat16, device_type=device_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to be extra sure that this always exists for PyTorch version for which is_bf16_available()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed by adding a torch check for >= 1.10
# TEST that previous fp16 flag still works | ||
print("Legacy FP16 training check.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure we need to keep this. We have done a couple of releases since we deprecated it, so it's okay if we stop testing it IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd feel more comfortable dropping the test once we've removed entirely the legacy param (whenever that may be)
Add BF16 support for TPUs and CPUs
What does this add?
This PR includes better support on TPUs for BF16, introduces BF16 support for CPUs, and a helper to check if your GPU also happens to support bf16 if set.
Who is it for?
Users of Accelerate who want to train on BF16
Why is it needed?
Post fixing the conditional for testing fp16 in the test script, I realized that CPU's can support bfloat16 (though is it advised is debatable), and that to enable TPU bf16 in Accelerate you should set an environment variable beforehand. This PR fixes these two.
What parts of the API does this impact?
User-facing:
The user can now pass
mixed_precision="bf16"
and train on bf16 in modern CPUs and GPUs.Internal structure:
Adds a
is_bf16_available
function that will check if we're on the CPU and a torch version > 1.10, runstorch.cuda.is_bf16_available()
if on the GPU, and will return whether the TPU should use BF16 or not as a param (useful for testing or opting to not run on bf16).Internally sets the
XLA_USE_BF16
env variable inAcceleratorState
based on if we're using BF16 or not.Basic Usage Example(s):
If your GPU is not supported, it will raise an error stating so.
When would I use it, and when wouldn't I?
When wanting to train on bf16 on CPU, GPU, and TPU.
Does a similar feature exist? If so, why is this better?
For TPU, to use
bf16
you setXLA_USE_BF16=1
to do so. We do this automatically for you.