Skip to content

Commit

Permalink
Merge pull request #36 from schmidtfa/minor_fixes
Browse files Browse the repository at this point in the history
Not so minor fixes
  • Loading branch information
schmidtfa authored Jul 30, 2024
2 parents 8aea059 + dc780b6 commit 3acc2ed
Show file tree
Hide file tree
Showing 11 changed files with 188 additions and 160 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## PyRASA

[![Project Status: WIP – Initial development is in progress, but there has not yet been a stable, usable release suitable for the public.](https://www.repostatus.org/badges/latest/wip.svg)](https://www.repostatus.org/#wip)
[![License](https://img.shields.io/badge/License-BSD_2--Clause-orange.svg)](https://opensource.org/licenses/BSD-2-Clause)
[![Checked with mypy](http://www.mypy-lang.org/static/mypy_badge.svg)](http://mypy-lang.org/)
[![Coverage Status](https://coveralls.io/repos/github/schmidtfa/pyrasa/badge.svg?branch=main)](https://coveralls.io/github/schmidtfa/pyrasa?branch=main)
Expand Down
4 changes: 2 additions & 2 deletions examples/basic_func_fun.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
freq_irasa, psd_ap, psd_p = irasa(sig,
fs=fs,
band=(1, 100),
irasa_kwargs={'nperseg': duration*fs,
'noverlap': duration*fs*overlap
psd_kwargs={'nperseg': duration*fs,
'noverlap': duration*fs*overlap
},
hset_info=(1, 2, 0.05))

Expand Down
1 change: 0 additions & 1 deletion examples/check_irasa_mne.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
# %% now lets simulate some data
# TODO: put in a helper function


def simulate_raw_(signal, scaling_fator, region, subject, info, subjects_dir):
"""Shorthand function to simulate a dipole"""

Expand Down
2 changes: 1 addition & 1 deletion pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ quote-style = 'single'

[tool.pytest.ini_options]
testpaths = ["tests"]
addopts = "--ignore=__version__.py"
addopts = "--ignore=pyrasa/__version__.py"

[tool.mypy]
disable_error_code = "import-untyped"
Expand Down
178 changes: 37 additions & 141 deletions pyrasa/irasa.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import fractions
from collections.abc import Callable
from typing import TypedDict

import numpy as np
import scipy.signal as dsp
from scipy.signal import ShortTimeFFT

from pyrasa.utils.input_classes import IrasaKwargsTyped, IrasaSprintKwargsTyped

# from scipy.stats.mstats import gmean
from pyrasa.utils.irasa_utils import _check_irasa_settings, _crop_data, _find_nearest, _gen_time_from_sft, _get_windows
from pyrasa.utils.irasa_utils import (
_check_irasa_settings,
_compute_psd_welch,
_compute_sgramm,
_crop_data, # _find_nearest, _gen_time_from_sft, _get_windows,
)


# TODO: Port to Cython
Expand All @@ -17,7 +22,7 @@ def _gen_irasa(
fs: int,
irasa_fun: Callable,
hset: np.ndarray,
irasa_kwargs: dict,
irasa_kwargs: dict | IrasaKwargsTyped | IrasaSprintKwargsTyped,
time: np.ndarray | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Expand Down Expand Up @@ -64,7 +69,12 @@ def irasa(
data: np.ndarray,
fs: int,
band: tuple[float, float],
irasa_kwargs: dict,
psd_kwargs: dict,
win_func: Callable = dsp.windows.hann,
win_func_kwargs: dict | None = None,
dpss_settings_time_bandwidth: float = 2.0,
dpss_settings_low_bias: bool = True,
dpss_eigenvalue_weighting: bool = True,
filter_settings: tuple[float | None, float | None] = (None, None),
hset_info: tuple[float, float, float] = (1.05, 2.0, 0.05),
hset_accuracy: int = 4,
Expand All @@ -87,7 +97,7 @@ def irasa(
The sampling frequency of the data. Can be omitted if data is :py:class:˚mne.io.BaseRaw˚.
band : tuple
A tuple containing the lower and upper band of the frequency range used to extract (a-)periodic spectra.
irasa_kwargs : dict
psd_kwargs : dict
A dictionary containing all the keyword arguments that are passed onto `scipy.signal.welch`.
filter_settings : tuple
A tuple containing the cut-off of the High- and Lowpass filter. It is highly advisable to set this
Expand Down Expand Up @@ -115,6 +125,9 @@ def irasa(
https://doi.org/10.1007/s10548-015-0448-0
"""
# set parameters
if win_func_kwargs is None:
win_func_kwargs = {}

# Minimal safety checks
if data.ndim == 1:
Expand All @@ -133,45 +146,23 @@ def irasa(

hset = np.round(np.arange(*hset_info), hset_accuracy)

# Calculate original spectrum
def _compute_psd_welch(
data: np.ndarray,
fs: int,
window: str = 'hann',
nperseg: int | None = None,
noverlap: int | None = None,
nfft: int | None = None,
detrend: str = 'constant',
return_onesided: bool = True,
scaling: str = 'density',
axis: int = -1,
average: str = 'mean',
spectrum_only: bool = False,
h: float | None = None,
time_orig: np.ndarray | None = None,
up_down: str | None = None,
# **irasa_kwargs ,
) -> tuple[np.ndarray, np.ndarray]:
"""Function to compute power spectral densities using welchs method"""

freq, psd = dsp.welch(
data,
fs=fs,
window=window,
nperseg=nperseg,
noverlap=noverlap,
nfft=nfft,
detrend=detrend,
return_onesided=return_onesided,
scaling=scaling,
axis=axis,
average=average,
)

if spectrum_only:
return psd
else:
return freq, psd
win_kwargs = {'win_func': win_func, 'win_func_kwargs': win_func_kwargs}
dpss_settings = {
'time_bandwidth': dpss_settings_time_bandwidth,
'low_bias': dpss_settings_low_bias,
'eigenvalue_weighting': dpss_eigenvalue_weighting,
}

irasa_kwargs: IrasaKwargsTyped = {
'nperseg': psd_kwargs.get('nperseg'),
'noverlap': psd_kwargs.get('noverlap'),
'nfft': psd_kwargs.get('nfft'),
'h': None,
'up_down': None,
'time_orig': None,
'dpss_settings': dpss_settings,
'win_kwargs': win_kwargs,
}

freq, psd = _compute_psd_welch(data, fs=fs, **irasa_kwargs)

Expand All @@ -195,8 +186,6 @@ def irasa_sprint( # noqa PLR0915 C901
fs: int,
band: tuple[float, float] = (1.0, 100.0),
freq_res: float = 0.5,
# smooth: bool = True,
# n_avgs: list = [1],
win_duration: float = 0.4,
hop: int = 10,
win_func: Callable = dsp.windows.hann,
Expand Down Expand Up @@ -303,19 +292,7 @@ def irasa_sprint( # noqa PLR0915 C901
'eigenvalue_weighting': dpss_eigenvalue_weighting,
}

class IrasaKwargsTyped(TypedDict):
mfft: int
hop: int
win_duration: float
h: int | None
up_down: str | None
dpss_settings: dict
win_kwargs: dict
time_orig: None | np.ndarray
# smooth: bool
# n_avgs: list

irasa_kwargs: IrasaKwargsTyped = {
irasa_kwargs: IrasaSprintKwargsTyped = {
'mfft': mfft,
'hop': hop,
'win_duration': win_duration,
Expand All @@ -324,89 +301,8 @@ class IrasaKwargsTyped(TypedDict):
'dpss_settings': dpss_settings,
'win_kwargs': win_kwargs,
'time_orig': None,
#'smooth': smooth,
#'n_avgs': n_avgs,
}

def _compute_sgramm( # noqa C901
x: np.ndarray,
fs: int,
mfft: int,
hop: int,
win_duration: float,
dpss_settings: dict,
win_kwargs: dict,
up_down: str | None = None,
h: int | None = None,
time_orig: np.ndarray | None = None,
# smooth: bool = True,
# n_avgs: list = [3],
spectrum_only: bool = False,
) -> tuple[np.ndarray, np.ndarray, np.ndarray] | np.ndarray:
"""Function to compute spectrograms"""

if h is None:
nperseg = int(np.floor(fs * win_duration))
elif np.logical_and(h is not None, up_down == 'up'):
nperseg = int(np.floor(fs * win_duration * h))
hop = int(hop * h)
elif np.logical_and(h is not None, up_down == 'down'):
nperseg = int(np.floor(fs * win_duration / h))
hop = int(hop / h)

win, ratios = _get_windows(nperseg, dpss_settings, **win_kwargs)

sgramms = []
for cur_win in win:
SFT = ShortTimeFFT(cur_win, hop=hop, mfft=mfft, fs=fs, scale_to='psd') # noqa N806
cur_sgramm = SFT.spectrogram(x, detr='constant')
sgramms.append(cur_sgramm)

if ratios is None:
sgramm = np.mean(sgramms, axis=0)
else:
weighted_sgramms = [ratios[ix] * cur_sgramm for ix, cur_sgramm in enumerate(sgramms)]
sgramm = np.sum(weighted_sgramms, axis=0) / np.sum(ratios)

# TODO: smoothing doesnt work properly
# if smooth:
# avgs = []
# def _moving_average(x: np.ndarray, w: int) -> np.ndarray:
# return np.convolve(x, np.ones(w), 'valid') / w

# def sgramm_smoother(sgramm: np.ndarray, n_avgs: int) -> np.ndarray:
# return np.array([_moving_average(sgramm[freq, :], w=n_avgs) for freq in range(sgramm.shape[0])])
# n_avgs_r = n_avgs[::-1]
# for avg, avg_r in zip(n_avgs, n_avgs_r):
# sgramm_fwd = sgramm_smoother(sgramm=np.squeeze(sgramm), n_avgs=avg)[:, avg_r:]
# sgramm_bwd = sgramm_smoother(sgramm=np.squeeze(sgramm)[:, ::-1], n_avgs=avg)[:, ::-1][:, avg_r:]
# sgramm_n = gmean([sgramm_fwd, sgramm_bwd], axis=0)
# avgs.append(sgramm_n)

# sgramm = np.median(avgs, axis=0)
# sgramm = sgramm[np.newaxis, :, :]

time = _gen_time_from_sft(SFT, x)
freq = SFT.f[SFT.f > 0]

# subsample the upsampled data in the time domain to allow averaging
# This is necessary as division by h can cause slight rounding differences that
# result in actual unintended temporal differences in up/dw for very long segments.
if time_orig is not None:
sgramm = np.array([_find_nearest(sgramm, time, t) for t in time_orig])
max_t_ix = time_orig.shape[0]
# swapping axes is necessitated by _find_nearest
sgramm = np.swapaxes(
np.swapaxes(sgramm[:max_t_ix, :, :], 1, 2), 0, 2
) # cut time axis for up/downsampled data to allow averaging

sgramm = np.squeeze(sgramm) # bring in proper format

if spectrum_only:
return sgramm
else:
return freq, time, sgramm

# get time and frequency info
freq, time, sgramm = _compute_sgramm(data, fs, **irasa_kwargs)

Expand Down
6 changes: 3 additions & 3 deletions pyrasa/irasa_mne/irasa_mne.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def irasa_raw(
band=band,
filter_settings=(data.info['highpass'], data.info['lowpass']),
hset_info=hset_info,
irasa_kwargs=kwargs_psd,
psd_kwargs=kwargs_psd,
)

if as_array is True:
Expand Down Expand Up @@ -169,7 +169,7 @@ def irasa_epochs(
# TODO: does zero padding make sense?
kwargs_psd = {
'window': 'hann',
#'nperseg': data_array.shape[2],
'nperseg': None,
'nfft': nfft,
'noverlap': 0,
}
Expand All @@ -183,7 +183,7 @@ def irasa_epochs(
band=band,
filter_settings=(data.info['highpass'], data.info['lowpass']),
hset_info=hset_info,
irasa_kwargs=kwargs_psd,
psd_kwargs=kwargs_psd,
)
psd_list_aperiodic.append(psd_aperiodic)
psd_list_periodic.append(psd_periodic)
Expand Down
27 changes: 27 additions & 0 deletions pyrasa/utils/input_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import TypedDict

import numpy as np


class IrasaKwargsTyped(TypedDict):
nperseg: int | None
noverlap: int | None
nfft: int | None
h: int | None
time_orig: None | np.ndarray
up_down: str | None
dpss_settings: dict
win_kwargs: dict


class IrasaSprintKwargsTyped(TypedDict):
mfft: int
hop: int
win_duration: float
h: int | None
up_down: str | None
dpss_settings: dict
win_kwargs: dict
time_orig: None | np.ndarray
# smooth: bool
# n_avgs: list
Loading

0 comments on commit 3acc2ed

Please sign in to comment.