Skip to content

Commit

Permalink
fix: correct parameters in refinement function
Browse files Browse the repository at this point in the history
  • Loading branch information
danfke committed May 20, 2022
1 parent ddad019 commit 6b1d198
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/emgdecompy/emgdecompy.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,20 +488,22 @@ def separation(z, B, Tolx=10e-4, fun=skew, max_iter=10):
return w_curr


def refinement(w_i, z, th_sil=0.9, name="", max_iter=10):
def refinement(w_i, z, i, th_sil=0.9, filepath="", max_iter=10):
"""
Parameters
----------
w_i: numpy.ndarray
current separation vector to refine
z: numpy.ndarray
xtended, whitened, centered emg data
i: int
iteration number
max_iter: int > 0
maximum iterations for refinement
th_sil: float
silhouette score threshold for accepting a separation vector
name: str
name to be used when saving pulse trains
filepath: str
filepath/name to be used when saving pulse trains
Returns
-------
Expand Down Expand Up @@ -566,7 +568,7 @@ def refinement(w_i, z, th_sil=0.9, name="", max_iter=10):

# Save pulse train
pd.DataFrame(pt_n, columns=["pulse_train"]).rename_axis("sample").to_csv(
f"{name}_PT_{i}"
f"{filepath}_PT_{i}"
)

# If silhouette score is greater than threshold, accept estimated source and add w_i to B
Expand All @@ -581,7 +583,7 @@ def refinement(w_i, z, th_sil=0.9, name="", max_iter=10):


def decomposition(
x, M=64, Tolx=10e-4, fun=skew, max_iter_sep=10, th_sil=0.9, name="", max_iter_ref=10
x, M=64, Tolx=10e-4, fun=skew, max_iter_sep=10, th_sil=0.9, filepath="", max_iter_ref=10
):
"""
Main function duplicating algorithm
Expand Down Expand Up @@ -636,6 +638,6 @@ def decomposition(
w_i = separation(z, B, Tolx, fun, max_iter_sep)

# Refine
B[:i] = refinement(w_i, z, max_iter_ref, th_sil, name)
B[:i] = refinement(w_i, z, i, max_iter_ref, th_sil, filepath)

return B

0 comments on commit 6b1d198

Please sign in to comment.