Skip to content

Commit

Permalink
[BYOC] Switch TensorRT BYOC integration to IRModule-at-a-time using R…
Browse files Browse the repository at this point in the history
…elayToTIR hook (#11979)

* [BYOC] Switch TensorRT BYOC integration to IRModule-at-a-time using RelayToTIR hook

This does for the TensorRT integration what #11631 did for the CUTLASS integration.

- All compilation options are captured within the attributes of a Target of
  kind "tensorrt" (instead of the "relay.ext.tensorrt.options" attribute in
  PassContext). This means all BYOC configurations options needed by Collage can
  be captured uniformly by a list-of-Targets. It also means RPC boundaries (as used
  internally at OctoML) only need to worry about maintaining the fidelity of the
  Target instance(s) rather than reaching into the PassContext.

- Compilation is switched from function-at-a-time (relying on the TECompiler) to
  IRModule-at-a-time (using the RelayToTIR target-specific hook mechanism). Though
  not strictly necessary for Collage I want to check the path is now clear to
  deprecate the support for BYOC in TEComplier.

- Get all the TensorRT tests going again, except for a few I've disabled with
  x-link to a new issue #11765. CAUTION: The TensorRT runtime is not supported in
  CI so many of these tests are cosmetic.

- While trying to track down a 'free(): invalid pointer' error in test_tensorrt_int8_exp.py
  made the TensorRT allocs/frees more robust, but turns out its also broken in main.
  No harm leaving these changes in though.

* - Lints

* - Woops, fix test

* - lints

* - Use default tensorrt target if none given in targets list

* - fix free error

* - accidentally introduced 'transforms' namespace
- can't use default Target("tensorrt") arg

* - D'oh! Include ended up #if protected

* - restore mark for test_dynamic_offload
- handle missing runtime in versioning
- turn test_maskrcnn_resnet50 back on now that we have the
  import-torch-first workaround.

* - wibble
  • Loading branch information
mbs-octoml authored Jul 1, 2022
1 parent 55dcd5f commit d2a14a6
Show file tree
Hide file tree
Showing 19 changed files with 524 additions and 348 deletions.
2 changes: 1 addition & 1 deletion include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class Module : public ObjectRef {
class TVM_DLL ModuleNode : public Object {
public:
/*! \brief virtual destructor */
virtual ~ModuleNode() {}
virtual ~ModuleNode() = default;
/*!
* \return The per module type key.
* \note This key is used to for serializing custom modules.
Expand Down
7 changes: 2 additions & 5 deletions python/tvm/meta_schedule/testing/custom_builder_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,8 @@ def build_relay_with_tensorrt(
from tvm.relay.op.contrib import tensorrt
from tvm.runtime import Module

mod, config = tensorrt.partition_for_tensorrt(mod, params)
with PassContext(
opt_level=3,
config={"relay.ext.tensorrt.options": config},
):
mod = tensorrt.partition_for_tensorrt(mod, params)
with PassContext(opt_level=3):
result = relay_build(mod, target=target, target_host=None, params=params)
assert isinstance(result, Module)
return result
Expand Down
191 changes: 98 additions & 93 deletions python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,125 +33,116 @@
logger = logging.getLogger("TensorRT")


def is_tensorrt_compiler_enabled() -> bool:
return tvm.get_global_func("relay.ext.tensorrt.is_runtime_enabled", True) is not None


def is_tensorrt_runtime_enabled() -> bool:
"""Check if the TensorRT graph executor is present.
Returns
-------
ret: bool
True if present, False if not.
"""
check_enabled = tvm.get_global_func("relay.op.is_tensorrt_runtime_enabled", True)
check_enabled = tvm.get_global_func("relay.ext.tensorrt.is_runtime_enabled", True)
if check_enabled:
return check_enabled()
return False


def get_tensorrt_target() -> tvm.target.Target:
"""Returns the current Target, which must be of kind "tensorrt"."""
target = tvm.target.Target.current()
if target is None or target.kind.name != "tensorrt":
# Create the default target.
return tvm.target.Target("tensorrt")
return target


def get_tensorrt_version() -> Tuple[int, int, int]:
"""Gets the version of TensorRT that TVM is built against or is targeting.
"""Returns the version of TensorRT to assume during compilation.
In order of preference this is taken from:
- The current "tensorrt" target's "tensorrt_version" attribute string.
- The version linked to the TVM runtime.
- (6, 0, 1)
Returns
-------
ret: Tuple[int, int, int]
TensorRT version as a tuple of major, minor, and patch number. If TVM
is not built with TensorRT, the value set by set_tensorrt_version() is returned instead.
TensorRT version as a tuple of (major, minor, patch).
"""
pass_ctx = tvm.transform.PassContext.current()
if "relay.ext.tensorrt.options" in pass_ctx.config:
return tuple(pass_ctx.config["relay.ext.tensorrt.options"].tensorrt_version) # type: ignore
return tuple(tvm.get_global_func("relay.op.get_tensorrt_version")()) # type: ignore
# cf logic in tensorrt/codegen.cc::SaveGlobalAttributes
# First check for version in target.
target = get_tensorrt_target()
version = target.attrs["tensorrt_version"]
if len(version) == 3:
return int(version[0]), int(version[1]), int(version[2])
assert len(version) == 0

# Next, ask runtime for its version.
if is_tensorrt_runtime_enabled():
get_version = tvm.get_global_func("relay.ext.tensorrt.get_version")
version = get_version()
assert len(version) == 3
return int(version[0]), int(version[1]), int(version[2])

# Finally, use default.
logger.warning(
"TVM was not built against TensorRT and no version was provided in the 'tensorrt' target."
"Defaulting to 6.0.1."
)
return (6, 0, 1)


def get_tensorrt_use_implicit_batch_mode() -> bool:
pass_ctx = tvm.transform.PassContext.current()
if "relay.ext.tensorrt.options" in pass_ctx.config:
return pass_ctx.config["relay.ext.tensorrt.options"].use_implicit_batch
logger.warning(
"PassContext has no relay.ext.tensorrt.options config, using default value "
"use_implicit_batch=True."
)
return True
"""Returns the "use_implicit_batch" attribute of the current "tensorrt" target."""
target = get_tensorrt_target()
return target.attrs["use_implicit_batch"]


def get_tensorrt_remove_no_mac_subgraphs() -> bool:
pass_ctx = tvm.transform.PassContext.current()
if "relay.ext.tensorrt.options" in pass_ctx.config:
return pass_ctx.config["relay.ext.tensorrt.options"].remove_no_mac_subgraphs
logger.warning(
"PassContext has no relay.ext.tensorrt.options config, using default value "
"remove_no_mac_subgraphs=False."
)
return False
"""Returns the "remove_no_mac_subgraphs" attribute of the current "tensorrt" target."""
target = get_tensorrt_target()
return target.attrs["remove_no_mac_subgraphs"]


def get_tensorrt_use_fp16() -> bool:
"""Returns the "use_fp16" attribute of the current "tensorrt" target."""
target = get_tensorrt_target()
return target.attrs["use_fp16"]


def partition_for_tensorrt(
mod: tvm.IRModule,
params: Optional[Dict[str, tvm.nd.NDArray]] = None,
version: Optional[Tuple[int, int, int]] = None,
use_implicit_batch: bool = True,
remove_no_mac_subgraphs: bool = False,
max_workspace_size: int = 1 << 30,
use_fp16: bool = False,
use_uint8: bool = False,
) -> Tuple[tvm.IRModule, Dict[str, Any]]:
"""Partition the graph greedily offloading supported operators to TensorRT.
# CAUTION: Can't use default Target("tensorrt") here since the target kind is only available
# if is_tensorrt_compiler_enabled() == True.
target: Optional[tvm.target.Target] = None,
) -> tvm.IRModule:
"""Partition all functions in mod to greedily offload supported operators to TensorRT.
Parameters
----------
mod : tvm.IRModule
The module to run passes on.
The module to partition.
target : tvm.target.Target
A target of kind "tensorrt" describing additional partitioning and compilation options.
params : Optional[Dict[str, tvm.nd.NDArray]]
Constant input parameters.
version : Optional[Tuple[int, int, int]]
TensorRT version to target as tuple of (major, minor, patch). If TVM is compiled with
USE_TENSORRT_RUNTIME=ON, the linked TensorRT version will be used instead.
use_implicit_batch : bool
Use TensorRT implicit batch mode (default true). Setting to false will enable explicit batch
mode which will widen supported operators to include those which modify the batch dimension,
but may reduce performance for some models.
remove_no_mac_subgraphs : bool
Removes subgraphs which have been partitioned for TensorRT if they do not have any
multiply-accumulate operations. The removed subgraphs will go through TVM's standard
compilation instead. Can improve performance.
max_workspace_size : int
How many bytes of workspace size to allow each subgraph to use for TensorRT engine creation.
See TensorRT documentation for more info.
use_fp16: bool
Allows, TRT to automatically convert FP32 inputs to FP16. Also, it is required to be enabled
if FP16 inputs tensors and weights are used.
Note that TensorRT will still choose a higher-precision kernel if it results in overall
lower runtime, or if no low-precision implementation exists.
use_uint8: bool
Allows, TRT to automatically convert FP32 inputs to UINT8.
Returns
-------
mod_and_config : Tuple[tvm.IRModule, Dict[str, Any]]
A tuple of 1) annotated and partitioned module and 2) "relay.ext.tensorrt.options"
configuration which should be given to PassContext when building.
partitioned_mod : tvm.IRModule
The partitioned module.
"""
config: Dict[str, Any] = {
"use_implicit_batch": use_implicit_batch,
"max_workspace_size": max_workspace_size,
"remove_no_mac_subgraphs": remove_no_mac_subgraphs,
"use_fp16": use_fp16,
"use_uint8": use_uint8,
}
if version:
assert isinstance(version, tuple) and len(version) == 3
config["tensorrt_version"] = version
else:
linked_version = tuple(tvm.get_global_func("relay.op.get_tensorrt_version")())
if not linked_version:
logger.warning(
"TVM was not built against TensorRT and no version was provided to "
"partition_for_tensorrt. Defaulting to 6.0.1"
)
linked_version = (6, 0, 1)
config["tensorrt_version"] = linked_version

assert is_tensorrt_compiler_enabled(), "Can only partition for TensorRT if it is enabled"
if params:
mod["main"] = bind_params_by_name(mod["main"], params)
if target is None:
# Use a default target. The get_tensorrt_target() function will similarly create an
# equivalent default target when compilation continues after partitioning.
target = tvm.target.Target("tensorrt")

seq = tvm.transform.Sequential(
[
Expand All @@ -174,24 +165,27 @@ def partition_for_tensorrt(
transform.InferType(),
]
)
with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
with target:
mod = seq(mod)
# TODO(mbs): Revisit
# mod = prune_tensorrt_subgraphs(mod)
return mod, config
mod = prune_tensorrt_subgraphs(mod)
return mod


def is_supported_trt_type(typ: Union[tvm.ir.TensorType, tvm.ir.TupleType], op_name: str) -> bool:
"""Check whether a type is supported by TensorRT."""
supported_dtypes = ["float32", "float16"]
supported_dtypes = ["float32"]
if get_tensorrt_use_fp16():
supported_dtypes.append("float16")
if isinstance(typ, tvm.ir.TensorType):
if typ.dtype not in supported_dtypes:
logger.info(f"{op_name}: Only float32 and float16 tensor dtypes are supported.")
logger.info(f"{op_name}: Only {supported_dtypes} tensor dtypes are supported.")
return False
# assumes dim 0 is for batch and can be dynamic
# TODO(mbs): But does this depend use_implicit_batch flag?
for dim_shape in typ.shape[1:]:
if isinstance(dim_shape, tvm.tir.expr.Any):
dims = typ.shape
if get_tensorrt_use_implicit_batch_mode():
# The first dimension can be Any.
dims = dims[1:]
for dim in dims:
if isinstance(dim, tvm.tir.expr.Any):
logger.info(f"{op_name}: Only statically known tensor shapes are supported.")
return False
elif isinstance(typ, tvm.ir.TupleType):
Expand Down Expand Up @@ -241,13 +235,19 @@ def get_attrs(expr: relay.expr.Expr) -> Any:


def make_predicate(checker: CheckFunc) -> Callable[[relay.expr.Expr], bool]:
"""Returns the pattern predicate which performs the standard checks, then invokes the
more primitive checker."""

def predicate(expr: relay.expr.Expr) -> bool:
op_name = get_op_name(expr)
attrs = get_attrs(expr)
args = get_args(expr)
if not all([is_supported_trt_type(arg.checked_type, op_name) for arg in args]):
return False
return checker(attrs, args, op_name)
if not checker(attrs, args, op_name):
return False
logger.info(f"{op_name}: Predicate passes")
return True

return predicate

Expand Down Expand Up @@ -535,11 +535,16 @@ def concatenate_checker(
if int(attrs.axis) == 0:
logger.info(f"{op_name}: can't modify batch dimension.")
return False
if isinstance(args[0], relay.Tuple):
for tuple_input in args[0].fields:
if isinstance(tuple_input, Constant):
logger.info(f"{op_name}: can't concatenate tensors with constants.")
return False

if not isinstance(args[0], relay.Tuple):
logger.info("f{op_name}: concatenate must be applied to a literal tuple")
return False

for tuple_input in args[0].fields:
if isinstance(tuple_input, Constant):
logger.info(f"{op_name}: can't concatenate tensors with constants.")
return False

return True


Expand Down
10 changes: 5 additions & 5 deletions src/relay/backend/contrib/codegen_c/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,8 @@ class CodegenCModule {
};

/*! \brief The actual translation pass. */
transform::Pass CCompilerImpl() {
auto pass_func = [=](IRModule mod, const transform::PassContext& pass_ctx) {
tvm::transform::Pass CCompilerImpl() {
auto pass_func = [=](IRModule mod, const tvm::transform::PassContext& pass_ctx) {
VLOG(1) << "CCompilerImpl input:" << std::endl << PrettyPrint(mod);
Target target = GetCCompilerTarget();

Expand All @@ -388,10 +388,10 @@ transform::Pass CCompilerImpl() {
return tvm::transform::CreateModulePass(pass_func, 0, "CCompilerImpl", {});
}

transform::Pass CCompilerPass() {
tvm::transform::Pass CCompilerPass() {
return transform::Sequential(
{transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("ccompiler"), CCompilerImpl(),
transforms::MarkCompilerFunctionsAsExtern("ccompiler")});
{transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("ccompiler"), CCompilerImpl(),
transform::MarkCompilerFunctionsAsExtern("ccompiler")});
}

} // namespace contrib
Expand Down
10 changes: 5 additions & 5 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -902,8 +902,8 @@ class CutlassModuleCodegen {
* \brief A small shim to redirect to the 'relay.ext.cutlass.compile_for_cutlass' Python
* function which does the main CUTLASS training, c-code generation and compilation steps.
*/
transform::Pass CompileForCutlassImpl() {
auto pass_func = [=](IRModule mod, const transform::PassContext& pass_ctx) {
tvm::transform::Pass CompileForCutlassImpl() {
auto pass_func = [=](IRModule mod, const tvm::transform::PassContext& pass_ctx) {
VLOG(1) << "CompileForCutlass input:" << std::endl << PrettyPrint(mod);
const auto* pf = runtime::Registry::Get("relay.ext.cutlass.compile_for_cutlass");
ICHECK(pf != nullptr) << "Cannot find compile_for_cutlass function";
Expand All @@ -926,10 +926,10 @@ runtime::Module CreateCSourceModule(const IRModule& mod) {

TVM_REGISTER_GLOBAL("relay.ext.cutlass.create_c_source_module").set_body_typed(CreateCSourceModule);

transform::Pass CompileForCutlass() {
tvm::transform::Pass CompileForCutlass() {
return transform::Sequential(
{transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("cutlass"),
CompileForCutlassImpl(), transforms::MarkCompilerFunctionsAsExtern("cutlass")});
{transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("cutlass"),
CompileForCutlassImpl(), transform::MarkCompilerFunctionsAsExtern("cutlass")});
}

} // namespace cutlass
Expand Down
Loading

0 comments on commit d2a14a6

Please sign in to comment.