From 9cfd29946a308526cee088c176a2285e6752453c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stef=20=7C=20=E3=82=B9=E3=83=86=E3=83=95?= Date: Thu, 29 Oct 2020 14:46:57 +0900 Subject: [PATCH] move example inputs to correct device when tracing module (#4360) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * use move_data_to_device instead of to; docstring also allow tuple of Tensor; not supported log error when example_inputs is a dict; commented docstring trace example * Use isinstance to check if example_inputs is a Mapping, instead of type Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * import Mapping for isinstance check * multi-line docstring code to test TorchScript trace() * Fix PEP8 f-string is missing placeholders * minor code style improvements * Use (possibly user overwritten) transfer_batch_to_device instead of move_data_to_device Co-authored-by: Rohit Gupta * fixed weird comment about trace() log error * Remove unused import Co-authored-by: Jeff Yang * Remove logger warning about dict not example_inputs not supported by trace Co-authored-by: stef-ubuntu Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Adrian Wälchli Co-authored-by: Rohit Gupta Co-authored-by: Jeff Yang --- pytorch_lightning/core/lightning.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 065b29c75da37..22d63d0a03a74 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -20,7 +20,7 @@ import tempfile from abc import ABC from argparse import Namespace -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, Mapping import torch from pytorch_lightning import _logger as log @@ -1539,7 +1539,7 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg def to_torchscript( self, file_path: Optional[str] = None, method: Optional[str] = 'script', - example_inputs: Optional[torch.Tensor] = None, **kwargs + example_inputs: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]] = None, **kwargs ) -> Union[ScriptModule, Dict[str, ScriptModule]]: """ By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. @@ -1576,6 +1576,9 @@ def to_torchscript( >>> model = SimpleModel() >>> torch.jit.save(model.to_torchscript(), "model.pt") # doctest: +SKIP >>> os.path.isfile("model.pt") # doctest: +SKIP + >>> torch.jit.save(model.to_torchscript(file_path="model_trace.pt", method='trace', # doctest: +SKIP + ... example_inputs=torch.randn(1, 64))) # doctest: +SKIP + >>> os.path.isfile("model_trace.pt") # doctest: +SKIP True Return: @@ -1592,8 +1595,8 @@ def to_torchscript( if example_inputs is None: example_inputs = self.example_input_array # automatically send example inputs to the right device and use trace - torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs.to(self.device), - **kwargs) + example_inputs = self.transfer_batch_to_device(example_inputs, device=self.device) + torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) else: raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:" f"{method}")