Skip to content

Commit

Permalink
remove XLA_USE_BF16 and other variants (#7582)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Jun 27, 2024
1 parent ed15ea1 commit 7d41035
Show file tree
Hide file tree
Showing 12 changed files with 42 additions and 504 deletions.
22 changes: 0 additions & 22 deletions API_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,28 +207,6 @@ copying data between an XLA device and the CPU. Inserting a barrier when
taking an optimizer step explicitly synchronizes the CPU and the XLA device. For
more information about our lazy tensor design, you can read [this paper](https://arxiv.org/pdf/2102.13267.pdf).

### XLA Tensors and bFloat16

PyTorch/XLA can use the
[bfloat16](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format)
datatype when running on TPUs. In fact, PyTorch/XLA handles float types
(`torch.float` and `torch.double`) differently on TPUs. This behavior is
controlled by the `XLA_USE_BF16` and `XLA_DOWNCAST_BF16` environment variable:

- By default both `torch.float` and `torch.double` are
`torch.float` on TPUs.
- If `XLA_USE_BF16` is set, then `torch.float` and `torch.double` are both
`bfloat16` on TPUs.
- If `XLA_DOWNCAST_BF16` is set, then `torch.float` is `bfloat16` on TPUs and `torch.double` is `float32` on TPUs.
- If a PyTorch tensor has `torch.bfloat16` data type, this will be directly
mapped to the TPU `bfloat16` (XLA `BF16` primitive type).

Developers should note that *XLA tensors on TPUs will always report their PyTorch datatype* regardless of
the actual datatype they're using. This conversion is automatic and opaque.
If an XLA tensor on a TPU is moved back to the CPU it will be converted
from its actual datatype to its PyTorch datatype. Depending on how your code operates, this conversion triggered by
the type of processing unit can be important.

### Memory Layout

The internal data representation of XLA tensors is opaque to the user. They
Expand Down
5 changes: 2 additions & 3 deletions docs/amp.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ Please file an issue or submit a pull request if there is an operator that shoul

### Best Practices
1. `autocast` should wrap only the forward pass(es) and loss computation(s) of the network. Backward ops run in the same type that autocast used for the corresponding forward ops.
2. Do not set `XLA_USE_BF16` flag when using AMP on TPUs. This will override the per-operator precision settings provided by AMP and cause all operators to execute in bfloat16.
3. Since TPU's use bfloat16 mixed precision, gradient scaling is not necessary.
4. Pytorch/XLA provides modified version of [optimizers](https://github.com/pytorch/xla/tree/master/torch_xla/amp/syncfree) that avoid the additional sync between device and host.
2. Since TPU's use bfloat16 mixed precision, gradient scaling is not necessary.
3. Pytorch/XLA provides modified version of [optimizers](https://github.com/pytorch/xla/tree/master/torch_xla/amp/syncfree) that avoid the additional sync between device and host.

### Supported Operators
AMP on TPUs operates like Pytorch's AMP. Rules for how autocasting is applied is summarized below:
Expand Down
13 changes: 0 additions & 13 deletions docs/first_steps.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,19 +149,6 @@ Now, consider using [Stable Diffusion Inference](https://github.com/huggingface/
(vm)$ python3 inference_tpu_single_device.py
```

Since there is no bf16 version of the SD-XL model available, you can use the `XLA_USE_BF16=1` flag to convert all values to bf16 and speed up training.
```
(vm)$ XLA_USE_BF16=1 python3 inference_tpu_single_device.py # uses sd-xl version
```
or
```
(vm)$ python3 inference_tpu_multidevice.py # uses 2.1 version
```
(already includes `torch.bfloat16` in the 2.1 version of the model).

Warning: watch out for caveats highlighted [here](https://github.com/huggingface/diffusers/pull/4254#issuecomment-1712289803).


# Running on a Single TPU device

This section describes the changes that need to be made to the [text_to_image inference example](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#inference) code to run it on TPUs.
Expand Down
272 changes: 0 additions & 272 deletions docs/pytorch_xla_overview.md

This file was deleted.

1 change: 0 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ debug
.. autofunction:: metric_names
.. autofunction:: metric_data

.. mdinclude:: ../pytorch_xla_overview.md
.. mdinclude:: ../../TROUBLESHOOTING.md
.. mdinclude:: ../pjrt.md
.. mdinclude:: ../dynamo.md
Expand Down
13 changes: 1 addition & 12 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,6 @@ function run_test_without_functionalization {
XLA_DISABLE_FUNCTIONALIZATION=1 run_test "$@"
}

function run_use_bf16 {
echo "Running with XLA_USE_BF16: $@"
XLA_USE_BF16=1 run_test "$@"
}

function run_downcast_bf16 {
echo "Running with XLA_DOWNCAST_BF16: $@"
XLA_DOWNCAST_BF16=1 run_test "$@"
}

function run_xla_ir_debug {
echo "Running with XLA_IR_DEBUG: $@"
XLA_IR_DEBUG=1 run_test "$@"
Expand Down Expand Up @@ -191,7 +181,7 @@ function run_xla_op_tests1 {
run_test "$CDIR/dynamo/test_num_output.py"
run_test "$CDIR/dynamo/test_graph_input_matcher.py"
run_save_tensor_ir "$CDIR/dynamo/test_dynamo_graph_dump.py"
run_use_bf16 "$CDIR/test_data_type.py"
run_test "$CDIR/test_data_type.py"
run_xla_ir_debug "$CDIR/test_env_var_mapper.py"
run_xla_hlo_debug "$CDIR/test_env_var_mapper.py"
run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_save_load.py"
Expand All @@ -200,7 +190,6 @@ function run_xla_op_tests1 {
}

function run_xla_op_tests2 {
run_downcast_bf16 "$CDIR/test_data_type.py"
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_while_loop.py"
run_test "$CDIR/test_autocast.py"
Expand Down
73 changes: 18 additions & 55 deletions test/test_data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,63 +13,26 @@ def check_env_flag(name, default=''):

class XlaDataTypeTest(unittest.TestCase):

def test_datatype_f32(self):
t1 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device())
t2 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device())
t3 = torch.div(t1, t2, rounding_mode='floor')
assert t3.dtype == torch.float

hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3])
device_data_hlo = hlo_text.split('\n')[1]
assert 'xla::device_data' in device_data_hlo, device_data_hlo
if check_env_flag('XLA_USE_BF16') or check_env_flag('XLA_DOWNCAST_BF16'):
assert 'bf16' in device_data_hlo, device_data_hlo
elif check_env_flag('XLA_USE_FP16') or check_env_flag('XLA_DOWNCAST_FP16'):
assert 'f16' in device_data_hlo, device_data_hlo
else:
assert 'f32' in device_data_hlo, device_data_hlo

def test_datatype_f64(self):
t1 = torch.tensor([2.0, 3.0], dtype=torch.double, device=xm.xla_device())
t2 = torch.tensor([2.0, 3.0], dtype=torch.double, device=xm.xla_device())
t3 = torch.div(t1, t2, rounding_mode='floor')
assert t3.dtype == torch.double

hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3])
device_data_hlo = hlo_text.split('\n')[1]
assert 'xla::device_data' in device_data_hlo, device_data_hlo
if check_env_flag('XLA_USE_BF16'):
assert 'bf16' in device_data_hlo, device_data_hlo
elif check_env_flag('XLA_USE_FP16'):
assert 'f16' in device_data_hlo, device_data_hlo
elif check_env_flag('XLA_DOWNCAST_BF16') or check_env_flag(
'XLA_DOWNCAST_FP16'):
assert 'f32' in device_data_hlo, device_data_hlo
else:
assert 'f64' in device_data_hlo, device_data_hlo

def test_datatype_f32_div_f64(self):
t1 = torch.rand(2, 2, dtype=torch.float, device=xm.xla_device())
t2 = t1 / 2.0
hlo_text = torch_xla._XLAC._get_xla_tensors_text([t2])
assert t2.dtype == torch.float
assert 'f64' not in hlo_text

def test_datatype_U16_32_64(self):

def _dtype_round_trip(dtype):
t = torch.randint(0, 128, (2, 4), dtype=dtype).to(xm.xla_device())
return t.cpu().dtype

for dtype in [torch.uint16, torch.uint32, torch.uint64]:
dtype2 = _dtype_round_trip(dtype)
self.assertTrue(dtype == dtype2)
def test_module_to_dtype(self):
device = torch_xla.device()
linear = torch.nn.Linear(
5, 10, dtype=torch.float32).to(device).to(torch.bfloat16)
input = torch.randn(
10,
5,
).to(device).to(torch.bfloat16)
xm.mark_step()
res = linear(input)

hlo_text = torch_xla._XLAC._get_xla_tensors_text([res])
res_hlo = hlo_text.split('\n')[-3]
assert 'bf16' in res_hlo, res_hlo

linear_weight_hlo = torch_xla._XLAC._get_xla_tensors_text([linear.weight
]).split('\n')[-3]
assert 'bf16' in linear_weight_hlo, linear_weight_hlo


if __name__ == '__main__':
print(f'XLA_USE_BF16: {os.getenv("XLA_USE_BF16")}')
print(f'XLA_USE_FP16: {os.getenv("XLA_USE_FP16")}')
print(f'XLA_DOWNCAST_BF16: {os.getenv("XLA_DOWNCAST_BF16")}')
print(f'XLA_DOWNCAST_FP16: {os.getenv("XLA_DOWNCAST_FP16")}')
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
13 changes: 0 additions & 13 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,19 +267,6 @@ def test_flash_attention_wrapper_causal(self):
self.assertFalse(torch.allclose(o.cpu(), expected_o.cpu()))
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@unittest.mock.patch.dict(os.environ, {"XLA_USE_BF16": "1"})
def test_flash_attention_wrapper_bf16(self):
from torch_xla.experimental.custom_kernel import flash_attention

q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")

# No exception being raised.
o = flash_attention(q, k, v)

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_multiple_returns(self):
import jax._src.pallas.mosaic.pallas_call_registration
Expand Down
2 changes: 1 addition & 1 deletion test/test_train_mp_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
)

# Best config to achieve peak performance based on TPU version
# 1. It is recommended to use this config in conjuntion with XLA_USE_BF16=1 Flag.
# 1. It is recommended to move the model to bf16 before training.
# 2. Hyperparameters can be tuned to further improve the accuracy.
# usage: python3 /usr/share/pytorch/xla/test/test_train_mp_imagenet.py --model=resnet50 \
# --fake_data --num_epochs=10 --log_steps=300 \
Expand Down
13 changes: 13 additions & 0 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import re
import tempfile
import warnings

import torch
import _XLAC
Expand Down Expand Up @@ -140,9 +141,21 @@ def _setup_tpu_vm_library_path() -> bool:
return False


def _check_deprecated_env_var():
deprecated_env_vars = [
'XLA_USE_BF16', 'XLA_USE_FP16', 'XLA_DOWNCAST_BF16', 'XLA_DOWNCAST_FP16',
'XLA_USE_32BIT_LONG'
]
for env_var in deprecated_env_vars:
if os.environ.get(env_var):
warnings.warn(f"The environment variable '{env_var}' is deprecated "
"Please update your code to avoid using it.")


# These needs to be called before the _XLAC module is loaded.
_setup_default_env()
_setup_xla_flags()
_check_deprecated_env_var()
if int(os.environ.get('PT_XLA_DEBUG', '0')):
_fd, _tmp_fname = _setup_debug_env()

Expand Down
113 changes: 7 additions & 106 deletions torch_xla/csrc/dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,92 +5,6 @@

namespace torch_xla {

namespace {

bool ShouldUseBF16() {
bool use_bf16 = runtime::sys_util::GetEnvBool("XLA_USE_BF16", false);
if (use_bf16) {
std::cout
<< "XLA_USE_BF16 will be deprecated after the 2.4 release, please "
"convert your model to bf16 directly\n";
TF_LOG(INFO) << "Using BF16 data type for floating point values";
}
return use_bf16;
}

bool ShouldUseF16() {
bool use_fp16 = runtime::sys_util::GetEnvBool("XLA_USE_FP16", false);
if (use_fp16) {
std::cout
<< "XLA_USE_FP16 will be deprecated after the 2.4 release, please "
"convert your model to fp16 directly\n";
TF_LOG(INFO) << "Using F16 data type for floating point values";
}
return use_fp16;
}

bool ShouldDowncastToBF16() {
bool downcast_bf16 =
runtime::sys_util::GetEnvBool("XLA_DOWNCAST_BF16", false);
if (downcast_bf16) {
std::cout
<< "XLA_DOWNCAST_BF16 will be deprecated after the 2.4 release, please "
"downcast your model directly\n";
TF_LOG(INFO) << "Downcasting floating point values, F64->F32, F32->BF16";
}
return downcast_bf16;
}

bool ShouldDowncastToF16() {
bool downcast_fp16 =
runtime::sys_util::GetEnvBool("XLA_DOWNCAST_FP16", false);
if (downcast_fp16) {
std::cout
<< "XLA_DOWNCAST_FP16 will be deprecated after the 2.4 release, please "
"downcast your model directly\n";
TF_LOG(INFO) << "Downcasting floating point values, F64->F32, F32->FP16";
}
return downcast_fp16;
}

bool ShouldUse32BitLong() {
bool use_32bit_long =
runtime::sys_util::GetEnvBool("XLA_USE_32BIT_LONG", false);
if (use_32bit_long) {
std::cout
<< "XLA_USE_32BIT_LONG will be deprecated after the 2.4 release\n";
TF_LOG(INFO) << "Using 32bit integers for kLong values";
}
return use_32bit_long;
}

bool UseBF16() {
static bool use_bf16 = ShouldUseBF16();
return use_bf16;
}

bool UseF16() {
static bool use_fp16 = ShouldUseF16();
return use_fp16;
}

bool DowncastBF16() {
static bool downcast_bf16 = ShouldDowncastToBF16();
return downcast_bf16;
}

bool DowncastF16() {
static bool downcast_fp16 = ShouldDowncastToF16();
return downcast_fp16;
}

bool Use32BitLong() {
static bool use_32bit_long = ShouldUse32BitLong();
return use_32bit_long;
}

} // namespace

at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) {
switch (xla_type) {
case xla::PrimitiveType::BF16:
Expand Down Expand Up @@ -167,32 +81,22 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType(
XlaDeviceType hw_type = static_cast<XlaDeviceType>(device.type());
switch (type) {
case xla::PrimitiveType::F64:
if (UseF16()) {
return xla::PrimitiveType::F16;
}
if (UseBF16()) {
return xla::PrimitiveType::BF16;
}
if (DowncastBF16() || DowncastF16() || hw_type == XlaDeviceType::NEURON) {
if (hw_type == XlaDeviceType::NEURON) {
return xla::PrimitiveType::F32;
}
return xla::PrimitiveType::F64;
case xla::PrimitiveType::F32:
if (UseF16() || DowncastF16()) {
return xla::PrimitiveType::F16;
}
return UseBF16() || DowncastBF16() ? xla::PrimitiveType::BF16
: xla::PrimitiveType::F32;
return xla::PrimitiveType::F32;
case xla::PrimitiveType::U16:
return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::U16
: xla::PrimitiveType::U32;
case xla::PrimitiveType::S16:
return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::S16
: xla::PrimitiveType::S32;
case xla::PrimitiveType::S64:
return Use32BitLong() ? xla::PrimitiveType::S32 : xla::PrimitiveType::S64;
return xla::PrimitiveType::S64;
case xla::PrimitiveType::U64:
return Use32BitLong() ? xla::PrimitiveType::U32 : xla::PrimitiveType::U64;
return xla::PrimitiveType::U64;
case xla::PrimitiveType::C128:
return xla::PrimitiveType::C128;
default:
Expand All @@ -210,14 +114,11 @@ at::ScalarType MaybeUpcastToHostTorchType(xla::PrimitiveType xla_type) {
at::ScalarType scalar_type = TorchTypeFromXlaType(xla_type);
switch (scalar_type) {
case at::ScalarType::BFloat16:
return UseBF16() || DowncastBF16() ? at::ScalarType::Float
: at::ScalarType::BFloat16;
return at::ScalarType::BFloat16;
case at::ScalarType::Half:
return UseF16() || DowncastF16() ? at::ScalarType::Float
: at::ScalarType::Half;
return at::ScalarType::Half;
case at::ScalarType::Float:
return DowncastBF16() || DowncastF16() ? at::ScalarType::Double
: at::ScalarType::Float;
return at::ScalarType::Float;
default:
return scalar_type;
}
Expand Down
Loading

0 comments on commit 7d41035

Please sign in to comment.