Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorchShim: Add serde callbacks to facilitate lazy loading models #796

Merged
merged 11 commits into from
Dec 8, 2022
4 changes: 2 additions & 2 deletions thinc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .layers import Dish, HardSwish, HardSwishMobilenet, Swish, Gelu
from .layers import PyTorchWrapper, PyTorchRNNWrapper, PyTorchLSTM
from .layers import TensorFlowWrapper, keras_subclass, MXNetWrapper
from .layers import PyTorchWrapper_v2, Softmax_v2
from .layers import PyTorchWrapper_v2, Softmax_v2, PyTorchWrapper_v3

from .layers import add, bidirectional, chain, clone, concatenate, noop
from .layers import residual, uniqued, siamese, list2ragged, ragged2list
Expand Down Expand Up @@ -91,7 +91,7 @@
"Dish", "HardSwish", "HardSwishMobilenet", "Swish", "Gelu",
"PyTorchWrapper", "PyTorchRNNWrapper", "PyTorchLSTM",
"TensorFlowWrapper", "keras_subclass", "MXNetWrapper",
"PyTorchWrapper_v2", "Softmax_v2",
"PyTorchWrapper_v2", "Softmax_v2", "PyTorchWrapper_v3",

"add", "bidirectional", "chain", "clone", "concatenate", "noop",
"residual", "uniqued", "siamese", "list2ragged", "ragged2list",
Expand Down
3 changes: 2 additions & 1 deletion thinc/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .mish import Mish
from .multisoftmax import MultiSoftmax
from .parametricattention import ParametricAttention
from .pytorchwrapper import PyTorchWrapper, PyTorchWrapper_v2
from .pytorchwrapper import PyTorchWrapper, PyTorchWrapper_v2, PyTorchWrapper_v3
from .pytorchwrapper import PyTorchRNNWrapper
from .relu import Relu
from .clipped_linear import ClippedLinear, ReluK, HardSigmoid, HardTanh
Expand Down Expand Up @@ -92,6 +92,7 @@
"PyTorchLSTM",
"PyTorchWrapper",
"PyTorchWrapper_v2",
"PyTorchWrapper_v3",
"PyTorchRNNWrapper",
"Relu",
"sigmoid_activation",
Expand Down
80 changes: 79 additions & 1 deletion thinc/layers/pytorchwrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Tuple, Optional, Any, cast
from typing import Callable, Dict, Tuple, Optional, Any, cast

from ..compat import torch
from ..model import Model
Expand Down Expand Up @@ -132,6 +132,84 @@ def PyTorchWrapper_v2(
)


@registry.layers("PyTorchWrapper.v3")
def PyTorchWrapper_v3(
pytorch_model,
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
convert_inputs: Optional[Callable] = None,
convert_outputs: Optional[Callable] = None,
mixed_precision: bool = False,
grad_scaler: Optional[PyTorchGradScaler] = None,
device: Optional["torch.device"] = None,
serialize_model: Optional[Callable[[Any], bytes]] = None,
deserialize_model: Optional[Callable[[Any, bytes, "torch.device"], Any]] = None,
) -> Model[Any, Any]:
"""Wrap a PyTorch model, so that it has the same API as Thinc models.
To optimize the model, you'll need to create a PyTorch optimizer and call
optimizer.step() after each batch. See examples/wrap_pytorch.py

Your PyTorch model's forward method can take arbitrary args and kwargs,
but must return either a single tensor as output or a tuple. You may find the
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
PyTorch register_forward_hook helpful if you need to adapt the output.

The convert functions are used to map inputs and outputs to and from your
PyTorch model. Each function should return the converted output, and a callback
to use during the backward pass. So:

Xtorch, get_dX = convert_inputs(X)
Ytorch, torch_backprop = model.shims[0](Xtorch, is_train)
Y, get_dYtorch = convert_outputs(Ytorch)

To allow maximum flexibility, the PyTorchShim expects ArgsKwargs objects
on the way into the forward and backward passed. The ArgsKwargs objects
will be passed straight into the model in the forward pass, and straight
into `torch.autograd.backward` during the backward pass.

mixed_precision:
Enable mixed-precision. This changes whitelisted ops to run
in half precision for better performance and lower memory use.
grad_scaler:
The gradient scaler to use for mixed-precision training. If this
argument is set to "None" and mixed precision is enabled, a gradient
scaler with the default configuration is used.
device:
The PyTorch device to run the model on. When this argument is
set to "None", the default device for the currently active Thinc
ops is used.
serialize_model:
Callback that receives the wrapped PyTorch model as its argument and
returns a "bytes" representation of the same. The representation should
contain all the necessary information to fully deserialize the model.

When set to "None", the default serializer serializes the model's parameters.
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
deserialize_model:
Callback that receives the default PyTorch model (passed to the constructor), the
serialized "bytes" representation and a PyTorch device. It should return a
fully deserialized model on the target device as its result.

When set to "None", the default dserializer deserializes the model's parameters.
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
"""
if convert_inputs is None:
convert_inputs = convert_pytorch_default_inputs
if convert_outputs is None:
convert_outputs = convert_pytorch_default_outputs
return Model(
"pytorch",
forward,
attrs={"convert_inputs": convert_inputs, "convert_outputs": convert_outputs},
shims=[
PyTorchShim(
pytorch_model,
mixed_precision=mixed_precision,
grad_scaler=grad_scaler,
device=device,
serialize_model=serialize_model,
deserialize_model=deserialize_model,
)
],
dims={"nI": None, "nO": None},
)


