From cfd256138ec075c67aced56280cae3b6d36034fa Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Thu, 30 Jun 2022 13:29:50 -0700 Subject: [PATCH 01/10] [BYOC] Switch TensorRT BYOC integration to IRModule-at-a-time using RelayToTIR hook This does for the TensorRT integration what #11631 did for the CUTLASS integration. - All compilation options are captured within the attributes of a Target of kind "tensorrt" (instead of the "relay.ext.tensorrt.options" attribute in PassContext). This means all BYOC configurations options needed by Collage can be captured uniformly by a list-of-Targets. It also means RPC boundaries (as used internally at OctoML) only need to worry about maintaining the fidelity of the Target instance(s) rather than reaching into the PassContext. - Compilation is switched from function-at-a-time (relying on the TECompiler) to IRModule-at-a-time (using the RelayToTIR target-specific hook mechanism). Though not strictly necessary for Collage I want to check the path is now clear to deprecate the support for BYOC in TEComplier. - Get all the TensorRT tests going again, except for a few I've disabled with x-link to a new issue #11765. CAUTION: The TensorRT runtime is not supported in CI so many of these tests are cosmetic. - While trying to track down a 'free(): invalid pointer' error in test_tensorrt_int8_exp.py made the TensorRT allocs/frees more robust, but turns out its also broken in main. No harm leaving these changes in though. --- include/tvm/runtime/module.h | 2 +- .../testing/custom_builder_runner.py | 7 +- python/tvm/relay/op/contrib/tensorrt.py | 177 ++++++------ src/relay/backend/contrib/tensorrt/codegen.cc | 261 ++++++++++-------- src/relay/backend/contrib/tensorrt/codegen.h | 47 ++++ src/relay/backend/contrib/tensorrt/target.cc | 31 ++- src/runtime/const_loader_module.cc | 24 +- src/runtime/contrib/json/json_runtime.h | 2 + .../contrib/tensorrt/tensorrt_builder.cc | 27 +- .../contrib/tensorrt/tensorrt_builder.h | 10 +- src/runtime/contrib/tensorrt/tensorrt_ops.cc | 4 +- .../contrib/tensorrt/tensorrt_runtime.cc | 14 +- src/target/metadata_module.cc | 2 - tests/python/contrib/test_tensorrt.py | 166 +++++------ .../python/contrib/test_tensorrt_int8_exp.py | 18 +- 15 files changed, 471 insertions(+), 321 deletions(-) create mode 100644 src/relay/backend/contrib/tensorrt/codegen.h diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 31d05571eefd..9d139c9feff3 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -113,7 +113,7 @@ class Module : public ObjectRef { class TVM_DLL ModuleNode : public Object { public: /*! \brief virtual destructor */ - virtual ~ModuleNode() {} + virtual ~ModuleNode() = default; /*! * \return The per module type key. * \note This key is used to for serializing custom modules. diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py b/python/tvm/meta_schedule/testing/custom_builder_runner.py index e203848c2cbb..1cfd4ab833be 100644 --- a/python/tvm/meta_schedule/testing/custom_builder_runner.py +++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py @@ -85,11 +85,8 @@ def build_relay_with_tensorrt( from tvm.relay.op.contrib import tensorrt from tvm.runtime import Module - mod, config = tensorrt.partition_for_tensorrt(mod, params) - with PassContext( - opt_level=3, - config={"relay.ext.tensorrt.options": config}, - ): + mod = tensorrt.partition_for_tensorrt(mod, params) + with PassContext(opt_level=3): result = relay_build(mod, target=target, target_host=None, params=params) assert isinstance(result, Module) return result diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index a69e2d410529..c441c30808c3 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -26,13 +26,17 @@ from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name from tvm.relay.dataflow_pattern import is_op, wildcard, is_constant, is_tuple, is_tuple_get_item -from tvm.relay.expr import Call, Constant, TupleGetItem +from tvm.relay.expr import Call, Constant, GlobalVar, TupleGetItem from tvm.relay.expr_functor import ExprMutator, ExprVisitor from tvm.relay.op.contrib.register import register_pattern_table logger = logging.getLogger("TensorRT") +def is_tensorrt_compiler_enabled() -> bool: + return tvm.get_global_func("relay.ext.tensorrt.is_runtime_enabled", True) is not None + + def is_tensorrt_runtime_enabled() -> bool: """Check if the TensorRT graph executor is present. Returns @@ -40,116 +44,90 @@ def is_tensorrt_runtime_enabled() -> bool: ret: bool True if present, False if not. """ - check_enabled = tvm.get_global_func("relay.op.is_tensorrt_runtime_enabled", True) + check_enabled = tvm.get_global_func("relay.ext.tensorrt.is_runtime_enabled", True) if check_enabled: return check_enabled() return False +def get_tensorrt_target() -> tvm.target.Target: + """Returns the current Target, which must be of kind "tensorrt".""" + target = tvm.target.Target.current() + assert target.kind.name == "tensorrt" + return target + + def get_tensorrt_version() -> Tuple[int, int, int]: - """Gets the version of TensorRT that TVM is built against or is targeting. + """Returns the version of TensorRT to assume during compilation. + In order of preference this is taken from: + - The current "tensorrt" target's "tensorrt_version" attribute string. + - The version linked to the TVM runtime. + - (6, 0, 1) Returns ------- ret: Tuple[int, int, int] - TensorRT version as a tuple of major, minor, and patch number. If TVM - is not built with TensorRT, the value set by set_tensorrt_version() is returned instead. + TensorRT version as a tuple of (major, minor, patch). """ - pass_ctx = tvm.transform.PassContext.current() - if "relay.ext.tensorrt.options" in pass_ctx.config: - return tuple(pass_ctx.config["relay.ext.tensorrt.options"].tensorrt_version) # type: ignore - return tuple(tvm.get_global_func("relay.op.get_tensorrt_version")()) # type: ignore + target = get_tensorrt_target() + version = target.attrs["tensorrt_version"] + if len(version) == 3: + return int(version[0]), int(version[1]), int(version[2]) + assert len(version) == 0 + + get_version = tvm.get_global_func("relay.ext.tensorrt.get_version", True) + if get_version: + version = get_version() + assert len(version) == 3 + return int(version[0]), int(version[1]), int(version[2]) - -def get_tensorrt_use_implicit_batch_mode() -> bool: - pass_ctx = tvm.transform.PassContext.current() - if "relay.ext.tensorrt.options" in pass_ctx.config: - return pass_ctx.config["relay.ext.tensorrt.options"].use_implicit_batch logger.warning( - "PassContext has no relay.ext.tensorrt.options config, using default value " - "use_implicit_batch=True." + "TVM was not built against TensorRT and no version was provided to " + "partition_for_tensorrt. Defaulting to 6.0.1" ) - return True + return (6, 0, 1) + + +def get_tensorrt_use_implicit_batch_mode() -> bool: + """Returns the "use_implicit_batch" attribute of the current "tensorrt" target.""" + target = get_tensorrt_target() + return target.attrs["use_implicit_batch"] def get_tensorrt_remove_no_mac_subgraphs() -> bool: - pass_ctx = tvm.transform.PassContext.current() - if "relay.ext.tensorrt.options" in pass_ctx.config: - return pass_ctx.config["relay.ext.tensorrt.options"].remove_no_mac_subgraphs - logger.warning( - "PassContext has no relay.ext.tensorrt.options config, using default value " - "remove_no_mac_subgraphs=False." - ) - return False + """Returns the "remove_no_mac_subgraphs" attribute of the current "tensorrt" target.""" + target = get_tensorrt_target() + return target.attrs["remove_no_mac_subgraphs"] + + +def get_tensorrt_use_fp16() -> bool: + """Returns the "use_fp16" attribute of the current "tensorrt" target.""" + target = get_tensorrt_target() + return target.attrs["use_fp16"] def partition_for_tensorrt( mod: tvm.IRModule, params: Optional[Dict[str, tvm.nd.NDArray]] = None, - version: Optional[Tuple[int, int, int]] = None, - use_implicit_batch: bool = True, - remove_no_mac_subgraphs: bool = False, - max_workspace_size: int = 1 << 30, - use_fp16: bool = False, - use_uint8: bool = False, -) -> Tuple[tvm.IRModule, Dict[str, Any]]: - """Partition the graph greedily offloading supported operators to TensorRT. + target: tvm.target.Target = tvm.target.Target("tensorrt"), +) -> tvm.IRModule: + """Partition all functions in mod to greedily offload supported operators to TensorRT. Parameters ---------- mod : tvm.IRModule - The module to run passes on. + The module to partition. + target : tvm.target.Target + A target of kind "tensorrt" describing additional partitioning and compilation options. params : Optional[Dict[str, tvm.nd.NDArray]] Constant input parameters. - version : Optional[Tuple[int, int, int]] - TensorRT version to target as tuple of (major, minor, patch). If TVM is compiled with - USE_TENSORRT_RUNTIME=ON, the linked TensorRT version will be used instead. - use_implicit_batch : bool - Use TensorRT implicit batch mode (default true). Setting to false will enable explicit batch - mode which will widen supported operators to include those which modify the batch dimension, - but may reduce performance for some models. - remove_no_mac_subgraphs : bool - Removes subgraphs which have been partitioned for TensorRT if they do not have any - multiply-accumulate operations. The removed subgraphs will go through TVM's standard - compilation instead. Can improve performance. - max_workspace_size : int - How many bytes of workspace size to allow each subgraph to use for TensorRT engine creation. - See TensorRT documentation for more info. - use_fp16: bool - Allows, TRT to automatically convert FP32 inputs to FP16. Also, it is required to be enabled - if FP16 inputs tensors and weights are used. - Note that TensorRT will still choose a higher-precision kernel if it results in overall - lower runtime, or if no low-precision implementation exists. - use_uint8: bool - Allows, TRT to automatically convert FP32 inputs to UINT8. Returns ------- - mod_and_config : Tuple[tvm.IRModule, Dict[str, Any]] - A tuple of 1) annotated and partitioned module and 2) "relay.ext.tensorrt.options" - configuration which should be given to PassContext when building. + partitioned_mod : tvm.IRModule + The partitioned module. """ - config: Dict[str, Any] = { - "use_implicit_batch": use_implicit_batch, - "max_workspace_size": max_workspace_size, - "remove_no_mac_subgraphs": remove_no_mac_subgraphs, - "use_fp16": use_fp16, - "use_uint8": use_uint8, - } - if version: - assert isinstance(version, tuple) and len(version) == 3 - config["tensorrt_version"] = version - else: - linked_version = tuple(tvm.get_global_func("relay.op.get_tensorrt_version")()) - if not linked_version: - logger.warning( - "TVM was not built against TensorRT and no version was provided to " - "partition_for_tensorrt. Defaulting to 6.0.1" - ) - linked_version = (6, 0, 1) - config["tensorrt_version"] = linked_version - if params: mod["main"] = bind_params_by_name(mod["main"], params) @@ -174,24 +152,27 @@ def partition_for_tensorrt( transform.InferType(), ] ) - with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): + with target: mod = seq(mod) - # TODO(mbs): Revisit - # mod = prune_tensorrt_subgraphs(mod) - return mod, config + mod = prune_tensorrt_subgraphs(mod) + return mod def is_supported_trt_type(typ: Union[tvm.ir.TensorType, tvm.ir.TupleType], op_name: str) -> bool: """Check whether a type is supported by TensorRT.""" - supported_dtypes = ["float32", "float16"] + supported_dtypes = ["float32"] + if get_tensorrt_use_fp16(): + supported_dtypes.append("float16") if isinstance(typ, tvm.ir.TensorType): if typ.dtype not in supported_dtypes: - logger.info(f"{op_name}: Only float32 and float16 tensor dtypes are supported.") + logger.info(f"{op_name}: Only {supported_dtypes} tensor dtypes are supported.") return False - # assumes dim 0 is for batch and can be dynamic - # TODO(mbs): But does this depend use_implicit_batch flag? - for dim_shape in typ.shape[1:]: - if isinstance(dim_shape, tvm.tir.expr.Any): + dims = typ.shape + if get_tensorrt_use_implicit_batch_mode(): + # The first dimension can be Any. + dims = dims[1:] + for dim in dims: + if isinstance(dim, tvm.tir.expr.Any): logger.info(f"{op_name}: Only statically known tensor shapes are supported.") return False elif isinstance(typ, tvm.ir.TupleType): @@ -247,7 +228,10 @@ def predicate(expr: relay.expr.Expr) -> bool: args = get_args(expr) if not all([is_supported_trt_type(arg.checked_type, op_name) for arg in args]): return False - return checker(attrs, args, op_name) + if not checker(attrs, args, op_name): + return False + logger.info(f"{op_name}: Predicate passes") + return True return predicate @@ -535,11 +519,16 @@ def concatenate_checker( if int(attrs.axis) == 0: logger.info(f"{op_name}: can't modify batch dimension.") return False - if isinstance(args[0], relay.Tuple): - for tuple_input in args[0].fields: - if isinstance(tuple_input, Constant): - logger.info(f"{op_name}: can't concatenate tensors with constants.") - return False + + if not isinstance(args[0], relay.Tuple): + logger.info("f{op_name}: concatenate must be applied to a literal tuple") + return False + + for tuple_input in args[0].fields: + if isinstance(tuple_input, Constant): + logger.info(f"{op_name}: can't concatenate tensors with constants.") + return False + return True diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc index e08cd240d4d1..526f6bf7588a 100644 --- a/src/relay/backend/contrib/tensorrt/codegen.cc +++ b/src/relay/backend/contrib/tensorrt/codegen.cc @@ -33,42 +33,56 @@ #include "../codegen_json/codegen_json.h" #if TVM_GRAPH_EXECUTOR_TENSORRT +#include "../../../transforms/compiler_function_utils.h" #include "NvInfer.h" #endif namespace tvm { namespace relay { namespace contrib { +namespace tensorrt { -/*! \brief Attributes to store the compiler options for TensorRT. */ -struct TensorRTCompilerConfigNode : public tvm::AttrsNode { - Array tensorrt_version; - bool use_implicit_batch; - size_t max_workspace_size; - bool remove_no_mac_subgraphs; - bool use_fp16; - bool use_uint8; - - TVM_DECLARE_ATTRS(TensorRTCompilerConfigNode, "ext.attrs.TensorRTCompilerConfigNode") { - TVM_ATTR_FIELD(tensorrt_version) - .describe("TensorRT version as (major, minor, patch).") - .set_default(Array({6, 0, 1})); - TVM_ATTR_FIELD(use_implicit_batch).set_default(true); - TVM_ATTR_FIELD(max_workspace_size).set_default(size_t(1) << 30); - TVM_ATTR_FIELD(remove_no_mac_subgraphs).set_default(false); - TVM_ATTR_FIELD(use_fp16).set_default(false); - TVM_ATTR_FIELD(use_uint8).set_default(false); - } -}; +/*! + * \brief Check whether TensorRT graph executor is enabled. + * \return True if enabled, False if not. + */ +inline constexpr bool IsRuntimeEnabled() { +#if TVM_GRAPH_EXECUTOR_TENSORRT + return true; +#else + return false; +#endif // TVM_GRAPH_EXECUTOR_TENSORRT +} -class TensorRTCompilerConfig : public Attrs { - public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorRTCompilerConfig, Attrs, - TensorRTCompilerConfigNode); -}; +TVM_REGISTER_GLOBAL("relay.ext.tensorrt.is_runtime_enabled").set_body_typed(IsRuntimeEnabled); -TVM_REGISTER_NODE_TYPE(TensorRTCompilerConfigNode); -TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.tensorrt.options", TensorRTCompilerConfig); +/*! + * \brief Get TensorRT version that TVM is built against. + * \return Array of three integers for major, minor, and patch, or empty array if TensorRT graph + * runtime is not enabled. + */ +Array GetVersion() { +#if TVM_GRAPH_EXECUTOR_TENSORRT + return {Integer(NV_TENSORRT_MAJOR), Integer(NV_TENSORRT_MINOR), Integer(NV_TENSORRT_PATCH)}; +#else + return {}; +#endif // TVM_GRAPH_EXECUTOR_TENSORRT +} + +TVM_REGISTER_GLOBAL("relay.ext.tensorrt.get_version").set_body_typed(GetVersion); + +/*! + * \brief Returns the "tensorrt" Target instance to use for compilation. + */ +Target GetTensorRTTarget() { + Target target = Target::Current(/*allow_not_defined=*/true); + if (!target.defined() || target->kind->name != "tensorrt") { + // Since we allow partition_for_tensorrt to use the default "tensorrt" target, we should + // similarly allow the custom pass to execute without a specific "tensorrt" target in scope. + target = Target("tensorrt"); + } + return target; +} using JSONGraphNode = tvm::runtime::json::JSONGraphNode; using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; @@ -87,6 +101,7 @@ class CollectFromCompositeFunctionBody : public ExprVisitor { explicit CollectFromCompositeFunctionBody(TensorRTJSONSerializer* serializer) : serializer_(serializer), node_(std::make_shared()) {} + // We'll need to implement these out-of-band since they use the serializer. void VisitExpr_(const ConstantNode* constant_node) final; void VisitExpr_(const CallNode* call_node) final; @@ -190,6 +205,7 @@ class CollectFromCompositeFunctionBody : public ExprVisitor { extractor.Extract(const_cast(attr_obj)); } + /*! \brief The parent serializer for the overall TensorRT partition. */ TensorRTJSONSerializer* serializer_; /*! \brief Accumulated translated arguments. */ std::vector args_; @@ -207,9 +223,10 @@ class CollectFromCompositeFunctionBody : public ExprVisitor { */ class TensorRTJSONSerializer : public JSONSerializer { public: - TensorRTJSONSerializer(const std::string& symbol, const Expr& expr) - : JSONSerializer(symbol, expr) {} + TensorRTJSONSerializer(Target target, const std::string& symbol, const Expr& expr) + : JSONSerializer(symbol, expr), target_(std::move(target)) {} + private: using JSONSerializer::VisitExpr_; std::vector VisitExpr_(const CallNode* call_node) final { @@ -245,40 +262,58 @@ class TensorRTJSONSerializer : public JSONSerializer { node->CaptureAttrs(*collector.node_); // Capture global settings on the JSON node. - SaveGlobalAttributes(node); + // TODO(mbs): Why on every call? + SaveGlobalAttributes(node.get()); VLOG(1) << name << " has " << node->GetInputs().size() << " inputs"; return AddNode(node, GetRef(call_node)); } - static void SaveGlobalAttributes(std::shared_ptr node) { - auto ctx = transform::PassContext::Current(); - auto cfg = ctx->GetConfig("relay.ext.tensorrt.options"); - if (!cfg.defined()) { - cfg = AttrsWithDefaultValues(); + static void SetAttr(JSONGraphNode* node, const std::string& key, + std::vector values) { + node->SetAttr(key, std::vector({std::move(values)})); + } + + /*! \brief Capture the compilation options as attributes on \p node. */ + void SaveGlobalAttributes(JSONGraphNode* node) { + { + Array target_attr = target_->GetAttr>("tensorrt_version").value(); + if (target_attr.empty()) { + target_attr = GetVersion(); + } + if (target_attr.empty()) { + target_attr = {6, 0, 1}; + } + ICHECK_EQ(target_attr.size(), 3); + SetAttr(node, "tensorrt_version", + {std::to_string(target_attr[0]), std::to_string(target_attr[1]), + std::to_string(target_attr[2])}); + } + + { + Bool target_attr = target_->GetAttr("use_implicit_batch").value(); + SetAttr(node, "use_implicit_batch", {std::to_string(target_attr->value)}); + } + + { + Integer target_attr = target_->GetAttr("max_workspace_size").value(); + SetAttr(node, "max_workspace_size", {std::to_string(target_attr->value)}); + } + + { + Bool target_attr = target_->GetAttr("use_fp16").value(); + SetAttr(node, "use_fp16", {std::to_string(target_attr->value)}); + } + + { + Bool target_attr = target_->GetAttr("use_uint8").value(); + SetAttr(node, "use_uint8", {std::to_string(target_attr->value)}); } - ICHECK_EQ(cfg.value()->tensorrt_version.size(), 3); - std::vector tensorrt_version = {std::to_string(cfg.value()->tensorrt_version[0]), - std::to_string(cfg.value()->tensorrt_version[1]), - std::to_string(cfg.value()->tensorrt_version[2])}; - std::vector use_implicit_batch = {std::to_string(cfg.value()->use_implicit_batch)}; - std::vector max_workspace_size = {std::to_string(cfg.value()->max_workspace_size)}; - std::vector use_fp16 = {std::to_string(cfg.value()->use_fp16)}; - std::vector use_uint8 = {std::to_string(cfg.value()->use_uint8)}; - std::vector tensorrt_version_attr, use_implicit_batch_attr, max_workspace_size_attr, - use_fp16_attr, use_uint8_attr; - tensorrt_version_attr.emplace_back(tensorrt_version); - use_implicit_batch_attr.emplace_back(use_implicit_batch); - max_workspace_size_attr.emplace_back(max_workspace_size); - use_fp16_attr.emplace_back(use_fp16); - use_uint8_attr.emplace_back(use_uint8); - node->SetAttr("tensorrt_version", tensorrt_version_attr); - node->SetAttr("use_implicit_batch", use_implicit_batch_attr); - node->SetAttr("max_workspace_size", max_workspace_size_attr); - node->SetAttr("use_fp16", use_fp16_attr); - node->SetAttr("use_uint8", use_uint8_attr); } + + /*! \brief The "tensorrt" Target guiding compilation. */ + Target target_; }; void CollectFromCompositeFunctionBody::VisitExpr_(const ConstantNode* constant_node) { @@ -304,64 +339,74 @@ void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) { } /*! - * \brief Create a runtime module for TensorRT. - * \param ref The ext_func Relay expression/module to be executed using extern ops. - * \return A runtime module. - */ -runtime::Module TensorRTCompiler(const ObjectRef& ref) { - ICHECK(ref->IsInstance()) << "The input ref is expected to be a Relay function."; - Function func = Downcast(ref); - std::string func_name = backend::GetExtSymbol(func); - - VLOG(1) << "TensorRT partition:" << std::endl << PrettyPrint(func); - TensorRTJSONSerializer serializer(func_name, func); - serializer.serialize(); - std::string graph_json = serializer.GetJSON(); - VLOG(1) << "TensorRT JSON:" << std::endl << graph_json; - - // Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes - // a callback which calls backend::UpdateConstants to capture the map before the function - // 'disappears' into lowered form, on the assumption the visit order and thus constant - // names match those generated by the JSONSerializer. - - const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create"); - ICHECK(pf != nullptr) << "Cannot find TensorRT runtime module create function."; - VLOG(1) << "Creating tensorrt runtime::Module for '" << func_name << "'"; - runtime::Module lib = (*pf)(func_name, graph_json, serializer.const_names()); - return lib; -} - -TVM_REGISTER_GLOBAL("relay.ext.tensorrt").set_body_typed(TensorRTCompiler); - -/*! - * \brief Check whether TensorRT graph executor is enabled. - * \return True if enabled, False if not. + * \brief The main TensorRT compiler. + * + * TODO(mbs): Currently we create a \p TensorRTRuntimeModule for every function with + * Compiler="tensorrt" (ie for each partition). Since the TensorRT engine is only designed to + * handle a single entry point this is mostly sensible, however there are probably opportunities + * for more sharing between functions. However, note this means each call to a TensorRT-compiled + * function will require a linear scan of imported runtime modules to find the matching + * TensorRTRuntimeModule implementing it. */ -inline constexpr bool IsTensorRTRuntimeEnabled() { -#if TVM_GRAPH_EXECUTOR_TENSORRT - return true; -#else - return false; -#endif // TVM_GRAPH_EXECUTOR_TENSORRT +transform::Pass CompileForTensorRTImpl() { + auto pass_func = [](IRModule mod, const transform::PassContext& pass_ctx) { + VLOG(1) << "CompileForTensorRT input:" << std::endl << PrettyPrint(mod); + Target target = GetTensorRTTarget(); + + const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create"); + ICHECK(pf != nullptr) << "Cannot find TensorRT runtime module create function."; + + // The accumulated external runtime modules. + Array external_mods = + mod->GetAttr>(tvm::attr::kExternalMods).value_or({}); + // The accumulated constant bindings. + Map const_name_to_constant = + mod->GetAttr>(tvm::attr::kConstNameToConstant).value_or({}); + + for (const auto& kv : mod->functions) { + if (const auto* function_node = kv.second.as()) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + Optional opt_compiler = function_node->GetAttr(attr::kCompiler); + if (opt_compiler && opt_compiler.value() == "tensorrt") { + // Serialize the function to JSON. + TensorRTJSONSerializer serializer(target, kv.first->name_hint, + GetRef(function_node)); + serializer.serialize(); + std::string graph_json = serializer.GetJSON(); + VLOG(1) << "TensorRT JSON for '" << kv.first->name_hint << "':" << std::endl + << graph_json; + + // Remember all the constant bindings. + for (const auto& kv2 : serializer.const_name_to_constant()) { + ICHECK_EQ(const_name_to_constant.count(kv2.first), 0); + VLOG(1) << "binding constant '" << kv2.first << "' for function '" + << kv.first->name_hint << "'"; + const_name_to_constant.Set(kv2.first, kv2.second); + } + + // Create the actual runtime module. + runtime::Module runtime_mod = + (*pf)(kv.first->name_hint, graph_json, serializer.const_names()); + + // Remember the runtime module. + external_mods.push_back(runtime_mod); + } + } + } + } + return WithAttrs(mod, {{tvm::attr::kExternalMods, external_mods}, + {tvm::attr::kConstNameToConstant, const_name_to_constant}}); + }; + return tvm::transform::CreateModulePass(pass_func, 0, "CompileForTensorRT", {}); } -/*! - * \brief Get TensorRT version that TVM is built against. - * \return Array of three integers for major, minor, and patch, or empty array if TensorRT graph - * runtime is not enabled. - */ -Array GetTensorRTVersion() { -#if TVM_GRAPH_EXECUTOR_TENSORRT - return {Integer(NV_TENSORRT_MAJOR), Integer(NV_TENSORRT_MINOR), Integer(NV_TENSORRT_PATCH)}; -#else - return {}; -#endif // TVM_GRAPH_EXECUTOR_TENSORRT +transform::Pass CompileForTensorRT() { + return transform::Sequential( + {transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("tensorrt"), + CompileForTensorRTImpl(), transforms::MarkCompilerFunctionsAsExtern("tensorrt")}); } -TVM_REGISTER_GLOBAL("relay.op.is_tensorrt_runtime_enabled") - .set_body_typed(IsTensorRTRuntimeEnabled); -TVM_REGISTER_GLOBAL("relay.op.get_tensorrt_version").set_body_typed(GetTensorRTVersion); - +} // namespace tensorrt } // namespace contrib } // namespace relay } // namespace tvm diff --git a/src/relay/backend/contrib/tensorrt/codegen.h b/src/relay/backend/contrib/tensorrt/codegen.h new file mode 100644 index 000000000000..813a8663756d --- /dev/null +++ b/src/relay/backend/contrib/tensorrt/codegen.h @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/tensorrt/codegen.h + * \brief The 'custom' compilation pass for TensorRT (invoked by the RelayToTIRTargetHook pass). + */ + +#ifndef TVM_RELAY_BACKEND_CONTRIB_TENSORRT_CODEGEN_H_ +#define TVM_RELAY_BACKEND_CONTRIB_TENSORRT_CODEGEN_H_ + +#include + +namespace tvm { +namespace relay { +namespace contrib { +namespace tensorrt { + +/*! + * \brief Returns the pass which replaces all calls to "Primitive" functions with a "Compiler" + * attribute of "tensorrt" with calls to an extern which is implemented by a \p TensorRTRuntime + * runtime module added to the IRModule's "external_mods" attribute. + */ +transform::Pass CompileForTensorRT(); + +} // namespace tensorrt +} // namespace contrib +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_CONTRIB_TENSORRT_CODEGEN_H_ diff --git a/src/relay/backend/contrib/tensorrt/target.cc b/src/relay/backend/contrib/tensorrt/target.cc index 85d127ab7115..2e4581d30a3c 100644 --- a/src/relay/backend/contrib/tensorrt/target.cc +++ b/src/relay/backend/contrib/tensorrt/target.cc @@ -24,19 +24,46 @@ #include +#include "./codegen.h" + namespace tvm { namespace relay { namespace contrib { +namespace tensorrt { /*! * \brief This external codegen target can offload compilation to the TensorRT compiler. * - Patterns: python/tvm/relay/op/contrib/tensorrt.py * - Custom compiler: src/relay/backend/contrib/tensorrt/codegen.cc - * - Runtime: src/runtime/contrib/tensorrt/ *.cc + * - Runtime: src/runtime/contrib/tensorrt/... */ TVM_REGISTER_TARGET_KIND("tensorrt", kDLCUDA) - .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); + .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)) + .set_attr("RelayToTIR", CompileForTensorRT()) + // A array of three integers given the major, minor, and patch numbers for the supported + // TensorRT compiler version. If empty will be auto-detected from linked library. Default empty. + .add_attr_option>("tensorrt_version", Array()) + // If true, the first tensor dimension for most operators is allowed to be Any and + // TensorRT will assume it represents a batch dimension only known at inference time. + // Fewer Relay operators are supported in implicit batch mode. Default true. + .add_attr_option("use_implicit_batch", Bool(true)) + // If true, excludes sub-graphs which do not have multiply-accumulate operations, even though + // TensorRT supports them. ad. This is a simple heuristic to optimize the partitioning between + // TensorRT and TVM. Not required if using Collage for partitioning. Defalut false. + .add_attr_option("remove_no_mac_subgraphs", Bool(false)) + // How many bytes of workspace size to allow each subgraph to use for TensorRT engine creation. + // Default 1G. + .add_attr_option("max_workspace_size", Integer(1 << 30)) + // If true, allows TensorRT to automatically convert float32 operations to float16. Must also be + // enabled if any float16 operations are in the model. Note that TensorRT may still choose a + // higher-precision kernel if it results in overall lower runtime, or if no low-precision + // implementation exists. Default false. + .add_attr_option("use_fp16", Bool(false)) + // If true, allows TensorRT to automatically convert float32 operations to uint8 + // (aka quantized). Default false. + .add_attr_option("use_uint8", Bool(false)); +} // namespace tensorrt } // namespace contrib } // namespace relay } // namespace tvm diff --git a/src/runtime/const_loader_module.cc b/src/runtime/const_loader_module.cc index 2e91d26d5f96..a8028e616c5b 100644 --- a/src/runtime/const_loader_module.cc +++ b/src/runtime/const_loader_module.cc @@ -51,15 +51,24 @@ class ConstLoaderModuleNode : public ModuleNode { const std::unordered_map& const_var_ndarray, const std::unordered_map>& const_vars_by_symbol) : const_var_ndarray_(const_var_ndarray), const_vars_by_symbol_(const_vars_by_symbol) { + VLOG(1) << "Creating ConstLoaderModule"; // Only the related submodules are cached to reduce the number of runtime // symbol lookup for initialization. Otherwise, symbols/primitives in the // DSO module will also be cached but they never need to be initialized. - for (const auto& it : const_vars_by_symbol_) { - initialized_[it.first] = false; + for (const auto& kv : const_vars_by_symbol_) { + for (const auto& var : kv.second) { + VLOG(1) << "ConstLoaderModuleNode has constant '" << var << "' for function '" << kv.first + << "'"; + ICHECK_GT(const_var_ndarray_.count(var), 0) + << "ConstLoaderModuleNode is missing entry for constant '" << var << "' for function '" + << kv.first << "'"; + } + initialized_[kv.first] = false; } } PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + VLOG(1) << "ConstLoaderModuleNode::GetFunction(" << name << ")"; // Initialize and memoize the module. // Usually, we have some warmup runs. The module initialization should be // done at this stage. Therefore, runtime overhead is not a concern. @@ -88,11 +97,13 @@ class ConstLoaderModuleNode : public ModuleNode { */ Array GetRequiredConstants(const std::string& symbol) { Array ret; - ICHECK_GT(const_vars_by_symbol_.count(symbol), 0U) << "No symbol is recorded for " << symbol; + ICHECK_GT(const_vars_by_symbol_.count(symbol), 0U) + << "No constants known for function '" << symbol << "'"; std::vector vars = const_vars_by_symbol_[symbol]; - for (const auto& it : vars) { - ICHECK_GT(const_var_ndarray_.count(it), 0U) << "Found not recorded constant variable: " << it; - ret.push_back(const_var_ndarray_[it]); + for (const auto& var : vars) { + ICHECK_GT(const_var_ndarray_.count(var), 0U) + << "No such constant variable '" << var << "' for function '" << symbol << "'"; + ret.push_back(const_var_ndarray_[var]); } return ret; } @@ -229,5 +240,6 @@ TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metadata") .set_body_typed(ConstLoaderModuleNode::LoadFromBinary); TVM_REGISTER_GLOBAL("runtime.module.loadbinary_const_loader") .set_body_typed(ConstLoaderModuleNode::LoadFromBinary); + } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index 355390765de7..3a02202b87f2 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -54,6 +54,8 @@ class JSONRuntimeBase : public ModuleNode { LoadGraph(graph_json_); } + ~JSONRuntimeBase() override = default; + const char* type_key() const override { return "json"; } // May be overridden /*! \brief Initialize a specific json runtime. */ diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc index 5f923667d0c2..436a6db4c8d4 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -45,10 +45,11 @@ TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger, max_workspace_size_(max_workspace_size), use_implicit_batch_(use_implicit_batch), use_fp16_(use_fp16), - batch_size_(batch_size) { + use_int8_(false), + batch_size_(batch_size), + calibrator_(calibrator) { // Create TRT builder and network. builder_ = nvinfer1::createInferBuilder(*logger); - use_int8_ = false; #if TRT_VERSION_GE(6, 0, 1) // Use INetworkV2. @@ -58,8 +59,7 @@ TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger, flags = 0U; builder_->setMaxBatchSize(batch_size_); } - this->calibrator_ = calibrator; - if (calibrator != nullptr) { + if (calibrator_ != nullptr) { use_int8_ = true; } network_ = builder_->createNetworkV2(flags); @@ -177,6 +177,7 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() { if (use_int8_) { config_->setFlag(nvinfer1::BuilderFlag::kINT8); + ICHECK(calibrator_); config_->setInt8Calibrator(calibrator_); LOG(INFO) << "config finishes setting up calibrator as INT8 mode ... "; } @@ -210,6 +211,9 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() { nvinfer1::IExecutionContext* context = engine->createExecutionContext(); CleanUp(); + ICHECK(engine); + ICHECK(context); + return {engine, context, network_input_names_, network_output_names_}; } @@ -254,18 +258,33 @@ nvinfer1::ITensor* TensorRTBuilder::GetInputAsTensor(const TensorRTOpInput& inpu } void TensorRTBuilder::CleanUp() { + VLOG(1) << "Destroying TensorRT network"; + ICHECK(network_); network_->destroy(); + network_ = nullptr; + #if TRT_VERSION_GE(6, 0, 1) + VLOG(1) << "Destroying TensorRT config"; + ICHECK(config_); config_->destroy(); + config_ = nullptr; #endif + + VLOG(1) << "Destroying TensorRT builder"; + ICHECK(builder_); builder_->destroy(); + builder_ = nullptr; + + VLOG(1) << "Destroying TensorRT weights"; for (auto weight : trt_weights_) { + ICHECK(weight.values); if (weight.type == nvinfer1::DataType::kFLOAT || weight.type == nvinfer1::DataType::kHALF) { delete[] static_cast(weight.values); } else { delete[] static_cast(weight.values); } } + trt_weights_.clear(); } } // namespace contrib diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.h b/src/runtime/contrib/tensorrt/tensorrt_builder.h index 13a118340e11..9bccc1ea4848 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.h +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.h @@ -48,8 +48,8 @@ using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; * perform inference. */ struct TensorRTEngineAndContext { - nvinfer1::ICudaEngine* engine; - nvinfer1::IExecutionContext* context; + nvinfer1::ICudaEngine* engine = nullptr; + nvinfer1::IExecutionContext* context = nullptr; std::vector inputs; std::vector outputs; }; @@ -125,15 +125,15 @@ class TensorRTBuilder { std::unordered_map> node_output_map_; /*! \brief TensorRT builder. */ - nvinfer1::IBuilder* builder_; + nvinfer1::IBuilder* builder_ = nullptr; #if TRT_VERSION_GE(6, 0, 1) /*! \brief TensorRT builder config. */ - nvinfer1::IBuilderConfig* config_; + nvinfer1::IBuilderConfig* config_ = nullptr; #endif /*! \brief TensorRT network definition. */ - nvinfer1::INetworkDefinition* network_; + nvinfer1::INetworkDefinition* network_ = nullptr; /*! \brief List of all weights held in memory. */ std::vector trt_weights_; diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc index 3971081bf8f8..cd46967e532b 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc @@ -67,7 +67,7 @@ nvinfer1::ITensor* TensorRTOpConverter::Transpose(TensorRTOpConverterParams* par // Batch dimension cannot be modified. ICHECK_EQ(input->getDimensions().nbDims, order.size() - 1); ICHECK_EQ(order[0], 0); - for (size_t i = 0; i < order.size(); ++i) { + for (size_t i = 0; i + 1 < order.size(); ++i) { perm.order[i] = order[i + 1] - 1; } } else { @@ -880,7 +880,7 @@ class ConcatOpConverter : public TensorRTOpConverter { const int input_rank = params->inputs[0].tensor->getDimensions().nbDims; std::vector input_tensors; for (auto input : params->inputs) { - ICHECK(input.type == kTensor); + ICHECK_EQ(input.type, kTensor); ICHECK_EQ(input_rank, input.tensor->getDimensions().nbDims); input_tensors.push_back(input.tensor); } diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index 18ffdbbbba85..b51684b95eb8 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -138,13 +138,21 @@ class TensorRTRuntime : public JSONRuntimeBase { /*! \brief Destroy engines and contexts. */ void DestroyEngines() { for (auto& it : trt_engine_cache_) { + VLOG(1) << "Destroying TensorRT context for function '" << it.first.first << "' (batch size " + << it.first.second << ")"; it.second.context->destroy(); + VLOG(1) << "Destroying TensorRT engine for function '" << it.first.first << "' (batch size " + << it.first.second << ")"; it.second.engine->destroy(); } trt_engine_cache_.clear(); } - ~TensorRTRuntime() { DestroyEngines(); } + ~TensorRTRuntime() override { + VLOG(1) << "Destroying TensorRT runtime"; + DestroyEngines(); + VLOG(1) << "Destroyed TensorRT runtime"; + } /*! \brief Run inference using built engine. */ void Run() override { @@ -467,7 +475,7 @@ class TensorRTRuntime : public JSONRuntimeBase { /*! \brief TensorRT logger. */ TensorRTLogger logger_; -#else +#else // TVM_GRAPH_EXECUTOR_TENSORRT void Run() override { LOG(FATAL) << "TensorRT runtime is not enabled. " << "Please build with USE_TENSORRT_RUNTIME."; @@ -481,7 +489,7 @@ class TensorRTRuntime : public JSONRuntimeBase { bool GetCachedEnginesFromDisk() { return false; } void CacheEngineToDisk() {} -#endif +#endif // TVM_GRAPH_EXECUTOR_TENSORRT bool use_implicit_batch_; diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index ec301d10812f..e5ca82d5c099 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -215,8 +215,6 @@ runtime::Module CreateMetadataModule( String symbol = pf_sym(); Array variables = pf_var(); for (size_t i = 0; i < variables.size(); i++) { - VLOG(1) << "From module of type '" << mod->type_key() << "' found const var '" - << variables[i] << "' for symbol '" << symbol << "'"; symbol_const_vars.push_back(variables[i].operator std::string()); } ICHECK_EQ(const_vars_by_symbol.count(symbol), 0U) << "Found duplicated symbol: " << symbol; diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index cecb64785a49..38da305f3b17 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -18,7 +18,7 @@ import numpy as np import pytest import itertools - +import logging import tvm import tvm.relay.testing @@ -33,12 +33,14 @@ from tvm.contrib.download import download from tvm.relay.op.contrib import tensorrt - SUPPORTED_DTYPES = ["float16", "float32"] has_tensorrt_codegen = pytest.mark.skipif( - not tvm.get_global_func("relay.ext.tensorrt", True), reason="TensorRT codegen not available" + not tensorrt.is_tensorrt_compiler_enabled(), reason="TensorRT codegen not available" ) + +# CAUTION: Currently always false in CI since adds tens of minutes to test time and depends +# on TensorRT installation. See https://github.com/apache/tvm/issues/11765 has_tensorrt_runtime = pytest.mark.skipif( not tensorrt.is_tensorrt_runtime_enabled(), reason="TensorRT runtime not available" ) @@ -72,7 +74,7 @@ def assert_result_dict_holds(result_dict, dtype="float16"): tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=5e-3) -def set_func_attr(func, compile_name, symbol_name): +def set_outer_func_attr(func, compile_name, symbol_name): func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Compiler", compile_name) @@ -80,6 +82,12 @@ def set_func_attr(func, compile_name, symbol_name): return func +def set_inner_func_attr(func, pattern_name, composite_name): + func = func.with_attr("PartitionedFromPattern", pattern_name) + func = func.with_attr("Composite", composite_name) + return func + + def run_and_verify_func(config, target="cuda", run_module=True, data_type="float32"): """Test a Relay func by compiling, running, and comparing TVM and TRT outputs. @@ -110,34 +118,31 @@ def run_and_verify_func(config, target="cuda", run_module=True, data_type="float result_dict = dict() for mode in ["vm", "graph"]: - for mode in ["graph"]: - for use_trt in [True, False]: - mod = tvm.IRModule() - mod["main"] = f - result_key = mode + ("_trt" if use_trt else "") - if use_trt: - mod = relay.transform.InferType()(mod) - mod, config = tensorrt.partition_for_tensorrt( - mod, params, use_fp16=data_type == "float16" - ) - with tvm.transform.PassContext( - opt_level=3, config={"relay.ext.tensorrt.options": config} - ): - func = relay.create_executor( - mode, mod=mod, device=dev, target=target - ).evaluate() - else: - mod = relay.transform.InferType()(mod) - with tvm.transform.PassContext(opt_level=3): - func = relay.create_executor( - mode, mod=mod, device=dev, target=target - ).evaluate() + for use_trt in [True, False]: + mod = tvm.IRModule() + mod["main"] = f + result_key = mode + ("_trt" if use_trt else "") + if use_trt: + use_fp16 = data_type == "float16" + trt_target = tvm.target.Target(f"tensorrt -use_fp16={use_fp16}") + mod = relay.transform.InferType()(mod) + mod = tensorrt.partition_for_tensorrt(mod, params=params, target=trt_target) + with tvm.transform.PassContext(opt_level=3): + func = relay.create_executor( + mode, mod=mod, device=dev, target=[target, trt_target] + ).evaluate() + else: + mod = relay.transform.InferType()(mod) + with tvm.transform.PassContext(opt_level=3): + func = relay.create_executor( + mode, mod=mod, device=dev, target=target + ).evaluate() - if run_module: - result_dict[result_key] = func(**input_dict, **params) + if run_module: + result_dict[result_key] = func(**input_dict, **params) - if run_module: - assert_result_dict_holds(result_dict, data_type) + if run_module: + assert_result_dict_holds(result_dict, data_type) def test_tensorrt_simple(run_module): @@ -163,10 +168,8 @@ def test_tensorrt_simple(run_module): result_key = mode + ("_trt" if use_trt else "") if use_trt: mod = relay.transform.InferType()(mod) - mod, config = tensorrt.partition_for_tensorrt(mod) - with tvm.transform.PassContext( - opt_level=3, config={"relay.ext.tensorrt.options": config} - ): + mod = tensorrt.partition_for_tensorrt(mod) + with tvm.transform.PassContext(opt_level=3): func = relay.create_executor( mode, mod=mod, device=tvm.cuda(0), target="cuda" ).evaluate() @@ -212,9 +215,9 @@ def test_tensorrt_not_compatible(run_module): f = relay.Function([x], out) mod = tvm.IRModule() mod["main"] = f - mod, config = tensorrt.partition_for_tensorrt(mod) + mod = tensorrt.partition_for_tensorrt(mod) for mode in ["graph", "vm"]: - with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): + with tvm.transform.PassContext(opt_level=3): func = relay.create_executor( mode, mod=mod, device=tvm.cuda(0), target="cuda" ).evaluate() @@ -622,26 +625,18 @@ def are_ops_on_graph(self, subgraph) -> bool: def are_ops_on_trt(mod, op_list): + op_on_trt = False + op_on_tvm = False for subgraph in mod.get_global_vars(): name = subgraph.name_hint - op_on_trt = False - op_on_tvm = True - if name == "main": - op_on_tvm = AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body) - elif mod[name].attrs and mod[name].attrs["Compiler"] == "tensorrt": - op_on_trt = AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body) + if mod[name].attrs and mod[name].attrs["Compiler"] == "tensorrt": + op_on_trt |= AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body) else: - op_on_tvm &= AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body) - - if not op_on_trt or op_on_tvm: - return False + op_on_tvm |= AreOpsOnGraph(op_list).are_ops_on_graph(mod[name].body) - return True + return op_on_trt and not op_on_tvm -@pytest.mark.xfail( - reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901") -) def test_dynamic_reshape(run_module): def test_run(x_data_list, x_shape, new_shape, should_offload_to_trt): result_arr = [{} for _ in range(len(x_data_list))] @@ -652,9 +647,9 @@ def test_run(x_data_list, x_shape, new_shape, should_offload_to_trt): mod = tvm.IRModule() mod["main"] = f if use_trt: - mod, _ = tensorrt.partition_for_tensorrt( - mod, params={}, remove_no_mac_subgraphs=False - ) + logging.info("Before partitioning:\n%s", mod) + mod = tensorrt.partition_for_tensorrt(mod) + logging.info("After partitioning:\n%s", mod) assert are_ops_on_trt(mod, op_list=["reshape"]) == should_offload_to_trt if run_module: with relay.build_config(opt_level=3): @@ -1051,6 +1046,7 @@ def get_graph(d_type="float16"): run_and_verify_func(get_graph(d_type=type), run_module=run_module, data_type=type) +@pytest.mark.skip(reason=("Fails assert_allclose. See https://github.com/apache/tvm/issues/11765")) def test_conv3d(run_module): def get_graph( x_shape=(1, 24, 8, 8, 8), @@ -1143,11 +1139,6 @@ def get_graph( ) -@pytest.mark.xfail( - reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901") -) -@has_tensorrt_codegen -@tvm.testing.requires_cuda def test_dynamic_offload(): """ This test checks for proper dynamic offloading of relay graphs. An addition between @@ -1161,24 +1152,29 @@ def test_dynamic_offload(): x = relay.var("x", shape=(data_shape[0], data_shape[1], Any(), Any()), dtype="float32") y = relay.var("y", shape=(data_shape), dtype="float32") - kernel = relay.var("kernel", shape=(k_shape), dtype="float32") + kernel = relay.const(np.random.rand(*k_shape).astype("float32")) def get_expected(): # Create a nested TRT function that matches the expected output mod = tvm.IRModule() - var1 = relay.var("tensorrt_0_i0", shape=(data_shape), dtype="float32") - kernel_trt = relay.var("tensorrt_0_i1", shape=(k_shape), dtype="float32") - out1 = relay.nn.conv2d(var1, kernel_trt, channels=k_shape[0], kernel_size=k_shape[2:4]) - f1 = GlobalVar("tvmgen_default_tensorrt_0") - func = relay.Function([var1, kernel_trt], out1) - func = set_func_attr(func, "tensorrt", "tvmgen_default_tensorrt_0") - mod[f1] = func + outer_var = relay.var("tensorrt_0_i0", shape=(data_shape), dtype="float32") + inner_var = relay.var("FunctionVar_0_0", shape=(data_shape), dtype="float32") + inner_body = relay.nn.conv2d( + inner_var, kernel, channels=k_shape[0], kernel_size=k_shape[2:4] + ) + inner_func = relay.Function([inner_var], inner_body) + inner_func = set_inner_func_attr(inner_func, "nn.conv2d_", "tensorrt.nn.conv2d") + outer_body = inner_func(outer_var) + outer_func = relay.Function([outer_var], outer_body) + outer_func = set_outer_func_attr(outer_func, "tensorrt", "tvmgen_default_tensorrt_main_0") + gv = GlobalVar("tvmgen_default_tensorrt_main_0") + mod[gv] = outer_func mod = relay.transform.InferType()(mod) # Create the main function out1 = relay.nn.conv2d(x, kernel, channels=k_shape[0], kernel_size=k_shape[2:4]) - out = relay.add(out1, f1(y, kernel)) - f = relay.Function([x, y, kernel], out) + out = relay.add(out1, gv(y)) + f = relay.Function([x, y], out) mod["main"] = f mod = relay.transform.InferType()(mod) return mod @@ -1187,13 +1183,13 @@ def get_expected(): out1 = relay.nn.conv2d(x, kernel, channels=k_shape[0], kernel_size=k_shape[2:4]) out2 = relay.nn.conv2d(y, kernel, channels=k_shape[0], kernel_size=k_shape[2:4]) out = relay.add(out1, out2) - f = relay.Function([x, y, kernel], out) + f = relay.Function([x, y], out) # Pass the function to TRT compilation mod = tvm.IRModule() mod["main"] = f mod = relay.transform.InferType()(mod) - mod_trt, config = tensorrt.partition_for_tensorrt(mod, params={}) + mod_trt = tensorrt.partition_for_tensorrt(mod) # Get the expected relay graph and compare mod_exp = get_expected() @@ -1212,7 +1208,7 @@ def test_tensorrt_dynamic_batch(run_module): mod = tvm.IRModule() mod["main"] = f if use_trt: - mod, _ = tensorrt.partition_for_tensorrt(mod) + mod = tensorrt.partition_for_tensorrt(mod) if run_module: with relay.build_config(opt_level=3): @@ -1242,17 +1238,17 @@ def test_tensorrt_dynamic_batch_conv(run_module): f = relay.Function([x, kernel], out) mod = tvm.IRModule() mod["main"] = f + trt_target = tvm.target.Target(f"tensorrt -use_implicit_batch={use_implicit_batch}") if use_trt: - mod, config = tensorrt.partition_for_tensorrt( - mod, params, use_implicit_batch=use_implicit_batch - ) + mod = tensorrt.partition_for_tensorrt(mod, params=params, target=trt_target) if run_module: for target in ["llvm", "cuda"]: - with tvm.transform.PassContext( - opt_level=3, config={"relay.ext.tensorrt.options": config} - ): + targets = [target] + if use_trt: + targets.append(trt_target) + with tvm.transform.PassContext(opt_level=3): func = relay.create_executor( - "vm", mod=mod, device=tvm.device(target), target=target + "vm", mod=mod, device=tvm.device(target), target=targets ).evaluate() for i, batch_size in enumerate(batches_to_test): result_arr[i][target][use_trt] = func(x_data[:batch_size, ...], **params) @@ -1262,6 +1258,11 @@ def test_tensorrt_dynamic_batch_conv(run_module): assert_result_dict_holds(result_arr[i][target]) +@pytest.mark.skip( + reason=( + "Coredumps, possibly due to LLVM and PyTorch version mismatch. See https://github.com/apache/tvm/issues/11765" + ) +) def test_maskrcnn_resnet50(run_module) -> None: """ This function tests the working of pytorch maskrcnn with resnet50 as backbone with @@ -1281,9 +1282,11 @@ def convert_traced_model_to_vm_trt( input_name = "input0" shape_list = [(input_name, input_shape)] mod, params = relay.frontend.from_pytorch(traced_module, shape_list) - mod, config = tensorrt.partition_for_tensorrt(mod, params, remove_no_mac_subgraphs=True) + trt_target = tvm.target.Target("tensorrt -remove_no_mac_subgraphs=True") + mod = tensorrt.partition_for_tensorrt(mod, params=params, target=trt_target) + targets = [target, trt_target] with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]): - vm_trt_exec = relay.vm.compile(mod, target=target, params=params) + vm_trt_exec = relay.vm.compile(mod, target=targets, params=params) return vm_trt_exec @@ -1381,7 +1384,7 @@ def test_empty_subgraph(run_module): var1 = relay.var("tensorrt_0_i0", shape=(x_shape), dtype="float32") f1 = GlobalVar("tensorrt_0") func = relay.Function([var1], var1) - func = set_func_attr(func, "tensorrt", "tvmgen_default_tensorrt_0") + func = set_outer_func_attr(func, "tensorrt", "tvmgen_default_tensorrt_0") mod[f1] = func mod = relay.transform.InferType()(mod) @@ -1402,4 +1405,5 @@ def test_empty_subgraph(run_module): if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) tvm.testing.main() diff --git a/tests/python/contrib/test_tensorrt_int8_exp.py b/tests/python/contrib/test_tensorrt_int8_exp.py index 84360e92d33b..260179dba29a 100644 --- a/tests/python/contrib/test_tensorrt_int8_exp.py +++ b/tests/python/contrib/test_tensorrt_int8_exp.py @@ -19,7 +19,7 @@ import numpy as np import tvm -import tvm.relay.testing +import tvm.testing from tvm import relay from tvm.contrib.download import download_testdata from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt @@ -31,9 +31,10 @@ def skip_codegen_test(): if not tvm.runtime.enabled("cuda") or not tvm.cuda(0).exist: print("Skip because CUDA is not enabled.") return True - if not tvm.get_global_func("relay.ext.tensorrt", True): - print("Skip because TensorRT codegen is not available.") + if not tensorrt.is_tensorrt_compiler_enabled(): + print("Skip because TensorRT compiler is not available.") return True + print("TensorRT compiler is available!") return False @@ -44,6 +45,7 @@ def skip_runtime_test(): if not tensorrt.is_tensorrt_runtime_enabled(): print("Skip because TensorRT runtime is not available.") return True + print("TensorRT runtime is available!") return False @@ -102,12 +104,11 @@ def test_trt_int8(): # compile the model target = "cuda" - dev = tvm.cuda(1) - mod, config = partition_for_tensorrt(mod, params) - with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): + dev = tvm.cuda() + mod = partition_for_tensorrt(mod, params) + with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target=target, params=params) - dtype = "float32" gen_module = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) num_cali_int8 = int(os.environ["TENSORRT_NUM_CALI_INT8"]) @@ -146,4 +147,5 @@ def test_trt_int8(): if __name__ == "__main__": - pytest.main([__file__]) + #tvm.testing.main() + test_trt_int8() From 49347a53352a88acce324e93bec720254893656e Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Thu, 30 Jun 2022 15:15:56 -0700 Subject: [PATCH 02/10] - Lints --- tests/python/contrib/test_tensorrt_int8_exp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/contrib/test_tensorrt_int8_exp.py b/tests/python/contrib/test_tensorrt_int8_exp.py index 260179dba29a..54e32f4b5b48 100644 --- a/tests/python/contrib/test_tensorrt_int8_exp.py +++ b/tests/python/contrib/test_tensorrt_int8_exp.py @@ -147,5 +147,5 @@ def test_trt_int8(): if __name__ == "__main__": - #tvm.testing.main() + # tvm.testing.main() test_trt_int8() From b0b0860d369ee308f8917baf13d101fc889427a3 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Thu, 30 Jun 2022 15:19:40 -0700 Subject: [PATCH 03/10] - Woops, fix test --- tests/python/contrib/test_tensorrt_int8_exp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python/contrib/test_tensorrt_int8_exp.py b/tests/python/contrib/test_tensorrt_int8_exp.py index 54e32f4b5b48..7f52760b889e 100644 --- a/tests/python/contrib/test_tensorrt_int8_exp.py +++ b/tests/python/contrib/test_tensorrt_int8_exp.py @@ -147,5 +147,4 @@ def test_trt_int8(): if __name__ == "__main__": - # tvm.testing.main() - test_trt_int8() + tvm.testing.main() From 23cf487d58cc262d5c299ceb7d9dc0e5bf6ccb27 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Thu, 30 Jun 2022 15:48:05 -0700 Subject: [PATCH 04/10] - lints --- python/tvm/relay/op/contrib/tensorrt.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index c441c30808c3..889db223a694 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -26,7 +26,7 @@ from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name from tvm.relay.dataflow_pattern import is_op, wildcard, is_constant, is_tuple, is_tuple_get_item -from tvm.relay.expr import Call, Constant, GlobalVar, TupleGetItem +from tvm.relay.expr import Call, Constant, TupleGetItem from tvm.relay.expr_functor import ExprMutator, ExprVisitor from tvm.relay.op.contrib.register import register_pattern_table @@ -222,6 +222,9 @@ def get_attrs(expr: relay.expr.Expr) -> Any: def make_predicate(checker: CheckFunc) -> Callable[[relay.expr.Expr], bool]: + """Returns the pattern predicate which performs the standard checks, then invokes the + more primitive checker.""" + def predicate(expr: relay.expr.Expr) -> bool: op_name = get_op_name(expr) attrs = get_attrs(expr) From 8dd0e9638f460b313ec28adbeda8b373fcc8d8ed Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Thu, 30 Jun 2022 15:55:21 -0700 Subject: [PATCH 05/10] - Use default tensorrt target if none given in targets list --- python/tvm/relay/op/contrib/tensorrt.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index 889db223a694..a499414fa5ac 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -53,7 +53,9 @@ def is_tensorrt_runtime_enabled() -> bool: def get_tensorrt_target() -> tvm.target.Target: """Returns the current Target, which must be of kind "tensorrt".""" target = tvm.target.Target.current() - assert target.kind.name == "tensorrt" + if target is None or target.kind.name != "tensorrt": + # Create the default target. + return tvm.target.Target("tensorrt") return target From f173fbce97543a97f7a62f4ac43eb1de4ee5163b Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Fri, 1 Jul 2022 10:33:38 -0700 Subject: [PATCH 06/10] - fix free error --- tests/python/contrib/test_tensorrt_int8_exp.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/python/contrib/test_tensorrt_int8_exp.py b/tests/python/contrib/test_tensorrt_int8_exp.py index 7f52760b889e..304d9a095e84 100644 --- a/tests/python/contrib/test_tensorrt_int8_exp.py +++ b/tests/python/contrib/test_tensorrt_int8_exp.py @@ -18,6 +18,12 @@ import os import numpy as np +try: + # See issue #9362. + import torch +except: + pass + import tvm import tvm.testing from tvm import relay From 8539ef4cd344c0357731b8135bb74fa3d25c6ac5 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Fri, 1 Jul 2022 10:58:09 -0700 Subject: [PATCH 07/10] - accidentally introduced 'transforms' namespace - can't use default Target("tensorrt") arg --- python/tvm/relay/op/contrib/tensorrt.py | 9 ++++++++- src/relay/backend/contrib/codegen_c/codegen.cc | 10 +++++----- src/relay/backend/contrib/cutlass/codegen.cc | 10 +++++----- src/relay/backend/contrib/tensorrt/codegen.cc | 10 +++++----- src/relay/transforms/compiler_function_utils.cc | 16 ++++++++-------- src/relay/transforms/compiler_function_utils.h | 15 ++++++++------- 6 files changed, 39 insertions(+), 31 deletions(-) diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index a499414fa5ac..d659f514d9a3 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -111,7 +111,9 @@ def get_tensorrt_use_fp16() -> bool: def partition_for_tensorrt( mod: tvm.IRModule, params: Optional[Dict[str, tvm.nd.NDArray]] = None, - target: tvm.target.Target = tvm.target.Target("tensorrt"), + # CAUTION: Can't use default Target("tensorrt") here since the target kind is only available + # if is_tensorrt_compiler_enabled() == True. + target: Optional[tvm.target.Target] = None, ) -> tvm.IRModule: """Partition all functions in mod to greedily offload supported operators to TensorRT. @@ -130,8 +132,13 @@ def partition_for_tensorrt( The partitioned module. """ + assert is_tensorrt_compiler_enabled(), "Can only partition for TensorRT if it is enabled" if params: mod["main"] = bind_params_by_name(mod["main"], params) + if target is None: + # Use a default target. The get_tensorrt_target() function will similarly create an + # equivalent default target when compilation continues after partitioning. + target = tvm.target.Target("tensorrt") seq = tvm.transform.Sequential( [ diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index ee8724fe92fe..41f0a0a06408 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -360,8 +360,8 @@ class CodegenCModule { }; /*! \brief The actual translation pass. */ -transform::Pass CCompilerImpl() { - auto pass_func = [=](IRModule mod, const transform::PassContext& pass_ctx) { +tvm::transform::Pass CCompilerImpl() { + auto pass_func = [=](IRModule mod, const tvm::transform::PassContext& pass_ctx) { VLOG(1) << "CCompilerImpl input:" << std::endl << PrettyPrint(mod); Target target = GetCCompilerTarget(); @@ -388,10 +388,10 @@ transform::Pass CCompilerImpl() { return tvm::transform::CreateModulePass(pass_func, 0, "CCompilerImpl", {}); } -transform::Pass CCompilerPass() { +tvm::transform::Pass CCompilerPass() { return transform::Sequential( - {transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("ccompiler"), CCompilerImpl(), - transforms::MarkCompilerFunctionsAsExtern("ccompiler")}); + {transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("ccompiler"), CCompilerImpl(), + transform::MarkCompilerFunctionsAsExtern("ccompiler")}); } } // namespace contrib diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index de2934173b5f..2e76ab1cbbf6 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -902,8 +902,8 @@ class CutlassModuleCodegen { * \brief A small shim to redirect to the 'relay.ext.cutlass.compile_for_cutlass' Python * function which does the main CUTLASS training, c-code generation and compilation steps. */ -transform::Pass CompileForCutlassImpl() { - auto pass_func = [=](IRModule mod, const transform::PassContext& pass_ctx) { +tvm::transform::Pass CompileForCutlassImpl() { + auto pass_func = [=](IRModule mod, const tvm::transform::PassContext& pass_ctx) { VLOG(1) << "CompileForCutlass input:" << std::endl << PrettyPrint(mod); const auto* pf = runtime::Registry::Get("relay.ext.cutlass.compile_for_cutlass"); ICHECK(pf != nullptr) << "Cannot find compile_for_cutlass function"; @@ -926,10 +926,10 @@ runtime::Module CreateCSourceModule(const IRModule& mod) { TVM_REGISTER_GLOBAL("relay.ext.cutlass.create_c_source_module").set_body_typed(CreateCSourceModule); -transform::Pass CompileForCutlass() { +tvm::transform::Pass CompileForCutlass() { return transform::Sequential( - {transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("cutlass"), - CompileForCutlassImpl(), transforms::MarkCompilerFunctionsAsExtern("cutlass")}); + {transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("cutlass"), + CompileForCutlassImpl(), transform::MarkCompilerFunctionsAsExtern("cutlass")}); } } // namespace cutlass diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc index 526f6bf7588a..dda5736b1be6 100644 --- a/src/relay/backend/contrib/tensorrt/codegen.cc +++ b/src/relay/backend/contrib/tensorrt/codegen.cc @@ -348,8 +348,8 @@ void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) { * function will require a linear scan of imported runtime modules to find the matching * TensorRTRuntimeModule implementing it. */ -transform::Pass CompileForTensorRTImpl() { - auto pass_func = [](IRModule mod, const transform::PassContext& pass_ctx) { +tvm::transform::Pass CompileForTensorRTImpl() { + auto pass_func = [](IRModule mod, const tvm::transform::PassContext& pass_ctx) { VLOG(1) << "CompileForTensorRT input:" << std::endl << PrettyPrint(mod); Target target = GetTensorRTTarget(); @@ -400,10 +400,10 @@ transform::Pass CompileForTensorRTImpl() { return tvm::transform::CreateModulePass(pass_func, 0, "CompileForTensorRT", {}); } -transform::Pass CompileForTensorRT() { +tvm::transform::Pass CompileForTensorRT() { return transform::Sequential( - {transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("tensorrt"), - CompileForTensorRTImpl(), transforms::MarkCompilerFunctionsAsExtern("tensorrt")}); + {transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("tensorrt"), + CompileForTensorRTImpl(), transform::MarkCompilerFunctionsAsExtern("tensorrt")}); } } // namespace tensorrt diff --git a/src/relay/transforms/compiler_function_utils.cc b/src/relay/transforms/compiler_function_utils.cc index 0df9f5ee294c..1dafcd10a361 100644 --- a/src/relay/transforms/compiler_function_utils.cc +++ b/src/relay/transforms/compiler_function_utils.cc @@ -24,14 +24,13 @@ #include "./compiler_function_utils.h" -#include "../op/call/call.h" #include "tvm/relay/analysis.h" #include "tvm/relay/expr_functor.h" #include "tvm/relay/transform.h" namespace tvm { namespace relay { -namespace transforms { +namespace transform { namespace { /*! @@ -211,8 +210,8 @@ GlobalVar ExistingGlobalSymbolCache::GetGlobalSymbol(const Function& function) { return global_var; } -transform::Pass OutlineCompilerFunctions(std::shared_ptr cache, - std::string compiler_filter) { +tvm::transform::Pass OutlineCompilerFunctions(std::shared_ptr cache, + std::string compiler_filter) { runtime::TypedPackedFunc pass_func = [cache = std::move(cache), compiler_filter = std::move(compiler_filter)]( IRModule mod, transform::PassContext ctx) { @@ -235,12 +234,13 @@ transform::Pass OutlineCompilerFunctions(std::shared_ptr cach } // Any Java programmers in the house? -transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string compiler_filter) { +tvm::transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols( + std::string compiler_filter) { return OutlineCompilerFunctions(std::make_shared(), std::move(compiler_filter)); } -transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) { +tvm::transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) { runtime::TypedPackedFunc pass_func = [compiler_filter = std::move(compiler_filter)](IRModule mod, transform::PassContext ctx) { VLOG(1) << "MarkCompilerFunctionsAsExtern input:" << std::endl << PrettyPrint(mod); @@ -262,7 +262,7 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) { return tvm::transform::CreateModulePass(pass_func, 0, "MarkCompilerFunctionsAsExtern", {}); } -transform::Pass InlineCompilerFunctionsBoundTo(Array global_vars) { +tvm::transform::Pass InlineCompilerFunctionsBoundTo(Array global_vars) { runtime::TypedPackedFunc pass_func = [global_vars = std::move(global_vars)](IRModule mod, transform::PassContext ctx) { VLOG(1) << "InlineCompilerFunctionsBoundTo with global_vars: " << PrettyPrint(global_vars); @@ -295,6 +295,6 @@ TVM_REGISTER_GLOBAL("relay._transform.MarkCompilerFunctionsAsExtern") TVM_REGISTER_GLOBAL("relay._transform.InlineCompilerFunctionsBoundTo") .set_body_typed(InlineCompilerFunctionsBoundTo); -} // namespace transforms +} // namespace transform } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/compiler_function_utils.h b/src/relay/transforms/compiler_function_utils.h index aa98430318a6..f3499faec262 100644 --- a/src/relay/transforms/compiler_function_utils.h +++ b/src/relay/transforms/compiler_function_utils.h @@ -66,7 +66,7 @@ namespace tvm { namespace relay { -namespace transforms { +namespace transform { /*! * \brief Abstract class representing a cache of unique global vars keyed by functions. This can @@ -105,8 +105,8 @@ class ExistingGlobalSymbolCache : public GlobalSymbolCache { * If \p compiler_filter is non-empty only functions with that as their attribute value are * outlined. */ -transform::Pass OutlineCompilerFunctions(std::shared_ptr cache, - std::string compiler_filter = ""); +tvm::transform::Pass OutlineCompilerFunctions(std::shared_ptr cache, + std::string compiler_filter = ""); /*! * \brief A pass to outline all let-bound and literal functions in direct call positions which have @@ -119,7 +119,8 @@ transform::Pass OutlineCompilerFunctions(std::shared_ptr cach * This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism * to prepare the IRModule before custom lowering. */ -transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string compiler_filter = ""); +tvm::transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols( + std::string compiler_filter = ""); /*! * \brief A pass to mark all global functions which have a "Compiler" attribute matching @@ -132,7 +133,7 @@ transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co * This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism to * cleanup the IRModule after custom lowering. */ -transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter = ""); +tvm::transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter = ""); /*! * \brief A pass to inline all global "Compiler" functions which are bound to a global var @@ -142,9 +143,9 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter = ""); * This pass may be useful for external codegen which needs to undo partitioning based on * properties of the entire partition. */ -transform::Pass InlineCompilerFunctionsBoundTo(Array global_vars); +tvm::transform::Pass InlineCompilerFunctionsBoundTo(Array global_vars); -} // namespace transforms +} // namespace transform } // namespace relay } // namespace tvm From a37e745a4bd40f9181a24b8c01b8be15d7e5d458 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Fri, 1 Jul 2022 11:40:25 -0700 Subject: [PATCH 08/10] - D'oh! Include ended up #if protected --- src/relay/backend/contrib/tensorrt/codegen.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc index dda5736b1be6..e67e38204a48 100644 --- a/src/relay/backend/contrib/tensorrt/codegen.cc +++ b/src/relay/backend/contrib/tensorrt/codegen.cc @@ -29,11 +29,11 @@ #include #include +#include "../../../transforms/compiler_function_utils.h" #include "../../utils.h" #include "../codegen_json/codegen_json.h" #if TVM_GRAPH_EXECUTOR_TENSORRT -#include "../../../transforms/compiler_function_utils.h" #include "NvInfer.h" #endif From f395d9cb74af78b4eac2304e4e7d11e21075e050 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Fri, 1 Jul 2022 13:13:41 -0700 Subject: [PATCH 09/10] - restore mark for test_dynamic_offload - handle missing runtime in versioning - turn test_maskrcnn_resnet50 back on now that we have the import-torch-first workaround. --- python/tvm/relay/op/contrib/tensorrt.py | 12 ++++++--- src/relay/backend/contrib/tensorrt/codegen.cc | 4 +++ tests/python/contrib/test_tensorrt.py | 25 +++++++++++-------- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index d659f514d9a3..4008b0eb3f78 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -71,21 +71,25 @@ def get_tensorrt_version() -> Tuple[int, int, int]: ret: Tuple[int, int, int] TensorRT version as a tuple of (major, minor, patch). """ + # cf logic in tensorrt/codegen.cc::SaveGlobalAttributes + # First check for version in target. target = get_tensorrt_target() version = target.attrs["tensorrt_version"] if len(version) == 3: return int(version[0]), int(version[1]), int(version[2]) assert len(version) == 0 - get_version = tvm.get_global_func("relay.ext.tensorrt.get_version", True) - if get_version: + # Next, ask runtime for its version. + if is_tensorrt_runtime_enabled(): + get_version = tvm.get_global_func("relay.ext.tensorrt.get_version") version = get_version() assert len(version) == 3 return int(version[0]), int(version[1]), int(version[2]) + # Finally, use default. logger.warning( - "TVM was not built against TensorRT and no version was provided to " - "partition_for_tensorrt. Defaulting to 6.0.1" + "TVM was not built against TensorRT and no version was provided in the 'tensorrt' target." + "Defaulting to 6.0.1." ) return (6, 0, 1) diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc index e67e38204a48..1c4a8d78062e 100644 --- a/src/relay/backend/contrib/tensorrt/codegen.cc +++ b/src/relay/backend/contrib/tensorrt/codegen.cc @@ -278,11 +278,15 @@ class TensorRTJSONSerializer : public JSONSerializer { /*! \brief Capture the compilation options as attributes on \p node. */ void SaveGlobalAttributes(JSONGraphNode* node) { { + // cf logic in tensorrt.py::get_tensorrt_version. + // First check for version in target. Array target_attr = target_->GetAttr>("tensorrt_version").value(); if (target_attr.empty()) { + // Next, ask runtime for its version. target_attr = GetVersion(); } if (target_attr.empty()) { + // Finally, use default. target_attr = {6, 0, 1}; } ICHECK_EQ(target_attr.size(), 3); diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index 38da305f3b17..bffcfb1e33cf 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -14,22 +14,26 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm.testing + import numpy as np import pytest import itertools import logging +from typing import Tuple + +try: + # See issue #9362. + import torch +except: + pass import tvm +import tvm.testing import tvm.relay.testing from tvm import relay -from tvm.relay.op.contrib import tensorrt - from tvm.relay import Any, GlobalVar - from tvm.relay.expr_functor import ExprVisitor -from typing import Tuple from tvm.contrib.download import download from tvm.relay.op.contrib import tensorrt @@ -1139,6 +1143,7 @@ def get_graph( ) +@has_tensorrt_codegen def test_dynamic_offload(): """ This test checks for proper dynamic offloading of relay graphs. An addition between @@ -1258,11 +1263,11 @@ def test_tensorrt_dynamic_batch_conv(run_module): assert_result_dict_holds(result_arr[i][target]) -@pytest.mark.skip( - reason=( - "Coredumps, possibly due to LLVM and PyTorch version mismatch. See https://github.com/apache/tvm/issues/11765" - ) -) +#@pytest.mark.skip( +# reason=( +# "Coredumps, possibly due to LLVM and PyTorch version mismatch. See https://github.com/apache/tvm/issues/11765" +# ) +#) def test_maskrcnn_resnet50(run_module) -> None: """ This function tests the working of pytorch maskrcnn with resnet50 as backbone with From c8e10c96d94d66dd8faace8546822dd8a4768bda Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Fri, 1 Jul 2022 13:16:56 -0700 Subject: [PATCH 10/10] - wibble --- tests/python/contrib/test_tensorrt.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index bffcfb1e33cf..9e39821fd317 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -1263,11 +1263,6 @@ def test_tensorrt_dynamic_batch_conv(run_module): assert_result_dict_holds(result_arr[i][target]) -#@pytest.mark.skip( -# reason=( -# "Coredumps, possibly due to LLVM and PyTorch version mismatch. See https://github.com/apache/tvm/issues/11765" -# ) -#) def test_maskrcnn_resnet50(run_module) -> None: """ This function tests the working of pytorch maskrcnn with resnet50 as backbone with