Skip to content

Commit

Permalink
Prepare the topi tests for AArch64 CI.
Browse files Browse the repository at this point in the history
This pull request cleans up the TOPI testuite for use on the AArch64
CI target by doing the following:

- Introducing a script to run the tests on AArch64 with a suitable
  invocation of the llvm target string by setting the TVM_TEST_TARGETS
  environment variable.

- Cleaning up the use of hard coded targets and moving the testsuite
  to testing more sensibly with the use of tvm.testing.enabled_targets.

- Cleanup the use of tvm.target.create.

- The above allows for the use of tests reasonably with the topi
  tests and cleans up what is needed from the testsuite.

Putting this up for a test run on ci_gpu and ci_cpu to see the effects
of moving TOPI test runs to AArch64 CPU before firing up the Jenkins
changes.

The motivation was from apache#8361 to pipeclean and add this support.
  • Loading branch information
Ramana Radhakrishnan committed Jul 5, 2021
1 parent ec47129 commit 51ba6d6
Show file tree
Hide file tree
Showing 12 changed files with 88 additions and 103 deletions.
10 changes: 4 additions & 6 deletions tests/python/topi/python/test_topi_batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,12 @@ def get_ref_data():
# get the test data
a_np, b_np, c_np = get_ref_data()

def check_device(device):
dev = tvm.device(device, 0)
def check_device(target, dev):
if device == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
print("Skip because int8 intrinsics are not available")
return

print("Running on target: %s" % device)
with tvm.target.Target(device):
with tvm.target.Target(target):
out = topi.cuda.batch_matmul_int8(x, y, None, out_dtype)
s = topi.cuda.schedule_batch_matmul_int8([out])
a = tvm.nd.array(a_np, dev)
Expand All @@ -127,8 +125,8 @@ def check_device(device):
f(a, b, c)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)

for device in ["cuda"]:
check_device(device)
for target, dev in tvm.testing.enabled_targets():
check_device(target, dev)


@tvm.testing.uses_gpu
Expand Down
2 changes: 1 addition & 1 deletion tests/python/topi/python/test_topi_batch_to_space_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def verify_batch_to_space_nd(input_shape, block_shape, crop_begin_list, crop_end

def check_device(target, dev):
print("Running on target: %s" % target)
with tvm.target.create(target):
with tvm.target.Target(target):
s = tvm.topi.testing.get_injective_schedule(target)(B)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), dev)
Expand Down
21 changes: 8 additions & 13 deletions tests/python/topi/python/test_topi_conv2d_NCHWc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _transform_bias(bias, bn):
bias = np.transpose(bias, (0, 2, 3, 1))
return bias


