Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TorchScriptWrapper_v1 #802

Merged
merged 3 commits into from
Nov 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions thinc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .model import Model, serialize_attr, deserialize_attr
from .model import set_dropout_rate, change_attr_values, wrap_model_recursive
from .shims import Shim, PyTorchGradScaler, PyTorchShim, TensorFlowShim, keras_model_fns
from .shims import MXNetShim, maybe_handshake_model
from .shims import MXNetShim, TorchScriptShim, maybe_handshake_model
from .optimizers import Adam, RAdam, SGD, Optimizer
from .schedules import cyclic_triangular, warmup_linear, constant, constant_then
from .schedules import decaying, slanted_triangular, compounding
Expand All @@ -30,7 +30,8 @@
from .layers import Dish, HardSwish, HardSwishMobilenet, Swish, Gelu
from .layers import PyTorchWrapper, PyTorchRNNWrapper, PyTorchLSTM
from .layers import TensorFlowWrapper, keras_subclass, MXNetWrapper
from .layers import PyTorchWrapper_v2, Softmax_v2
from .layers import PyTorchWrapper_v2, Softmax_v2, TorchScriptWrapper_v1
from .layers import pytorch_to_torchscript_wrapper

from .layers import add, bidirectional, chain, clone, concatenate, noop
from .layers import residual, uniqued, siamese, list2ragged, ragged2list
Expand Down Expand Up @@ -61,7 +62,7 @@
"set_dropout_rate", "change_attr_values", "wrap_model_recursive",
# .shims
"Shim", "PyTorchGradScaler", "PyTorchShim", "TensorFlowShim", "keras_model_fns",
"MXNetShim", "maybe_handshake_model",
"MXNetShim", "TorchScriptShim", "maybe_handshake_model",
# .optimizers
"Adam", "RAdam", "SGD", "Optimizer",
# .schedules
Expand Down Expand Up @@ -92,6 +93,7 @@
"PyTorchWrapper", "PyTorchRNNWrapper", "PyTorchLSTM",
"TensorFlowWrapper", "keras_subclass", "MXNetWrapper",
"PyTorchWrapper_v2", "Softmax_v2", "SparseLinear_v2",
"TorchScriptWrapper_v1",

"add", "bidirectional", "chain", "clone", "concatenate", "noop",
"residual", "uniqued", "siamese", "list2ragged", "ragged2list",
Expand All @@ -103,6 +105,7 @@
"array_getitem", "with_cpu", "with_debug", "with_nvtx_range",
"with_signpost_interval",
"tuplify",
"pytorch_to_torchscript_wrapper",

"reduce_first", "reduce_last", "reduce_max", "reduce_mean", "reduce_sum",
]
Expand Down
3 changes: 3 additions & 0 deletions thinc/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .softmax import Softmax, Softmax_v2
from .sparselinear import SparseLinear, SparseLinear_v2
from .tensorflowwrapper import TensorFlowWrapper, keras_subclass
from .torchscriptwrapper import TorchScriptWrapper_v1, pytorch_to_torchscript_wrapper
from .mxnetwrapper import MXNetWrapper

# Combinators
Expand Down Expand Up @@ -102,6 +103,7 @@
"SparseLinear",
"SparseLinear_v2",
"TensorFlowWrapper",
"TorchScriptWrapper_v1",
"add",
"bidirectional",
"chain",
Expand Down Expand Up @@ -154,5 +156,6 @@
"strings2arrays",
"array_getitem",
"tuplify",
"pytorch_to_torchscript_wrapper",
]
# fmt: on
90 changes: 90 additions & 0 deletions thinc/layers/torchscriptwrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import Any, Callable, Optional

from ..compat import torch
from ..model import Model
from ..shims import PyTorchGradScaler, PyTorchShim, TorchScriptShim
from .pytorchwrapper import forward, convert_pytorch_default_inputs
from .pytorchwrapper import convert_pytorch_default_outputs


