Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Daubechies Wavelets #10

Merged
merged 36 commits into from
Aug 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
3569df2
checkpoint
sjperkins Jul 20, 2020
f4fb46a
More tests
sjperkins Jul 20, 2020
88b3675
fix
sjperkins Jul 20, 2020
1d38f46
Add row slicing
sjperkins Jul 24, 2020
a2ec875
Generic axis slicing. Segfaults in some cases
sjperkins Jul 29, 2020
29cb0d3
This seems to resolve the segfaults
sjperkins Jul 29, 2020
223f777
Modularise
sjperkins Jul 29, 2020
77779f2
Test flags are the same
sjperkins Jul 29, 2020
c8602c2
Seems to fix segfaults
sjperkins Jul 30, 2020
ff96660
Contiguity handling
sjperkins Jul 30, 2020
999873a
Remove debug in jit decorators
sjperkins Jul 30, 2020
1ada536
tidy up
sjperkins Jul 30, 2020
5a7ed5a
Mode enums
sjperkins Jul 30, 2020
fbb2f48
Outline, segfaults when compiling
sjperkins Jul 30, 2020
fa2780c
utils.py -> numba_llvm.py
sjperkins Jul 30, 2020
a502696
Use step in downsampling convolution
sjperkins Jul 30, 2020
29ceecc
Fixes and more tests, but not working yet
sjperkins Jul 30, 2020
40bdbdf
numba_llvm.py => intrinsics.py
sjperkins Jul 30, 2020
aa00e35
Extend test_slice_axis
sjperkins Jul 31, 2020
9fe8e94
dwt working
sjperkins Aug 12, 2020
b2b5498
simplify
sjperkins Aug 12, 2020
46dc0fe
idwt_axis working
sjperkins Aug 12, 2020
f944144
Working idwtn
sjperkins Aug 12, 2020
d945fd9
Argument promotion fixups
sjperkins Aug 12, 2020
443f29b
More test case
sjperkins Aug 12, 2020
1abf2f3
Complain about periodisation downsampling
sjperkins Aug 12, 2020
0afedf1
cache decorators
sjperkins Aug 12, 2020
6df99e9
waverecn in place, coeff trimming not yet done
sjperkins Aug 14, 2020
3e23da8
Add extent to slice_axis
sjperkins Aug 17, 2020
2162daf
Clip coefficients in idwt_axis
sjperkins Aug 17, 2020
e3bb57d
cache=True, remove level arg from waverecn
sjperkins Aug 18, 2020
a72df74
Add axes and level tests to waverecn/wavedecn tests
sjperkins Aug 18, 2020
00e87fb
Add fastmath decorators
sjperkins Aug 18, 2020
a2b6388
Use numba typed Lists
sjperkins Aug 18, 2020
8be255a
test zeropad and levels
sjperkins Aug 20, 2020
4d353dc
Disable wavelet level warnings for now
sjperkins Aug 20, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 288 additions & 0 deletions pfb/test/test_wavelets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
from itertools import product

import numba
from numba.cpython.unsafe.tuple import tuple_setitem
import numpy as np
from numpy.testing import assert_array_almost_equal, assert_array_equal
import pytest

from pfb.wavelets.wavelets import (dwt, dwt_axis,
idwt, idwt_axis,
dwt_max_level,
str_to_int,
coeff_product,
promote_axis,
promote_level,
discrete_wavelet,
wavedecn, waverecn)
from pfb.wavelets.modes import (Modes,
promote_mode,
mode_str_to_enum)

from pfb.wavelets.intrinsics import slice_axis


convert_mode = numba.njit(lambda s: mode_str_to_enum(s))


@pytest.mark.parametrize("ndim", [1, 2, 3])
@pytest.mark.parametrize("extent", [-1, 0, 1])
def test_slice_axis(ndim, extent):
@numba.njit
def fn(a, index, axis=1, extent=None):
return slice_axis(a, index, axis, extent)

A = np.random.random(np.random.randint(4, 10, size=ndim))
assert A.ndim == ndim

for axis in range(A.ndim):
# Randomly choose indexes within the array
tup_idx = tuple(np.random.randint(0, d) for d in A.shape)
# Replace index with slice along desired axis
ax_size = A.shape[axis]

ext = None if extent == 0 else ax_size + extent
slice_idx = tuple(slice(ext) if a == axis
else i for a, i in enumerate(tup_idx))

As = A[slice_idx]
B = fn(A, tup_idx, axis, ext)

assert_array_equal(As, B)

if ndim == 1:
assert B.flags.c_contiguous is As.flags.c_contiguous
assert B.flags.f_contiguous is As.flags.f_contiguous
assert B.flags.aligned is As.flags.aligned
assert B.flags.writeable is As.flags.writeable
assert B.flags.writebackifcopy is As.flags.writebackifcopy
assert B.flags.updateifcopy is As.flags.updateifcopy

