From e371f37137f37bbc80e3aa5adfebdc52fcea4d54 Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Sun, 5 Feb 2023 23:12:25 -0500 Subject: [PATCH] Add support for loading torchscript models --- .../ml/inference/pytorch_inference.py | 75 +++++++++++++++---- 1 file changed, 61 insertions(+), 14 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 3366d523076fb..87bf599569166 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -58,9 +58,11 @@ def _load_model( - model_class: torch.nn.Module, state_dict_path, device, **model_params): - model = model_class(**model_params) - + model_class: torch.nn.Module, + state_dict_path, + device, + model_params, + use_torch_script_format=False): if device == torch.device('cuda') and not torch.cuda.is_available(): logging.warning( "Model handler specified a 'GPU' device, but GPUs are not available. " \ @@ -71,18 +73,26 @@ def _load_model( try: logging.info( "Loading state_dict_path %s onto a %s device", state_dict_path, device) - state_dict = torch.load(file, map_location=device) + if not use_torch_script_format: + model = model_class(**model_params) + state_dict = torch.load(file, map_location=device) + model.load_state_dict(state_dict) + else: + model = torch.jit.load(file, map_location=device) except RuntimeError as e: if device == torch.device('cuda'): message = "Loading the model onto a GPU device failed due to an " \ f"exception:\n{e}\nAttempting to load onto a CPU device instead." logging.warning(message) return _load_model( - model_class, state_dict_path, torch.device('cpu'), **model_params) + model_class, + state_dict_path, + torch.device('cpu'), + model_params, + use_torch_script_format) else: raise e - model.load_state_dict(state_dict) model.to(device) model.eval() logging.info("Finished loading PyTorch model.") @@ -149,11 +159,13 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor, def __init__( self, state_dict_path: str, - model_class: Callable[..., torch.nn.Module], - model_params: Dict[str, Any], + model_class: Optional[Callable[..., torch.nn.Module]] = None, + model_params: Optional[Dict[str, Any]] = None, device: str = 'CPU', *, - inference_fn: TensorInferenceFn = default_tensor_inference_fn): + inference_fn: TensorInferenceFn = default_tensor_inference_fn, + use_torch_script_format=False, + ): """Implementation of the ModelHandler interface for PyTorch. Example Usage:: @@ -174,6 +186,9 @@ def __init__( Otherwise, it will be CPU. inference_fn: the inference function to use during RunInference. default=_default_tensor_inference_fn + use_torch_script_format: When `use_torch_script_format` is set to `True`, + the model will be loaded using `torch.jit.load()`. + `model_class` and `model_params` arguments will be disregarded. **Supported Versions:** RunInference APIs in Apache Beam have been tested with PyTorch 1.9 and 1.10. @@ -188,6 +203,18 @@ def __init__( self._model_class = model_class self._model_params = model_params self._inference_fn = inference_fn + self._use_torch_script_format = use_torch_script_format + + self._validate_func_args() + + def _validate_func_args(self): + if not self._use_torch_script_format and (self._model_class is None or + self._model_params is None): + raise RuntimeError( + "Please pass both `model_class` and `model_params` to the torch " + "model handler when using it with PyTorch. " + "If you opt to load the entire that was saved using TorchScript, " + "set `use_torch_script_format` to True.") def load_model(self) -> torch.nn.Module: """Loads and initializes a Pytorch model for processing.""" @@ -195,7 +222,9 @@ def load_model(self) -> torch.nn.Module: self._model_class, self._state_dict_path, self._device, - **self._model_params) + self._model_params, + self._use_torch_script_format + ) self._device = device return model @@ -323,11 +352,12 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor], def __init__( self, state_dict_path: str, - model_class: Callable[..., torch.nn.Module], - model_params: Dict[str, Any], + model_class: Optional[Callable[..., torch.nn.Module]] = None, + model_params: Optional[Dict[str, Any]] = None, device: str = 'CPU', *, - inference_fn: KeyedTensorInferenceFn = default_keyed_tensor_inference_fn): + inference_fn: KeyedTensorInferenceFn = default_keyed_tensor_inference_fn, + use_torch_script_format: bool = False): """Implementation of the ModelHandler interface for PyTorch. Example Usage:: @@ -352,6 +382,9 @@ def __init__( Otherwise, it will be CPU. inference_fn: the function to invoke on run_inference. default = default_keyed_tensor_inference_fn + use_torch_script_format: When `use_torch_script_format` is set to `True`, + the model will be loaded using `torch.jit.load()`. + `model_class` and `model_params` arguments will be disregarded. **Supported Versions:** RunInference APIs in Apache Beam have been tested on torch>=1.9.0,<1.14.0. @@ -366,6 +399,9 @@ def __init__( self._model_class = model_class self._model_params = model_params self._inference_fn = inference_fn + self._use_torch_script_format = use_torch_script_format + + self._validate_func_args() def load_model(self) -> torch.nn.Module: """Loads and initializes a Pytorch model for processing.""" @@ -373,7 +409,9 @@ def load_model(self) -> torch.nn.Module: self._model_class, self._state_dict_path, self._device, - **self._model_params) + self._model_params, + self._use_torch_script_format + ) self._device = device return model @@ -429,3 +467,12 @@ def get_metrics_namespace(self) -> str: def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): pass + + def _validate_func_args(self): + if not self._use_torch_script_format and (self._model_class is None or + self._model_params is None): + raise RuntimeError( + "Please pass both `model_class` and `model_params` to the torch " + "model handler when using it with PyTorch. " + "If you opt to load the entire that was saved using TorchScript, " + "set `use_torch_script_format` to True.")