diff --git a/CHANGELOG.md b/CHANGELOG.md
index 22bd7058ba..b5c9ec4dd5 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -46,6 +46,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added instance segmentation task ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608))
+- Added Torch ORT support to Transformer based tasks ([#667](https://github.com/PyTorchLightning/lightning-flash/pull/667))
+
### Changed
- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
diff --git a/docs/source/reference/summarization.rst b/docs/source/reference/summarization.rst
index ff7bedf4bc..6010324cb1 100644
--- a/docs/source/reference/summarization.rst
+++ b/docs/source/reference/summarization.rst
@@ -85,3 +85,21 @@ You can now perform inference from your client like this:
.. literalinclude:: ../../../flash_examples/serve/summarization/client.py
:language: python
:lines: 14-
+
+------
+
+**********************************************
+Accelerate Training & Inference with Torch ORT
+**********************************************
+
+`Torch ORT `__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the ``SummarizationTask`` once installed. See installation instructions `here `__.
+
+.. note::
+
+ Not all Transformer models are supported. See `this table `__ for supported models + branches containing fixes for certain models.
+
+.. code-block:: python
+
+ ...
+
+ model = SummarizationTask(backbone="t5-large", num_classes=datamodule.num_classes, enable_ort=True)
diff --git a/docs/source/reference/text_classification.rst b/docs/source/reference/text_classification.rst
index 42424cc980..989ce2e387 100644
--- a/docs/source/reference/text_classification.rst
+++ b/docs/source/reference/text_classification.rst
@@ -85,3 +85,21 @@ You can now perform inference from your client like this:
.. literalinclude:: ../../../flash_examples/serve/text_classification/client.py
:language: python
:lines: 14-
+
+------
+
+**********************************************
+Accelerate Training & Inference with Torch ORT
+**********************************************
+
+`Torch ORT `__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the ``TextClassifier`` once installed. See installation instructions `here `__.
+
+.. note::
+
+ Not all Transformer models are supported. See `this table `__ for supported models + branches containing fixes for certain models.
+
+.. code-block:: python
+
+ ...
+
+ model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True)
diff --git a/docs/source/reference/translation.rst b/docs/source/reference/translation.rst
index 939e3f544a..cc7c21c517 100644
--- a/docs/source/reference/translation.rst
+++ b/docs/source/reference/translation.rst
@@ -85,3 +85,21 @@ You can now perform inference from your client like this:
.. literalinclude:: ../../../flash_examples/serve/translation/client.py
:language: python
:lines: 14-
+
+------
+
+**********************************************
+Accelerate Training & Inference with Torch ORT
+**********************************************
+
+`Torch ORT `__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the ``TranslationTask`` once installed. See installation instructions `here `__.
+
+.. note::
+
+ Not all Transformer models are supported. See `this table `__ for supported models + branches containing fixes for certain models.
+
+.. code-block:: python
+
+ ...
+
+ model = TranslationTask(backbone="t5-large", num_classes=datamodule.num_classes, enable_ort=True)
diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py
index 1a4837c68b..015c432c57 100644
--- a/flash/core/utilities/imports.py
+++ b/flash/core/utilities/imports.py
@@ -96,6 +96,7 @@ def _compare_version(package: str, op, version) -> bool:
_SENTENCEPIECE_AVAILABLE = _module_available("sentencepiece")
_DATASETS_AVAILABLE = _module_available("datasets")
_ICEVISION_AVAILABLE = _module_available("icevision")
+_TORCH_ORT_AVAILABLE = _module_available("torch_ort")
if Version:
_TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0")
diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py
index c9ba5fa0a1..cf339153a0 100644
--- a/flash/text/classification/model.py
+++ b/flash/text/classification/model.py
@@ -16,15 +16,17 @@
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union
import torch
+from pytorch_lightning import Callback
from torchmetrics import Metric
from flash.core.classification import ClassificationTask, Labels
from flash.core.data.process import Serializer
from flash.core.utilities.imports import _TEXT_AVAILABLE
+from flash.text.ort_callback import ORTCallback
if _TEXT_AVAILABLE:
- from transformers import BertForSequenceClassification
- from transformers.modeling_outputs import SequenceClassifierOutput
+ from transformers import AutoModelForSequenceClassification
+ from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput, SequenceClassifierOutput
class TextClassifier(ClassificationTask):
@@ -43,6 +45,7 @@ class TextClassifier(ClassificationTask):
learning_rate: Learning rate to use for training, defaults to `1e-3`
multi_label: Whether the targets are multi-label or not.
serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs.
+ enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""
required_extras: str = "text"
@@ -57,6 +60,7 @@ def __init__(
learning_rate: float = 1e-2,
multi_label: bool = False,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
+ enable_ort: bool = False,
):
self.save_hyperparameters()
@@ -76,25 +80,24 @@ def __init__(
multi_label=multi_label,
serializer=serializer or Labels(multi_label=multi_label),
)
- self.model = BertForSequenceClassification.from_pretrained(backbone, num_labels=num_classes)
-
+ self.enable_ort = enable_ort
+ self.model = AutoModelForSequenceClassification.from_pretrained(backbone, num_labels=num_classes)
self.save_hyperparameters()
@property
def backbone(self):
- # see huggingface's BertForSequenceClassification
- return self.model.bert
+ return self.model.base_model
def forward(self, batch: Dict[str, torch.Tensor]):
return self.model(input_ids=batch.get("input_ids", None), attention_mask=batch.get("attention_mask", None))
def to_loss_format(self, x) -> torch.Tensor:
- if isinstance(x, SequenceClassifierOutput):
+ if isinstance(x, (SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput)):
x = x.logits
return super().to_loss_format(x)
def to_metrics_format(self, x) -> torch.Tensor:
- if isinstance(x, SequenceClassifierOutput):
+ if isinstance(x, (SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput)):
x = x.logits
return super().to_metrics_format(x)
@@ -112,3 +115,9 @@ def _ci_benchmark_fn(self, history: List[Dict[str, Any]]):
assert history[-1]["val_f1"] > 0.40, history[-1]["val_f1"]
else:
assert history[-1]["val_accuracy"] > 0.70, history[-1]["val_accuracy"]
+
+ def configure_callbacks(self) -> List[Callback]:
+ callbacks = super().configure_callbacks() or []
+ if self.enable_ort:
+ callbacks.append(ORTCallback())
+ return callbacks
diff --git a/flash/text/ort_callback.py b/flash/text/ort_callback.py
new file mode 100644
index 0000000000..b3d1a615a3
--- /dev/null
+++ b/flash/text/ort_callback.py
@@ -0,0 +1,52 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from pytorch_lightning import Callback, LightningModule
+from pytorch_lightning.utilities.exceptions import MisconfigurationException
+
+from flash import Trainer
+from flash.core.utilities.imports import _TORCH_ORT_AVAILABLE
+
+if _TORCH_ORT_AVAILABLE:
+ from torch_ort import ORTModule
+
+
+class ORTCallback(Callback):
+ """Enables Torch ORT: Accelerate PyTorch models with ONNX Runtime.
+
+ Wraps a model with the ORT wrapper, lazily converting your module into an ONNX export, to optimize for
+ training and inference.
+
+ Usage:
+
+ # via Transformer Tasks
+ model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True)
+
+ # or via the trainer
+ trainer = flash.Trainer(callbacks=ORTCallback())
+ """
+
+ def __init__(self):
+ if not _TORCH_ORT_AVAILABLE:
+ raise MisconfigurationException(
+ "Torch ORT is required to use ORT. See here for installation: https://github.com/pytorch/ort"
+ )
+
+ def on_before_accelerator_backend_setup(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ if not hasattr(pl_module, "model"):
+ raise MisconfigurationException(
+ "Torch ORT requires to wrap a single model that defines a forward function "
+ "assigned as `model` inside the `LightningModule`."
+ )
+ if not isinstance(pl_module.model, ORTModule):
+ pl_module.model = ORTModule(pl_module.model)
diff --git a/flash/text/seq2seq/core/model.py b/flash/text/seq2seq/core/model.py
index 283abaf120..d79ca18a78 100644
--- a/flash/text/seq2seq/core/model.py
+++ b/flash/text/seq2seq/core/model.py
@@ -16,6 +16,7 @@
from typing import Any, Callable, List, Mapping, Optional, Sequence, Type, Union
import torch
+from pytorch_lightning import Callback
from pytorch_lightning.utilities import rank_zero_info
from torch import Tensor
from torchmetrics import Metric
@@ -23,6 +24,7 @@
from flash.core.finetuning import FlashBaseFinetuning
from flash.core.model import Task
from flash.core.utilities.imports import _TEXT_AVAILABLE
+from flash.text.ort_callback import ORTCallback
from flash.text.seq2seq.core.finetuning import Seq2SeqFreezeEmbeddings
if _TEXT_AVAILABLE:
@@ -54,6 +56,7 @@ class Seq2SeqTask(Task):
learning_rate: Learning rate to use for training, defaults to `3e-4`
val_target_max_length: Maximum length of targets in validation. Defaults to `128`
num_beams: Number of beams to use in validation when generating predictions. Defaults to `4`
+ enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""
required_extras: str = "text"
@@ -67,6 +70,7 @@ def __init__(
learning_rate: float = 5e-5,
val_target_max_length: Optional[int] = None,
num_beams: Optional[int] = None,
+ enable_ort: bool = False,
):
os.environ["TOKENIZERS_PARALLELISM"] = "TRUE"
# disable HF thousand warnings
@@ -75,6 +79,7 @@ def __init__(
os.environ["PYTHONWARNINGS"] = "ignore"
super().__init__(loss_fn=loss_fn, optimizer=optimizer, metrics=metrics, learning_rate=learning_rate)
self.model = AutoModelForSeq2SeqLM.from_pretrained(backbone)
+ self.enable_ort = enable_ort
self.val_target_max_length = val_target_max_length
self.num_beams = num_beams
self._initialize_model_specific_parameters()
@@ -134,3 +139,9 @@ def tokenize_labels(self, labels: Tensor) -> List[str]:
def configure_finetune_callback(self) -> List[FlashBaseFinetuning]:
return [Seq2SeqFreezeEmbeddings(self.model.config.model_type, train_bn=True)]
+
+ def configure_callbacks(self) -> List[Callback]:
+ callbacks = super().configure_callbacks() or []
+ if self.enable_ort:
+ callbacks.append(ORTCallback())
+ return callbacks
diff --git a/flash/text/seq2seq/question_answering/model.py b/flash/text/seq2seq/question_answering/model.py
index 2db3a6d6aa..0ebec8aed3 100644
--- a/flash/text/seq2seq/question_answering/model.py
+++ b/flash/text/seq2seq/question_answering/model.py
@@ -42,6 +42,7 @@ class QuestionAnsweringTask(Seq2SeqTask):
num_beams: Number of beams to use in validation when generating predictions. Defaults to `4`
use_stemmer: Whether Porter stemmer should be used to strip word suffixes to improve matching.
rouge_newline_sep: Add a new line at the beginning of each sentence in Rouge Metric calculation.
+ enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""
def __init__(
@@ -55,6 +56,7 @@ def __init__(
num_beams: Optional[int] = 4,
use_stemmer: bool = True,
rouge_newline_sep: bool = True,
+ enable_ort: bool = False,
):
self.save_hyperparameters()
super().__init__(
@@ -65,6 +67,7 @@ def __init__(
learning_rate=learning_rate,
val_target_max_length=val_target_max_length,
num_beams=num_beams,
+ enable_ort=enable_ort,
)
self.rouge = RougeMetric(
rouge_newline_sep=rouge_newline_sep,
diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py
index af7820b10e..19e812baf1 100644
--- a/flash/text/seq2seq/summarization/model.py
+++ b/flash/text/seq2seq/summarization/model.py
@@ -42,6 +42,7 @@ class SummarizationTask(Seq2SeqTask):
num_beams: Number of beams to use in validation when generating predictions. Defaults to `4`
use_stemmer: Whether Porter stemmer should be used to strip word suffixes to improve matching.
rouge_newline_sep: Add a new line at the beginning of each sentence in Rouge Metric calculation.
+ enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""
def __init__(
@@ -55,6 +56,7 @@ def __init__(
num_beams: Optional[int] = 4,
use_stemmer: bool = True,
rouge_newline_sep: bool = True,
+ enable_ort: bool = False,
):
self.save_hyperparameters()
super().__init__(
@@ -65,6 +67,7 @@ def __init__(
learning_rate=learning_rate,
val_target_max_length=val_target_max_length,
num_beams=num_beams,
+ enable_ort=enable_ort,
)
self.rouge = RougeMetric(
rouge_newline_sep=rouge_newline_sep,
diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py
index ad99f47e31..c70089e8d6 100644
--- a/flash/text/seq2seq/translation/model.py
+++ b/flash/text/seq2seq/translation/model.py
@@ -42,6 +42,7 @@ class TranslationTask(Seq2SeqTask):
num_beams: Number of beams to use in validation when generating predictions. Defaults to `4`
n_gram: Maximum n_grams to use in metric calculation. Defaults to `4`
smooth: Apply smoothing in BLEU calculation. Defaults to `True`
+ enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""
def __init__(
@@ -55,6 +56,7 @@ def __init__(
num_beams: Optional[int] = 4,
n_gram: bool = 4,
smooth: bool = True,
+ enable_ort: bool = False,
):
self.save_hyperparameters()
super().__init__(
@@ -65,6 +67,7 @@ def __init__(
learning_rate=learning_rate,
val_target_max_length=val_target_max_length,
num_beams=num_beams,
+ enable_ort=enable_ort,
)
self.bleu = BLEUScore(
n_gram=n_gram,
diff --git a/tests/text/classification/test_ort.py b/tests/text/classification/test_ort.py
new file mode 100644
index 0000000000..01d987e092
--- /dev/null
+++ b/tests/text/classification/test_ort.py
@@ -0,0 +1,62 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+import pytest
+import torch
+from pytorch_lightning import Callback
+from pytorch_lightning.core.lightning import LightningModule
+from pytorch_lightning.utilities.exceptions import MisconfigurationException
+
+from flash import Trainer
+from flash.core.utilities.imports import _TORCH_ORT_AVAILABLE
+from flash.text import TextClassifier
+from flash.text.ort_callback import ORTCallback
+from tests.helpers.boring_model import BoringModel
+from tests.helpers.utils import _TEXT_TESTING
+from tests.text.classification.test_model import DummyDataset, TEST_BACKBONE
+
+if _TORCH_ORT_AVAILABLE:
+ from torch_ort import ORTModule
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
+@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.")
+def test_init_train_enable_ort(tmpdir):
+ class TestCallback(Callback):
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ assert isinstance(pl_module.model, ORTModule)
+
+ model = TextClassifier(2, TEST_BACKBONE, enable_ort=True)
+ trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=TestCallback())
+ trainer.fit(
+ model,
+ train_dataloader=torch.utils.data.DataLoader(DummyDataset()),
+ val_dataloaders=torch.utils.data.DataLoader(DummyDataset()),
+ )
+ trainer.test(model, test_dataloaders=torch.utils.data.DataLoader(DummyDataset()))
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.")
+def test_ort_callback_fails_no_model(tmpdir):
+ model = BoringModel()
+ trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=ORTCallback())
+ with pytest.raises(MisconfigurationException, match="Torch ORT requires to wrap a single model"):
+ trainer.fit(
+ model,
+ train_dataloader=torch.utils.data.DataLoader(DummyDataset()),
+ val_dataloaders=torch.utils.data.DataLoader(DummyDataset()),
+ )