From 10bb95802f15a35477b1d5851b254d605aaf5962 Mon Sep 17 00:00:00 2001 From: Pavel Grunt Date: Sun, 8 Aug 2021 19:31:57 +0200 Subject: [PATCH] Allow exporting to onnx when input is tuple Fixes #8799 --- CHANGELOG.md | 3 +++ pytorch_lightning/core/lightning.py | 5 ++++- tests/models/test_onnx.py | 16 ++++++++++++---- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e790c38cfcd6..becf55676c73d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -153,6 +153,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed truncated backprop through time enablement when set as a property on the LightningModule and not the Trainer ([#8804](https://github.com/PyTorchLightning/pytorch-lightning/pull/8804/)) +- Fixed an issues with export to ONNX format when a model has multiple inputs ([#8800](https://github.com/PyTorchLightning/pytorch-lightning/pull/8800)) + + ## [1.4.0] - 2021-07-27 ### Added diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index b4a50bf10a577..a2f238c0836d8 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1891,7 +1891,10 @@ def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = Non if "example_outputs" not in kwargs: self.eval() - kwargs["example_outputs"] = self(input_sample) + if isinstance(input_sample, Tuple): + kwargs["example_outputs"] = self(*input_sample) + else: + kwargs["example_outputs"] = self(input_sample) torch.onnx.export(self, input_sample, file_path, **kwargs) self.train(mode) diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py index cec01e828d1ed..7cd1d2776f43c 100644 --- a/tests/models/test_onnx.py +++ b/tests/models/test_onnx.py @@ -23,6 +23,7 @@ from pytorch_lightning import Trainer from tests.helpers import BoringModel from tests.helpers.runif import RunIf +from tests.utilities.test_model_summary import UnorderedModel def test_model_saves_with_input_sample(tmpdir): @@ -66,10 +67,17 @@ def test_model_saves_with_example_output(tmpdir): assert os.path.exists(file_path) is True -def test_model_saves_with_example_input_array(tmpdir): - """Test that ONNX model saves with_example_input_array and size is greater than 3 MB""" - model = BoringModel() - model.example_input_array = torch.randn(5, 32) +@pytest.mark.parametrize( + ["modelclass", "input_sample"], + [ + (BoringModel, torch.randn(1, 32)), + (UnorderedModel, (torch.rand(2, 3), torch.rand(2, 10))), + ], +) +def test_model_saves_with_example_input_array(tmpdir, modelclass, input_sample): + """Test that ONNX model saves with example_input_array and size is greater than 3 MB""" + model = modelclass() + model.example_input_array = input_sample file_path = os.path.join(tmpdir, "model.onnx") model.to_onnx(file_path)