Skip to content

Commit

Permalink
[Topi][Testing] Float16 unittests for dense, conv2d, depthwise conv2d (
Browse files Browse the repository at this point in the history
…apache#8529)

* [Topi][Testing] Minor cleanup for python reference implementations

- Use input dtype for dilate/conv2d accumulate in python
  impl. Previously, the python implementations of dilation and conv2d
  would use numpy default dtype in some cases, rather than the input
  data's dtype.

- Added fallback for datatypes not supported by scipy.signal.convolve2d (e.g. float16).

- Refactored to avoid duplication, use common get_pad_tuple functionality.

* [Topi][UnitTests] Added float16 tests to test_topi_dense.py

* [Topi][UnitTests] Added float16 to test_topi_conv2d_nchw.py

* [Topi][Float16] Added float16 tests for depthwise conv2d.

* [UnitTests] Explicitly set seed for float16 tests

Intended to avoid flaky test failures later due to rounding errors.

* [UnitTests] Fixed a few failing unit tests.

- ref_data must be a test fixture, not acquired through
  request.getfixturevalue, in order to have the random_seed be known.

- dilate_python's return value didn't follow `out_dtype`.

- The test_topi_conv3d tests had the reference results computed in
  float64, due to dilate_python() not respecting the input data type.
  With the correct dtype, the tolerances needed to be slightly widened.

Co-authored-by: Eric Lunderberg <elunderberg@octoml.ai>
  • Loading branch information
2 people authored and ylc committed Jan 13, 2022
1 parent 6470eef commit 343efb8
Show file tree
Hide file tree
Showing 9 changed files with 302 additions and 132 deletions.
51 changes: 51 additions & 0 deletions python/tvm/topi/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
"""Common utility for topi test"""

import numpy as np
import scipy.signal

import tvm
from tvm import topi
from tvm.testing import assert_allclose
Expand Down Expand Up @@ -108,3 +110,52 @@ def compare_numpy_tvm(inputs, output, target, device, compute, schedule):
arys = [tvm.nd.array(x, device=device) for x in inputs]
func(*(arys + [te_out]))
assert_allclose(te_out.numpy(), output, atol=1e-4, rtol=1e-4)


def _convolve2d(data, weights):
"""2d convolution operator in HW layout.
This is intended to be used as a replacement for
scipy.signals.convolve2d, with wider support for different dtypes.
scipy.signal.convolve2d does not support all TVM-supported
dtypes (e.g. float16). Where possible, this function uses
scipy.signal.convolve2d to take advantage of compiled scipy
routines, falling back to an explicit loop only where needed.
Parameters
----------
data : numpy.ndarray
2-D with shape [in_height, in_width]
weights : numpy.ndarray
2-D with shape [filter_height, filter_width].
Returns
-------
b_np : np.ndarray
2-D with shape [out_height, out_width]
Return value and layout conventions are matched to
``scipy.signal.convolve2d(data, weights, mode="valid")``
"""

try:
return scipy.signal.convolve2d(data, weights, mode="valid")
except ValueError:
pass

weights = np.rot90(weights, k=2)

assert len(data.shape) == len(weights.shape) == 2

dtype = data.dtype
kernel_h, kernel_w = weights.shape

output_shape = [a_dim - w_dim + 1 for a_dim, w_dim in zip(data.shape, weights.shape)]
output = np.zeros(output_shape, dtype=dtype)

for y in range(output_shape[0]):
for x in range(output_shape[1]):
output[y][x] = np.sum(data[y : y + kernel_h, x : x + kernel_w] * weights)

return output
55 changes: 51 additions & 4 deletions python/tvm/topi/testing/conv2d_nchw_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals, too-many-branches
"""Convolution in python"""
import numpy as np
import scipy.signal
import scipy

from tvm.topi.nn.utils import get_pad_tuple


Expand Down Expand Up @@ -58,21 +59,67 @@ def _conv2d_nchw_python(a_np, w_np, stride, padding):
out_channel = num_filter
out_height = (in_height - kernel_h + pad_h) // stride_h + 1
out_width = (in_width - kernel_w + pad_w) // stride_w + 1
b_np = np.zeros((batch, out_channel, out_height, out_width))
b_np = np.zeros((batch, out_channel, out_height, out_width), dtype=a_np.dtype)
# computation
for n in range(batch):
for f in range(out_channel):
for c in range(in_channel):
if pad_h > 0 or pad_w > 0:
apad = np.zeros((in_height + pad_h, in_width + pad_w))
apad = np.zeros((in_height + pad_h, in_width + pad_w), dtype=a_np.dtype)
apad[pad_top : pad_top + in_height, pad_left : pad_left + in_width] = a_np[n, c]
else:
apad = a_np[n, c]
out = scipy.signal.convolve2d(apad, np.rot90(np.rot90(w_np[f, c])), mode="valid")

out = _conv2d_hw(apad, w_np[f, c])
b_np[n, f] += out[::stride_h, ::stride_w]
return b_np


def _conv2d_hw(apad, w_np_fc):
"""2d convolution operator in HW layout.
This is intended to be used as a subroutine from
_conv2d_nchw_python. Using scipy.signal.convolve2d directly does
not work for all dtypes (e.g. float16). Where possible, this
function uses scipy.signal.convolve2d to take advantage of
compiled scipy routines, falling back to an explicit loop only
where needed
Parameters
----------
a_np : numpy.ndarray
2-D with shape [in_height, in_width]
w_np : numpy.ndarray
2-D with shape [filter_height, filter_width].
Returns
-------
b_np : np.ndarray
2-D with shape [out_height, out_width]
"""

try:
return scipy.signal.convolve2d(apad, np.rot90(np.rot90(w_np_fc)), mode="valid")
except ValueError:
pass

assert len(apad.shape) == len(w_np_fc.shape) == 2

dtype = apad.dtype
in_height, in_width = apad.shape
kernel_h, kernel_w = w_np_fc.shape

output_shape = [a_dim - w_dim + 1 for a_dim, w_dim in zip(apad.shape, w_np_fc.shape)]
output = np.zeros(output_shape, dtype=apad.dtype)

for y in range(output_shape[0]):
for x in range(output_shape[1]):
output[y][x] = np.sum(apad[y : y + kernel_h, x : x + kernel_w] * w_np_fc)

return output


def conv2d_nchw_python(a_np, w_np, stride, padding, groups=1):
"""Convolution operator in NCHW layout.
Expand Down
118 changes: 34 additions & 84 deletions python/tvm/topi/testing/depthwise_conv2d_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
# pylint: disable=invalid-name, unused-variable, line-too-long
"""Depthwise convolution in python"""
import numpy as np
from scipy import signal

from tvm.topi.nn.utils import get_pad_tuple
from .common import _convolve2d


def depthwise_conv2d_python_nchw(input_np, filter_np, stride, padding):
Expand Down Expand Up @@ -49,42 +51,29 @@ def depthwise_conv2d_python_nchw(input_np, filter_np, stride, padding):
else:
stride_h, stride_w = stride

# calculate output shape
if padding == "VALID":
out_channel = in_channel * channel_multiplier
out_height = (in_height - filter_height) // stride_h + 1
out_width = (in_width - filter_width) // stride_w + 1
output_np = np.zeros((batch, out_channel, out_height, out_width))
for i in range(batch):
for j in range(out_channel):
output_np[i, j, :, :] = signal.convolve2d(
input_np[i, j // channel_multiplier, :, :],
np.rot90(filter_np[j // channel_multiplier, j % channel_multiplier, :, :], 2),
mode="valid",
)[
0 : (in_height - filter_height + 1) : stride_h,
0 : (in_width - filter_width + 1) : stride_w,
]
elif padding == "SAME":
out_channel = in_channel * channel_multiplier
out_height = int(np.ceil(float(in_height) / float(stride_h)))
out_width = int(np.ceil(float(in_width) / float(stride_w)))
output_np = np.zeros((batch, out_channel, out_height, out_width))
pad_along_height = int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0))
pad_along_width = int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0))
pad_top_tvm = int(np.ceil(float(pad_along_height) / 2))
pad_left_tvm = int(np.ceil(float(pad_along_width) / 2))
pad_top_scipy = int(np.ceil(float(filter_height - 1) / 2))
pad_left_scipy = int(np.ceil(float(filter_width - 1) / 2))
index_h = pad_top_scipy - pad_top_tvm
index_w = pad_left_scipy - pad_left_tvm
for i in range(batch):
for j in range(out_channel):
output_np[i, j, :, :] = signal.convolve2d(
input_np[i, j // channel_multiplier, :, :],
np.rot90(filter_np[j // channel_multiplier, j % channel_multiplier, :, :], 2),
mode="same",
)[index_h:in_height:stride_h, index_w:in_width:stride_w]
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (filter_height, filter_width))
pad_h = pad_top + pad_bottom
pad_w = pad_left + pad_right

out_channel = in_channel * channel_multiplier
out_height = (in_height - filter_height + pad_h) // stride_h + 1
out_width = (in_width - filter_width + pad_w) // stride_w + 1
output_np = np.zeros((batch, out_channel, out_height, out_width))

for i in range(batch):
for j in range(out_channel):
apad = input_np[i, j // channel_multiplier, :, :]
if pad_h or pad_w:
apad = np.pad(apad, [(pad_top, pad_bottom), (pad_left, pad_right)])

conv = _convolve2d(
apad,
np.rot90(filter_np[j // channel_multiplier, j % channel_multiplier, :, :], k=2),
)
output_np[i, j, :, :] = conv[
::stride_h,
::stride_w,
]

return output_np

Expand Down Expand Up @@ -139,15 +128,17 @@ def depthwise_conv2d_python_nchwc(input_np, filter_np, stride, padding):
# Perform conv2d
output_np = depthwise_conv2d_python_nchw(input_nchw, filter_nchw, stride, padding)

# Transform back
# Transform back to NCHWc

# pylint: disable=unpacking-non-sequence
batch_size, out_channel, out_height, out_width = output_np.shape
return output_np.reshape(
(batch_size, out_channel_chunk, out_channel_block, out_height, out_width)
).transpose(0, 1, 3, 4, 2)


def depthwise_conv2d_python_nhwc(input_np, filter_np, stride, padding):
"""Depthwise convolution operator in nchw layout.
"""Depthwise convolution operator in nhwc layout.
Parameters
----------
Expand All @@ -168,48 +159,7 @@ def depthwise_conv2d_python_nhwc(input_np, filter_np, stride, padding):
output_np : np.ndarray
4-D with shape [batch, out_height, out_width, out_channel]
"""
batch, in_height, in_width, in_channel = input_np.shape
filter_height, filter_width, _, channel_multiplier = filter_np.shape
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride

# calculate output shape
if padding == "VALID":
out_channel = in_channel * channel_multiplier
out_height = (in_height - filter_height) // stride_h + 1
out_width = (in_width - filter_width) // stride_w + 1
output_np = np.zeros((batch, out_height, out_width, out_channel))
for i in range(batch):
for j in range(out_channel):
output_np[i, :, :, j] = signal.convolve2d(
input_np[i, :, :, j // channel_multiplier],
np.rot90(filter_np[:, :, j // channel_multiplier, j % channel_multiplier], 2),
mode="valid",
)[
0 : (in_height - filter_height + 1) : stride_h,
0 : (in_width - filter_width + 1) : stride_w,
]
if padding == "SAME":
out_channel = in_channel * channel_multiplier
out_height = int(np.ceil(float(in_height) / float(stride_h)))
out_width = int(np.ceil(float(in_width) / float(stride_w)))
output_np = np.zeros((batch, out_height, out_width, out_channel))
pad_along_height = int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0))
pad_along_width = int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0))
pad_top_tvm = int(np.ceil(float(pad_along_height) / 2))
pad_left_tvm = int(np.ceil(float(pad_along_width) / 2))
pad_top_scipy = int(np.ceil(float(filter_height - 1) / 2))
pad_left_scipy = int(np.ceil(float(filter_width - 1) / 2))
index_h = pad_top_scipy - pad_top_tvm
index_w = pad_left_scipy - pad_left_tvm
for i in range(batch):
for j in range(out_channel):
output_np[i, :, :, j] = signal.convolve2d(
input_np[i, :, :, j // channel_multiplier],
np.rot90(filter_np[:, :, j // channel_multiplier, j % channel_multiplier], 2),
mode="same",
)[index_h:in_height:stride_h, index_w:in_width:stride_w]

return output_np
input_nchw = input_np.transpose(0, 3, 1, 2)
filter_nchw = filter_np.transpose(2, 3, 0, 1)
output_nchw = depthwise_conv2d_python_nchw(input_nchw, filter_nchw, stride, padding)
return output_nchw.transpose(0, 2, 3, 1)
35 changes: 23 additions & 12 deletions python/tvm/topi/testing/dilate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np


def dilate_python(input_np, strides, dilation_value=0.0):
def dilate_python(input_np, strides, dilation_value=0.0, out_dtype=None):
"""Dilate operation.
Parameters
Expand All @@ -33,23 +33,34 @@ def dilate_python(input_np, strides, dilation_value=0.0):
dilation_value : int/float, optional
Value used to dilate the input.
out_dtype : Option[str]
The datatype of the dilated array. If unspecified, will use
the same dtype as the input array.
Returns
-------
output_np : numpy.ndarray
n-D, the same layout as Input.
"""
n = len(input_np.shape)
assert len(strides) == n, "Input dimension and strides size dismatch : %d vs %d" % (
n,
assert len(input_np.shape) == len(
strides
), "Input dimension and strides size dismatch : %d vs %d" % (
len(input_np.shape),
len(strides),
)
output_size = ()
no_zero = ()
for i in range(n):
output_size += ((input_np.shape[i] - 1) * strides[i] + 1,)
no_zero += ((range(0, output_size[i], strides[i])),)
output_np = np.ones(shape=output_size)
output_np = dilation_value * output_np
output_np[np.ix_(*no_zero)] = input_np

if out_dtype is None:
out_dtype = input_np.dtype

output_size = [
(input_dim - 1) * stride + 1 for input_dim, stride in zip(input_np.shape, strides)
]
non_zero_elements = np.ix_(
*[range(0, output_dim, stride) for output_dim, stride in zip(output_size, strides)]
)

output_np = np.full(shape=output_size, fill_value=dilation_value, dtype=out_dtype)
output_np[non_zero_elements] = input_np

return output_np
Loading

0 comments on commit 343efb8

Please sign in to comment.