Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC][TRT] Add DFPattern support for TRT backend #10759

Merged
merged 8 commits into from
Apr 4, 2022

Conversation

mikepapadim
Copy link
Contributor

@mikepapadim mikepapadim commented Mar 24, 2022

This PR adds DFPattern support for the TRT backend without removing the existing predicate registry.

Adds and extends the following:

  • In tensorrt.py: Add a pattern_table for all the supported ops and consumes the pre-existing op_registry checks
  • Adds an additional pass as unmerge_composites.cc. This is required for the TRT backend as it expects a single primitive function to work with, while the MergeComposite and PartitionGraph will produce a single function for each Composite pattern.
  • Adds test_inline_composites.py which tests the newly introduced pass.

Both the pattern-based and predicate-based pass sequences produce syntactically equivalent IRModules.
This is to ensure backwards compatibility."

Original Pass orderding:

  seq = tvm.transform.Sequential(
        [
            transform.InferType(),
            RemoveDropoutPass(),
            transform.RemoveUnusedFunctions(),
            transform.ConvertLayout(
                {
                    "nn.conv1d": ["NCW", "default"],
                    "nn.conv2d": ["NCHW", "default"],
                    "nn.conv3d": ["NCDHW", "default"],
                    "nn.conv2d_transpose": ["NCHW", "default"],
                }
            ),
            transform.FoldConstant(),
            transform.AnnotateTarget("tensorrt"),
            transform.MergeCompilerRegions(),
            transform.PartitionGraph(),
            transform.InferType(),
        ]
    )

Pass ordering with MergeComposites and UnmergeComposites:

  seq = tvm.transform.Sequential(
        [
            transform.InferType(),
            RemoveDropoutPass(),
            transform.RemoveUnusedFunctions(),
            transform.ConvertLayout(
                {
                    "nn.conv1d": ["NCW", "default"],
                    "nn.conv2d": ["NCHW", "default"],
                    "nn.conv3d": ["NCDHW", "default"],
                    "nn.conv2d_transpose": ["NCHW", "default"],
                }
            ),
            transform.FoldConstant(),
            transform.MergeComposite(pattern_table()),                                   <-------- Change #1
            transform.AnnotateTarget("tensorrt"),
            transform.MergeCompilerRegions(),
            transform.PartitionGraph(),
            transform.InlineComposites("tensorrt"),                                     <-------- Change #2
            transform.InferType(),
        ]
    )

@mbs-octoml @mbaret @masahi

@mikepapadim mikepapadim force-pushed the pattern_trt branch 2 times, most recently from a4f84ce to 53b0c48 Compare March 24, 2022 12:48
Copy link
Contributor

@mbaret mbaret left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this pass needs its own unit tests so it can be tested outside of the TRT partitioning flow.

@@ -446,7 +451,7 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable


@_register_external_dynamic_check_func("nn.batch_matmul")
def batch_matmul_annotate_fn(expr):
def batch_matmul_annotate_fn(expr): # pylint: disable=unused-variable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change required?

@@ -543,7 +543,10 @@ def MergeComposite(pattern_table):
for tup in pattern_table:
if len(tup) == 2:
pattern_name, pattern = tup
check = lambda extract: True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missmatch in black autoformat it


