Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move example inputs to correct device when tracing module #4360

Merged
merged 14 commits into from
Oct 29, 2020
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import tempfile
from abc import ABC
from argparse import Namespace
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, Mapping

import torch
from pytorch_lightning import _logger as log
Expand All @@ -38,6 +38,7 @@
collect_init_args,
get_init_args,
)
from pytorch_lightning.utilities.apply_func import move_data_to_device
NumesSanguis marked this conversation as resolved.
Show resolved Hide resolved
from torch import ScriptModule, Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
Expand Down Expand Up @@ -1539,7 +1540,7 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg

def to_torchscript(
self, file_path: Optional[str] = None, method: Optional[str] = 'script',
example_inputs: Optional[torch.Tensor] = None, **kwargs
example_inputs: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]] = None, **kwargs
) -> Union[ScriptModule, Dict[str, ScriptModule]]:
"""
By default compiles the whole model to a :class:`~torch.jit.ScriptModule`.
Expand Down Expand Up @@ -1576,6 +1577,9 @@ def to_torchscript(
>>> model = SimpleModel()
>>> torch.jit.save(model.to_torchscript(), "model.pt") # doctest: +SKIP
>>> os.path.isfile("model.pt") # doctest: +SKIP
>>> torch.jit.save(model.to_torchscript(file_path="model_trace.pt", method='trace', # doctest: +SKIP
... example_inputs=torch.randn(1, 64))) # doctest: +SKIP
>>> os.path.isfile("model_trace.pt") # doctest: +SKIP
True

Return:
Expand All @@ -1591,9 +1595,13 @@ def to_torchscript(
# if no example inputs are provided, try to see if model has example_input_array set
if example_inputs is None:
example_inputs = self.example_input_array
# dicts are not supported, so show the user an error; not raising an error to show the original error
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this comment, it is saying there is an error but there is not an error, what is it?
can we just remove the comment? the code should speak for itself

Copy link
Contributor

@rohitgr7 rohitgr7 Oct 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the inputs must be a tensor or tuple of tensors. IMO a better way to handle this is by wrapping the input tensor into a tuple and checking whether each element in the tuple is an instance of torch.Tensor or not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also with a test for the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli good catch, that doesn't make sense indeed.
@rohitgr7 I think trace() already does internally this wrapping of a torch.Tensor in a tuple, so I don't think we have to add that again on Lightning's side?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the comment to hopefully make more sense

Copy link
Contributor

@rohitgr7 rohitgr7 Oct 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

give me sometime. Need to check what's the actual issue here. Is there something wrong from pytorch side or we are doing something wrong here? In the meantime can you open an issue on pytorch forums if possible? Maybe we can get a quick response there :) Would be good to resolve all issues in this PR itself to avoid any issues in the future related to to_torchscript. Also will make similar changes to to_onnx #4378.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This issue is already in the master, since this pull request was already merged: #4140
This pull request is just to add some quality of life changes to the previous one. If we merge this one, it's much easier for other people to reproduce this issue, because they will have the same error output (this pull request does not add a new problem, just 1 step closer to solving it).

We can just keep the original issue (#4140) open, and discuss this issue there, as it would be more easy to find compared to this comment thread. Then we can point a PyTorch forum issue to there. A new pull request can then target that specific dict improvement (which might be very deep), instead of making this PR huge.

Honestly, I would like to make all parts work nicely, but I'm not affected by the dict issue, and I already spend too much time on this pull request. The previous pull request already added everything needed for my use case, but this pull request is just an extra to make the previous one a little bit less rough.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok cool. Then let's remove the check for Mapping and merge this one since it doesn't throw any error with dict :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohitgr7 Thanks. The logger error has been removed :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohitgr7 I put a summary of the Dict issue here: #4140 (comment)
which should make the discussion a bit more visible for others.

if isinstance(example_inputs, Mapping):
log.error("`example_inputs` should be a Tensor or a tuple of Tensors,"
f" but got {type(example_inputs)}.")
# automatically send example inputs to the right device and use trace
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs.to(self.device),
**kwargs)
example_inputs = move_data_to_device(example_inputs, device=self.device)
NumesSanguis marked this conversation as resolved.
Show resolved Hide resolved
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
else:
raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:"
f"{method}")
Expand Down