Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

idwtn should allow coefficients to be set as None #291

Merged
merged 4 commits into from
Mar 9, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions pywt/_multidim.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def idwt2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)):
----------
coeffs : tuple
(cA, (cH, cV, cD)) A tuple with approximation coefficients and three
details coefficients 2D arrays like from `dwt2()`
details coefficients 2D arrays like from `dwt2()`. If any of these
components are set to ``None``, it will be treated as zeros.
wavelet : Wavelet object or name string, or 2-tuple of wavelets
Wavelet to use. This can also be a tuple containing a wavelet to
apply along each axis in ``axes``.
Expand Down Expand Up @@ -113,10 +114,6 @@ def idwt2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)):
raise ValueError("Expected 2 axes")

coeffs = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}

# drop the keys corresponding to value = None
coeffs = dict((k, v) for k, v in coeffs.items() if v is not None)

return idwtn(coeffs, wavelet, mode, axes)


Expand Down Expand Up @@ -224,8 +221,8 @@ def idwtn(coeffs, wavelet, mode='symmetric', axes=None):
Parameters
----------
coeffs: dict
Dictionary as in output of `dwtn`. Missing or None items
will be treated as zeroes.
Dictionary as in output of ``dwtn``. Missing or ``None`` items
will be treated as zeros.
wavelet : Wavelet object or name string, or tuple of wavelets
Wavelet to use. This can also be a tuple containing a wavelet to
apply along each axis in ``axes``.
Expand All @@ -247,6 +244,10 @@ def idwtn(coeffs, wavelet, mode='symmetric', axes=None):
Original signal reconstructed from input data.

"""

# drop the keys corresponding to value = None
coeffs = dict((k, v) for k, v in coeffs.items() if v is not None)

# Raise error for invalid key combinations
coeffs = _fix_coeffs(coeffs)

Expand Down
38 changes: 34 additions & 4 deletions pywt/tests/test_multidim.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,6 @@ def test_error_on_invalid_keys():
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH, 'ff': LH}
assert_raises(ValueError, pywt.idwtn, d, wavelet)

# a key whose value is None
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': None}
assert_raises(ValueError, pywt.idwtn, d, wavelet)

# mismatched key lengths
d = {'a': LL, 'da': HL, 'ad': LH, 'dd': HH}
assert_raises(ValueError, pywt.idwtn, d, wavelet)
Expand Down Expand Up @@ -268,6 +264,40 @@ def test_idwtn_axes():
assert_allclose(pywt.idwtn(coefs, 'haar', axes=(1, 1)), data, atol=1e-14)


def test_idwt2_none_coeffs():
data = np.array([[0, 1, 2, 3],
[1, 1, 1, 1],
[1, 4, 2, 8]])
data = data + 1j*data # test with complex data
cA, (cH, cV, cD) = pywt.dwt2(data, 'haar', axes=(1, 1))

# verify setting coefficients to None is the same as zeroing them
cD = np.zeros_like(cD)
result_zeros = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1))

cD = None
result_none = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1))

assert_equal(result_zeros, result_none)


def test_idwtn_none_coeffs():
data = np.array([[0, 1, 2, 3],
[1, 1, 1, 1],
[1, 4, 2, 8]])
data = data + 1j*data # test with complex data
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))

# verify setting coefficients to None is the same as zeroing them
coefs['dd'] = np.zeros_like(coefs['dd'])
result_zeros = pywt.idwtn(coefs, 'haar', axes=(1, 1))

coefs['dd'] = None
result_none = pywt.idwtn(coefs, 'haar', axes=(1, 1))

assert_equal(result_zeros, result_none)


def test_idwt2_axes():
data = np.array([[0, 1, 2, 3],
[1, 1, 1, 1],
Expand Down