diff --git a/src/emgdecompy/contrast.py b/src/emgdecompy/contrast.py index ab1d436..593241b 100644 --- a/src/emgdecompy/contrast.py +++ b/src/emgdecompy/contrast.py @@ -1,73 +1,78 @@ import numpy as np +import warnings -def skew(x, der=False): +def skew(w, der=False): """ Applies contrast function (if der=False) or first derivative of contrast function (if der=True) to w. - skew = x^3 / 3 + skew = w^3 / 3 Parameters ---------- - x: float - Number to apply contrast function to. + w: np.array + Array to apply contrast function to. der: boolean Whether to apply derivative (or base version). Returns ------- - float - Float with contrast function applied. + np.array + Array with contrast function applied, same shape as w. Examples -------- - >>> x = 4 - >>> skew(x, der=True) - 16 + >>> w = np.array([1, 2, 3, 800]) + >>> skew(w, der=True) + array([1, 4, 9, 640000]) """ - # first derivative of x^3/3 = x^2 + # first derivitive of x^3/3 = x^2 if der == True: - rtn = x ** 2 + rtn = w ** 2 else: - rtn = (x ** 3) / 3 + rtn = (w ** 3) / 3 return rtn -def log_cosh(x, der=False): +def log_cosh(w, der=False): """ Applies contrast function (if der=False) or first derivative of contrast function (if der=True) - to w. - function = log(cosh(x)) + to each element of w. + function = log(cosh(w)) + Parameters ---------- - x: float - Number to apply contrast function to. + w: np.array + Array to apply contrast function to. der: boolean Whether to apply derivative (or base version). + Returns ------- - float - Float with contrast function applied. + np.array + Array with contrast function applied, same shape as w. + Examples -------- - >>> x = 4 - >>> log_cosh(x) - 3.3071882258129506 + >>> w = np.array([1, 2, 3, 800]) + >>> log_cosh(w) + array([4.33780830e-01, 1.32500275e+00, 2.30932850e+00, 7.99300000e+02]) """ - # first derivative of log(cosh(x)) = tanh(x) + # First derivitive of log(cosh(x)) = tanh(x) if der == True: - rtn = np.tanh(x) + rtn = np.tanh(w) else: - x = abs(x) - if x > 710: # cosh(x) breaks for abs(x) > 710 - rtn = x - 0.7 - else: - rtn = np.log(np.cosh(x)) + warnings.filterwarnings( + "ignore" + ) # To avoid warning from np.cosh(w) for values over 710 + x = abs(w) + rtn = np.where(w > 710, w - 0.7, np.log(np.cosh(w))) + warnings.resetwarnings() return rtn @@ -77,28 +82,28 @@ def exp_sq(x, der=False): Applies contrast function (if der=False) or first derivative of contrast function (if der=True) to w. - function = exp((-x^2/2)) + exp_sq = exp((-x^2/2)) Parameters ---------- - x: float - Number to apply contrast function to. + w: np.array + Array to apply contrast function to. der: boolean Whether to apply derivative (or base version). Returns ------- - float - Float with contrast function applied. + np.array + Array with contrast function applied, same shape as w. Examples -------- - >>> x = 4 - >>> exp_sq(4, der=True) - -0.0013418505116100474 + >>> w = np.array([1, 2, 3, 800]) + >>> exp_sq(w, der=False) + array([0.60653066, 0.13533528, 0.011109, 0.]) """ - # first derivative of exp((-x^2/2)) = -e^(-x^2/2) x + # first derivitive of exp((-x^2/2)) = -e^(-x^2/2) x pwr_x = -(x ** 2) / 2 if der == True: rtn = -(np.exp(pwr_x) * x) @@ -130,8 +135,15 @@ def apply_contrast(w, fun=skew, der=False): -------- >>> w = np.array([1, 2, 3]) >>> fun = skew - >>> apply_contrast(w, fun) + >>> apply_contrast(w, fun, True) array([1, 4, 9]) + + >>> w = np.array([0.01, 0.1, 1, 10, 100, 1000]) + >>> fun = log_cosh + >>> apply_contrast(w, fun) + array([4.99991667e-05, 4.99168882e-03, 4.33780830e-01, 9.30685282e+00, + 9.93068528e+01, 9.99300000e+02]) + """ rtn = fun(w, der)