Skip to content

Commit

Permalink
[Relay] Register layout conversion function to more reduce ops (apach…
Browse files Browse the repository at this point in the history
…e#9048)

* Register layout conversion function to more reduce ops

* bug fix for exclude=True case, the original code compute wrong axes

* properly handle variance op, which has two inputs

* update test expected output
  • Loading branch information
masahi authored and ylc committed Jan 13, 2022
1 parent ca87c28 commit b6da88c
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 50 deletions.
48 changes: 36 additions & 12 deletions src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,14 @@ Array<Integer> GetExcludeAxes(size_t indim, const Array<Integer>& inaxis) {
}

// Return the modified layout for AlterOpLayout pass.
template <typename T>
InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
const auto* attrs_ptr = attrs.as<ReduceAttrs>();
const auto* attrs_ptr = attrs.as<T>();
ICHECK(attrs_ptr);
ObjectPtr<ReduceAttrs> params = make_object<ReduceAttrs>(*attrs_ptr);
ObjectPtr<T> params = make_object<T>(*attrs_ptr);

// Get the reduce axes.
Array<Array<IndexExpr>> old_in_shapes;
Expand Down Expand Up @@ -152,11 +153,14 @@ InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs,
for (auto iter_var : layout->axes) {
const auto& layout_axis = LayoutAxis::Get(iter_var);
const std::string& layout_dim = layout_axis.name();
if (old_r_dims.count(layout_dim)) {
new_r_axes.push_back(tvm::Integer(axis_index));
}
// Collect only the primal axis.
if (layout_axis.IsPrimal()) {
if (old_r_dims.count(layout_dim) && !params->exclude) {
new_r_axes.push_back(tvm::Integer(axis_index));
}
if (!old_r_dims.count(layout_dim) && params->exclude) {
new_r_axes.push_back(tvm::Integer(axis_index));
}
if (!old_r_dims.count(layout_dim) || params->keepdims) {
inferred_out_string += layout_dim;
}
Expand All @@ -171,26 +175,38 @@ InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs,

std::string new_layout_string;
Array<Integer> new_r_axes;
Array<Layout> new_input_layouts;

auto check_num_input_layouts = [](Array<Layout> in_layouts) {
// The second case is for variance op
ICHECK(in_layouts.size() == 1 || in_layouts.size() == 2);
};

if (new_in_layouts.defined() && r_axes.size()) {
// Adapt to new layout. The axis has to change. Record original reduce axes. Convert to the
// modified layout axes.
ICHECK_EQ(new_in_layouts.size(), 1);
ICHECK_EQ(old_in_layouts.size(), 1);
check_num_input_layouts(new_in_layouts);
check_num_input_layouts(old_in_layouts);

// Get inferred_in and inferred_out from new_in_layout.
std::tie(inferred_in, inferred_out, new_r_axes) = infer(new_in_layouts[0]);
params->axis = new_r_axes;
} else if (old_in_layouts.defined()) {
ICHECK_EQ(old_in_layouts.size(), 1);
check_num_input_layouts(old_in_layouts);

// If the new layout is undefined, get inferred_in and inferred_out from old_in_layout.
if (old_in_layouts[0].defined()) {
std::tie(inferred_in, inferred_out, std::ignore) = infer(old_in_layouts[0]);
}
}

return InferCorrectLayoutOutput({inferred_in}, {inferred_out}, Attrs(params));
new_input_layouts.push_back(inferred_in);

if (old_in_layouts.size() == 2) {
new_input_layouts.push_back(inferred_in);
}

return InferCorrectLayoutOutput(new_input_layouts, {inferred_out}, Attrs(params));
}

template <typename F>
Expand Down Expand Up @@ -389,6 +405,7 @@ values over a given axis.
.set_support_level(4)
.add_type_rel("ArgReduce", GenericReduceRel<ArgReduceAttrs>)
.set_attr<FTVMCompute>("FTVMCompute", ArgMaxCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ArgReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Array<te::Tensor> ArgMinCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
Expand All @@ -405,6 +422,7 @@ values over a given axis.
.set_support_level(4)
.add_type_rel("ArgReduce", GenericReduceRel<ArgReduceAttrs>)
.set_attr<FTVMCompute>("FTVMCompute", ArgMinCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ArgReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Array<te::Tensor> SumCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
Expand Down Expand Up @@ -433,7 +451,7 @@ Example::
.set_attrs_type<ReduceAttrs>()
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
.set_attr<FTVMCompute>("FTVMCompute", SumCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Expand Down Expand Up @@ -468,6 +486,7 @@ Example::
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", AllCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Array<te::Tensor> AnyCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
Expand Down Expand Up @@ -516,6 +535,7 @@ RELAY_REGISTER_REDUCE_OP("max")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", MaxCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Array<te::Tensor> MinCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
Expand All @@ -531,6 +551,7 @@ RELAY_REGISTER_REDUCE_OP("min")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", MinCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Array<te::Tensor> ProdCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
Expand All @@ -551,17 +572,18 @@ Example::
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
mean(data, axis=1)
prod(data, axis=1)
[35562240]
mean(data, axis=[1,2])
prod(data, axis=[1,2])
[ 36 480 2058]
)code" TVM_ADD_FILELINE)
.set_attrs_type<ReduceAttrs>()
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", ProdCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Array<te::Tensor> MeanCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
Expand Down Expand Up @@ -600,6 +622,7 @@ Example::
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", MeanCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

bool VarianceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Expand Down Expand Up @@ -675,6 +698,7 @@ RELAY_REGISTER_OP("variance")
.add_argument("mean", "Tensor", "The mean tensor.")
.add_type_rel("Variance", VarianceRel)
.set_attr<FTVMCompute>("FTVMCompute", VarianceCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<VarianceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

} // namespace relay
Expand Down
8 changes: 3 additions & 5 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,7 @@ def before():
beta = relay.var("beta")
y = relay.nn.batch_norm(y, gamma, beta, mean, var, axis=3)
y = y[0]
y = relay.Function(analysis.free_vars(y), y)
return y
return relay.Function(analysis.free_vars(y), y)

def alter_conv2d(attrs, inputs, tinfos, out_type):
data, weight = inputs
Expand All @@ -509,9 +508,8 @@ def expected():
bias = relay.layout_transform(bias, src_layout="NCHW", dst_layout="NCHW16c")
add = relay.add(y, bias)
y = relay.layout_transform(add, src_layout="NCHW16c", dst_layout="NCHW")
y = relay.layout_transform(y, src_layout="NCHW", dst_layout="NHWC")
mean = relay.mean(y, axis=3, exclude=True)
var = relay.variance(y, axis=3, exclude=True)
mean = relay.mean(y, axis=1, exclude=True)
var = relay.variance(y, axis=1, exclude=True)
denom = relay.const(1.0) / relay.sqrt(var + relay.const(1e-05))
gamma = relay.var("gamma", shape=(16,))
denom = denom * gamma
Expand Down
80 changes: 47 additions & 33 deletions tests/python/relay/test_pass_convert_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""Test alter op layout pass"""
import pytest

import tvm
from tvm import te

Expand Down Expand Up @@ -1925,37 +1927,49 @@ def infer_correct_layout_relu(attrs, new_in_layouts, old_in_layouts, old_in_type
assert test_infer_correct_layout_flag == True


def test_reduce_op_convert_layout():
for reduce_op in [relay.argmax, relay.mean, relay.max]:

def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var("weight", shape=(64, 64, 3, 3))
y = relay.nn.conv2d(
x,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
)
y = reduce_op(y, axis=[2, 3])
y = relay.Function([x, weight], y)
return y

def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var("weight", shape=(64, 64, 3, 3))
x = relay.layout_transform(x, "NCHW", "NHWC")
weight = relay.layout_transform(weight, "OIHW", "HWIO")
y = relay.nn.conv2d(
x,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
y = reduce_op(y, axis=[1, 2])
y = relay.Function(relay.analysis.free_vars(y), y)
return y

a = before()
a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NHWC", "default"]}))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


if __name__ == "__main__":
test_qnn_binary_no_convert_layout()
test_no_convert_layout()
test_conv_convert_layout()
test_conv_nhwc_convert_layout()
test_conv_bias_pool_convert_layout()
test_conv_concat_convert_layout()
test_dual_path_convert_layout()
test_bn_convert_layout()
test_slice_like_convert_layout()
test_transpose_convert_layout()
test_resnet_convert_layout()
test_scalar_convert_layout()
test_conv_bn_convert_layout()
test_qnn_conv_requantize_convert_layout()
test_qnn_conv_concat_convert_layout()
test_qnn_conv_add_convert_layout()
test_qnn_conv_nhwc_convert_layout()
test_conv_convert_kernel_layout()
test_conv_transpose_convert_layout()
test_conv_roi_align_convert_layout()
test_conv_roi_pool_convert_layout()
test_conv_strided_slice_convert_layout()
test_deformable_conv_bias_pool_convert_layout()
test_default_keyword()
test_different_ops_convert_layout()
test_no_desired_layout()
test_convert_with_config()
test_conv_squeeze_convert_layout()
test_conv_reduce_convert_layout()
test_conv_strided_slice_axes_convert_layout()
test_image_resize_convert_layout()
test_conv_image_resize_convert_layout()
test_infer_correct_layout()
pytest.main([__file__])

0 comments on commit b6da88c

Please sign in to comment.