Skip to content

Commit

Permalink
Add support for loading torchscript models
Browse files Browse the repository at this point in the history
  • Loading branch information
AnandInguva committed Feb 6, 2023
1 parent 16cb63b commit e371f37
Showing 1 changed file with 61 additions and 14 deletions.
75 changes: 61 additions & 14 deletions sdks/python/apache_beam/ml/inference/pytorch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@


def _load_model(
model_class: torch.nn.Module, state_dict_path, device, **model_params):
model = model_class(**model_params)

model_class: torch.nn.Module,
state_dict_path,
device,
model_params,
use_torch_script_format=False):
if device == torch.device('cuda') and not torch.cuda.is_available():
logging.warning(
"Model handler specified a 'GPU' device, but GPUs are not available. " \
Expand All @@ -71,18 +73,26 @@ def _load_model(
try:
logging.info(
"Loading state_dict_path %s onto a %s device", state_dict_path, device)
state_dict = torch.load(file, map_location=device)
if not use_torch_script_format:
model = model_class(**model_params)
state_dict = torch.load(file, map_location=device)
model.load_state_dict(state_dict)
else:
model = torch.jit.load(file, map_location=device)
except RuntimeError as e:
if device == torch.device('cuda'):
message = "Loading the model onto a GPU device failed due to an " \
f"exception:\n{e}\nAttempting to load onto a CPU device instead."
logging.warning(message)
return _load_model(
model_class, state_dict_path, torch.device('cpu'), **model_params)
model_class,
state_dict_path,
torch.device('cpu'),
model_params,
use_torch_script_format)
else:
raise e

model.load_state_dict(state_dict)
model.to(device)
model.eval()
logging.info("Finished loading PyTorch model.")
Expand Down Expand Up @@ -149,11 +159,13 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
def __init__(
self,
state_dict_path: str,
model_class: Callable[..., torch.nn.Module],
model_params: Dict[str, Any],
model_class: Optional[Callable[..., torch.nn.Module]] = None,
model_params: Optional[Dict[str, Any]] = None,
device: str = 'CPU',
*,
inference_fn: TensorInferenceFn = default_tensor_inference_fn):
inference_fn: TensorInferenceFn = default_tensor_inference_fn,
use_torch_script_format=False,
):
"""Implementation of the ModelHandler interface for PyTorch.
Example Usage::
Expand All @@ -174,6 +186,9 @@ def __init__(
Otherwise, it will be CPU.
inference_fn: the inference function to use during RunInference.
default=_default_tensor_inference_fn
use_torch_script_format: When `use_torch_script_format` is set to `True`,
the model will be loaded using `torch.jit.load()`.
`model_class` and `model_params` arguments will be disregarded.
**Supported Versions:** RunInference APIs in Apache Beam have been tested
with PyTorch 1.9 and 1.10.
Expand All @@ -188,14 +203,28 @@ def __init__(
self._model_class = model_class
self._model_params = model_params
self._inference_fn = inference_fn
self._use_torch_script_format = use_torch_script_format

self._validate_func_args()

def _validate_func_args(self):
if not self._use_torch_script_format and (self._model_class is None or
self._model_params is None):
raise RuntimeError(
"Please pass both `model_class` and `model_params` to the torch "
"model handler when using it with PyTorch. "
"If you opt to load the entire that was saved using TorchScript, "
"set `use_torch_script_format` to True.")

def load_model(self) -> torch.nn.Module:
"""Loads and initializes a Pytorch model for processing."""
model, device = _load_model(
self._model_class,
self._state_dict_path,
self._device,
**self._model_params)
self._model_params,
self._use_torch_script_format
)
self._device = device
return model

Expand Down Expand Up @@ -323,11 +352,12 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
def __init__(
self,
state_dict_path: str,
model_class: Callable[..., torch.nn.Module],
model_params: Dict[str, Any],
model_class: Optional[Callable[..., torch.nn.Module]] = None,
model_params: Optional[Dict[str, Any]] = None,
device: str = 'CPU',
*,
inference_fn: KeyedTensorInferenceFn = default_keyed_tensor_inference_fn):
inference_fn: KeyedTensorInferenceFn = default_keyed_tensor_inference_fn,
use_torch_script_format: bool = False):
"""Implementation of the ModelHandler interface for PyTorch.
Example Usage::
Expand All @@ -352,6 +382,9 @@ def __init__(
Otherwise, it will be CPU.
inference_fn: the function to invoke on run_inference.
default = default_keyed_tensor_inference_fn
use_torch_script_format: When `use_torch_script_format` is set to `True`,
the model will be loaded using `torch.jit.load()`.
`model_class` and `model_params` arguments will be disregarded.
**Supported Versions:** RunInference APIs in Apache Beam have been tested
on torch>=1.9.0,<1.14.0.
Expand All @@ -366,14 +399,19 @@ def __init__(
self._model_class = model_class
self._model_params = model_params
self._inference_fn = inference_fn
self._use_torch_script_format = use_torch_script_format

self._validate_func_args()

def load_model(self) -> torch.nn.Module:
"""Loads and initializes a Pytorch model for processing."""
model, device = _load_model(
self._model_class,
self._state_dict_path,
self._device,
**self._model_params)
self._model_params,
self._use_torch_script_format
)
self._device = device
return model

Expand Down Expand Up @@ -429,3 +467,12 @@ def get_metrics_namespace(self) -> str:

def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
pass

def _validate_func_args(self):
if not self._use_torch_script_format and (self._model_class is None or
self._model_params is None):
raise RuntimeError(
"Please pass both `model_class` and `model_params` to the torch "
"model handler when using it with PyTorch. "
"If you opt to load the entire that was saved using TorchScript, "
"set `use_torch_script_format` to True.")

0 comments on commit e371f37

Please sign in to comment.