namespace relay {

class Unmerger : ExprMutator {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MixedModeMutator is now preferred where possible.

Comment on lines 48 to 49
Function gv = GetRef<Function>(function_var_node);
const auto* fn = gv.as<FunctionNode>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed - it looks like we already start with the FunctionNode?


Expr VisitExpr_(const CallNode* call_node) final {
Call vanilla_call = GetAnyCall(call_node);
const auto* global_var_node = vanilla_call->op.as<GlobalVarNode>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused

Comment on lines 62 to 65
// Attrs need to be empty at this point to avoid propagating Composite and
// PartitionedFromPattern that fiddling TRT code gen for registered ops.
auto func = Function(fn->params, fn->body, fn->ret_type, fn->type_params, {});
return Bind(func->body, bind_map);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand this, why can't we just do

return Bind(fn->body, bind_map);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1


if (!base_func->GetAttr<String>(attr::kCompiler).defined() &&
base_func->GetAttr<String>(attr::kCompiler) != target) {
return module;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be better to continue; here rather than return, otherwise it seems if any partitioning for a different target has taken place, this will bail out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

*/

/*!
* \file src/relay/transforms/unmerge_composites.cc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personal preference for the name here would be either InlineComposite or RemoveComposite, not a huge deal though, if no one else agrees we can keep it as Unmerge.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InlineComposite makes sense, I will rename it


/*!
* \file src/relay/transforms/unmerge_composites.cc
* \brief Undo the partioned graphs originate from merge composite.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 'Inline composite functions for a given target' describes this a bit better.

@mikepapadim
Copy link
Contributor Author

PTAL

@mbaret
Copy link
Contributor

mbaret commented Mar 25, 2022

I still think this needs a couple of simple unit tests to confirm the behaviour. Also ping @mbs-octoml if you want to take a quick look.

@mikepapadim
Copy link
Contributor Author

@mbs-octoml

@mikepapadim
Copy link
Contributor Author

@mbaret PTAL. Under test_inline_composites.py are a couple o unit-tests for testing the new pass.

Comment on lines 123 to 125
print("merge composite reusult")
print(result)
print("---------------------")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably omit the prints.

Comment on lines 164 to 174
def expected():
a = relay.var("a", shape=(10, 10))
b = relay.var("b", shape=(10, 10))

# add_relu function
in_1 = relay.var("in_1", shape=(10, 10))
in_2 = relay.var("in_2", shape=(10, 10))
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
return add_relu
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the same as before() (given a and b aren't used). If all we really want to test is that doing InlineComposites undoes MergeComposite, we can probably just test that the result is equal to the input.

"""Utility function to check inline composites results."""
result = run_opt_pass(
graph, relay.transform.MergeComposite(pattern_table), import_prelude=import_prelude
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should put some form of check here just to confirm that a composite function has been created (so we know MergeComposite didn't just skip everything if for instance there was a pattern error).

relu relu

"""
pattern_table = [("add", make_add_pattern()), ("nn.relu", make_relu_pattern())]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to match the description above.

"""


def make_conv_bias_relu_pattern():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern doesn't seem to be used, I think either add a test for it or remove.

tests/python/relay/test_pass_inline_composites.py Outdated Show resolved Hide resolved
Copy link
Contributor

@mbaret mbaret left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@mbaret mbaret merged commit 98580a2 into apache:main Apr 4, 2022
pfk-beta pushed a commit to pfk-beta/tvm that referenced this pull request Apr 11, 2022
This PR adds DFPattern support for the TRT backend without removing the existing predicate registry.

Adds and extends the following:

In tensorrt.py: Add a pattern_table for all the supported ops and consumes the pre-existing op_registry checks
Adds an additional pass as unmerge_composites.cc. This is required for the TRT backend as it expects a single primitive 
function to work with, while the MergeComposite and PartitionGraph will produce a single function for each Composite 
pattern.

Adds test_inline_composites.py which tests the newly introduced pass.
Both the pattern-based and predicate-based pass sequences produce syntactically equivalent IRModules.
This is to ensure backwards compatibility."
mehrdadh pushed a commit to mehrdadh/tvm that referenced this pull request Apr 11, 2022
This PR adds DFPattern support for the TRT backend without removing the existing predicate registry.

Adds and extends the following:

In tensorrt.py: Add a pattern_table for all the supported ops and consumes the pre-existing op_registry checks
Adds an additional pass as unmerge_composites.cc. This is required for the TRT backend as it expects a single primitive 
function to work with, while the MergeComposite and PartitionGraph will produce a single function for each Composite 
pattern.

Adds test_inline_composites.py which tests the newly introduced pass.
Both the pattern-based and predicate-based pass sequences produce syntactically equivalent IRModules.
This is to ensure backwards compatibility."
Lucien0 pushed a commit to Lucien0/tvm that referenced this pull request Apr 19, 2022
This PR adds DFPattern support for the TRT backend without removing the existing predicate registry.

Adds and extends the following:

In tensorrt.py: Add a pattern_table for all the supported ops and consumes the pre-existing op_registry checks
Adds an additional pass as unmerge_composites.cc. This is required for the TRT backend as it expects a single primitive 
function to work with, while the MergeComposite and PartitionGraph will produce a single function for each Composite 
pattern.

Adds test_inline_composites.py which tests the newly introduced pass.
Both the pattern-based and predicate-based pass sequences produce syntactically equivalent IRModules.
This is to ensure backwards compatibility."
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants