From e54660e934526bc71341227b50676a4f9c172f9c Mon Sep 17 00:00:00 2001 From: Jasmine Ortega Date: Tue, 28 Jun 2022 14:34:17 -0700 Subject: [PATCH] docs: more viz comments --- tests/test_viz.py | 78 +++++++++++++++++++++++++++++------------------ 1 file changed, 49 insertions(+), 29 deletions(-) diff --git a/tests/test_viz.py b/tests/test_viz.py index e41f76e..72d8c14 100644 --- a/tests/test_viz.py +++ b/tests/test_viz.py @@ -5,13 +5,18 @@ import panel import pytest +# Test script for all functions defined in src/viz.py + +# Note: Fixtures are special PyTest objects calld into individual tests, +# they are useful when data is repeatedly required to test functions + @pytest.fixture def mu(): """ Create fake indices of pulse trains for two motor units. """ - # ptl + # Pulse train mu_values = np.array([[32, 90], [250, 300]]) return mu_values @@ -22,11 +27,11 @@ def fx_data(): """ Create subset of EMG data to test with. """ - # load data + # Load data gl_10 = loadmat("data/raw/GL_10.mat") raw = gl_10["SIG"] - # select two channels from raw data + # Select two channels from raw data data = raw[1, 1:3] return data @@ -88,12 +93,14 @@ def test_RMSE(): """ Run unit test on RMSE function from EMGdecomPy. """ + # Create data actual = np.array([0, 10, 50, 75]) predicted = np.array([0.1, 10.1, 50.1, 75.1]) - # hand calculate MSE + # Hand calculate mean squared error mse = np.sum((actual - predicted) ** 2) / len(actual) + # Hand calculate root mean squared error rmse = np.sqrt(mse) rmse_fx = emg.viz.RMSE(actual, predicted) @@ -106,25 +113,26 @@ def test_muap_dict(fx_data, mu): Run unit test on muap_dict function from EMGdecomPy. """ - # actual function + # Function to test fx = emg.viz.muap_dict(fx_data, mu, l=2) - # hand calculating avg - l = 2 - + # Create muap_dict using a different method raw_flat = emg.preprocessing.flatten_signal(fx_data) - + l = 2 mu = mu.squeeze() - all_peak_idx = [] # list of all peaks in pulse train + all_peak_idx = [] # List of all peaks in pulse train + # For each motor unit, collect the indices around a firing (+/-l) + # This allows us to visualize the entire shape of the peak for i in mu: k = 0 while k <= 1: firing = i[k] - - if np.less(firing, l) == True: + + # Edge case where MU fires at value < l, prevents negative indexing + if np.less(firing, l) == True: idx = np.arange(firing - l, firing + l + 1) neg_idx = abs(firing - l) idx[:neg_idx] = np.repeat(0, neg_idx) @@ -135,12 +143,14 @@ def test_muap_dict(fx_data, mu): all_peak_idx.append(idx) k += 1 + # Grab values of peaks + surrouding range (+/- l) peaks = raw_flat[:, all_peak_idx] signal = np.zeros((2, 2, 5)) n_mu = mu.shape[1] + # Calculate average shape of peaks across a single channel for i in range(0, n_mu): if i == 0: avg = peaks[:, 0:n_mu].mean(axis=1) @@ -148,7 +158,7 @@ def test_muap_dict(fx_data, mu): avg = peaks[:, n_mu:].mean(axis=1) signal[i] = avg - # test sample length (plotting purposes) + # Test sample length (plotting purposes) x, y, z = signal.shape assert y * z == len(fx["mu_0"]["signal"]), "Signal length incorrect." @@ -160,7 +170,7 @@ def test_muap_dict(fx_data, mu): assert y * z == len(fx["mu_0"]["channel"]), "Channel length incorrect." assert y * z == len(fx["mu_1"]["channel"]), "Channel length incorrect." - # test values of avg signal + # Test values of avg signal assert np.array_equal( signal[0].flatten(), fx["mu_0"]["signal"] ), "Average of motor unit signal incorrectly calculated." @@ -173,13 +183,14 @@ def test_muap_dict_by_peak(fx_data): """ Run unit test on muap_dict_by_peak function from EMGdecomPy. """ - # create dictionary to test + # Create dictionary to test peak_dict = emg.viz.muap_dict_by_peak(fx_data, 100, mu_index=1, l=2) channel = peak_dict["mu_1"]["channel"] + # l = 1/2 length of firing l = 2 x = fx_data.shape[0] - rng = l * 2 + 1 + rng = l * 2 + 1 # entire range of firing sample_range = np.arange(0, rng) signal = peak_dict["mu_1"]["signal"] @@ -198,10 +209,10 @@ def test_muap_plot(fx_data, mu): shape_dict = emg.viz.muap_dict(fx_data, mu, l=2) - for i in range(0, 2): # test motor unit 1 and 2 + for i in range(0, 2): # Test motor unit 1 and 2 plots = emg.viz.muap_plot(shape_dict, i) - # test dictionary correctly converted to df + # Test dictionary correctly converted to df len_data = len(plots.data) len_input = len(shape_dict[f"mu_{i}"]["sample"]) @@ -215,7 +226,7 @@ def test_mismatch_scores(avg_mu_shape, avg_peak_shape): Run unit test on mismatch function from EMGdecomPy. """ - # test error across all channels for mu_1 + # Test error across all channels for mu_1 fx_output = emg.viz.mismatch_score(avg_mu_shape, avg_peak_shape, mu_index=1) x = avg_mu_shape["mu_1"]["signal"] @@ -227,7 +238,7 @@ def test_mismatch_scores(avg_mu_shape, avg_peak_shape): for num, i in enumerate(avg_peak_shape): - # test single channel error for single motor unit + # Test error across a single channel for both motor units fx_output = emg.viz.mismatch_score( avg_mu_shape, avg_peak_shape, mu_index=num, channel=0 ) @@ -244,7 +255,7 @@ def test_channel_preset(): """ Run unit test on channel_preset function from EMGdecomPy. """ - # test standard orientation + # Test standard orientation std = emg.viz.channel_preset(preset="standard") @@ -252,7 +263,7 @@ def test_channel_preset(): assert len(std["sort_order"]) == 64, "Standard orientation incorrect." assert std["cols"] == 8, "Standard orientation incorrect." - # test vert63 orientation + # Test vert63 orientation std = emg.viz.channel_preset(preset="vert63") @@ -260,15 +271,18 @@ def test_channel_preset(): assert std["sort_order"][0] == 63, "Standard orientation incorrect." assert len(std["sort_order"]) == 64, "Standard orientation incorrect." assert std["cols"] == 5, "Standard orientation incorrect." - - + def test_pulse_plot(fx_data): """ Run unit test on pulse_plot function from EMGdecomPy. """ - # note: doesnt appear I can test the individual plots that make up this concat'd dashboard + # Note: the individual plots are not accessible in a concat'd dashboard, + # so these tests are rather simple + # Create two motor unit pulse trains pt = np.array([[10, 60, 120], [15, 65, 125]]) + + # Pre-process data signal = emg.preprocessing.flatten_signal(fx_data) signal = np.apply_along_axis( emg.preprocessing.butter_bandpass_filter, @@ -280,6 +294,7 @@ def test_pulse_plot(fx_data): order=6, ) + # Calculate square mean of centered data centered = emg.preprocessing.center_matrix(signal) c_sq = centered**2 c_sq_mean = c_sq.mean(axis=0) @@ -288,6 +303,7 @@ def test_pulse_plot(fx_data): for i, j in enumerate(pt): plt = emg.viz.pulse_plot(pt, c_sq_mean, mu_index=i) + # Access dataframe used in plot df = plt.data df = df["Pulse"].to_numpy() @@ -304,10 +320,11 @@ def test_select_peak(fx_data, mu): """ dic = emg.viz.muap_dict(fx_data, mu, l=2) - # test empty selection - + # Test empty peak selection select = [] pulse = [[100], [200]] + + # Altair objects, like plot, can be indexed into to access individual plots plot = emg.viz.select_peak( selection=select, mu_index=1, raw=fx_data, shape_dict=dic, pt=pulse ) @@ -329,7 +346,7 @@ def test_dashboard(fake_decomp, fx_data): """ Run unit test on dashboard function from EMGdecomPy. """ - # there are not a lot of attributes to test for with concat'd plots + # There are not a lot of attributes to test for with Panel objects for i, decomp_pulse in enumerate(fake_decomp["MUPulses"]): dash = emg.viz.dashboard(fake_decomp, fx_data, i) @@ -340,7 +357,7 @@ def test_dashboard(fake_decomp, fx_data): type(dash[0].object) == alt.vegalite.v4.api.VConcatChart ), "Object returned is not concatenated plots." - # check that plotted pulses match input + # Check that plotted pulses match input assert np.all(df_pulses == decomp_pulse), "MU Pulses incorrectly plotted." @@ -351,6 +368,9 @@ def test_visualize_decomp(fake_decomp, fx_data): x = emg.viz.visualize_decomp(fake_decomp, fx_data) + # Concat'd Panel objects can be indexed into + # x[0] are the widgets (dropdown menus) + # x[1] are the actual plots assert x[0][0].values == [0, 1] assert x[0][1].values == ["standard", "vert63"] assert x[0][2].values == ["RMSE"]