Skip to content

Commit

Permalink
[ConvertLayout] Keep span in ConvertLayout (#7895)
Browse files Browse the repository at this point in the history
  • Loading branch information
lixiaoquan authored Apr 23, 2021
1 parent c7ac5ee commit 60c170e
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/relay/transforms/convert_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class ConvertTransformMemorizer : public TransformMemorizer {

const CallNode* new_call = new_e.as<CallNode>();
ICHECK(new_call) << "Can only replace the original operator with another call node";
return GetRef<Call>(new_call);
return Call(new_call->op, new_call->args, new_call->attrs, new_call->type_args, ref_call->span);
}

using ContainerType = ConvertTransformMemorizerNode;
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/forward_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class ForwardRewriter : private MixedModeMutator {
}
}
if (unchanged) return ref_call;
return Call(new_op, call_args, call_node->attrs, call_node->type_args);
return Call(new_op, call_args, call_node->attrs, call_node->type_args, call_node->span);
}
};

Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/transform_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
} else {
auto rnode = make_object<LayoutAlternatedExprNode<TransformMemorizerT>>();
ICHECK_EQ(new_out.size(), 1);
rnode->value = Call(new_call->op, transformed_args, new_call->attrs);
rnode->value = Call(new_call->op, transformed_args, new_call->attrs, {}, ref_call->span);
rnode->old_layout = old_out[0];
rnode->new_layout = new_out[0];
rnode->memorizer = memorizer;
Expand Down
13 changes: 13 additions & 0 deletions tests/python/relay/test_pass_convert_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1703,6 +1703,11 @@ def expected():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)

def _test_conv_reduce_convert_layout2():
def _set_span(y, text):
return relay.Call(
y.op, y.args, y.attrs, y.type_args, relay.Span(relay.SourceName(text), 0, 0, 0, 0)
)

def before():
x = relay.var("x", shape=(1, 38, 38, 512))
weight = relay.var("weight", shape=(3, 3, 512, 512))
Expand All @@ -1714,9 +1719,13 @@ def before():
data_layout="NHWC",
kernel_layout="HWIO",
)
y = _set_span(y, "SpanConv2D")
y = relay.nn.relu(y)
y = _set_span(y, "SpanRelu")
y = relay.multiply(y, y)
y = _set_span(y, "SpanMultiply")
y = relay.sum(y, axis=(3,), keepdims=True)
y = _set_span(y, "SpanSum")
return relay.Function(analysis.free_vars(y), y)

def expected():
Expand All @@ -1733,6 +1742,10 @@ def expected():

a = before()
a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
assert "SpanConv2D" in a.astext()
assert "SpanRelu" in a.astext()
assert "SpanMultiply" in a.astext()
assert "SpanSum" in a.astext()
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
Expand Down

0 comments on commit 60c170e

Please sign in to comment.