Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge pytorch-device branch into master #695

Merged
merged 9 commits into from
Jun 14, 2022
9 changes: 7 additions & 2 deletions examples/transformers_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def forward(
return TokensPlus(**token_data), lambda d_tokens: []

return Model(
"tokenizer", forward, attrs={"tokenizer": AutoTokenizer.from_pretrained(name)},
"tokenizer",
forward,
attrs={"tokenizer": AutoTokenizer.from_pretrained(name)},
)


Expand Down Expand Up @@ -166,11 +168,14 @@ def convert_transformer_outputs(model, inputs_outputs, is_train):

def backprop(d_tokvecs: List[Floats2d]) -> ArgsKwargs:
# Restore entries for bos and eos markers.
shim = model.shims[0]
row = model.ops.alloc2f(1, d_tokvecs[0].shape[1])
d_tokvecs = [model.ops.xp.vstack((row, arr, row)) for arr in d_tokvecs]
return ArgsKwargs(
args=(torch_tokvecs,),
kwargs={"grad_tensors": xp2torch(model.ops.pad(d_tokvecs))},
kwargs={
"grad_tensors": xp2torch(model.ops.pad(d_tokvecs, device=shim.device))
},
)

return tokvecs, backprop
Expand Down
2 changes: 1 addition & 1 deletion thinc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .util import torch2xp, xp2torch, tensorflow2xp, xp2tensorflow, mxnet2xp, xp2mxnet
from .compat import has_cupy
from .backends import get_ops, set_current_ops, get_current_ops, use_ops
from .backends import Ops, CupyOps, NumpyOps, set_gpu_allocator
from .backends import Ops, CupyOps, MPSOps, NumpyOps, set_gpu_allocator
from .backends import use_pytorch_for_gpu_memory, use_tensorflow_for_gpu_memory

from .layers import Dropout, Embed, expand_window, HashEmbed, LayerNorm, Linear
Expand Down
9 changes: 7 additions & 2 deletions thinc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from .ops import Ops
from .cupy_ops import CupyOps
from .numpy_ops import NumpyOps
from .mps_ops import MPSOps
from ._cupy_allocators import cupy_tensorflow_allocator, cupy_pytorch_allocator
from ._param_server import ParamServer
from ..util import assert_tensorflow_installed, assert_pytorch_installed
from ..util import is_cupy_array, set_torch_tensor_type_for_ops, require_cpu
from ..util import get_torch_default_device, is_cupy_array, require_cpu
from .. import registry
from ..compat import cupy, has_cupy

Expand Down Expand Up @@ -48,6 +49,10 @@ def use_pytorch_for_gpu_memory() -> None: # pragma: no cover
(or vice versa), but do not currently have an implementation for it.
"""
assert_pytorch_installed()

if get_torch_default_device().type != "cuda":
return

pools = context_pools.get()
if "pytorch" not in pools:
pools["pytorch"] = cupy.cuda.MemoryPool(allocator=cupy_pytorch_allocator)
Expand Down Expand Up @@ -134,7 +139,6 @@ def set_current_ops(ops: Ops) -> None:
"""Change the current backend object."""
context_ops.set(ops)
_get_thread_state().ops = ops
set_torch_tensor_type_for_ops(ops)


def contextvars_eq_thread_ops() -> bool:
Expand Down Expand Up @@ -170,6 +174,7 @@ def _create_thread_local(
"ParamServer",
"Ops",
"CupyOps",
"MPSOps",
"NumpyOps",
"has_cupy",
]
7 changes: 5 additions & 2 deletions thinc/backends/_cupy_allocators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import cast

from ..types import ArrayXd
from ..util import tensorflow2xp
from ..util import get_torch_default_device, tensorflow2xp
from ..compat import torch, cupy, tensorflow


Expand All @@ -23,6 +23,7 @@ def cupy_tensorflow_allocator(size_in_bytes: int):


def cupy_pytorch_allocator(size_in_bytes: int):
device = get_torch_default_device()
"""Function that can be passed into cupy.cuda.set_allocator, to have cupy
allocate memory via PyTorch. This is important when using the two libraries
together, as otherwise OOM errors can occur when there's available memory
Expand All @@ -34,7 +35,9 @@ def cupy_pytorch_allocator(size_in_bytes: int):
# creating a whole Tensor.
# This turns out to be way faster than making FloatStorage? Maybe
# a Python vs C++ thing I guess?
torch_tensor = torch.zeros((size_in_bytes // 4,), requires_grad=False)
torch_tensor = torch.zeros(
(size_in_bytes // 4,), requires_grad=False, device=device
)
# cupy has a neat class to help us here. Otherwise it will try to free.
# I think this is a private API? It's not in the types.
address = torch_tensor.data_ptr() # type: ignore
Expand Down
4 changes: 2 additions & 2 deletions thinc/backends/cupy_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..types import DeviceTypes
from ..util import torch2xp, tensorflow2xp, mxnet2xp
from ..util import is_cupy_array
from ..util import is_torch_gpu_array, is_tensorflow_gpu_array, is_mxnet_gpu_array
from ..util import is_torch_cuda_array, is_tensorflow_gpu_array, is_mxnet_gpu_array
from ..compat import cupy, cupyx


Expand Down Expand Up @@ -62,7 +62,7 @@ def asarray(self, data, dtype=None):
# We'll try to perform a zero-copy conversion if possible.
if is_cupy_array(data):
array = data
elif is_torch_gpu_array(data):
elif is_torch_cuda_array(data):
array = torch2xp(data)
elif is_tensorflow_gpu_array(data):
array = tensorflow2xp(data)
Expand Down
26 changes: 26 additions & 0 deletions thinc/backends/mps_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import TYPE_CHECKING
import numpy

from .. import registry
from . import NumpyOps, Ops

if TYPE_CHECKING:
# Type checking does not work with dynamic base classes, since MyPy cannot
# determine against which base class to check. So, always derive from Ops
# during type checking.
_Ops = Ops
else:
try:
from thinc_apple_ops import AppleOps

_Ops = AppleOps
except ImportError:
_Ops = NumpyOps


@registry.ops("MPSOps")
class MPSOps(_Ops):
"""Ops class for Metal Performance shaders."""

name = "mps"
xp = numpy
13 changes: 12 additions & 1 deletion thinc/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@
import torch

has_torch = True
has_torch_gpu = torch.cuda.device_count() != 0
has_torch_cuda_gpu = torch.cuda.device_count() != 0
has_torch_mps_gpu = (
hasattr(torch, "has_mps")
and torch.has_mps
and torch.backends.mps.is_available()
)
has_torch_gpu = has_torch_cuda_gpu
torch_version = Version(str(torch.__version__))
has_torch_amp = (
torch_version >= Version("1.9.0")
Expand All @@ -40,7 +46,9 @@
except ImportError: # pragma: no cover
torch = None # type: ignore
has_torch = False
has_torch_cuda_gpu = False
has_torch_gpu = False
has_torch_mps_gpu = False
has_torch_amp = False
torch_version = Version("0.0.0")

Expand Down Expand Up @@ -68,3 +76,6 @@
import h5py
except ImportError: # pragma: no cover
h5py = None


has_gpu = has_cupy_gpu or has_torch_mps_gpu
30 changes: 24 additions & 6 deletions thinc/layers/pytorchwrapper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Callable, Tuple, Optional, Any, cast

from ..compat import torch
from ..model import Model
from ..shims import PyTorchGradScaler, PyTorchShim
from ..config import registry
from ..util import is_xp_array, is_torch_array
from ..util import is_xp_array, is_torch_array, partial
from ..util import xp2torch, torch2xp, convert_recursive
from ..types import Floats3d, ArgsKwargs, Padded

Expand Down Expand Up @@ -76,6 +77,7 @@ def PyTorchWrapper_v2(
convert_outputs: Optional[Callable] = None,
mixed_precision: bool = False,
grad_scaler: Optional[PyTorchGradScaler] = None,
device: Optional["torch.device"] = None,
) -> Model[Any, Any]:
"""Wrap a PyTorch model, so that it has the same API as Thinc models.
To optimize the model, you'll need to create a PyTorch optimizer and call
Expand Down Expand Up @@ -105,6 +107,10 @@ def PyTorchWrapper_v2(
The gradient scaler to use for mixed-precision training. If this
argument is set to "None" and mixed precision is enabled, a gradient
scaler with the default configuration is used.
device:
The PyTorch device to run the model on. When this argument is
set to "None", the default device for the currently active Thinc
ops is used.
"""
if convert_inputs is None:
convert_inputs = convert_pytorch_default_inputs
Expand All @@ -116,7 +122,10 @@ def PyTorchWrapper_v2(
attrs={"convert_inputs": convert_inputs, "convert_outputs": convert_outputs},
shims=[
PyTorchShim(
pytorch_model, mixed_precision=mixed_precision, grad_scaler=grad_scaler
pytorch_model,
mixed_precision=mixed_precision,
grad_scaler=grad_scaler,
device=device,
)
],
dims={"nI": None, "nO": None},
Expand Down Expand Up @@ -149,7 +158,8 @@ def backprop(dY: Any) -> Any:
def convert_pytorch_default_inputs(
model: Model, X: Any, is_train: bool
) -> Tuple[ArgsKwargs, Callable[[ArgsKwargs], Any]]:
xp2torch_ = lambda x: xp2torch(x, requires_grad=is_train)
shim = cast(PyTorchShim, model.shims[0])
xp2torch_ = lambda x: xp2torch(x, requires_grad=is_train, device=shim.device)
converted = convert_recursive(is_xp_array, xp2torch_, X)
if isinstance(converted, ArgsKwargs):

Expand Down Expand Up @@ -181,11 +191,14 @@ def reverse_conversion(dXtorch):


def convert_pytorch_default_outputs(model: Model, X_Ytorch: Any, is_train: bool):
shim = cast(PyTorchShim, model.shims[0])
X, Ytorch = X_Ytorch
Y = convert_recursive(is_torch_array, torch2xp, Ytorch)

def reverse_conversion(dY: Any) -> ArgsKwargs:
dYtorch = convert_recursive(is_xp_array, xp2torch, dY)
dYtorch = convert_recursive(
is_xp_array, partial(xp2torch, device=shim.device), dY
)
return ArgsKwargs(args=((Ytorch,),), kwargs={"grad_tensors": dYtorch})

return Y, reverse_conversion
Expand All @@ -195,6 +208,7 @@ def reverse_conversion(dY: Any) -> ArgsKwargs:


def convert_rnn_inputs(model: Model, Xp: Padded, is_train: bool):
shim = cast(PyTorchShim, model.shims[0])
size_at_t = Xp.size_at_t
lengths = Xp.lengths
indices = Xp.indices
Expand All @@ -203,15 +217,19 @@ def convert_from_torch_backward(d_inputs: ArgsKwargs) -> Padded:
dX = torch2xp(d_inputs.args[0])
return Padded(dX, size_at_t, lengths, indices) # type: ignore

output = ArgsKwargs(args=(xp2torch(Xp.data, requires_grad=True), None), kwargs={})
output = ArgsKwargs(
args=(xp2torch(Xp.data, requires_grad=True, device=shim.device), None),
kwargs={},
)
return output, convert_from_torch_backward


def convert_rnn_outputs(model: Model, inputs_outputs: Tuple, is_train):
shim = cast(PyTorchShim, model.shims[0])
Xp, (Ytorch, _) = inputs_outputs

def convert_for_torch_backward(dYp: Padded) -> ArgsKwargs:
dYtorch = xp2torch(dYp.data, requires_grad=True)
dYtorch = xp2torch(dYp.data, requires_grad=True, device=shim.device)
return ArgsKwargs(args=(Ytorch,), kwargs={"grad_tensors": dYtorch})

Y = cast(Floats3d, torch2xp(Ytorch))
Expand Down
40 changes: 29 additions & 11 deletions thinc/shims/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import srsly

from ..util import torch2xp, xp2torch, convert_recursive, iterate_recursive
from ..util import get_torch_default_device
from ..compat import torch
from ..backends import get_current_ops, context_pools, CupyOps
from ..backends import set_gpu_allocator
Expand All @@ -25,6 +26,10 @@ class PyTorchShim(Shim):
The gradient scaler to use for mixed-precision training. If this
argument is set to "None" and mixed precision is enabled, a gradient
scaler with the default configuration is used.
device:
The PyTorch device to run the model on. When this argument is
set to "None", the default device for the currently active Thinc
ops is used.
"""

def __init__(
Expand All @@ -34,12 +39,20 @@ def __init__(
optimizer: Any = None,
mixed_precision: bool = False,
grad_scaler: Optional[PyTorchGradScaler] = None,
device: Optional["torch.device"] = None,
):
super().__init__(model, config, optimizer)

if device is None:
device = get_torch_default_device()
if model is not None:
model.to(device)

if grad_scaler is None:
grad_scaler = PyTorchGradScaler(mixed_precision)

grad_scaler.to_(device)

self._grad_scaler = grad_scaler

self._mixed_precision = mixed_precision
Expand All @@ -58,6 +71,14 @@ def __call__(self, inputs, is_train):
else:
return self.predict(inputs), lambda a: ...

@property
def device(self):
p = next(self._model.parameters(), None)
if p is None:
return get_torch_default_device()
else:
return p.device

def predict(self, inputs: ArgsKwargs) -> Any:
"""Pass inputs through to the underlying PyTorch model, and return the
output. No conversions are performed. The PyTorch model is set into
Expand Down Expand Up @@ -126,7 +147,9 @@ def finish_update(self, optimizer: Optimizer):
cast(FloatsXd, torch2xp(torch_data.data)),
cast(FloatsXd, torch2xp(torch_data.grad)),
)
torch_data.data = xp2torch(param, requires_grad=True)
torch_data.data = xp2torch(
param, requires_grad=True, device=torch_data.device
)
torch_data.grad.zero_()

self._grad_scaler.update()
Expand All @@ -137,7 +160,7 @@ def use_params(self, params):
state_dict = {}
for k, v in params.items():
if hasattr(k, "startswith") and k.startswith(key_prefix):
state_dict[k.replace(key_prefix, "")] = xp2torch(v)
state_dict[k.replace(key_prefix, "")] = xp2torch(v, device=self.device)
if state_dict:
backup = {k: v.clone() for k, v in self._model.state_dict().items()}
self._model.load_state_dict(state_dict)
Expand All @@ -164,17 +187,12 @@ def to_bytes(self):
return srsly.msgpack_dumps(msg)

def from_bytes(self, bytes_data):
ops = get_current_ops()
device = get_torch_default_device()
msg = srsly.msgpack_loads(bytes_data)
self.cfg = msg["config"]
filelike = BytesIO(msg["state"])
filelike.seek(0)
if ops.device_type == "cpu":
map_location = "cpu"
else: # pragma: no cover
device_id = torch.cuda.current_device()
map_location = "cuda:%d" % device_id
self._model.load_state_dict(torch.load(filelike, map_location=map_location))
self._model.to(map_location)
self._grad_scaler.to_(map_location)
self._model.load_state_dict(torch.load(filelike, map_location=device))
self._model.to(device)
self._grad_scaler.to_(device)
return self
Loading