Skip to content

Commit

Permalink
fixed test bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Fabi committed Sep 16, 2024
1 parent 4408a85 commit 84b3c66
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 5 deletions.
2 changes: 1 addition & 1 deletion examples/irasa_sprint.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
"version": "3.1.undefined"
}
},
"nbformat": 4,
Expand Down
29 changes: 29 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,35 @@ def ts4sprint(fs, exponent_1, exponent_2):
yield sim_ts


@pytest.fixture(scope='session')
def ts4sprint_knee(fs, exponent_1, exponent_2):
alpha = sim_oscillation(n_seconds=0.5, fs=fs, freq=10)
no_alpha = np.zeros(len(alpha))
beta = sim_oscillation(n_seconds=0.5, fs=fs, freq=25)
no_beta = np.zeros(len(beta))

knee1 = 20 ** np.abs(exponent_1)
knee2 = 20 ** np.abs(exponent_2)
exp_1 = sim_knee(n_seconds=2.5, fs=fs, exponent1=0, exponent2=exponent_1, knee=knee1)
exp_2 = sim_knee(n_seconds=2.5, fs=fs, exponent1=0, exponent2=exponent_2, knee=knee2)

# %%
alphas = np.concatenate([no_alpha, alpha, no_alpha, alpha, no_alpha])
betas = np.concatenate([beta, no_beta, beta, no_beta, beta])

sim_ts = np.concatenate(
[
exp_1 + alphas,
exp_1 + alphas + betas,
exp_1 + betas,
exp_2 + alphas,
exp_2 + alphas + betas,
exp_2 + betas,
]
)
yield sim_ts


@pytest.fixture(scope='session')
def gen_mne_data_raw():
data_path = sample.data_path()
Expand Down
18 changes: 14 additions & 4 deletions tests/test_irasa_knee.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ def test_aperiodic_error(load_knee_cmb_signal, fs, exponent, knee, osc_freq):
@pytest.mark.parametrize('fs', [1000], scope='session')
@pytest.mark.parametrize('exponent_1', [-0], scope='session')
@pytest.mark.parametrize('exponent_2', [-2], scope='session')
def test_aperiodic_error_tf(ts4sprint, fs, exponent, knee, osc_freq):
def test_aperiodic_error_tf(ts4sprint_knee, fs, exponent_1, exponent_2):
irasa_out = irasa_sprint(
ts4sprint,
ts4sprint_knee,
fs=fs,
band=(0.1, 50),
overlap_fraction=0.95,
Expand All @@ -125,12 +125,22 @@ def test_aperiodic_error_tf(ts4sprint, fs, exponent, knee, osc_freq):
)

irasa_out_bad = irasa_sprint(
ts4sprint,
ts4sprint_knee,
fs=fs,
band=(0.1, 50),
overlap_fraction=0.95,
win_duration=0.5,
hset_info=(1, 8.0, 0.05),
)

assert np.mean(irasa_out.get_aperiodic_error()) < np.mean(irasa_out_bad.get_aperiodic_error())
kwargs = {
'cut_spectrum': (1, 40),
'smooth': True,
'smoothing_window': 3,
'min_peak_height': 0.01,
'peak_width_limits': (0.5, 12),
}

assert np.mean(irasa_out.get_aperiodic_error(peak_kwargs=kwargs)) < np.mean(
irasa_out_bad.get_aperiodic_error(peak_kwargs=kwargs)
)

0 comments on commit 84b3c66

Please sign in to comment.