Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Michalis Papapdimitriou committed Mar 25, 2022
1 parent a4d95c7 commit a961ef8
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 30 deletions.
89 changes: 66 additions & 23 deletions python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def partition_for_tensorrt(
max_workspace_size=1 << 30,
use_fp16=False,
use_uint8=False,
use_patterns=False,
):
"""Partition the graph greedily offloading supported operators to TensorRT.
Expand Down Expand Up @@ -136,6 +137,8 @@ def partition_for_tensorrt(
lower runtime, or if no low-precision implementation exists.
use_uint8: Optional[bool]
Allows, TRT to automatically convert FP32 inputs to UINT8.
use_patterns: Optional[bool]
Switches to use pattern-based op suppot by applying MergeCompsite and InlineComposites passes.
Returns
-------
mod_and_config : Tuple[Module, Dict[str, Any]]
Expand Down Expand Up @@ -164,34 +167,74 @@ def partition_for_tensorrt(

if params:
mod["main"] = bind_params_by_name(mod["main"], params)
seq = tvm.transform.Sequential(
[
transform.InferType(),
RemoveDropoutPass(),
transform.RemoveUnusedFunctions(),
transform.ConvertLayout(
{
"nn.conv1d": ["NCW", "default"],
"nn.conv2d": ["NCHW", "default"],
"nn.conv3d": ["NCDHW", "default"],
"nn.conv2d_transpose": ["NCHW", "default"],
}
),
transform.FoldConstant(),
transform.MergeComposite(pattern_table()),
transform.AnnotateTarget("tensorrt"),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
transform.InlineComposites("tensorrt"),
transform.InferType(),
]
)

seq = get_pass_order(use_patterns)
with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
mod = seq(mod)
mod = prune_tensorrt_subgraphs(mod)
return mod, config


def get_pass_order(use_patterns):
"""
Get the pass ordering based on using predicates or patterns.
Parameters
----------
use_patterns: Bool
True if pass needs to work with op patterns
Returns
----------
ret : Sequential
Pass object
"""
return (
tvm.transform.Sequential(
[
transform.InferType(),
RemoveDropoutPass(),
transform.RemoveUnusedFunctions(),
transform.ConvertLayout(
{
"nn.conv1d": ["NCW", "default"],
"nn.conv2d": ["NCHW", "default"],
"nn.conv3d": ["NCDHW", "default"],
"nn.conv2d_transpose": ["NCHW", "default"],
}
),
transform.FoldConstant(),
transform.MergeComposite(pattern_table()),
transform.AnnotateTarget("tensorrt"),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
transform.InlineComposites("tensorrt"),
transform.InferType(),
]
)
if use_patterns
else tvm.transform.Sequential(
[
transform.InferType(),
RemoveDropoutPass(),
transform.RemoveUnusedFunctions(),
transform.ConvertLayout(
{
"nn.conv1d": ["NCW", "default"],
"nn.conv2d": ["NCHW", "default"],
"nn.conv3d": ["NCDHW", "default"],
"nn.conv2d_transpose": ["NCHW", "default"],
}
),
transform.FoldConstant(),
transform.AnnotateTarget("tensorrt"),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
transform.InferType(),
]
)
)


def check_dynamism(args, op_name):
"""
Check for dynamism inside any of the args in the op.
Expand Down Expand Up @@ -451,7 +494,7 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable


@_register_external_dynamic_check_func("nn.batch_matmul")
def batch_matmul_annotate_fn(expr): # pylint: disable=unused-variable
def batch_matmul_annotate_fn(expr):
"""Check if dense is supported by TensorRT."""

args = expr.args
Expand Down
14 changes: 7 additions & 7 deletions src/relay/transforms/inline_composites.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ namespace tvm {

namespace relay {

class Unmerger : ExprMutator {
class CompositeInliner : public MixedModeMutator {
public:
explicit Unmerger(CallGraphEntry* cur_node, CallGraphNode* call_graph)
explicit CompositeInliner(CallGraphEntry* cur_node, CallGraphNode* call_graph)
: cur_node_(cur_node), call_graph_(call_graph) {}

Expr VisitExpr_(const CallNode* call_node) final {
Expr Rewrite_(const CallNode* call_node) {
Call vanilla_call = GetAnyCall(call_node);
const auto* function_node = vanilla_call->op.as<FunctionNode>();

Expand All @@ -60,10 +60,10 @@ class Unmerger : ExprMutator {
return Bind(function_node->body, bind_map);
}

return ExprMutator::VisitExpr_(call_node);
return MixedModeMutator::VisitExpr_(call_node);
}

Function Unmerge(const Function& func) {
Function Inline(const Function& func) {
return WithFields(func, func->params, VisitExpr(func->body));
}

Expand All @@ -88,13 +88,13 @@ IRModule InlineComposites(const IRModule& module, runtime::String target) {

if (!base_func->GetAttr<String>(attr::kCompiler).defined() &&
base_func->GetAttr<String>(attr::kCompiler) != target) {
return module;
continue;
}

if (it->GetNameHint() != "main") {
if (const auto* fn = base_func.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
auto new_func = Unmerger(it, cg.operator->()).Unmerge(func);
auto new_func = CompositeInliner(it, cg.operator->()).Inline(func);
cg->module->Update(it->GetGlobalVar(), new_func);
}
}
Expand Down

0 comments on commit a961ef8

Please sign in to comment.