Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Schmidt Fabian committed Jul 29, 2024
1 parent b3d946e commit f210b9b
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 4 deletions.
6 changes: 3 additions & 3 deletions pyrasa/utils/irasa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ def _get_windows(
) -> tuple[np.ndarray, np.ndarray]:
"""Generate a window function used for tapering"""
low_bias_ratio = 0.9
max_time_bandwidth = 2.0
min_time_bandwidth = 2.0
win_func_kwargs = copy(win_func_kwargs)

# special settings in case multitapering is required
if win_func == dsp.windows.dpss:
time_bandwidth = dpss_settings['time_bandwidth']
if time_bandwidth < max_time_bandwidth:
raise ValueError(f'time_bandwidth should be >= {max_time_bandwidth} for good tapers')
if time_bandwidth > min_time_bandwidth:
raise ValueError(f'time_bandwidth should be >= {min_time_bandwidth} for good tapers')

n_taps = int(np.floor(time_bandwidth - 1))
win_func_kwargs.update(
Expand Down
16 changes: 15 additions & 1 deletion tests/test_irasa_sprint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest
from neurodsp.utils.sim import set_random_seed

from pyrasa.irasa import irasa_sprint
Expand Down Expand Up @@ -80,7 +81,7 @@ def test_irasa_sprint_settings(ts4sprint):
# test dpss
import scipy.signal as dsp

sgramm_ap, sgramm_p, freqs_ir, times_ir = irasa_sprint(
irasa_sprint(
ts4sprint[np.newaxis, :],
fs=500,
band=(1, 100),
Expand All @@ -89,3 +90,16 @@ def test_irasa_sprint_settings(ts4sprint):
# smooth=False,
# n_avgs=[3, 7, 11],
)

# test too much bandwidht
with pytest.raises(ValueError):
irasa_sprint(
ts4sprint[np.newaxis, :],
fs=500,
band=(1, 100),
win_func=dsp.windows.dpss,
dpss_settings_time_bandwidth=1,
freq_res=0.5,
# smooth=False,
# n_avgs=[3, 7, 11],
)
12 changes: 12 additions & 0 deletions tests/test_peak_detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ def test_peak_detection(oscillation, fs, osc_freq):
assert bool(np.isclose(pe_filt['cf'][0], osc_freq, atol=2))


@pytest.mark.parametrize('fs, exponent', [(500, -1)], scope='session')
def test_no_peak_detection(fixed_aperiodic_signal, fs):
f_range = [1, 250]
# test whether recombining periodic and aperiodic spectrum is equivalent to the original spectrum
freqs, psd = dsp.welch(fixed_aperiodic_signal, fs, nperseg=int(4 * fs))
freq_logical = np.logical_and(freqs >= f_range[0], freqs <= f_range[1])
freqs, psd = freqs[freq_logical], psd[freq_logical]
# test whether we can reconstruct the peaks correctly
pe_params = get_peak_params(psd[np.newaxis, :], freqs, min_peak_height=0.1)
assert pe_params.shape[0] == 0


@pytest.mark.parametrize('osc_freq, fs', [(10, 500)], scope='session')
def test_peak_detection_setings(oscillation, fs, osc_freq):
f_range = [1, 250]
Expand Down

0 comments on commit f210b9b

Please sign in to comment.