From 7d41035b89fe6b6f6cfb13679ffc256429efa7b2 Mon Sep 17 00:00:00 2001
From: JackCaoG <59073027+JackCaoG@users.noreply.github.com>
Date: Thu, 27 Jun 2024 11:57:41 -0700
Subject: [PATCH] remove XLA_USE_BF16 and other variants (#7582)
---
API_GUIDE.md | 22 --
docs/amp.md | 5 +-
docs/first_steps.md | 13 --
docs/pytorch_xla_overview.md | 272 ------------------------
docs/source/index.rst | 1 -
test/run_tests.sh | 13 +-
test/test_data_type.py | 73 ++-----
test/test_pallas.py | 13 --
test/test_train_mp_imagenet.py | 2 +-
torch_xla/__init__.py | 13 ++
torch_xla/csrc/dtype.cpp | 113 +---------
torch_xla/experimental/custom_kernel.py | 6 -
12 files changed, 42 insertions(+), 504 deletions(-)
delete mode 100644 docs/pytorch_xla_overview.md
diff --git a/API_GUIDE.md b/API_GUIDE.md
index 1fc9c506b8c..778d29cdf1d 100644
--- a/API_GUIDE.md
+++ b/API_GUIDE.md
@@ -207,28 +207,6 @@ copying data between an XLA device and the CPU. Inserting a barrier when
taking an optimizer step explicitly synchronizes the CPU and the XLA device. For
more information about our lazy tensor design, you can read [this paper](https://arxiv.org/pdf/2102.13267.pdf).
-### XLA Tensors and bFloat16
-
-PyTorch/XLA can use the
-[bfloat16](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format)
-datatype when running on TPUs. In fact, PyTorch/XLA handles float types
-(`torch.float` and `torch.double`) differently on TPUs. This behavior is
-controlled by the `XLA_USE_BF16` and `XLA_DOWNCAST_BF16` environment variable:
-
-- By default both `torch.float` and `torch.double` are
-`torch.float` on TPUs.
-- If `XLA_USE_BF16` is set, then `torch.float` and `torch.double` are both
-`bfloat16` on TPUs.
-- If `XLA_DOWNCAST_BF16` is set, then `torch.float` is `bfloat16` on TPUs and `torch.double` is `float32` on TPUs.
-- If a PyTorch tensor has `torch.bfloat16` data type, this will be directly
-mapped to the TPU `bfloat16` (XLA `BF16` primitive type).
-
-Developers should note that *XLA tensors on TPUs will always report their PyTorch datatype* regardless of
-the actual datatype they're using. This conversion is automatic and opaque.
-If an XLA tensor on a TPU is moved back to the CPU it will be converted
-from its actual datatype to its PyTorch datatype. Depending on how your code operates, this conversion triggered by
-the type of processing unit can be important.
-
### Memory Layout
The internal data representation of XLA tensors is opaque to the user. They
diff --git a/docs/amp.md b/docs/amp.md
index 905fde036a1..0f138b32eec 100644
--- a/docs/amp.md
+++ b/docs/amp.md
@@ -32,9 +32,8 @@ Please file an issue or submit a pull request if there is an operator that shoul
### Best Practices
1. `autocast` should wrap only the forward pass(es) and loss computation(s) of the network. Backward ops run in the same type that autocast used for the corresponding forward ops.
-2. Do not set `XLA_USE_BF16` flag when using AMP on TPUs. This will override the per-operator precision settings provided by AMP and cause all operators to execute in bfloat16.
-3. Since TPU's use bfloat16 mixed precision, gradient scaling is not necessary.
-4. Pytorch/XLA provides modified version of [optimizers](https://github.com/pytorch/xla/tree/master/torch_xla/amp/syncfree) that avoid the additional sync between device and host.
+2. Since TPU's use bfloat16 mixed precision, gradient scaling is not necessary.
+3. Pytorch/XLA provides modified version of [optimizers](https://github.com/pytorch/xla/tree/master/torch_xla/amp/syncfree) that avoid the additional sync between device and host.
### Supported Operators
AMP on TPUs operates like Pytorch's AMP. Rules for how autocasting is applied is summarized below:
diff --git a/docs/first_steps.md b/docs/first_steps.md
index 2658d2d8bb8..a54b8e36e64 100644
--- a/docs/first_steps.md
+++ b/docs/first_steps.md
@@ -149,19 +149,6 @@ Now, consider using [Stable Diffusion Inference](https://github.com/huggingface/
(vm)$ python3 inference_tpu_single_device.py
```
-Since there is no bf16 version of the SD-XL model available, you can use the `XLA_USE_BF16=1` flag to convert all values to bf16 and speed up training.
-```
-(vm)$ XLA_USE_BF16=1 python3 inference_tpu_single_device.py # uses sd-xl version
-```
-or
-```
-(vm)$ python3 inference_tpu_multidevice.py # uses 2.1 version
-```
-(already includes `torch.bfloat16` in the 2.1 version of the model).
-
-Warning: watch out for caveats highlighted [here](https://github.com/huggingface/diffusers/pull/4254#issuecomment-1712289803).
-
-
# Running on a Single TPU device
This section describes the changes that need to be made to the [text_to_image inference example](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#inference) code to run it on TPUs.
diff --git a/docs/pytorch_xla_overview.md b/docs/pytorch_xla_overview.md
deleted file mode 100644
index da087098cae..00000000000
--- a/docs/pytorch_xla_overview.md
+++ /dev/null
@@ -1,272 +0,0 @@
-# Beginner's Guide to PyTorch/XLA
-
-### **Objective:**
-This document provides a high-level overview of PyTorch XLA and illustrates a
-few examples how PyTorch code is converted to run on XLA devices (e.g. TPUs).
-This is not a complete solution, and additional changes may be required
-depending on the specific code. However, this document should serve as a
-starting point for the conversion process.
-
-
-## Basic high-level understanding of some XLA details
-This section provides a brief overview of the basic details of PyTorch XLA,
- which should help readers better understand the required modifications and
- optimizations of code. It is supplement to the API guide described [here](https://github.com/pytorch/xla/blob/master/API_GUIDE.md).
-
-Unlike regular PyTorch, which executes code line by line and does not block execution until the value of a PyTorch tensor is fetched, PyTorch XLA works differently. It iterates through the python code and records the operations on (PyTorch) XLA tensors in an intermediate representation (IR) graph until it encounters a barrier (discussed below). This process of generating the IR graph is referred to as tracing (LazyTensor tracing or code tracing). PyTorch XLA then converts the IR graph to a lower-level machine-readable format called HLO (High-Level Opcodes). HLO is a representation of a computation that is specific to the XLA compiler and allows it to generate efficient code for the hardware that it is running on. HLO is fed to the XLA compiler for compilation and optimization. Compilation is then cached by PyTorch XLA to be reused later if/when needed. The compilation of the graph is done on the host (CPU), which is the machine that runs the Python code. If there are multiple XLA devices, the host compiles the code for each of the devices separately except when using SPMD (single-program, multiple-data). For example, v4-8 has one host machine and [four devices](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v4). In this case the host compiles the code for each of the four devices separately. In case of pod slices, when there are multiple hosts, each host does the compilation for XLA devices it is attached to. If SPMD is used, then the code is compiled only once (for given shapes and computations) on each host for all the devices.
-
-![img](assets/pytorchXLA_flow.svg)
-
-For more details and examples, please refer to the [LazyTensor guide](https://pytorch.org/blog/understanding-lazytensor-system-performance-with-pytorch-xla-on-cloud-tpu/).
-
-The operations in the IR graph are executed only when values of tensors are needed. This is referred to as evaluation or materialization of tensors. Sometimes this is also called lazy evaluation and it can lead to significant [performance improvements](https://arxiv.org/pdf/2102.13267.pdf).
-
-The _synchronous_ operations in Pytorch XLA, like printing, logging, checkpointing or callbacks block tracing and result in slower execution. In the case when an operation requires a specific value of an XLA tensor, e.g. `print(xla_tensor_z)`, tracing is blocked until the value of that tensor is available to the host. Note that only the part of the graph responsible for computing that tensor value is executed. These operations do not cut the IR graph, but they trigger host-device communication through `TransferFromDevice`, which results in slower performance.
-
-A _barrier_ is a special instruction that tells XLA to execute the IR graph and materialize the tensors. This means that the PyTorch XLA tensors will be evaluated, and the results will be available to the host. The user-exposed barrier in Pytorch XLA is [xm.mark_step()](https://github.com/pytorch/xla/blob/bdceee54eca1269ee954f6cdd1868c584d0e88a4/torch_xla/core/xla_model.py#L808), which breaks the IR graph and results in code execution on the XLA devices. One of the key properties of `xm.mark_step` is that unlike synchronous operations it does not block the further tracing while the device is executing the graph. However, it does block access to the values of the tensors that are being materialized.
-
-The example in the LazyTensor guide illustrates what happens in a simple case of adding two tensors. Now, suppose we have a for loop that adds XLA tensors and uses the value later:
-
-```
-for x, y in tensors_on_device:
- z += x + y
-```
-
-Without a barrier, the Python tracing will result in a single graph that wraps the addition of tensors `len(tensors_on_device)` times. This is because the `for` loop is not captured by the tracing, so each iteration of the loop will create a new subgraph corresponding to the computation of `z += x+y` and add it to the graph. Here is an example when `len(tensors_on_device)=3`.
-
-![img](assets/IRgraph_no_markstep.png)
-
-However, introducing a barrier at the end of the loop will result in a smaller graph that will be compiled once during the first pass inside the `for` loop and will be reused for the next `len(tensors_on_device)-1 ` iterations. The barrier will signal to the tracing that the graph traced so far can be submitted for execution, and if that graph has been seen before, a cached compiled program will be reused.
-
-```
-for x, y in tensors_on_device:
- z += x + y
- xm.mark_step()
-```
-
-In this case there will be a small graph that is used `len(tensors_on_device)=3` times.
-
-![img](assets/IRgraph_markstep.png)
-
-It is important to highlight that in PyTorch XLA Python code inside for loops is traced and a new graph is constructed for each iteration if there is a barrier at the end. This can be a significant performance bottleneck.
-
-The XLA graphs can be reused when the same computation happens on the same shapes of tensors. If the shapes of the inputs or intermediate tensors change, then the XLA compiler will recompile a new graph with the new tensor shapes. This means that if you have dynamic shapes or if your code does not reuse tensor graphs, running your model on XLA will not be suitable for that use case. Padding the input into a fixed shape can be an option to help avoid dynamic shapes. Otherwise, a significant amount of time will be spent by the compiler on optimizing and fusing operations which will not be used again.
-
-The trade-off between graph size and compilation time is also important to consider. If there is one large IR graph, the XLA compiler can spend a lot of time on optimization and fusion of the ops. This can result in a very long compilation time. However, the later execution may be much faster, due to the optimizations that were performed during compilation.
-
-Sometimes it is worth breaking the IR graph with `xm.mark_step()`. As explained above, this will result in a smaller graph that can be reused later. However making graphs smaller can reduce optimizations that otherwise could be done by the XLA compiler.
-
-Another important point to consider is [MPDeviceLoader](https://github.com/pytorch/xla/blob/a1f822e2627a5639464273241821852677401026/torch_xla/distributed/parallel_loader.py#L186). Once your code is running on an XLA device, consider wrapping the torch dataloader with XLA `MPDeviceLoader` which preloads data to the device to improve performance and includes `xm.mark_step()` in it. The latter automatically breaks the iterations over batches of data and sends them for execution. Note, if you are not using MPDeviceLoader, you might need to set `barrier=True` in the `optimizer_step()` to enable `xm.mark_step()` if running a training job or explicitly adding `xm.mark_step()`.
-
-## TPU Setup
-Create TPU with base image to use nightly wheels or from the stable release by specifying the `RUNTIME_VERSION`.
-```
-export ZONE=us-central2-b
-export PROJECT_ID=your-project-id
-export ACCELERATOR_TYPE=v4-8 # v4-16, v4-32, …
-export RUNTIME_VERSION=tpu-vm-v4-pt-2.0 # or tpu-vm-v4-base
-export TPU_NAME=your_tpu_name
-
-gcloud compute tpus tpu-vm create ${TPU_NAME} \
---zone=${ZONE} \
---accelerator-type=${ACCELERATOR_TYPE} \
---version=${RUNTIME_VERSION} \
---subnetwork=tpusubnet
-```
-
-If you have a single host VM (e.g. v4-8), you can ssh to your vm and run the following commands from the vm directly. Otherwise, in case of TPU pods, you can use `--worker=all --command=""` similar to
-
-```
-gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
---zone=us-central2-b \
---worker=all \
---command="pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl"
-```
-
-Next, if you are using base image, install nightly packages and required libraries
-
-```
-pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl
-pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl
-sudo apt-get install libopenblas-dev -y
-
-sudo apt-get update && sudo apt-get install libgl1 -y # diffusion specific
-```
-
-## Converting code to PyTorch XLA
-General guidelines to modify your code:
-* Replace `cuda` with `xm.xla_device()`
-* Remove progress bar, printing that would access the XLA tensor values
-* Reduce logging and callbacks that would access the XLA tensor values
-* Wrap data loader with MPDeviceLoader
-* Profile to further optimize the code
-
-Remember: each case is unique so you might need to do something different for each case.
-
-## Example 1. Stable Diffusion inference in PyTorch Lightning on a Single TPU Device
-
-As a first example consider the [inference code](https://github.com/pytorch-tpu/stable-diffusion/blob/main/scripts/txt2img.py) of the stable diffusion model in PyTorch Lightning which can be run from command line as
-```
-python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse"
-```
-
-For your reference, the diff of modifications described below can be found [here](https://github.com/pytorch-tpu/stable-diffusion/commit/57f398eb784387e244dc5fb78421aa5261abd1ef). Let's go over them step by step.
-As in the general guideline above, start with changes related to `cuda` device. This inference code is written to run on GPUs and `cuda` can be found in multiple places. Start making changes by removing `model.cuda()` from [this line](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/scripts/txt2img.py#L64), and `precision_scope` from [here](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/scripts/txt2img.py#L290). Additionally, replace the `cuda` device in [this line](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/scripts/txt2img.py#L248) with the `xla` device similar to the code below:
-
-Next, this particular configuration of the model is using `FrozenCLIPEmbedder`, therefore we will modify this [line](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/ldm/modules/encoders/modules.py#L143) as well. For simplicity we will directly define the `device` in this tutorial, but you can pass the `device` value to the function as well.
-```
-import torch_xla.core.xla_model as xm
-self.device = xm.xla_device()
-```
-
-Another place in the code that has cuda specific code is DDIM scheduler. Add `import torch_xla.core.xla_model as xm` on top of the file then replace [these](https://github.com/pytorch-tpu/stable-diffusion/blob/978da4c625a712a01ee066d019a0b0d2319cd8b3/ldm/models/diffusion/ddim.py#L21-L22) lines
-
-
-```
-if attr.device != torch.device("cuda"):
- attr = attr.to(torch.device("cuda"))
-```
-
-with
-
-```
-device = xm.xla_device()
-attr = attr.to(torch.device(device))
-```
-
-Next, you can reduce device (TPU) and host (CPU) communication by removing print statements, disabling progress bars, and reducing or removing callbacks and logging. These operations require the device to stop executing, falling back to the CPU, executing the logging/callbacks, and then returning to the device. This can be a significant performance bottleneck, especially on large models.
-
-After making these changes, the code will run on TPUs. However, the performance will be very slow. This is because the XLA compiler tries to build a single (huge) graph that wraps the number of inference steps (in this case, 50) as there is no barrier inside the for loop. It is difficult for the compiler to optimize the graph, and this leads to significant performance degradation. As discussed above, breaking the for loop with the barrier (xm.mark_step()) will result in a smaller graph that is easier for the compiler to optimize. This will also allow the compiler to reuse the graph from the previous step, which can improve performance.
-
-Now the [code](https://github.com/pytorch-tpu/stable-diffusion/blob/ss-inference/scripts/txt2img.py) is ready to run on TPUs in a reasonable time. More optimization and analysis can be done by [capturing a profile](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm) and investigating further. However, this is not covered here.
-
-Note: if you are running on v4-8 TPU, then you have 4 available XLA (TPU) devices. Running the code as above will only use one XLA device. In order to run on all 4 devices you need to use `xmp.spawn()` function to spawn the code on all the devices. We will discuss an `xmp.spawn` in the next example.
-
-## Example 2. HF Stable Diffusion Inference
-Now, consider using [Stable Diffusion Inference](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) in the HuggingFace diffusers library for both the SD-XL and 2.1 versions of the model. For your reference, the changes described below can be found in this [repo](https://github.com/pytorch-tpu/diffusers). You can clone the repo and run the inference using the following command on your TPU VM:
-
-```
-(vm)$ git clone https://github.com/pytorch-tpu/diffusers.git
-(vm)$ cd diffusers/examples/text_to_image/
-(vm)$ python3 inference_tpu_single_device.py
-```
-
-Since there is no bf16 version of the SD-XL model available, you can use the `XLA_USE_BF16=1` flag to convert all values to bf16 and speed up training.
-```
-(vm)$ XLA_USE_BF16=1 python3 inference_tpu_single_device.py # uses sd-xl version
-```
-or
-```
-(vm)$ python3 inference_tpu_multidevice.py # uses 2.1 version
-```
-(already includes `torch.bfloat16` in the 2.1 version of the model).
-
-Warning: watch out for caveats highlighted [here](https://github.com/huggingface/diffusers/pull/4254#issuecomment-1712289803).
-
-
-## Running on a Single TPU device
-
-This section describes the changes that need to be made to the [text_to_image inference example](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#inference) code to run it on TPUs.
-
-The original code uses Lora for inference, but this tutorial will not use it. Instead, we will set the `model_id` argument to `stabilityai/stable-diffusion-xl-base-0.9` when initializing the pipeline. We will also use the default scheduler (DPMSolverMultistepScheduler). However, similar changes can be made to the other schedulers as well.
-```
-git clone https://github.com/huggingface/diffusers
-cd diffusers
-pip install . # pip install -e .
-
-cd examples/text_to_image/
-pip install -r requirements.txt
-pip install invisible_watermark transformers accelerate safetensors
-```
-(If `accelerate` is not found, log out, log back in.)
-
-Log in to HF and agree to the [sd-xl 0.9 license](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9) on the model card. Next, go to [account→settings→access](https://huggingface.co/settings/tokens) token and generate a new token. Copy the token and run the following command with that specific token value on your vm
-```
-(vm)$ huggingface-cli login --token _your_copied_token__
-```
-
-The HuggingFace readme provides PyTorch code that is written to run on GPUs. To run it on TPUs, the first step is to change the CUDA device to an XLA device. This can be done by replacing the line `pipe.to("cuda")` with the following lines:
-
-```
-import torch_xla.core.xla_model as xm
-device = xm.xla_device()
-pipe.to(device)
-```
-
-Additionally, it is important to note that the first time you run inference with XLA, it will take a long time to compile. For example, compilation time for stable diffusion XL model inference from HuggingFace can take about an hour to compile, whereas the actual inference may take only 5 seconds, depending on the batch size. Likewise, a GPT-2 model can take about 10-15 mins to compile, after which the training epoch time becomes much faster. This is because XLA builds a graph of the computation that will be performed, and then optimizes this graph for the specific hardware that it is running on. However, once the graph has been compiled, it can be reused for subsequent inferences, which will be much faster. Therefore, if you are only running inference once, you may not benefit from using XLA. However, if you are running inference multiple times, or if you are running inference on a list of prompts, you will start to see the advantages of XLA after the first few inferences. For example, if you run inference on a list of 10 prompts, the first inference (maybe two[^1]) may take a long time to compile, but the remaining inference steps will be much faster. This is because XLA will reuse the graph that it compiled for the first inference.
-
-If you try to run the code without making any additional changes, you will notice that the compilation time is very long (>6 hours). This is because the XLA compiler tries to build a single graph for all of the scheduler steps at once similar to what we have discussed in the previous example. To make the code run faster, we need to break the graph up into smaller pieces with `xm.mark_step()` and reuse them in the next steps. This happens inside the `pipe.__call__` [function](https://github.com/huggingface/diffusers/blob/2b1786735e27bc97f4d4699712292d5c463a7380/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L559) in [these lines](https://github.com/huggingface/diffusers/blob/2b1786735e27bc97f4d4699712292d5c463a7380/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L805-L839). Disabling the progress bar, removing callbacks and adding `xm.mark_step()` at the end of the for loop speeds up the code significantly. Changes are provided in this [commit](https://github.com/huggingface/diffusers/compare/main...pytorch-tpu:diffusers:main).
-
-
-Additionally, the `self.scheduler.step()` function, which by default uses the DPMSolverMultistepScheduler scheduler, has a few issues that are described in the
-[PyTorch XLA caveats](https://pytorch.org/xla/release/2.0/index.html#known-performance-caveats). The `.nonzero()` and `.item()` calls in this function send requests to the CPU for tensor evaluation, which trigger device-host communication. This is not desirable, as it can slow down the code. In this particular case, we can avoid these calls by passing the index to the function directly. This will prevent the function from sending requests to the CPU, and will improve the performance of the code. Changes are available in [this](https://github.com/pytorch-tpu/diffusers/commit/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d) commit. The code now is ready to be run on TPUs.
-
-[^1]: 0 and 1 are magic numbers in XLA and treated as constants in the HLO. So if there is a random number generator in the code that can generate these values, the code will compile for each value separately. This can be disabled with `XLA_NO_SPECIAL_SCALARS=1` environment variable.
-
-
-## Profiling and performance analysis
-
-To further investigate the performance of the model, we can profile it using the profiling [guide](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm). As a rule of thumb, the profiling script should be run with the maximum batch size that fits into the memory for [optimal memory usage](https://cloud.google.com/tpu/docs/performance-guide). It also helps to overlap tracing of the code with device execution which leads to more optimal device usage. The duration of profiling should be long enough to capture at least one step. Good performance of the model on TPUs means that device-host communication is minimized and the device is constantly running processes with no idle time.
-
-Starting a server in the `inference_tpu_*.py` file and running `capture_profile.py` script as described in the guide will give us information on processes that run on the devices. Currently, only one XLA device is profiled. To better understand the TPU idle time (gaps in the profile), profiling traces (`xp.Trace()`) should be added to the code. The `xp.Trace()` measures the time it takes to trace the python code on the host machine wrapped with the trace. For this example, `xp.Trace()` traces were added inside the [pipeline](https://github.com/ssusie/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py) and the [U-net model](https://github.com/ssusie/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py) to measure the time to run specific sections of the code on the host (CPU).
-
-If the gaps in the profile are due to Python code tracing that happens on the host, then this might be a bottleneck and there is no further straightforward optimization that can be done. Otherwise, the code should be analyzed further to understand the caveats and improve the performance further. Note that you cannot `xp.Trace()` wrap portions of the code where `xm.mark_step()` is called.
-
-To illustrate this we can look at already captured profiles that were uploaded to tensorboard following the profiling guide.
-
-Starting from Stable Diffusion model version 2.1
-
-If we capture a profile without inserting any traces, we will see the following:
-
-![Alt text](assets/image.png)
-
-The single TPU device on v4-8, which has two cores, appears to be busy. There are no significant gaps in their usage, except for a small one in the middle. If we scroll up to try to find which process is occupying the host machine, we will not find any information. Therefore, we will add `xp.traces` to the pipeline [file](https://github.com/pytorch-tpu/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) as well as the U-net [function](https://github.com/pytorch-tpu/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py). The latter may not be useful for this particular use case, but it does demonstrate how traces can be added in different places and how their information is displayed in TensorBoard.
-
-If we add traces and re-capture the profile with the largest batch size that can fit on the device (32 in this case), we will see that the gap in the device is caused by a Python process that is running on the host machine.
-![Alt text](assets/image-1.png)
-![Alt text](assets/image-2.png)
-
-We can use the appropriate tool to zoom in on the timeline and see which process is running during that period. This is when the Python code tracing happens on the host, and we cannot improve the tracing further at this point.
-
-
-Now, let's examine the XL version of the model and do the same thing. We will add traces to the pipeline [file](https://github.com/pytorch-tpu/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py) in the same way that we did for the 2.1 version and capture a profile.
-
-![Alt text](assets/image-4.png)
-
-This time, in addition to the large gap in the middle, which is caused by the `pipe_watermark` tracing, there are many small gaps between the inference steps within [this loop](https://github.com/pytorch-tpu/diffusers/blob/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L814-L830).
-
-First look closer into the large gap that is caused by `pipe_watermark`. The gap is preceded with `TransferFromDevice` which indicates that something is happening on the host machine that is waiting for computation to finish before proceeding. Looking into watermark [code](https://github.com/pytorch-tpu/diffusers/blob/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d/src/diffusers/pipelines/stable_diffusion_xl/watermark.py#L29), we can see that tensors are transferred to cpu and converted to numpy arrays in order to be processed with `cv2` and `pywt` libraries later. Since this part is not straightforward to optimize, we will leave this as is.
-
-Now if we zoom in on the loop, we can see that the graph within the loop is broken into smaller parts because the `TransferFromDevice` operation happens.
-![Alt text](assets/image-3.png)
-
-
-If we investigate the U-Net function and the scheduler, we can see that the U-Net code does not contain any optimization targets for PyTorch/XLA. However, there are `.item()` and `.nonzero()` calls inside the [scheduler.step](https://github.com/huggingface/diffusers/blob/15782fd506e8c4a7c2b288fc2e558bd77fdfa51a/src/diffusers/schedulers/scheduling_euler_discrete.py#L371). We can [rewrite](https://github.com/pytorch-tpu/diffusers/blob/0243d2ef9c2c7bc06956bb1bcc92c23038f6519d/src/diffusers/schedulers/scheduling_euler_discrete.py#L310) the function to avoid those calls. If we fix this issue and rerun a profile, we will not see much difference. However, since we have reduced the device-host communication that was introducing smaller graphs, we allowed the compiler to optimize the code better. The function [scale_model_input](https://github.com/huggingface/diffusers/blob/15782fd506e8c4a7c2b288fc2e558bd77fdfa51a/src/diffusers/schedulers/scheduling_euler_discrete.py#L205) has similar issues, and we can fix these by making the changes we made above to the `step` function. Overall, since many of the gaps are caused from python level code tracing and graph building, these gaps are not possible to optimize with the current version of PyTorch XLA, but we may see improvements in the future when dynamo is enabled in PyTorch XLA.
-
-
-## Running on Multiple TPU Devices
-
-To use multiple TPU devices, you can use the `xmp.spawn` function to spawn the function you ran on a single device to multiple devices. The `xmp.spawn` function will start processes on multiple TPU devices and sync them when needed. This can be done by passing the `index` argument to the function that runs on a single device. For example,
-```
-import torch_xla.distributed.xla_multiprocessing as xmp
-
-def my_function(index):
- # function that runs on a single device
-
-xmp.spawn(my_function, args=(0,), nprocs=4)
-```
-
-In this example, the `my_function` function will be spawned on 4 TPU devices on v4-8, with each device being assigned an index from 0 to 3.
-
-[This file](https://github.com/ssusie/diffusers/blob/main/examples/text_to_image/inference_tpu_multidevice.py) illustrates how xmp.spawn can be used to run stable diffusion 2.1 version on multiple TPU devices. For this version similar to the above changes were made to the [pipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) file.
-
-
-## Running on Pods
-Once you have the code for running on a single host device, there is no further change needed. You can create the TPU pod, for example, by following these [instructions](https://cloud.google.com/tpu/docs/pytorch-pods#create-tpu-vm). Then run your script with
-```
-gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
- --zone=${ZONE} \
- --worker=all \
- --command="python3 your_script.py"
-```
-
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 7d1f51ea0dd..0d7182a45f1 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -91,7 +91,6 @@ debug
.. autofunction:: metric_names
.. autofunction:: metric_data
-.. mdinclude:: ../pytorch_xla_overview.md
.. mdinclude:: ../../TROUBLESHOOTING.md
.. mdinclude:: ../pjrt.md
.. mdinclude:: ../dynamo.md
diff --git a/test/run_tests.sh b/test/run_tests.sh
index 7dbeba95bd1..2095dab3cb5 100755
--- a/test/run_tests.sh
+++ b/test/run_tests.sh
@@ -82,16 +82,6 @@ function run_test_without_functionalization {
XLA_DISABLE_FUNCTIONALIZATION=1 run_test "$@"
}
-function run_use_bf16 {
- echo "Running with XLA_USE_BF16: $@"
- XLA_USE_BF16=1 run_test "$@"
-}
-
-function run_downcast_bf16 {
- echo "Running with XLA_DOWNCAST_BF16: $@"
- XLA_DOWNCAST_BF16=1 run_test "$@"
-}
-
function run_xla_ir_debug {
echo "Running with XLA_IR_DEBUG: $@"
XLA_IR_DEBUG=1 run_test "$@"
@@ -191,7 +181,7 @@ function run_xla_op_tests1 {
run_test "$CDIR/dynamo/test_num_output.py"
run_test "$CDIR/dynamo/test_graph_input_matcher.py"
run_save_tensor_ir "$CDIR/dynamo/test_dynamo_graph_dump.py"
- run_use_bf16 "$CDIR/test_data_type.py"
+ run_test "$CDIR/test_data_type.py"
run_xla_ir_debug "$CDIR/test_env_var_mapper.py"
run_xla_hlo_debug "$CDIR/test_env_var_mapper.py"
run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_save_load.py"
@@ -200,7 +190,6 @@ function run_xla_op_tests1 {
}
function run_xla_op_tests2 {
- run_downcast_bf16 "$CDIR/test_data_type.py"
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_while_loop.py"
run_test "$CDIR/test_autocast.py"
diff --git a/test/test_data_type.py b/test/test_data_type.py
index 9b7f55ff148..f676b566206 100644
--- a/test/test_data_type.py
+++ b/test/test_data_type.py
@@ -13,63 +13,26 @@ def check_env_flag(name, default=''):
class XlaDataTypeTest(unittest.TestCase):
- def test_datatype_f32(self):
- t1 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device())
- t2 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device())
- t3 = torch.div(t1, t2, rounding_mode='floor')
- assert t3.dtype == torch.float
-
- hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3])
- device_data_hlo = hlo_text.split('\n')[1]
- assert 'xla::device_data' in device_data_hlo, device_data_hlo
- if check_env_flag('XLA_USE_BF16') or check_env_flag('XLA_DOWNCAST_BF16'):
- assert 'bf16' in device_data_hlo, device_data_hlo
- elif check_env_flag('XLA_USE_FP16') or check_env_flag('XLA_DOWNCAST_FP16'):
- assert 'f16' in device_data_hlo, device_data_hlo
- else:
- assert 'f32' in device_data_hlo, device_data_hlo
-
- def test_datatype_f64(self):
- t1 = torch.tensor([2.0, 3.0], dtype=torch.double, device=xm.xla_device())
- t2 = torch.tensor([2.0, 3.0], dtype=torch.double, device=xm.xla_device())
- t3 = torch.div(t1, t2, rounding_mode='floor')
- assert t3.dtype == torch.double
-
- hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3])
- device_data_hlo = hlo_text.split('\n')[1]
- assert 'xla::device_data' in device_data_hlo, device_data_hlo
- if check_env_flag('XLA_USE_BF16'):
- assert 'bf16' in device_data_hlo, device_data_hlo
- elif check_env_flag('XLA_USE_FP16'):
- assert 'f16' in device_data_hlo, device_data_hlo
- elif check_env_flag('XLA_DOWNCAST_BF16') or check_env_flag(
- 'XLA_DOWNCAST_FP16'):
- assert 'f32' in device_data_hlo, device_data_hlo
- else:
- assert 'f64' in device_data_hlo, device_data_hlo
-
- def test_datatype_f32_div_f64(self):
- t1 = torch.rand(2, 2, dtype=torch.float, device=xm.xla_device())
- t2 = t1 / 2.0
- hlo_text = torch_xla._XLAC._get_xla_tensors_text([t2])
- assert t2.dtype == torch.float
- assert 'f64' not in hlo_text
-
- def test_datatype_U16_32_64(self):
-
- def _dtype_round_trip(dtype):
- t = torch.randint(0, 128, (2, 4), dtype=dtype).to(xm.xla_device())
- return t.cpu().dtype
-
- for dtype in [torch.uint16, torch.uint32, torch.uint64]:
- dtype2 = _dtype_round_trip(dtype)
- self.assertTrue(dtype == dtype2)
+ def test_module_to_dtype(self):
+ device = torch_xla.device()
+ linear = torch.nn.Linear(
+ 5, 10, dtype=torch.float32).to(device).to(torch.bfloat16)
+ input = torch.randn(
+ 10,
+ 5,
+ ).to(device).to(torch.bfloat16)
+ xm.mark_step()
+ res = linear(input)
+
+ hlo_text = torch_xla._XLAC._get_xla_tensors_text([res])
+ res_hlo = hlo_text.split('\n')[-3]
+ assert 'bf16' in res_hlo, res_hlo
+
+ linear_weight_hlo = torch_xla._XLAC._get_xla_tensors_text([linear.weight
+ ]).split('\n')[-3]
+ assert 'bf16' in linear_weight_hlo, linear_weight_hlo
if __name__ == '__main__':
- print(f'XLA_USE_BF16: {os.getenv("XLA_USE_BF16")}')
- print(f'XLA_USE_FP16: {os.getenv("XLA_USE_FP16")}')
- print(f'XLA_DOWNCAST_BF16: {os.getenv("XLA_DOWNCAST_BF16")}')
- print(f'XLA_DOWNCAST_FP16: {os.getenv("XLA_DOWNCAST_FP16")}')
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
diff --git a/test/test_pallas.py b/test/test_pallas.py
index 25c487912cf..1127198e452 100644
--- a/test/test_pallas.py
+++ b/test/test_pallas.py
@@ -267,19 +267,6 @@ def test_flash_attention_wrapper_causal(self):
self.assertFalse(torch.allclose(o.cpu(), expected_o.cpu()))
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)
- @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
- "This test only works on TPUv3+.")
- @unittest.mock.patch.dict(os.environ, {"XLA_USE_BF16": "1"})
- def test_flash_attention_wrapper_bf16(self):
- from torch_xla.experimental.custom_kernel import flash_attention
-
- q = torch.randn(3, 2, 128, 4).to("xla")
- k = torch.randn(3, 2, 128, 4).to("xla")
- v = torch.randn(3, 2, 128, 4).to("xla")
-
- # No exception being raised.
- o = flash_attention(q, k, v)
-
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_multiple_returns(self):
import jax._src.pallas.mosaic.pallas_call_registration
diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py
index 3342d3901d6..6d8615ee967 100644
--- a/test/test_train_mp_imagenet.py
+++ b/test/test_train_mp_imagenet.py
@@ -108,7 +108,7 @@
)
# Best config to achieve peak performance based on TPU version
-# 1. It is recommended to use this config in conjuntion with XLA_USE_BF16=1 Flag.
+# 1. It is recommended to move the model to bf16 before training.
# 2. Hyperparameters can be tuned to further improve the accuracy.
# usage: python3 /usr/share/pytorch/xla/test/test_train_mp_imagenet.py --model=resnet50 \
# --fake_data --num_epochs=10 --log_steps=300 \
diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py
index 2839ae36758..76ac8e0a3f2 100644
--- a/torch_xla/__init__.py
+++ b/torch_xla/__init__.py
@@ -2,6 +2,7 @@
import os
import re
import tempfile
+import warnings
import torch
import _XLAC
@@ -140,9 +141,21 @@ def _setup_tpu_vm_library_path() -> bool:
return False
+def _check_deprecated_env_var():
+ deprecated_env_vars = [
+ 'XLA_USE_BF16', 'XLA_USE_FP16', 'XLA_DOWNCAST_BF16', 'XLA_DOWNCAST_FP16',
+ 'XLA_USE_32BIT_LONG'
+ ]
+ for env_var in deprecated_env_vars:
+ if os.environ.get(env_var):
+ warnings.warn(f"The environment variable '{env_var}' is deprecated "
+ "Please update your code to avoid using it.")
+
+
# These needs to be called before the _XLAC module is loaded.
_setup_default_env()
_setup_xla_flags()
+_check_deprecated_env_var()
if int(os.environ.get('PT_XLA_DEBUG', '0')):
_fd, _tmp_fname = _setup_debug_env()
diff --git a/torch_xla/csrc/dtype.cpp b/torch_xla/csrc/dtype.cpp
index c5310a9e2ea..8232381a867 100644
--- a/torch_xla/csrc/dtype.cpp
+++ b/torch_xla/csrc/dtype.cpp
@@ -5,92 +5,6 @@
namespace torch_xla {
-namespace {
-
-bool ShouldUseBF16() {
- bool use_bf16 = runtime::sys_util::GetEnvBool("XLA_USE_BF16", false);
- if (use_bf16) {
- std::cout
- << "XLA_USE_BF16 will be deprecated after the 2.4 release, please "
- "convert your model to bf16 directly\n";
- TF_LOG(INFO) << "Using BF16 data type for floating point values";
- }
- return use_bf16;
-}
-
-bool ShouldUseF16() {
- bool use_fp16 = runtime::sys_util::GetEnvBool("XLA_USE_FP16", false);
- if (use_fp16) {
- std::cout
- << "XLA_USE_FP16 will be deprecated after the 2.4 release, please "
- "convert your model to fp16 directly\n";
- TF_LOG(INFO) << "Using F16 data type for floating point values";
- }
- return use_fp16;
-}
-
-bool ShouldDowncastToBF16() {
- bool downcast_bf16 =
- runtime::sys_util::GetEnvBool("XLA_DOWNCAST_BF16", false);
- if (downcast_bf16) {
- std::cout
- << "XLA_DOWNCAST_BF16 will be deprecated after the 2.4 release, please "
- "downcast your model directly\n";
- TF_LOG(INFO) << "Downcasting floating point values, F64->F32, F32->BF16";
- }
- return downcast_bf16;
-}
-
-bool ShouldDowncastToF16() {
- bool downcast_fp16 =
- runtime::sys_util::GetEnvBool("XLA_DOWNCAST_FP16", false);
- if (downcast_fp16) {
- std::cout
- << "XLA_DOWNCAST_FP16 will be deprecated after the 2.4 release, please "
- "downcast your model directly\n";
- TF_LOG(INFO) << "Downcasting floating point values, F64->F32, F32->FP16";
- }
- return downcast_fp16;
-}
-
-bool ShouldUse32BitLong() {
- bool use_32bit_long =
- runtime::sys_util::GetEnvBool("XLA_USE_32BIT_LONG", false);
- if (use_32bit_long) {
- std::cout
- << "XLA_USE_32BIT_LONG will be deprecated after the 2.4 release\n";
- TF_LOG(INFO) << "Using 32bit integers for kLong values";
- }
- return use_32bit_long;
-}
-
-bool UseBF16() {
- static bool use_bf16 = ShouldUseBF16();
- return use_bf16;
-}
-
-bool UseF16() {
- static bool use_fp16 = ShouldUseF16();
- return use_fp16;
-}
-
-bool DowncastBF16() {
- static bool downcast_bf16 = ShouldDowncastToBF16();
- return downcast_bf16;
-}
-
-bool DowncastF16() {
- static bool downcast_fp16 = ShouldDowncastToF16();
- return downcast_fp16;
-}
-
-bool Use32BitLong() {
- static bool use_32bit_long = ShouldUse32BitLong();
- return use_32bit_long;
-}
-
-} // namespace
-
at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) {
switch (xla_type) {
case xla::PrimitiveType::BF16:
@@ -167,22 +81,12 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType(
XlaDeviceType hw_type = static_cast(device.type());
switch (type) {
case xla::PrimitiveType::F64:
- if (UseF16()) {
- return xla::PrimitiveType::F16;
- }
- if (UseBF16()) {
- return xla::PrimitiveType::BF16;
- }
- if (DowncastBF16() || DowncastF16() || hw_type == XlaDeviceType::NEURON) {
+ if (hw_type == XlaDeviceType::NEURON) {
return xla::PrimitiveType::F32;
}
return xla::PrimitiveType::F64;
case xla::PrimitiveType::F32:
- if (UseF16() || DowncastF16()) {
- return xla::PrimitiveType::F16;
- }
- return UseBF16() || DowncastBF16() ? xla::PrimitiveType::BF16
- : xla::PrimitiveType::F32;
+ return xla::PrimitiveType::F32;
case xla::PrimitiveType::U16:
return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::U16
: xla::PrimitiveType::U32;
@@ -190,9 +94,9 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType(
return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::S16
: xla::PrimitiveType::S32;
case xla::PrimitiveType::S64:
- return Use32BitLong() ? xla::PrimitiveType::S32 : xla::PrimitiveType::S64;
+ return xla::PrimitiveType::S64;
case xla::PrimitiveType::U64:
- return Use32BitLong() ? xla::PrimitiveType::U32 : xla::PrimitiveType::U64;
+ return xla::PrimitiveType::U64;
case xla::PrimitiveType::C128:
return xla::PrimitiveType::C128;
default:
@@ -210,14 +114,11 @@ at::ScalarType MaybeUpcastToHostTorchType(xla::PrimitiveType xla_type) {
at::ScalarType scalar_type = TorchTypeFromXlaType(xla_type);
switch (scalar_type) {
case at::ScalarType::BFloat16:
- return UseBF16() || DowncastBF16() ? at::ScalarType::Float
- : at::ScalarType::BFloat16;
+ return at::ScalarType::BFloat16;
case at::ScalarType::Half:
- return UseF16() || DowncastF16() ? at::ScalarType::Float
- : at::ScalarType::Half;
+ return at::ScalarType::Half;
case at::ScalarType::Float:
- return DowncastBF16() || DowncastF16() ? at::ScalarType::Double
- : at::ScalarType::Float;
+ return at::ScalarType::Float;
default:
return scalar_type;
}
diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py
index f0e09d7e81a..95d5e9f0df7 100644
--- a/torch_xla/experimental/custom_kernel.py
+++ b/torch_xla/experimental/custom_kernel.py
@@ -12,8 +12,6 @@
from torch.library import impl
from torch_xla.core.xla_model import XLA_LIB
-_XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0") == "1"
-
def _extract_backend_config(
module: "jaxlib.mlir._mlir_libs._mlir.ir.Module") -> Optional[str]:
@@ -63,12 +61,8 @@ def convert_torch_dtype_to_jax(dtype: torch.dtype) -> "jnp.dtype":
import jax.numpy as jnp
if dtype == torch.float32:
- if _XLA_USE_BF16:
- return jnp.bfloat16
return jnp.float32
elif dtype == torch.float64:
- if _XLA_USE_BF16:
- return jnp.bfloat16
return jnp.float64
elif dtype == torch.float16:
return jnp.float16