Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Within Session splitter #664

Open
wants to merge 49 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
bacedc5
Creating new splitters and base evaluation
brunaafl Jun 6, 2024
419b2ca
Adding metasplitters
brunaafl Jun 7, 2024
d6e795d
Fixing LazyEvaluation
brunaafl Jun 10, 2024
140670c
Merge branch 'NeuroTechX:develop' into eval_splitters
brunaafl Jun 10, 2024
d724674
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2024
a278026
More optimized version of TimeSeriesSplit
brunaafl Jun 10, 2024
300a6b9
More optimized version of TimeSeriesSplit
brunaafl Jun 10, 2024
7cb79f6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2024
55db70f
Addressing some comments: documentation, types, inconsistencies
brunaafl Jun 10, 2024
2851a15
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2024
c73dd1a
Addressing some comments: optimizing code, adjusts
brunaafl Jun 12, 2024
2b0e735
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 12, 2024
cf4b709
Adding examples
brunaafl Jun 26, 2024
177bf65
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 26, 2024
a6b5772
Adding: Pytests for evaluation splitters, and examples for meta split…
brunaafl Aug 15, 2024
26b13d5
Changing: name of TimeSeriesSplit to PseudoOnlineSplit
brunaafl Sep 30, 2024
e6661c4
Merge branch 'develop' into eval_splitters
brunaafl Sep 30, 2024
430e3a8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2024
698e539
Fixing pre-commit
brunaafl Sep 30, 2024
0fff053
Merge remote-tracking branch 'origin/eval_splitters' into eval_splitters
brunaafl Sep 30, 2024
98d12ac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2024
558d27b
Adding some tests for metasplitters
brunaafl Oct 1, 2024
34ea645
Merge remote-tracking branch 'origin/eval_splitters' into eval_splitters
brunaafl Oct 1, 2024
b435bf8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 1, 2024
d8f26a3
Fixing pre-commit
brunaafl Oct 1, 2024
eaf0fb9
Merge remote-tracking branch 'origin/eval_splitters' into eval_splitters
brunaafl Oct 1, 2024
e5159f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 1, 2024
516a5e8
Fixing pre-commit
brunaafl Oct 1, 2024
b29ecd2
Merge remote-tracking branch 'origin/eval_splitters' into eval_splitters
brunaafl Oct 1, 2024
37cff03
Fix example SamplerSplit
brunaafl Oct 17, 2024
88ee910
Add shuffle and random_state parameters to WithinSession
brunaafl Oct 18, 2024
ea9cc59
Change nomenclature of variables
brunaafl Oct 18, 2024
819c4ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 18, 2024
65c305e
Merge branch 'develop' into within_session
bruAristimunha Oct 18, 2024
f1ad587
FIX: fixing the whats_new.rst file
bruAristimunha Oct 18, 2024
485e7a5
EHN: playing a little
bruAristimunha Oct 18, 2024
3f3742f
FIX: fixing the import and docs/docstring
bruAristimunha Oct 18, 2024
c181c59
FIX: fixing the import and docs/docstring
bruAristimunha Oct 18, 2024
8f034c8
FIX: fixing the import and docs/docstring
bruAristimunha Oct 18, 2024
fbef726
FIX: removing cross-session and cross-subject
bruAristimunha Oct 18, 2024
837c061
FIX: focus only in the within-session
bruAristimunha Oct 18, 2024
34822e9
Merge branch 'develop' into within_session
bruAristimunha Oct 19, 2024
39e92e5
Fix test
brunaafl Oct 19, 2024
612c6a6
Merge remote-tracking branch 'origin/within_session' into within_session
brunaafl Oct 19, 2024
590edb1
[FIX] I think it is fixed.
bruAristimunha Oct 23, 2024
b151d61
[FIX] shuffle everything
bruAristimunha Oct 23, 2024
602ccd5
Merge remote-tracking branch 'origin/within_session' into within_session
brunaafl Oct 25, 2024
74cf246
Changing WithinSession image
brunaafl Oct 25, 2024
c85928d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/evaluations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ Evaluations
CrossSubjectEvaluation


.. autosummary::
:toctree: generated/
:template: class.rst

WithinSessionSplitter


------------
Base & Utils
------------
Expand Down
Binary file added docs/source/images/withinsess.pdf
Binary file not shown.
Binary file added docs/source/images/withinsess.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Develop branch

Enhancements
~~~~~~~~~~~~
- Adding :class:`moabb.evaluations.splitters.WithinSessionSplitter` (:gh:`664` by `Bruna Lopes_`)

