diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index 711e24956c..3143d28472 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -313,7 +313,7 @@ def inference(self, data, *args, **kwargs): Returns: Torch Tensor : The Predicted Torch Tensor is returned in this function. """ - with torch.no_grad(): + with torch.inference_mode(): marshalled_data = data.to(self.device) results = self.model(marshalled_data, *args, **kwargs) return results