Skip to content

Commit

Permalink
MAINT: Split thresholding.threshold
Browse files Browse the repository at this point in the history
Each mode of the threshold function was independent, so they make more sense as
individual functions. This commit is just a code rearrangement, no change to
functionality.

Addresses PyWavelets#61
  • Loading branch information
Kai Wohlfahrt authored and aaren committed Aug 3, 2015
1 parent 44e09a7 commit 6931435
Showing 1 changed file with 36 additions and 28 deletions.
64 changes: 36 additions & 28 deletions pywt/thresholding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,39 @@

import numpy as np

def soft(data, value, substitute=0):
data = np.asarray(data)
mvalue = -value

cond_less = np.less(data, value)
cond_greater = np.greater(data, mvalue)

output = np.where(cond_less & cond_greater, substitute, data)
output = np.where(cond_less, output + value, output)
output = np.where(cond_greater, output - value, output)
return output

def hard(data, value, substitute=0):
data = np.asarray(data)
mvalue = -value

cond = np.less(data, value)
cond &= np.greater(data, mvalue)

return np.where(cond, substitute, data)

def greater(data, value, substitute=0):
data = np.asarray(data)
return np.where(np.less(data, value), substitute, data)

def less(data, value, substitute=0):
data = np.asarray(data)
return np.where(np.greater(data, value), substitute, data)

thresholding_options = {'soft': soft,
'hard': hard,
'greater': greater,
'less': less}

def threshold(data, value, mode='soft', substitute=0):
"""
Expand Down Expand Up @@ -67,34 +100,9 @@ def threshold(data, value, mode='soft', substitute=0):
array([ 1. , 1.5, 2. , 0. , 0. , 0. , 0. ])
"""
data = np.asarray(data)

if mode == 'soft':
mvalue = -value

cond_less = np.less(data, value)
cond_greater = np.greater(data, mvalue)

output = np.where(cond_less & cond_greater, substitute, data)
output = np.where(cond_less, output + value, output)
output = np.where(cond_greater, output - value, output)

elif mode == 'hard':
mvalue = -value

cond = np.less(data, value)
cond &= np.greater(data, mvalue)

output = np.where(cond, substitute, data)

elif mode == 'greater':
output = np.where(np.less(data, value), substitute, data)

elif mode == 'less':
output = np.where(np.greater(data, value), substitute, data)

else:
try:
return thresholding_options[mode](data, value, substitute)
except KeyError:
raise ValueError("The mode parameter only takes value among "
"{'soft', 'hard', 'greater','less'}.")

return output

0 comments on commit 6931435

Please sign in to comment.