@tvm.testing.requires_llvm
def verify_conv2d_NCHWc(
batch,
in_channel,
Expand Down Expand Up @@ -115,13 +115,8 @@ def get_ref_data():

a_np, w_np, b_np, c_np = get_ref_data()

def check_device(device):
dev = tvm.device(device, 0)
if not tvm.testing.device_enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.Target(device):
def check_device(target, dev):
with tvm.target.Target(target):
C = topi.x86.conv2d_NCHWc(
A,
W,
Expand All @@ -146,7 +141,7 @@ def check_device(device):
func = tvm.build(
s,
[A, W, bias, C],
device,
target,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
% (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
)
Expand All @@ -155,17 +150,17 @@ def check_device(device):
func = tvm.build(
s,
[A, W, C],
device,
target,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d"
% (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation),
)
func(a, w, c)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)

# test llvm only for now since conv2d_NCHWc implement is missing in other backend.
for device in ["llvm"]:
with autotvm.tophub.context(device): # load tophub pre-tuned parameters
check_device(device)
for target, device in tvm.testing.enabled_targets():
with autotvm.tophub.context(target): # load tophub pre-tuned parameters
check_device(target, device)


def test_conv2d_NCHWc():
Expand Down
6 changes: 3 additions & 3 deletions tests/python/topi/python/test_topi_conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,7 @@ def get_ref_data():

a_np, w_np, b_np, c_np = get_ref_data()

def check_target(target):
dev = tvm.device(target, 0)
def check_target(target, dev):
if not tvm.testing.device_enabled(target):
print("Skip because %s is not enabled" % target)
return
Expand Down Expand Up @@ -222,7 +221,8 @@ def check_target(target):
func(a, w, c)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)

check_target("llvm")
for target, dev in tvm.testing.enabled_targets():
check_target(target, dev)


oc_block_factor = 4
Expand Down
19 changes: 9 additions & 10 deletions tests/python/topi/python/test_topi_conv2d_nhwc_pack_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import tvm.topi.testing
from tvm.contrib.pickle_memoize import memoize
from tvm.topi.utils import get_const_tuple

import tvm.testing

def verify_conv2d_1x1_nhwc_pack_int8(
batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1
Expand All @@ -51,26 +51,25 @@ def get_ref_data():

a_np, w_np, b_np = get_ref_data()

def check_device(device):
dev = tvm.device(device, 0)
if not tvm.testing.device_enabled(device):
print("Skip because %s is not enabled" % device)
def check_device(target, dev):
if not tvm.testing.device_enabled(target):
print("Skip because %s is not enabled" % target)
return
print("Running on target: %s" % device)
print("Running on target: %s" % target)

with tvm.target.Target(device):
with tvm.target.Target(target):
B = topi.nn.conv2d(A, W, stride, padding, dilation, layout="NHWC", out_dtype="int32")
s = topi.x86.schedule_conv2d_nhwc_pack_int8([B])
a = tvm.nd.array(a_np, dev)
w = tvm.nd.array(w_np, dev)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
func = tvm.build(s, [A, W, B], device)
func = tvm.build(s, [A, W, B], target)
func(a, w, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)

# for device in ['llvm -mcpu=skylake-avx512']:
for device in ["llvm"]:
check_device(device)
for target, dev in tvm.testing.enabled_targets():
check_device(target, dev)


# TODO(@llyfacebook): Please fix https://github.com/apache/tvm/issues/4122 to enable this test.
Expand Down
19 changes: 9 additions & 10 deletions tests/python/topi/python/test_topi_deformable_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,13 @@ def get_ref_data():

a_np, offset_np, w_np, c_np = get_ref_data()

def check_device(device):
dev = tvm.device(device, 0)
if not tvm.testing.device_enabled(device):
print("Skip because %s is not enabled" % device)
def check_device(target, dev):
if not tvm.testing.device_enabled(target):
print("Skip because %s is not enabled" % target)
return
print("Running on target: %s" % device)
fcompute, fschedule = tvm.topi.testing.dispatch(device, _deformable_conv2d_nchw_implement)
with tvm.target.Target(device):
print("Running on target: %s" % target)
fcompute, fschedule = tvm.topi.testing.dispatch(target, _deformable_conv2d_nchw_implement)
with tvm.target.Target(target):
C = fcompute(A, Offset, W, stride, padding, dilation, deformable_groups, groups, dtype)
s = fschedule([C])

Expand All @@ -108,12 +107,12 @@ def check_device(device):
w = tvm.nd.array(w_np, dev)
c = tvm.nd.empty(c_np.shape, dtype=c_np.dtype, device=dev)

func = tvm.build(s, [A, Offset, W, C], device)
func = tvm.build(s, [A, Offset, W, C], target)
func(a, offset, w, c)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)

for device in ["llvm", "cuda"]:
check_device(device)
for target, dev in tvm.testing.enabled_targets():
check_device(target, dev)


def verify_deformable_conv2d_nhwc(
Expand Down
17 changes: 6 additions & 11 deletions tests/python/topi/python/test_topi_lrn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,18 @@ def verify_lrn(shape, size, axis, bias, alpha, beta):
a_np = np.random.uniform(size=shape).astype(dtype)
b_np = tvm.topi.testing.lrn_python(a_np, size, axis, bias, alpha, beta)

def check_device(device):
if not tvm.testing.device_enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.Target(device):
s_func = tvm.topi.testing.dispatch(device, _lrn_schedule)
def check_device(target, dev):
with tvm.target.Target(target):
s_func = tvm.topi.testing.dispatch(target, _lrn_schedule)
s = s_func([B])
dev = tvm.device(device, 0)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev)
f = tvm.build(s, [A, B], device)
f = tvm.build(s, [A, B], target)
f(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)

for device in ["llvm", "cuda", "opencl", "metal", "rocm", "vulkan", "nvptx"]:
check_device(device)
for target, dev in tvm.testing.enabled_targets():
check_device(target, dev)


@tvm.testing.uses_gpu
Expand Down
17 changes: 6 additions & 11 deletions tests/python/topi/python/test_topi_reorg.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,19 @@ def get_ref_data_reorg():

a_np, b_np = get_ref_data_reorg()

def check_device(device):
def check_device(target, dev):
"""Cheching devices is enabled or not"""
dev = tvm.device(device, 0)
if not tvm.testing.device_enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.Target(device):
s_func = tvm.topi.testing.dispatch(device, _reorg_schedule)
with tvm.target.Target(target):
s_func = tvm.topi.testing.dispatch(target, _reorg_schedule)
s = s_func([B])
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
func = tvm.build(s, [A, B], device)
func = tvm.build(s, [A, B], target)
func(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)

for device in ["llvm", "cuda"]:
check_device(device)
for target, dev in tvm.testing.enabled_targets():
check_device(target, dev)


@tvm.testing.uses_gpu
Expand Down
2 changes: 1 addition & 1 deletion tests/python/topi/python/test_topi_space_to_batch_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def verify_space_to_batch_nd(input_shape, block_shape, pad_before, pad_after, pa

def check_target(target, dev):
print("Running on target: %s" % target)
with tvm.target.create(target):
with tvm.target.Target(target):
s = tvm.topi.testing.get_injective_schedule(target)(B)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), dev)
Expand Down
53 changes: 17 additions & 36 deletions tests/python/topi/python/test_topi_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,7 @@ def get_ref_data():

