diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index d7f1398c2..134c5415b 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -24,7 +24,9 @@ Bugs - Fix Stieger2021 dataset bugs (:gh:`651` by `Martin Wimpff`_) - Unpinning major version Scikit-learn and numpy (:gh:`652` by `Bruno Aristimunha`_) - Replacing the func:`numpy.string_` to func:`numpy.bytes_` (:gh:`665` by `Bruno Aristimunha`_) -- Fixing the set_download_dir that was not working when we tried to set the dir more than 10 times at the same time (:gh:`668` by `Bruno Aristimunha`_) +- Fixing the set_download_dir that was not working when we tried to set the dir more than 10 times at the same time (:gh:`668` by `Bruno Aristimunha`_) +- Creating stimulus channels in :class:`moabb.datasets.Zhou2016` and :class:`moabb.datasets.PhysionetMI` to allow braindecode compatibility (:gh:`669` by `Bruno Aristimunha`_) + API changes ~~~~~~~~~~~ diff --git a/moabb/datasets/Zhou2016.py b/moabb/datasets/Zhou2016.py index c4037cea9..fb7f07b04 100644 --- a/moabb/datasets/Zhou2016.py +++ b/moabb/datasets/Zhou2016.py @@ -14,6 +14,7 @@ from .base import BaseDataset from .download import get_dataset_path +from .utils import stim_channels_with_selected_ids DATA_PATH = "https://ndownloader.figshare.com/files/3662952" @@ -88,6 +89,7 @@ def __init__(self): paradigm="imagery", doi="10.1371/journal.pone.0162657", ) + self.events = dict(left_hand=1, right_hand=2, feet=3) def _get_single_subject_data(self, subject): """Return data for a single subject.""" @@ -105,7 +107,9 @@ def _get_single_subject_data(self, subject): stim[stim == "2"] = "right_hand" stim[stim == "3"] = "feet" raw.annotations.description = stim - out[sess_key][run_key] = raw + out[sess_key][run_key] = stim_channels_with_selected_ids( + raw, desired_event_id=self.events + ) out[sess_key][run_key].set_montage(make_standard_montage("standard_1005")) return out diff --git a/moabb/datasets/liu2024.py b/moabb/datasets/liu2024.py index 82dbf4af6..07e0738d8 100644 --- a/moabb/datasets/liu2024.py +++ b/moabb/datasets/liu2024.py @@ -14,6 +14,7 @@ from moabb.datasets import download as dl from moabb.datasets.base import BaseDataset +from moabb.datasets.utils import stim_channels_with_selected_ids # Link to the raw data @@ -77,15 +78,15 @@ class Liu2024(BaseDataset): def __init__(self, break_events=False, instr_events=False): self.break_events = break_events self.instr_events = instr_events - events = {"left_hand": 1, "right_hand": 2} + self.events = {"left_hand": 1, "right_hand": 2} if break_events: - events["instr"] = 3 + self.events["instr"] = 3 if instr_events: - events["break"] = 4 + self.events["break"] = 4 super().__init__( subjects=list(range(1, 50 + 1)), sessions_per_subject=1, - events=events, + events=self.events, code="Liu2024", interval=(2, 6), paradigm="imagery", @@ -277,7 +278,7 @@ def _get_single_subject_data(self, subject): # Loading dataset raw = raw.load_data(verbose=False) # There is only one session - sessions = {"0": {"0": raw}} + sessions = {"0": {"0": stim_channels_with_selected_ids(raw, self.event_id)}} return sessions diff --git a/moabb/datasets/mpi_mi.py b/moabb/datasets/mpi_mi.py index 06e8bd366..9abff905b 100644 --- a/moabb/datasets/mpi_mi.py +++ b/moabb/datasets/mpi_mi.py @@ -5,6 +5,7 @@ from moabb.datasets import download as dl from moabb.datasets.base import BaseDataset +from moabb.datasets.utils import stim_channels_with_selected_ids from moabb.utils import depreciated_alias @@ -56,10 +57,11 @@ class GrosseWentrup2009(BaseDataset): """ def __init__(self): + self.events_id = dict(right_hand=2, left_hand=1) super().__init__( subjects=list(range(1, 11)), sessions_per_subject=1, - events=dict(right_hand=2, left_hand=1), + events=self.events_id, code="GrosseWentrup2009", interval=[0, 7], paradigm="imagery", @@ -76,7 +78,7 @@ def _get_single_subject_data(self, subject): stim[stim == "20"] = "right_hand" stim[stim == "10"] = "left_hand" raw.annotations.description = stim - return {"0": {"0": raw}} + return {"0": {"0": stim_channels_with_selected_ids(raw, self.event_id)}} def data_path( self, subject, path=None, force_update=False, update_path=None, verbose=None diff --git a/moabb/datasets/physionet_mi.py b/moabb/datasets/physionet_mi.py index 3367c9cf3..8f99dfad2 100644 --- a/moabb/datasets/physionet_mi.py +++ b/moabb/datasets/physionet_mi.py @@ -6,6 +6,7 @@ from moabb.datasets.base import BaseDataset from moabb.datasets.download import data_dl, get_dataset_path +from moabb.datasets.utils import stim_channels_with_selected_ids BASE_URL = "https://physionet.org/files/eegmmidb/1.0.0/" @@ -79,7 +80,7 @@ def __init__(self, imagined=True, executed=False): paradigm="imagery", doi="10.1109/TBME.2004.827072", ) - + self.events = dict(left_hand=2, right_hand=3, feet=5, hands=4, rest=1) self.imagined = imagined self.executed = executed self.feet_runs = [] @@ -123,7 +124,9 @@ def _get_single_subject_data(self, subject): stim[stim == "T1"] = "left_hand" stim[stim == "T2"] = "right_hand" raw.annotations.description = stim - data[str(idx)] = raw + data[str(idx)] = stim_channels_with_selected_ids( + raw, desired_event_id=self.events + ) idx += 1 # feet runs @@ -136,7 +139,9 @@ def _get_single_subject_data(self, subject): stim[stim == "T1"] = "hands" stim[stim == "T2"] = "feet" raw.annotations.description = stim - data[str(idx)] = raw + data[str(idx)] = stim_channels_with_selected_ids( + raw, desired_event_id=self.events + ) idx += 1 return {"0": data} diff --git a/moabb/datasets/sosulski2019.py b/moabb/datasets/sosulski2019.py index ade789445..2fc20db7e 100644 --- a/moabb/datasets/sosulski2019.py +++ b/moabb/datasets/sosulski2019.py @@ -7,6 +7,7 @@ from moabb.datasets import download as dl from moabb.datasets.base import BaseDataset +from moabb.datasets.utils import stim_channels_with_selected_ids SPOT_PILOT_P300_URL = ( @@ -95,12 +96,13 @@ def __init__( self.n_channels = 31 self.use_soas_as_sessions = use_soas_as_sessions self.description_map = {"Stimulus/S 21": "Target", "Stimulus/S 1": "NonTarget"} + self.events = dict(Target=21, NonTarget=1) code = "Sosulski2019" interval = [-0.2, 1] if interval is None else interval super().__init__( subjects=list(range(1, 13 + 1)), sessions_per_subject=1, - events=dict(Target=21, NonTarget=1), + events=self.events, code=code, interval=interval, paradigm="p300", @@ -133,7 +135,7 @@ def _get_single_run_data(self, file_path): if self.reject_non_iid: raw.set_annotations(raw.annotations[7:85]) # non-iid rejection raw.annotations.rename(self.description_map) - return raw + return stim_channels_with_selected_ids(raw, self.events) def _get_single_subject_data(self, subject): """Return data for a single subject.""" diff --git a/moabb/datasets/upper_limb.py b/moabb/datasets/upper_limb.py index f10db8ccd..e8095de81 100644 --- a/moabb/datasets/upper_limb.py +++ b/moabb/datasets/upper_limb.py @@ -3,6 +3,7 @@ from mne.io import read_raw_gdf from moabb.datasets.base import BaseDataset +from moabb.datasets.utils import stim_channels_with_selected_ids from . import download as dl @@ -58,7 +59,7 @@ class Ofner2017(BaseDataset): def __init__(self, imagined=True, executed=False): self.imagined = imagined self.executed = executed - event_id = { + self.event_id = { "right_elbow_flexion": 1536, "right_elbow_extension": 1537, "right_supination": 1538, @@ -72,7 +73,7 @@ def __init__(self, imagined=True, executed=False): super().__init__( subjects=list(range(1, 16)), sessions_per_subject=n_sessions, - events=event_id, + events=self.event_id, code="Ofner2017", interval=[0, 3], # according to paper 2-5 paradigm="imagery", @@ -114,7 +115,7 @@ def _get_single_subject_data(self, subject): stim[stim == "1541"] = "right_hand_open" stim[stim == "1542"] = "rest" raw.annotations.description = stim - data[str(ii)] = raw + data[str(ii)] = stim_channels_with_selected_ids(raw, self.event_id) out[session_name] = data return out diff --git a/moabb/datasets/utils.py b/moabb/datasets/utils.py index f0e280f94..c0f198537 100644 --- a/moabb/datasets/utils.py +++ b/moabb/datasets/utils.py @@ -2,6 +2,7 @@ import inspect +import mne import numpy as np from mne import create_info from mne.io import RawArray @@ -273,3 +274,52 @@ def add_stim_channel_epoch( ) raw = raw.add_channels([RawArray(data=stim_chan, info=info, verbose=False)]) return raw + + +def stim_channels_with_selected_ids( + raw: mne.io.BaseRaw, desired_event_id: dict, stim_channel_name="STIM" +): + """ + Add a stimulus channel with filtering and renaming based on events_ids. + + Parameters + ---------- + raw: mne.Raw + The raw object to add the stimulus channel to. + desired_event_id: dict + Dictionary with events + """ + + # Get events using the consistent event_id mapping + events, _ = mne.events_from_annotations(raw, event_id=desired_event_id) + + # Filter the events array to include only desired events + desired_event_ids = list(desired_event_id.values()) + filtered_events = events[np.isin(events[:, 2], desired_event_ids)] + + # Create annotations from filtered events using the inverted mapping + event_desc = {v: k for k, v in desired_event_id.items()} + annot_from_events = mne.annotations_from_events( + events=filtered_events, + event_desc=event_desc, + sfreq=raw.info["sfreq"], + orig_time=raw.info["meas_date"], + ) + raw.set_annotations(annot_from_events) + + # Create the stim channel data array + stim_channs = np.zeros((1, raw.n_times)) + for event in filtered_events: + sample_index = event[0] + event_code = event[2] # Consistent event IDs + stim_channs[0, sample_index] = event_code + + # Create the stim channel and add it to raw + + stim_info = mne.create_info( + [stim_channel_name], sfreq=raw.info["sfreq"], ch_types=["stim"] + ) + stim_raw = mne.io.RawArray(stim_channs, stim_info, verbose=False) + raw_with_stim = raw.copy().add_channels([stim_raw], force_update_info=True) + + return raw_with_stim