def forward(model: Model, X: Any, is_train: bool) -> Tuple[Any, Callable]:
"""Return the output of the wrapped PyTorch model for the given input,
along with a callback to handle the backward pass.
Expand Down
49 changes: 38 additions & 11 deletions thinc/shims/pytorch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, cast
from typing import Any, Dict, Optional, cast, Callable
import contextlib
from io import BytesIO
import itertools
Expand Down Expand Up @@ -30,6 +30,14 @@ class PyTorchShim(Shim):
The PyTorch device to run the model on. When this argument is
set to "None", the default device for the currently active Thinc
ops is used.
serialize_model:
Callback that receives the wrapped PyTorch model as its argument and
returns a "bytes" representation of the same. The representation should
contain all the necessary information to fully deserialize the model.
deserialize_model:
Callback that receives the default PyTorch model (passed to the constructor), the
serialized "bytes" representation and a PyTorch device. It should return a
fully deserialized model on the target device as its result.
"""

def __init__(
Expand All @@ -40,6 +48,8 @@ def __init__(
mixed_precision: bool = False,
grad_scaler: Optional[PyTorchGradScaler] = None,
device: Optional["torch.device"] = None,
serialize_model: Optional[Callable[[Any], bytes]] = None,
danieldk marked this conversation as resolved.
Show resolved Hide resolved
deserialize_model: Optional[Callable[[Any, bytes, "torch.device"], Any]] = None,
):
super().__init__(model, config, optimizer)

Expand All @@ -54,9 +64,15 @@ def __init__(
grad_scaler.to_(device)

self._grad_scaler = grad_scaler

self._mixed_precision = mixed_precision

if serialize_model is None:
serialize_model = default_serialize_torch_model
if deserialize_model is None:
deserialize_model = default_deserialize_torch_model
self._serialize_model = serialize_model
self._deserialize_model = deserialize_model
shadeMe marked this conversation as resolved.
Show resolved Hide resolved

if CupyOps.xp is not None and isinstance(get_current_ops(), CupyOps):
pools = context_pools.get()
if "pytorch" not in pools:
Expand Down Expand Up @@ -179,20 +195,31 @@ def to_device(self, device_type: str, device_id: int): # pragma: no cover
raise ValueError(msg)

def to_bytes(self):
filelike = BytesIO()
torch.save(self._model.state_dict(), filelike)
filelike.seek(0)
weights_bytes = filelike.getvalue()
msg = {"config": self.cfg, "state": weights_bytes}
model_bytes = self._serialize_model(self._model)
msg = {"config": self.cfg, "state": model_bytes}
return srsly.msgpack_dumps(msg)

def from_bytes(self, bytes_data):
device = get_torch_default_device()
msg = srsly.msgpack_loads(bytes_data)
self.cfg = msg["config"]
filelike = BytesIO(msg["state"])
filelike.seek(0)
self._model.load_state_dict(torch.load(filelike, map_location=device))
self._model.to(device)
self._model = self._deserialize_model(self._model, msg["state"], device)
self._grad_scaler.to_(device)
return self


def default_serialize_torch_model(model: Any) -> bytes:
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
filelike = BytesIO()
torch.save(model.state_dict(), filelike)
filelike.seek(0)
return filelike.getvalue()


def default_deserialize_torch_model(
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
model: Any, state_bytes: bytes, device: "torch.device"
) -> Any:
filelike = BytesIO(state_bytes)
filelike.seek(0)
model.load_state_dict(torch.load(filelike, map_location=device))
model.to(device)
return model
34 changes: 33 additions & 1 deletion thinc/tests/layers/test_pytorch_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from thinc.api import Linear, SGD, PyTorchWrapper, PyTorchWrapper_v2
from thinc.api import Linear, SGD, PyTorchWrapper, PyTorchWrapper_v2, PyTorchWrapper_v3
from thinc.api import xp2torch, torch2xp, ArgsKwargs, use_ops
from thinc.api import chain, get_current_ops, Relu
from thinc.api import CupyOps, MPSOps, NumpyOps
from thinc.backends import context_pools
from thinc.layers.pytorchwrapper import PyTorchWrapper_v3
from thinc.shims.pytorch_grad_scaler import PyTorchGradScaler
from thinc.shims.pytorch import (
default_deserialize_torch_model,
default_serialize_torch_model,
)
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
from thinc.compat import has_torch, has_torch_amp
from thinc.compat import has_cupy_gpu, has_torch_mps_gpu
import numpy
Expand Down Expand Up @@ -169,3 +174,30 @@ def test_pytorch_convert_inputs(data, n_args, kwargs_keys):
convert_inputs = model.attrs["convert_inputs"]
Y, backprop = convert_inputs(model, data, is_train=True)
check_input_converters(Y, backprop, data, n_args, kwargs_keys, torch.Tensor)


@pytest.mark.skipif(not has_torch, reason="needs PyTorch")
def test_pytorch_wrapper_custom_serde():
import torch.nn

def serialize(model):
return default_serialize_torch_model(model)

def deserialize(model, state_bytes, device):
return default_deserialize_torch_model(model, state_bytes, device)

def get_model():
return PyTorchWrapper_v3(
torch.nn.Linear(2, 3),
serialize_model=serialize,
deserialize_model=deserialize,
)

model = get_model()
model_bytes = model.to_bytes()
get_model().from_bytes(model_bytes)
with make_tempdir() as path:
model_path = path / "model"
model.to_disk(model_path)
new_model = get_model().from_bytes(model_bytes)
new_model.from_disk(model_path)