Skip to content

Commit

Permalink
handle complex output
Browse files Browse the repository at this point in the history
  • Loading branch information
drammock committed Feb 16, 2024
1 parent 1d9cc39 commit 90ffc8f
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 44 deletions.
4 changes: 4 additions & 0 deletions mne/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2251,6 +2251,7 @@ def compute_tfr(
tmax=None,
picks=None,
proj=False,
output="power",
reject_by_annotation=True,
decim=1,
n_jobs=None,
Expand All @@ -2266,6 +2267,7 @@ def compute_tfr(
%(tmin_tmax_psd)s
%(picks_good_data_noref)s
%(proj_psd)s
%(output_compute_tfr)s
%(reject_by_annotation_tfr)s
%(decim_tfr)s
%(n_jobs)s
Expand All @@ -2285,6 +2287,8 @@ def compute_tfr(
----------
.. footbibliography::
"""
_check_option("output", output, ("power", "phase", "complex"))
method_kw["output"] = output
return RawTFR(
self,
method=method,
Expand Down
17 changes: 0 additions & 17 deletions mne/time_frequency/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,6 @@
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import inspect


def _ensure_output_not_in_method_kw(inst, method_kw):
legacy = inspect.currentframe().f_back.f_back.f_back.f_code.co_name == "_tfr_aux"
if legacy:
return method_kw
if "output" in method_kw:
raise ValueError(
f"{type(inst).__name__}.compute_tfr() got an unexpected keyword argument "
'"output". if you need more control over the output computation, please '
"use the array interfaces (mne.time_frequency.tfr_array_morlet() or "
"mne.time_frequency.tfr_array_multitaper())."
)
method_kw["output"] = "power"
return method_kw


def _get_instance_type_string(inst):
"""Get string representation of the originating instance type."""
Expand Down
24 changes: 14 additions & 10 deletions mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@
parametrize_morlet_multitaper = pytest.mark.parametrize(
"method", ("morlet", "multitaper")
)
parametrize_power_phase_complex = pytest.mark.parametrize(
"output", ("power", "phase", "complex")
)
parametrize_inst_and_ch_type = pytest.mark.parametrize(
"inst,ch_type",
(
Expand Down Expand Up @@ -623,7 +626,7 @@ def test_tfr_io(inst, average_tfr, request, tmp_path):


def test_raw_tfr_init(raw):
"""Test the AverageTFR, RawTFR and RawTFRArray constructors."""
"""Test the RawTFR and RawTFRArray constructors."""
one = RawTFR(inst=raw, method="morlet", freqs=freqs_linspace)
two = RawTFRArray(one.info, one.data, one.times, one.freqs, method="morlet")
# some attributes we know won't match:
Expand Down Expand Up @@ -1343,11 +1346,12 @@ def test_to_data_frame_time_format(time_format):


@parametrize_morlet_multitaper
@parametrize_power_phase_complex
@pytest.mark.parametrize("picks", ("mag", mag_names, [2, 5, 8])) # all 3 equivalent
def test_raw_compute_tfr(raw, method, picks):
def test_raw_compute_tfr(raw, method, output, picks):
"""Test Raw.compute_tfr() and picks handling."""
full_tfr = raw.compute_tfr(method, freqs=freqs_linspace)
pick_tfr = raw.compute_tfr(method, freqs=freqs_linspace, picks=picks)
full_tfr = raw.compute_tfr(method, output=output, freqs=freqs_linspace)
pick_tfr = raw.compute_tfr(method, output=output, freqs=freqs_linspace, picks=picks)
assert isinstance(pick_tfr, RawTFR), type(pick_tfr)
# ↓↓↓ can't use [2,5,8] because ch0 is IAS, so indices change between raw and TFR
want = full_tfr.get_data(picks=mag_names)
Expand All @@ -1356,10 +1360,11 @@ def test_raw_compute_tfr(raw, method, picks):


@parametrize_morlet_multitaper
@parametrize_power_phase_complex
@pytest.mark.parametrize("freqs", (freqs_linspace, freqs_unsorted_list))
def test_evoked_compute_tfr(full_evoked, method, freqs):
def test_evoked_compute_tfr(full_evoked, method, output, freqs):
"""Test Evoked.compute_tfr(), with a few different ways of specifying freqs."""
tfr = full_evoked.compute_tfr(method, freqs)
tfr = full_evoked.compute_tfr(method, freqs, output=output)
assert isinstance(tfr, AverageTFR), type(tfr)
assert tfr.nave == full_evoked.nave
assert tfr.comment == full_evoked.comment
Expand Down Expand Up @@ -1402,8 +1407,7 @@ def test_epochs_compute_tfr_average_itc(
assert avg.comment.startswith(f"mean of {len(epochs)} EpochsTFR")


@parametrize_morlet_multitaper
def test_epochs_vs_evoked_compute_tfr(epochs, method):
def test_epochs_vs_evoked_compute_tfr(epochs):
"""Compare result of averaging before or after the TFR computation.
This is mostly a test of object structure / attribute preservation. In normal cases,
Expand All @@ -1415,8 +1419,8 @@ def test_epochs_vs_evoked_compute_tfr(epochs, method):
The three things that will always end up different are `._comment`, `._inst_type`,
and `._data_type`, so we ignore those here.
"""
avg_first = epochs.average().compute_tfr(method=method, freqs=freqs_linspace)
avg_second = epochs.compute_tfr(method=method, freqs=freqs_linspace).average()
avg_first = epochs.average().compute_tfr(method="morlet", freqs=freqs_linspace)
avg_second = epochs.compute_tfr(method="morlet", freqs=freqs_linspace).average()
for attr in ("_comment", "_inst_type", "_data_type"):
assert getattr(avg_first, attr) != getattr(avg_second, attr)
delattr(avg_first, attr)
Expand Down
54 changes: 37 additions & 17 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
figure_nobar,
plt_show,
)
from ._utils import _ensure_output_not_in_method_kw, _get_instance_type_string
from ._utils import _get_instance_type_string
from .multitaper import dpss_windows, tfr_array_multitaper
from .spectrum import EpochsSpectrum

Expand Down Expand Up @@ -1199,6 +1199,7 @@ def __init__(
# and `freqs` vector has been pre-computed
if method != "stockwell":
method_kw.update(freqs=freqs)
# ↓↓↓ if constructor called directly, prevents key error
method_kw.setdefault("output", "power")
self._freqs = np.asarray(freqs, dtype=np.float64)
del freqs
Expand All @@ -1223,9 +1224,16 @@ def __init__(
self._method = method
self._inst_type = type(inst)
self._baseline = None
self.preload = True # needed for __getitem__, never False
self.preload = True # needed for __getitem__, never False for TFRs
# self._dims may also get updated by child classes
self._dims = ("channel", "freq", "time")
self._dims = ["channel", "freq", "time"]
self._needs_taper_dim = method == "multitaper" and method_kw["output"] in (
"complex",
"phase",
)
if self._needs_taper_dim:
self._dims.insert(1, "taper")
self._dims = tuple(self._dims)
# get the instance data.
time_mask = _time_mask(inst.times, tmin, tmax, sfreq=self.sfreq)
get_instance_data_kw = dict(time_mask=time_mask)
Expand All @@ -1246,16 +1254,27 @@ def __init__(
self._set_times(self._raw_times)
self._decim = 1
# record data type (for repr and html_repr). ITC handled in the calling method.
is_complex = np.iscomplexobj(self._data)
if is_complex:
self._data_type = "Complex TFR"
else:
if method == "stockwell":
self._data_type = "Power Estimates"
# check for correct shape and bad values
self._check_values(is_complex)
else:
data_types = dict(
power="Power Estimates",
avg_power="Average Power Estimates",
avg_power_itc="Average Power Estimates",
phase="Phase",
complex="Complex Amplitude",
)
self._data_type = data_types[method_kw["output"]]
# check for correct shape and bad values. `tfr_array_stockwell` doesn't take kw
# `output` so it may be missing here, so use `.get()`
negative_ok = method_kw.get("output", "") in ("complex", "phase")
# if method_kw.get("output", None) in ("phase", "complex"):
# raise RuntimeError
self._check_values(negative_ok=negative_ok)
# we don't need these anymore, and they make save/load harder
del self._picks
del self._tfr_func
del self._needs_taper_dim
del self._shape # calculated from self._data henceforth
del self.inst # save memory

Expand Down Expand Up @@ -1463,7 +1482,7 @@ def _check_state(self):
return
raise ValueError(msg)

def _check_values(self, wants_complex=False):
def _check_values(self, negative_ok=False):
"""Check TFR results for correct shape and bad values."""
assert len(self._dims) == self._data.ndim
assert self._data.shape == self._shape
Expand All @@ -1473,7 +1492,7 @@ def _check_values(self, wants_complex=False):
dims = np.arange(self._data.ndim).tolist()
dims.pop(ch_dim)
negative_values = self._data.min(axis=tuple(dims)) < 0
if negative_values.any() and not wants_complex:
if negative_values.any() and not negative_ok:
chs = np.array(self.ch_names)[negative_values].tolist()
s = _pl(negative_values.sum())
warn(
Expand Down Expand Up @@ -1505,11 +1524,15 @@ def _compute_tfr(self, data, n_jobs, verbose):

# this is *expected* shape, it gets asserted later in _check_values()
# (and then deleted afterwards)
self._shape = (
expected_shape = [
len(self.ch_names),
len(self.freqs),
len(self._raw_times[self._decim]), # don't use self.times, not set yet
)
]
# deal with the "taper" dimension
if self._needs_taper_dim:
expected_shape.insert(1, self._data.shape[1])
self._shape = tuple(expected_shape)

@verbose
def _onselect(
Expand Down Expand Up @@ -2839,7 +2862,7 @@ def __init__(
n_fft = method_kw.get("n_fft", default_nfft)
*_, freqs = _compute_freqs_st(fmin, fmax, n_fft, inst.info["sfreq"])

# use Evoked.comment or str(Epochs.event_id) as the comment...
# use Evoked.comment or str(Epochs.event_id) as the default comment...
if comment is None:
comment = getattr(inst, "comment", ",".join(getattr(inst, "event_id", "")))
# ...but don't overwrite if it's coming in with a comment already set
Expand Down Expand Up @@ -3783,9 +3806,6 @@ def __init__(
_validate_type(
inst, (BaseRaw, dict), "object passed to RawTFR constructor", "Raw"
)
# make sure they didn't pass "output"
_ensure_output_not_in_method_kw(inst, method_kw)

super().__init__(
inst,
method,
Expand Down

0 comments on commit 90ffc8f

Please sign in to comment.