Skip to content

Commit

Permalink
[ConvertLayout] slice_like support (apache#7184)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored Jan 5, 2021
1 parent 7163b5c commit d052752
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 0 deletions.
41 changes: 41 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2752,6 +2752,46 @@ Expr MakeSliceLike(Expr data, Expr shape_like, Array<Integer> axes) {
return Call(op, {data, shape_like}, Attrs(attrs), {});
}

Array<Array<Layout>> SliceLikeInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
Array<Integer> new_axes;
if (old_in_layouts.defined() && new_in_layouts.defined()) {
ICHECK_EQ(new_in_layouts.size(), 2);
ICHECK_EQ(new_in_layouts[0]->name, new_in_layouts[1]->name);
ICHECK_EQ(old_in_layouts.size(), 2);
ICHECK_EQ(old_in_layouts[0]->name, old_in_layouts[1]->name);

auto old_layout = old_in_layouts[0];
auto new_layout = new_in_layouts[0];

// Discard "const" qualifier.
auto* params = const_cast<SliceLikeAttrs*>(attrs.as<SliceLikeAttrs>());
ICHECK(params != nullptr);

for (auto axis : params->axes) {
auto new_axis = new_layout.IndexOf(old_layout[axis->value]);
// Cannot find the target axis in the new layout.
if (new_axis == -1) {
new_axes.clear();
break;
}
new_axes.push_back(new_axis);
}
if (!new_axes.empty()) {
params->axes = std::move(new_axes);
return Array<Array<Layout>>({{new_layout, new_layout}, {new_layout}});
}
}

if (old_in_layouts.defined()) {
ICHECK_EQ(old_in_layouts.size(), 2);
return {{old_in_layouts[0], old_in_layouts[1]}, {old_in_layouts[1]}};
}
return Array<Array<Layout>>({{Layout::Undef(), Layout::Undef()}, {Layout::Undef()}});
}

Array<te::Tensor> SliceLikeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<SliceLikeAttrs>();
Expand Down Expand Up @@ -2801,6 +2841,7 @@ RELAY_REGISTER_OP("slice_like")
.set_support_level(10)
.add_type_rel("SliceLike", SliceLikeRel)
.set_attr<FTVMCompute>("FTVMCompute", SliceLikeCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", SliceLikeInferCorrectLayout)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// relay.layout_transform
Expand Down
70 changes: 70 additions & 0 deletions tests/python/relay/test_pass_convert_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,75 @@ def before():
assert len(has_lt) == 1


def test_slice_like_convert_layout():
def verify_slice_like(after, expected_axes):
# Verify if the slice_like after the convert layout has the expected axes.
has_expected = list()
checker = lambda x: has_expected.append(
isinstance(x, tvm.relay.expr.Call)
and x.op.name == "slice_like"
and str(x.attrs.axes) == str(expected_axes)
)
relay.analysis.post_order_visit(after, checker)
assert any(has_expected)

def func_nhwc():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
y = relay.nn.conv2d(
x,
weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
out = relay.slice_like(y, y, axes=[1, 2])
return relay.Function(analysis.free_vars(out), out)

after = run_opt_pass(func_nhwc(), transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
verify_slice_like(after, [2, 3])

def func_nchw():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var("weight1", shape=(32, 64, 3, 3))
y = relay.nn.conv2d(
x,
weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
)
out = relay.slice_like(y, y, axes=[2, 3])
return relay.Function(analysis.free_vars(out), out)

after = run_opt_pass(func_nchw(), transform.ConvertLayout({"nn.conv2d": ["NHWC", "default"]}))
verify_slice_like(after, [1, 2])

def func_vars():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
y = relay.nn.conv2d(
x,
weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
# z has no layout information so convert layout won't happen.
z = relay.var("y", shape=(1, 56, 56, 32))
out = relay.slice_like(y, z, axes=[1, 2])
return relay.Function(analysis.free_vars(out), out)

after = run_opt_pass(func_vars(), transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
verify_slice_like(after, [1, 2])


def test_resnet_convert_layout():
def before():
x = relay.var("x", shape=(1, 56, 56, 64))
Expand Down Expand Up @@ -1412,6 +1481,7 @@ def expected():
test_conv_concat_convert_layout()
test_dual_path_convert_layout()
test_bn_convert_layout()
test_slice_like_convert_layout()
test_resnet_convert_layout()
test_scalar_convert_layout()
test_conv_bn_convert_layout()
Expand Down

0 comments on commit d052752

Please sign in to comment.