Skip to content

Commit

Permalink
int32 pooling with int64 shapes (apache#6687)
Browse files Browse the repository at this point in the history
* Failing tests for Int32 avg_pooling with Int64 shapes

* fix pooling implementations
  • Loading branch information
Matthew Brookhart authored and Trevor Morris committed Dec 2, 2020
1 parent a9213c7 commit 026c06a
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 93 deletions.
42 changes: 29 additions & 13 deletions include/tvm/topi/nn/pooling.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ inline Tensor pool_impl(const Tensor& x, const Array<PrimExpr>& kernel_size,
auto stride_height = cast(DataType::DataType::Int(32), stride_size[0]);
auto stride_width = cast(DataType::DataType::Int(32), stride_size[1]);

auto height = x->shape[height_axis];
auto width = x->shape[width_axis];
auto height = cast(DataType::DataType::Int(32), x->shape[height_axis]);
auto width = cast(DataType::DataType::Int(32), x->shape[width_axis]);

auto pad_top = cast(DataType::DataType::Int(32), padding_size[0]);
auto pad_left = cast(DataType::DataType::Int(32), padding_size[1]);
Expand Down Expand Up @@ -107,6 +107,9 @@ inline Tensor pool_impl(const Tensor& x, const Array<PrimExpr>& kernel_size,
auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width));

Array<PrimExpr> out_shape = x->shape;
for (size_t i = 0; i < out_shape.size(); ++i) {
out_shape.Set(i, cast(DataType::DataType::Int(32), out_shape[i]));
}
out_shape.Set(height_axis, out_height);
out_shape.Set(width_axis, out_width);

Expand Down Expand Up @@ -189,8 +192,8 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
auto stride_height = cast(DataType::DataType::Int(32), stride_size[0]);
auto stride_width = cast(DataType::DataType::Int(32), stride_size[1]);

auto height = x->shape[height_axis];
auto width = x->shape[width_axis];
auto height = cast(DataType::DataType::Int(32), x->shape[height_axis]);
auto width = cast(DataType::DataType::Int(32), x->shape[width_axis]);

auto pad_top = cast(DataType::DataType::Int(32), padding_size[0]);
auto pad_left = cast(DataType::DataType::Int(32), padding_size[1]);
Expand Down Expand Up @@ -220,7 +223,12 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
auto dheight = tvm::te::reduce_axis(Range(0, kernel_height));
auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width));

Array<PrimExpr> out_shape = x->shape;
Array<PrimExpr> data_shape = x->shape;
for (size_t i = 0; i < data_shape.size(); ++i) {
data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
}

Array<PrimExpr> out_shape = data_shape;
out_shape.Set(height_axis, out_height);
out_shape.Set(width_axis, out_width);

Expand All @@ -232,7 +240,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));

if (pool_type == kMaxPool) {
Array<PrimExpr> ravel_shape{x->shape.begin(), x->shape.end()};
Array<PrimExpr> ravel_shape{data_shape.begin(), data_shape.end()};
ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom);
ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right);

Expand All @@ -257,7 +265,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
auto mp_inds = mp_argmax[0];

