diff --git a/py/torch_tensorrt/runtime/_cudagraphs.py b/py/torch_tensorrt/runtime/_cudagraphs.py index 6bdc07cbce..d1564cb4dc 100644 --- a/py/torch_tensorrt/runtime/_cudagraphs.py +++ b/py/torch_tensorrt/runtime/_cudagraphs.py @@ -1,5 +1,5 @@ import logging -from typing import Any +from typing import Any, Union import torch import torch_tensorrt @@ -88,17 +88,15 @@ def __enter__(self) -> torch.nn.Module: torch.ops.tensorrt.set_cudagraphs_mode(_PY_RT_CUDAGRAPHS) logger.debug( - f"{num_torch_module} torch modules are in subgraphs. Using wrapper module for cuda graphs" + "Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule" ) return CudaGraphsTorchTensorRTModule(self.compiled_module) else: if num_trt_module > 0: - logger.debug( - "There is no graph breaks. Using original module for cuda graphs" - ) + logger.debug("No graph breaks detected, using runtime cudagraphs mode") else: - logger.warning( - "Please consider dynamo if there is graph breaks. Using original module for cuda graphs" + logger.debug( + "Please consider dynamo if there is graph breaks. Using runtime cudagraphs mode" ) # Enable cudagraphs for TRT submodule set_cudagraphs_mode(True) @@ -110,6 +108,6 @@ def __exit__(self, *args: Any) -> None: def enable_cudagraphs( - compiled_module: torch.nn.Module, + compiled_module: Union[torch.fx.GraphModule, torch.nn.Module], ) -> _CudagraphsContextManager: return _CudagraphsContextManager(compiled_module) diff --git a/py/torch_tensorrt/runtime/_weight_streaming.py b/py/torch_tensorrt/runtime/_weight_streaming.py index 8ae5dc1841..42f02a02a8 100755 --- a/py/torch_tensorrt/runtime/_weight_streaming.py +++ b/py/torch_tensorrt/runtime/_weight_streaming.py @@ -1,5 +1,5 @@ import logging -from typing import Any +from typing import Any, Union import torch from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule @@ -16,7 +16,7 @@ class _WeightStreamingContextManager(object): """ def __init__( - self, module: torch.fx.GraphModule | CudaGraphsTorchTensorRTModule + self, module: Union[torch.fx.GraphModule, CudaGraphsTorchTensorRTModule] ) -> None: rt_mods = [] self.current_device_budget = 0