Skip to content

Commit

Permalink
chore: Proper logging message and rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna committed Dec 16, 2024
1 parent 11886fe commit c211c98
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
14 changes: 6 additions & 8 deletions py/torch_tensorrt/runtime/_cudagraphs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any
from typing import Any, Union

import torch
import torch_tensorrt
Expand Down Expand Up @@ -88,17 +88,15 @@ def __enter__(self) -> torch.nn.Module:
torch.ops.tensorrt.set_cudagraphs_mode(_PY_RT_CUDAGRAPHS)

logger.debug(
f"{num_torch_module} torch modules are in subgraphs. Using wrapper module for cuda graphs"
"Found pytorch subgraphs in module, wrapping module in CudaGraphsTorchTensorRTModule"
)
return CudaGraphsTorchTensorRTModule(self.compiled_module)
else:
if num_trt_module > 0:
logger.debug(
"There is no graph breaks. Using original module for cuda graphs"
)
logger.debug("No graph breaks detected, using runtime cudagraphs mode")
else:
logger.warning(
"Please consider dynamo if there is graph breaks. Using original module for cuda graphs"
logger.debug(
"Please consider dynamo if there is graph breaks. Using runtime cudagraphs mode"
)
# Enable cudagraphs for TRT submodule
set_cudagraphs_mode(True)
Expand All @@ -110,6 +108,6 @@ def __exit__(self, *args: Any) -> None:


def enable_cudagraphs(
compiled_module: torch.nn.Module,
compiled_module: Union[torch.fx.GraphModule, torch.nn.Module],
) -> _CudagraphsContextManager:
return _CudagraphsContextManager(compiled_module)
4 changes: 2 additions & 2 deletions py/torch_tensorrt/runtime/_weight_streaming.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any
from typing import Any, Union

import torch
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
Expand All @@ -16,7 +16,7 @@ class _WeightStreamingContextManager(object):
"""

def __init__(
self, module: torch.fx.GraphModule | CudaGraphsTorchTensorRTModule
self, module: Union[torch.fx.GraphModule, CudaGraphsTorchTensorRTModule]
) -> None:
rt_mods = []
self.current_device_budget = 0
Expand Down

0 comments on commit c211c98

Please sign in to comment.