Skip to content

Commit

Permalink
[BYOC][bugfix] Handle empty tuples in annotation pass (apache#7288)
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky authored and masahi committed Jan 18, 2021
1 parent 91f17a0 commit 563e179
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
5 changes: 3 additions & 2 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,12 @@ class AnnotateTargetRewriter : public ExprRewriter {
*/
Expr new_expr = expr;
const CallNode* call = expr.as<CallNode>();
const TupleNode* tup = expr.as<TupleNode>();
if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) {
// Check whether expr has args, if not - do not insert compiler_end.
if (expr->IsInstance<RefWriteNode>() || expr->IsInstance<RefCreateNode>() ||
expr->IsInstance<RefReadNode>() || expr->IsInstance<TupleNode>() ||
expr->IsInstance<TupleGetItemNode>() || (call && !call->args.empty())) {
expr->IsInstance<RefReadNode>() || expr->IsInstance<TupleGetItemNode>() ||
(call && !call->args.empty()) || (tup && !tup->fields.empty())) {
std::string target = op_expr_to_target_[new_expr];
new_expr = InsertAnnotation(new_expr, target, make_end_op);
op_expr_to_target_[new_expr] = target;
Expand Down
26 changes: 24 additions & 2 deletions tests/python/relay/test_pass_annotate_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,8 +738,8 @@ def after():
mod = tvm.IRModule.from_expr(func)
return mod

for annotate_non_call_ops in [True, False, True]:
result = transform.AnnotateTarget(target)(before())
for annotate_non_call_ops in [True, False]:
result = transform.AnnotateTarget(target, annotate_non_call_ops)(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)

Expand All @@ -764,6 +764,27 @@ def after():
assert tvm.ir.structural_equal(expected, result)


def test_empty_tuple():
target = "test_empty_tuple"

"""An empty tuple should behave just like a call with no args (see above test)."""

def before():
func = relay.Function([], relay.Tuple([]))
mod = tvm.IRModule.from_expr(func)
return mod

def after():
func = relay.Function([], relay.Tuple([]))
mod = tvm.IRModule.from_expr(func)
return mod

for annotate_non_call_ops in [True, False]:
result = transform.AnnotateTarget(target, annotate_non_call_ops)(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)


if __name__ == "__main__":
test_extern_dnnl()
test_composite_function()
Expand All @@ -780,3 +801,4 @@ def after():
test_double_target()
test_ends_with_tuple()
test_ref_create_read_write()
test_empty_tuple()

0 comments on commit 563e179

Please sign in to comment.