Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow exporting to onnx when input is tuple #8800

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed truncated backprop through time enablement when set as a property on the LightningModule and not the Trainer ([#8804](https://github.com/PyTorchLightning/pytorch-lightning/pull/8804/))


- Fixed an issues with export to ONNX format when a model has multiple inputs ([#8800](https://github.com/PyTorchLightning/pytorch-lightning/pull/8800))


## [1.4.0] - 2021-07-27

### Added
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,7 +1891,10 @@ def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = Non

if "example_outputs" not in kwargs:
self.eval()
kwargs["example_outputs"] = self(input_sample)
if isinstance(input_sample, Tuple):
kwargs["example_outputs"] = self(*input_sample)
else:
kwargs["example_outputs"] = self(input_sample)
ananthsub marked this conversation as resolved.
Show resolved Hide resolved

torch.onnx.export(self, input_sample, file_path, **kwargs)
self.train(mode)
Expand Down
16 changes: 12 additions & 4 deletions tests/models/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning import Trainer
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
from tests.utilities.test_model_summary import UnorderedModel


def test_model_saves_with_input_sample(tmpdir):
Expand Down Expand Up @@ -66,10 +67,17 @@ def test_model_saves_with_example_output(tmpdir):
assert os.path.exists(file_path) is True


def test_model_saves_with_example_input_array(tmpdir):
"""Test that ONNX model saves with_example_input_array and size is greater than 3 MB"""
model = BoringModel()
model.example_input_array = torch.randn(5, 32)
@pytest.mark.parametrize(
["modelclass", "input_sample"],
[
(BoringModel, torch.randn(1, 32)),
(UnorderedModel, (torch.rand(2, 3), torch.rand(2, 10))),
],
)
def test_model_saves_with_example_input_array(tmpdir, modelclass, input_sample):
"""Test that ONNX model saves with example_input_array and size is greater than 3 MB"""
model = modelclass()
model.example_input_array = input_sample

file_path = os.path.join(tmpdir, "model.onnx")
model.to_onnx(file_path)
Expand Down