Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Jit support (#389)
Browse files Browse the repository at this point in the history
* Add support for jit script

* Add jit test to image classification

* Add jit for object detection

* Add style transfer jit

* Add jit support for embedding

* Add jit for tabular and template

* init

* Add jit for text classification

* Add video classification jit

* Add seq2seq jit

* Fixes

* Add jit support matrix

* Update CHANGELOG.md
  • Loading branch information
ethanwharris authored Jun 10, 2021
1 parent a047330 commit b49bf04
Show file tree
Hide file tree
Showing 27 changed files with 358 additions and 24 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for `torch.jit` to tasks where possible and documented task JIT compatibility ([#389](https://github.com/PyTorchLightning/lightning-flash/pull/389))

### Changed

Expand Down
59 changes: 59 additions & 0 deletions docs/source/general/jit.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#######################
TorchScript JIT Support
#######################

.. _jit:

We test all of our tasks for compatibility with :mod:`torch.jit`.
This table gives a breakdown of the supported features.

.. list-table::
:widths: 25 25 25 25
:header-rows: 1

* - Task
- :func:`torch.jit.script`
- :func:`torch.jit.trace`
- :func:`torch.jit.save`
* - :class:`~flash.image.classification.model.ImageClassifier`
- Yes
- Yes
- Yes
* - :class:`~flash.image.detection.model.ObjectDetector`
- Yes
- No
- Yes
* - :class:`~flash.image.embedding.model.ImageEmbedder`
- Yes
- Yes
- Yes
* - :class:`~flash.image.segmentation.model.SemanticSegmentation`
- Yes
- Yes
- Yes
* - :class:`~flash.image.style_transfer.model.StyleTransfer`
- No
- Yes
- Yes
* - :class:`~flash.tabular.classification.model.TabularClassifier`
- No
- Yes
- No
* - :class:`~flash.text.classification.model.TabularClassifier`
- No
- Yes :sup:`*`
- Yes
* - :class:`~flash.text.seq2seq.summarization.model.SummarizationTask`
- No
- Yes
- Yes
* - :class:`~flash.text.seq2seq.translation.model.TranslationTask`
- No
- Yes
- Yes
* - :class:`~flash.video.classification.model.VideoClassifier`
- No
- Yes
- Yes

:sup:`*` Only with ``strict=False``.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Lightning Flash
general/training
general/finetuning
general/predictions
general/jit


.. toctree::
Expand Down
4 changes: 2 additions & 2 deletions flash/core/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
pass


class Preprocess(BasePreprocess, Properties, Module):
class Preprocess(BasePreprocess, Properties):
"""The :class:`~flash.core.data.process.Preprocess` encapsulates all the data processing logic that should run before
the data is passed to the model. It is particularly useful when you want to provide an end to end implementation
which works with 4 different stages: ``train``, ``validation``, ``test``, and inference (``predict``).
Expand Down Expand Up @@ -454,7 +454,7 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool):
return cls(**state_dict)


class Postprocess(Properties, Module):
class Postprocess(Properties):

def __init__(self, save_path: Optional[str] = None):
super().__init__()
Expand Down
8 changes: 7 additions & 1 deletion flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,14 @@ def _resolve(

return preprocess, postprocess, serializer

@torch.jit.unused
@property
def serializer(self) -> Optional[Serializer]:
"""The current :class:`.Serializer` associated with this model. If this property was set to a mapping
(e.g. ``.serializer = {'output1': SerializerOne()}``) then this will be a :class:`.MappingSerializer`."""
return self._serializer

@torch.jit.unused
@serializer.setter
def serializer(self, serializer: Union[Serializer, Mapping[str, Serializer]]):
if isinstance(serializer, Mapping):
Expand Down Expand Up @@ -350,12 +352,14 @@ def build_data_pipeline(
self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state)
return data_pipeline

@torch.jit.unused
@property
def data_pipeline(self) -> DataPipeline:
"""The current :class:`.DataPipeline`. If set, the new value will override the :class:`.Task` defaults. See
:py:meth:`~build_data_pipeline` for more details on the resolution order."""
return self.build_data_pipeline()

@torch.jit.unused
@data_pipeline.setter
def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None:
self._preprocess, self._postprocess, self.serializer = Task._resolve(
Expand All @@ -366,14 +370,16 @@ def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None:
getattr(data_pipeline, '_postprocess_pipeline', None),
getattr(data_pipeline, '_serializer', None),
)
self._preprocess.state_dict()
# self._preprocess.state_dict()
if getattr(self._preprocess, "_ddp_params_and_buffers_to_ignore", None):
self._ddp_params_and_buffers_to_ignore = self._preprocess._ddp_params_and_buffers_to_ignore

@torch.jit.unused
@property
def preprocess(self) -> Preprocess:
return getattr(self.data_pipeline, '_preprocess_pipeline', None)

@torch.jit.unused
@property
def postprocess(self) -> Postprocess:
return getattr(self.data_pipeline, '_postprocess_pipeline', None)
Expand Down
3 changes: 0 additions & 3 deletions flash/image/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,11 @@
from types import FunctionType
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union

import pytorch_lightning as pl
import torch
import torchmetrics
from pytorch_lightning.callbacks.base import Callback
from torch import nn
from torch.optim.lr_scheduler import _LRScheduler

import flash
from flash.core.classification import ClassificationTask
from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.process import Serializer
Expand Down
9 changes: 6 additions & 3 deletions flash/image/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ def get_model(
model = RetinaNet(backbone_model, num_classes=num_classes, anchor_generator=anchor_generator)
return model

def forward(self, x: List[torch.Tensor]) -> Any:
return self.model(x)

def training_step(self, batch, batch_idx) -> Any:
"""The training step. Overrides ``Task.training_step``
"""
Expand All @@ -178,7 +181,7 @@ def training_step(self, batch, batch_idx) -> Any:
def validation_step(self, batch, batch_idx):
images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]
# fasterrcnn takes only images for eval() mode
outs = self.model(images)
outs = self(images)
iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean()
self.log("val_iou", iou)

Expand All @@ -188,13 +191,13 @@ def on_validation_end(self) -> None:
def test_step(self, batch, batch_idx):
images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]
# fasterrcnn takes only images for eval() mode
outs = self.model(images)
outs = self(images)
iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean()
self.log("test_iou", iou)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
images = batch[DefaultDataKeys.INPUT]
return self.model(images)
return self(images)

def configure_finetune_callback(self):
return [ObjectDetectionFineTuning(train_bn=True)]
Expand Down
15 changes: 7 additions & 8 deletions flash/image/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Type, Union

import torch
from pytorch_lightning.utilities.distributed import rank_zero_warn
Expand Down Expand Up @@ -89,13 +89,12 @@ def __init__(
rank_zero_warn('embedding_dim. Remember to finetune first!')

def apply_pool(self, x):
if self.pooling_fn == torch.max:
# torch.max also returns argmax
x = self.pooling_fn(x, dim=-1)[0]
x = self.pooling_fn(x, dim=-1)[0]
else:
x = self.pooling_fn(x, dim=-1)
x = self.pooling_fn(x, dim=-1)
x = self.pooling_fn(x, dim=-1)
if torch.jit.isinstance(x, Tuple[torch.Tensor, torch.Tensor]):
x = x[0]
x = self.pooling_fn(x, dim=-1)
if torch.jit.isinstance(x, Tuple[torch.Tensor, torch.Tensor]):
x = x[0]
return x

def forward(self, x) -> torch.Tensor:
Expand Down
5 changes: 2 additions & 3 deletions flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,12 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A

def forward(self, x) -> torch.Tensor:
# infer the image to the model
res: Union[torch.Tensor, Dict[str, torch.Tensor]] = self.backbone(x)
res = self.backbone(x)

# some frameworks like torchvision return a dict.
# In particular, torchvision segmentation models return the output logits
# in the key `out`.
out: torch.Tensor
if isinstance(res, dict):
if torch.jit.isinstance(res, Dict[str, torch.Tensor]):
out = res['out']
elif torch.is_tensor(res):
out = res
Expand Down
6 changes: 5 additions & 1 deletion flash/tabular/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,11 @@ def __init__(

def forward(self, x_in) -> torch.Tensor:
# TabNet takes single input, x_in is composed of (categorical, numerical)
x = torch.cat([x for x in x_in if x.numel()], dim=1)
xs = []
for x in x_in:
if x.numel():
xs.append(x)
x = torch.cat(xs, dim=1)
return self.model(x)[0]

def training_step(self, batch: Any, batch_idx: int) -> Any:
Expand Down
32 changes: 29 additions & 3 deletions flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,35 @@ def backbone(self):
# see huggingface's BertForSequenceClassification
return self.model.bert

def forward(self, batch_dict):
return self.model(**batch_dict)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None
):
return self.model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)

def step(self, batch, batch_idx) -> dict:
output = {}
out = self.forward(batch)
out = self.forward(**batch)
loss, logits = out[:2]
output["loss"] = loss
output["y_hat"] = logits
Expand All @@ -91,6 +114,9 @@ def step(self, batch, batch_idx) -> dict:
output["logs"] = {name: metric(probs, batch["labels"]) for name, metric in self.metrics.items()}
return output

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
return self(**batch)

def _ci_benchmark_fn(self, history: List[Dict[str, Any]]):
"""
This function is used only for debugging usage with CI
Expand Down
20 changes: 20 additions & 0 deletions tests/image/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest
import torch

Expand Down Expand Up @@ -108,3 +110,21 @@ def test_multilabel(tmpdir):
assert (torch.tensor(predictions) < 0).sum() == 0
assert len(predictions[0]) == num_classes == len(label)
assert len(torch.unique(label)) <= 2


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32), ))])
def test_jit(tmpdir, jitter, args):
path = os.path.join(tmpdir, "test.pt")

model = ImageClassifier(2)
model.eval()

model = jitter(model, *args)

torch.jit.save(model, path)
model = torch.jit.load(path)

out = model(torch.rand(1, 3, 32, 32))
assert isinstance(out, torch.Tensor)
assert out.shape == torch.Size([1, 2])
22 changes: 22 additions & 0 deletions tests/image/detection/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest
import torch
from pytorch_lightning import Trainer
Expand Down Expand Up @@ -75,3 +77,23 @@ def test_training(tmpdir, model):
dl = DataLoader(ds, collate_fn=collate_fn)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model, dl)


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
def test_jit(tmpdir):
path = os.path.join(tmpdir, "test.pt")

model = ObjectDetector(2)
model.eval()

model = torch.jit.script(model) # torch.jit.trace doesn't work with torchvision RCNN

torch.jit.save(model, path)
model = torch.jit.load(path)

out = model([torch.rand(3, 32, 32)])

# torchvision RCNN always returns a (Losses, Detections) tuple in scripting
out = out[1]

assert {"boxes", "labels", "scores"} <= out[0].keys()
Empty file.
Loading

0 comments on commit b49bf04

Please sign in to comment.