Skip to content

Commit

Permalink
fix: rewrite child contrast functions and docstrings to do arrays, re…
Browse files Browse the repository at this point in the history
…write parent apply_contrast function docstring accordingly
  • Loading branch information
Radascript committed May 25, 2022
1 parent bf901e2 commit 2bc1fbf
Showing 1 changed file with 51 additions and 39 deletions.
90 changes: 51 additions & 39 deletions src/emgdecompy/contrast.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,78 @@
import numpy as np
import warnings


def skew(x, der=False):
def skew(w, der=False):
"""
Applies contrast function (if der=False) or
first derivative of contrast function (if der=True)
to w.
skew = x^3 / 3
skew = w^3 / 3
Parameters
----------
x: float
Number to apply contrast function to.
w: np.array
Array to apply contrast function to.
der: boolean
Whether to apply derivative (or base version).
Returns
-------
float
Float with contrast function applied.
np.array
Array with contrast function applied, same shape as w.
Examples
--------
>>> x = 4
>>> skew(x, der=True)
16
>>> w = np.array([1, 2, 3, 800])
>>> skew(w, der=True)
array([1, 4, 9, 640000])
"""

# first derivative of x^3/3 = x^2
# first derivitive of x^3/3 = x^2
if der == True:
rtn = x ** 2
rtn = w ** 2
else:
rtn = (x ** 3) / 3
rtn = (w ** 3) / 3

return rtn


def log_cosh(x, der=False):
def log_cosh(w, der=False):
"""
Applies contrast function (if der=False) or
first derivative of contrast function (if der=True)
to w.
function = log(cosh(x))
to each element of w.
function = log(cosh(w))
Parameters
----------
x: float
Number to apply contrast function to.
w: np.array
Array to apply contrast function to.
der: boolean
Whether to apply derivative (or base version).
Returns
-------
float
Float with contrast function applied.
np.array
Array with contrast function applied, same shape as w.
Examples
--------
>>> x = 4
>>> log_cosh(x)
3.3071882258129506
>>> w = np.array([1, 2, 3, 800])
>>> log_cosh(w)
array([4.33780830e-01, 1.32500275e+00, 2.30932850e+00, 7.99300000e+02])
"""

# first derivative of log(cosh(x)) = tanh(x)
# First derivitive of log(cosh(x)) = tanh(x)
if der == True:
rtn = np.tanh(x)
rtn = np.tanh(w)
else:
x = abs(x)
if x > 710: # cosh(x) breaks for abs(x) > 710
rtn = x - 0.7
else:
rtn = np.log(np.cosh(x))
warnings.filterwarnings(
"ignore"
) # To avoid warning from np.cosh(w) for values over 710
x = abs(w)
rtn = np.where(w > 710, w - 0.7, np.log(np.cosh(w)))
warnings.resetwarnings()

return rtn

Expand All @@ -77,28 +82,28 @@ def exp_sq(x, der=False):
Applies contrast function (if der=False) or
first derivative of contrast function (if der=True)
to w.
function = exp((-x^2/2))
exp_sq = exp((-x^2/2))
Parameters
----------
x: float
Number to apply contrast function to.
w: np.array
Array to apply contrast function to.
der: boolean
Whether to apply derivative (or base version).
Returns
-------
float
Float with contrast function applied.
np.array
Array with contrast function applied, same shape as w.
Examples
--------
>>> x = 4
>>> exp_sq(4, der=True)
-0.0013418505116100474
>>> w = np.array([1, 2, 3, 800])
>>> exp_sq(w, der=False)
array([0.60653066, 0.13533528, 0.011109, 0.])
"""

# first derivative of exp((-x^2/2)) = -e^(-x^2/2) x
# first derivitive of exp((-x^2/2)) = -e^(-x^2/2) x
pwr_x = -(x ** 2) / 2
if der == True:
rtn = -(np.exp(pwr_x) * x)
Expand Down Expand Up @@ -130,8 +135,15 @@ def apply_contrast(w, fun=skew, der=False):
--------
>>> w = np.array([1, 2, 3])
>>> fun = skew
>>> apply_contrast(w, fun)
>>> apply_contrast(w, fun, True)
array([1, 4, 9])
>>> w = np.array([0.01, 0.1, 1, 10, 100, 1000])
>>> fun = log_cosh
>>> apply_contrast(w, fun)
array([4.99991667e-05, 4.99168882e-03, 4.33780830e-01, 9.30685282e+00,
9.93068528e+01, 9.99300000e+02])
"""

rtn = fun(w, der)
Expand Down

0 comments on commit 2bc1fbf

Please sign in to comment.