From edad358c1b81fbff40a1e2c6278947e3ae83da56 Mon Sep 17 00:00:00 2001 From: Ritwik Das Date: Tue, 3 Nov 2020 10:19:16 -0800 Subject: [PATCH] Fix Annotate Target to support freevars(relay.zeros, relay.ones etc) of any size (including zero) (#6826) * Fix Annotate Target * Add Test Cases * Formatting * Comments C++ * Remove Unnecesssary test cases * typo * annotate_target Co-authored-by: Ubuntu --- src/relay/transforms/annotate_target.cc | 14 ++- .../python/relay/test_pass_annotate_target.py | 87 +++++++++++++++++++ 2 files changed, 98 insertions(+), 3 deletions(-) diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 7a083304515be..9d160b26f1ad8 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -77,7 +77,13 @@ class AnnotateTargetRewriter : public ExprRewriter { compiler_ends.push_back(call->args[0]); } else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) { arg_target = op_expr_to_target_[arg]; - compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op)); + // If an argument is a call node and has no argument, then it should be tensor ops such as + // zeros, so we treat it as input vars. + if (call && call->args.size() == 0) { + compiler_ends.push_back(arg); + } else { + compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op)); + } } else { // Input vars. compiler_ends.push_back(arg); @@ -113,14 +119,16 @@ class AnnotateTargetRewriter : public ExprRewriter { * \brief This function inserts compiler end to expr and maps the corresponding target to the * new expression. * - * This function checks for expr existence within the map and inserts the annotation + * This function checks for expr existence within the map and inserts the annotation. + * If the expression has a free variable (e.g: relay.zeros, relay.ones) we do not insert + * compiler end, since there are no compiler begins for it. * Further, it propagates the target to the new expression and returns it * * \param expr A relay expression * \return An annotated and target-propagated relay expression. */ Expr new_expr = expr; - if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) { + if (op_expr_to_target_.find(expr) != op_expr_to_target_.end() && FreeVars(expr).size() != 0) { new_expr = InsertAnnotation(expr, op_expr_to_target_[expr], make_end_op); op_expr_to_target_[new_expr] = op_expr_to_target_[expr]; } diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index b99e3bc02ba40..106909e16fa73 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -510,6 +510,91 @@ def after(): assert tvm.ir.structural_equal(expected, result) +def test_if_free_vars(): + target = "test_if_free_vars" + + @tvm.ir.register_op_attr("equal", "target." + target) + def equal(attrs, args): # pylint: disable=unused-variable + return True + + @tvm.ir.register_op_attr("sigmoid", "target." + target) + def sigmoid(attrs, args): # pylint: disable=unused-variable + return True + + @tvm.ir.register_op_attr("erf", "target." + target) + def erf(attrs, args): # pylint: disable=unused-variable + return True + + """Test that If-else nodes compiles correctly when surrounded by free variables""" + + def before(): + data = relay.var("data", shape=(1, 32)) + eq1 = relay.var("e1", shape=[], dtype="float32") + eq2 = relay.var("e2", shape=[], dtype="float32") + eq = relay.equal(eq1, eq2) + + true_branch = relay.zeros(shape=(1, 32), dtype="float32") + false_branch = relay.sigmoid(data) + ife = relay.If(eq, true_branch, false_branch) + out = relay.erf(ife) + + func = relay.Function([data, eq1, eq2], out) + mod = tvm.IRModule.from_expr(func) + + return mod + + def after(): + data = relay.var("data", shape=(1, 32)) + eq1 = relay.var("e1", shape=[], dtype="float32") + eq2 = relay.var("e2", shape=[], dtype="float32") + + cb_1 = relay.annotation.compiler_begin(eq1, target) + cb_2 = relay.annotation.compiler_begin(eq2, target) + + equality_condition = relay.equal(cb_1, cb_2) + ce_1 = relay.annotation.compiler_end(equality_condition, target) + + # if condition + true_branch = relay.zeros(shape=(1, 32), dtype="float32") + + # else condition + cb_3 = relay.annotation.compiler_begin(data, target) + false_branch = relay.sigmoid(cb_3) + ce_2 = relay.annotation.compiler_end(false_branch, target) + + if_condition = relay.If(ce_1, true_branch, ce_2) + cb_4 = relay.annotation.compiler_begin(if_condition, target) + erf_out = relay.erf(cb_4) + ce_3 = relay.annotation.compiler_end(erf_out, target) + func = relay.Function([data, eq1, eq2], ce_3) + mod = tvm.IRModule.from_expr(func) + return mod + + result = transform.AnnotateTarget(target)(before()) + expected = transform.InferType()(after()) + assert tvm.ir.structural_equal(expected, result) + + +def test_free_vars_zeros(): + target = "test_free_vars_zeros" + + """Test that free variables compile correctly on their own""" + + def before(): + func = relay.Function([], relay.zeros(shape=(0), dtype="float32")) + mod = tvm.IRModule.from_expr(func) + return mod + + def after(): + func = relay.Function([], relay.zeros(shape=(0), dtype="float32")) + mod = tvm.IRModule.from_expr(func) + return mod + + result = transform.AnnotateTarget(target)(before()) + expected = transform.InferType()(after()) + assert tvm.ir.structural_equal(expected, result) + + if __name__ == "__main__": test_extern_dnnl() test_composite_function() @@ -520,3 +605,5 @@ def after(): test_multiple_runs() test_if_else() test_while_let() + test_if_free_vars() + test_free_vars_zeros()