Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
icemelon committed Mar 26, 2021
1 parent f8bde56 commit 7334158
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions tests/python/topi/python/test_topi_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@ def get_implementations(name, axis, dtype, exclusive):


def _run_tests(
ctx,
dev,
target,
op_name: str = "cumsum",
gt_func: Callable[..., np.array] = np.cumsum,
):
def check_scan(np_ref, data, axis=None, dtype=None, exclusive=False):
implementations = get_implementations(op_name, axis, dtype, exclusive)
fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)
tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule)
tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, dev, fcompute, fschedule)

data = np.array([2, 3, 0])
check_scan(gt_func(data), data)
Expand Down Expand Up @@ -121,24 +121,24 @@ def check_scan(np_ref, data, axis=None, dtype=None, exclusive=False):


@tvm.testing.parametrize_targets
def test_cumsum(ctx, target):
_run_tests(ctx, target, op_name="cumsum", gt_func=np.cumsum)
def test_cumsum(dev, target):
_run_tests(dev, target, op_name="cumsum", gt_func=np.cumsum)


@tvm.testing.parametrize_targets
def test_cumprod(ctx, target):
_run_tests(ctx, target, op_name="cumprod", gt_func=np.cumprod)
def test_cumprod(dev, target):
_run_tests(dev, target, op_name="cumprod", gt_func=np.cumprod)


if __name__ == "__main__":
test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm"))
test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda"))
test_cumsum(tvm.context("nvptx"), tvm.target.Target("nvptx"))
test_cumsum(tvm.context("vulkan"), tvm.target.Target("vulkan"))
test_cumsum(tvm.context("metal"), tvm.target.Target("metal"))

test_cumprod(tvm.context("cpu"), tvm.target.Target("llvm"))
test_cumprod(tvm.context("cuda"), tvm.target.Target("cuda"))
test_cumprod(tvm.context("nvptx"), tvm.target.Target("nvptx"))
test_cumprod(tvm.context("vulkan"), tvm.target.Target("vulkan"))
test_cumprod(tvm.context("metal"), tvm.target.Target("metal"))
test_cumsum(tvm.device("cpu"), tvm.target.Target("llvm"))
test_cumsum(tvm.device("cuda"), tvm.target.Target("cuda"))
test_cumsum(tvm.device("nvptx"), tvm.target.Target("nvptx"))
test_cumsum(tvm.device("vulkan"), tvm.target.Target("vulkan"))
test_cumsum(tvm.device("metal"), tvm.target.Target("metal"))

test_cumprod(tvm.device("cpu"), tvm.target.Target("llvm"))
test_cumprod(tvm.device("cuda"), tvm.target.Target("cuda"))
test_cumprod(tvm.device("nvptx"), tvm.target.Target("nvptx"))
test_cumprod(tvm.device("vulkan"), tvm.target.Target("vulkan"))
test_cumprod(tvm.device("metal"), tvm.target.Target("metal"))

0 comments on commit 7334158

Please sign in to comment.