Skip to content

Commit

Permalink
Fix Annotate Target to support freevars(relay.zeros, relay.ones etc) …
Browse files Browse the repository at this point in the history
…of any size (including zero) (apache#6826)

* Fix Annotate Target

* Add Test Cases

* Formatting

* Comments C++

* Remove Unnecesssary test cases

* typo

* annotate_target

Co-authored-by: Ubuntu <ubuntu@ip-172-31-27-149.us-east-2.compute.internal>
  • Loading branch information
2 people authored and Trevor Morris committed Dec 2, 2020
1 parent 5f5c7ad commit edad358
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 3 deletions.
14 changes: 11 additions & 3 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,13 @@ class AnnotateTargetRewriter : public ExprRewriter {
compiler_ends.push_back(call->args[0]);
} else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
arg_target = op_expr_to_target_[arg];
compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op));
// If an argument is a call node and has no argument, then it should be tensor ops such as
// zeros, so we treat it as input vars.
if (call && call->args.size() == 0) {
compiler_ends.push_back(arg);
} else {
compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op));
}
} else {
// Input vars.
compiler_ends.push_back(arg);
Expand Down Expand Up @@ -113,14 +119,16 @@ class AnnotateTargetRewriter : public ExprRewriter {
* \brief This function inserts compiler end to expr and maps the corresponding target to the
* new expression.
*
* This function checks for expr existence within the map and inserts the annotation
* This function checks for expr existence within the map and inserts the annotation.
* If the expression has a free variable (e.g: relay.zeros, relay.ones) we do not insert
* compiler end, since there are no compiler begins for it.
* Further, it propagates the target to the new expression and returns it
*
* \param expr A relay expression
* \return An annotated and target-propagated relay expression.
*/
Expr new_expr = expr;
if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) {
if (op_expr_to_target_.find(expr) != op_expr_to_target_.end() && FreeVars(expr).size() != 0) {
new_expr = InsertAnnotation(expr, op_expr_to_target_[expr], make_end_op);
op_expr_to_target_[new_expr] = op_expr_to_target_[expr];
}
Expand Down
87 changes: 87 additions & 0 deletions tests/python/relay/test_pass_annotate_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,91 @@ def after():
assert tvm.ir.structural_equal(expected, result)


def test_if_free_vars():
target = "test_if_free_vars"

@tvm.ir.register_op_attr("equal", "target." + target)
def equal(attrs, args): # pylint: disable=unused-variable
return True

@tvm.ir.register_op_attr("sigmoid", "target." + target)
def sigmoid(attrs, args): # pylint: disable=unused-variable
return True

@tvm.ir.register_op_attr("erf", "target." + target)
def erf(attrs, args): # pylint: disable=unused-variable
return True

"""Test that If-else nodes compiles correctly when surrounded by free variables"""

def before():
data = relay.var("data", shape=(1, 32))
eq1 = relay.var("e1", shape=[], dtype="float32")
eq2 = relay.var("e2", shape=[], dtype="float32")
eq = relay.equal(eq1, eq2)

true_branch = relay.zeros(shape=(1, 32), dtype="float32")
false_branch = relay.sigmoid(data)
ife = relay.If(eq, true_branch, false_branch)
out = relay.erf(ife)

func = relay.Function([data, eq1, eq2], out)
mod = tvm.IRModule.from_expr(func)

return mod

def after():
data = relay.var("data", shape=(1, 32))
eq1 = relay.var("e1", shape=[], dtype="float32")
eq2 = relay.var("e2", shape=[], dtype="float32")

cb_1 = relay.annotation.compiler_begin(eq1, target)
cb_2 = relay.annotation.compiler_begin(eq2, target)

equality_condition = relay.equal(cb_1, cb_2)
ce_1 = relay.annotation.compiler_end(equality_condition, target)

# if condition
true_branch = relay.zeros(shape=(1, 32), dtype="float32")

# else condition
cb_3 = relay.annotation.compiler_begin(data, target)
false_branch = relay.sigmoid(cb_3)
ce_2 = relay.annotation.compiler_end(false_branch, target)

if_condition = relay.If(ce_1, true_branch, ce_2)
cb_4 = relay.annotation.compiler_begin(if_condition, target)
erf_out = relay.erf(cb_4)
ce_3 = relay.annotation.compiler_end(erf_out, target)
func = relay.Function([data, eq1, eq2], ce_3)
mod = tvm.IRModule.from_expr(func)
return mod

result = transform.AnnotateTarget(target)(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)


def test_free_vars_zeros():
target = "test_free_vars_zeros"

"""Test that free variables compile correctly on their own"""

def before():
func = relay.Function([], relay.zeros(shape=(0), dtype="float32"))
mod = tvm.IRModule.from_expr(func)
return mod

def after():
func = relay.Function([], relay.zeros(shape=(0), dtype="float32"))
mod = tvm.IRModule.from_expr(func)
return mod

result = transform.AnnotateTarget(target)(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)


if __name__ == "__main__":
test_extern_dnnl()
test_composite_function()
Expand All @@ -520,3 +605,5 @@ def after():
test_multiple_runs()
test_if_else()
test_while_let()
test_if_free_vars()
test_free_vars_zeros()

0 comments on commit edad358

Please sign in to comment.