return tvm::te::compute(
x->shape,
data_shape,
[&](const Array<Var>& inds) {
Array<PrimExpr> pad_inds{inds.begin(), inds.end()};
pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top);
Expand Down Expand Up @@ -288,7 +296,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height));
auto windoww = tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width));
return tvm::te::compute(
x->shape,
data_shape,
[&](const Array<Var>& inds) {
PrimExpr pad_h_idx = inds[height_axis] + pad_top;
PrimExpr pad_w_idx = inds[width_axis] + pad_left;
Expand Down Expand Up @@ -483,10 +491,14 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array<PrimExpr>& output_
const auto n_dim = output_size.size();
CHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out dimension";

Array<PrimExpr> out_shape = x->shape;
Array<PrimExpr> data_shape = x->shape;
for (size_t i = 0; i < data_shape.size(); ++i) {
data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
}
Array<PrimExpr> out_shape = data_shape;
Array<PrimExpr> in_size, out_size;
for (size_t i = 0; i < n_dim; ++i) {
in_size.push_back(x->shape[axes[i]]);
in_size.push_back(data_shape[axes[i]]);
out_size.push_back(cast(DataType::Int(32), output_size[i]));
out_shape.Set(axes[i], out_size[i]);
}
Expand Down Expand Up @@ -661,7 +673,11 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
std::vector<PrimExpr> pad_tail(k_size);
Array<PrimExpr> pad_before(std::vector<PrimExpr>(x_size, 0));
Array<PrimExpr> pad_after(std::vector<PrimExpr>(x_size, 0));
Array<PrimExpr> out_shape = x->shape;
Array<PrimExpr> data_shape = x->shape;
for (size_t i = 0; i < data_shape.size(); ++i) {
data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
}
Array<PrimExpr> out_shape = data_shape;

bool do_pad = false;
for (int i = 0; i < k_size; i++) {
Expand All @@ -687,7 +703,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,

arith::Analyzer analyzer;
auto out_dim = analyzer.Simplify(
indexdiv(x->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1);
indexdiv(data_shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1);

out_shape.Set(ii, out_dim);
}
Expand Down Expand Up @@ -746,7 +762,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
for (int i = 0; i < k_size; i++) {
int ii = axis[i];
start[i] = output[ii] * stride[i] - pad_head[i];
end[i] = min(start[i] + kernel[i], x->shape[ii]);
end[i] = min(start[i] + kernel[i], data_shape[ii]);
start[i] = max(start[i], make_const(DataType::Int(32), 0));
kernel_size *= (end[i] - start[i]);
}
Expand Down
75 changes: 44 additions & 31 deletions tests/python/relay/test_op_grad_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,39 +66,43 @@ def test_max_pool2d_grad():
)


def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, count_include_pad):
x = relay.var("x", relay.TensorType(x_shape, "float32"))
y = tvm.relay.nn.avg_pool2d(
x,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
)

fwd_func = relay.Function([x], y)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func))
def verify_avg_pool2d_grad(
x_shape, pool_size, strides, padding, ceil_mode, count_include_pad, dtype="float32"
):

for shape_dtype in ["int32", "int64"]:
x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in x_shape], dtype=dtype)
y = tvm.relay.nn.avg_pool2d(
x,
pool_size=pool_size,
strides=strides,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
)

data = np.random.rand(*x_shape).astype("float32")
ph, pw = padding
y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
out_grad = np.ones(shape=y_shape)
ref_grad = tvm.topi.testing.pool_grad_nchw(
data,
out_grad,
pool_size=pool_size,
strides=strides,
padding=[ph, pw, ph, pw],
pool_type="avg",
ceil_mode=ceil_mode,
)
fwd_func = relay.Function([x], y)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func))

data = np.random.rand(*x_shape).astype(dtype)
ph, pw = padding
y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
out_grad = np.ones(shape=y_shape)
ref_grad = tvm.topi.testing.pool_grad_nchw(
data,
out_grad,
pool_size=pool_size,
strides=strides,
padding=[ph, pw, ph, pw],
pool_type="avg",
ceil_mode=ceil_mode,
)

for target, ctx in tvm.testing.enabled_targets():
intrp = relay.create_executor(ctx=ctx, target=target)
op_res, (op_grad,) = intrp.evaluate(bwd_func)(data)
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
for target, ctx in tvm.testing.enabled_targets():
intrp = relay.create_executor(ctx=ctx, target=target)
op_res, (op_grad,) = intrp.evaluate(bwd_func)(data)
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)


@tvm.testing.uses_gpu
Expand All @@ -119,6 +123,15 @@ def test_avg_pool2d_grad():
ceil_mode=False,
count_include_pad=False,
)
verify_avg_pool2d_grad(
(1, 4, 16, 16),
pool_size=(1, 1),
strides=(1, 1),
padding=(1, 1),
ceil_mode=False,
count_include_pad=False,
dtype="int32",
)


def verify_global_avg_pool2d_grad(x_shape):
Expand Down
22 changes: 13 additions & 9 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,17 +425,18 @@ def verify_ndarray_size(shape):


def verify_adaptive_pool(dshape, out_size, pool_type, layout, dtype, opfunc):
x = relay.var("x", relay.TensorType(dshape, "float32"))
y = opfunc(x, out_size, layout)
func = relay.Function([x], y)
for shape_dtype in ["int32", "int64"]:
x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype)
y = opfunc(x, out_size, layout)
func = relay.Function([x], y)

