Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MNT] Solving moabb and braindecode compatibility #669

Merged
4 changes: 3 additions & 1 deletion docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ Bugs
- Fix Stieger2021 dataset bugs (:gh:`651` by `Martin Wimpff`_)
- Unpinning major version Scikit-learn and numpy (:gh:`652` by `Bruno Aristimunha`_)
- Replacing the func:`numpy.string_` to func:`numpy.bytes_` (:gh:`665` by `Bruno Aristimunha`_)
- Fixing the set_download_dir that was not working when we tried to set the dir more than 10 times at the same time (:gh:`668` by `Bruno Aristimunha`_)
- Fixing the set_download_dir that was not working when we tried to set the dir more than 10 times at the same time (:gh:`668` by `Bruno Aristimunha`_)
- Creating stimulus channels in :class:`moabb.datasets.Zhou2016` and :class:`moabb.datasets.PhysionetMI` to allow braindecode compatibility (:gh:`669` by `Bruno Aristimunha`_)


API changes
~~~~~~~~~~~
Expand Down
6 changes: 5 additions & 1 deletion moabb/datasets/Zhou2016.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .base import BaseDataset
from .download import get_dataset_path
from .utils import stim_channels_with_selected_ids


DATA_PATH = "https://ndownloader.figshare.com/files/3662952"
Expand Down Expand Up @@ -88,6 +89,7 @@ def __init__(self):
paradigm="imagery",
doi="10.1371/journal.pone.0162657",
)
self.events = dict(left_hand=1, right_hand=2, feet=3)

def _get_single_subject_data(self, subject):
"""Return data for a single subject."""
Expand All @@ -105,7 +107,9 @@ def _get_single_subject_data(self, subject):
stim[stim == "2"] = "right_hand"
stim[stim == "3"] = "feet"
raw.annotations.description = stim
out[sess_key][run_key] = raw
out[sess_key][run_key] = stim_channels_with_selected_ids(
raw, desired_event_id=self.events
)
out[sess_key][run_key].set_montage(make_standard_montage("standard_1005"))
return out

Expand Down
11 changes: 6 additions & 5 deletions moabb/datasets/liu2024.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from moabb.datasets import download as dl
from moabb.datasets.base import BaseDataset
from moabb.datasets.utils import stim_channels_with_selected_ids


