Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC] Switch TensorRT BYOC integration to IRModule-at-a-time using RelayToTIR hook #11979

Merged
merged 10 commits into from
Jul 1, 2022
2 changes: 1 addition & 1 deletion include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class Module : public ObjectRef {
class TVM_DLL ModuleNode : public Object {
public:
/*! \brief virtual destructor */
virtual ~ModuleNode() {}
virtual ~ModuleNode() = default;
/*!
* \return The per module type key.
* \note This key is used to for serializing custom modules.
Expand Down
7 changes: 2 additions & 5 deletions python/tvm/meta_schedule/testing/custom_builder_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,8 @@ def build_relay_with_tensorrt(
from tvm.relay.op.contrib import tensorrt
from tvm.runtime import Module

mod, config = tensorrt.partition_for_tensorrt(mod, params)
with PassContext(
opt_level=3,
config={"relay.ext.tensorrt.options": config},
):
mod = tensorrt.partition_for_tensorrt(mod, params)
with PassContext(opt_level=3):
result = relay_build(mod, target=target, target_host=None, params=params)
assert isinstance(result, Module)
return result
Expand Down
191 changes: 98 additions & 93 deletions python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,125 +33,116 @@
logger = logging.getLogger("TensorRT")


def is_tensorrt_compiler_enabled() -> bool:
return tvm.get_global_func("relay.ext.tensorrt.is_runtime_enabled", True) is not None


def is_tensorrt_runtime_enabled() -> bool:
"""Check if the TensorRT graph executor is present.
Returns
-------
ret: bool
True if present, False if not.
"""
check_enabled = tvm.get_global_func("relay.op.is_tensorrt_runtime_enabled", True)
check_enabled = tvm.get_global_func("relay.ext.tensorrt.is_runtime_enabled", True)
if check_enabled:
return check_enabled()
return False


def get_tensorrt_target() -> tvm.target.Target:
"""Returns the current Target, which must be of kind "tensorrt"."""
target = tvm.target.Target.current()
if target is None or target.kind.name != "tensorrt":
# Create the default target.
return tvm.target.Target("tensorrt")
return target


def get_tensorrt_version() -> Tuple[int, int, int]:
"""Gets the version of TensorRT that TVM is built against or is targeting.
"""Returns the version of TensorRT to assume during compilation.
In order of preference this is taken from:
- The current "tensorrt" target's "tensorrt_version" attribute string.
- The version linked to the TVM runtime.
- (6, 0, 1)

Returns
-------
ret: Tuple[int, int, int]
TensorRT version as a tuple of major, minor, and patch number. If TVM
is not built with TensorRT, the value set by set_tensorrt_version() is returned instead.
TensorRT version as a tuple of (major, minor, patch).
"""
pass_ctx = tvm.transform.PassContext.current()
if "relay.ext.tensorrt.options" in pass_ctx.config:
return tuple(pass_ctx.config["relay.ext.tensorrt.options"].tensorrt_version) # type: ignore
return tuple(tvm.get_global_func("relay.op.get_tensorrt_version")()) # type: ignore
# cf logic in tensorrt/codegen.cc::SaveGlobalAttributes
# First check for version in target.
target = get_tensorrt_target()
version = target.attrs["tensorrt_version"]
if len(version) == 3:
return int(version[0]), int(version[1]), int(version[2])
assert len(version) == 0

# Next, ask runtime for its version.
if is_tensorrt_runtime_enabled():
get_version = tvm.get_global_func("relay.ext.tensorrt.get_version")
version = get_version()
assert len(version) == 3
return int(version[0]), int(version[1]), int(version[2])

# Finally, use default.
logger.warning(
"TVM was not built against TensorRT and no version was provided in the 'tensorrt' target."
"Defaulting to 6.0.1."
)
return (6, 0, 1)


def get_tensorrt_use_implicit_batch_mode() -> bool:
pass_ctx = tvm.transform.PassContext.current()
if "relay.ext.tensorrt.options" in pass_ctx.config:
return pass_ctx.config["relay.ext.tensorrt.options"].use_implicit_batch
logger.warning(
"PassContext has no relay.ext.tensorrt.options config, using default value "
"use_implicit_batch=True."
)
return True
"""Returns the "use_implicit_batch" attribute of the current "tensorrt" target."""
target = get_tensorrt_target()
return target.attrs["use_implicit_batch"]


def get_tensorrt_remove_no_mac_subgraphs() -> bool:
pass_ctx = tvm.transform.PassContext.current()
if "relay.ext.tensorrt.options" in pass_ctx.config:
return pass_ctx.config["relay.ext.tensorrt.options"].remove_no_mac_subgraphs
logger.warning(
"PassContext has no relay.ext.tensorrt.options config, using default value "
"remove_no_mac_subgraphs=False."
)
return False
"""Returns the "remove_no_mac_subgraphs" attribute of the current "tensorrt" target."""
target = get_tensorrt_target()
return target.attrs["remove_no_mac_subgraphs"]


def get_tensorrt_use_fp16() -> bool:
"""Returns the "use_fp16" attribute of the current "tensorrt" target."""
target = get_tensorrt_target()
return target.attrs["use_fp16"]


def partition_for_tensorrt(
mod: tvm.IRModule,
params: Optional[Dict[str, tvm.nd.NDArray]] = None,
version: Optional[Tuple[int, int, int]] = None,
use_implicit_batch: bool = True,
remove_no_mac_subgraphs: bool = False,
max_workspace_size: int = 1 << 30,
use_fp16: bool = False,
use_uint8: bool = False,
) -> Tuple[tvm.IRModule, Dict[str, Any]]:
"""Partition the graph greedily offloading supported operators to TensorRT.
# CAUTION: Can't use default Target("tensorrt") here since the target kind is only available
# if is_tensorrt_compiler_enabled() == True.
target: Optional[tvm.target.Target] = None,
) -> tvm.IRModule:
"""Partition all functions in mod to greedily offload supported operators to TensorRT.

Parameters
----------
mod : tvm.IRModule
The module to run passes on.
The module to partition.
target : tvm.target.Target
A target of kind "tensorrt" describing additional partitioning and compilation options.
params : Optional[Dict[str, tvm.nd.NDArray]]
Constant input parameters.
version : Optional[Tuple[int, int, int]]
TensorRT version to target as tuple of (major, minor, patch). If TVM is compiled with
USE_TENSORRT_RUNTIME=ON, the linked TensorRT version will be used instead.
use_implicit_batch : bool
Use TensorRT implicit batch mode (default true). Setting to false will enable explicit batch
mode which will widen supported operators to include those which modify the batch dimension,
but may reduce performance for some models.
remove_no_mac_subgraphs : bool
Removes subgraphs which have been partitioned for TensorRT if they do not have any
multiply-accumulate operations. The removed subgraphs will go through TVM's standard
compilation instead. Can improve performance.
max_workspace_size : int
How many bytes of workspace size to allow each subgraph to use for TensorRT engine creation.
See TensorRT documentation for more info.
use_fp16: bool
Allows, TRT to automatically convert FP32 inputs to FP16. Also, it is required to be enabled
if FP16 inputs tensors and weights are used.
Note that TensorRT will still choose a higher-precision kernel if it results in overall
lower runtime, or if no low-precision implementation exists.
use_uint8: bool
Allows, TRT to automatically convert FP32 inputs to UINT8.

Returns
-------
mod_and_config : Tuple[tvm.IRModule, Dict[str, Any]]
A tuple of 1) annotated and partitioned module and 2) "relay.ext.tensorrt.options"
configuration which should be given to PassContext when building.
partitioned_mod : tvm.IRModule
The partitioned module.

"""
config: Dict[str, Any] = {
"use_implicit_batch": use_implicit_batch,
"max_workspace_size": max_workspace_size,
"remove_no_mac_subgraphs": remove_no_mac_subgraphs,
"use_fp16": use_fp16,
"use_uint8": use_uint8,
}
if version:
assert isinstance(version, tuple) and len(version) == 3
config["tensorrt_version"] = version
else:
linked_version = tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
if not linked_version:
logger.warning(
"TVM was not built against TensorRT and no version was provided to "
"partition_for_tensorrt. Defaulting to 6.0.1"
)
linked_version = (6, 0, 1)
config["tensorrt_version"] = linked_version

assert is_tensorrt_compiler_enabled(), "Can only partition for TensorRT if it is enabled"
if params:
mod["main"] = bind_params_by_name(mod["main"], params)
if target is None:
# Use a default target. The get_tensorrt_target() function will similarly create an
# equivalent default target when compilation continues after partitioning.
target = tvm.target.Target("tensorrt")

seq = tvm.transform.Sequential(
[
Expand All @@ -174,24 +165,27 @@ def partition_for_tensorrt(
transform.InferType(),
]
)
with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
with target:
mod = seq(mod)
# TODO(mbs): Revisit
# mod = prune_tensorrt_subgraphs(mod)
return mod, config
mod = prune_tensorrt_subgraphs(mod)
return mod


def is_supported_trt_type(typ: Union[tvm.ir.TensorType, tvm.ir.TupleType], op_name: str) -> bool:
"""Check whether a type is supported by TensorRT."""
supported_dtypes = ["float32", "float16"]
supported_dtypes = ["float32"]
if get_tensorrt_use_fp16():
supported_dtypes.append("float16")
if isinstance(typ, tvm.ir.TensorType):
if typ.dtype not in supported_dtypes:
logger.info(f"{op_name}: Only float32 and float16 tensor dtypes are supported.")
logger.info(f"{op_name}: Only {supported_dtypes} tensor dtypes are supported.")
return False
# assumes dim 0 is for batch and can be dynamic
# TODO(mbs): But does this depend use_implicit_batch flag?
for dim_shape in typ.shape[1:]:
if isinstance(dim_shape, tvm.tir.expr.Any):
dims = typ.shape
if get_tensorrt_use_implicit_batch_mode():
# The first dimension can be Any.
dims = dims[1:]
for dim in dims:
if isinstance(dim, tvm.tir.expr.Any):
logger.info(f"{op_name}: Only statically known tensor shapes are supported.")
return False
elif isinstance(typ, tvm.ir.TupleType):
Expand Down Expand Up @@ -241,13 +235,19 @@ def get_attrs(expr: relay.expr.Expr) -> Any:


def make_predicate(checker: CheckFunc) -> Callable[[relay.expr.Expr], bool]:
"""Returns the pattern predicate which performs the standard checks, then invokes the
more primitive checker."""

def predicate(expr: relay.expr.Expr) -> bool:
op_name = get_op_name(expr)
attrs = get_attrs(expr)
args = get_args(expr)
if not all([is_supported_trt_type(arg.checked_type, op_name) for arg in args]):
return False
return checker(attrs, args, op_name)
if not checker(attrs, args, op_name):
return False
logger.info(f"{op_name}: Predicate passes")
return True

return predicate

Expand Down Expand Up @@ -535,11 +535,16 @@ def concatenate_checker(
if int(attrs.axis) == 0:
logger.info(f"{op_name}: can't modify batch dimension.")
return False
if isinstance(args[0], relay.Tuple):
for tuple_input in args[0].fields:
if isinstance(tuple_input, Constant):
logger.info(f"{op_name}: can't concatenate tensors with constants.")
return False

if not isinstance(args[0], relay.Tuple):
logger.info("f{op_name}: concatenate must be applied to a literal tuple")
return False

for tuple_input in args[0].fields:
if isinstance(tuple_input, Constant):
logger.info(f"{op_name}: can't concatenate tensors with constants.")
return False

return True


Expand Down
10 changes: 5 additions & 5 deletions src/relay/backend/contrib/codegen_c/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,8 @@ class CodegenCModule {
};

/*! \brief The actual translation pass. */
transform::Pass CCompilerImpl() {
auto pass_func = [=](IRModule mod, const transform::PassContext& pass_ctx) {
tvm::transform::Pass CCompilerImpl() {
auto pass_func = [=](IRModule mod, const tvm::transform::PassContext& pass_ctx) {
VLOG(1) << "CCompilerImpl input:" << std::endl << PrettyPrint(mod);
Target target = GetCCompilerTarget();

Expand All @@ -388,10 +388,10 @@ transform::Pass CCompilerImpl() {
return tvm::transform::CreateModulePass(pass_func, 0, "CCompilerImpl", {});
}

transform::Pass CCompilerPass() {
tvm::transform::Pass CCompilerPass() {
return transform::Sequential(
{transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("ccompiler"), CCompilerImpl(),
transforms::MarkCompilerFunctionsAsExtern("ccompiler")});
{transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("ccompiler"), CCompilerImpl(),
transform::MarkCompilerFunctionsAsExtern("ccompiler")});
}

} // namespace contrib
Expand Down
10 changes: 5 additions & 5 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -902,8 +902,8 @@ class CutlassModuleCodegen {
* \brief A small shim to redirect to the 'relay.ext.cutlass.compile_for_cutlass' Python
* function which does the main CUTLASS training, c-code generation and compilation steps.
*/
transform::Pass CompileForCutlassImpl() {
auto pass_func = [=](IRModule mod, const transform::PassContext& pass_ctx) {
tvm::transform::Pass CompileForCutlassImpl() {
auto pass_func = [=](IRModule mod, const tvm::transform::PassContext& pass_ctx) {
VLOG(1) << "CompileForCutlass input:" << std::endl << PrettyPrint(mod);
const auto* pf = runtime::Registry::Get("relay.ext.cutlass.compile_for_cutlass");
ICHECK(pf != nullptr) << "Cannot find compile_for_cutlass function";
Expand All @@ -926,10 +926,10 @@ runtime::Module CreateCSourceModule(const IRModule& mod) {

TVM_REGISTER_GLOBAL("relay.ext.cutlass.create_c_source_module").set_body_typed(CreateCSourceModule);

transform::Pass CompileForCutlass() {
tvm::transform::Pass CompileForCutlass() {
return transform::Sequential(
{transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("cutlass"),
CompileForCutlassImpl(), transforms::MarkCompilerFunctionsAsExtern("cutlass")});
{transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("cutlass"),
CompileForCutlassImpl(), transform::MarkCompilerFunctionsAsExtern("cutlass")});
}

} // namespace cutlass
Expand Down
Loading