You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Provide a consistent API for external packages to register available PJRT C API Plugins.
TPU and GPU should use the same plugin API.
Motivation
At the time of writing, PyTorch/XLA supports three PJRT Plugins: TPU, Intel XPU, and AWS Neuron. Each of these plugins required updates to PyTorch/XLA itself, and any future plugins will require the same. Hardware vendors should not have to make changes in this repository just to register their devices.
Additional context
High-level API for configuration and initialization
A working PJRT C API implementation alone is not enough to cleanly integrate with PyTorch/XLA. We will need to provide at least the following additional functionality in addition to what is available through runtime APIs:
Plugin location: The low-level ComputationClient will also need to know where to find the PJRT Plugin binary, either by looking in locations known by convention (e.g. bundled libtpu binaries) or by calling the plugin package (e.g. libtpu.get_library_path()).
local_rank and local_world_size: the concept of "local" devices exists in PyTorch, but not in PJRT.
host_index: PJRT only has a concept of "process", not host.
Multi-process initialization: to assign each device to a separate process, PyTorch/XLA performs additional setup. Note: this assignment must be done before the PJRT client is created.
This functionality can be bundled under a common interface, implemented for each plugin and registered with PyTorch/XLA before client initialization. #5644 contains a (likely incomplete) DevicePlugin interface and an example implementation for TPU.
Hardcoded device types
Currently, the available device types have to be hardcoded across PyTorch/XLA, meaning for every new device, vendors must send a PR to us. There are lists of devices in the following locations:
Encoding devices as strings containing the device type (e.g. GPU:0) is a holdover from the XRT days when these were treated as identifiers of remote devices. In practice, most of our code is device agnostic, except for a few special cases (see below). For most cases, we can simply use a placeholder "PLUGIN" XlaDeviceType, and conditions that expect a specific set of strings can be relaxed.
Device special cases
There are some limited cases where we do check the XlaDeviceType while building the XLA graph, either to provide additional optimizations, device-specific defaults, or avoid incompatibilities.
Many old conditions specific to TPU are now likely irrelevant as the TPU runtime has improved support (see #6197). Some other examples that may still be relevant:
Replacing sparse gather with dense gather on TPU and Neuron
With two exceptions (listed below), we can capture these special cases in the DevicePlugin API and make the codebase outside of the runtime mostly device agnostic. TODO: elaborate and prototype
Special special case: CUDA
Since CUDA is implemented in both upstream PyTorch and XLA, there are cases where we specifically want to know if we are using CUDA compared to any other device backend. For example, we use a different autocast wrapper for CUDA GPU specifically because it exists in upstream:
We will also need to identify GPUs in the future if we want to implement zero-copy "transfers" from eager GPU tensors to XLA:GPU tensors. CUDA GPUs are a special case, however, since they are the accelerator supported in torch.
Special special case: SPMD
We model a group of devices for SPMD as a single xla:0 "device" with the SPMDXlaDeviceType. The SPMD virtual device has substantially different semantics than an actual runtime device, so we can't escape needing special cases.
Plugin packaging
In terms of packaging, we can follow JAX's convention and request that plugin packages register themselves as "entry points". Plugin implementers can add something like the following to their setup.py file:
setup(
# Assuming a plugin for a GPU device
entry_points = {
'torch_xla.plugins': [
'gpu = torch_xla_gpu:GpuPlugin',
]
}
)
Within PyTorch/XLA, we can initialize and register all of the plugins automatically using Python's standard library:
pjrt_plugins = entry_points(group='torch_xla.plugins')
for plugin in pjrt_plugins:
plugins.register_plugin(plugin.name, plugin.load())
TODO
Create DevicePlugin interface and use it for existing device types
TPU
GPU
Neuron
XPU
Add placeholder PLUGIN device type and remove hardcoded strings
Remove as many device special cases as possible, moving logic for supported dtypes and conditional optimizations into DevicePlugin
Document packaging process for PJRT plugins
Automatically register available plugins
The text was updated successfully, but these errors were encountered:
I never fully addressed the device special cases throughout the code. So far, this hasn't been a problem. If new device special cases come up, I recommend extending the DevicePlugin interface to capture that (e.g. adding a configurable dense_gather_factor property that others can implement).
🚀 Feature
Motivation
At the time of writing, PyTorch/XLA supports three PJRT Plugins: TPU, Intel XPU, and AWS Neuron. Each of these plugins required updates to PyTorch/XLA itself, and any future plugins will require the same. Hardware vendors should not have to make changes in this repository just to register their devices.
Additional context
High-level API for configuration and initialization
A working PJRT C API implementation alone is not enough to cleanly integrate with PyTorch/XLA. We will need to provide at least the following additional functionality in addition to what is available through runtime APIs:
libtpu.get_library_path()
).local_rank
andlocal_world_size
: the concept of "local" devices exists in PyTorch, but not in PJRT.host_index
: PJRT only has a concept of "process", not host.This functionality can be bundled under a common interface, implemented for each plugin and registered with PyTorch/XLA before client initialization. #5644 contains a (likely incomplete)
DevicePlugin
interface and an example implementation for TPU.Hardcoded device types
Currently, the available device types have to be hardcoded across PyTorch/XLA, meaning for every new device, vendors must send a PR to us. There are lists of devices in the following locations:
XlaDeviceType
enumerationXlaDeviceType
to stringEncoding devices as strings containing the device type (e.g.
GPU:0
) is a holdover from the XRT days when these were treated as identifiers of remote devices. In practice, most of our code is device agnostic, except for a few special cases (see below). For most cases, we can simply use a placeholder "PLUGIN
"XlaDeviceType
, and conditions that expect a specific set of strings can be relaxed.Device special cases
There are some limited cases where we do check the
XlaDeviceType
while building the XLA graph, either to provide additional optimizations, device-specific defaults, or avoid incompatibilities.Many old conditions specific to TPU are now likely irrelevant as the TPU runtime has improved support (see #6197). Some other examples that may still be relevant:
xla/torch_xla/csrc/data_ops.cpp
Lines 35 to 44 in efa6fcf
xla/torch_xla/csrc/random.cpp
Lines 23 to 26 in efa6fcf
xla/torch_xla/csrc/aten_xla_type.cpp
Lines 3156 to 3166 in efa6fcf
With two exceptions (listed below), we can capture these special cases in the
DevicePlugin
API and make the codebase outside of the runtime mostly device agnostic. TODO: elaborate and prototypeSpecial special case: CUDA
Since CUDA is implemented in both upstream PyTorch and XLA, there are cases where we specifically want to know if we are using CUDA compared to any other device backend. For example, we use a different autocast wrapper for CUDA GPU specifically because it exists in upstream:
xla/torch_xla/csrc/tensor_impl.cpp
Lines 79 to 84 in efa6fcf
We will also need to identify GPUs in the future if we want to implement zero-copy "transfers" from eager GPU tensors to XLA:GPU tensors. CUDA GPUs are a special case, however, since they are the accelerator supported in
torch
.Special special case: SPMD
We model a group of devices for SPMD as a single
xla:0
"device" with theSPMD
XlaDeviceType
. The SPMD virtual device has substantially different semantics than an actual runtime device, so we can't escape needing special cases.Plugin packaging
In terms of packaging, we can follow JAX's convention and request that plugin packages register themselves as "entry points". Plugin implementers can add something like the following to their setup.py file:
Within PyTorch/XLA, we can initialize and register all of the plugins automatically using Python's standard library:
TODO
DevicePlugin
interface and use it for existing device typesPLUGIN
device type and remove hardcoded stringsDevicePlugin
The text was updated successfully, but these errors were encountered: