diff --git a/thinc/api.py b/thinc/api.py index 7853596b7..8f5b3247e 100644 --- a/thinc/api.py +++ b/thinc/api.py @@ -30,8 +30,8 @@ from .layers import Dish, HardSwish, HardSwishMobilenet, Swish, Gelu from .layers import PyTorchWrapper, PyTorchRNNWrapper, PyTorchLSTM from .layers import TensorFlowWrapper, keras_subclass, MXNetWrapper -from .layers import PyTorchWrapper_v2, Softmax_v2, TorchScriptWrapper_v1 -from .layers import pytorch_to_torchscript_wrapper +from .layers import PyTorchWrapper_v2, Softmax_v2, PyTorchWrapper_v3 +from .layers import TorchScriptWrapper_v1, pytorch_to_torchscript_wrapper from .layers import add, bidirectional, chain, clone, concatenate, noop from .layers import residual, uniqued, siamese, list2ragged, ragged2list @@ -92,8 +92,8 @@ "Dish", "HardSwish", "HardSwishMobilenet", "Swish", "Gelu", "PyTorchWrapper", "PyTorchRNNWrapper", "PyTorchLSTM", "TensorFlowWrapper", "keras_subclass", "MXNetWrapper", - "PyTorchWrapper_v2", "Softmax_v2", "SparseLinear_v2", - "TorchScriptWrapper_v1", + "PyTorchWrapper_v2", "Softmax_v2", "PyTorchWrapper_v3", + "SparseLinear_v2", "TorchScriptWrapper_v1", "add", "bidirectional", "chain", "clone", "concatenate", "noop", "residual", "uniqued", "siamese", "list2ragged", "ragged2list", diff --git a/thinc/layers/__init__.py b/thinc/layers/__init__.py index f1cf779f3..95c8fcc83 100644 --- a/thinc/layers/__init__.py +++ b/thinc/layers/__init__.py @@ -13,7 +13,7 @@ from .mish import Mish from .multisoftmax import MultiSoftmax from .parametricattention import ParametricAttention -from .pytorchwrapper import PyTorchWrapper, PyTorchWrapper_v2 +from .pytorchwrapper import PyTorchWrapper, PyTorchWrapper_v2, PyTorchWrapper_v3 from .pytorchwrapper import PyTorchRNNWrapper from .relu import Relu from .clipped_linear import ClippedLinear, ReluK, HardSigmoid, HardTanh @@ -93,6 +93,7 @@ "PyTorchLSTM", "PyTorchWrapper", "PyTorchWrapper_v2", + "PyTorchWrapper_v3", "PyTorchRNNWrapper", "Relu", "sigmoid_activation", diff --git a/thinc/layers/pytorchwrapper.py b/thinc/layers/pytorchwrapper.py index 882132dcb..a1b0c462a 100644 --- a/thinc/layers/pytorchwrapper.py +++ b/thinc/layers/pytorchwrapper.py @@ -1,4 +1,4 @@ -from typing import Callable, Tuple, Optional, Any, cast +from typing import Callable, Dict, Tuple, Optional, Any, cast from ..compat import torch from ..model import Model @@ -11,7 +11,7 @@ @registry.layers("PyTorchRNNWrapper.v1") def PyTorchRNNWrapper( - pytorch_model, + pytorch_model: Any, convert_inputs: Optional[Callable] = None, convert_outputs: Optional[Callable] = None, ) -> Model[Padded, Padded]: @@ -32,7 +32,7 @@ def PyTorchRNNWrapper( @registry.layers("PyTorchWrapper.v1") def PyTorchWrapper( - pytorch_model, + pytorch_model: Any, convert_inputs: Optional[Callable] = None, convert_outputs: Optional[Callable] = None, ) -> Model[Any, Any]: @@ -72,7 +72,7 @@ def PyTorchWrapper( @registry.layers("PyTorchWrapper.v2") def PyTorchWrapper_v2( - pytorch_model, + pytorch_model: Any, convert_inputs: Optional[Callable] = None, convert_outputs: Optional[Callable] = None, mixed_precision: bool = False, @@ -132,6 +132,82 @@ def PyTorchWrapper_v2( ) +@registry.layers("PyTorchWrapper.v3") +def PyTorchWrapper_v3( + pytorch_model: "torch.nn.Module", + convert_inputs: Optional[Callable] = None, + convert_outputs: Optional[Callable] = None, + mixed_precision: bool = False, + grad_scaler: Optional[PyTorchGradScaler] = None, + device: Optional["torch.device"] = None, + serialize_model: Optional[Callable[[Any], bytes]] = None, + deserialize_model: Optional[Callable[[Any, bytes, "torch.device"], Any]] = None, +) -> Model[Any, Any]: + """Wrap a PyTorch model, so that it has the same API as Thinc models. + To optimize the model, you'll need to create a PyTorch optimizer and call + optimizer.step() after each batch. See examples/wrap_pytorch.py + + Your PyTorch model's forward method can take arbitrary args and kwargs, + but must return either a single tensor or a tuple. You may find the + PyTorch register_forward_hook helpful if you need to adapt the output. + + The convert functions are used to map inputs and outputs to and from your + PyTorch model. Each function should return the converted output, and a callback + to use during the backward pass. So: + + Xtorch, get_dX = convert_inputs(X) + Ytorch, torch_backprop = model.shims[0](Xtorch, is_train) + Y, get_dYtorch = convert_outputs(Ytorch) + + To allow maximum flexibility, the PyTorchShim expects ArgsKwargs objects + on the way into the forward and backward passed. The ArgsKwargs objects + will be passed straight into the model in the forward pass, and straight + into `torch.autograd.backward` during the backward pass. + + mixed_precision: + Enable mixed-precision. This changes whitelisted ops to run + in half precision for better performance and lower memory use. + grad_scaler: + The gradient scaler to use for mixed-precision training. If this + argument is set to "None" and mixed precision is enabled, a gradient + scaler with the default configuration is used. + device: + The PyTorch device to run the model on. When this argument is + set to "None", the default device for the currently active Thinc + ops is used. + serialize_model: + Callback that receives the wrapped PyTorch model as its argument and + returns a "bytes" representation of the same. The representation should + contain all the necessary information to fully deserialize the model. + When set to "None", the default serializer serializes the model's parameters. + deserialize_model: + Callback that receives the default PyTorch model (passed to the constructor), the + serialized "bytes" representation and a PyTorch device. It should return a + fully deserialized model on the target device as its result. + When set to "None", the default deserializer deserializes the model's parameters. + """ + if convert_inputs is None: + convert_inputs = convert_pytorch_default_inputs + if convert_outputs is None: + convert_outputs = convert_pytorch_default_outputs + return Model( + "pytorch", + forward, + attrs={"convert_inputs": convert_inputs, "convert_outputs": convert_outputs}, + shims=[ + PyTorchShim( + pytorch_model, + mixed_precision=mixed_precision, + grad_scaler=grad_scaler, + device=device, + serialize_model=serialize_model, + deserialize_model=deserialize_model, + ) + ], + dims={"nI": None, "nO": None}, + ) + + def forward(model: Model, X: Any, is_train: bool) -> Tuple[Any, Callable]: """Return the output of the wrapped PyTorch model for the given input, along with a callback to handle the backward pass. diff --git a/thinc/shims/pytorch.py b/thinc/shims/pytorch.py index 81a2fe11f..84d4c08e5 100644 --- a/thinc/shims/pytorch.py +++ b/thinc/shims/pytorch.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, cast +from typing import Any, Dict, Optional, cast, Callable import contextlib from io import BytesIO import itertools @@ -30,6 +30,14 @@ class PyTorchShim(Shim): The PyTorch device to run the model on. When this argument is set to "None", the default device for the currently active Thinc ops is used. + serialize_model: + Callback that receives the wrapped PyTorch model as its argument and + returns a "bytes" representation of the same. The representation should + contain all the necessary information to fully deserialize the model. + deserialize_model: + Callback that receives the default PyTorch model (passed to the constructor), the + serialized "bytes" representation and a PyTorch device. It should return a + fully deserialized model on the target device as its result. """ def __init__( @@ -40,6 +48,8 @@ def __init__( mixed_precision: bool = False, grad_scaler: Optional[PyTorchGradScaler] = None, device: Optional["torch.device"] = None, + serialize_model: Optional[Callable[[Any], bytes]] = None, + deserialize_model: Optional[Callable[[Any, bytes, "torch.device"], Any]] = None, ): super().__init__(model, config, optimizer) @@ -54,9 +64,11 @@ def __init__( grad_scaler.to_(device) self._grad_scaler = grad_scaler - self._mixed_precision = mixed_precision + self._serialize_model = serialize_model if serialize_model is not None else default_serialize_torch_model + self._deserialize_model = deserialize_model if deserialize_model is not None else default_deserialize_torch_model + if CupyOps.xp is not None and isinstance(get_current_ops(), CupyOps): pools = context_pools.get() if "pytorch" not in pools: @@ -179,20 +191,52 @@ def to_device(self, device_type: str, device_id: int): # pragma: no cover raise ValueError(msg) def to_bytes(self): - filelike = BytesIO() - torch.save(self._model.state_dict(), filelike) - filelike.seek(0) - weights_bytes = filelike.getvalue() - msg = {"config": self.cfg, "state": weights_bytes} + model_bytes = self._serialize_model(self._model) + msg = {"config": self.cfg, "state": model_bytes} return srsly.msgpack_dumps(msg) def from_bytes(self, bytes_data): device = get_torch_default_device() msg = srsly.msgpack_loads(bytes_data) self.cfg = msg["config"] - filelike = BytesIO(msg["state"]) - filelike.seek(0) - self._model.load_state_dict(torch.load(filelike, map_location=device)) - self._model.to(device) + self._model = self._deserialize_model(self._model, msg["state"], device) self._grad_scaler.to_(device) return self + + +def default_serialize_torch_model(model: Any) -> bytes: + """Serializes the parameters of the wrapped PyTorch model to bytes. + + model: + Wrapped PyTorch model. + + Returns: + A `bytes` object that encapsulates the serialized model parameters. + """ + filelike = BytesIO() + torch.save(model.state_dict(), filelike) + filelike.seek(0) + return filelike.getvalue() + + +def default_deserialize_torch_model( + model: Any, state_bytes: bytes, device: "torch.device" +) -> Any: + """Deserializes the parameters of the wrapped PyTorch model and + moves it to the specified device. + + model: + Wrapped PyTorch model. + state_bytes: + Serialized parameters as a byte stream. + device: + PyTorch device to which the model is bound. + + Returns: + The deserialized model. + """ + filelike = BytesIO(state_bytes) + filelike.seek(0) + model.load_state_dict(torch.load(filelike, map_location=device)) + model.to(device) + return model diff --git a/thinc/tests/layers/test_pytorch_wrapper.py b/thinc/tests/layers/test_pytorch_wrapper.py index d2eeaeb97..1e3f928e9 100644 --- a/thinc/tests/layers/test_pytorch_wrapper.py +++ b/thinc/tests/layers/test_pytorch_wrapper.py @@ -1,9 +1,11 @@ -from thinc.api import Linear, SGD, PyTorchWrapper, PyTorchWrapper_v2 +from thinc.api import Linear, SGD, PyTorchWrapper, PyTorchWrapper_v2, PyTorchWrapper_v3 from thinc.api import xp2torch, torch2xp, ArgsKwargs, use_ops from thinc.api import chain, get_current_ops, Relu from thinc.api import CupyOps, MPSOps, NumpyOps from thinc.backends import context_pools +from thinc.layers.pytorchwrapper import PyTorchWrapper_v3 from thinc.shims.pytorch_grad_scaler import PyTorchGradScaler +from thinc.shims.pytorch import default_deserialize_torch_model, default_serialize_torch_model from thinc.compat import has_torch, has_torch_amp from thinc.compat import has_cupy_gpu, has_torch_mps_gpu import numpy @@ -169,3 +171,30 @@ def test_pytorch_convert_inputs(data, n_args, kwargs_keys): convert_inputs = model.attrs["convert_inputs"] Y, backprop = convert_inputs(model, data, is_train=True) check_input_converters(Y, backprop, data, n_args, kwargs_keys, torch.Tensor) + + +@pytest.mark.skipif(not has_torch, reason="needs PyTorch") +def test_pytorch_wrapper_custom_serde(): + import torch.nn + + def serialize(model): + return default_serialize_torch_model(model) + + def deserialize(model, state_bytes, device): + return default_deserialize_torch_model(model, state_bytes, device) + + def get_model(): + return PyTorchWrapper_v3( + torch.nn.Linear(2, 3), + serialize_model=serialize, + deserialize_model=deserialize, + ) + + model = get_model() + model_bytes = model.to_bytes() + get_model().from_bytes(model_bytes) + with make_tempdir() as path: + model_path = path / "model" + model.to_disk(model_path) + new_model = get_model().from_bytes(model_bytes) + new_model.from_disk(model_path)