From 6b1d198a58a29e7d783131d5c53de8b64885eaaa Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 20 May 2022 10:28:59 -0700 Subject: [PATCH] fix: correct parameters in refinement function --- src/emgdecompy/emgdecompy.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/emgdecompy/emgdecompy.py b/src/emgdecompy/emgdecompy.py index 74eadb7..b7cb37c 100644 --- a/src/emgdecompy/emgdecompy.py +++ b/src/emgdecompy/emgdecompy.py @@ -488,7 +488,7 @@ def separation(z, B, Tolx=10e-4, fun=skew, max_iter=10): return w_curr -def refinement(w_i, z, th_sil=0.9, name="", max_iter=10): +def refinement(w_i, z, i, th_sil=0.9, filepath="", max_iter=10): """ Parameters ---------- @@ -496,12 +496,14 @@ def refinement(w_i, z, th_sil=0.9, name="", max_iter=10): current separation vector to refine z: numpy.ndarray xtended, whitened, centered emg data + i: int + iteration number max_iter: int > 0 maximum iterations for refinement th_sil: float silhouette score threshold for accepting a separation vector - name: str - name to be used when saving pulse trains + filepath: str + filepath/name to be used when saving pulse trains Returns ------- @@ -566,7 +568,7 @@ def refinement(w_i, z, th_sil=0.9, name="", max_iter=10): # Save pulse train pd.DataFrame(pt_n, columns=["pulse_train"]).rename_axis("sample").to_csv( - f"{name}_PT_{i}" + f"{filepath}_PT_{i}" ) # If silhouette score is greater than threshold, accept estimated source and add w_i to B @@ -581,7 +583,7 @@ def refinement(w_i, z, th_sil=0.9, name="", max_iter=10): def decomposition( - x, M=64, Tolx=10e-4, fun=skew, max_iter_sep=10, th_sil=0.9, name="", max_iter_ref=10 + x, M=64, Tolx=10e-4, fun=skew, max_iter_sep=10, th_sil=0.9, filepath="", max_iter_ref=10 ): """ Main function duplicating algorithm @@ -636,6 +638,6 @@ def decomposition( w_i = separation(z, B, Tolx, fun, max_iter_sep) # Refine - B[:i] = refinement(w_i, z, max_iter_ref, th_sil, name) + B[:i] = refinement(w_i, z, i, max_iter_ref, th_sil, filepath) return B