np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype)
np_out = tvm.topi.testing.adaptive_pool(np_data, out_size, pool_type, layout)
np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype)
np_out = tvm.topi.testing.adaptive_pool(np_data, out_size, pool_type, layout)

for target, ctx in tvm.testing.enabled_targets():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
relay_out = intrp1.evaluate(func)(np_data)
tvm.testing.assert_allclose(relay_out.asnumpy(), np_out, rtol=1e-5, atol=1e-5)
for target, ctx in tvm.testing.enabled_targets():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
relay_out = intrp1.evaluate(func)(np_data)
tvm.testing.assert_allclose(relay_out.asnumpy(), np_out, rtol=1e-5, atol=1e-5)


def verify_adaptive_pool2d(dshape, out_size, pool_type, layout="NCHW", dtype="float32"):
Expand All @@ -452,13 +453,16 @@ def verify_adaptive_pool3d(dshape, out_size, pool_type, layout="NCHW", dtype="fl
def test_adaptive_pool():
verify_adaptive_pool2d((1, 9, 224, 224), (1, 1), "max")
verify_adaptive_pool2d((1, 3, 224, 224), (2, 3), "avg")
verify_adaptive_pool2d((1, 3, 224, 224), (2, 3), "avg", dtype="int32")
verify_adaptive_pool2d((1, 14, 56, 78), (34, 13), "max")
verify_adaptive_pool2d((1, 5, 46, 97), (4, 96), "avg")
verify_adaptive_pool2d((1, 224, 224, 3), (1, 1), "max", layout="NHWC")
verify_adaptive_pool2d((1, 3, 224, 224), (2, 3), "avg", layout="NHWC")
verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "max", layout="NCDHW")
verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "avg", layout="NCDHW")
verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "avg", layout="NDHWC")
verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "avg", layout="NCDHW", dtype="int32")
verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "avg", layout="NDHWC", dtype="int32")
verify_adaptive_pool3d((1, 16, 32, 32, 32), (2, 4, 4), "max", layout="NDHWC")


Expand Down
87 changes: 47 additions & 40 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,15 +959,16 @@ def _test_pool2d_int(opfunc, reffunc, dtype):
# test execution
dtype = "int32"
dshape = (1, 3, 28, 28)
x = relay.var("x", shape=dshape, dtype=dtype)
y = opfunc(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
func = relay.Function([x], y)
data = np.random.randint(low=-128, high=128, size=dshape)
ref_res = reffunc(data.reshape(1, 3, 14, 2, 14, 2), axis=(3, 5)).astype(dtype)
for target, ctx in tvm.testing.enabled_targets():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
for shape_dtype in ["int32", "int64"]:
x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype)
y = opfunc(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
func = relay.Function([x], y)
data = np.random.randint(low=-128, high=128, size=dshape)
ref_res = reffunc(data.reshape(1, 3, 14, 2, 14, 2), axis=(3, 5)).astype(dtype)
for target, ctx in tvm.testing.enabled_targets():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)


def _test_global_pool2d(opfunc, reffunc):
Expand Down Expand Up @@ -1010,32 +1011,34 @@ def test_pool2d():

@tvm.testing.uses_gpu
def test_pool1d():
def _test_pool1d(opfunc, pool_size=(2,), strides=(2,), padding=(0, 0)):
def _test_pool1d(opfunc, pool_size=(2,), strides=(2,), padding=(0, 0), dtype="float32"):
n, c, w = te.var("n"), 10, 224
x = relay.var("x", relay.TensorType((n, c, w), "float32"))
y = opfunc(x, pool_size=(1,))
assert "pool_size=" in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 10, 224), "float32")
# test execution
dtype = "float32"
dshape = (1, 3, 32)
x = relay.var("x", shape=dshape)
pool_type = "max" if "max" in str(opfunc) else "avg"
y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding)
func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype)
ref_res = tvm.topi.testing.pool1d_ncw_python(
data, (2,), (2,), (0, 0), (1, 3, 16), pool_type, False
)
for target, ctx in tvm.testing.enabled_targets():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
for shape_dtype in ["int32", "int64"]:
x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype)
pool_type = "max" if "max" in str(opfunc) else "avg"
y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding)
func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype)
ref_res = tvm.topi.testing.pool1d_ncw_python(
data, (2,), (2,), (0, 0), (1, 3, 16), pool_type, False
)
for target, ctx in tvm.testing.enabled_targets():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)

