-
Notifications
You must be signed in to change notification settings - Fork 181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Within Session splitter #664
base: develop
Are you sure you want to change the base?
Changes from all commits
bacedc5
419b2ca
d6e795d
140670c
d724674
a278026
300a6b9
7cb79f6
55db70f
2851a15
c73dd1a
2b0e735
cf4b709
177bf65
a6b5772
26b13d5
e6661c4
430e3a8
698e539
0fff053
98d12ac
558d27b
34ea645
b435bf8
d8f26a3
eaf0fb9
e5159f2
516a5e8
b29ecd2
37cff03
88ee910
ea9cc59
819c4ff
65c305e
f1ad587
485e7a5
3f3742f
c181c59
8f034c8
fbef726
837c061
34822e9
39e92e5
612c6a6
590edb1
b151d61
602ccd5
74cf246
c85928d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,109 @@ | ||||||||||||||
from sklearn.model_selection import BaseCrossValidator, StratifiedKFold | ||||||||||||||
from sklearn.utils import check_random_state | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
class WithinSessionSplitter(BaseCrossValidator): | ||||||||||||||
"""Data splitter for within session evaluation. | ||||||||||||||
|
||||||||||||||
Within-session evaluation uses k-fold cross_validation to determine train | ||||||||||||||
and test sets on separate session for each subject. This splitter assumes that | ||||||||||||||
all data from all subjects is already known and loaded. | ||||||||||||||
Comment on lines
+9
to
+10
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
||||||||||||||
.. image:: images/withinsess.png | ||||||||||||||
:alt: The schematic diagram of the WithinSession split | ||||||||||||||
:align: center | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
Parameters | ||||||||||||||
---------- | ||||||||||||||
n_folds : int | ||||||||||||||
Number of folds. Must be at least 2. | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
random_state: int, RandomState instance or None, default=None | ||||||||||||||
Important when `shuffle` is True. Controls the randomness of splits. | ||||||||||||||
Pass an int for reproducible output across multiple function calls. | ||||||||||||||
Comment on lines
+21
to
+23
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move at the end.
Suggested change
|
||||||||||||||
shuffle_session : bool, default=True | ||||||||||||||
Whether to shuffle each class's samples before splitting into batches. | ||||||||||||||
Note that the samples within each split will not be shuffled. | ||||||||||||||
shuffle_subjects : bool, default=False | ||||||||||||||
Apply shuffle in mixing subjects and sessions, this parameter allows | ||||||||||||||
sample iterations of the sppliter. | ||||||||||||||
Comment on lines
+24
to
+29
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think it is necessary to have both? I don't really see any use case where I would only use one no? |
||||||||||||||
|
||||||||||||||
Examples | ||||||||||||||
----------- | ||||||||||||||
|
||||||||||||||
>>> import pandas as pd | ||||||||||||||
>>> import numpy as np | ||||||||||||||
>>> from moabb.evaluations.splitters import WithinSessionSplitter | ||||||||||||||
>>> X = np.array([[1, 2], [3, 4], [5, 6], [1,4], [7, 4], [5, 8], [0,3], [2,4]]) | ||||||||||||||
>>> y = np.array([1, 2, 1, 2, 1, 2, 1, 2]) | ||||||||||||||
>>> subjects = np.array([1, 1, 1, 1, 1, 1, 1, 1]) | ||||||||||||||
>>> sessions = np.array(['T', 'T', 'E', 'E', 'T', 'T', 'E', 'E']) | ||||||||||||||
>>> metadata = pd.DataFrame(data={'subject': subjects, 'session': sessions}) | ||||||||||||||
>>> csess = WithinSessionSplitter(n_folds=2) | ||||||||||||||
>>> csess.get_n_splits(metadata) | ||||||||||||||
4 | ||||||||||||||
>>> for i, (train_index, test_index) in enumerate(csess.split(y, metadata)): | ||||||||||||||
... print(f"Fold {i}:") | ||||||||||||||
... print(f" Train: index={train_index}, group={subjects[train_index]}, session={sessions[train_index]}") | ||||||||||||||
... print(f" Test: index={test_index}, group={subjects[test_index]}, sessions={sessions[test_index]}") | ||||||||||||||
Fold 0: | ||||||||||||||
Train: index=[2 7], group=[1 1], session=['E' 'E'] | ||||||||||||||
Test: index=[3 6], group=[1 1], sessions=['E' 'E'] | ||||||||||||||
Fold 1: | ||||||||||||||
Train: index=[3 6], group=[1 1], session=['E' 'E'] | ||||||||||||||
Test: index=[2 7], group=[1 1], sessions=['E' 'E'] | ||||||||||||||
Fold 2: | ||||||||||||||
Train: index=[4 5], group=[1 1], session=['T' 'T'] | ||||||||||||||
Test: index=[0 1], group=[1 1], sessions=['T' 'T'] | ||||||||||||||
Fold 3: | ||||||||||||||
Train: index=[0 1], group=[1 1], session=['T' 'T'] | ||||||||||||||
Test: index=[4 5], group=[1 1], sessions=['T' 'T'] | ||||||||||||||
""" | ||||||||||||||
|
||||||||||||||
def __init__( | ||||||||||||||
self, | ||||||||||||||
n_folds: int = 5, | ||||||||||||||
random_state: int = 42, | ||||||||||||||
shuffle_subjects: bool = False, | ||||||||||||||
shuffle_session: bool = True, | ||||||||||||||
Comment on lines
+66
to
+68
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Default
Suggested change
|
||||||||||||||
): | ||||||||||||||
self.n_folds = n_folds | ||||||||||||||
self.shuffle_subjects = shuffle_subjects | ||||||||||||||
self.shuffle_session = shuffle_session | ||||||||||||||
self.random_state = check_random_state(random_state) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you use it like that, this is not a |
||||||||||||||
|
||||||||||||||
def get_n_splits(self, metadata): | ||||||||||||||
num_sessions_subjects = metadata.groupby(["subject", "session"]).ngroups | ||||||||||||||
return self.n_folds * num_sessions_subjects | ||||||||||||||
|
||||||||||||||
def split(self, y, metadata, **kwargs): | ||||||||||||||
all_index = metadata.index.values | ||||||||||||||
subjects = metadata.subject.unique() | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general, avoid using the
Suggested change
|
||||||||||||||
|
||||||||||||||
# Shuffle subjects if required | ||||||||||||||
if self.shuffle_subjects: | ||||||||||||||
self.random_state.shuffle(subjects) | ||||||||||||||
|
||||||||||||||
for subject in subjects: | ||||||||||||||
subject_mask = metadata.subject == subject | ||||||||||||||
subject_indices = all_index[subject_mask] | ||||||||||||||
subject_metadata = metadata[subject_mask] | ||||||||||||||
sessions = subject_metadata.session.unique() | ||||||||||||||
|
||||||||||||||
# Shuffle sessions if required | ||||||||||||||
if self.shuffle_session: | ||||||||||||||
self.random_state.shuffle(sessions) | ||||||||||||||
|
||||||||||||||
for session in sessions: | ||||||||||||||
session_mask = subject_metadata.session == session | ||||||||||||||
indices = subject_indices[session_mask] | ||||||||||||||
group_y = y[indices] | ||||||||||||||
|
||||||||||||||
# Use StratifiedKFold with the group-specific random state | ||||||||||||||
cv = StratifiedKFold( | ||||||||||||||
n_splits=self.n_folds, | ||||||||||||||
shuffle=self.shuffle_session, | ||||||||||||||
random_state=self.random_state, | ||||||||||||||
) | ||||||||||||||
for ix_train, ix_test in cv.split(indices, group_y): | ||||||||||||||
yield indices[ix_train], indices[ix_test] | ||||||||||||||
Comment on lines
+87
to
+109
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We talk a bit with @sylvchev and I think the best would be to modify this, to take a That way, we can do a real shuffle, with shuffling the groups from which we retrieve the next split. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make sense, I'm working on that, thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. great! Thank you so much! And thanks for your patience :) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import numpy as np | ||
import pytest | ||
from sklearn.model_selection import StratifiedKFold | ||
from sklearn.utils import check_random_state | ||
|
||
from moabb.datasets.fake import FakeDataset | ||
from moabb.evaluations.splitters import WithinSessionSplitter | ||
from moabb.paradigms.motor_imagery import FakeImageryParadigm | ||
|
||
|
||
dataset = FakeDataset(["left_hand", "right_hand"], n_subjects=3, seed=12) | ||
paradigm = FakeImageryParadigm() | ||
|
||
|
||
# Split done for the Within Session evaluation | ||
def eval_split_within_session(shuffle, random_state): | ||
random_state = check_random_state(random_state) if shuffle else None | ||
for subject in dataset.subject_list: | ||
X, y, metadata = paradigm.get_data(dataset=dataset, subjects=[subject]) | ||
sessions = metadata.session | ||
for session in np.unique(sessions): | ||
ix = sessions == session | ||
cv = StratifiedKFold(n_splits=5, shuffle=shuffle, random_state=random_state) | ||
X_, metadata_, y_ = X[ix], y[ix], metadata[ix] | ||
for train, test in cv.split(y_, metadata_): | ||
yield X_[train], X_[test] | ||
|
||
|
||
@pytest.mark.parametrize("shuffle", [True, False]) | ||
@pytest.mark.parametrize("random_state", [0, 42]) | ||
def test_within_session(shuffle, random_state): | ||
X, y, metadata = paradigm.get_data(dataset=dataset) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is important to check if the split is the same when we load the data of one/a few subject(s) only, |
||
|
||
split = WithinSessionSplitter(n_folds=5, shuffle=shuffle, random_state=random_state) | ||
|
||
for (X_train_t, X_test_t), (train, test) in zip( | ||
eval_split_within_session(shuffle=shuffle, random_state=random_state), | ||
split.split(y, metadata), | ||
): | ||
X_train, X_test = X[train], X[test] | ||
|
||
# Check if the output is the same as the input | ||
assert np.array_equal(X_train, X_train_t) | ||
assert np.array_equal(X_test, X_test_t) | ||
|
||
|
||
def test_is_shuffling(): | ||
X, y, metadata = paradigm.get_data(dataset=dataset) | ||
|
||
split = WithinSessionSplitter(n_folds=5, shuffle=False) | ||
split_shuffle = WithinSessionSplitter(n_folds=5, shuffle=True, random_state=3) | ||
|
||
for (train, test), (train_shuffle, test_shuffle) in zip( | ||
split.split(y, metadata), split_shuffle.split(y, metadata) | ||
): | ||
# Check if the output is the same as the input | ||
assert np.array_equal(train, train_shuffle) == False | ||
assert np.array_equal(test, test_shuffle) == False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not using only the other version?