Skip to content

Commit

Permalink
[Topi] fix get_pad_tuple3d bug, the conv3d kernel layout should be DH…
Browse files Browse the repository at this point in the history
…W. (apache#9788)
  • Loading branch information
FredJia-intellif authored Dec 28, 2021
1 parent 0f3441a commit 7448eab
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 11 deletions.
6 changes: 3 additions & 3 deletions python/tvm/topi/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,9 @@ def get_pad_tuple3d(padding, kernel):
pad_w = 0
pad_d = 0
elif padding == "SAME":
pad_h = kernel[0] - 1
pad_w = kernel[1] - 1
pad_d = kernel[2] - 1
pad_d = kernel[0] - 1
pad_h = kernel[1] - 1
pad_w = kernel[2] - 1
else:
raise ValueError("Unknown padding option %s" % padding)
pad_top = (pad_h + 1) // 2
Expand Down
66 changes: 58 additions & 8 deletions tests/python/topi/python/test_topi_conv3d_ncdhw.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,41 @@ def verify_conv3d_ncdhw(
add_bias=False,
add_relu=False,
):
if isinstance(kernel, (tuple, list)):
if len(kernel) == 3:
kernel_d = kernel[0]
kernel_h = kernel[1]
kernel_w = kernel[2]
else:
raise ValueError("Size of kernel can only be 3")
elif isinstance(kernel, int):
kernel_d = kernel_h = kernel_w = kernel
else:
raise ValueError("Unknown kernel option %s" % kernel)
pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = get_pad_tuple3d(
padding, (kernel, kernel, kernel)
padding, (kernel_d, kernel_h, kernel_w)
)
padding_sum = pad_front + pad_back + pad_top + pad_left + pad_bottom + pad_right
print(
"Workload: (%d, %d, %d, %d, %d, %d, %d, %d)"
% (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)
"Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d, %d)"
% (
batch,
in_channel,
in_size,
num_filter,
kernel_d,
kernel_h,
kernel_w,
stride,
padding_sum,
dilation,
)
)

in_depth = in_height = in_width = in_size

A = te.placeholder((batch, in_channel, in_depth, in_height, in_width), name="A")
W = te.placeholder((num_filter, in_channel, kernel, kernel, kernel), name="W")
W = te.placeholder((num_filter, in_channel, kernel_d, kernel_h, kernel_w), name="W")
bias = te.placeholder((num_filter, 1, 1, 1), name="bias")

a_shape = get_const_tuple(A.shape)
Expand Down Expand Up @@ -103,17 +125,39 @@ def check_target(target, dev):
s,
[A, W, bias, C],
target,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
% (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d"
% (
batch,
in_channel,
in_size,
num_filter,
kernel_d,
kernel_h,
kernel_w,
stride,
padding_sum,
dilation,
),
)
func(a, w, b, c)
else:
func = tvm.build(
s,
[A, W, C],
target,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
% (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d_%d"
% (
batch,
in_channel,
in_size,
num_filter,
kernel_d,
kernel_h,
kernel_w,
stride,
padding_sum,
dilation,
),
)
func(a, w, c)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4, atol=1e-6)
Expand Down Expand Up @@ -155,6 +199,12 @@ def test_conv3d_ncdhw():
verify_conv3d_ncdhw(1, 32, 32, 1, 3, 1, "VALID")
verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, "VALID")

# DHW kernel layout
verify_conv3d_ncdhw(1, 32, 56, 16, (3, 5, 7), 2, (1, 2, 3))
verify_conv3d_ncdhw(1, 3, 56, 16, (3, 7, 7), 2, (1, 2, 3, 0, 3, 2))
verify_conv3d_ncdhw(1, 3, 56, 16, (3, 3, 7), 2, (1, 2, 3))
verify_conv3d_ncdhw(1, 3, 56, 16, (3, 7, 3), 2, (1, 3, 1))


if __name__ == "__main__":
test_conv3d_ncdhw()
12 changes: 12 additions & 0 deletions tests/python/topi/python/test_topi_conv3d_transpose_ncdhw.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,18 @@ def test_conv3d_transpose_ncdhw():
verify_conv3d_transpose_ncdhw(
1, 8, (32, 32, 32), 64, (5, 5, 5), (2, 2, 2), (1, 1, 1, 1, 1, 1), (1, 1, 1)
)
verify_conv3d_transpose_ncdhw(
1, 8, (32, 32, 32), 64, (3, 5, 7), (2, 2, 2), (1, 1, 1, 1, 1, 1), (0, 0, 0)
)
verify_conv3d_transpose_ncdhw(
1, 8, (32, 32, 32), 64, (3, 5, 5), (2, 2, 2), (1, 1, 1, 1, 1, 1), (0, 0, 0)
)
verify_conv3d_transpose_ncdhw(
1, 8, (32, 32, 32), 64, (3, 3, 7), (2, 2, 2), (1, 1, 1, 1, 1, 1), (0, 0, 0)
)
verify_conv3d_transpose_ncdhw(
1, 8, (32, 32, 32), 64, (3, 5, 3), (2, 2, 2), (1, 1, 1, 1, 1, 1), (0, 0, 0)
)


if __name__ == "__main__":
Expand Down

0 comments on commit 7448eab

Please sign in to comment.