def TorchScriptWrapper_v1(
torchscript_model: Optional["torch.jit.ScriptModule"] = None,
convert_inputs: Optional[Callable] = None,
convert_outputs: Optional[Callable] = None,
mixed_precision: bool = False,
grad_scaler: Optional[PyTorchGradScaler] = None,
device: Optional["torch.device"] = None,
) -> Model[Any, Any]:
"""Wrap a TorchScript model, so that it has the same API as Thinc models.

torchscript_model:
The TorchScript module. A value of `None` is also possible to
construct a shim to deserialize into.
convert_inputs:
Function that converts inputs and gradients that should be passed
to the model to Torch tensors.
convert_outputs:
Function that converts model outputs and gradients from Torch tensors
Thinc arrays.
mixed_precision:
danieldk marked this conversation as resolved.
Show resolved Hide resolved
Enable mixed-precision. This changes whitelisted ops to run
in half precision for better performance and lower memory use.
grad_scaler:
The gradient scaler to use for mixed-precision training. If this
argument is set to "None" and mixed precision is enabled, a gradient
scaler with the default configuration is used.
device:
The PyTorch device to run the model on. When this argument is
set to "None", the default device for the currently active Thinc
ops is used.
"""

if convert_inputs is None:
convert_inputs = convert_pytorch_default_inputs
if convert_outputs is None:
convert_outputs = convert_pytorch_default_outputs

return Model(
"pytorch_script",
forward,
attrs={"convert_inputs": convert_inputs, "convert_outputs": convert_outputs},
shims=[
TorchScriptShim(
model=torchscript_model,
mixed_precision=mixed_precision,
grad_scaler=grad_scaler,
device=device,
)
],
dims={"nI": None, "nO": None},
)


def pytorch_to_torchscript_wrapper(model: Model):
"""Convert a PyTorch wrapper to a TorchScript wrapper. The embedded PyTorch
`Module` is converted to `ScriptModule`.
"""
shim = model.shims[0]
if not isinstance(shim, PyTorchShim):
raise ValueError("Expected PyTorchShim when converting a PyTorch wrapper")

convert_inputs = model.attrs["convert_inputs"]
convert_outputs = model.attrs["convert_outputs"]

pytorch_model = shim._model
if not isinstance(pytorch_model, torch.nn.Module):
raise ValueError("PyTorchShim does not wrap a PyTorch module")

torchscript_model = torch.jit.script(pytorch_model)
grad_scaler = shim._grad_scaler
mixed_precision = shim._mixed_precision
device = shim.device

return TorchScriptWrapper_v1(
torchscript_model,
convert_inputs=convert_inputs,
convert_outputs=convert_outputs,
mixed_precision=mixed_precision,
grad_scaler=grad_scaler,
device=device,
)
10 changes: 7 additions & 3 deletions thinc/shims/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@
from .pytorch import PyTorchShim
from .pytorch_grad_scaler import PyTorchGradScaler
from .tensorflow import keras_model_fns, TensorFlowShim, maybe_handshake_model
from .torchscript import TorchScriptShim
from .mxnet import MXNetShim


# fmt: off
__all__ = [
"Shim",
"MXNetShim",
"PyTorchShim",
"PyTorchGradScaler",
"keras_model_fns", "TensorFlowShim", "maybe_handshake_model",
"MXNetShim",
"Shim",
"TensorFlowShim",
"TorchScriptShim",
"maybe_handshake_model",
"keras_model_fns",
]
# fmt: on
63 changes: 63 additions & 0 deletions thinc/shims/torchscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Any, Optional
from io import BytesIO
import srsly

from ..compat import torch
from ..util import get_torch_default_device
from .pytorch import PyTorchShim
from .pytorch_grad_scaler import PyTorchGradScaler