# TODO(sjperkins)
# Why is owndata True in the
# case of the numba intrinsic, but
# not in the case of numpy?
assert B.flags.owndata is not (extent or As.flags.owndata)
else:
assert B.flags == As.flags

# Check that modifying the numba slice
# modifies the numpy slice
B[:] = np.arange(B.shape[0])
assert_array_equal(As, B)


def test_internal_slice_axis():
@numba.njit
def fn(A):
for axis in range(A.ndim):
for i in np.ndindex(*tuple_setitem(A.shape, axis, 1)):
S = slice_axis(A, i, axis, None)

if S.flags.c_contiguous != (S.itemsize == S.strides[0]):
raise ValueError("contiguity flag doesn't match layout")

fn(np.random.random((8, 9, 10)))


@pytest.mark.parametrize("repeat", range(5))
def test_coeff_product(repeat):
res = coeff_product('ad', repeat=repeat)
coeffs = [''.join(c) for c in product('ad', repeat=repeat)]
assert list(res) == coeffs


def test_str_to_int():
assert str_to_int("111") == 111
assert str_to_int("23") == 23
assert str_to_int("3") == 3


def test_promote_mode():
assert [Modes.symmetric] == list(promote_mode("symmetric", 1))
assert [Modes.symmetric]*3 == list(promote_mode("symmetric", 3))

assert [Modes.symmetric] == list(promote_mode(["symmetric"], 1))
assert [Modes.symmetric] == list(promote_mode(("symmetric",), 1))

with pytest.raises(ValueError):
assert [Modes.symmetric] == list(promote_mode(["symmetric"], 2))

list_inputs = ["symmetric", "reflect"]
tuple_inputs = tuple(list_inputs)
result_enums = [Modes.symmetric, Modes.reflect]

assert result_enums == list(promote_mode(list_inputs, 2))
assert result_enums == list(promote_mode(tuple_inputs, 2))

with pytest.raises(ValueError):
assert result_enums == list(promote_mode(list_inputs, 3))

with pytest.raises(ValueError):
assert result_enums == list(promote_mode(list_inputs, 1))


def test_promote_axis():
assert [0] == list(promote_axis(0, 1))
assert [0] == list(promote_axis([0], 1))
assert [0] == list(promote_axis((0,), 1))

with pytest.raises(ValueError):
assert [0, 1] == list(promote_axis((0, 1), 1))

assert [0, 1] == list(promote_axis((0, 1), 2))
assert [0, 1] == list(promote_axis([0, 1], 2))

assert [0, 1] == list(promote_axis((0, 1), 3))


@pytest.mark.parametrize("data", [500, 100, 12])
@pytest.mark.parametrize("filter", [500, 100, 24])
def test_dwt_max_level(data, filter):
pywt = pytest.importorskip("pywt")
assert pywt.dwt_max_level(data, filter) == dwt_max_level(data, filter)


@pytest.mark.parametrize("wavelet", ["db1", "db4", "db5"])
def test_discrete_wavelet(wavelet):
pfb_wave = discrete_wavelet(wavelet)

pywt = pytest.importorskip("pywt")
py_wave = pywt.Wavelet(wavelet)

# assert py_wave.support_width == pfb_wave.support_width
assert py_wave.orthogonal == pfb_wave.orthogonal
assert py_wave.biorthogonal == pfb_wave.biorthogonal
#assert py_wave.compact_support == pfb_wave.compact_support
assert py_wave.family_name == pfb_wave.family_name
assert py_wave.short_family_name == pfb_wave.short_name
assert py_wave.vanishing_moments_phi == pfb_wave.vanishing_moments_phi
assert py_wave.vanishing_moments_psi == pfb_wave.vanishing_moments_psi

assert_array_almost_equal(py_wave.rec_lo, pfb_wave.rec_lo)
assert_array_almost_equal(py_wave.dec_lo, pfb_wave.dec_lo)
assert_array_almost_equal(py_wave.rec_hi, pfb_wave.rec_hi)
assert_array_almost_equal(py_wave.dec_hi, pfb_wave.dec_hi)


@pytest.mark.parametrize("wavelet", ["db1", "db4", "db5"])
@pytest.mark.parametrize("data_shape", [(13,), (12, 7)])
@pytest.mark.parametrize("mode", ["symmetric", "zero"])
def test_dwt_idwt_axis(wavelet, mode, data_shape):
pywt = pytest.importorskip("pywt")
data = np.random.random(size=data_shape)
enum_mode = convert_mode(mode)

pywt_dwt_axis = pywt._dwt.dwt_axis
pywt_idwt_axis = pywt._dwt.idwt_axis

pywt_wavelet = pywt.Wavelet(wavelet)
pywt_mode = pywt.Modes.from_object(mode)

wavelet = discrete_wavelet(wavelet)

