Skip to content

Commit

Permalink
make sampling irregular again
Browse files Browse the repository at this point in the history
  • Loading branch information
Fabi committed Aug 20, 2024
1 parent 1798ef8 commit e5684b3
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 579 deletions.
509 changes: 0 additions & 509 deletions examples/hset_optimization.ipynb

This file was deleted.

28 changes: 14 additions & 14 deletions examples/irasa_sprint.ipynb

Large diffs are not rendered by default.

11 changes: 3 additions & 8 deletions pyrasa/irasa.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def irasa(
_check_irasa_settings(irasa_params=irasa_params, hset_info=hset_info)

hset = np.round(np.arange(*hset_info), hset_accuracy)
hset = [h for h in hset if h % 1 != 0] # filter integers

win_kwargs = {'win_func': win_func, 'win_func_kwargs': win_func_kwargs}
dpss_settings = {
Expand Down Expand Up @@ -163,15 +164,12 @@ def _local_irasa_fun(

freq, psd_aperiodic, psd_periodic, psd = _crop_data(band, freq, psd_aperiodic, psd_periodic, psd, axis=-1)

del irasa_params['data']
irasa_params['hmax'] = hset_info[1]
return IrasaSpectrum(
freqs=freq,
raw_spectrum=psd,
aperiodic=psd_aperiodic,
periodic=psd_periodic,
ch_names=ch_names,
irasa_settings=irasa_params,
)


Expand All @@ -181,7 +179,6 @@ def irasa_sprint( # noqa PLR0915 C901
fs: int,
ch_names: np.ndarray | None = None,
band: tuple[float, float] = (1.0, 100.0),
freq_res: float = 0.5,
win_duration: float = 0.4,
hop: int = 10,
win_func: Callable = dsp.windows.hann,
Expand Down Expand Up @@ -210,8 +207,6 @@ def irasa_sprint( # noqa PLR0915 C901
Channel names associated with the data, if available. Default is None.
band : tuple[float, float], optional
The frequency range (lower and upper bounds in Hz) over which to compute the spectra. Default is (1.0, 100.0).
freq_res : float, optional
Desired frequency resolution in Hz. Default is 0.5 Hz.
win_duration : float, optional
Duration of the window in seconds used for the short-time Fourier transforms (STFTs). Default is 0.4 seconds.
hop : int, optional
Expand Down Expand Up @@ -283,8 +278,9 @@ def irasa_sprint( # noqa PLR0915 C901
_check_irasa_settings(irasa_params=irasa_params, hset_info=hset_info)

hset = np.round(np.arange(*hset_info), hset_accuracy)
hset = [h for h in hset if h % 1 != 0] # filter integers

nfft = int(fs / freq_res)
nfft = int(2 ** np.ceil(np.log2(np.max(hset) * win_duration * fs)))
win_kwargs = {'win_func': win_func, 'win_func_kwargs': win_func_kwargs}
dpss_settings = {
'time_bandwidth': dpss_settings_time_bandwidth,
Expand Down Expand Up @@ -321,7 +317,6 @@ def _local_irasa_fun(
time=time,
)

# NOTE: we need to transpose the data as crop_data extracts stuff from the last axis
freq, sgramm_aperiodic, sgramm_periodic, sgramm = _crop_data(
band, freq, sgramm_aperiodic, sgramm_periodic, sgramm, axis=0
)
Expand Down
1 change: 0 additions & 1 deletion pyrasa/utils/fit_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def _get_gof(psd: np.ndarray, psd_pred: np.ndarray, k: int, fit_type: str) -> pd
For further details on BIC and AIC, see: https://machinelearningmastery.com/probabilistic-model-selection-measures/
"""

# add np.log10 to psd
residuals = psd - psd_pred
ss_res = np.sum(residuals**2)
ss_tot = np.sum((psd - np.mean(psd)) ** 2)
Expand Down
42 changes: 0 additions & 42 deletions pyrasa/utils/irasa_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
import pandas as pd
from attrs import define
from scipy.signal import medfilt

from pyrasa.utils.aperiodic_utils import compute_aperiodic_model
from pyrasa.utils.fit_funcs import AbstractFitFun
Expand All @@ -18,7 +17,6 @@ class IrasaSpectrum:
aperiodic: np.ndarray
periodic: np.ndarray
ch_names: np.ndarray | None
irasa_settings: dict

def fit_aperiodic_model(
self,
Expand Down Expand Up @@ -158,43 +156,3 @@ def get_peaks(
polyorder=polyorder,
peak_width_limits=peak_width_limits,
)

def get_aperiodic_error(self, kernel_size: int | np.ndarray) -> pd.DataFrame:
"""
Computes the error of the aperiodic_fit.
The function takes the absolut of the periodic spectrum and gets rid of spectral peaks or oscillation
using median filtering. Afterwards the area under the curve is computed.
If the kernel size of the median filter is specified correctly this function can
be used to get an estimate of the aperiodic model error.
Parameters
----------
kernel_size: A scalar or an N-length list giving the size of the median filter window in each dimension.
Elements of kernel_size should be odd.
If kernel_size is a scalar, then this scalar is used as the size in each dimension.
Returns
----------
Definite integral of the aperiodic error as approximated by the trapezoidal rule. (see numpy.trapz)
"""

ch_names = self.ch_names
assert isinstance(
ch_names, list | tuple | np.ndarray | None
), 'Channel names should be of type list, tuple or numpy.ndarray or None'

if ch_names is None:
ch_names = np.arange(self.periodic.shape[0])

aperiodic_error = [
np.trapz(medfilt(np.abs(self.periodic[ix, :]), kernel_size=kernel_size)) for ix, ch in enumerate(ch_names)
]

df_aperiodic_error = pd.DataFrame({'error': aperiodic_error, 'ch_name': ch_names})
df_aperiodic_error['hmax'] = self.irasa_settings['hmax']
df_aperiodic_error['lower_band'] = self.irasa_settings['band'][0]
df_aperiodic_error['upper_band'] = self.irasa_settings['band'][1]
df_aperiodic_error['frequency_resolution'] = self.freqs[1] - self.freqs[0]

return df_aperiodic_error
33 changes: 28 additions & 5 deletions simulations/notebooks/test_basic_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import scipy.signal as dsp
import matplotlib.pyplot as plt
import seaborn as sns
from pyrasa import irasa

sns.set_style('ticks')
sns.set_context('poster')
Expand Down Expand Up @@ -32,16 +33,38 @@
plt.tight_layout()


from pyrasa.irasa import irasa
from pyrasa.irasa import irasa, irasa_sprint

# %%
freq_irasa, psd_ap, psd_p = irasa(
irasa_psd = irasa(
sig,
fs=fs,
band=(1, 150),
kwargs_psd={'nperseg': duration * fs, 'noverlap': duration * fs * overlap},
hset_info=(1, 2, 0.05),
band=(1, 50),
psd_kwargs={'nperseg': duration * fs, 'noverlap': duration * fs * overlap},
hset_info=(1, 4, 0.05),
)

#%%
irasa_out_tf = irasa_sprint(
sig,
fs=fs,
band=(1, 50),
win_duration=4,
hset_info=(1, 3, 0.1),
)
#%%
from neurodsp.plts import plot_timefrequency
plot_timefrequency(times=irasa_out_tf.time,
freqs=irasa_out_tf.freqs,
powers=irasa_out_tf.periodic)

#%%
plot_timefrequency(times=irasa_out_tf.time,
freqs=irasa_out_tf.freqs,
powers=irasa_out_tf.aperiodic)



# %%
f, axes = plt.subplots(ncols=2, figsize=(8, 4))
axes[0].set_title('Periodic')
Expand Down

0 comments on commit e5684b3

Please sign in to comment.