-
Notifications
You must be signed in to change notification settings - Fork 275
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for PyTorch Metal Performance Shaders (#685)
* Add `test_slow_gpu` explosion-bot command * Auto-format code with black (#682) Co-authored-by: explosion-bot <explosion-bot@users.noreply.github.com> * Add support for PyTorch Metal Performance Shaders Nightly PyTorch versions add support for Metal Performance Shaders (MPS). Metal is a low-level graphics API for Apple platforms that also supports compute kernels (shaders). MPS is a framework of highly-optimized compute and graphics kernels, including kernels for neural networks. MPS is supported on both Apple Silicon, such as the M1 family of SoC, as well as a range of AMD GPUs used in Macs. Since devices are handled in Thinc through a specific `Ops` implementation (e.g. `CupyOps` == CUDA GPUs), this change introduces the `MPSOps` class. This class is a subclass of `NumpyOps` or `AppleOps` (when available). `MPSOps` does not override any methods, but is used to signal to relevant code paths (e.g. `xp2torch`) that Torch tensors should be placed on the MPS device. The mapping in the previously introduced `get_torch_default_device` function is updated to: - `NumpyOps` -> `cpu` - `CupyOps` -> `cuda:N`, where N is the selected CUDA device. - `MPSOps` -> `mps` to ensure placement of Torch tensors on the `mps` device when `MPSOps` is active. Finally, the following booleans have been added to or changed in `compat`: - `has_torch_mps` (new): PyTorch has MPS support - `has_torch_mps_gpu` (new): PyTorch has MPS support and an MPS-capable GPU is available. - `has_torch_cuda_gpu` (new): PyTorch has CUDA support and a CUDA-capable GPU is available. - `has_torch_gpu` (changed): PyTorch has a GPU available (CUDA or MPS). * Test PyTorch wrapper with all xp ops * Azure: pin protobuf to fix Tensorflow * Extend typing_extensions to <4.2.0 (#689) * Fix type checking error * Only back-off to NumpyOps on import error We do not want to hide other issues while importing thinc_apple_ops. * Remove unneeded `has_torch_mps` bool * Add `has_gpu` bool and use it in `util` * Replace another expression by has_gpu * Set `has_torch_gpu` to `has_torch_cuda_gpu` We need to decide whether we want to make the potentially breaking change from `has_torch_cuda_gpu` to `has_torch_cuda_gpu or has_torch_mps_gpu`. But since the latter is not needed for this PR, remove the change. * Update thinc/util.py Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: shademe <shadeMe@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: explosion-bot <explosion-bot@users.noreply.github.com> Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
- Loading branch information
1 parent
b8054fd
commit 5beeaf2
Showing
9 changed files
with
111 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from typing import TYPE_CHECKING | ||
import numpy | ||
|
||
from .. import registry | ||
from . import NumpyOps, Ops | ||
|
||
if TYPE_CHECKING: | ||
# Type checking does not work with dynamic base classes, since MyPy cannot | ||
# determine against which base class to check. So, always derive from Ops | ||
# during type checking. | ||
_Ops = Ops | ||
else: | ||
try: | ||
from thinc_apple_ops import AppleOps | ||
|
||
_Ops = AppleOps | ||
except ImportError: | ||
_Ops = NumpyOps | ||
|
||
|
||
@registry.ops("MPSOps") | ||
class MPSOps(_Ops): | ||
"""Ops class for Metal Performance shaders.""" | ||
|
||
name = "mps" | ||
xp = numpy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters