Skip to content

Commit

Permalink
fuse constant padding into conv kernels (#7515)
Browse files Browse the repository at this point in the history
* fuse constant padding into conv kernels

* change the kernel to support other layouts

* add channel-last test

* add a comment about bailing early
  • Loading branch information
Matthew Brookhart authored Mar 2, 2021
1 parent a1d43c1 commit 633ee11
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 0 deletions.
116 changes: 116 additions & 0 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,121 @@ class SimplifyReshape : public SimplifyPattern {
DFPattern x_;
};

/*!
* \brief SimplifyConvPad matches a pad followed by a conv/convtranspose/pool/etc
* with a pad attribute and merges the padding into the kernel.
*/
class SimplifyConvPad : public SimplifyPattern {
public:
SimplifyConvPad() {
x_ = IsWildcard();
w_ = IsWildcard();
pad_ = IsOp("nn.pad")({x_});
conv1d_ = IsOp("nn.conv1d");
conv2d_ = IsOp("nn.conv2d");
conv3d_ = IsOp("nn.conv3d");
conv_ = (conv1d_ || conv2d_ || conv3d_)({pad_, w_});
pattern_ = conv_;
}
template <typename T>
Attrs MakeConvAttrs(const T* old_attrs, const Array<PrimExpr> padding) const {
ICHECK(old_attrs);
ICHECK(padding.size() == old_attrs->padding.size())
<< "Number of dimensions to pad and convolution padding attributes should have the same "
"extent";

auto new_attrs = make_object<T>();
Array<PrimExpr> combined_padding;
for (size_t i = 0; i < padding.size(); ++i) {
combined_padding.push_back(padding[i] + old_attrs->padding[i]);
}
new_attrs->strides = old_attrs->strides;
new_attrs->padding = combined_padding;
new_attrs->dilation = old_attrs->dilation;
new_attrs->groups = old_attrs->groups;
new_attrs->channels = old_attrs->channels;
new_attrs->kernel_size = old_attrs->kernel_size;
new_attrs->data_layout = old_attrs->data_layout;
new_attrs->kernel_layout = old_attrs->kernel_layout;
new_attrs->out_layout = old_attrs->out_layout;
new_attrs->out_dtype = old_attrs->out_dtype;
return Attrs(new_attrs);
}
template <typename T>
Attrs GetAttrs(const PadAttrs* param, const T* attrs) const {
ICHECK(param);
ICHECK(attrs);
ICHECK(attrs->data_layout.size() == param->pad_width.size())
<< "Data Layout and padding attributes should have the same extent";

std::string data_layout = attrs->data_layout;
std::set<char> image_dims({'H', 'W', 'D'});
Array<PrimExpr> padding;
// If we're padding a non-spatial dimension, don't simplify
// Convolution can only pad on spatial axes
for (size_t i = 0; i < param->pad_width.size(); ++i) {
if (!image_dims.count(data_layout[i])) {
for (size_t j = 0; j < param->pad_width[i].size(); ++j) {
if (param->pad_width[i][j] != 0) {
return Attrs();
}
}
}
}
for (size_t j = 0; j < param->pad_width[0].size(); ++j) {
for (size_t i = 0; i < param->pad_width.size(); ++i) {
if (image_dims.count(data_layout[i])) {
padding.push_back(param->pad_width[i][j]);
}
}
}

return MakeConvAttrs(attrs, padding);
}
Expr callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
const CallNode* call_node = post.as<CallNode>();
ICHECK(call_node);
auto pad = node_map[pad_][0];
const CallNode* pad_node = pad.as<CallNode>();
ICHECK(pad_node);
const PadAttrs* param = pad_node->attrs.as<PadAttrs>();
ICHECK(param);
if (param->pad_mode == "constant" && param->pad_value == 0.0) {
Attrs attrs;
if (node_map.count(conv1d_)) {
attrs = GetAttrs(param, call_node->attrs.as<Conv1DAttrs>());
} else if (node_map.count(conv2d_)) {
attrs = GetAttrs(param, call_node->attrs.as<Conv2DAttrs>());
} else if (node_map.count(conv3d_)) {
attrs = GetAttrs(param, call_node->attrs.as<Conv3DAttrs>());
} else {
return post;
}
if (!attrs.defined()) {
return post;
}
auto x = node_map[x_][0];
auto w = node_map[w_][0];
return Call(call_node->op, {x, w}, attrs, call_node->type_args, call_node->span);
}
return post;
}

private:
/*! \brief Pattern input */
DFPattern x_;
/*! \brief Pattern input weight */
DFPattern w_;
/*! \brief Pattern pad */
DFPattern pad_;
/*! \brief Pattern conv */
DFPattern conv_;
DFPattern conv1d_;
DFPattern conv2d_;
DFPattern conv3d_;
};

/*!
* \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op
*/
Expand Down Expand Up @@ -163,6 +278,7 @@ class ExprSimplifier {
explicit ExprSimplifier(IRModule mod) : mod_(mod) {
CreateCallback(SimplifyReshape());
CreateCallback(FullElementwise());
CreateCallback(SimplifyConvPad());
}
template <typename T>
void CreateCallback(const T& pattern) {
Expand Down
78 changes: 78 additions & 0 deletions tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from tvm.relay import transform
from tvm.relay.testing import run_opt_pass

import numpy as np


def test_simplify_reshape():
def before():
Expand Down Expand Up @@ -122,6 +124,82 @@ def after_right(x, elem_op, value):
validate(shape, value, dtype)


def test_simplify_conv_pad():
convs = [relay.nn.conv1d, relay.nn.conv2d, relay.nn.conv3d]

def validate(ndim, pad_width, pad_value, pad_mode, orig_padding, layout):
if layout[1] == "C":
shape = [1, 3] + [10] * ndim
wshape = [8, 3] + [3] * ndim
elif layout[-1] == "C":
shape = [1] + [10] * ndim + [3]
wshape = [8] + [3] * ndim + [3]
else:
raise ValueError("This test only supports NC* and N*C")

x = relay.var("x", shape=shape, dtype="float32")
w = relay.var("w", shape=wshape, dtype="float32")
pad = relay.nn.pad(x, pad_width, pad_value, pad_mode)
if layout[1] == "C":
conv = convs[ndim - 1](pad, w, padding=orig_padding)
else:
conv = convs[ndim - 1](
pad, w, padding=orig_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :]
)

if pad_mode == "constant" and pad_value == 0:
new_padding = []
for j in range(2):
for i in range(len(pad_width)):
if layout[i] in ["D", "H", "W"]:
new_padding.append(pad_width[i][j])
for i in range(len(new_padding)):
new_padding[i] += orig_padding[i]
if layout[1] == "C":
after = convs[ndim - 1](x, w, padding=new_padding)
else:
after = convs[ndim - 1](
x, w, padding=new_padding, data_layout=layout, kernel_layout="DHWIO"[3 - ndim :]
)
else:
after = conv

zz = run_opt_pass(conv, transform.SimplifyExpr())
expected = run_opt_pass(after, transform.InferType())
assert tvm.ir.structural_equal(zz, expected)

mod1 = tvm.IRModule.from_expr(conv)
mod2 = tvm.IRModule.from_expr(zz)

with tvm.transform.PassContext(disabled_pass="SimplifyExpr"):
ex1 = relay.create_executor("vm", mod=mod1, ctx=tvm.cpu(), target="llvm")
ex2 = relay.create_executor("vm", mod=mod2, ctx=tvm.cpu(), target="llvm")
x_np = np.random.rand(*shape).astype("float32")
w_np = np.random.rand(*wshape).astype("float32")
result1 = ex1.evaluate()(x_np, w_np)
result2 = ex2.evaluate()(x_np, w_np)

tvm.testing.assert_allclose(result1.asnumpy(), result2.asnumpy())

for orig_pad in [[0, 0], [2, 0], [0, 2]]:
for i_pad in [[0, 0], [1, 1], [1, 0]]:
for ndim in [1, 2, 3]:
for channels_last in [0, 1]:
if channels_last:
layout = "NDHWC"
layout = layout[0:1] + layout[4 - ndim : 4] + layout[-1:]
padding = [[0, 0]] + [i_pad] * ndim + [[0, 0]]
else:
layout = "NCDHW"
layout = layout[0:2] + layout[5 - ndim :]
padding = [[0, 0]] * 2 + [i_pad] * ndim

validate(ndim, padding, 0, "constant", orig_pad * ndim, layout)
ndim = 2
validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 1, "constant", orig_pad * ndim, "NCHW")
validate(ndim, [[0, 0]] * 2 + [i_pad] * ndim, 0, "edge", orig_pad * ndim, "NCHW")


if __name__ == "__main__":
test_simplify_reshape()
test_simplify_full_elementwise()

0 comments on commit 633ee11

Please sign in to comment.