class TorchScriptShim(PyTorchShim):
"""A Thinc shim that wraps a TorchScript module.

model:
The TorchScript module. A value of `None` is also possible to
construct a shim to deserialize into.
mixed_precision:
Enable mixed-precision. This changes whitelisted ops to run
in half precision for better performance and lower memory use.
grad_scaler:
The gradient scaler to use for mixed-precision training. If this
argument is set to "None" and mixed precision is enabled, a gradient
scaler with the default configuration is used.
device:
The PyTorch device to run the model on. When this argument is
set to "None", the default device for the currently active Thinc
ops is used.
"""
svlandeg marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
model: Optional["torch.ScriptModule"],
config=None,
optimizer: Any = None,
mixed_precision: bool = False,
grad_scaler: Optional[PyTorchGradScaler] = None,
device: Optional["torch.device"] = None,
):
if model is not None and not isinstance(model, torch.jit.ScriptModule):
raise ValueError(
"PyTorchScriptShim must be initialized with ScriptModule or None (for deserialization)"
)

super().__init__(model, config, optimizer, mixed_precision, grad_scaler, device)

def to_bytes(self):
filelike = BytesIO()
torch.jit.save(self._model, filelike)
filelike.seek(0)
model_bytes = filelike.getvalue()
msg = {"config": self.cfg, "model": model_bytes}
return srsly.msgpack_dumps(msg)

def from_bytes(self, bytes_data):
device = get_torch_default_device()
msg = srsly.msgpack_loads(bytes_data)
self.cfg = msg["config"]
filelike = BytesIO(msg["model"])
filelike.seek(0)
self._model = torch.jit.load(filelike, map_location=device)
self._model.to(device)
self._grad_scaler.to_(device)
return self
25 changes: 25 additions & 0 deletions thinc/tests/layers/test_torchscriptwrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest
import numpy

from thinc.api import PyTorchWrapper_v2, TorchScriptWrapper_v1
from thinc.api import pytorch_to_torchscript_wrapper
from thinc.compat import has_torch, torch


@pytest.mark.skipif(not has_torch, reason="needs PyTorch")
@pytest.mark.parametrize("nN,nI,nO", [(2, 3, 4)])
def test_pytorch_script(nN, nI, nO):

model = PyTorchWrapper_v2(torch.nn.Linear(nI, nO)).initialize()
script_model = pytorch_to_torchscript_wrapper(model)

X = numpy.random.randn(nN, nI).astype("f")
Y = model.predict(X)
Y_script = script_model.predict(X)
numpy.testing.assert_allclose(Y, Y_script)

serialized = script_model.to_bytes()
script_model2 = TorchScriptWrapper_v1()
script_model2.from_bytes(serialized)

numpy.testing.assert_allclose(Y, script_model2.predict(X))
60 changes: 60 additions & 0 deletions website/docs/api-layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -1695,6 +1695,66 @@ the backward pass.
https://github.com/explosion/thinc/blob/master/thinc/layers/pytorchwrapper.py
```

### TorchScriptWrapper_v1 {#torchscriptwrapper tag="function" new="8.1.6"}

<inline-list>

- **Input:** <tt>Any</tt>
- **Output:** <tt>Any</tt>

</inline-list>

Wrap a [TorchScript](https://pytorch.org/docs/stable/jit.html) model so that it
has the same API as Thinc models. To optimize the model, you'll need to create a
PyTorch optimizer and call `optimizer.step` after each batch.

Your TorchScript model's forward method can take arbitrary positional arguments
and keyword arguments, but must return either a **single tensor** as output or a
**tuple**. The convert functions are used to map inputs and outputs to and from
your TorchScript model. Each function should return the converted output, and a
callback to use during the backward pass:

```python
Xtorch, get_dX = convert_inputs(X)
Ytorch, torch_backprop = model.shims[0](Xtorch, is_train)
Y, get_dYtorch = convert_outputs(Ytorch)
```

To allow maximum flexibility, the [`TorchScriptShim`](/docs/api-model#shims)
expects [`ArgsKwargs`](/docs/api-types#argskwargs) objects on the way into the
forward and backward passes. The `ArgsKwargs` objects will be passed straight
into the model in the forward pass, and straight into `torch.autograd.backward`
during the backward pass.

Note that the `torchscript_model` argument can be `None`. This is useful for
deserialization since serialized TorchScript contains both the model and its
weights.

A PyTorch wrapper can be converted to a TorchScript wrapper using the
`pytorch_to_torchscript_wrapper` function:

```python
from thinc.api import PyTorchWrapper_v2, pytorch_to_torchscript_wrapper
import torch

