Skip to content

Commit

Permalink
fix: Necessary default fixes for compile
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive committed Aug 23, 2023
1 parent 28ba56f commit e03e94d
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,23 @@
import torch
import torch.fx
import torch_tensorrt.ts
from torch._export import ExportedProgram
from torch_tensorrt._enums import dtype
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo.compile import compile as dynamo_compile
from torch_tensorrt._utils import sanitized_torch_version
from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx.lower import compile as fx_compile
from torch_tensorrt.fx.utils import LowerPrecision
from torch_tensorrt.ts._compiler import compile as torchscript_compile
from typing_extensions import TypeGuard

from packaging import version

DYNAMO_ENABLED = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev")

if DYNAMO_ENABLED:
from torch._export import ExportedProgram
from torch_tensorrt.dynamo.compile import compile as dynamo_compile

logger = logging.getLogger(__name__)

__all__ = [
Expand Down Expand Up @@ -64,7 +71,7 @@ def _parse_module_type(module: Any) -> _ModuleType:
return _ModuleType.ts
elif isinstance(module, torch.fx.GraphModule):
return _ModuleType.fx
elif isinstance(module, ExportedProgram):
elif DYNAMO_ENABLED and isinstance(module, ExportedProgram):
return _ModuleType.ep
elif isinstance(module, torch.nn.Module):
return _ModuleType.nn
Expand Down Expand Up @@ -93,13 +100,14 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
else:
if ir == "default":
# Options are listed in order of preference
if module_is_fxable:
if DYNAMO_ENABLED and module_is_fxable:
logger.info("ir was set to default, using dynamo as ir")
return _IRType.dynamo
elif module_is_tsable:
logger.warning(
"Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to suppress the warning. Compiling the module with ir=torchscript"
)
if DYNAMO_ENABLED:
logger.warning(
"Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to suppress the warning. Compiling the module with ir=torchscript"
)
return _IRType.ts
elif module_is_exportable:
raise ValueError(
Expand Down

0 comments on commit e03e94d

Please sign in to comment.