forked from pytorch/xla
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Experimental GPU PJRT Plugin (pytorch#6240)
- Loading branch information
1 parent
0cd6f10
commit 29977cb
Showing
3 changed files
with
70 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# CUDA PJRT plugin (experimental) | ||
|
||
This directory contains an experimental implementation of the PJRT GPU client as | ||
a plugin. The actual implementation of the PJRT C API lives in the main OpenXLA | ||
repository (see `bazel build` command below). | ||
|
||
## Building | ||
|
||
```bash | ||
# Build PJRT plugin | ||
bazel build @xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=1 --config=cuda | ||
# Copy to package dir | ||
cp bazel-bin/external/xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so plugins/cuda/torch_xla_cuda_plugin | ||
|
||
# Build wheel | ||
pip wheel plugins/cuda | ||
# Or install directly | ||
pip install plugins/cuda | ||
``` | ||
|
||
## Usage | ||
|
||
```python | ||
import os | ||
|
||
# Log device type | ||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' | ||
os.environ['TF_CPP_VMODULE'] = 'pjrt_registry=5' | ||
|
||
from torch_xla.experimental import plugins | ||
import torch_xla_cuda_plugin | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.runtime as xr | ||
|
||
# Use dynamic plugin instead of built-in CUDA support | ||
plugins.use_dynamic_plugins() | ||
plugins.register_plugin('CUDA', torch_xla_cuda_plugin.GpuPlugin()) | ||
xr.set_device_type('CUDA') | ||
|
||
print(xm.xla_device()) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[build-system] | ||
requires = ["setuptools"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[project] | ||
name = "torch_xla_cuda_plugin" | ||
version = "0.0.1" | ||
authors = [ | ||
{name = "Will Cromar", email = "wcromar@google.com"}, | ||
] | ||
description = "CUDA Plugin" | ||
requires-python = ">=3.8" | ||
|
||
[tool.setuptools.package-data] | ||
torch_xla_cuda_plugin = ["*.so"] | ||
|
||
[project.entry-points."torch_xla.plugins"] | ||
gpu = "torch_xla_cuda_plugin:GpuPlugin" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import os | ||
from torch_xla.experimental import plugins | ||
from torch_xla._internal import tpu | ||
|
||
class GpuPlugin(plugins.DevicePlugin): | ||
def library_path(self) -> str: | ||
return os.path.join(os.path.dirname(__file__), 'pjrt_c_api_gpu_plugin.so') | ||
|
||
def physical_chip_count(self) -> int: | ||
# TODO: default to actual device count | ||
return os.getenv('GPU_NUM_DEVICES', 1) |