From e93ea11bfb6b5f0f6d1a26281f555770705deaf7 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Sat, 20 Aug 2022 20:54:42 +0300 Subject: [PATCH] port test_tools to arraycontext --- test/test_tools.py | 68 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 50 insertions(+), 18 deletions(-) diff --git a/test/test_tools.py b/test/test_tools.py index 08a90689..263e412e 100644 --- a/test/test_tools.py +++ b/test/test_tools.py @@ -20,26 +20,38 @@ THE SOFTWARE. """ -import logging -logger = logging.getLogger(__name__) +import pytest +import sys -import sumpy.symbolic as sym -from sumpy.tools import (fft_toeplitz_upper_triangular, - matvec_toeplitz_upper_triangular, loopy_fft, fft) import numpy as np -import pyopencl as cl -import pyopencl.array as cla -from pyopencl.tools import ( # noqa - pytest_generate_tests_for_pyopencl as pytest_generate_tests) +from arraycontext import pytest_generate_tests_for_array_contexts +from sumpy.array_context import ( # noqa: F401 + PytestPyOpenCLArrayContextFactory, _acf) + +import sumpy.symbolic as sym +from sumpy.tools import ( + fft_toeplitz_upper_triangular, + matvec_toeplitz_upper_triangular, + loopy_fft, + fft) + +import logging +logger = logging.getLogger(__name__) + +pytest_generate_tests = pytest_generate_tests_for_array_contexts([ + PytestPyOpenCLArrayContextFactory, + ]) -import pytest +# {{{ test_matvec_fft def test_matvec_fft(): k = 5 - v = np.random.rand(k) - x = np.random.rand(k) + + rng = np.random.default_rng(42) + v = rng.random(k) + x = rng.random(k) fft = fft_toeplitz_upper_triangular(v, x) matvec = matvec_toeplitz_upper_triangular(v, x) @@ -47,6 +59,10 @@ def test_matvec_fft(): for i in range(k): assert abs(fft[i] - matvec[i]) < 1e-14 +# }}} + + +# {{{ test_matvec_fft_small_floats def test_matvec_fft_small_floats(): k = 5 @@ -60,15 +76,31 @@ def test_matvec_fft_small_floats(): continue assert abs(f) > 1e-10 +# }}} + + +# {{{ test_fft @pytest.mark.parametrize("size", [1, 2, 7, 10, 30, 210]) -def test_fft(ctx_factory, size): - ctx = ctx_factory() - queue = cl.CommandQueue(ctx) +def test_fft(actx_factory, size): + actx = actx_factory() + inp = np.arange(size, dtype=np.complex64) - inp_dev = cla.to_device(queue, inp) + inp_dev = actx.from_numpy(inp) out = fft(inp) fft_func = loopy_fft(inp.shape, inverse=False, complex_dtype=inp.dtype.type) - evt, (out_dev,) = fft_func(queue, y=inp_dev) - assert np.allclose(out_dev.get(), out) + evt, (out_dev,) = fft_func(actx.queue, y=inp_dev) + + assert np.allclose(actx.to_numpy(out_dev), out) + +# }}} + + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + pytest.main([__file__]) + +# vim: fdm=marker