_test_pool1d(relay.nn.max_pool1d)
_test_pool1d(relay.nn.max_pool1d, dtype="int32")
_test_pool1d(relay.nn.max_pool1d, pool_size=2, strides=2, padding=0)
_test_pool1d(relay.nn.avg_pool1d)
_test_pool1d(relay.nn.avg_pool1d, dtype="int32")
_test_pool1d(relay.nn.avg_pool1d, pool_size=2, strides=2, padding=0)


Expand All @@ -1047,6 +1050,7 @@ def _test_pool3d(
strides=(2, 2, 2),
padding=(0, 0, 0, 0, 0, 0),
out_shape=(1, 3, 16, 16, 16),
dtype="float32",
):
n, c, d, h, w = te.size_var("n"), 10, 5, 224, 224
x = relay.var("x", relay.TensorType((n, c, d, h, w), "float32"))
Expand All @@ -1057,30 +1061,33 @@ def _test_pool3d(
# test execution
dtype = "float32"
dshape = (1, 3, 32, 32, 32)
x = relay.var("x", shape=dshape)
pool_type = "max" if "max" in str(opfunc) else "avg"
y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding)
func = relay.Function([x], y)
# check output shape
f_out_shape = tuple(map(lambda x: int(x), run_infer_type(func).ret_type.shape))
assert out_shape == f_out_shape, "Output shape mismatch. expected {}, actual {}".format(
out_shape, f_out_shape
)
data = np.random.uniform(size=dshape).astype(dtype)
ref_res = tvm.topi.testing.pool3d_ncdhw_python(
data, pool_size, strides, padding, out_shape, pool_type, False
)
for target, ctx in tvm.testing.enabled_targets():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
for shape_dtype in ["int32", "int64"]:
x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype)
pool_type = "max" if "max" in str(opfunc) else "avg"
y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding)
func = relay.Function([x], y)
# check output shape
f_out_shape = tuple(map(lambda x: int(x), run_infer_type(func).ret_type.shape))
assert out_shape == f_out_shape, "Output shape mismatch. expected {}, actual {}".format(
out_shape, f_out_shape
)
data = np.random.uniform(size=dshape).astype(dtype)
ref_res = tvm.topi.testing.pool3d_ncdhw_python(
data, pool_size, strides, padding, out_shape, pool_type, False
)
for target, ctx in tvm.testing.enabled_targets():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)

_test_pool3d(relay.nn.max_pool3d)
_test_pool3d(relay.nn.max_pool3d, dtype="int32")
_test_pool3d(relay.nn.max_pool3d, padding=(2, 0, 0, 2, 0, 0), out_shape=(1, 3, 18, 16, 16))
_test_pool3d(relay.nn.max_pool3d, padding=(0, 3, 0, 0, 3, 0), out_shape=(1, 3, 16, 19, 16))
_test_pool3d(relay.nn.max_pool3d, padding=(0, 0, 4, 0, 0, 4), out_shape=(1, 3, 16, 16, 20))
_test_pool3d(relay.nn.max_pool3d, pool_size=2, padding=0, strides=2)
_test_pool3d(relay.nn.avg_pool3d)
_test_pool3d(relay.nn.avg_pool3d, dtype="int32")
_test_pool3d(relay.nn.avg_pool3d, padding=(2, 0, 0, 2, 0, 0), out_shape=(1, 3, 18, 16, 16))
_test_pool3d(relay.nn.avg_pool3d, padding=(0, 3, 0, 0, 3, 0), out_shape=(1, 3, 16, 19, 16))
_test_pool3d(relay.nn.avg_pool3d, padding=(0, 0, 4, 0, 0, 4), out_shape=(1, 3, 16, 16, 20))
Expand Down

0 comments on commit 026c06a

Please sign in to comment.