Skip to content

Commit

Permalink
[Op] Do not override specified layout in pooling (2nd PR) (#9328)
Browse files Browse the repository at this point in the history
* [Op] Do not override specified layout in pooling (2nd PR)

* [Op] Do not override specified layout in pooling (2nd PR)

* [Op] Do not override specified layout in pooling (2nd PR)

* [Op] Do not override specified layout in pooling (2nd PR)
  • Loading branch information
ccjoechou authored Oct 21, 2021
1 parent e62075d commit d11bdcd
Show file tree
Hide file tree
Showing 9 changed files with 675 additions and 75 deletions.
78 changes: 78 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,7 @@ struct MaxPool2DAttrs : public tvm::AttrsNode<MaxPool2DAttrs> {
Array<IndexExpr> padding;
Array<IndexExpr> dilation;
tvm::String layout;
tvm::String out_layout;
bool ceil_mode;

TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relay.attrs.MaxPool2DAttrs") {
Expand All @@ -709,6 +710,13 @@ struct MaxPool2DAttrs : public tvm::AttrsNode<MaxPool2DAttrs> {
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Pooling is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(out_layout)
.set_default("")
.describe(
"Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Pooling is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
"When true, will use ceil instead of floor to compute the output shape.");
}
Expand All @@ -721,6 +729,7 @@ struct AvgPool2DAttrs : public tvm::AttrsNode<AvgPool2DAttrs> {
Array<IndexExpr> padding;
Array<IndexExpr> dilation;
tvm::String layout;
tvm::String out_layout;
bool ceil_mode;
bool count_include_pad;

Expand All @@ -745,6 +754,13 @@ struct AvgPool2DAttrs : public tvm::AttrsNode<AvgPool2DAttrs> {
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Pooling is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(out_layout)
.set_default("")
.describe(
"Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Pooling is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
"When true, will use ceil instead of floor to compute the output shape.");
TVM_ATTR_FIELD(count_include_pad)
Expand All @@ -756,20 +772,29 @@ struct AvgPool2DAttrs : public tvm::AttrsNode<AvgPool2DAttrs> {
/*! \brief Attributes for global pool operator */
struct GlobalPool2DAttrs : public tvm::AttrsNode<GlobalPool2DAttrs> {
tvm::String layout;
tvm::String out_layout;

TVM_DECLARE_ATTRS(GlobalPool2DAttrs, "relay.attrs.GlobalPool2DAttrs") {
TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
"Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Pooling is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(out_layout)
.set_default("")
.describe(
"Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Pooling is applied on the 'H' and"
"'W' dimensions.");
}
};

/*! \brief Attributes for 1d adaptive pool operator */
struct AdaptivePool1DAttrs : public tvm::AttrsNode<AdaptivePool1DAttrs> {
Array<IndexExpr> output_size;
std::string layout;
tvm::String out_layout;

TVM_DECLARE_ATTRS(AdaptivePool1DAttrs, "relay.attrs.AdaptivePool1DAttrs") {
TVM_ATTR_FIELD(output_size).set_default(Array<IndexExpr>({})).describe("Output width.");
Expand All @@ -778,13 +803,21 @@ struct AdaptivePool1DAttrs : public tvm::AttrsNode<AdaptivePool1DAttrs> {
"'N', 'C', 'W' stands for batch, channel, and width"
"dimensions respectively. Pooling is applied on the"
"'W' dimension.");
TVM_ATTR_FIELD(out_layout)
.set_default("")
.describe(
"Dimension ordering of output data. Can be 'NCW', 'NWC', etc."
"'N', 'C', 'W' stands for batch, channel, and width"
"dimensions respectively. Pooling is applied on the"
"'W' dimension.");
}
};

/*! \brief Attributes for 2d adaptive pool operator */
struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
Array<IndexExpr> output_size;
std::string layout;
tvm::String out_layout;

TVM_DECLARE_ATTRS(AdaptivePool2DAttrs, "relay.attrs.AdaptivePool2DAttrs") {
TVM_ATTR_FIELD(output_size)
Expand All @@ -795,13 +828,21 @@ struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Pooling is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(out_layout)
.set_default("")
.describe(
"Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Pooling is applied on the 'H' and"
"'W' dimensions.");
}
};

/*! \brief Attributes for 3d adaptive pool operator */
struct AdaptivePool3DAttrs : public tvm::AttrsNode<AdaptivePool3DAttrs> {
Array<IndexExpr> output_size;
std::string layout;
tvm::String out_layout;

TVM_DECLARE_ATTRS(AdaptivePool3DAttrs, "relay.attrs.AdaptivePool3DAttrs") {
TVM_ATTR_FIELD(output_size)
Expand All @@ -812,6 +853,13 @@ struct AdaptivePool3DAttrs : public tvm::AttrsNode<AdaptivePool3DAttrs> {
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Pooling is applied on 'D', 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(out_layout)
.set_default("")
.describe(
"Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Pooling is applied on 'D', 'H' and"
"'W' dimensions.");
}
};

Expand All @@ -822,6 +870,7 @@ struct MaxPool1DAttrs : public tvm::AttrsNode<MaxPool1DAttrs> {
Array<IndexExpr> dilation;
Array<IndexExpr> padding;
std::string layout;
tvm::String out_layout;
bool ceil_mode;

TVM_DECLARE_ATTRS(MaxPool1DAttrs, "relay.attrs.MaxPool1DAttrs") {
Expand All @@ -844,6 +893,12 @@ struct MaxPool1DAttrs : public tvm::AttrsNode<MaxPool1DAttrs> {
"Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
"'N', 'C', 'W' stands for batch, channel, and width"
"dimensions respectively. Pooling is applied on the 'W' dimensions.");
TVM_ATTR_FIELD(out_layout)
.set_default("")
.describe(
"Dimension ordering of output data. Can be 'NCW', 'NWC', etc."
"'N', 'C', 'W' stands for batch, channel, and width"
"dimensions respectively. Pooling is applied on the 'W' dimensions.");
TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
"When true, will use ceil instead of floor to compute the output shape.");
}
Expand All @@ -856,6 +911,7 @@ struct AvgPool1DAttrs : public tvm::AttrsNode<AvgPool1DAttrs> {
Array<IndexExpr> dilation;
Array<IndexExpr> padding;
std::string layout;
tvm::String out_layout;
bool ceil_mode;
bool count_include_pad;

Expand All @@ -879,6 +935,12 @@ struct AvgPool1DAttrs : public tvm::AttrsNode<AvgPool1DAttrs> {
"Dimension ordering of input data. Can be 'NCW', 'NHC', etc."
"'N', 'C', 'W' stands for batch, channel, and width"
"dimensions respectively. Pooling is applied on the 'W' dimension.");
TVM_ATTR_FIELD(out_layout)
.set_default("")
.describe(
"Dimension ordering of output data. Can be 'NCW', 'NHC', etc."
"'N', 'C', 'W' stands for batch, channel, and width"
"dimensions respectively. Pooling is applied on the 'W' dimension.");
TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
"When true, will use ceil instead of floor to compute the output shape.");
TVM_ATTR_FIELD(count_include_pad)
Expand All @@ -894,6 +956,7 @@ struct MaxPool3DAttrs : public tvm::AttrsNode<MaxPool3DAttrs> {
Array<IndexExpr> dilation;
Array<IndexExpr> padding;
std::string layout;
tvm::String out_layout;
bool ceil_mode;

TVM_DECLARE_ATTRS(MaxPool3DAttrs, "relay.attrs.MaxPool3DAttrs") {
Expand All @@ -917,6 +980,13 @@ struct MaxPool3DAttrs : public tvm::AttrsNode<MaxPool3DAttrs> {
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Pooling is applied on the 'D', 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(out_layout)
.set_default("")
.describe(
"Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Pooling is applied on the 'D', 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
"When true, will use ceil instead of floor to compute the output shape.");
}
Expand All @@ -929,6 +999,7 @@ struct AvgPool3DAttrs : public tvm::AttrsNode<AvgPool3DAttrs> {
Array<IndexExpr> dilation;
Array<IndexExpr> padding;
std::string layout;
tvm::String out_layout;
bool ceil_mode;
bool count_include_pad;

Expand All @@ -953,6 +1024,13 @@ struct AvgPool3DAttrs : public tvm::AttrsNode<AvgPool3DAttrs> {
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Pooling is applied on the 'D', 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(out_layout)
.set_default("")
.describe(
"Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Pooling is applied on the 'D', 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
"When true, will use ceil instead of floor to compute the output shape.");
TVM_ATTR_FIELD(count_include_pad)
Expand Down
110 changes: 97 additions & 13 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""Backend compiler related feature registration"""
from __future__ import absolute_import

from tvm import topi
from tvm import topi, relay
from tvm.topi.utils import get_const_tuple

from tvm.runtime import convert
Expand Down Expand Up @@ -267,9 +267,6 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layouts):
result : tvm.relay.Expr
The transformed expr
"""
# pylint: disable=import-outside-toplevel
from tvm import relay

data, weight = inputs

# First check if there is a LayoutConfig scope, and if so, whether
Expand Down Expand Up @@ -363,9 +360,6 @@ def convert_conv2d_transpose(attrs, inputs, tinfos, desired_layouts):
result : tvm.relay.Expr
The transformed expr
"""
# pylint: disable=import-outside-toplevel
from tvm import relay

data, weight = inputs
new_attrs = dict(attrs)
assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv2d's inputs"
Expand Down Expand Up @@ -446,9 +440,6 @@ def convert_conv3d(attrs, inputs, tinfos, desired_layouts):
result : tvm.relay.Expr
The transformed expr
"""
# pylint: disable=import-outside-toplevel
from tvm import relay

data, weight = inputs
new_attrs = dict(attrs)
assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv3d's inputs"
Expand Down Expand Up @@ -515,6 +506,30 @@ def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype):
reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_convert_op_layout("nn.max_pool2d")
def convert_max_pool2d(attrs, inputs, tinfos, desired_layouts):
"""Convert Layout pass registration for max_pool2d op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current pooling
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
tinfos : list of types
List of input and output types
desired_layouts : list of one layout string
layout string defining our desired layout for input and output.
Returns
-------
result : tvm.relay.Expr
The transformed expr
"""
new_attrs = dict(attrs)
new_attrs["layout"] = str(desired_layouts[0])
new_attrs["out_layout"] = str(desired_layouts[0])
return relay.nn.max_pool2d(*inputs, **new_attrs)


# max_pool3d
reg.register_schedule("nn.max_pool3d", strategy.schedule_pool)
reg.register_pattern("nn.max_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE)
Expand All @@ -530,6 +545,30 @@ def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype):
reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_convert_op_layout("nn.avg_pool2d")
def convert_avg_pool2d(attrs, inputs, tinfos, desired_layouts):
"""Convert Layout pass registration for avg_pool2d op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current pooling
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
tinfos : list of types
List of input and output types
desired_layouts : list of one layout string
layout string defining our desired layout for input and output.
Returns
-------
result : tvm.relay.Expr
The transformed expr
"""
new_attrs = dict(attrs)
new_attrs["layout"] = str(desired_layouts[0])
new_attrs["out_layout"] = str(desired_layouts[0])
return relay.nn.avg_pool2d(*inputs, **new_attrs)


# avg_pool3d
reg.register_schedule("nn.avg_pool3d", strategy.schedule_pool)
reg.register_pattern("nn.avg_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE)
Expand Down Expand Up @@ -560,11 +599,59 @@ def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype):
reg.register_pattern("nn.global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_convert_op_layout("nn.global_max_pool2d")
def convert_global_max_pool2d(attrs, inputs, tinfos, desired_layouts):
"""Convert Layout pass registration for global_max_pool2d op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current pooling
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
tinfos : list of types
List of input and output types
desired_layouts : list of one layout string
layout string defining our desired layout for input and output.
Returns
-------
result : tvm.relay.Expr
The transformed expr
"""
new_attrs = dict(attrs)
new_attrs["layout"] = str(desired_layouts[0])
new_attrs["out_layout"] = str(desired_layouts[0])
return relay.nn.global_max_pool2d(*inputs, **new_attrs)


# global_avg_pool2d
reg.register_schedule("nn.global_avg_pool2d", strategy.schedule_adaptive_pool)
reg.register_pattern("nn.global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_convert_op_layout("nn.global_avg_pool2d")
def convert_global_avg_pool2d(attrs, inputs, tinfos, desired_layouts):
"""Convert Layout pass registration for global_avg_pool2d op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current pooling
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
tinfos : list of types
List of input and output types
desired_layouts : list of one layout string
layout string defining our desired layout for input and output.
Returns
-------
result : tvm.relay.Expr
The transformed expr
"""
new_attrs = dict(attrs)
new_attrs["layout"] = str(desired_layouts[0])
new_attrs["out_layout"] = str(desired_layouts[0])
return relay.nn.global_avg_pool2d(*inputs, **new_attrs)


# adaptive_max_pool2d
reg.register_schedule("nn.adaptive_max_pool2d", strategy.schedule_adaptive_pool)
reg.register_pattern("nn.adaptive_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
Expand Down Expand Up @@ -796,9 +883,6 @@ def convert_deformable_conv2d(attrs, inputs, tinfos, desired_layouts):
result : tvm.relay.Expr
The transformed expr
"""
# pylint: disable=import-outside-toplevel
from tvm import relay

data, offset, weight = inputs
new_attrs = dict(attrs)
for attr in new_attrs:
Expand Down
Loading

0 comments on commit d11bdcd

Please sign in to comment.