Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add trace functionality to the function to_torchscript #4142

Merged
merged 3 commits into from
Oct 14, 2020
Merged

Add trace functionality to the function to_torchscript #4142

merged 3 commits into from
Oct 14, 2020

Conversation

NumesSanguis
Copy link
Contributor

@NumesSanguis NumesSanguis commented Oct 14, 2020

What does this PR do?

Add the ability to also choose to make use of TorchScript's trace() method, besides the default script()

  • (method) Adds the parameters method and example_inputs to support both modes. See the Issue for the rational. Default values assure that with no arguments provided, the original behaviour is kept.
  • (docs) Rewrites function description to match the extended capability
  • (tests) Adds a test case for the trace() function

Fixes #4140

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together? Otherwise, we ask you to create a separate PR for every change.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Yes :)

Notes

  • The documentation does not state how to run the tests locally. I did however test the new functionality in my own project.
  • The CHANGELOG did not have an entry yet for v1.1, so to prevent conflicts, I have not updated this yet.

@pep8speaks
Copy link

pep8speaks commented Oct 14, 2020

Hello @NumesSanguis! Thanks for updating this PR.

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2020-10-14 10:18:36 UTC

@mergify mergify bot requested a review from a team October 14, 2020 10:12
Copy link
Member

@justusschock justusschock left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine to me. Could you maybe also extend the example in docstrings?

@mergify mergify bot requested a review from a team October 14, 2020 12:00
@codecov
Copy link

codecov bot commented Oct 14, 2020

Codecov Report

Merging #4142 into master will decrease coverage by 0%.
The diff coverage is 89%.

@@          Coverage Diff           @@
##           master   #4142   +/-   ##
======================================
- Coverage      92%     92%   -0%     
======================================
  Files         103     103           
  Lines        7792    7798    +6     
======================================
+ Hits         7147    7152    +5     
- Misses        645     646    +1     

@williamFalcon williamFalcon merged commit fa737a5 into Lightning-AI:master Oct 14, 2020
if example_inputs is None:
example_inputs = self.example_input_array
# automatically send example inputs to the right device and use trace
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs.to(self.device),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you assume that example_input_array is a tensor, but this is not true.
If forward takes *args, example_input_array is a tuple, and if forward takes **kwargs, example_input_array must be a dict.

Copy link
Contributor Author

@NumesSanguis NumesSanguis Oct 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given your comment on the issue that .trace() accepts either a tuple or a torch.Tensor (that is automatically converted to a tuple), it means that the input should be: example_input_array: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]?

However, when the forward function accepts **kwargs, self.example_input_array could be a dict, in which case .trace(example_inputs=example_inputs) will fail?

What would be the best way to approach this? Does this mean that .trace() cannot be used if forward expects a dict?

Copy link
Contributor

@awaelchli awaelchli Oct 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)


class Net2(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x, y):
        return self.conv(x)


class Net3(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x, y):
        return self.conv(x)

# SINGLE INPUT
net = Net()
ex_inp = torch.rand(1, 1, 3, 3)
torch.jit.trace(net, ex_inp)

# TWO INPUTS
net = Net2()
torch.jit.trace(net, (ex_inp, ex_inp))

# DICT (**kwargs)
# fails
# net = Net3()
# torch.jit.trace(net, dict(x=ex_inp, y=ex_inp))

Here is an example. tracing supports single input and tuple, which gets unrolled to multiple positional args. In these two cases, you can use the Lightning self.example_input_array. However, dicts will not be passed as kwargs, and instead as a single input. In Lightning however, a dict would mean **kwargs.

I see several ways to handle it:

  1. leave as is, user needs to know how self.example_input_array works
  2. error when self.example_input_array is a dict
  3. do not even use self.example_input_array, and require the user to give inputs to the method directly

Then there is a second issue. You should use the pytorch_lightning.utilities.apply_func.move_data_to_device to move the example input to the device, since it could be a tuple.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 & 2 could be combined by raising a warning instead of an error. From PL's side throw a warning similar to:

self.example_input_array cannot be a dict. Please provide a sample Tensor/Tuple to example_inputs as argument, or set self.example_input_array to a Tensor/Tuple.

Then output the actual error produced by .trace().
If in the future .trace() would be updated to support a dict, there is no need for a change (except removing the warning) on PL's side.

Personally, PL is for me about removing boilerplate code. Since self.example_input_array is already a thing in PL, it's better to use it. Therefore, I would advise against option 3.
I haven't used self.example_input_array personally yet, but in how many projects would this be a dict?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, makes sense.
Would you like to follow up on this with a PR? Would greatly appreciate this. For me the main concern is to properly move the input to the device with the function I referenced. For the way inputs are passed in, I don't have a strong opinon.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDK how much future support there will be for tracing vs scripting (scripting is strongly recommended). Rather than adding more trace support at the top-level of the PL module, why not override to_torchscript in your lightning module to determine how you want to export? then you have way more flexibility with tracing

Copy link
Contributor Author

@NumesSanguis NumesSanguis Oct 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli Ok, I'll follow up with another pull request using the move_data_to_device function.

@ananthsub edit moved my comment to the feature request, as it is a more relevant place for this discussion: #4140

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli I addressed your issues in a follow-up pull request (could not be added to this one due to it already being merged):
#4360

@NumesSanguis
Copy link
Contributor Author

Follow-up pull request can be found here: #4360

@NumesSanguis NumesSanguis mentioned this pull request Nov 6, 2020
1 task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Expand to_torchscript to support also TorchScript's trace method
7 participants