diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 76585cf1272f..e365dca3860f 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -144,11 +144,12 @@ class AnnotateTargetRewriter : public ExprRewriter { */ Expr new_expr = expr; const CallNode* call = expr.as(); + const TupleNode* tup = expr.as(); if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) { // Check whether expr has args, if not - do not insert compiler_end. if (expr->IsInstance() || expr->IsInstance() || - expr->IsInstance() || expr->IsInstance() || - expr->IsInstance() || (call && !call->args.empty())) { + expr->IsInstance() || expr->IsInstance() || + (call && !call->args.empty()) || (tup && !tup->fields.empty())) { std::string target = op_expr_to_target_[new_expr]; new_expr = InsertAnnotation(new_expr, target, make_end_op); op_expr_to_target_[new_expr] = target; diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 4f35066a8384..ce86cc603d6d 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -738,8 +738,8 @@ def after(): mod = tvm.IRModule.from_expr(func) return mod - for annotate_non_call_ops in [True, False, True]: - result = transform.AnnotateTarget(target)(before()) + for annotate_non_call_ops in [True, False]: + result = transform.AnnotateTarget(target, annotate_non_call_ops)(before()) expected = transform.InferType()(after()) assert tvm.ir.structural_equal(expected, result) @@ -764,6 +764,27 @@ def after(): assert tvm.ir.structural_equal(expected, result) +def test_empty_tuple(): + target = "test_empty_tuple" + + """An empty tuple should behave just like a call with no args (see above test).""" + + def before(): + func = relay.Function([], relay.Tuple([])) + mod = tvm.IRModule.from_expr(func) + return mod + + def after(): + func = relay.Function([], relay.Tuple([])) + mod = tvm.IRModule.from_expr(func) + return mod + + for annotate_non_call_ops in [True, False]: + result = transform.AnnotateTarget(target, annotate_non_call_ops)(before()) + expected = transform.InferType()(after()) + assert tvm.ir.structural_equal(expected, result) + + if __name__ == "__main__": test_extern_dnnl() test_composite_function() @@ -780,3 +801,4 @@ def after(): test_double_target() test_ends_with_tuple() test_ref_create_read_write() + test_empty_tuple()