-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BYOC][TRT] Add DFPattern support for TRT backend #10759
Conversation
a4f84ce
to
53b0c48
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this pass needs its own unit tests so it can be tested outside of the TRT partitioning flow.
@@ -446,7 +451,7 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable | |||
|
|||
|
|||
@_register_external_dynamic_check_func("nn.batch_matmul") | |||
def batch_matmul_annotate_fn(expr): | |||
def batch_matmul_annotate_fn(expr): # pylint: disable=unused-variable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change required?
@@ -543,7 +543,10 @@ def MergeComposite(pattern_table): | |||
for tup in pattern_table: | |||
if len(tup) == 2: | |||
pattern_name, pattern = tup | |||
check = lambda extract: True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missmatch in black autoformat it
|
||
namespace relay { | ||
|
||
class Unmerger : ExprMutator { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MixedModeMutator is now preferred where possible.
Function gv = GetRef<Function>(function_var_node); | ||
const auto* fn = gv.as<FunctionNode>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this needed - it looks like we already start with the FunctionNode?
|
||
Expr VisitExpr_(const CallNode* call_node) final { | ||
Call vanilla_call = GetAnyCall(call_node); | ||
const auto* global_var_node = vanilla_call->op.as<GlobalVarNode>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused
// Attrs need to be empty at this point to avoid propagating Composite and | ||
// PartitionedFromPattern that fiddling TRT code gen for registered ops. | ||
auto func = Function(fn->params, fn->body, fn->ret_type, fn->type_params, {}); | ||
return Bind(func->body, bind_map); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand this, why can't we just do
return Bind(fn->body, bind_map);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
|
||
if (!base_func->GetAttr<String>(attr::kCompiler).defined() && | ||
base_func->GetAttr<String>(attr::kCompiler) != target) { | ||
return module; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it'd be better to continue;
here rather than return, otherwise it seems if any partitioning for a different target has taken place, this will bail out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
*/ | ||
|
||
/*! | ||
* \file src/relay/transforms/unmerge_composites.cc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personal preference for the name here would be either InlineComposite or RemoveComposite, not a huge deal though, if no one else agrees we can keep it as Unmerge.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
InlineComposite makes sense, I will rename it
|
||
/*! | ||
* \file src/relay/transforms/unmerge_composites.cc | ||
* \brief Undo the partioned graphs originate from merge composite. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think 'Inline composite functions for a given target' describes this a bit better.
53b0c48
to
a4d95c7
Compare
PTAL |
a961ef8
to
a3dc554
Compare
I still think this needs a couple of simple unit tests to confirm the behaviour. Also ping @mbs-octoml if you want to take a quick look. |
@mbaret PTAL. Under |
print("merge composite reusult") | ||
print(result) | ||
print("---------------------") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably omit the prints.
def expected(): | ||
a = relay.var("a", shape=(10, 10)) | ||
b = relay.var("b", shape=(10, 10)) | ||
|
||
# add_relu function | ||
in_1 = relay.var("in_1", shape=(10, 10)) | ||
in_2 = relay.var("in_2", shape=(10, 10)) | ||
add_node = relay.add(in_1, in_2) | ||
relu_node = relay.nn.relu(add_node) | ||
add_relu = relay.Function([in_1, in_2], relu_node) | ||
return add_relu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is the same as before() (given a and b aren't used). If all we really want to test is that doing InlineComposites undoes MergeComposite, we can probably just test that the result is equal to the input.
"""Utility function to check inline composites results.""" | ||
result = run_opt_pass( | ||
graph, relay.transform.MergeComposite(pattern_table), import_prelude=import_prelude | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should put some form of check here just to confirm that a composite function has been created (so we know MergeComposite didn't just skip everything if for instance there was a pattern error).
relu relu | ||
|
||
""" | ||
pattern_table = [("add", make_add_pattern()), ("nn.relu", make_relu_pattern())] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem to match the description above.
""" | ||
|
||
|
||
def make_conv_bias_relu_pattern(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pattern doesn't seem to be used, I think either add a test for it or remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
This PR adds DFPattern support for the TRT backend without removing the existing predicate registry. Adds and extends the following: In tensorrt.py: Add a pattern_table for all the supported ops and consumes the pre-existing op_registry checks Adds an additional pass as unmerge_composites.cc. This is required for the TRT backend as it expects a single primitive function to work with, while the MergeComposite and PartitionGraph will produce a single function for each Composite pattern. Adds test_inline_composites.py which tests the newly introduced pass. Both the pattern-based and predicate-based pass sequences produce syntactically equivalent IRModules. This is to ensure backwards compatibility."
This PR adds DFPattern support for the TRT backend without removing the existing predicate registry. Adds and extends the following: In tensorrt.py: Add a pattern_table for all the supported ops and consumes the pre-existing op_registry checks Adds an additional pass as unmerge_composites.cc. This is required for the TRT backend as it expects a single primitive function to work with, while the MergeComposite and PartitionGraph will produce a single function for each Composite pattern. Adds test_inline_composites.py which tests the newly introduced pass. Both the pattern-based and predicate-based pass sequences produce syntactically equivalent IRModules. This is to ensure backwards compatibility."
This PR adds DFPattern support for the TRT backend without removing the existing predicate registry. Adds and extends the following: In tensorrt.py: Add a pattern_table for all the supported ops and consumes the pre-existing op_registry checks Adds an additional pass as unmerge_composites.cc. This is required for the TRT backend as it expects a single primitive function to work with, while the MergeComposite and PartitionGraph will produce a single function for each Composite pattern. Adds test_inline_composites.py which tests the newly introduced pass. Both the pattern-based and predicate-based pass sequences produce syntactically equivalent IRModules. This is to ensure backwards compatibility."
This PR adds
DFPattern
support for the TRT backend without removing the existing predicate registry.Adds and extends the following:
tensorrt.py
: Add apattern_table
for all the supported ops and consumes the pre-existing op_registry checksunmerge_composites.cc
. This is required for the TRT backend as it expects a single primitive function to work with, while theMergeComposite
andPartitionGraph
will produce a single function for eachComposite
pattern.test_inline_composites.py
which tests the newly introduced pass.Both the pattern-based and predicate-based pass sequences produce syntactically equivalent
IRModules
.This is to ensure backwards compatibility."
Original Pass orderding:
Pass ordering with MergeComposites and UnmergeComposites:
@mbs-octoml @mbaret @masahi