-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add trace functionality to the function to_torchscript #4142
Conversation
Hello @NumesSanguis! Thanks for updating this PR. There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻 Comment last updated at 2020-10-14 10:18:36 UTC |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fine to me. Could you maybe also extend the example in docstrings?
Codecov Report
@@ Coverage Diff @@
## master #4142 +/- ##
======================================
- Coverage 92% 92% -0%
======================================
Files 103 103
Lines 7792 7798 +6
======================================
+ Hits 7147 7152 +5
- Misses 645 646 +1 |
if example_inputs is None: | ||
example_inputs = self.example_input_array | ||
# automatically send example inputs to the right device and use trace | ||
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs.to(self.device), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here you assume that example_input_array
is a tensor, but this is not true.
If forward takes *args
, example_input_array
is a tuple, and if forward takes **kwargs
, example_input_array
must be a dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given your comment on the issue that .trace()
accepts either a tuple
or a torch.Tensor
(that is automatically converted to a tuple), it means that the input should be: example_input_array: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]
?
However, when the forward function accepts **kwargs
, self.example_input_array
could be a dict, in which case .trace(example_inputs=example_inputs)
will fail?
What would be the best way to approach this? Does this mean that .trace()
cannot be used if forward
expects a dict?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(self, x):
return self.conv(x)
class Net2(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(self, x, y):
return self.conv(x)
class Net3(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(self, x, y):
return self.conv(x)
# SINGLE INPUT
net = Net()
ex_inp = torch.rand(1, 1, 3, 3)
torch.jit.trace(net, ex_inp)
# TWO INPUTS
net = Net2()
torch.jit.trace(net, (ex_inp, ex_inp))
# DICT (**kwargs)
# fails
# net = Net3()
# torch.jit.trace(net, dict(x=ex_inp, y=ex_inp))
Here is an example. tracing supports single input and tuple, which gets unrolled to multiple positional args. In these two cases, you can use the Lightning self.example_input_array. However, dicts will not be passed as kwargs, and instead as a single input. In Lightning however, a dict would mean **kwargs.
I see several ways to handle it:
- leave as is, user needs to know how self.example_input_array works
- error when
self.example_input_array
is a dict - do not even use
self.example_input_array
, and require the user to give inputs to the method directly
Then there is a second issue. You should use the pytorch_lightning.utilities.apply_func.move_data_to_device
to move the example input to the device, since it could be a tuple.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @ananthsub
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 & 2 could be combined by raising a warning instead of an error. From PL's side throw a warning similar to:
self.example_input_array
cannot be a dict. Please provide a sample Tensor/Tuple toexample_inputs
as argument, or setself.example_input_array
to a Tensor/Tuple.
Then output the actual error produced by .trace()
.
If in the future .trace()
would be updated to support a dict, there is no need for a change (except removing the warning) on PL's side.
Personally, PL is for me about removing boilerplate code. Since self.example_input_array
is already a thing in PL, it's better to use it. Therefore, I would advise against option 3.
I haven't used self.example_input_array
personally yet, but in how many projects would this be a dict?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, makes sense.
Would you like to follow up on this with a PR? Would greatly appreciate this. For me the main concern is to properly move the input to the device with the function I referenced. For the way inputs are passed in, I don't have a strong opinon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IDK how much future support there will be for tracing vs scripting (scripting is strongly recommended). Rather than adding more trace support at the top-level of the PL module, why not override to_torchscript
in your lightning module to determine how you want to export? then you have way more flexibility with tracing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@awaelchli Ok, I'll follow up with another pull request using the move_data_to_device
function.
@ananthsub edit moved my comment to the feature request, as it is a more relevant place for this discussion: #4140
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@awaelchli I addressed your issues in a follow-up pull request (could not be added to this one due to it already being merged):
#4360
Follow-up pull request can be found here: #4360 |
What does this PR do?
Add the ability to also choose to make use of TorchScript's
trace()
method, besides the defaultscript()
method
andexample_inputs
to support both modes. See the Issue for the rational. Default values assure that with no arguments provided, the original behaviour is kept.trace()
functionFixes #4140
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Yes :)
Notes