From cbbf04f9fdbefe7cfd7a2a207b66a12e26845fa6 Mon Sep 17 00:00:00 2001 From: stef-ubuntu Date: Mon, 26 Oct 2020 14:42:02 +0900 Subject: [PATCH 01/10] use move_data_to_device instead of to; docstring also allow tuple of Tensor; not supported log error when example_inputs is a dict; commented docstring trace example --- pytorch_lightning/core/lightning.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 065b29c75da37..575e72cc1c391 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -38,6 +38,7 @@ collect_init_args, get_init_args, ) +from pytorch_lightning.utilities.apply_func import move_data_to_device from torch import ScriptModule, Tensor from torch.nn import Module from torch.optim.optimizer import Optimizer @@ -1539,7 +1540,7 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg def to_torchscript( self, file_path: Optional[str] = None, method: Optional[str] = 'script', - example_inputs: Optional[torch.Tensor] = None, **kwargs + example_inputs: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]] = None, **kwargs ) -> Union[ScriptModule, Dict[str, ScriptModule]]: """ By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. @@ -1576,6 +1577,8 @@ def to_torchscript( >>> model = SimpleModel() >>> torch.jit.save(model.to_torchscript(), "model.pt") # doctest: +SKIP >>> os.path.isfile("model.pt") # doctest: +SKIP + # >>> torch.jit.save(model.to_torchscript(method='trace', example_inputs=torch.randn(1, 64)), "model_trace.pt") # doctest: +SKIP + # >>> os.path.isfile("model_trace.pt") # doctest: +SKIP True Return: @@ -1591,8 +1594,12 @@ def to_torchscript( # if no example inputs are provided, try to see if model has example_input_array set if example_inputs is None: example_inputs = self.example_input_array + # dicts are not supported, so show the user an error; not raising an error to show the original error + if type(example_inputs) == dict: + log.error(f"`example_inputs` should be a Tensor or a tuple of Tensors, but got a dict.") # automatically send example inputs to the right device and use trace - torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs.to(self.device), + example_input_array_device = move_data_to_device(example_inputs, device=self.device) + torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_input_array_device, **kwargs) else: raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:" From fb724164547bf000dd7e698ca3974eff1ac6eddb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stef=20=7C=20=E3=82=B9=E3=83=86=E3=83=95?= Date: Mon, 26 Oct 2020 17:15:08 +0900 Subject: [PATCH 02/10] Use isinstance to check if example_inputs is a Mapping, instead of type Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 575e72cc1c391..2a3d94bd08c45 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1595,7 +1595,7 @@ def to_torchscript( if example_inputs is None: example_inputs = self.example_input_array # dicts are not supported, so show the user an error; not raising an error to show the original error - if type(example_inputs) == dict: + if isinstance(example_inputs, Mapping): log.error(f"`example_inputs` should be a Tensor or a tuple of Tensors, but got a dict.") # automatically send example inputs to the right device and use trace example_input_array_device = move_data_to_device(example_inputs, device=self.device) From 8421c5e5d58f0924aafc4d268dc14bd9908d4b7f Mon Sep 17 00:00:00 2001 From: stef-ubuntu Date: Mon, 26 Oct 2020 17:47:52 +0900 Subject: [PATCH 03/10] import Mapping for isinstance check --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 2a3d94bd08c45..9a4e33dc231dd 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -20,7 +20,7 @@ import tempfile from abc import ABC from argparse import Namespace -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, Mapping import torch from pytorch_lightning import _logger as log From 8cb11dab52bd2bf6b7e78d9b325b31710579877e Mon Sep 17 00:00:00 2001 From: stef-ubuntu Date: Mon, 26 Oct 2020 18:20:17 +0900 Subject: [PATCH 04/10] multi-line docstring code to test TorchScript trace() --- pytorch_lightning/core/lightning.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 9a4e33dc231dd..f2810d08b9020 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1577,8 +1577,9 @@ def to_torchscript( >>> model = SimpleModel() >>> torch.jit.save(model.to_torchscript(), "model.pt") # doctest: +SKIP >>> os.path.isfile("model.pt") # doctest: +SKIP - # >>> torch.jit.save(model.to_torchscript(method='trace', example_inputs=torch.randn(1, 64)), "model_trace.pt") # doctest: +SKIP - # >>> os.path.isfile("model_trace.pt") # doctest: +SKIP + >>> torch.jit.save(model.to_torchscript(file_path="model_trace.pt", method='trace', # doctest: +SKIP + ... example_inputs=torch.randn(1, 64))) # doctest: +SKIP + >>> os.path.isfile("model_trace.pt") # doctest: +SKIP True Return: From 95c00b74bce51599f232b86d6ad6bcb17ff89752 Mon Sep 17 00:00:00 2001 From: stef-ubuntu Date: Mon, 26 Oct 2020 18:21:40 +0900 Subject: [PATCH 05/10] Fix PEP8 f-string is missing placeholders --- pytorch_lightning/core/lightning.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f2810d08b9020..6a73bddcd7ee6 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1597,7 +1597,8 @@ def to_torchscript( example_inputs = self.example_input_array # dicts are not supported, so show the user an error; not raising an error to show the original error if isinstance(example_inputs, Mapping): - log.error(f"`example_inputs` should be a Tensor or a tuple of Tensors, but got a dict.") + log.error("`example_inputs` should be a Tensor or a tuple of Tensors," + f"but got {type(example_inputs)}.") # automatically send example inputs to the right device and use trace example_input_array_device = move_data_to_device(example_inputs, device=self.device) torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_input_array_device, From f5f7cf3addde176972169d8021cf22b7194e25dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 27 Oct 2020 09:41:52 +0100 Subject: [PATCH 06/10] minor code style improvements --- pytorch_lightning/core/lightning.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 6a73bddcd7ee6..fb47648632e25 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1598,11 +1598,10 @@ def to_torchscript( # dicts are not supported, so show the user an error; not raising an error to show the original error if isinstance(example_inputs, Mapping): log.error("`example_inputs` should be a Tensor or a tuple of Tensors," - f"but got {type(example_inputs)}.") + f" but got {type(example_inputs)}.") # automatically send example inputs to the right device and use trace - example_input_array_device = move_data_to_device(example_inputs, device=self.device) - torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_input_array_device, - **kwargs) + example_inputs = move_data_to_device(example_inputs, device=self.device) + torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) else: raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:" f"{method}") From 7e769e8aecfd3b3a2f5c3130cb0e63be1a0bb7f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stef=20=7C=20=E3=82=B9=E3=83=86=E3=83=95?= Date: Tue, 27 Oct 2020 18:25:20 +0900 Subject: [PATCH 07/10] Use (possibly user overwritten) transfer_batch_to_device instead of move_data_to_device Co-authored-by: Rohit Gupta --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index fb47648632e25..4d211f8d4a899 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1600,7 +1600,7 @@ def to_torchscript( log.error("`example_inputs` should be a Tensor or a tuple of Tensors," f" but got {type(example_inputs)}.") # automatically send example inputs to the right device and use trace - example_inputs = move_data_to_device(example_inputs, device=self.device) + example_inputs = self.transfer_batch_to_device(example_inputs, device=self.device) torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) else: raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:" From c1fc753f3061fcdeb4c63a481ad2c89099e03c05 Mon Sep 17 00:00:00 2001 From: stef-ubuntu Date: Tue, 27 Oct 2020 18:36:10 +0900 Subject: [PATCH 08/10] fixed weird comment about trace() log error --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 4d211f8d4a899..828bcef78d918 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1595,7 +1595,7 @@ def to_torchscript( # if no example inputs are provided, try to see if model has example_input_array set if example_inputs is None: example_inputs = self.example_input_array - # dicts are not supported, so show the user an error; not raising an error to show the original error + # dicts are not supported for example_inputs, so show a user-friendly message about what is wrong if isinstance(example_inputs, Mapping): log.error("`example_inputs` should be a Tensor or a tuple of Tensors," f" but got {type(example_inputs)}.") From d3da5f16b65122bc26e9641904b6184b208d8283 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stef=20=7C=20=E3=82=B9=E3=83=86=E3=83=95?= Date: Wed, 28 Oct 2020 10:08:31 +0900 Subject: [PATCH 09/10] Remove unused import Co-authored-by: Jeff Yang --- pytorch_lightning/core/lightning.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 828bcef78d918..ff8733a6f5886 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -38,7 +38,6 @@ collect_init_args, get_init_args, ) -from pytorch_lightning.utilities.apply_func import move_data_to_device from torch import ScriptModule, Tensor from torch.nn import Module from torch.optim.optimizer import Optimizer From a8c7e15e69f38b1a653aa4082fe2bff7c806ccf6 Mon Sep 17 00:00:00 2001 From: stef-ubuntu Date: Thu, 29 Oct 2020 14:20:39 +0900 Subject: [PATCH 10/10] Remove logger warning about dict not example_inputs not supported by trace --- pytorch_lightning/core/lightning.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ff8733a6f5886..22d63d0a03a74 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1594,10 +1594,6 @@ def to_torchscript( # if no example inputs are provided, try to see if model has example_input_array set if example_inputs is None: example_inputs = self.example_input_array - # dicts are not supported for example_inputs, so show a user-friendly message about what is wrong - if isinstance(example_inputs, Mapping): - log.error("`example_inputs` should be a Tensor or a tuple of Tensors," - f" but got {type(example_inputs)}.") # automatically send example inputs to the right device and use trace example_inputs = self.transfer_batch_to_device(example_inputs, device=self.device) torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)