Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport] Support XLA_USE_BF16 #6841

Merged
merged 1 commit into from
Mar 28, 2024
Merged

[Backport] Support XLA_USE_BF16 #6841

merged 1 commit into from
Mar 28, 2024

Conversation

alanwaketan
Copy link
Collaborator

Summary:
XLA_USE_BF16=1 will make all the internal xla tensors to use BF16 but torch.tensor wrappers will still return torch.float. To address this, we need to set the jax tracers correctly to produce the correct Mosaic.

Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py -v -k test_flash_attention_wrapper_bf16

Summary:
XLA_USE_BF16=1 will make all the internal xla tensors to use BF16 but
torch.tensor wrappers will still return torch.float. To address this,
we need to set the jax tracers correctly to produce the correct Mosaic.

Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py -v -k test_flash_attention_wrapper_bf16

address comments
@lsy323 lsy323 merged commit c6a8874 into r2.3 Mar 28, 2024
17 checks passed
@alanwaketan
Copy link
Collaborator Author

Thanks Jack and Siyuan.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants