Skip to content

Commit

Permalink
MAINT: Remove six
Browse files Browse the repository at this point in the history
`_check_mode_input` was the only user of six. It was only used to check the
input of `MODES.from_object`, and raised a `TypeError` if a numeric mode was
out of range or a mode object was not a string type. `MODES.from_object` raises
a `ValueError` in these situations, and when the mode name string is not found.

Now, `MODES.from_object` is wrapped in a function `_try_mode` that returns its
result and fixes the type of any exception raised to keep the API the same.

This is a candidate for cleanup after the API break. Part of issue PyWavelets#61
  • Loading branch information
Kai Wohlfahrt authored and aaren committed Aug 3, 2015
1 parent b671610 commit 78e9bd0
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 51 deletions.
34 changes: 0 additions & 34 deletions pywt/_tools/six.py

This file was deleted.

30 changes: 13 additions & 17 deletions pywt/src/_pywt.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ import warnings
import numpy as np
cimport numpy as np

from pywt._tools.six import string_types


ctypedef fused data_t:
np.float32_t
Expand Down Expand Up @@ -91,9 +89,9 @@ class _Modes(object):

def from_object(self, mode):
if isinstance(mode, int):
m = mode
if m <= c_wt.MODE_INVALID or m >= c_wt.MODE_MAX:
if mode <= c_wt.MODE_INVALID or mode >= c_wt.MODE_MAX:
raise ValueError("Invalid mode.")
m = mode
else:
try:
m = getattr(MODES, mode)
Expand Down Expand Up @@ -642,7 +640,6 @@ def dwt(object data, object wavelet, object mode='sym'):
[-0.70710678 -0.70710678 -0.70710678]
"""
_check_mode_input(mode)
# accept array_like input; make a copy to ensure a contiguous array
dt = _check_dtype(data)
data = np.array(data, dtype=dt)
Expand All @@ -656,7 +653,7 @@ def _dwt(np.ndarray[data_t, ndim=1] data, object wavelet, object mode='sym'):
cdef c_wt.MODE mode_

w = c_wavelet_from_object(wavelet)
mode_ = MODES.from_object(mode)
mode_ = _try_mode(mode)

data = np.array(data)
output_len = c_wt.dwt_buffer_length(data.size, w.dec_len, mode_)
Expand Down Expand Up @@ -719,7 +716,6 @@ def dwt_coeff_len(data_len, filter_len, mode='sym'):
"""
cdef index_t filter_len_

_check_mode_input(mode)
if isinstance(filter_len, Wavelet):
filter_len_ = filter_len.dec_len
else:
Expand All @@ -730,18 +726,20 @@ def dwt_coeff_len(data_len, filter_len, mode='sym'):
if filter_len_ < 1:
raise ValueError("Value of filter_len must be greater than zero.")

return c_wt.dwt_buffer_length(data_len, filter_len_, MODES.from_object(mode))
return c_wt.dwt_buffer_length(data_len, filter_len_, _try_mode(mode))


###############################################################################
# idwt


def _check_mode_input(mode):
valid_ints = range(len(MODES.modes))
if not ((isinstance(mode, string_types)) or (mode in valid_ints)):
raise TypeError("`mode` should be a string, unicode or a pywt.MODES "
"object.")
def _try_mode(mode):
try:
return MODES.from_object(mode)
except ValueError as e:
if "Unknown mode name" in str(e):
raise
raise TypeError("Invalid mode: {}".format(mode))


def _check_dtype(data):
Expand Down Expand Up @@ -788,7 +786,6 @@ def idwt(cA, cD, object wavelet, object mode='sym', int correct_size=0):
Single level reconstruction of signal from given coefficients.
"""
_check_mode_input(mode)
# accept array_like input; make a copy to ensure a contiguous array

if cA is None and cD is None:
Expand Down Expand Up @@ -826,7 +823,7 @@ def _idwt(np.ndarray[data_t, ndim=1, mode="c"] cA,
cdef c_wt.MODE mode_

w = c_wavelet_from_object(wavelet)
mode_ = MODES.from_object(mode)
mode_ = _try_mode(mode)

cdef np.ndarray[data_t, ndim=1, mode="c"] rec
cdef index_t rec_len
Expand Down Expand Up @@ -1049,9 +1046,8 @@ def _downcoef(part, np.ndarray[data_t, ndim=1, mode="c"] data,
cdef Wavelet w
cdef c_wt.MODE mode_

_check_mode_input(mode)
w = c_wavelet_from_object(wavelet)
mode_ = MODES.from_object(mode)
mode_ = _try_mode(mode)

if part not in ('a', 'd'):
raise ValueError("Argument 1 must be 'a' or 'd', not '%s'." % part)
Expand Down
1 change: 1 addition & 0 deletions pywt/tests/test_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def test_available_modes():

def test_invalid_modes():
x = np.arange(4)
assert_raises(ValueError, pywt.dwt, x, 'db2', 'unknown')
assert_raises(TypeError, pywt.dwt, x, 'db2', -1)
assert_raises(TypeError, pywt.dwt, x, 'db2', 7)
assert_raises(TypeError, pywt.dwt, x, 'db2', None)
Expand Down

0 comments on commit 78e9bd0

Please sign in to comment.