Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include bf16 support for TPUs and CPUs, and a better check for if a CUDA device supports BF16 #462

Merged
merged 9 commits into from
Jun 22, 2022

Conversation

muellerzr
Copy link
Collaborator

Add BF16 support for TPUs and CPUs

What does this add?

This PR includes better support on TPUs for BF16, introduces BF16 support for CPUs, and a helper to check if your GPU also happens to support bf16 if set.

Who is it for?

Users of Accelerate who want to train on BF16

Why is it needed?

Post fixing the conditional for testing fp16 in the test script, I realized that CPU's can support bfloat16 (though is it advised is debatable), and that to enable TPU bf16 in Accelerate you should set an environment variable beforehand. This PR fixes these two.

What parts of the API does this impact?

User-facing:

The user can now pass mixed_precision="bf16" and train on bf16 in modern CPUs and GPUs.

Internal structure:

Adds a is_bf16_available function that will check if we're on the CPU and a torch version > 1.10, runs torch.cuda.is_bf16_available() if on the GPU, and will return whether the TPU should use BF16 or not as a param (useful for testing or opting to not run on bf16).

Internally sets the XLA_USE_BF16 env variable in AcceleratorState based on if we're using BF16 or not.

Basic Usage Example(s):

# When training on the CPU, TPU, or GPU
accelerator = Accelerate(mixed_precision="bf16")

If your GPU is not supported, it will raise an error stating so.

When would I use it, and when wouldn't I?

When wanting to train on bf16 on CPU, GPU, and TPU.

Does a similar feature exist? If so, why is this better?

For TPU, to use bf16 you set XLA_USE_BF16=1 to do so. We do this automatically for you.

@muellerzr muellerzr added enhancement New feature or request CPU Bug or feature on CPU or MultiCPU platforms GPU Bug or feature on GPU or MultiGPU platforms TPU Bug or feature on TPU platforms labels Jun 22, 2022
@muellerzr muellerzr requested a review from sgugger June 22, 2022 15:49
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 22, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the fixes in this PR, plus the new BF16 support!

elif self.mixed_precision == "bf16" and is_bf16_available():
if self.distributed_type in [DistributedType.NO, DistributedType.MULTI_CPU, DistributedType.MULTI_GPU]:
device_type = "cpu" if not torch.cuda.is_available() else "cuda"
autocast_context = torch.autocast(dtype=torch.bfloat16, device_type=device_type)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to be extra sure that this always exists for PyTorch version for which is_bf16_available()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed by adding a torch check for >= 1.10

src/accelerate/launchers.py Outdated Show resolved Hide resolved
src/accelerate/state.py Outdated Show resolved Hide resolved
Comment on lines +279 to +280
# TEST that previous fp16 flag still works
print("Legacy FP16 training check.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we need to keep this. We have done a couple of releases since we deprecated it, so it's okay if we stop testing it IMO.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd feel more comfortable dropping the test once we've removed entirely the legacy param (whenever that may be)

@muellerzr muellerzr merged commit f13c59f into main Jun 22, 2022
@muellerzr muellerzr deleted the bf16 branch June 22, 2022 21:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CPU Bug or feature on CPU or MultiCPU platforms enhancement New feature or request GPU Bug or feature on GPU or MultiGPU platforms TPU Bug or feature on TPU platforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants