Skip to content

Commit

Permalink
Allow exporting to onnx when input is tuple
Browse files Browse the repository at this point in the history
Fixes #8799
  • Loading branch information
xerus committed Aug 11, 2021
1 parent 24f0124 commit 10bb958
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed truncated backprop through time enablement when set as a property on the LightningModule and not the Trainer ([#8804](https://github.com/PyTorchLightning/pytorch-lightning/pull/8804/))


- Fixed an issues with export to ONNX format when a model has multiple inputs ([#8800](https://github.com/PyTorchLightning/pytorch-lightning/pull/8800))


## [1.4.0] - 2021-07-27

### Added
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,7 +1891,10 @@ def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = Non

if "example_outputs" not in kwargs:
self.eval()
kwargs["example_outputs"] = self(input_sample)
if isinstance(input_sample, Tuple):
kwargs["example_outputs"] = self(*input_sample)
else:
kwargs["example_outputs"] = self(input_sample)

torch.onnx.export(self, input_sample, file_path, **kwargs)
self.train(mode)
Expand Down
16 changes: 12 additions & 4 deletions tests/models/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning import Trainer
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
from tests.utilities.test_model_summary import UnorderedModel


def test_model_saves_with_input_sample(tmpdir):
Expand Down Expand Up @@ -66,10 +67,17 @@ def test_model_saves_with_example_output(tmpdir):
assert os.path.exists(file_path) is True


def test_model_saves_with_example_input_array(tmpdir):
"""Test that ONNX model saves with_example_input_array and size is greater than 3 MB"""
model = BoringModel()
model.example_input_array = torch.randn(5, 32)
@pytest.mark.parametrize(
["modelclass", "input_sample"],
[
(BoringModel, torch.randn(1, 32)),
(UnorderedModel, (torch.rand(2, 3), torch.rand(2, 10))),
],
)
def test_model_saves_with_example_input_array(tmpdir, modelclass, input_sample):
"""Test that ONNX model saves with example_input_array and size is greater than 3 MB"""
model = modelclass()
model.example_input_array = input_sample

file_path = os.path.join(tmpdir, "model.onnx")
model.to_onnx(file_path)
Expand Down

0 comments on commit 10bb958

Please sign in to comment.