diff --git a/pyrasa/irasa.py b/pyrasa/irasa.py index 6780d48..49a970b 100644 --- a/pyrasa/irasa.py +++ b/pyrasa/irasa.py @@ -1,11 +1,10 @@ import fractions from collections.abc import Callable +from typing import TYPE_CHECKING, Any import numpy as np import scipy.signal as dsp -from pyrasa.utils.input_classes import IrasaKwargsTyped, IrasaSprintKwargsTyped - # from scipy.stats.mstats import gmean from pyrasa.utils.irasa_utils import ( _check_irasa_settings, @@ -13,6 +12,10 @@ _compute_sgramm, _crop_data, # _find_nearest, _gen_time_from_sft, _get_windows, ) +from pyrasa.utils.types import IrasaFun + +if TYPE_CHECKING: + from pyrasa.utils.input_classes import IrasaSprintKwargsTyped # TODO: Port to Cython @@ -20,9 +23,8 @@ def _gen_irasa( data: np.ndarray, orig_spectrum: np.ndarray, fs: int, - irasa_fun: Callable, + irasa_fun: IrasaFun, hset: np.ndarray, - irasa_kwargs: dict | IrasaKwargsTyped | IrasaSprintKwargsTyped, time: np.ndarray | None = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ @@ -45,12 +47,8 @@ def _gen_irasa( data_down = dsp.resample_poly(data, down, up, axis=-1) # Calculate an up/downsampled version of the PSD using same params as original - irasa_kwargs['h'] = h - irasa_kwargs['time_orig'] = time - irasa_kwargs['up_down'] = 'up' - spectrum_up = irasa_fun(data_up, int(fs * h), spectrum_only=True, **irasa_kwargs) - irasa_kwargs['up_down'] = 'down' - spectrum_dw = irasa_fun(data_down, int(fs / h), spectrum_only=True, **irasa_kwargs) + spectrum_up = irasa_fun(data=data_up, fs=int(fs * h), h=h, time_orig=time, up_down='up') + spectrum_dw = irasa_fun(data=data_down, fs=int(fs / h), h=h, time_orig=time, up_down='down') # geometric mean between up and downsampled # be aware of the input dimensions @@ -153,26 +151,34 @@ def irasa( 'eigenvalue_weighting': dpss_eigenvalue_weighting, } - irasa_kwargs: IrasaKwargsTyped = { - 'nperseg': psd_kwargs.get('nperseg'), - 'noverlap': psd_kwargs.get('noverlap'), - 'nfft': psd_kwargs.get('nfft'), - 'h': None, - 'up_down': None, - 'time_orig': None, - 'dpss_settings': dpss_settings, - 'win_kwargs': win_kwargs, - } - - freq, psd = _compute_psd_welch(data, fs=fs, **irasa_kwargs) + def _local_irasa_fun( + data: np.ndarray, + fs: int, + *args: Any, + **kwargs: Any, + ) -> np.ndarray: + return _compute_psd_welch( + data, + fs=fs, + nperseg=psd_kwargs.get('nperseg'), + win_kwargs=win_kwargs, + dpss_settings=dpss_settings, + noverlap=psd_kwargs.get('noverlap'), + nfft=psd_kwargs.get('nfft'), + )[1] + + freq, psd = _compute_psd_welch( + data, + fs=fs, + nperseg=psd_kwargs.get('nperseg'), + win_kwargs=win_kwargs, + dpss_settings=dpss_settings, + noverlap=psd_kwargs.get('noverlap'), + nfft=psd_kwargs.get('nfft'), + ) psd, psd_aperiodic, psd_periodic = _gen_irasa( - data=np.squeeze(data), - orig_spectrum=psd, - fs=fs, - irasa_fun=_compute_psd_welch, - hset=hset, - irasa_kwargs=irasa_kwargs, + data=np.squeeze(data), orig_spectrum=psd, fs=fs, irasa_fun=_local_irasa_fun, hset=hset ) freq, psd_aperiodic, psd_periodic = _crop_data(band, freq, psd_aperiodic, psd_periodic, axis=-1) @@ -296,13 +302,19 @@ def irasa_sprint( # noqa PLR0915 C901 'mfft': mfft, 'hop': hop, 'win_duration': win_duration, - 'h': None, - 'up_down': None, 'dpss_settings': dpss_settings, 'win_kwargs': win_kwargs, - 'time_orig': None, } + def _local_irasa_fun( + data: np.ndarray, + fs: int, + h: int | None, + up_down: str | None, + time_orig: np.ndarray | None = None, + ) -> np.ndarray: + return _compute_sgramm(data, fs, h=h, up_down=up_down, time_orig=time_orig, **irasa_kwargs)[2] + # get time and frequency info freq, time, sgramm = _compute_sgramm(data, fs, **irasa_kwargs) @@ -310,9 +322,8 @@ def irasa_sprint( # noqa PLR0915 C901 data=data, orig_spectrum=sgramm, fs=fs, - irasa_fun=_compute_sgramm, + irasa_fun=_local_irasa_fun, hset=hset, - irasa_kwargs=dict(irasa_kwargs), time=time, ) diff --git a/pyrasa/utils/input_classes.py b/pyrasa/utils/input_classes.py index 27e7b0f..4090853 100644 --- a/pyrasa/utils/input_classes.py +++ b/pyrasa/utils/input_classes.py @@ -1,27 +1,11 @@ from typing import TypedDict -import numpy as np - - -class IrasaKwargsTyped(TypedDict): - nperseg: int | None - noverlap: int | None - nfft: int | None - h: int | None - time_orig: None | np.ndarray - up_down: str | None - dpss_settings: dict - win_kwargs: dict - class IrasaSprintKwargsTyped(TypedDict): mfft: int hop: int win_duration: float - h: int | None - up_down: str | None dpss_settings: dict win_kwargs: dict - time_orig: None | np.ndarray # smooth: bool # n_avgs: list diff --git a/pyrasa/utils/irasa_utils.py b/pyrasa/utils/irasa_utils.py index 9f73cd5..830bcf2 100644 --- a/pyrasa/utils/irasa_utils.py +++ b/pyrasa/utils/irasa_utils.py @@ -137,10 +137,6 @@ def _compute_psd_welch( scaling: str = 'density', axis: int = -1, average: str = 'mean', - up_down: str | None = None, - spectrum_only: bool = False, - h: float | None = None, - time_orig: np.ndarray | None = None, ) -> tuple[np.ndarray, np.ndarray]: """Function to compute power spectral densities using welchs method""" @@ -171,10 +167,7 @@ def _compute_psd_welch( weighted_psds = [ratios[ix] * cur_sgramm for ix, cur_sgramm in enumerate(psds)] psd = np.sum(weighted_psds, axis=0) / np.sum(ratios) - if spectrum_only: - return psd - else: - return freq, psd + return freq, psd def _compute_sgramm( # noqa C901 @@ -188,8 +181,7 @@ def _compute_sgramm( # noqa C901 up_down: str | None = None, h: int | None = None, time_orig: np.ndarray | None = None, - spectrum_only: bool = False, -) -> tuple[np.ndarray, np.ndarray, np.ndarray] | np.ndarray: +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Function to compute spectrograms""" if h is None: @@ -231,7 +223,4 @@ def _compute_sgramm( # noqa C901 sgramm = np.squeeze(sgramm) # bring in proper format - if spectrum_only: - return sgramm - else: - return freq, time, sgramm + return freq, time, sgramm diff --git a/pyrasa/utils/types.py b/pyrasa/utils/types.py new file mode 100644 index 0000000..a4dd657 --- /dev/null +++ b/pyrasa/utils/types.py @@ -0,0 +1,9 @@ +from typing import Protocol + +import numpy as np + + +class IrasaFun(Protocol): + def __call__( + self, data: np.ndarray, fs: int, h: int | None, up_down: str | None, time_orig: np.ndarray | None = None + ) -> np.ndarray: ...