This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for Torch ORT to Transformer based Tasks (#667)
* Add torch ORT support, move transformer Tasks to use general task class * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformers version * Revert * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add tests * Add tests * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add docs for text classification and translation * Add note * Add CHANGELOG.md * Address code review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
- Loading branch information
1 parent
4e89a37
commit 741a838
Showing
12 changed files
with
208 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from pytorch_lightning import Callback, LightningModule | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
|
||
from flash import Trainer | ||
from flash.core.utilities.imports import _TORCH_ORT_AVAILABLE | ||
|
||
if _TORCH_ORT_AVAILABLE: | ||
from torch_ort import ORTModule | ||
|
||
|
||
class ORTCallback(Callback): | ||
"""Enables Torch ORT: Accelerate PyTorch models with ONNX Runtime. | ||
Wraps a model with the ORT wrapper, lazily converting your module into an ONNX export, to optimize for | ||
training and inference. | ||
Usage: | ||
# via Transformer Tasks | ||
model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True) | ||
# or via the trainer | ||
trainer = flash.Trainer(callbacks=ORTCallback()) | ||
""" | ||
|
||
def __init__(self): | ||
if not _TORCH_ORT_AVAILABLE: | ||
raise MisconfigurationException( | ||
"Torch ORT is required to use ORT. See here for installation: https://github.com/pytorch/ort" | ||
) | ||
|
||
def on_before_accelerator_backend_setup(self, trainer: Trainer, pl_module: LightningModule) -> None: | ||
if not hasattr(pl_module, "model"): | ||
raise MisconfigurationException( | ||
"Torch ORT requires to wrap a single model that defines a forward function " | ||
"assigned as `model` inside the `LightningModule`." | ||
) | ||
if not isinstance(pl_module.model, ORTModule): | ||
pl_module.model = ORTModule(pl_module.model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os | ||
|
||
import pytest | ||
import torch | ||
from pytorch_lightning import Callback | ||
from pytorch_lightning.core.lightning import LightningModule | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
|
||
from flash import Trainer | ||
from flash.core.utilities.imports import _TORCH_ORT_AVAILABLE | ||
from flash.text import TextClassifier | ||
from flash.text.ort_callback import ORTCallback | ||
from tests.helpers.boring_model import BoringModel | ||
from tests.helpers.utils import _TEXT_TESTING | ||
from tests.text.classification.test_model import DummyDataset, TEST_BACKBONE | ||
|
||
if _TORCH_ORT_AVAILABLE: | ||
from torch_ort import ORTModule | ||
|
||
|
||
@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") | ||
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") | ||
@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.") | ||
def test_init_train_enable_ort(tmpdir): | ||
class TestCallback(Callback): | ||
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: | ||
assert isinstance(pl_module.model, ORTModule) | ||
|
||
model = TextClassifier(2, TEST_BACKBONE, enable_ort=True) | ||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=TestCallback()) | ||
trainer.fit( | ||
model, | ||
train_dataloader=torch.utils.data.DataLoader(DummyDataset()), | ||
val_dataloaders=torch.utils.data.DataLoader(DummyDataset()), | ||
) | ||
trainer.test(model, test_dataloaders=torch.utils.data.DataLoader(DummyDataset())) | ||
|
||
|
||
@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") | ||
@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.") | ||
def test_ort_callback_fails_no_model(tmpdir): | ||
model = BoringModel() | ||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=ORTCallback()) | ||
with pytest.raises(MisconfigurationException, match="Torch ORT requires to wrap a single model"): | ||
trainer.fit( | ||
model, | ||
train_dataloader=torch.utils.data.DataLoader(DummyDataset()), | ||
val_dataloaders=torch.utils.data.DataLoader(DummyDataset()), | ||
) |