Skip to content

Commit

Permalink
Experimental GPU PJRT Plugin (pytorch#6240)
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar authored Jan 3, 2024
1 parent 0cd6f10 commit 29977cb
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 0 deletions.
41 changes: 41 additions & 0 deletions plugins/cuda/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# CUDA PJRT plugin (experimental)

This directory contains an experimental implementation of the PJRT GPU client as
a plugin. The actual implementation of the PJRT C API lives in the main OpenXLA
repository (see `bazel build` command below).

## Building

```bash
# Build PJRT plugin
bazel build @xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=1 --config=cuda
# Copy to package dir
cp bazel-bin/external/xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so plugins/cuda/torch_xla_cuda_plugin

# Build wheel
pip wheel plugins/cuda
# Or install directly
pip install plugins/cuda
```

## Usage

```python
import os

# Log device type
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ['TF_CPP_VMODULE'] = 'pjrt_registry=5'

from torch_xla.experimental import plugins
import torch_xla_cuda_plugin
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr

# Use dynamic plugin instead of built-in CUDA support
plugins.use_dynamic_plugins()
plugins.register_plugin('CUDA', torch_xla_cuda_plugin.GpuPlugin())
xr.set_device_type('CUDA')

print(xm.xla_device())
```
18 changes: 18 additions & 0 deletions plugins/cuda/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"

[project]
name = "torch_xla_cuda_plugin"
version = "0.0.1"
authors = [
{name = "Will Cromar", email = "wcromar@google.com"},
]
description = "CUDA Plugin"
requires-python = ">=3.8"

[tool.setuptools.package-data]
torch_xla_cuda_plugin = ["*.so"]

[project.entry-points."torch_xla.plugins"]
gpu = "torch_xla_cuda_plugin:GpuPlugin"
11 changes: 11 additions & 0 deletions plugins/cuda/torch_xla_cuda_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import os
from torch_xla.experimental import plugins
from torch_xla._internal import tpu

class GpuPlugin(plugins.DevicePlugin):
def library_path(self) -> str:
return os.path.join(os.path.dirname(__file__), 'pjrt_c_api_gpu_plugin.so')

def physical_chip_count(self) -> int:
# TODO: default to actual device count
return os.getenv('GPU_NUM_DEVICES', 1)

0 comments on commit 29977cb

Please sign in to comment.