Skip to content

Commit

Permalink
feat: add more print statements to refinement
Browse files Browse the repository at this point in the history
  • Loading branch information
danfke committed Jun 13, 2022
1 parent 532fde3 commit 2378e67
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions src/emgdecompy/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,26 +459,33 @@ def refinement(
# Estimate pulse train pt_n with peak detection applied to the square of the source vector
s_i2 = np.square(s_i)

# Peak-finding algorithm
peak_indices, _ = find_peaks(
s_i2, distance=l
)

# b. Use KMeans to separate large peaks from relatively small peaks, which are discarded
kmeans = KMeans(n_clusters=2, random_state=random_seed)
kmeans.fit(s_i2[peak_indices].reshape(-1, 1))

# Determine which cluster contains large peaks
centroid_a = np.argmax(
kmeans.cluster_centers_
) # Determine which cluster contains large peaks
)

# Determine which peaks are large (part of cluster a)
peak_a = ~kmeans.labels_.astype(
bool
) # Determine which peaks are large (part of cluster a)
)

if centroid_a == 1: # If cluster a corresponds to kmeans label 1, change indices correspondingly
peak_a = ~peak_a


# Get the indices of the peaks in cluster a
peak_indices_a = peak_indices[
peak_a
] # Get the indices of the peaks in cluster a
]

# Create pulse train, where values are 0 except for when MU fires, which have values of 1
# pt_n = np.zeros_like(s_i2)
Expand All @@ -488,6 +495,7 @@ def refinement(
isi = np.diff(peak_indices_a) # inter-spike intervals
cv_prev = cv_curr
cv_curr = variation(isi)

if np.isnan(cv_curr): # Translate nan to 0
cv_curr = 0

Expand All @@ -496,7 +504,7 @@ def refinement(
):
break

elif iter != max_iter - 1:
elif iter != max_iter - 1: # If we are not on the last iteration
# d. Update separation vector for next iteration unless refinement doesn't converge
j = len(peak_indices_a)
w_i = (1 / j) * z[:, peak_indices_a].sum(axis=1)
Expand All @@ -506,12 +514,18 @@ def refinement(
s_i2, peak_indices_a
)
pnr_score = pnr(s_i2, peak_indices_a)

if isi.size > 0 and verbose:
print(f"Cov(ISI): {cv_curr / isi.mean() * 100}")

if verbose:
print(f"PNR: {pnr_score}")
print(f"SIL: {sil}")
print(f"cv_curr = {cv_curr}")
print(f"cv_prev = {cv_prev}")
print(f"cv_prev = {cv_prev}")

if cv_curr > cv_prev:
print(f"Refinement converged after {iter} iterations.")

if sil_pnr:
score = sil # If using SIL as acceptance criterion
Expand Down

0 comments on commit 2378e67

Please sign in to comment.