diff --git a/pyrasa/utils/irasa_utils.py b/pyrasa/utils/irasa_utils.py index f19b2c2..85d4571 100644 --- a/pyrasa/utils/irasa_utils.py +++ b/pyrasa/utils/irasa_utils.py @@ -49,14 +49,14 @@ def _get_windows( ) -> tuple[np.ndarray, np.ndarray]: """Generate a window function used for tapering""" low_bias_ratio = 0.9 - max_time_bandwidth = 2.0 + min_time_bandwidth = 2.0 win_func_kwargs = copy(win_func_kwargs) # special settings in case multitapering is required if win_func == dsp.windows.dpss: time_bandwidth = dpss_settings['time_bandwidth'] - if time_bandwidth < max_time_bandwidth: - raise ValueError(f'time_bandwidth should be >= {max_time_bandwidth} for good tapers') + if time_bandwidth > min_time_bandwidth: + raise ValueError(f'time_bandwidth should be >= {min_time_bandwidth} for good tapers') n_taps = int(np.floor(time_bandwidth - 1)) win_func_kwargs.update( diff --git a/tests/test_irasa_sprint.py b/tests/test_irasa_sprint.py index bc752f3..8d1a954 100644 --- a/tests/test_irasa_sprint.py +++ b/tests/test_irasa_sprint.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from neurodsp.utils.sim import set_random_seed from pyrasa.irasa import irasa_sprint @@ -80,7 +81,7 @@ def test_irasa_sprint_settings(ts4sprint): # test dpss import scipy.signal as dsp - sgramm_ap, sgramm_p, freqs_ir, times_ir = irasa_sprint( + irasa_sprint( ts4sprint[np.newaxis, :], fs=500, band=(1, 100), @@ -89,3 +90,16 @@ def test_irasa_sprint_settings(ts4sprint): # smooth=False, # n_avgs=[3, 7, 11], ) + + # test too much bandwidht + with pytest.raises(ValueError): + irasa_sprint( + ts4sprint[np.newaxis, :], + fs=500, + band=(1, 100), + win_func=dsp.windows.dpss, + dpss_settings_time_bandwidth=1, + freq_res=0.5, + # smooth=False, + # n_avgs=[3, 7, 11], + ) diff --git a/tests/test_peak_detect.py b/tests/test_peak_detect.py index f1652a9..c8ac781 100644 --- a/tests/test_peak_detect.py +++ b/tests/test_peak_detect.py @@ -24,6 +24,18 @@ def test_peak_detection(oscillation, fs, osc_freq): assert bool(np.isclose(pe_filt['cf'][0], osc_freq, atol=2)) +@pytest.mark.parametrize('fs, exponent', [(500, -1)], scope='session') +def test_no_peak_detection(fixed_aperiodic_signal, fs): + f_range = [1, 250] + # test whether recombining periodic and aperiodic spectrum is equivalent to the original spectrum + freqs, psd = dsp.welch(fixed_aperiodic_signal, fs, nperseg=int(4 * fs)) + freq_logical = np.logical_and(freqs >= f_range[0], freqs <= f_range[1]) + freqs, psd = freqs[freq_logical], psd[freq_logical] + # test whether we can reconstruct the peaks correctly + pe_params = get_peak_params(psd[np.newaxis, :], freqs, min_peak_height=0.1) + assert pe_params.shape[0] == 0 + + @pytest.mark.parametrize('osc_freq, fs', [(10, 500)], scope='session') def test_peak_detection_setings(oscillation, fs, osc_freq): f_range = [1, 250]