Skip to content

Commit

Permalink
fix: rewrite helper functions to have option to do derivative or not
Browse files Browse the repository at this point in the history
  • Loading branch information
Radascript committed May 19, 2022
1 parent 284fdf6 commit b0107ad
Showing 1 changed file with 37 additions and 24 deletions.
61 changes: 37 additions & 24 deletions src/emgdecompy/emgdecompy.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def normalize(w):
return w


def apply_contrast_fun_router(w, fun=der_skew):
def apply_contrast_fun_router(w, fun=skew, der=False):
"""
Takes first derivitive and applies contrast function to w with map()
for Step 2a of fixed point algorithm
Expand All @@ -273,32 +273,28 @@ def apply_contrast_fun_router(w, fun=der_skew):
Example
--------
>>> w = np.array([1, 2, 3])
>>> fun = der_skew
>>> fun = skew
>>> apply_contrast_fun_router(w, fun)
>>> array([1, 4, 9])
"""

# an_array = np.array([1, 2, 3])

# def double(x):
# return x * 2

# mapped_array = double(an_array)
# print(mapped_array)

rtn = fun(w)
rtn = fun(w, der)
return rtn


def der_skew(x):
def skew(x, der=False):
"""
Takes first derivitive and applies contrast function to w
function = x^3 / 3
Applies contrast function (if der=False) or
first derivative of contrast function (if der=True)
to w
skew = x^3 / 3
Parameters
----------
x: float
number to apply contrast function to
der: boolean
whether to apply derivative (or base version)
Returns
-------
Expand All @@ -313,20 +309,27 @@ def der_skew(x):
"""

# first derivitive of x^3/3 = x^2
rtn = x ** 2
if der == True:
rtn = x ** 2
else:
rtn = (x ** 3) / 3

return rtn


def der_log_cosh(x):
def log_cosh(x, der=False):
"""
Takes first derivitive and applies contrast function to w
Applies contrast function (if der=False) or
first derivative of contrast function (if der=True)
to w
function = log(cosh(x))
Parameters
----------
x: float
number to apply contrast function to
der: boolean
whether to apply derivative (or base version)
Returns
-------
Expand All @@ -335,26 +338,33 @@ def der_log_cosh(x):
Example
--------
>>> x = 0.5
>>> der_log_cosh(x)
>>> 0.46211715726000974
>>> x = 4
>>> log_cosh(x)
>>> 16
"""

# first derivitive of log(cosh(x)) = tanh(x)
rtn = np.tanh(x)
if der == True:
rtn = np.tanh(x)
else:
rtn = np.log(np.cosh(x))

return rtn


def der_exp_sq(x):
def exp_sq(x, der=False):
"""
Takes first derivitive and applies contrast function to w
Applies contrast function (if der=False) or
first derivative of contrast function (if der=True)
to w
function = exp((-x^2/2))
Parameters
----------
x: float
number to apply contrast function to
der: boolean
whether to apply derivative (or base version)
Returns
-------
Expand All @@ -370,7 +380,10 @@ def der_exp_sq(x):

# first derivitive of exp((-x^2/2)) = -e^(-x^2/2) x
pwr_x = -(x ** 2) / 2
rtn = -(np.exp(pwr_x) * x)
if der == True:
rtn = -(np.exp(pwr_x) * x)
else:
rtn = np.exp(pwr_x)

return rtn

Expand Down

0 comments on commit b0107ad

Please sign in to comment.