Skip to content

Commit

Permalink
Allow exporting to onnx when input is tuple (#8800)
Browse files Browse the repository at this point in the history
Fixes #8799
  • Loading branch information
xerus authored and justusschock committed Sep 7, 2021
1 parent 881285b commit 6130540
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `accelerator=ddp` choice for CPU ([#8645](https://github.com/PyTorchLightning/pytorch-lightning/pull/8645))


- Fixed an issues with export to ONNX format when a model has multiple inputs ([#8800](https://github.com/PyTorchLightning/pytorch-lightning/pull/8800))


## [1.4.0] - 2021-07-27

### Added
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1892,7 +1892,10 @@ def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = Non

if "example_outputs" not in kwargs:
self.eval()
kwargs["example_outputs"] = self(input_sample)
if isinstance(input_sample, Tuple):
kwargs["example_outputs"] = self(*input_sample)
else:
kwargs["example_outputs"] = self(input_sample)

torch.onnx.export(self, input_sample, file_path, **kwargs)
self.train(mode)
Expand Down
16 changes: 12 additions & 4 deletions tests/models/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning import Trainer
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
from tests.utilities.test_model_summary import UnorderedModel


def test_model_saves_with_input_sample(tmpdir):
Expand Down Expand Up @@ -66,10 +67,17 @@ def test_model_saves_with_example_output(tmpdir):
assert os.path.exists(file_path) is True


def test_model_saves_with_example_input_array(tmpdir):
"""Test that ONNX model saves with_example_input_array and size is greater than 3 MB"""
model = BoringModel()
model.example_input_array = torch.randn(5, 32)
@pytest.mark.parametrize(
["modelclass", "input_sample"],
[
(BoringModel, torch.randn(1, 32)),
(UnorderedModel, (torch.rand(2, 3), torch.rand(2, 10))),
],
)
def test_model_saves_with_example_input_array(tmpdir, modelclass, input_sample):
"""Test that ONNX model saves with example_input_array and size is greater than 3 MB"""
model = modelclass()
model.example_input_array = input_sample

file_path = os.path.join(tmpdir, "model.onnx")
model.to_onnx(file_path)
Expand Down

0 comments on commit 6130540

Please sign in to comment.