for axis in reversed(range(len(data_shape))):
# Deconstruct
ca, cd = dwt_axis(data, wavelet, enum_mode, axis)
pywt_ca, pywt_cd = pywt_dwt_axis(data, pywt_wavelet, pywt_mode, axis)
assert_array_almost_equal(ca, pywt_ca)
assert_array_almost_equal(cd, pywt_cd)

# Reconstruct with both approximation and detail
pywt_out = pywt_idwt_axis(ca, cd, pywt_wavelet, pywt_mode, axis)
output = idwt_axis(ca, cd, wavelet, enum_mode, axis)
assert_array_almost_equal(output, pywt_out)

# Reconstruct with approximation only
pywt_out = pywt_idwt_axis(ca, None, pywt_wavelet, pywt_mode, axis)
output = idwt_axis(ca, None, wavelet, enum_mode, axis)
assert_array_almost_equal(output, pywt_out)

# Reconstruct with detail only
pywt_out = pywt_idwt_axis(None, cd, pywt_wavelet, pywt_mode, axis)
output = idwt_axis(None, cd, wavelet, enum_mode, axis)
assert_array_almost_equal(output, pywt_out)


def test_dwt_idwt():
pywt = pytest.importorskip("pywt")
data = np.random.random((5, 8, 11))

res = dwt(data, "db1", "symmetric")
pywt_res = pywt.dwtn(data, "db1", "symmetric")
for k, v in res.items():
assert_array_almost_equal(v, pywt_res[k])

res = dwt(data, "db1", "symmetric", 1)
pywt_res = pywt.dwtn(data, "db1", "symmetric", (1,))
for k, v in res.items():
assert_array_almost_equal(v, pywt_res[k])

res = dwt(data, ("db1", "db2"), ("symmetric", "symmetric"), (0, 1))
pywt_res = pywt.dwtn(data, ("db1", "db2"), ("symmetric", "symmetric"), (0, 1))
for k, v in res.items():
assert_array_almost_equal(v, pywt_res[k])

output = idwt(res, ("db1", "db2"), ("symmetric", "symmetric"), (0, 1))
pywt_out = pywt.idwtn(pywt_res, ("db1", "db2"), ("symmetric", "symmetric"), (0, 1))
assert_array_almost_equal(output, pywt_out)


@pytest.mark.parametrize("data_shape", [(50, 24, 63)])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we test this with multiple levels as well please?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should be possible to parametrize this on level.

@pytest.mark.parametrize("complex_data", [True, False])
@pytest.mark.parametrize("level", list(range(10)))
@pytest.mark.parametrize("mode", ["symmetric", "zero"])
@pytest.mark.parametrize("wavelet", ["db1", "db4", "db5"])
def test_wavedecn_waverecn(data_shape, wavelet, mode, level, complex_data):
pywt = pytest.importorskip("pywt")
data = np.random.random(data_shape)

if complex_data:
data = data + np.random.random(data_shape) * 1j

out = pywt.wavedecn(data, wavelet, mode)
a, coeffs = wavedecn(data, wavelet, mode)

assert_array_almost_equal(a, out[0])

for d1, d2 in zip(out[1:], coeffs):
assert list(d1.keys()) == list(d2.keys())

for k, v in d1.items():
assert_array_almost_equal(v, d2[k])

pywt_rec = pywt.waverecn(out, wavelet, mode)
rec = waverecn(a, coeffs, wavelet, mode)
assert_array_almost_equal(pywt_rec, rec)

out = pywt.wavedecn(data, wavelet, mode, axes=(1, 2))
a, coeffs = wavedecn(data, wavelet, mode, axis=(1, 2))

assert_array_almost_equal(a, out[0])

for d1, d2 in zip(out[1:], coeffs):
assert list(d1.keys()) == list(d2.keys())

for k, v in d1.items():
assert_array_almost_equal(v, d2[k])

pywt_rec = pywt.waverecn(out, wavelet, mode, axes=(1, 2))
rec = waverecn(a, coeffs, wavelet, mode, axis=(1, 2))
assert_array_almost_equal(pywt_rec, rec)

# Test various levels of decomposition
out = pywt.wavedecn(data, wavelet, mode, level=level, axes=(1, 2))
a, coeffs = wavedecn(data, wavelet, mode, level=level, axis=(1, 2))

assert_array_almost_equal(a, out[0])

for d1, d2 in zip(out[1:], coeffs):
assert list(d1.keys()) == list(d2.keys())

for k, v in d1.items():
assert_array_almost_equal(v, d2[k])

pywt_rec = pywt.waverecn(out, wavelet, mode, axes=(1, 2))
rec = waverecn(a, coeffs, wavelet, mode, axis=(1, 2))
assert_array_almost_equal(pywt_rec, rec)

1 change: 1 addition & 0 deletions pfb/wavelets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pfb.wavelets.wavelets import dwt, idwt
Loading