Skip to content

Commit

Permalink
fix: rearrange arguments in decomposion(), separation(), refinement()…
Browse files Browse the repository at this point in the history
… functions
  • Loading branch information
Radascript committed May 20, 2022
1 parent 76e0f59 commit ddad019
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions src/emgdecompy/emgdecompy.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def separation(z, B, Tolx=10e-4, fun=skew, max_iter=10):
return w_curr


def refinement(w_i, z, max_iter=10, th_sil=0.9, name=""):
def refinement(w_i, z, th_sil=0.9, name="", max_iter=10):
"""
Parameters
----------
Expand Down Expand Up @@ -580,7 +580,9 @@ def refinement(w_i, z, max_iter=10, th_sil=0.9, name=""):
return w_i # May change implementation to update B here


def decomposition(x, M=64, th_sil=0.9, name="", Tolx=10e-4, fun=skew, max_iter=10):
def decomposition(
x, M=64, Tolx=10e-4, fun=skew, max_iter_sep=10, th_sil=0.9, name="", max_iter_ref=10
):
"""
Main function duplicating algorithm
Performs decomposition of input of observations
Expand All @@ -589,16 +591,19 @@ def decomposition(x, M=64, th_sil=0.9, name="", Tolx=10e-4, fun=skew, max_iter=1
----------
x: numpy.ndarray
the input matrix
th_sil: float
threshold silloutte score
Tolx: float
Tolx for element-wise comparison in separation
fun: function
Contrast function to use
skew, og_cosh or exp_sq
max_iter: int > 1
max_iter_sep: int > 1
maximum iterations for Fixed Point Algorithm
when to stop if it doesn't converge
th_sil: float
threshold silloutte score,
max_iter_ref: int > 1
maximum iterations for Refining process
when to stop if it doesn't converge
name: str
name to be used when saving pulse trains
Expand Down Expand Up @@ -628,9 +633,9 @@ def decomposition(x, M=64, th_sil=0.9, name="", Tolx=10e-4, fun=skew, max_iter=1
for i in range(M):

# Separate
w_i = separation(z, B, Tolx, max_iter)
w_i = separation(z, B, Tolx, fun, max_iter_sep)

# Refine
B[:i] = refinement(w_i, z, max_iter=10, th_sil=0.9, name="")
B[:i] = refinement(w_i, z, max_iter_ref, th_sil, name)

return B

0 comments on commit ddad019

Please sign in to comment.