diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index eb814ab5012a5..2d45a8b6608ae 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -106,6 +106,7 @@ def partition_for_tensorrt( max_workspace_size=1 << 30, use_fp16=False, use_uint8=False, + use_patterns=False, ): """Partition the graph greedily offloading supported operators to TensorRT. @@ -136,6 +137,8 @@ def partition_for_tensorrt( lower runtime, or if no low-precision implementation exists. use_uint8: Optional[bool] Allows, TRT to automatically convert FP32 inputs to UINT8. + use_patterns: Optional[bool] + Switches to use pattern-based op suppot by applying MergeCompsite and InlineComposites passes. Returns ------- mod_and_config : Tuple[Module, Dict[str, Any]] @@ -164,34 +167,74 @@ def partition_for_tensorrt( if params: mod["main"] = bind_params_by_name(mod["main"], params) - seq = tvm.transform.Sequential( - [ - transform.InferType(), - RemoveDropoutPass(), - transform.RemoveUnusedFunctions(), - transform.ConvertLayout( - { - "nn.conv1d": ["NCW", "default"], - "nn.conv2d": ["NCHW", "default"], - "nn.conv3d": ["NCDHW", "default"], - "nn.conv2d_transpose": ["NCHW", "default"], - } - ), - transform.FoldConstant(), - transform.MergeComposite(pattern_table()), - transform.AnnotateTarget("tensorrt"), - transform.MergeCompilerRegions(), - transform.PartitionGraph(), - transform.InlineComposites("tensorrt"), - transform.InferType(), - ] - ) + + seq = get_pass_order(use_patterns) with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): mod = seq(mod) mod = prune_tensorrt_subgraphs(mod) return mod, config +def get_pass_order(use_patterns): + """ + Get the pass ordering based on using predicates or patterns. + + Parameters + ---------- + use_patterns: Bool + True if pass needs to work with op patterns + Returns + ---------- + ret : Sequential + Pass object + """ + return ( + tvm.transform.Sequential( + [ + transform.InferType(), + RemoveDropoutPass(), + transform.RemoveUnusedFunctions(), + transform.ConvertLayout( + { + "nn.conv1d": ["NCW", "default"], + "nn.conv2d": ["NCHW", "default"], + "nn.conv3d": ["NCDHW", "default"], + "nn.conv2d_transpose": ["NCHW", "default"], + } + ), + transform.FoldConstant(), + transform.MergeComposite(pattern_table()), + transform.AnnotateTarget("tensorrt"), + transform.MergeCompilerRegions(), + transform.PartitionGraph(), + transform.InlineComposites("tensorrt"), + transform.InferType(), + ] + ) + if use_patterns + else tvm.transform.Sequential( + [ + transform.InferType(), + RemoveDropoutPass(), + transform.RemoveUnusedFunctions(), + transform.ConvertLayout( + { + "nn.conv1d": ["NCW", "default"], + "nn.conv2d": ["NCHW", "default"], + "nn.conv3d": ["NCDHW", "default"], + "nn.conv2d_transpose": ["NCHW", "default"], + } + ), + transform.FoldConstant(), + transform.AnnotateTarget("tensorrt"), + transform.MergeCompilerRegions(), + transform.PartitionGraph(), + transform.InferType(), + ] + ) + ) + + def check_dynamism(args, op_name): """ Check for dynamism inside any of the args in the op. @@ -451,7 +494,7 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable @_register_external_dynamic_check_func("nn.batch_matmul") -def batch_matmul_annotate_fn(expr): # pylint: disable=unused-variable +def batch_matmul_annotate_fn(expr): """Check if dense is supported by TensorRT.""" args = expr.args diff --git a/src/relay/transforms/inline_composites.cc b/src/relay/transforms/inline_composites.cc index b878bb247874f..63e7d078b0c54 100644 --- a/src/relay/transforms/inline_composites.cc +++ b/src/relay/transforms/inline_composites.cc @@ -34,12 +34,12 @@ namespace tvm { namespace relay { -class Unmerger : ExprMutator { +class CompositeInliner : public MixedModeMutator { public: - explicit Unmerger(CallGraphEntry* cur_node, CallGraphNode* call_graph) + explicit CompositeInliner(CallGraphEntry* cur_node, CallGraphNode* call_graph) : cur_node_(cur_node), call_graph_(call_graph) {} - Expr VisitExpr_(const CallNode* call_node) final { + Expr Rewrite_(const CallNode* call_node) { Call vanilla_call = GetAnyCall(call_node); const auto* function_node = vanilla_call->op.as(); @@ -60,10 +60,10 @@ class Unmerger : ExprMutator { return Bind(function_node->body, bind_map); } - return ExprMutator::VisitExpr_(call_node); + return MixedModeMutator::VisitExpr_(call_node); } - Function Unmerge(const Function& func) { + Function Inline(const Function& func) { return WithFields(func, func->params, VisitExpr(func->body)); } @@ -88,13 +88,13 @@ IRModule InlineComposites(const IRModule& module, runtime::String target) { if (!base_func->GetAttr(attr::kCompiler).defined() && base_func->GetAttr(attr::kCompiler) != target) { - return module; + continue; } if (it->GetNameHint() != "main") { if (const auto* fn = base_func.as()) { auto func = GetRef(fn); - auto new_func = Unmerger(it, cg.operator->()).Unmerge(func); + auto new_func = CompositeInliner(it, cg.operator->()).Inline(func); cg->module->Update(it->GetGlobalVar(), new_func); } }