Skip to content

Commit

Permalink
port test_tools to arraycontext
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Aug 20, 2022
1 parent a02133f commit e93ea11
Showing 1 changed file with 50 additions and 18 deletions.
68 changes: 50 additions & 18 deletions test/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,49 @@
THE SOFTWARE.
"""

import logging
logger = logging.getLogger(__name__)
import pytest
import sys

import sumpy.symbolic as sym
from sumpy.tools import (fft_toeplitz_upper_triangular,
matvec_toeplitz_upper_triangular, loopy_fft, fft)
import numpy as np

import pyopencl as cl
import pyopencl.array as cla
from pyopencl.tools import ( # noqa
pytest_generate_tests_for_pyopencl as pytest_generate_tests)
from arraycontext import pytest_generate_tests_for_array_contexts
from sumpy.array_context import ( # noqa: F401
PytestPyOpenCLArrayContextFactory, _acf)

import sumpy.symbolic as sym
from sumpy.tools import (
fft_toeplitz_upper_triangular,
matvec_toeplitz_upper_triangular,
loopy_fft,
fft)

import logging
logger = logging.getLogger(__name__)

pytest_generate_tests = pytest_generate_tests_for_array_contexts([
PytestPyOpenCLArrayContextFactory,
])

import pytest

# {{{ test_matvec_fft

def test_matvec_fft():
k = 5
v = np.random.rand(k)
x = np.random.rand(k)

rng = np.random.default_rng(42)
v = rng.random(k)
x = rng.random(k)

fft = fft_toeplitz_upper_triangular(v, x)
matvec = matvec_toeplitz_upper_triangular(v, x)

for i in range(k):
assert abs(fft[i] - matvec[i]) < 1e-14

# }}}


# {{{ test_matvec_fft_small_floats

def test_matvec_fft_small_floats():
k = 5
Expand All @@ -60,15 +76,31 @@ def test_matvec_fft_small_floats():
continue
assert abs(f) > 1e-10

# }}}


# {{{ test_fft

@pytest.mark.parametrize("size", [1, 2, 7, 10, 30, 210])
def test_fft(ctx_factory, size):
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
def test_fft(actx_factory, size):
actx = actx_factory()

inp = np.arange(size, dtype=np.complex64)
inp_dev = cla.to_device(queue, inp)
inp_dev = actx.from_numpy(inp)
out = fft(inp)

fft_func = loopy_fft(inp.shape, inverse=False, complex_dtype=inp.dtype.type)
evt, (out_dev,) = fft_func(queue, y=inp_dev)
assert np.allclose(out_dev.get(), out)
evt, (out_dev,) = fft_func(actx.queue, y=inp_dev)

assert np.allclose(actx.to_numpy(out_dev), out)

# }}}


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
else:
pytest.main([__file__])

# vim: fdm=marker

0 comments on commit e93ea11

Please sign in to comment.