Bugs
~~~~
Expand Down
Binary file added images/withinsess.png
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not using only the other version?

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions moabb/evaluations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
CrossSubjectEvaluation,
WithinSessionEvaluation,
)
from .splitters import WithinSessionSplitter
from .utils import create_save_path, save_model_cv, save_model_list
109 changes: 109 additions & 0 deletions moabb/evaluations/splitters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from sklearn.model_selection import BaseCrossValidator, StratifiedKFold
from sklearn.utils import check_random_state


class WithinSessionSplitter(BaseCrossValidator):
"""Data splitter for within session evaluation.

Within-session evaluation uses k-fold cross_validation to determine train
and test sets on separate session for each subject. This splitter assumes that
all data from all subjects is already known and loaded.
Comment on lines +9 to +10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
and test sets on separate session for each subject. This splitter assumes that
all data from all subjects is already known and loaded.
and test sets for each subject in each session. This splitter
assumes that all data from all subjects is already known and loaded.


.. image:: images/withinsess.png
:alt: The schematic diagram of the WithinSession split
:align: center


Parameters
----------
n_folds : int
Number of folds. Must be at least 2.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Number of folds. Must be at least 2.
Number of folds for the within-session k-fold split. Must be at least 2.

random_state: int, RandomState instance or None, default=None
Important when `shuffle` is True. Controls the randomness of splits.
Pass an int for reproducible output across multiple function calls.
Comment on lines +21 to +23
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move at the end.

Suggested change
random_state: int, RandomState instance or None, default=None
Important when `shuffle` is True. Controls the randomness of splits.
Pass an int for reproducible output across multiple function calls.
random_state: int, RandomState instance or None, default=None
Controls the randomness of splits. Only used when `shuffle` is True.
Pass an int for reproducible output across multiple function calls.

shuffle_session : bool, default=True
Whether to shuffle each class's samples before splitting into batches.
Note that the samples within each split will not be shuffled.
shuffle_subjects : bool, default=False
Apply shuffle in mixing subjects and sessions, this parameter allows
sample iterations of the sppliter.
Comment on lines +24 to +29
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it is necessary to have both? I don't really see any use case where I would only use one no?
I would only have a shuffle


Examples
-----------

>>> import pandas as pd
>>> import numpy as np
>>> from moabb.evaluations.splitters import WithinSessionSplitter
>>> X = np.array([[1, 2], [3, 4], [5, 6], [1,4], [7, 4], [5, 8], [0,3], [2,4]])
>>> y = np.array([1, 2, 1, 2, 1, 2, 1, 2])
>>> subjects = np.array([1, 1, 1, 1, 1, 1, 1, 1])
>>> sessions = np.array(['T', 'T', 'E', 'E', 'T', 'T', 'E', 'E'])
>>> metadata = pd.DataFrame(data={'subject': subjects, 'session': sessions})
>>> csess = WithinSessionSplitter(n_folds=2)
>>> csess.get_n_splits(metadata)
4
>>> for i, (train_index, test_index) in enumerate(csess.split(y, metadata)):
... print(f"Fold {i}:")
... print(f" Train: index={train_index}, group={subjects[train_index]}, session={sessions[train_index]}")
... print(f" Test: index={test_index}, group={subjects[test_index]}, sessions={sessions[test_index]}")
Fold 0:
Train: index=[2 7], group=[1 1], session=['E' 'E']
Test: index=[3 6], group=[1 1], sessions=['E' 'E']
Fold 1:
Train: index=[3 6], group=[1 1], session=['E' 'E']
Test: index=[2 7], group=[1 1], sessions=['E' 'E']
Fold 2:
Train: index=[4 5], group=[1 1], session=['T' 'T']
Test: index=[0 1], group=[1 1], sessions=['T' 'T']
Fold 3:
Train: index=[0 1], group=[1 1], session=['T' 'T']
Test: index=[4 5], group=[1 1], sessions=['T' 'T']
"""

def __init__(
self,
n_folds: int = 5,
random_state: int = 42,
shuffle_subjects: bool = False,
shuffle_session: bool = True,
Comment on lines +66 to +68
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default random_state should be None. Also, convention would be to put it at the end of the argument list.

Suggested change
random_state: int = 42,
shuffle_subjects: bool = False,
shuffle_session: bool = True,
shuffle_subjects: bool = False,
shuffle_session: bool = True,
random_state: int = None,

):
self.n_folds = n_folds
self.shuffle_subjects = shuffle_subjects
self.shuffle_session = shuffle_session
self.random_state = check_random_state(random_state)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you use it like that, this is not a random_state anymore but a rng.


def get_n_splits(self, metadata):
num_sessions_subjects = metadata.groupby(["subject", "session"]).ngroups
return self.n_folds * num_sessions_subjects

