diff --git a/CHANGELOG.md b/CHANGELOG.md index 22bd7058ba..b5c9ec4dd5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added instance segmentation task ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608)) +- Added Torch ORT support to Transformer based tasks ([#667](https://github.com/PyTorchLightning/lightning-flash/pull/667)) + ### Changed - Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) diff --git a/docs/source/reference/summarization.rst b/docs/source/reference/summarization.rst index ff7bedf4bc..6010324cb1 100644 --- a/docs/source/reference/summarization.rst +++ b/docs/source/reference/summarization.rst @@ -85,3 +85,21 @@ You can now perform inference from your client like this: .. literalinclude:: ../../../flash_examples/serve/summarization/client.py :language: python :lines: 14- + +------ + +********************************************** +Accelerate Training & Inference with Torch ORT +********************************************** + +`Torch ORT `__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the ``SummarizationTask`` once installed. See installation instructions `here `__. + +.. note:: + + Not all Transformer models are supported. See `this table `__ for supported models + branches containing fixes for certain models. + +.. code-block:: python + + ... + + model = SummarizationTask(backbone="t5-large", num_classes=datamodule.num_classes, enable_ort=True) diff --git a/docs/source/reference/text_classification.rst b/docs/source/reference/text_classification.rst index 42424cc980..989ce2e387 100644 --- a/docs/source/reference/text_classification.rst +++ b/docs/source/reference/text_classification.rst @@ -85,3 +85,21 @@ You can now perform inference from your client like this: .. literalinclude:: ../../../flash_examples/serve/text_classification/client.py :language: python :lines: 14- + +------ + +********************************************** +Accelerate Training & Inference with Torch ORT +********************************************** + +`Torch ORT `__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the ``TextClassifier`` once installed. See installation instructions `here `__. + +.. note:: + + Not all Transformer models are supported. See `this table `__ for supported models + branches containing fixes for certain models. + +.. code-block:: python + + ... + + model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True) diff --git a/docs/source/reference/translation.rst b/docs/source/reference/translation.rst index 939e3f544a..cc7c21c517 100644 --- a/docs/source/reference/translation.rst +++ b/docs/source/reference/translation.rst @@ -85,3 +85,21 @@ You can now perform inference from your client like this: .. literalinclude:: ../../../flash_examples/serve/translation/client.py :language: python :lines: 14- + +------ + +********************************************** +Accelerate Training & Inference with Torch ORT +********************************************** + +`Torch ORT `__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the ``TranslationTask`` once installed. See installation instructions `here `__. + +.. note:: + + Not all Transformer models are supported. See `this table `__ for supported models + branches containing fixes for certain models. + +.. code-block:: python + + ... + + model = TranslationTask(backbone="t5-large", num_classes=datamodule.num_classes, enable_ort=True) diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 1a4837c68b..015c432c57 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -96,6 +96,7 @@ def _compare_version(package: str, op, version) -> bool: _SENTENCEPIECE_AVAILABLE = _module_available("sentencepiece") _DATASETS_AVAILABLE = _module_available("datasets") _ICEVISION_AVAILABLE = _module_available("icevision") +_TORCH_ORT_AVAILABLE = _module_available("torch_ort") if Version: _TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0") diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index c9ba5fa0a1..cf339153a0 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -16,15 +16,17 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union import torch +from pytorch_lightning import Callback from torchmetrics import Metric from flash.core.classification import ClassificationTask, Labels from flash.core.data.process import Serializer from flash.core.utilities.imports import _TEXT_AVAILABLE +from flash.text.ort_callback import ORTCallback if _TEXT_AVAILABLE: - from transformers import BertForSequenceClassification - from transformers.modeling_outputs import SequenceClassifierOutput + from transformers import AutoModelForSequenceClassification + from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput, SequenceClassifierOutput class TextClassifier(ClassificationTask): @@ -43,6 +45,7 @@ class TextClassifier(ClassificationTask): learning_rate: Learning rate to use for training, defaults to `1e-3` multi_label: Whether the targets are multi-label or not. serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. + enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training """ required_extras: str = "text" @@ -57,6 +60,7 @@ def __init__( learning_rate: float = 1e-2, multi_label: bool = False, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + enable_ort: bool = False, ): self.save_hyperparameters() @@ -76,25 +80,24 @@ def __init__( multi_label=multi_label, serializer=serializer or Labels(multi_label=multi_label), ) - self.model = BertForSequenceClassification.from_pretrained(backbone, num_labels=num_classes) - + self.enable_ort = enable_ort + self.model = AutoModelForSequenceClassification.from_pretrained(backbone, num_labels=num_classes) self.save_hyperparameters() @property def backbone(self): - # see huggingface's BertForSequenceClassification - return self.model.bert + return self.model.base_model def forward(self, batch: Dict[str, torch.Tensor]): return self.model(input_ids=batch.get("input_ids", None), attention_mask=batch.get("attention_mask", None)) def to_loss_format(self, x) -> torch.Tensor: - if isinstance(x, SequenceClassifierOutput): + if isinstance(x, (SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput)): x = x.logits return super().to_loss_format(x) def to_metrics_format(self, x) -> torch.Tensor: - if isinstance(x, SequenceClassifierOutput): + if isinstance(x, (SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput)): x = x.logits return super().to_metrics_format(x) @@ -112,3 +115,9 @@ def _ci_benchmark_fn(self, history: List[Dict[str, Any]]): assert history[-1]["val_f1"] > 0.40, history[-1]["val_f1"] else: assert history[-1]["val_accuracy"] > 0.70, history[-1]["val_accuracy"] + + def configure_callbacks(self) -> List[Callback]: + callbacks = super().configure_callbacks() or [] + if self.enable_ort: + callbacks.append(ORTCallback()) + return callbacks diff --git a/flash/text/ort_callback.py b/flash/text/ort_callback.py new file mode 100644 index 0000000000..b3d1a615a3 --- /dev/null +++ b/flash/text/ort_callback.py @@ -0,0 +1,52 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning import Callback, LightningModule +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from flash import Trainer +from flash.core.utilities.imports import _TORCH_ORT_AVAILABLE + +if _TORCH_ORT_AVAILABLE: + from torch_ort import ORTModule + + +class ORTCallback(Callback): + """Enables Torch ORT: Accelerate PyTorch models with ONNX Runtime. + + Wraps a model with the ORT wrapper, lazily converting your module into an ONNX export, to optimize for + training and inference. + + Usage: + + # via Transformer Tasks + model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True) + + # or via the trainer + trainer = flash.Trainer(callbacks=ORTCallback()) + """ + + def __init__(self): + if not _TORCH_ORT_AVAILABLE: + raise MisconfigurationException( + "Torch ORT is required to use ORT. See here for installation: https://github.com/pytorch/ort" + ) + + def on_before_accelerator_backend_setup(self, trainer: Trainer, pl_module: LightningModule) -> None: + if not hasattr(pl_module, "model"): + raise MisconfigurationException( + "Torch ORT requires to wrap a single model that defines a forward function " + "assigned as `model` inside the `LightningModule`." + ) + if not isinstance(pl_module.model, ORTModule): + pl_module.model = ORTModule(pl_module.model) diff --git a/flash/text/seq2seq/core/model.py b/flash/text/seq2seq/core/model.py index 283abaf120..d79ca18a78 100644 --- a/flash/text/seq2seq/core/model.py +++ b/flash/text/seq2seq/core/model.py @@ -16,6 +16,7 @@ from typing import Any, Callable, List, Mapping, Optional, Sequence, Type, Union import torch +from pytorch_lightning import Callback from pytorch_lightning.utilities import rank_zero_info from torch import Tensor from torchmetrics import Metric @@ -23,6 +24,7 @@ from flash.core.finetuning import FlashBaseFinetuning from flash.core.model import Task from flash.core.utilities.imports import _TEXT_AVAILABLE +from flash.text.ort_callback import ORTCallback from flash.text.seq2seq.core.finetuning import Seq2SeqFreezeEmbeddings if _TEXT_AVAILABLE: @@ -54,6 +56,7 @@ class Seq2SeqTask(Task): learning_rate: Learning rate to use for training, defaults to `3e-4` val_target_max_length: Maximum length of targets in validation. Defaults to `128` num_beams: Number of beams to use in validation when generating predictions. Defaults to `4` + enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training """ required_extras: str = "text" @@ -67,6 +70,7 @@ def __init__( learning_rate: float = 5e-5, val_target_max_length: Optional[int] = None, num_beams: Optional[int] = None, + enable_ort: bool = False, ): os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" # disable HF thousand warnings @@ -75,6 +79,7 @@ def __init__( os.environ["PYTHONWARNINGS"] = "ignore" super().__init__(loss_fn=loss_fn, optimizer=optimizer, metrics=metrics, learning_rate=learning_rate) self.model = AutoModelForSeq2SeqLM.from_pretrained(backbone) + self.enable_ort = enable_ort self.val_target_max_length = val_target_max_length self.num_beams = num_beams self._initialize_model_specific_parameters() @@ -134,3 +139,9 @@ def tokenize_labels(self, labels: Tensor) -> List[str]: def configure_finetune_callback(self) -> List[FlashBaseFinetuning]: return [Seq2SeqFreezeEmbeddings(self.model.config.model_type, train_bn=True)] + + def configure_callbacks(self) -> List[Callback]: + callbacks = super().configure_callbacks() or [] + if self.enable_ort: + callbacks.append(ORTCallback()) + return callbacks diff --git a/flash/text/seq2seq/question_answering/model.py b/flash/text/seq2seq/question_answering/model.py index 2db3a6d6aa..0ebec8aed3 100644 --- a/flash/text/seq2seq/question_answering/model.py +++ b/flash/text/seq2seq/question_answering/model.py @@ -42,6 +42,7 @@ class QuestionAnsweringTask(Seq2SeqTask): num_beams: Number of beams to use in validation when generating predictions. Defaults to `4` use_stemmer: Whether Porter stemmer should be used to strip word suffixes to improve matching. rouge_newline_sep: Add a new line at the beginning of each sentence in Rouge Metric calculation. + enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training """ def __init__( @@ -55,6 +56,7 @@ def __init__( num_beams: Optional[int] = 4, use_stemmer: bool = True, rouge_newline_sep: bool = True, + enable_ort: bool = False, ): self.save_hyperparameters() super().__init__( @@ -65,6 +67,7 @@ def __init__( learning_rate=learning_rate, val_target_max_length=val_target_max_length, num_beams=num_beams, + enable_ort=enable_ort, ) self.rouge = RougeMetric( rouge_newline_sep=rouge_newline_sep, diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py index af7820b10e..19e812baf1 100644 --- a/flash/text/seq2seq/summarization/model.py +++ b/flash/text/seq2seq/summarization/model.py @@ -42,6 +42,7 @@ class SummarizationTask(Seq2SeqTask): num_beams: Number of beams to use in validation when generating predictions. Defaults to `4` use_stemmer: Whether Porter stemmer should be used to strip word suffixes to improve matching. rouge_newline_sep: Add a new line at the beginning of each sentence in Rouge Metric calculation. + enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training """ def __init__( @@ -55,6 +56,7 @@ def __init__( num_beams: Optional[int] = 4, use_stemmer: bool = True, rouge_newline_sep: bool = True, + enable_ort: bool = False, ): self.save_hyperparameters() super().__init__( @@ -65,6 +67,7 @@ def __init__( learning_rate=learning_rate, val_target_max_length=val_target_max_length, num_beams=num_beams, + enable_ort=enable_ort, ) self.rouge = RougeMetric( rouge_newline_sep=rouge_newline_sep, diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py index ad99f47e31..c70089e8d6 100644 --- a/flash/text/seq2seq/translation/model.py +++ b/flash/text/seq2seq/translation/model.py @@ -42,6 +42,7 @@ class TranslationTask(Seq2SeqTask): num_beams: Number of beams to use in validation when generating predictions. Defaults to `4` n_gram: Maximum n_grams to use in metric calculation. Defaults to `4` smooth: Apply smoothing in BLEU calculation. Defaults to `True` + enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training """ def __init__( @@ -55,6 +56,7 @@ def __init__( num_beams: Optional[int] = 4, n_gram: bool = 4, smooth: bool = True, + enable_ort: bool = False, ): self.save_hyperparameters() super().__init__( @@ -65,6 +67,7 @@ def __init__( learning_rate=learning_rate, val_target_max_length=val_target_max_length, num_beams=num_beams, + enable_ort=enable_ort, ) self.bleu = BLEUScore( n_gram=n_gram, diff --git a/tests/text/classification/test_ort.py b/tests/text/classification/test_ort.py new file mode 100644 index 0000000000..01d987e092 --- /dev/null +++ b/tests/text/classification/test_ort.py @@ -0,0 +1,62 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest +import torch +from pytorch_lightning import Callback +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from flash import Trainer +from flash.core.utilities.imports import _TORCH_ORT_AVAILABLE +from flash.text import TextClassifier +from flash.text.ort_callback import ORTCallback +from tests.helpers.boring_model import BoringModel +from tests.helpers.utils import _TEXT_TESTING +from tests.text.classification.test_model import DummyDataset, TEST_BACKBONE + +if _TORCH_ORT_AVAILABLE: + from torch_ort import ORTModule + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.") +def test_init_train_enable_ort(tmpdir): + class TestCallback(Callback): + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + assert isinstance(pl_module.model, ORTModule) + + model = TextClassifier(2, TEST_BACKBONE, enable_ort=True) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=TestCallback()) + trainer.fit( + model, + train_dataloader=torch.utils.data.DataLoader(DummyDataset()), + val_dataloaders=torch.utils.data.DataLoader(DummyDataset()), + ) + trainer.test(model, test_dataloaders=torch.utils.data.DataLoader(DummyDataset())) + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.") +def test_ort_callback_fails_no_model(tmpdir): + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=ORTCallback()) + with pytest.raises(MisconfigurationException, match="Torch ORT requires to wrap a single model"): + trainer.fit( + model, + train_dataloader=torch.utils.data.DataLoader(DummyDataset()), + val_dataloaders=torch.utils.data.DataLoader(DummyDataset()), + )