Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Fix consistency bug with orig units #11160

Merged
merged 3 commits into from
Sep 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,11 +415,9 @@ def rename_channels(self, mapping, allow_duplicates=False, verbose=None):
if isinstance(self, BaseRaw):
# whatever mapping was provided, now we can just use a dict
mapping = dict(zip(ch_names_orig, self.info['ch_names']))
if self._orig_units is not None:
for old_name, new_name in mapping.items():
if old_name != new_name:
self._orig_units[new_name] = self._orig_units[old_name]
del self._orig_units[old_name]
for old_name, new_name in mapping.items():
if old_name in self._orig_units:
self._orig_units[new_name] = self._orig_units.pop(old_name)
ch_names = self.annotations.ch_names
for ci, ch in enumerate(ch_names):
ch_names[ci] = tuple(mapping.get(name, name) for name in ch)
Expand Down Expand Up @@ -830,6 +828,9 @@ def _pick_drop_channels(self, idx, *, verbose=None):

if isinstance(self, BaseRaw):
self.annotations._prune_ch_names(self.info, on_missing='ignore')
self._orig_units = {
k: v for k, v in self._orig_units.items()
if k in self.ch_names}

self._pick_projs()
return self
Expand Down Expand Up @@ -944,6 +945,8 @@ def add_channels(self, add_list, force_update_info=False):
self._read_picks = [
np.concatenate([r, extra_idx]) for r in self._read_picks]
assert all(len(r) == self.info['nchan'] for r in self._read_picks)
for other in add_list:
self._orig_units.update(other._orig_units)
elif isinstance(self, BaseEpochs):
self.picks = np.arange(self._data.shape[1])
if hasattr(self, '_projector'):
Expand Down
2 changes: 1 addition & 1 deletion mne/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def __init__(self, info, preload=False,
# Final check of orig_units, editing a unit if it is not a valid
# unit
orig_units = _check_orig_units(orig_units)
self._orig_units = orig_units
self._orig_units = orig_units or dict() # always a dict
self._projectors = list()
self._projector = None
self._dtype_ = dtype
Expand Down
14 changes: 14 additions & 0 deletions mne/io/edf/tests/test_edf.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,20 @@ def test_orig_units():
orig_units = raw._orig_units
assert len(orig_units) == len(raw.ch_names)
assert orig_units['A1'] == 'µV' # formerly 'uV' edit by _check_orig_units
del orig_units

raw.rename_channels(dict(A1='AA'))
assert raw._orig_units['AA'] == 'µV'
raw.rename_channels(dict(AA='A1'))

raw_back = raw.copy().pick(raw.ch_names[:1]) # _pick_drop_channels
assert raw_back.ch_names == ['A1']
assert set(raw_back._orig_units) == {'A1'}
raw_back.add_channels([raw.copy().pick(raw.ch_names[1:])])
assert raw_back.ch_names == raw.ch_names
assert set(raw_back._orig_units) == set(raw.ch_names)
raw_back.reorder_channels(raw.ch_names[::-1])
assert set(raw_back._orig_units) == set(raw.ch_names)


def test_units_params():
Expand Down
7 changes: 4 additions & 3 deletions mne/io/fiff/tests/test_raw_fiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,11 +1336,12 @@ def test_add_channels():
"""Test raw splitting / re-appending channel types."""
rng = np.random.RandomState(0)
raw = read_raw_fif(test_fif_fname).crop(0, 1).load_data()
assert raw._orig_units == {}
raw_nopre = read_raw_fif(test_fif_fname, preload=False)
raw_eeg_meg = raw.copy().pick_types(meg=True, eeg=True)
raw_eeg = raw.copy().pick_types(meg=False, eeg=True)
raw_meg = raw.copy().pick_types(meg=True, eeg=False)
raw_stim = raw.copy().pick_types(meg=False, eeg=False, stim=True)
raw_eeg = raw.copy().pick_types(eeg=True)
raw_meg = raw.copy().pick_types(meg=True)
raw_stim = raw.copy().pick_types(stim=True)
raw_new = raw_meg.copy().add_channels([raw_eeg, raw_stim])
assert (
all(ch in raw_new.ch_names
Expand Down