Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tests] Replace the Relay interpreter with the VM in the op tests #11386

Merged
merged 1 commit into from
May 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def check_grad(
mean=0,
mode="higher_order",
target_devices=None,
executor_kind="debug",
):
"""Perform numerical gradient checking given a relay function.

Expand Down Expand Up @@ -146,8 +147,12 @@ def check_grad(
for target, dev in target_devices:
# Eval the backward and forward functions
# TODO(mbs): Evaluate a pair of functions so can share preparation between them.
bwd_func_compiled = relay.create_executor(device=dev, target=target).evaluate(bwd_func)
fwd_func_compiled = relay.create_executor(device=dev, target=target).evaluate(fwd_func)
bwd_func_compiled = relay.create_executor(
executor_kind, device=dev, target=target
).evaluate(bwd_func)
fwd_func_compiled = relay.create_executor(
executor_kind, device=dev, target=target
).evaluate(fwd_func)

# Get analytic gradients.
_, grads = bwd_func_compiled(*inputs)
Expand Down
54 changes: 27 additions & 27 deletions tests/python/relay/dyn/test_dynamic_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
import random
import tvm.testing

executor_kind = tvm.testing.parameter("debug", "vm")


@tvm.testing.uses_gpu
def test_broadcast_to():
def test_broadcast_to(executor_kind):
def verify_more_dynamic_broadcast_to(x_shape, out_shape):
rank = len(out_shape)
dtype = "float32"
Expand All @@ -45,12 +47,13 @@ def verify_more_dynamic_broadcast_to(x_shape, out_shape):
x = np.random.uniform(size=np.prod(x_shape)).astype(dtype)
ref_res = np.broadcast_to(np.reshape(x, x_shape), out_shape)
for target, dev in tvm.testing.enabled_targets():
for kind in ["vm", "debug"]:
mod = tvm.ir.IRModule.from_expr(func)
op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate(
func
)(x, np.array(x_shape).astype(shape_type), np.array(out_shape).astype(shape_type))
tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)
mod = tvm.ir.IRModule.from_expr(func)
op_res = relay.create_executor(
executor_kind, mod=mod, device=dev, target=target
).evaluate(func)(
x, np.array(x_shape).astype(shape_type), np.array(out_shape).astype(shape_type)
)
tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)

verify_more_dynamic_broadcast_to((4, 3), (3, 4, 3))

Expand All @@ -70,20 +73,19 @@ def verify_broadcast_to(x_shape, out_shape):
x = np.random.uniform(size=x_shape).astype(dtype)
ref_res = np.broadcast_to(x, out_shape)
for target, dev in tvm.testing.enabled_targets():
for kind in ["vm", "debug"]:
mod = tvm.ir.IRModule.from_expr(func)
op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate(
func
)(x, np.array(out_shape).astype(shape_type))
tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)
mod = tvm.ir.IRModule.from_expr(func)
op_res = relay.create_executor(
executor_kind, mod=mod, device=dev, target=target
).evaluate(func)(x, np.array(out_shape).astype(shape_type))
tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)

verify_broadcast_to((1,), (1, 1, 1))
verify_broadcast_to((1, 1), (4, 1, 1))
verify_broadcast_to((4, 1), (1, 4, 3))


@tvm.testing.uses_gpu
def test_dyn_broadcast_to():
def test_dyn_broadcast_to(executor_kind):
dtype = "uint8"
rank = 3
shape_type = "int64"
Expand All @@ -101,16 +103,15 @@ def test_dyn_broadcast_to():
dyn_shape = (1,) * rank
ref_res = np.broadcast_to(x, dyn_shape)
for target, dev in tvm.testing.enabled_targets():
for kind in ["vm", "debug"]:
mod = tvm.ir.IRModule.from_expr(func)
op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate(func)(
x, np.array(dyn_shape).astype(shape_type)
)
tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)
mod = tvm.ir.IRModule.from_expr(func)
op_res = relay.create_executor(executor_kind, mod=mod, device=dev, target=target).evaluate(
func
)(x, np.array(dyn_shape).astype(shape_type))
tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)


@tvm.testing.uses_gpu
def test_dyn_one_hot():
def test_dyn_one_hot(executor_kind):
def _get_oshape(indices_shape, depth, axis):
oshape = []
true_axis = len(indices_shape) if axis == -1 else axis
Expand All @@ -135,12 +136,11 @@ def _verify(indices_shape, depth, on_value, off_value, axis, dtype):
indices_np = np.random.randint(0, depth, size=indices_shape).astype("int32")
out_np = tvm.topi.testing.one_hot(indices_np, on_value, off_value, depth, axis, dtype)
for target, dev in tvm.testing.enabled_targets():
for kind in ["vm", "debug"]:
mod = tvm.ir.IRModule.from_expr(func)
out_relay = relay.create_executor(
kind, mod=mod, device=dev, target=target
).evaluate()(indices_np, np.array(depth).astype("int32"))
tvm.testing.assert_allclose(out_relay.numpy(), out_np)
mod = tvm.ir.IRModule.from_expr(func)
out_relay = relay.create_executor(
executor_kind, mod=mod, device=dev, target=target
).evaluate()(indices_np, np.array(depth).astype("int32"))
tvm.testing.assert_allclose(out_relay.numpy(), out_np)

_verify((3,), 3, 1, 0, -1, "int32")
_verify((3,), 3, 1.0, 0.0, -1, "float32")
Expand Down
46 changes: 26 additions & 20 deletions tests/python/relay/dyn/test_dynamic_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
import tvm.topi.testing
from tvm.relay.testing import run_infer_type

executor_kind = tvm.testing.parameter("debug", "vm")


@tvm.testing.uses_gpu
def test_dyn_upsampling_run():
def test_dyn_upsampling_run(executor_kind):
def verify_upsampling(dshape, scale_h, scale_w, layout, method, align_corners=False):

if layout == "NCHW":
Expand Down Expand Up @@ -58,12 +60,13 @@ def verify_upsampling(dshape, scale_h, scale_w, layout, method, align_corners=Fa
func = relay.Function([x, scale_h_var, scale_w_var], z)

for target, dev in tvm.testing.enabled_targets():
for kind in ["vm", "debug"]:
mod = tvm.ir.IRModule.from_expr(func)
op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()(
x_data, np.array(scale_h).astype("float32"), np.array(scale_w).astype("float32")
)
tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4, atol=1e-6)
mod = tvm.ir.IRModule.from_expr(func)
op_res = relay.create_executor(
executor_kind, mod=mod, device=dev, target=target
).evaluate()(
x_data, np.array(scale_h).astype("float32"), np.array(scale_w).astype("float32")
)
tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4, atol=1e-6)

verify_upsampling((1, 16, 32, 32), 3, 2.0, "NCHW", "nearest_neighbor")
verify_upsampling((1, 16, 32, 32), 5, 2.0, "NCHW", "bilinear", True)
Expand All @@ -85,7 +88,7 @@ def test_dyn_upsampling_infer_type_const():


@tvm.testing.uses_gpu
def test_dyn_upsampling3d_run():
def test_dyn_upsampling3d_run(executor_kind):
def verify_upsampling3d(
dshape, scale_d, scale_h, scale_w, layout, method, coord_trans="asymmetric"
):
Expand Down Expand Up @@ -124,15 +127,16 @@ def verify_upsampling3d(
func = relay.Function([x, scale_d_var, scale_h_var, scale_w_var], z)

for target, dev in enabled_targets():
for kind in ["vm", "debug"]:
mod = tvm.ir.IRModule.from_expr(func)
op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()(
x_data,
np.array(scale_d).astype("float32"),
np.array(scale_h).astype("float32"),
np.array(scale_w).astype("float32"),
)
tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4, atol=1e-6)
mod = tvm.ir.IRModule.from_expr(func)
op_res = relay.create_executor(
executor_kind, mod=mod, device=dev, target=target
).evaluate()(
x_data,
np.array(scale_d).astype("float32"),
np.array(scale_h).astype("float32"),
np.array(scale_w).astype("float32"),
)
tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4, atol=1e-6)

verify_upsampling3d((1, 1, 1, 1, 1), 2, 3, 4, "NCDHW", "nearest_neighbor")
verify_upsampling3d((1, 8, 16, 16, 16), 2.0, 3.0, 4.0, "NCDHW", "nearest_neighbor")
Expand Down Expand Up @@ -163,7 +167,7 @@ def test_dyn_upsampling3d_infer_type_const():


@tvm.testing.uses_gpu
def test_dyn_pad():
def test_dyn_pad(executor_kind):
def verify_pad(dshape, pad_width, pad_val, dtype):
x = relay.var("x", relay.TensorType(dshape, dtype))
ndim = len(dshape)
Expand All @@ -178,7 +182,9 @@ def verify_pad(dshape, pad_width, pad_val, dtype):
ref_res = np.pad(data, pad_width, "constant", constant_values=(((pad_val,) * 2),) * ndim)
pad_width = np.array(pad_width).astype("int64")

verify_func(func, [data, pad_width, np.array(pad_val).astype(dtype)], ref_res)
verify_func(
executor_kind, func, [data, pad_width, np.array(pad_val).astype(dtype)], ref_res
)

def verify_pad_default_fill(dshape, pad_width, dtype):
x = relay.var("x", relay.TensorType(dshape, dtype))
Expand All @@ -193,7 +199,7 @@ def verify_pad_default_fill(dshape, pad_width, dtype):
ref_res = np.pad(data, pad_width)
pad_width = np.array(pad_width).astype("int64")

verify_func(func, [data, pad_width], ref_res)
verify_func(executor_kind, func, [data, pad_width], ref_res)

verify_pad((4, 10, 7, 7), ((1, 1), (2, 2), (3, 3), (4, 4)), 2.0, "int32")
verify_pad((2, 7), ((1, 4), (2, 2)), 4.0, "float64")
Expand Down
Loading