model = PyTorchWrapper_v2(torch.nn.Linear(nI, nO)).initialize()
script_model = pytorch_to_torchscript_wrapper(model)
```

| Argument | Type | Description |
| ------------------- | ----------------------------------------- | ---------------------------------------------------------------------------------------- |
| `torchscript_model` | <tt>Optional[torch.jit.ScriptModule]</tt> | The TorchScript model. |
| `convert_inputs` | <tt>Callable</tt> | Function to convert inputs to PyTorch tensors (same signature as `forward` function). |
| `convert_outputs` | <tt>Callable</tt> | Function to convert outputs from PyTorch tensors (same signature as `forward` function). |
| `mixed_precision` | <tt>bool</tt> | Enable mixed-precision training. |
| `grad_scaler` | <tt>Optional[PyTorchGradScaler]</tt> | Gradient scaler to use during mixed-precision training. |
| `device` | <tt>Optional[torch.Device]</tt> | The Torch device to execute the model on. |
| **RETURNS** | <tt>Model[Any, Any]</tt> | The Thinc model. |

```python
https://github.com/explosion/thinc/blob/master/thinc/layers/torchscriptwrapper.py
```

### TensorFlowWrapper {#tensorflowwrapper tag="function"}

<inline-list>
Expand Down
11 changes: 6 additions & 5 deletions website/docs/api-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -812,8 +812,9 @@ A shim container is **not** a Thinc `Model` subclass itself, it's a subclass of

</infobox>

| | |
| ---------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `PyTorchShim` | Interface between a [PyTorch](https://pytorch.org) model and a Thinc `Model`. For more details and examples, see the [`PyTorchWrapper` layer](/docs/api-layers#pytorchwrapper) and docs on [integrating other frameworks](/docs/usage-frameworks). |
| `TensorFlowShim` | Interface between a [TensorFlow](https://tensorflow.org) model and a Thinc `Model`. For more details, see the [`TensorFlowWrapper` layer](/docs/api-layers#tensorflowwrapper) and docs on [integrating other frameworks](/docs/usage-frameworks) |
| `MXNetShim` | Interface between a [MXNet](https://mxnet.apache.org/) model and a Thinc `Model`. For more details, see the [`MXNetWrapper` layer](/docs/api-layers#mxnetwrapper) and docs on [integrating other frameworks](/docs/usage-frameworks) |
| | |
| ----------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `PyTorchShim` | Interface between a [PyTorch](https://pytorch.org) model and a Thinc `Model`. For more details and examples, see the [`PyTorchWrapper` layer](/docs/api-layers#pytorchwrapper) and docs on [integrating other frameworks](/docs/usage-frameworks). |
| `TorchScriptShim` | Interface between a [TorchScript](https://pytorch.org/docs/stable/jit.html) model and a Thinc `Model`. For more details and examples, see the [`TorchScriptWrapper` layer](/docs/api-layers#torchscriptwrapper) and docs on [integrating other frameworks](/docs/usage-frameworks). |
| `TensorFlowShim` | Interface between a [TensorFlow](https://tensorflow.org) model and a Thinc `Model`. For more details, see the [`TensorFlowWrapper` layer](/docs/api-layers#tensorflowwrapper) and docs on [integrating other frameworks](/docs/usage-frameworks) |
| `MXNetShim` | Interface between a [MXNet](https://mxnet.apache.org/) model and a Thinc `Model`. For more details, see the [`MXNetWrapper` layer](/docs/api-layers#mxnetwrapper) and docs on [integrating other frameworks](/docs/usage-frameworks) |
Loading