Skip to content

Commit

Permalink
Merge pull request #37 from schmidtfa/Factor-out-_irasa_funs
Browse files Browse the repository at this point in the history
Factor out  irasa funs and provide a consistent interface
  • Loading branch information
schmidtfa authored Jul 30, 2024
2 parents 3acc2ed + 4fb1156 commit 856ed43
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 63 deletions.
77 changes: 44 additions & 33 deletions pyrasa/irasa.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
import fractions
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

import numpy as np
import scipy.signal as dsp

from pyrasa.utils.input_classes import IrasaKwargsTyped, IrasaSprintKwargsTyped

# from scipy.stats.mstats import gmean
from pyrasa.utils.irasa_utils import (
_check_irasa_settings,
_compute_psd_welch,
_compute_sgramm,
_crop_data, # _find_nearest, _gen_time_from_sft, _get_windows,
)
from pyrasa.utils.types import IrasaFun

if TYPE_CHECKING:
from pyrasa.utils.input_classes import IrasaSprintKwargsTyped


# TODO: Port to Cython
def _gen_irasa(
data: np.ndarray,
orig_spectrum: np.ndarray,
fs: int,
irasa_fun: Callable,
irasa_fun: IrasaFun,
hset: np.ndarray,
irasa_kwargs: dict | IrasaKwargsTyped | IrasaSprintKwargsTyped,
time: np.ndarray | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Expand All @@ -45,12 +47,8 @@ def _gen_irasa(
data_down = dsp.resample_poly(data, down, up, axis=-1)

# Calculate an up/downsampled version of the PSD using same params as original
irasa_kwargs['h'] = h
irasa_kwargs['time_orig'] = time
irasa_kwargs['up_down'] = 'up'
spectrum_up = irasa_fun(data_up, int(fs * h), spectrum_only=True, **irasa_kwargs)
irasa_kwargs['up_down'] = 'down'
spectrum_dw = irasa_fun(data_down, int(fs / h), spectrum_only=True, **irasa_kwargs)
spectrum_up = irasa_fun(data=data_up, fs=int(fs * h), h=h, time_orig=time, up_down='up')
spectrum_dw = irasa_fun(data=data_down, fs=int(fs / h), h=h, time_orig=time, up_down='down')

# geometric mean between up and downsampled
# be aware of the input dimensions
Expand Down Expand Up @@ -153,26 +151,34 @@ def irasa(
'eigenvalue_weighting': dpss_eigenvalue_weighting,
}

irasa_kwargs: IrasaKwargsTyped = {
'nperseg': psd_kwargs.get('nperseg'),
'noverlap': psd_kwargs.get('noverlap'),
'nfft': psd_kwargs.get('nfft'),
'h': None,
'up_down': None,
'time_orig': None,
'dpss_settings': dpss_settings,
'win_kwargs': win_kwargs,
}

freq, psd = _compute_psd_welch(data, fs=fs, **irasa_kwargs)
def _local_irasa_fun(
data: np.ndarray,
fs: int,
*args: Any,
**kwargs: Any,
) -> np.ndarray:
return _compute_psd_welch(
data,
fs=fs,
nperseg=psd_kwargs.get('nperseg'),
win_kwargs=win_kwargs,
dpss_settings=dpss_settings,
noverlap=psd_kwargs.get('noverlap'),
nfft=psd_kwargs.get('nfft'),
)[1]

freq, psd = _compute_psd_welch(
data,
fs=fs,
nperseg=psd_kwargs.get('nperseg'),
win_kwargs=win_kwargs,
dpss_settings=dpss_settings,
noverlap=psd_kwargs.get('noverlap'),
nfft=psd_kwargs.get('nfft'),
)

psd, psd_aperiodic, psd_periodic = _gen_irasa(
data=np.squeeze(data),
orig_spectrum=psd,
fs=fs,
irasa_fun=_compute_psd_welch,
hset=hset,
irasa_kwargs=irasa_kwargs,
data=np.squeeze(data), orig_spectrum=psd, fs=fs, irasa_fun=_local_irasa_fun, hset=hset
)

freq, psd_aperiodic, psd_periodic = _crop_data(band, freq, psd_aperiodic, psd_periodic, axis=-1)
Expand Down Expand Up @@ -296,23 +302,28 @@ def irasa_sprint( # noqa PLR0915 C901
'mfft': mfft,
'hop': hop,
'win_duration': win_duration,
'h': None,
'up_down': None,
'dpss_settings': dpss_settings,
'win_kwargs': win_kwargs,
'time_orig': None,
}

def _local_irasa_fun(
data: np.ndarray,
fs: int,
h: int | None,
up_down: str | None,
time_orig: np.ndarray | None = None,
) -> np.ndarray:
return _compute_sgramm(data, fs, h=h, up_down=up_down, time_orig=time_orig, **irasa_kwargs)[2]

# get time and frequency info
freq, time, sgramm = _compute_sgramm(data, fs, **irasa_kwargs)

sgramm, sgramm_aperiodic, sgramm_periodic = _gen_irasa(
data=data,
orig_spectrum=sgramm,
fs=fs,
irasa_fun=_compute_sgramm,
irasa_fun=_local_irasa_fun,
hset=hset,
irasa_kwargs=dict(irasa_kwargs),
time=time,
)

Expand Down
16 changes: 0 additions & 16 deletions pyrasa/utils/input_classes.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,11 @@
from typing import TypedDict

import numpy as np


class IrasaKwargsTyped(TypedDict):
nperseg: int | None
noverlap: int | None
nfft: int | None
h: int | None
time_orig: None | np.ndarray
up_down: str | None
dpss_settings: dict
win_kwargs: dict


class IrasaSprintKwargsTyped(TypedDict):
mfft: int
hop: int
win_duration: float
h: int | None
up_down: str | None
dpss_settings: dict
win_kwargs: dict
time_orig: None | np.ndarray
# smooth: bool
# n_avgs: list
17 changes: 3 additions & 14 deletions pyrasa/utils/irasa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,6 @@ def _compute_psd_welch(
scaling: str = 'density',
axis: int = -1,
average: str = 'mean',
up_down: str | None = None,
spectrum_only: bool = False,
h: float | None = None,
time_orig: np.ndarray | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Function to compute power spectral densities using welchs method"""

Expand Down Expand Up @@ -171,10 +167,7 @@ def _compute_psd_welch(
weighted_psds = [ratios[ix] * cur_sgramm for ix, cur_sgramm in enumerate(psds)]
psd = np.sum(weighted_psds, axis=0) / np.sum(ratios)

if spectrum_only:
return psd
else:
return freq, psd
return freq, psd


def _compute_sgramm( # noqa C901
Expand All @@ -188,8 +181,7 @@ def _compute_sgramm( # noqa C901
up_down: str | None = None,
h: int | None = None,
time_orig: np.ndarray | None = None,
spectrum_only: bool = False,
) -> tuple[np.ndarray, np.ndarray, np.ndarray] | np.ndarray:
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Function to compute spectrograms"""

if h is None:
Expand Down Expand Up @@ -231,7 +223,4 @@ def _compute_sgramm( # noqa C901

sgramm = np.squeeze(sgramm) # bring in proper format

if spectrum_only:
return sgramm
else:
return freq, time, sgramm
return freq, time, sgramm
9 changes: 9 additions & 0 deletions pyrasa/utils/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Protocol

import numpy as np


class IrasaFun(Protocol):
def __call__(
self, data: np.ndarray, fs: int, h: int | None, up_down: str | None, time_orig: np.ndarray | None = None
) -> np.ndarray: ...

0 comments on commit 856ed43

Please sign in to comment.