# Link to the raw data
Expand Down Expand Up @@ -77,15 +78,15 @@ class Liu2024(BaseDataset):
def __init__(self, break_events=False, instr_events=False):
self.break_events = break_events
self.instr_events = instr_events
events = {"left_hand": 1, "right_hand": 2}
self.events = {"left_hand": 1, "right_hand": 2}
if break_events:
events["instr"] = 3
self.events["instr"] = 3
if instr_events:
events["break"] = 4
self.events["break"] = 4
super().__init__(
subjects=list(range(1, 50 + 1)),
sessions_per_subject=1,
events=events,
events=self.events,
code="Liu2024",
interval=(2, 6),
paradigm="imagery",
Expand Down Expand Up @@ -277,7 +278,7 @@ def _get_single_subject_data(self, subject):
# Loading dataset
raw = raw.load_data(verbose=False)
# There is only one session
sessions = {"0": {"0": raw}}
sessions = {"0": {"0": stim_channels_with_selected_ids(raw, self.event_id)}}

return sessions

Expand Down
6 changes: 4 additions & 2 deletions moabb/datasets/mpi_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from moabb.datasets import download as dl
from moabb.datasets.base import BaseDataset
from moabb.datasets.utils import stim_channels_with_selected_ids
from moabb.utils import depreciated_alias


Expand Down Expand Up @@ -56,10 +57,11 @@ class GrosseWentrup2009(BaseDataset):
"""

def __init__(self):
self.events_id = dict(right_hand=2, left_hand=1)
super().__init__(
subjects=list(range(1, 11)),
sessions_per_subject=1,
events=dict(right_hand=2, left_hand=1),
events=self.events_id,
code="GrosseWentrup2009",
interval=[0, 7],
paradigm="imagery",
Expand All @@ -76,7 +78,7 @@ def _get_single_subject_data(self, subject):
stim[stim == "20"] = "right_hand"
stim[stim == "10"] = "left_hand"
raw.annotations.description = stim
return {"0": {"0": raw}}
return {"0": {"0": stim_channels_with_selected_ids(raw, self.event_id)}}

def data_path(
self, subject, path=None, force_update=False, update_path=None, verbose=None
Expand Down
11 changes: 8 additions & 3 deletions moabb/datasets/physionet_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from moabb.datasets.base import BaseDataset
from moabb.datasets.download import data_dl, get_dataset_path
from moabb.datasets.utils import stim_channels_with_selected_ids


BASE_URL = "https://physionet.org/files/eegmmidb/1.0.0/"
Expand Down Expand Up @@ -79,7 +80,7 @@ def __init__(self, imagined=True, executed=False):
paradigm="imagery",
doi="10.1109/TBME.2004.827072",
)

self.events = dict(left_hand=2, right_hand=3, feet=5, hands=4, rest=1)
self.imagined = imagined
self.executed = executed
self.feet_runs = []
Expand Down Expand Up @@ -123,7 +124,9 @@ def _get_single_subject_data(self, subject):
stim[stim == "T1"] = "left_hand"
stim[stim == "T2"] = "right_hand"
raw.annotations.description = stim
data[str(idx)] = raw
data[str(idx)] = stim_channels_with_selected_ids(
raw, desired_event_id=self.events
)
idx += 1

# feet runs
Expand All @@ -136,7 +139,9 @@ def _get_single_subject_data(self, subject):
stim[stim == "T1"] = "hands"
stim[stim == "T2"] = "feet"
raw.annotations.description = stim
data[str(idx)] = raw
data[str(idx)] = stim_channels_with_selected_ids(
raw, desired_event_id=self.events
)
idx += 1

return {"0": data}
Expand Down
6 changes: 4 additions & 2 deletions moabb/datasets/sosulski2019.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from moabb.datasets import download as dl
from moabb.datasets.base import BaseDataset
from moabb.datasets.utils import stim_channels_with_selected_ids


SPOT_PILOT_P300_URL = (
Expand Down Expand Up @@ -95,12 +96,13 @@ def __init__(
self.n_channels = 31
self.use_soas_as_sessions = use_soas_as_sessions
self.description_map = {"Stimulus/S 21": "Target", "Stimulus/S 1": "NonTarget"}
self.events = dict(Target=21, NonTarget=1)
code = "Sosulski2019"
interval = [-0.2, 1] if interval is None else interval
super().__init__(
subjects=list(range(1, 13 + 1)),
sessions_per_subject=1,
events=dict(Target=21, NonTarget=1),
events=self.events,
code=code,
interval=interval,
paradigm="p300",
Expand Down Expand Up @@ -133,7 +135,7 @@ def _get_single_run_data(self, file_path):
if self.reject_non_iid:
raw.set_annotations(raw.annotations[7:85]) # non-iid rejection
raw.annotations.rename(self.description_map)
return raw
return stim_channels_with_selected_ids(raw, self.events)

def _get_single_subject_data(self, subject):
"""Return data for a single subject."""
Expand Down
7 changes: 4 additions & 3 deletions moabb/datasets/upper_limb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from mne.io import read_raw_gdf

from moabb.datasets.base import BaseDataset
from moabb.datasets.utils import stim_channels_with_selected_ids

from . import download as dl

Expand Down Expand Up @@ -58,7 +59,7 @@ class Ofner2017(BaseDataset):
def __init__(self, imagined=True, executed=False):
self.imagined = imagined
self.executed = executed
event_id = {
self.event_id = {
"right_elbow_flexion": 1536,
"right_elbow_extension": 1537,
"right_supination": 1538,
Expand All @@ -72,7 +73,7 @@ def __init__(self, imagined=True, executed=False):
super().__init__(
subjects=list(range(1, 16)),
sessions_per_subject=n_sessions,
events=event_id,
events=self.event_id,
code="Ofner2017",
interval=[0, 3], # according to paper 2-5
paradigm="imagery",
Expand Down Expand Up @@ -114,7 +115,7 @@ def _get_single_subject_data(self, subject):
stim[stim == "1541"] = "right_hand_open"
stim[stim == "1542"] = "rest"
raw.annotations.description = stim
data[str(ii)] = raw
data[str(ii)] = stim_channels_with_selected_ids(raw, self.event_id)

out[session_name] = data
return out
Expand Down
50 changes: 50 additions & 0 deletions moabb/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import inspect

import mne
import numpy as np
from mne import create_info
from mne.io import RawArray
Expand Down Expand Up @@ -273,3 +274,52 @@ def add_stim_channel_epoch(
)
raw = raw.add_channels([RawArray(data=stim_chan, info=info, verbose=False)])
return raw


def stim_channels_with_selected_ids(
raw: mne.io.BaseRaw, desired_event_id: dict, stim_channel_name="STIM"
):
"""
Add a stimulus channel with filtering and renaming based on events_ids.

Parameters
----------
raw: mne.Raw
The raw object to add the stimulus channel to.
desired_event_id: dict
Dictionary with events
"""

# Get events using the consistent event_id mapping
events, _ = mne.events_from_annotations(raw, event_id=desired_event_id)

# Filter the events array to include only desired events
desired_event_ids = list(desired_event_id.values())
filtered_events = events[np.isin(events[:, 2], desired_event_ids)]

# Create annotations from filtered events using the inverted mapping
event_desc = {v: k for k, v in desired_event_id.items()}
annot_from_events = mne.annotations_from_events(
events=filtered_events,
event_desc=event_desc,
sfreq=raw.info["sfreq"],
orig_time=raw.info["meas_date"],
)
raw.set_annotations(annot_from_events)

# Create the stim channel data array
stim_channs = np.zeros((1, raw.n_times))
for event in filtered_events:
sample_index = event[0]
event_code = event[2] # Consistent event IDs
stim_channs[0, sample_index] = event_code

# Create the stim channel and add it to raw

stim_info = mne.create_info(
[stim_channel_name], sfreq=raw.info["sfreq"], ch_types=["stim"]
)
stim_raw = mne.io.RawArray(stim_channs, stim_info, verbose=False)
raw_with_stim = raw.copy().add_channels([stim_raw], force_update_info=True)

return raw_with_stim
Loading