def split(self, y, metadata, **kwargs):
all_index = metadata.index.values
subjects = metadata.subject.unique()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, avoid using the .col notation in pandas, this is super error prone and hard to read, use the indexation:

Suggested change
subjects = metadata.subject.unique()
subjects = metadata['subject'].unique()


# Shuffle subjects if required
if self.shuffle_subjects:
self.random_state.shuffle(subjects)

for subject in subjects:
subject_mask = metadata.subject == subject
subject_indices = all_index[subject_mask]
subject_metadata = metadata[subject_mask]
sessions = subject_metadata.session.unique()

# Shuffle sessions if required
if self.shuffle_session:
self.random_state.shuffle(sessions)

for session in sessions:
session_mask = subject_metadata.session == session
indices = subject_indices[session_mask]
group_y = y[indices]

# Use StratifiedKFold with the group-specific random state
cv = StratifiedKFold(
n_splits=self.n_folds,
shuffle=self.shuffle_session,
random_state=self.random_state,
)
for ix_train, ix_test in cv.split(indices, group_y):
yield indices[ix_train], indices[ix_test]
Comment on lines +87 to +109
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We talk a bit with @sylvchev and I think the best would be to modify this, to take a cv object in the constructor (default would be StratifiedKFold), clone it with a different random seed for each group subject, session, and then yield the right indices.

That way, we can do a real shuffle, with shuffling the groups from which we retrieve the next split.
Would this make sense?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, I'm working on that, thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great! Thank you so much! And thanks for your patience :)

13 changes: 13 additions & 0 deletions moabb/evaluations/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import re
from pathlib import Path
from pickle import HIGHEST_PROTOCOL, dump
from typing import Sequence

import numpy as np
from numpy import argmax
from sklearn.pipeline import Pipeline

Expand Down Expand Up @@ -222,6 +224,17 @@ def create_save_path(
print("No hdf5_path provided, models will not be saved.")


def sort_group(groups):
runs_sort = []
pattern = r"([0-9]+)(|[a-zA-Z]+[a-zA-Z0-9]*)"
for i, group in enumerate(groups):
index, description = re.fullmatch(pattern, group).groups()
index = int(index)
runs_sort.append(index)
sorted_ix = np.argsort(runs_sort)
return groups[sorted_ix]


def _convert_sklearn_params_to_optuna(param_grid: dict) -> dict:
"""
Function to convert the parameter in Optuna format. This function will
Expand Down
58 changes: 58 additions & 0 deletions moabb/tests/splits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np
import pytest
from sklearn.model_selection import StratifiedKFold
from sklearn.utils import check_random_state

from moabb.datasets.fake import FakeDataset
from moabb.evaluations.splitters import WithinSessionSplitter
from moabb.paradigms.motor_imagery import FakeImageryParadigm


dataset = FakeDataset(["left_hand", "right_hand"], n_subjects=3, seed=12)
paradigm = FakeImageryParadigm()


# Split done for the Within Session evaluation
def eval_split_within_session(shuffle, random_state):
random_state = check_random_state(random_state) if shuffle else None
for subject in dataset.subject_list:
X, y, metadata = paradigm.get_data(dataset=dataset, subjects=[subject])
sessions = metadata.session
for session in np.unique(sessions):
ix = sessions == session
cv = StratifiedKFold(n_splits=5, shuffle=shuffle, random_state=random_state)
X_, metadata_, y_ = X[ix], y[ix], metadata[ix]
for train, test in cv.split(y_, metadata_):
yield X_[train], X_[test]


@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("random_state", [0, 42])
def test_within_session(shuffle, random_state):
X, y, metadata = paradigm.get_data(dataset=dataset)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is important to check if the split is the same when we load the data of one/a few subject(s) only, paradigm.get_data(dataset=dataset, subjects=[m, n...])


split = WithinSessionSplitter(n_folds=5, shuffle=shuffle, random_state=random_state)

for (X_train_t, X_test_t), (train, test) in zip(
eval_split_within_session(shuffle=shuffle, random_state=random_state),
split.split(y, metadata),
):
X_train, X_test = X[train], X[test]

# Check if the output is the same as the input
assert np.array_equal(X_train, X_train_t)
assert np.array_equal(X_test, X_test_t)


def test_is_shuffling():
X, y, metadata = paradigm.get_data(dataset=dataset)

split = WithinSessionSplitter(n_folds=5, shuffle=False)
split_shuffle = WithinSessionSplitter(n_folds=5, shuffle=True, random_state=3)

for (train, test), (train_shuffle, test_shuffle) in zip(
split.split(y, metadata), split_shuffle.split(y, metadata)
):
# Check if the output is the same as the input
assert np.array_equal(train, train_shuffle) == False
assert np.array_equal(test, test_shuffle) == False
Loading