Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve PJRT C API support for GPUs and custom hardware #6242

Closed
8 of 9 tasks
will-cromar opened this issue Dec 28, 2023 · 1 comment
Closed
8 of 9 tasks

Improve PJRT C API support for GPUs and custom hardware #6242

will-cromar opened this issue Dec 28, 2023 · 1 comment
Assignees

Comments

@will-cromar
Copy link
Collaborator

will-cromar commented Dec 28, 2023

🚀 Feature

  • Provide a consistent API for external packages to register available PJRT C API Plugins.
  • TPU and GPU should use the same plugin API.

Motivation

At the time of writing, PyTorch/XLA supports three PJRT Plugins: TPU, Intel XPU, and AWS Neuron. Each of these plugins required updates to PyTorch/XLA itself, and any future plugins will require the same. Hardware vendors should not have to make changes in this repository just to register their devices.

Additional context

High-level API for configuration and initialization

A working PJRT C API implementation alone is not enough to cleanly integrate with PyTorch/XLA. We will need to provide at least the following additional functionality in addition to what is available through runtime APIs:

  • Plugin location: The low-level ComputationClient will also need to know where to find the PJRT Plugin binary, either by looking in locations known by convention (e.g. bundled libtpu binaries) or by calling the plugin package (e.g. libtpu.get_library_path()).
  • local_rank and local_world_size: the concept of "local" devices exists in PyTorch, but not in PJRT.
  • host_index: PJRT only has a concept of "process", not host.
  • Multi-process initialization: to assign each device to a separate process, PyTorch/XLA performs additional setup. Note: this assignment must be done before the PJRT client is created.

This functionality can be bundled under a common interface, implemented for each plugin and registered with PyTorch/XLA before client initialization. #5644 contains a (likely incomplete) DevicePlugin interface and an example implementation for TPU.

Hardcoded device types

Currently, the available device types have to be hardcoded across PyTorch/XLA, meaning for every new device, vendors must send a PR to us. There are lists of devices in the following locations:

Encoding devices as strings containing the device type (e.g. GPU:0) is a holdover from the XRT days when these were treated as identifiers of remote devices. In practice, most of our code is device agnostic, except for a few special cases (see below). For most cases, we can simply use a placeholder "PLUGIN" XlaDeviceType, and conditions that expect a specific set of strings can be relaxed.

Device special cases

There are some limited cases where we do check the XlaDeviceType while building the XLA graph, either to provide additional optimizations, device-specific defaults, or avoid incompatibilities.

Many old conditions specific to TPU are now likely irrelevant as the TPU runtime has improved support (see #6197). Some other examples that may still be relevant:

  • Replacing sparse gather with dense gather on TPU and Neuron
    if (hw_type == XlaDeviceType::TPU || hw_type == XlaDeviceType::NEURON) {
    // XLA_DENSE_GATHER_FACTOR can be used to finely control the
    // sparsity check.
    static int dense_gather_factor =
    runtime::sys_util::GetEnvInt("XLA_DENSE_GATHER_FACTOR", 8192);
    int64_t input_elements = input_shape.dimensions()[dim];
    // Use a very conservative check so that we run dense gather
    // most of the time on TPU.
    return input_elements > dense_gather_factor * 10;
    }
  • Default RNG for GPU
    case XlaDeviceType::GPU:
    case XlaDeviceType::CUDA:
    case XlaDeviceType::ROCM:
    return "three_fry";
  • Use custom call only implemented for TPU
    // Only the XLA TPU backend for now implements the CustomCall required by
    // our XLA lowering.
    XlaDeviceType hw_type =
    static_cast<XlaDeviceType>(grad_output_tensor->GetDevice().type());
    if (hw_type != XlaDeviceType::TPU) {
    return at::native::call_fallback_fn<
    &xla_cpu_fallback,
    ATEN_OP(upsample_bilinear2d_backward)>::call(grad_output, output_size,
    input_size, align_corners,
    scales_h, scales_w);
    }

With two exceptions (listed below), we can capture these special cases in the DevicePlugin API and make the codebase outside of the runtime mostly device agnostic. TODO: elaborate and prototype

Special special case: CUDA

Since CUDA is implemented in both upstream PyTorch and XLA, there are cases where we specifically want to know if we are using CUDA compared to any other device backend. For example, we use a different autocast wrapper for CUDA GPU specifically because it exists in upstream:

if (dev_type == XlaDeviceType::GPU || dev_type == XlaDeviceType::CUDA ||
dev_type == XlaDeviceType::ROCM) {
auto autocast_cuda_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastCUDA);
auto autocast_xla_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastXLA);
key_set_ = (key_set_ - autocast_xla_ks) | autocast_cuda_ks;
}

We will also need to identify GPUs in the future if we want to implement zero-copy "transfers" from eager GPU tensors to XLA:GPU tensors. CUDA GPUs are a special case, however, since they are the accelerator supported in torch.

Special special case: SPMD

We model a group of devices for SPMD as a single xla:0 "device" with the SPMD XlaDeviceType. The SPMD virtual device has substantially different semantics than an actual runtime device, so we can't escape needing special cases.

Plugin packaging

In terms of packaging, we can follow JAX's convention and request that plugin packages register themselves as "entry points". Plugin implementers can add something like the following to their setup.py file:

setup(
    # Assuming a plugin for a GPU device
    entry_points = {
        'torch_xla.plugins': [
            'gpu = torch_xla_gpu:GpuPlugin',
        ]
    }
)

Within PyTorch/XLA, we can initialize and register all of the plugins automatically using Python's standard library:

pjrt_plugins = entry_points(group='torch_xla.plugins')
for plugin in pjrt_plugins:
    plugins.register_plugin(plugin.name, plugin.load())

TODO

  • Create DevicePlugin interface and use it for existing device types
    • TPU
    • GPU
    • Neuron
    • XPU
  • Add placeholder PLUGIN device type and remove hardcoded strings
  • Remove as many device special cases as possible, moving logic for supported dtypes and conditional optimizations into DevicePlugin
  • Document packaging process for PJRT plugins
  • Automatically register available plugins
@will-cromar
Copy link
Collaborator Author

Closing this. The TPU and GPU plugins are fully adopted, and we have stub plugins for the other backends as well. The packaging process for future plugins is documented at https://github.com/pytorch/xla/blob/master/docs/plugins.md

I never fully addressed the device special cases throughout the code. So far, this hasn't been a problem. If new device special cases come up, I recommend extending the DevicePlugin interface to capture that (e.g. adding a configurable dense_gather_factor property that others can implement).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants