Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Register layout conversion function to more reduce ops #9048

Merged
merged 4 commits into from
Sep 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 36 additions & 12 deletions src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,14 @@ Array<Integer> GetExcludeAxes(size_t indim, const Array<Integer>& inaxis) {
}

// Return the modified layout for AlterOpLayout pass.
template <typename T>
InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
const auto* attrs_ptr = attrs.as<ReduceAttrs>();
const auto* attrs_ptr = attrs.as<T>();
ICHECK(attrs_ptr);
ObjectPtr<ReduceAttrs> params = make_object<ReduceAttrs>(*attrs_ptr);
ObjectPtr<T> params = make_object<T>(*attrs_ptr);

// Get the reduce axes.
Array<Array<IndexExpr>> old_in_shapes;
Expand Down Expand Up @@ -152,11 +153,14 @@ InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs,
for (auto iter_var : layout->axes) {
const auto& layout_axis = LayoutAxis::Get(iter_var);
const std::string& layout_dim = layout_axis.name();
if (old_r_dims.count(layout_dim)) {
new_r_axes.push_back(tvm::Integer(axis_index));
}
// Collect only the primal axis.
if (layout_axis.IsPrimal()) {
if (old_r_dims.count(layout_dim) && !params->exclude) {
new_r_axes.push_back(tvm::Integer(axis_index));
}
if (!old_r_dims.count(layout_dim) && params->exclude) {
new_r_axes.push_back(tvm::Integer(axis_index));
}
if (!old_r_dims.count(layout_dim) || params->keepdims) {
inferred_out_string += layout_dim;
}
Expand All @@ -171,26 +175,38 @@ InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs,

std::string new_layout_string;
Array<Integer> new_r_axes;
Array<Layout> new_input_layouts;

auto check_num_input_layouts = [](Array<Layout> in_layouts) {
// The second case is for variance op
ICHECK(in_layouts.size() == 1 || in_layouts.size() == 2);
};

if (new_in_layouts.defined() && r_axes.size()) {
// Adapt to new layout. The axis has to change. Record original reduce axes. Convert to the
// modified layout axes.
ICHECK_EQ(new_in_layouts.size(), 1);
ICHECK_EQ(old_in_layouts.size(), 1);
check_num_input_layouts(new_in_layouts);
check_num_input_layouts(old_in_layouts);

// Get inferred_in and inferred_out from new_in_layout.
std::tie(inferred_in, inferred_out, new_r_axes) = infer(new_in_layouts[0]);
params->axis = new_r_axes;
} else if (old_in_layouts.defined()) {
ICHECK_EQ(old_in_layouts.size(), 1);
check_num_input_layouts(old_in_layouts);

// If the new layout is undefined, get inferred_in and inferred_out from old_in_layout.
if (old_in_layouts[0].defined()) {
std::tie(inferred_in, inferred_out, std::ignore) = infer(old_in_layouts[0]);
}
}

return InferCorrectLayoutOutput({inferred_in}, {inferred_out}, Attrs(params));
new_input_layouts.push_back(inferred_in);

if (old_in_layouts.size() == 2) {
new_input_layouts.push_back(inferred_in);
}

return InferCorrectLayoutOutput(new_input_layouts, {inferred_out}, Attrs(params));
}

template <typename F>
Expand Down Expand Up @@ -389,6 +405,7 @@ values over a given axis.
.set_support_level(4)
.add_type_rel("ArgReduce", GenericReduceRel<ArgReduceAttrs>)
.set_attr<FTVMCompute>("FTVMCompute", ArgMaxCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ArgReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Array<te::Tensor> ArgMinCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
Expand All @@ -405,6 +422,7 @@ values over a given axis.
.set_support_level(4)
.add_type_rel("ArgReduce", GenericReduceRel<ArgReduceAttrs>)
.set_attr<FTVMCompute>("FTVMCompute", ArgMinCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ArgReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Array<te::Tensor> SumCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
Expand Down Expand Up @@ -433,7 +451,7 @@ Example::
.set_attrs_type<ReduceAttrs>()
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
.set_attr<FTVMCompute>("FTVMCompute", SumCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Expand Down Expand Up @@ -468,6 +486,7 @@ Example::
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", AllCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Array<te::Tensor> AnyCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
Expand Down Expand Up @@ -516,6 +535,7 @@ RELAY_REGISTER_REDUCE_OP("max")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", MaxCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Array<te::Tensor> MinCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
Expand All @@ -531,6 +551,7 @@ RELAY_REGISTER_REDUCE_OP("min")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", MinCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Array<te::Tensor> ProdCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
Expand All @@ -551,17 +572,18 @@ Example::
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]

mean(data, axis=1)
prod(data, axis=1)
[35562240]

mean(data, axis=[1,2])
prod(data, axis=[1,2])
[ 36 480 2058]

)code" TVM_ADD_FILELINE)
.set_attrs_type<ReduceAttrs>()
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", ProdCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Array<te::Tensor> MeanCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
Expand Down Expand Up @@ -600,6 +622,7 @@ Example::
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", MeanCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

bool VarianceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Expand Down Expand Up @@ -675,6 +698,7 @@ RELAY_REGISTER_OP("variance")
.add_argument("mean", "Tensor", "The mean tensor.")
.add_type_rel("Variance", VarianceRel)
.set_attr<FTVMCompute>("FTVMCompute", VarianceCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<VarianceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

} // namespace relay
Expand Down
8 changes: 3 additions & 5 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,7 @@ def before():
beta = relay.var("beta")
y = relay.nn.batch_norm(y, gamma, beta, mean, var, axis=3)
y = y[0]
y = relay.Function(analysis.free_vars(y), y)
return y
return relay.Function(analysis.free_vars(y), y)

def alter_conv2d(attrs, inputs, tinfos, out_type):
data, weight = inputs
Expand All @@ -509,9 +508,8 @@ def expected():
bias = relay.layout_transform(bias, src_layout="NCHW", dst_layout="NCHW16c")
add = relay.add(y, bias)
y = relay.layout_transform(add, src_layout="NCHW16c", dst_layout="NCHW")
y = relay.layout_transform(y, src_layout="NCHW", dst_layout="NHWC")
mean = relay.mean(y, axis=3, exclude=True)
var = relay.variance(y, axis=3, exclude=True)
mean = relay.mean(y, axis=1, exclude=True)
var = relay.variance(y, axis=1, exclude=True)
denom = relay.const(1.0) / relay.sqrt(var + relay.const(1e-05))
gamma = relay.var("gamma", shape=(16,))
denom = denom * gamma
Expand Down
80 changes: 47 additions & 33 deletions tests/python/relay/test_pass_convert_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""Test alter op layout pass"""
import pytest

import tvm
from tvm import te

Expand Down Expand Up @@ -1925,37 +1927,49 @@ def infer_correct_layout_relu(attrs, new_in_layouts, old_in_layouts, old_in_type
assert test_infer_correct_layout_flag == True


def test_reduce_op_convert_layout():
for reduce_op in [relay.argmax, relay.mean, relay.max]:

def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var("weight", shape=(64, 64, 3, 3))
y = relay.nn.conv2d(
x,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
)
y = reduce_op(y, axis=[2, 3])
y = relay.Function([x, weight], y)
return y

def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var("weight", shape=(64, 64, 3, 3))
x = relay.layout_transform(x, "NCHW", "NHWC")
weight = relay.layout_transform(weight, "OIHW", "HWIO")
y = relay.nn.conv2d(
x,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
y = reduce_op(y, axis=[1, 2])
y = relay.Function(relay.analysis.free_vars(y), y)
return y

a = before()
a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NHWC", "default"]}))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


if __name__ == "__main__":
test_qnn_binary_no_convert_layout()
test_no_convert_layout()
test_conv_convert_layout()
test_conv_nhwc_convert_layout()
test_conv_bias_pool_convert_layout()
test_conv_concat_convert_layout()
test_dual_path_convert_layout()
test_bn_convert_layout()
test_slice_like_convert_layout()
test_transpose_convert_layout()
test_resnet_convert_layout()
test_scalar_convert_layout()
test_conv_bn_convert_layout()
test_qnn_conv_requantize_convert_layout()
test_qnn_conv_concat_convert_layout()
test_qnn_conv_add_convert_layout()
test_qnn_conv_nhwc_convert_layout()
test_conv_convert_kernel_layout()
test_conv_transpose_convert_layout()
test_conv_roi_align_convert_layout()
test_conv_roi_pool_convert_layout()
test_conv_strided_slice_convert_layout()
test_deformable_conv_bias_pool_convert_layout()
test_default_keyword()
test_different_ops_convert_layout()
test_no_desired_layout()
test_convert_with_config()
test_conv_squeeze_convert_layout()
test_conv_reduce_convert_layout()
test_conv_strided_slice_axes_convert_layout()
test_image_resize_convert_layout()
test_conv_image_resize_convert_layout()
test_infer_correct_layout()
pytest.main([__file__])