a_np, b_np, c_np, d_np = get_ref_data()

def check_device(device):
dev = tvm.device(device, 0)
if not tvm.testing.device_enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
def check_device(target, dev):
a = tvmsp.array(a_np, dev)
_nr, _nc, _n = a.shape[0], a.shape[1], a.data.shape[0]
assert a.shape[0] == a.indptr.shape[0] - 1
Expand All @@ -73,12 +68,12 @@ def check_device(device):
assert a.data.dtype == A.data.dtype
assert a.indices.dtype == A.indices.dtype
assert a.indptr.dtype == A.indptr.dtype
f = tvm.build(s, [nr, A.data, A.indices, A.indptr, B, C, D], device, name="csrmv")
f = tvm.build(s, [nr, A.data, A.indices, A.indptr, B, C, D], target, name="csrmv")
f(_nr, a.data, a.indices, a.indptr, b, c, d)
tvm.testing.assert_allclose(d.numpy(), d_np, rtol=1e-4, atol=1e-4)

for device in ["llvm"]:
check_device(device)
for target, dev in tvm.testing.enabled_targets():
check_device(target, dev)


def verify_dynamic_csrmm(batch, in_dim, out_dim, use_bias=True):
Expand All @@ -104,25 +99,20 @@ def get_ref_data():

a_np, b_np, c_np, d_np = get_ref_data()

def check_device(device):
dev = tvm.device(device, 0)
if not tvm.testing.device_enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
def check_device(target, dev):
a = tvmsp.array(a_np, dev)
_nr, _nc, _n = a.shape[0], a.shape[1], a.data.shape[0]
assert a.shape[0] == a.indptr.shape[0] - 1
b = tvm.nd.array(b_np, dev)
c = tvm.nd.array(c_np, dev)
d = tvm.nd.array(np.zeros((_nr, out_dim), dtype=dtype), dev)
f = tvm.build(s, [nr, A.data, A.indices, A.indptr, B, C, D], device, name="csrmm")
f = tvm.build(s, [nr, A.data, A.indices, A.indptr, B, C, D], target, name="csrmm")

f(_nr, a.data, a.indices, a.indptr, b, c, d)
tvm.testing.assert_allclose(d.numpy(), d_np, rtol=1e-2, atol=1e-2)

for device in ["llvm"]:
check_device(device)
for target, dev in tvm.testing.enabled_targets():
check_device(target, dev)


def verify_dense_si(batch, in_dim, out_dim, use_bias=True, dtype="float32"):
Expand Down Expand Up @@ -438,14 +428,9 @@ def test_sparse_dense_bsr_randomized():
W_indptr = te.placeholder(shape=W_sp_np.indptr.shape, dtype=str(W_sp_np.indptr.dtype))
X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype))

def check_device(device):
dev = tvm.device(device, 0)
if not tvm.testing.device_enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
fcompute, fschedule = tvm.topi.testing.dispatch(device, _sparse_dense_implement)
with tvm.target.Target(device):
def check_device(target, dev):
fcompute, fschedule = tvm.topi.testing.dispatch(target, _sparse_dense_implement)
with tvm.target.Target(target):
Y = fcompute(X, W_data, W_indices, W_indptr)
s = fschedule([Y])
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
Expand All @@ -459,8 +444,8 @@ def check_device(device):
)
tvm.testing.assert_allclose(Y_tvm.numpy(), Y_np, atol=1e-5, rtol=1e-5)

for device in ["llvm", "cuda"]:
check_device(device)
for target, dev in tvm.testing.enabled_targets():
check_device(target, dev)


@tvm.testing.parametrize_targets("cuda", "rocm")
Expand Down Expand Up @@ -577,14 +562,9 @@ def verify_sparse_conv2d_bsr(M, H, W, N, K, BS_R, BS_C, density, layout):
Y = topi.nn.sparse_conv2d(X, W_data, W_indices, W_indptr, layout)
s = te.create_schedule(Y.op)

def check_device(device):
dev = tvm.device(device, 0)
if not tvm.testing.device_enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
def check_device(target, dev):

func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y], target)
Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype="float32"))
func(
tvm.nd.array(X_np, dev),
Expand All @@ -595,7 +575,8 @@ def check_device(device):
)
tvm.testing.assert_allclose(Y_tvm.numpy(), Y_np.astype("float32"), atol=1e-4, rtol=1e-4)

check_device("llvm")
for target, dev in tvm.testing.enabled_targets():
check_device(target, dev)


def test_sparse_conv2d_bsr():
Expand Down
2 changes: 1 addition & 1 deletion tests/python/topi/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def check_device(target, dev):
print("Skip because %s is not enabled" % target)
return
print("Running on target: %s" % target)
with tvm.target.create(target):
with tvm.target.Target(target):
s = tvm.topi.testing.get_injective_schedule(target)(out)

func = tvm.build(s, [data] + indices + [out], target, name="adv_index")
Expand Down
Loading

0 comments on commit 51ba6d6

Please sign in to comment.