Skip to content

Commit

Permalink
Add support for PyTorch Metal Performance Shaders (#685)
Browse files Browse the repository at this point in the history
* Add `test_slow_gpu` explosion-bot command

* Auto-format code with black (#682)

Co-authored-by: explosion-bot <explosion-bot@users.noreply.github.com>

* Add support for PyTorch Metal Performance Shaders

Nightly PyTorch versions add support for Metal Performance Shaders
(MPS). Metal is a low-level graphics API for Apple platforms that also
supports compute kernels (shaders). MPS is a framework of
highly-optimized compute and graphics kernels, including kernels for
neural networks. MPS is supported on both Apple Silicon, such as the M1
family of SoC, as well as a range of AMD GPUs used in Macs.

Since devices are handled in Thinc through a specific `Ops`
implementation (e.g. `CupyOps` == CUDA GPUs), this change introduces the
`MPSOps` class. This class is a subclass of `NumpyOps` or
`AppleOps` (when available). `MPSOps` does not override any methods, but
is used to signal to relevant code paths (e.g. `xp2torch`) that Torch
tensors should be placed on the MPS device.

The mapping in the previously introduced `get_torch_default_device`
function is updated to:

- `NumpyOps` -> `cpu`
- `CupyOps` -> `cuda:N`, where N is the selected CUDA device.
- `MPSOps` -> `mps`

to ensure placement of Torch tensors on the `mps` device when `MPSOps`
is active.

Finally, the following booleans have been added to or changed in
`compat`:

- `has_torch_mps` (new): PyTorch has MPS support
- `has_torch_mps_gpu` (new): PyTorch has MPS support and an
  MPS-capable GPU is available.
- `has_torch_cuda_gpu` (new): PyTorch has CUDA support and a
  CUDA-capable GPU is available.
- `has_torch_gpu` (changed): PyTorch has a GPU available (CUDA
  or MPS).

* Test PyTorch wrapper with all xp ops

* Azure: pin protobuf to fix Tensorflow

* Extend typing_extensions to <4.2.0 (#689)

* Fix type checking error

* Only back-off to NumpyOps on import error

We do not want to hide other issues while importing thinc_apple_ops.

* Remove unneeded `has_torch_mps` bool

* Add `has_gpu` bool and use it in `util`

* Replace another expression by has_gpu

* Set `has_torch_gpu` to `has_torch_cuda_gpu`

We need to decide whether we want to make the potentially breaking
change from `has_torch_cuda_gpu` to `has_torch_cuda_gpu or
has_torch_mps_gpu`. But since the latter is not needed for this PR,
remove the change.

* Update thinc/util.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

Co-authored-by: shademe <shadeMe@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: explosion-bot <explosion-bot@users.noreply.github.com>
Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
  • Loading branch information
6 people authored Jun 10, 2022
1 parent b8054fd commit 5beeaf2
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 36 deletions.
2 changes: 1 addition & 1 deletion thinc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .util import torch2xp, xp2torch, tensorflow2xp, xp2tensorflow, mxnet2xp, xp2mxnet
from .compat import has_cupy
from .backends import get_ops, set_current_ops, get_current_ops, use_ops
from .backends import Ops, CupyOps, NumpyOps, set_gpu_allocator
from .backends import Ops, CupyOps, MPSOps, NumpyOps, set_gpu_allocator
from .backends import use_pytorch_for_gpu_memory, use_tensorflow_for_gpu_memory

from .layers import Dropout, Embed, expand_window, HashEmbed, LayerNorm, Linear
Expand Down
8 changes: 7 additions & 1 deletion thinc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from .ops import Ops
from .cupy_ops import CupyOps
from .numpy_ops import NumpyOps
from .mps_ops import MPSOps
from ._cupy_allocators import cupy_tensorflow_allocator, cupy_pytorch_allocator
from ._param_server import ParamServer
from ..util import assert_tensorflow_installed, assert_pytorch_installed
from ..util import is_cupy_array, require_cpu
from ..util import get_torch_default_device, is_cupy_array, require_cpu
from .. import registry
from ..compat import cupy, has_cupy

Expand Down Expand Up @@ -48,6 +49,10 @@ def use_pytorch_for_gpu_memory() -> None: # pragma: no cover
(or vice versa), but do not currently have an implementation for it.
"""
assert_pytorch_installed()

if get_torch_default_device().type != "cuda":
return

pools = context_pools.get()
if "pytorch" not in pools:
pools["pytorch"] = cupy.cuda.MemoryPool(allocator=cupy_pytorch_allocator)
Expand Down Expand Up @@ -169,6 +174,7 @@ def _create_thread_local(
"ParamServer",
"Ops",
"CupyOps",
"MPSOps",
"NumpyOps",
"has_cupy",
]
4 changes: 2 additions & 2 deletions thinc/backends/cupy_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..types import DeviceTypes
from ..util import torch2xp, tensorflow2xp, mxnet2xp
from ..util import is_cupy_array
from ..util import is_torch_gpu_array, is_tensorflow_gpu_array, is_mxnet_gpu_array
from ..util import is_torch_cuda_array, is_tensorflow_gpu_array, is_mxnet_gpu_array
from ..compat import cupy, cupyx


Expand Down Expand Up @@ -62,7 +62,7 @@ def asarray(self, data, dtype=None):
# We'll try to perform a zero-copy conversion if possible.
if is_cupy_array(data):
array = data
elif is_torch_gpu_array(data):
elif is_torch_cuda_array(data):
array = torch2xp(data)
elif is_tensorflow_gpu_array(data):
array = tensorflow2xp(data)
Expand Down
26 changes: 26 additions & 0 deletions thinc/backends/mps_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import TYPE_CHECKING
import numpy

from .. import registry
from . import NumpyOps, Ops

if TYPE_CHECKING:
# Type checking does not work with dynamic base classes, since MyPy cannot
# determine against which base class to check. So, always derive from Ops
# during type checking.
_Ops = Ops
else:
try:
from thinc_apple_ops import AppleOps

_Ops = AppleOps
except ImportError:
_Ops = NumpyOps


@registry.ops("MPSOps")
class MPSOps(_Ops):
"""Ops class for Metal Performance shaders."""

name = "mps"
xp = numpy
13 changes: 12 additions & 1 deletion thinc/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@
import torch

has_torch = True
has_torch_gpu = torch.cuda.device_count() != 0
has_torch_cuda_gpu = torch.cuda.device_count() != 0
has_torch_mps_gpu = (
hasattr(torch, "has_mps")
and torch.has_mps
and torch.backends.mps.is_available()
)
has_torch_gpu = has_torch_cuda_gpu
torch_version = Version(str(torch.__version__))
has_torch_amp = (
torch_version >= Version("1.9.0")
Expand All @@ -40,7 +46,9 @@
except ImportError: # pragma: no cover
torch = None # type: ignore
has_torch = False
has_torch_cuda_gpu = False
has_torch_gpu = False
has_torch_mps_gpu = False
has_torch_amp = False
torch_version = Version("0.0.0")

Expand Down Expand Up @@ -68,3 +76,6 @@
import h5py
except ImportError: # pragma: no cover
h5py = None


has_gpu = has_cupy_gpu or has_torch_mps_gpu
38 changes: 28 additions & 10 deletions thinc/tests/layers/test_pytorch_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,37 @@
from thinc.api import Linear, SGD, PyTorchWrapper, PyTorchWrapper_v2
from thinc.api import xp2torch, torch2xp, ArgsKwargs, use_ops
from thinc.api import chain, get_current_ops, Relu
from thinc.api import CupyOps, MPSOps, NumpyOps
from thinc.backends import context_pools
from thinc.shims.pytorch_grad_scaler import PyTorchGradScaler
from thinc.compat import has_torch, has_torch_amp, has_torch_gpu
from thinc.compat import has_cupy
from thinc.compat import has_torch, has_torch_amp
from thinc.compat import has_cupy_gpu, has_torch_mps_gpu
import numpy
import pytest
from thinc.util import get_torch_default_device

from ..util import make_tempdir, check_input_converters


XP_OPS = [NumpyOps()]
if has_cupy_gpu:
XP_OPS.append(CupyOps())
if has_torch_mps_gpu:
XP_OPS.append(MPSOps())


if has_torch_amp:
TORCH_MIXED_PRECISION = [False, True]
else:
TORCH_MIXED_PRECISION = [False]

XP_OPS_MIXED = [
(ops, mixed)
for ops in XP_OPS
for mixed in TORCH_MIXED_PRECISION
if not mixed or isinstance(ops, CupyOps)
]


def check_learns_zero_output(model, sgd, X, Y):
"""Check we can learn to output a zero vector"""
Expand Down Expand Up @@ -64,32 +80,34 @@ def test_pytorch_wrapper(nN, nI, nO):
assert isinstance(model.predict(X), numpy.ndarray)


@pytest.mark.skipif(
not has_cupy or not has_torch_gpu, reason="needs PyTorch with CUDA-capable GPU"
)
@pytest.mark.skipif(not has_torch, reason="needs PyTorch")
@pytest.mark.parametrize("ops_mixed", XP_OPS_MIXED)
@pytest.mark.parametrize("nN,nI,nO", [(2, 3, 4)])
@pytest.mark.parametrize("mixed_precision", TORCH_MIXED_PRECISION)
def test_pytorch_wrapper_thinc_input(nN, nI, nO, mixed_precision):
def test_pytorch_wrapper_thinc_input(ops_mixed, nN, nI, nO):
import torch.nn

with use_ops("cupy"):
ops, mixed_precision = ops_mixed

with use_ops(ops.name):
ops = get_current_ops()
pytorch_layer = torch.nn.Linear(nO, nO)
# Initialize with large weights to trigger overflow of FP16 in
# mixed-precision training.
torch.nn.init.uniform_(pytorch_layer.weight, 9.0, 11.0)
device = get_torch_default_device()
model = chain(
Relu(),
PyTorchWrapper_v2(
pytorch_layer.cuda(),
pytorch_layer.to(device),
mixed_precision=mixed_precision,
grad_scaler=PyTorchGradScaler(
enabled=mixed_precision, init_scale=2.0**16
),
).initialize(),
)
# pytorch allocator is set in PyTorchShim
assert "pytorch" in context_pools.get()
if isinstance(ops, CupyOps):
assert "pytorch" in context_pools.get()
sgd = SGD(0.001)
X = ops.xp.zeros((nN, nI), dtype="f")
X += ops.xp.random.uniform(size=X.size).reshape(X.shape)
Expand Down
4 changes: 2 additions & 2 deletions thinc/tests/regression/test_issue564.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytest

from thinc.api import CupyOps
from thinc.compat import has_torch, has_torch_gpu
from thinc.compat import has_torch, has_torch_cuda_gpu


@pytest.mark.skipif(not has_torch, reason="needs PyTorch")
@pytest.mark.skipif(not has_torch_gpu, reason="needs a GPU")
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs a GPU")
def test_issue564():
import torch

Expand Down
6 changes: 3 additions & 3 deletions thinc/tests/shims/test_pytorch_grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from hypothesis import given, settings
from hypothesis.strategies import lists, one_of, tuples
from thinc.compat import has_torch, has_torch_amp, has_torch_gpu, torch
from thinc.compat import has_torch, has_torch_amp, has_torch_cuda_gpu, torch
from thinc.util import is_torch_array
from thinc.api import PyTorchGradScaler

Expand All @@ -14,7 +14,7 @@ def tensors():


@pytest.mark.skipif(not has_torch, reason="needs PyTorch")
@pytest.mark.skipif(not has_torch_gpu, reason="needs a GPU")
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs a GPU")
@pytest.mark.skipif(
not has_torch_amp, reason="requires PyTorch with mixed-precision support"
)
Expand All @@ -37,7 +37,7 @@ def test_scale_random_inputs(X):


@pytest.mark.skipif(not has_torch, reason="needs PyTorch")
@pytest.mark.skipif(not has_torch_gpu, reason="needs a GPU")
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs a GPU")
@pytest.mark.skipif(
not has_torch_amp, reason="requires PyTorch with mixed-precision support"
)
Expand Down
46 changes: 30 additions & 16 deletions thinc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from contextvars import ContextVar
from dataclasses import dataclass
from .compat import has_cupy, has_mxnet, has_torch, has_tensorflow
from .compat import has_cupy_gpu, has_torch_gpu
from .compat import has_cupy_gpu, has_torch_cuda_gpu, has_gpu
from .compat import has_torch_mps_gpu
from .compat import torch, cupy, tensorflow as tf, mxnet as mx, cupy_from_dlpack

DATA_VALIDATION: ContextVar[bool] = ContextVar("DATA_VALIDATION", default=False)
Expand All @@ -33,11 +34,14 @@ def get_torch_default_device() -> "torch.device":

from .backends import get_current_ops
from .backends.cupy_ops import CupyOps
from .backends.mps_ops import MPSOps

ops = get_current_ops()
if isinstance(ops, CupyOps):
device_id = torch.cuda.current_device()
return torch.device(f"cuda:{device_id}")
elif isinstance(ops, MPSOps):
return torch.device("mps")

return torch.device("cpu")

Expand All @@ -50,7 +54,7 @@ def get_array_module(arr): # pragma: no cover


def gpu_is_available():
return has_cupy_gpu
return has_gpu


def fix_random_seed(seed: int = 0) -> None: # pragma: no cover
Expand All @@ -61,7 +65,7 @@ def fix_random_seed(seed: int = 0) -> None: # pragma: no cover
torch.manual_seed(seed)
if has_cupy_gpu:
cupy.random.seed(seed)
if has_torch and has_torch_gpu:
if has_torch and has_torch_cuda_gpu:
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
Expand Down Expand Up @@ -99,10 +103,18 @@ def is_torch_array(obj: Any) -> bool: # pragma: no cover
return False


def is_torch_gpu_array(obj: Any) -> bool: # pragma: no cover
def is_torch_cuda_array(obj: Any) -> bool: # pragma: no cover
return is_torch_array(obj) and obj.is_cuda


def is_torch_gpu_array(obj: Any) -> bool: # pragma: no cover
return is_torch_cuda_array(obj) or is_torch_mps_array(obj)


def is_torch_mps_array(obj: Any) -> bool: # pragma: no cover
return is_torch_array(obj) and hasattr(obj, "is_mps") and obj.is_mps


def is_tensorflow_array(obj: Any) -> bool: # pragma: no cover
if not has_tensorflow:
return False
Expand Down Expand Up @@ -146,7 +158,7 @@ def set_active_gpu(gpu_id: int) -> "cupy.cuda.Device": # pragma: no cover
device = cupy.cuda.device.Device(gpu_id)
device.use()

if has_torch_gpu:
if has_torch_cuda_gpu:
torch.cuda.set_device(gpu_id)

return device
Expand All @@ -164,21 +176,23 @@ def require_cpu() -> bool: # pragma: no cover

def prefer_gpu(gpu_id: int = 0) -> bool: # pragma: no cover
"""Use GPU if it's available. Returns True if so, False otherwise."""
if not has_cupy_gpu:
return False
else:
if has_gpu:
require_gpu(gpu_id=gpu_id)
return True
return has_gpu


def require_gpu(gpu_id: int = 0) -> bool: # pragma: no cover
from .backends import set_current_ops, CupyOps
from .backends import set_current_ops, CupyOps, MPSOps

if not has_cupy_gpu:
raise ValueError("No CUDA GPU devices detected")
if not has_gpu:
raise ValueError("No GPU devices detected")

if has_cupy_gpu:
set_current_ops(CupyOps())
set_active_gpu(gpu_id)
else:
set_current_ops(MPSOps())

set_current_ops(CupyOps())
set_active_gpu(gpu_id)
return True


Expand Down Expand Up @@ -353,14 +367,14 @@ def torch2xp(
from .api import NumpyOps

assert_pytorch_installed()
if is_torch_gpu_array(torch_tensor):
if is_torch_cuda_array(torch_tensor):
if isinstance(ops, NumpyOps):
return torch_tensor.detach().cpu().numpy()
else:
return cupy_from_dlpack(torch.utils.dlpack.to_dlpack(torch_tensor))
else:
if isinstance(ops, NumpyOps) or ops is None:
return torch_tensor.detach().numpy()
return torch_tensor.detach().cpu().numpy()
else:
return cupy.asarray(torch_tensor)

Expand Down

0 comments on commit 5beeaf2

Please sign in to comment.