From be57f818d4da11fc5422d5d925947834356da99a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Mon, 8 Mar 2021 18:32:17 +0100 Subject: [PATCH 01/17] style: formatted with black --- docs/source/conf.py | 46 +-- examples/plot_cross_session_motor_imagery.py | 37 +- .../plot_cross_session_multiple_datasets.py | 19 +- examples/plot_cross_session_ssvep.py | 11 +- examples/plot_cross_subject_ssvep.py | 31 +- examples/plot_filterbank_csp_vs_csp.py | 39 ++- examples/plot_within_session_p300.py | 39 ++- moabb/analysis/__init__.py | 7 +- moabb/analysis/meta_analysis.py | 71 ++-- moabb/analysis/plotting.py | 132 ++++--- moabb/analysis/results.py | 82 +++-- moabb/datasets/Weibo2014.py | 68 ++-- moabb/datasets/Zhou2016.py | 47 +-- moabb/datasets/alex_mi.py | 13 +- moabb/datasets/base.py | 42 ++- moabb/datasets/bbci_eeg_fnirs.py | 84 ++--- moabb/datasets/bnci.py | 323 +++++++++++------- moabb/datasets/braininvaders.py | 35 +- moabb/datasets/download.py | 3 +- moabb/datasets/epfl.py | 38 +-- moabb/datasets/fake.py | 37 +- moabb/datasets/gigadb.py | 54 +-- moabb/datasets/mpi_mi.py | 23 +- moabb/datasets/physionet_mi.py | 36 +- moabb/datasets/schirrmeister2017.py | 177 +++++----- moabb/datasets/ssvep_exo.py | 21 +- moabb/datasets/ssvep_mamem.py | 62 ++-- moabb/datasets/ssvep_nakanishi.py | 50 +-- moabb/datasets/ssvep_wang.py | 34 +- moabb/datasets/upper_limb.py | 46 ++- moabb/datasets/utils.py | 25 +- moabb/evaluations/base.py | 85 +++-- moabb/evaluations/evaluations.py | 104 +++--- moabb/paradigms/base.py | 60 ++-- moabb/paradigms/motor_imagery.py | 83 +++-- moabb/paradigms/p300.py | 68 ++-- moabb/paradigms/ssvep.py | 62 ++-- moabb/pipelines/classification.py | 9 +- moabb/pipelines/csp.py | 12 +- moabb/pipelines/features.py | 13 +- moabb/pipelines/utils.py | 12 +- moabb/run.py | 72 ++-- moabb/tests/analysis.py | 107 +++--- moabb/tests/datasets.py | 12 +- moabb/tests/download.py | 6 +- moabb/tests/evaluations.py | 27 +- moabb/tests/paradigms.py | 49 +-- moabb/tests/util_tests.py | 46 ++- moabb/utils.py | 4 +- pipelines/CSP_svm_search.py | 4 +- pipelines/FBCSP.py | 11 +- pipelines/LogVar.py | 4 +- pipelines/TSSVM.py | 4 +- pipelines/WTRCSP.py | 7 +- pyproject.toml | 78 +---- setup.py | 31 +- tutorials/plot_Getting_Started.py | 8 +- tutorials/plot_statistical_analysis.py | 12 +- tutorials/select_electrodes_resample.py | 15 +- ...tutorial_1_simple_example_motor_imagery.py | 10 +- .../tutorial_2_using_mulitple_datasets.py | 16 +- ...orial_3_benchmarking_multiple_pipelines.py | 24 +- tutorials/tutorial_4_adding_a_dataset.py | 21 +- 63 files changed, 1553 insertions(+), 1255 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 6547eda51..3bcba8365 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -74,7 +74,8 @@ sphinx_gallery_conf = { 'examples_dirs': ['../../examples', '../../tutorials'], 'gallery_dirs': ['auto_examples', 'auto_tutorials'], - 'backreferences_dir': False} + 'backreferences_dir': False, +} # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] @@ -124,10 +125,8 @@ html_theme_options = { # Navigation bar title. (Default: ``project`` value) # 'navbar_title': "Demo", - # Tab name for entire site. (Default: "Site") # 'navbar_site_name': "Site", - # A list of tuples containing pages or urls to link to. # Valid tuples should be in the following forms: # (name, page) # a link to a page @@ -135,23 +134,20 @@ # (name, "http://example.com", True) # arbitrary absolute url # Note the "1" or "True" value above as the third argument to indicate # an arbitrary url. - 'navbar_links': [("API", "api"), - ("Gallery", "auto_examples/index"), - ("Tutorials", "auto_tutorials/index")], - + 'navbar_links': [ + ("API", "api"), + ("Gallery", "auto_examples/index"), + ("Tutorials", "auto_tutorials/index"), + ], # Render the next and previous page links in navbar. (Default: true) 'navbar_sidebarrel': False, - # Render the current pages TOC in the navbar. (Default: true) 'navbar_pagenav': True, - # Tab name for the current pages TOC. (Default: "Page") 'navbar_pagenav_name': "Page", - # Global TOC depth for "site" navbar tab. (Default: 1) # Switching to -1 shows all levels. 'globaltoc_depth': 2, - # Include hidden TOCs in Site navbar? # # Note: If this is "false", you cannot have mixed ``:hidden:`` and @@ -160,19 +156,15 @@ # # Values: "true" (default) or "false" 'globaltoc_includehidden': "true", - # HTML navbar class (Default: "navbar") to attach to
element. # For black navbar, do "navbar navbar-inverse" 'navbar_class': "navbar navbar-inverse", - # Fix navigation bar to top of page? # Values: "true" (default) or "false" 'navbar_fixed_top': "true", - # Location of link to source. # Options are "nav" (default), "footer" or anything else to exclude. 'source_link_position': "footer", - # Bootswatch (http://bootswatch.com/) theme. # # Options are nothing (default) or the name of a valid theme @@ -185,7 +177,6 @@ # - Bootstrap 2: https://bootswatch.com/2 # - Bootstrap 3: https://bootswatch.com/3 'bootswatch_theme': "united", - # Choose Bootstrap version. # Values: "3" (default) or "2" (in quotes) 'bootstrap_version': "3", @@ -219,15 +210,12 @@ # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -237,8 +225,7 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'moabb.tex', 'moabb Documentation', - 'Alexandre Barachant', 'manual'), + (master_doc, 'moabb.tex', 'moabb Documentation', 'Alexandre Barachant', 'manual'), ] @@ -246,10 +233,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'moabb', 'moabb Documentation', - [author], 1) -] +man_pages = [(master_doc, 'moabb', 'moabb Documentation', [author], 1)] # -- Options for Texinfo output ---------------------------------------------- @@ -258,9 +242,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'moabb', 'moabb Documentation', - author, 'moabb', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + 'moabb', + 'moabb Documentation', + author, + 'moabb', + 'One line description of project.', + 'Miscellaneous', + ), ] diff --git a/examples/plot_cross_session_motor_imagery.py b/examples/plot_cross_session_motor_imagery.py index 87992be8e..96bab1f03 100644 --- a/examples/plot_cross_session_motor_imagery.py +++ b/examples/plot_cross_session_motor_imagery.py @@ -54,12 +54,11 @@ pipelines = {} -pipelines['CSP + LDA'] = make_pipeline(CSP(n_components=8), - LDA()) +pipelines['CSP + LDA'] = make_pipeline(CSP(n_components=8), LDA()) -pipelines['RG + LR'] = make_pipeline(Covariances(), - TangentSpace(), - LogisticRegression(solver='lbfgs')) +pipelines['RG + LR'] = make_pipeline( + Covariances(), TangentSpace(), LogisticRegression(solver='lbfgs') +) ############################################################################## # Evaluation @@ -79,8 +78,9 @@ dataset.subject_list = dataset.subject_list[:2] datasets = [dataset] overwrite = False # set to True if we want to overwrite cached results -evaluation = CrossSessionEvaluation(paradigm=paradigm, datasets=datasets, - suffix='examples', overwrite=overwrite) +evaluation = CrossSessionEvaluation( + paradigm=paradigm, datasets=datasets, suffix='examples', overwrite=overwrite +) results = evaluation.process(pipelines) @@ -98,21 +98,28 @@ fig, axes = plt.subplots(1, 2, figsize=[8, 4], sharey=True) -sns.stripplot(data=results, y='score', x='pipeline', ax=axes[0], jitter=True, - alpha=.5, zorder=1, palette="Set1") -sns.pointplot(data=results, y='score', x='pipeline', ax=axes[0], - zorder=1, palette="Set1") +sns.stripplot( + data=results, + y='score', + x='pipeline', + ax=axes[0], + jitter=True, + alpha=0.5, + zorder=1, + palette="Set1", +) +sns.pointplot(data=results, y='score', x='pipeline', ax=axes[0], zorder=1, palette="Set1") axes[0].set_ylabel('ROC AUC') axes[0].set_ylim(0.5, 1) # paired plot -paired = results.pivot_table(values='score', columns='pipeline', - index=['subject', 'session']) +paired = results.pivot_table( + values='score', columns='pipeline', index=['subject', 'session'] +) paired = paired.reset_index() -sns.regplot(data=paired, y='RG + LR', x='CSP + LDA', ax=axes[1], - fit_reg=False) +sns.regplot(data=paired, y='RG + LR', x='CSP + LDA', ax=axes[1], fit_reg=False) axes[1].plot([0, 1], [0, 1], ls='--', c='k') axes[1].set_xlim(0.5, 1) diff --git a/examples/plot_cross_session_multiple_datasets.py b/examples/plot_cross_session_multiple_datasets.py index 5e053bc05..2e4363c8a 100644 --- a/examples/plot_cross_session_multiple_datasets.py +++ b/examples/plot_cross_session_multiple_datasets.py @@ -66,8 +66,7 @@ freqs = paradigm.used_events(datasets[0]) pipeline = {} -pipeline["CCA"] = make_pipeline( - SSVEP_CCA(interval=interval, freqs=freqs, n_harmonics=3)) +pipeline["CCA"] = make_pipeline(SSVEP_CCA(interval=interval, freqs=freqs, n_harmonics=3)) ############################################################################## # Get data (optional) @@ -97,8 +96,9 @@ overwrite = True # set to True if we want to overwrite cached results -evaluation = CrossSessionEvaluation(paradigm=paradigm, datasets=datasets, - suffix='examples', overwrite=overwrite) +evaluation = CrossSessionEvaluation( + paradigm=paradigm, datasets=datasets, suffix='examples', overwrite=overwrite +) results = evaluation.process(pipeline) print(results.head()) @@ -109,6 +109,13 @@ # # Here we plot the results, indicating the score for each session and subject -sns.catplot(data=results, x='session', y='score', hue='subject', - col='dataset', kind='bar', palette='viridis') +sns.catplot( + data=results, + x='session', + y='score', + hue='subject', + col='dataset', + kind='bar', + palette='viridis', +) plt.show() diff --git a/examples/plot_cross_session_ssvep.py b/examples/plot_cross_session_ssvep.py index a46c4b13c..92e2f6ea5 100644 --- a/examples/plot_cross_session_ssvep.py +++ b/examples/plot_cross_session_ssvep.py @@ -64,8 +64,7 @@ freqs = paradigm.used_events(dataset) pipeline = {} -pipeline["CCA"] = make_pipeline( - SSVEP_CCA(interval=interval, freqs=freqs, n_harmonics=3)) +pipeline["CCA"] = make_pipeline(SSVEP_CCA(interval=interval, freqs=freqs, n_harmonics=3)) ############################################################################## # Get data (optional) @@ -90,8 +89,9 @@ overwrite = True # set to True if we want to overwrite cached results -evaluation = CrossSessionEvaluation(paradigm=paradigm, datasets=dataset, - suffix='examples', overwrite=overwrite) +evaluation = CrossSessionEvaluation( + paradigm=paradigm, datasets=dataset, suffix='examples', overwrite=overwrite +) results = evaluation.process(pipeline) print(results.head()) @@ -103,7 +103,6 @@ # Here we plot the results, indicating the score for each session and subject plt.figure() -sns.barplot(data=results, y='score', x='session', - hue='subject', palette='viridis') +sns.barplot(data=results, y='score', x='session', hue='subject', palette='viridis') plt.show() diff --git a/examples/plot_cross_subject_ssvep.py b/examples/plot_cross_subject_ssvep.py index f7eceb183..fc064d054 100644 --- a/examples/plot_cross_subject_ssvep.py +++ b/examples/plot_cross_subject_ssvep.py @@ -88,11 +88,11 @@ ExtendedSSVEPSignal(), Covariances(estimator='lwf'), TangentSpace(), - LogisticRegression(solver='lbfgs', multi_class='auto')) + LogisticRegression(solver='lbfgs', multi_class='auto'), +) pipelines = {} -pipelines['CCA'] = make_pipeline( - SSVEP_CCA(interval=interval, freqs=freqs, n_harmonics=3)) +pipelines['CCA'] = make_pipeline(SSVEP_CCA(interval=interval, freqs=freqs, n_harmonics=3)) ############################################################################## # Evaluation @@ -107,14 +107,16 @@ overwrite = False # set to True if we want to overwrite cached results -evaluation = CrossSubjectEvaluation(paradigm=paradigm, - datasets=dataset, overwrite=overwrite) +evaluation = CrossSubjectEvaluation( + paradigm=paradigm, datasets=dataset, overwrite=overwrite +) results = evaluation.process(pipelines) # Filter bank processing, determine automatically the filter from the # stimulation frequency values of events. -evaluation_fb = CrossSubjectEvaluation(paradigm=paradigm_fb, - datasets=dataset, overwrite=overwrite) +evaluation_fb = CrossSubjectEvaluation( + paradigm=paradigm_fb, datasets=dataset, overwrite=overwrite +) results_fb = evaluation_fb.process(pipelines_fb) ############################################################################### @@ -129,10 +131,17 @@ # Here we plot the results. fig, ax = plt.subplots(facecolor='white', figsize=[8, 4]) -sns.stripplot(data=results, y='score', x='pipeline', ax=ax, jitter=True, - alpha=.5, zorder=1, palette="Set1") -sns.pointplot(data=results, y='score', x='pipeline', ax=ax, - zorder=1, palette="Set1") +sns.stripplot( + data=results, + y='score', + x='pipeline', + ax=ax, + jitter=True, + alpha=0.5, + zorder=1, + palette="Set1", +) +sns.pointplot(data=results, y='score', x='pipeline', ax=ax, zorder=1, palette="Set1") ax.set_ylabel('Accuracy') ax.set_ylim(0.1, 0.6) plt.savefig('ssvep.png') diff --git a/examples/plot_filterbank_csp_vs_csp.py b/examples/plot_filterbank_csp_vs_csp.py index 29f29cb80..c0e2adaa6 100644 --- a/examples/plot_filterbank_csp_vs_csp.py +++ b/examples/plot_filterbank_csp_vs_csp.py @@ -41,12 +41,10 @@ # their own dict. pipelines = {} -pipelines['CSP + LDA'] = make_pipeline(CSP(n_components=8), - LDA()) +pipelines['CSP + LDA'] = make_pipeline(CSP(n_components=8), LDA()) pipelines_fb = {} -pipelines_fb['FBCSP + LDA'] = make_pipeline(FilterBank(CSP(n_components=4)), - LDA()) +pipelines_fb['FBCSP + LDA'] = make_pipeline(FilterBank(CSP(n_components=4)), LDA()) ############################################################################## # Evaluation @@ -72,15 +70,17 @@ fmin = 8 fmax = 35 paradigm = LeftRightImagery(fmin=fmin, fmax=fmax) -evaluation = CrossSessionEvaluation(paradigm=paradigm, datasets=datasets, - suffix='examples', overwrite=overwrite) +evaluation = CrossSessionEvaluation( + paradigm=paradigm, datasets=datasets, suffix='examples', overwrite=overwrite +) results = evaluation.process(pipelines) # bank of 6 filter, by 4 Hz increment filters = [[8, 12], [12, 16], [16, 20], [20, 24], [24, 28], [28, 35]] paradigm = FilterBankLeftRightImagery(filters=filters) -evaluation = CrossSessionEvaluation(paradigm=paradigm, datasets=datasets, - suffix='examples', overwrite=overwrite) +evaluation = CrossSessionEvaluation( + paradigm=paradigm, datasets=datasets, suffix='examples', overwrite=overwrite +) results_fb = evaluation.process(pipelines_fb) ############################################################################### @@ -101,21 +101,28 @@ fig, axes = plt.subplots(1, 2, figsize=[8, 4], sharey=True) -sns.stripplot(data=results, y='score', x='pipeline', ax=axes[0], jitter=True, - alpha=.5, zorder=1, palette="Set1") -sns.pointplot(data=results, y='score', x='pipeline', ax=axes[0], - zorder=1, palette="Set1") +sns.stripplot( + data=results, + y='score', + x='pipeline', + ax=axes[0], + jitter=True, + alpha=0.5, + zorder=1, + palette="Set1", +) +sns.pointplot(data=results, y='score', x='pipeline', ax=axes[0], zorder=1, palette="Set1") axes[0].set_ylabel('ROC AUC') axes[0].set_ylim(0.5, 1) # paired plot -paired = results.pivot_table(values='score', columns='pipeline', - index=['subject', 'session']) +paired = results.pivot_table( + values='score', columns='pipeline', index=['subject', 'session'] +) paired = paired.reset_index() -sns.regplot(data=paired, y='FBCSP + LDA', x='CSP + LDA', ax=axes[1], - fit_reg=False) +sns.regplot(data=paired, y='FBCSP + LDA', x='CSP + LDA', ax=axes[1], fit_reg=False) axes[1].plot([0, 1], [0, 1], ls='--', c='k') axes[1].set_xlim(0.5, 1) diff --git a/examples/plot_within_session_p300.py b/examples/plot_within_session_p300.py index 7d6da07e0..0a23901c7 100644 --- a/examples/plot_within_session_p300.py +++ b/examples/plot_within_session_p300.py @@ -50,7 +50,6 @@ class Vectorizer(BaseEstimator, TransformerMixin): - def __init__(self): pass @@ -62,6 +61,7 @@ def transform(self, X): """transform. """ return np.reshape(X, (X.shape[0], -1)) + ############################################################################## # Create pipelines # ---------------- @@ -78,17 +78,15 @@ def transform(self, X): pipelines['RG + LDA'] = make_pipeline( XdawnCovariances( - nfilter=2, - classes=[ - labels_dict['Target']], - estimator='lwf', - xdawn_estimator='lwf'), + nfilter=2, classes=[labels_dict['Target']], estimator='lwf', xdawn_estimator='lwf' + ), TangentSpace(), - LDA(solver='lsqr', shrinkage='auto')) + LDA(solver='lsqr', shrinkage='auto'), +) -pipelines['Xdw + LDA'] = make_pipeline(Xdawn(nfilter=2, estimator='lwf'), - Vectorizer(), LDA(solver='lsqr', - shrinkage='auto')) +pipelines['Xdw + LDA'] = make_pipeline( + Xdawn(nfilter=2, estimator='lwf'), Vectorizer(), LDA(solver='lsqr', shrinkage='auto') +) ############################################################################## # Evaluation @@ -107,9 +105,9 @@ def transform(self, X): dataset.subject_list = dataset.subject_list[:2] datasets = [dataset] overwrite = True # set to True if we want to overwrite cached results -evaluation = WithinSessionEvaluation(paradigm=paradigm, - datasets=datasets, - suffix='examples', overwrite=overwrite) +evaluation = WithinSessionEvaluation( + paradigm=paradigm, datasets=datasets, suffix='examples', overwrite=overwrite +) results = evaluation.process(pipelines) ############################################################################## @@ -120,10 +118,17 @@ def transform(self, X): fig, ax = plt.subplots(facecolor='white', figsize=[8, 4]) -sns.stripplot(data=results, y='score', x='pipeline', ax=ax, jitter=True, - alpha=.5, zorder=1, palette="Set1") -sns.pointplot(data=results, y='score', x='pipeline', ax=ax, - zorder=1, palette="Set1") +sns.stripplot( + data=results, + y='score', + x='pipeline', + ax=ax, + jitter=True, + alpha=0.5, + zorder=1, + palette="Set1", +) +sns.pointplot(data=results, y='score', x='pipeline', ax=ax, zorder=1, palette="Set1") ax.set_ylabel('ROC AUC') ax.set_ylim(0.5, 1) diff --git a/moabb/analysis/__init__.py b/moabb/analysis/__init__.py index 213f84e10..726ee1e0b 100644 --- a/moabb/analysis/__init__.py +++ b/moabb/analysis/__init__.py @@ -15,7 +15,7 @@ def analyze(results, out_path, name='analysis', plot=False): - '''Analyze results. + """Analyze results. Given a results dataframe, generates a folder with results and a dataframe of the exact data used to generate those results, @@ -33,7 +33,7 @@ def analyze(results, out_path, name='analysis', plot=False): Either path or results is necessary - ''' + """ # input checks # if not isinstance(out_path, str): raise ValueError('Given out_path argument is not string') @@ -47,8 +47,7 @@ def analyze(results, out_path, name='analysis', plot=False): print(unique_ids) print(set(unique_ids)) if len(unique_ids) != len(set(unique_ids)): - log.warning( - 'Pipeline names are too similar, turning off name shortening') + log.warning('Pipeline names are too similar, turning off name shortening') simplify = False os.makedirs(analysis_path, exist_ok=True) diff --git a/moabb/analysis/meta_analysis.py b/moabb/analysis/meta_analysis.py index ac47d98dc..6005407bf 100644 --- a/moabb/analysis/meta_analysis.py +++ b/moabb/analysis/meta_analysis.py @@ -14,7 +14,7 @@ def collapse_session_scores(df): def compute_pvals_wilcoxon(df, order=None): - '''Returns kxk matrix of p-values computed via the Wilcoxon rank-sum test, + """Returns kxk matrix of p-values computed via the Wilcoxon rank-sum test, order defines the order of rows and columns df: DataFrame, samples are index, columns are pipelines, and values are @@ -23,7 +23,7 @@ def compute_pvals_wilcoxon(df, order=None): order: list of length (num algorithms) with names corresponding to columns of df - ''' + """ if order is None: order = df.columns else: @@ -47,23 +47,23 @@ def compute_pvals_wilcoxon(df, order=None): def _pairedttest_exhaustive(data): - '''Returns p-values for exhaustive ttest that runs through all possible + """Returns p-values for exhaustive ttest that runs through all possible permutations of the first dimension. Very bad idea for size greater than 12 data is a (subj, alg, alg) matrix of differences between scores for each pair of algorithms per subject - ''' + """ out = np.ones((data.shape[1], data.shape[1])) true = data.sum(axis=0) - nperms = 2**data.shape[0] + nperms = 2 ** data.shape[0] for perm in itertools.product([-1, 1], repeat=data.shape[0]): # turn into numpy array perm = np.array(perm) # multiply permutation by subject dimension and sum over subjects randperm = (data * perm[:, None, None]).sum(axis=0) # compare to true difference (numpy autocasts bool to 0/1) - out += (randperm > true) + out += randperm > true out = out / nperms # control for cases where pval is 1 out[out == 1] = 1 - (1 / nperms) @@ -71,11 +71,11 @@ def _pairedttest_exhaustive(data): def _pairedttest_random(data, nperms): - '''Returns p-values based on nperms permutations of a paired ttest + """Returns p-values based on nperms permutations of a paired ttest data is a (subj, alg, alg) matrix of differences between scores for each pair of algorithms per subject - ''' + """ out = np.ones((data.shape[1], data.shape[1])) true = data.sum(axis=0) for _ in range(nperms): @@ -84,13 +84,13 @@ def _pairedttest_random(data, nperms): # multiply permutation by subject dimension and sum over subjects randperm = (data * perm[:, None, None]).sum(axis=0) # compare to true difference (numpy autocasts bool to 0/1) - out += (randperm > true) + out += randperm > true out[out == nperms] = nperms - 1 return out / nperms def compute_pvals_perm(df, order=None): - '''Returns kxk matrix of p-values computed via permutation test, + """Returns kxk matrix of p-values computed via permutation test, order defines the order of rows and columns df: DataFrame, samples are index, columns are pipelines, and values are @@ -99,7 +99,7 @@ def compute_pvals_perm(df, order=None): order: list of length (num algorithms) with names corresponding to columns of df - ''' + """ if order is None: order = df.columns else: @@ -121,7 +121,7 @@ def compute_pvals_perm(df, order=None): def compute_effect(df, order=None): - '''Returns kxk matrix of effect sizes, order defines the order of rows/columns + """Returns kxk matrix of effect sizes, order defines the order of rows/columns df: DataFrame, samples are index, columns are pipelines, and values are scores @@ -129,7 +129,7 @@ def compute_effect(df, order=None): order: list of length (num algorithms) with names corresponding to columns of df - ''' + """ if order is None: order = df.columns else: @@ -141,38 +141,34 @@ def compute_effect(df, order=None): for j, pipe2 in enumerate(order): if i != j: # for now it's just the standardized difference - diffs = (df.loc[:, pipe1] - df.loc[:, pipe2]) + diffs = df.loc[:, pipe1] - df.loc[:, pipe2] diffs = diffs.mean() / diffs.std() out[i, j] = diffs return out def compute_dataset_statistics(df, perm_cutoff=20): - ''' + """ Returns dict of datasets to DataFrames with stats - ''' + """ df = collapse_session_scores(df) algs = df.pipeline.unique() dsets = df.dataset.unique() out = {} for d in dsets: - score_data = df[df.dataset == d].pivot(index='subject', - values='score', - columns='pipeline') + score_data = df[df.dataset == d].pivot( + index='subject', values='score', columns='pipeline' + ) if score_data.shape[0] < perm_cutoff: p = compute_pvals_perm(score_data, algs) else: p = compute_pvals_wilcoxon(score_data, algs) t = compute_effect(score_data, algs) - P = pd.DataFrame(index=pd.Index(algs, name='pipe1'), - columns=algs, data=p) - T = pd.DataFrame(index=pd.Index(algs, name='pipe1'), - columns=algs, data=t) - D1 = pd.melt(P.reset_index(), id_vars='pipe1', - var_name='pipe2', value_name='p') - D2 = pd.melt(T.reset_index(), id_vars='pipe1', - var_name='pipe2', value_name='smd') + P = pd.DataFrame(index=pd.Index(algs, name='pipe1'), columns=algs, data=p) + T = pd.DataFrame(index=pd.Index(algs, name='pipe1'), columns=algs, data=t) + D1 = pd.melt(P.reset_index(), id_vars='pipe1', var_name='pipe2', value_name='p') + D2 = pd.melt(T.reset_index(), id_vars='pipe1', var_name='pipe2', value_name='smd') stats_df = D1.merge(D2) stats_df['nsub'] = score_data.shape[0] out[d] = stats_df @@ -180,31 +176,30 @@ def compute_dataset_statistics(df, perm_cutoff=20): def combine_effects(effects, nsubs): - '''Function that takes effects from each experiments and number of subjects to + """Function that takes effects from each experiments and number of subjects to return meta-analysis effect - ''' + """ W = np.sqrt(nsubs) W = W / W.sum() return (W * effects).sum() def combine_pvalues(p, nsubs): - '''Function that takes pvals from each experiments and number of subjects to + """Function that takes pvals from each experiments and number of subjects to return meta-analysis significance - ''' + """ if len(p) == 1: return p.item() else: W = np.sqrt(nsubs) - out = stats.combine_pvalues(np.array(p), - weights=W, method='stouffer')[1] + out = stats.combine_pvalues(np.array(p), weights=W, method='stouffer')[1] return out def find_significant_differences(df, perm_cutoff=20): - '''Compute matrix of p-values for all algorithms over all datasets via + """Compute matrix of p-values for all algorithms over all datasets via combined p-values method df: DataFrame, output of compute_dataset_statistics @@ -218,14 +213,12 @@ def find_significant_differences(df, perm_cutoff=20): T: matrix (k,k) of signed standardized mean difference - ''' + """ dsets = df.dataset.unique() algs = df.pipe1.unique() nsubs = np.array([df.loc[df.dataset == d, 'nsub'].mean() for d in dsets]) - P_full = df.pivot_table( - values='p', index=['dataset', 'pipe1'], columns='pipe2') - T_full = df.pivot_table(values='smd', index=[ - 'dataset', 'pipe1'], columns='pipe2') + P_full = df.pivot_table(values='p', index=['dataset', 'pipe1'], columns='pipe2') + T_full = df.pivot_table(values='smd', index=['dataset', 'pipe1'], columns='pipe2') P = np.full((len(algs), len(algs)), np.NaN) T = np.full((len(algs), len(algs)), np.NaN) for i in range(len(algs)): diff --git a/moabb/analysis/plotting.py b/moabb/analysis/plotting.py index 8e22dd3d4..994811984 100644 --- a/moabb/analysis/plotting.py +++ b/moabb/analysis/plotting.py @@ -15,8 +15,7 @@ PIPELINE_PALETTE = sea.color_palette("husl", 6) -sea.set(font='serif', style='whitegrid', - palette=PIPELINE_PALETTE, color_codes=False) +sea.set(font='serif', style='whitegrid', palette=PIPELINE_PALETTE, color_codes=False) log = logging.getLogger() @@ -29,13 +28,13 @@ def _simplify_names(x): def score_plot(data, pipelines=None): - ''' + """ In: data: output of Results.to_dataframe() pipelines: list of string|None, pipelines to include in this plot Out: ax: pyplot Axes reference - ''' + """ data = collapse_session_scores(data) data['dataset'] = data['dataset'].apply(_simplify_names) if pipelines is not None: @@ -43,9 +42,17 @@ def score_plot(data, pipelines=None): fig = plt.figure(figsize=(8.5, 11)) ax = fig.add_subplot(111) # markers = ['o', '8', 's', 'p', '+', 'x', 'D', 'd', '>', '<', '^'] - sea.stripplot(data=data, y="dataset", x="score", jitter=0.15, - palette=PIPELINE_PALETTE, hue='pipeline', dodge=True, ax=ax, - alpha=0.7) + sea.stripplot( + data=data, + y="dataset", + x="score", + jitter=0.15, + palette=PIPELINE_PALETTE, + hue='pipeline', + dodge=True, + ax=ax, + alpha=0.7, + ) ax.set_xlim([0, 1]) ax.axvline(0.5, linestyle='--', color='k', linewidth=2) ax.set_title('Scores per dataset and algorithm') @@ -56,17 +63,18 @@ def score_plot(data, pipelines=None): def paired_plot(data, alg1, alg2): - ''' + """ returns figure with an axis that has a paired plot on it Data: dataframe from Results alg1: name of a member of column data.pipeline alg2: name of a member of column data.pipeline - ''' + """ data = collapse_session_scores(data) data = data[data.pipeline.isin([alg1, alg2])] - data = data.pivot_table(values='score', columns='pipeline', - index=['subject', 'dataset']) + data = data.pivot_table( + values='score', columns='pipeline', index=['subject', 'dataset'] + ) data = data.reset_index() fig = plt.figure(figsize=(11, 8.5)) ax = fig.add_subplot(111) @@ -78,12 +86,12 @@ def paired_plot(data, alg1, alg2): def summary_plot(sig_df, effect_df, p_threshold=0.05, simplify=True): - '''Visualize significances as a heatmap with green/grey/red for significantly + """Visualize significances as a heatmap with green/grey/red for significantly higher/significantly lower. sig_df is a DataFrame of pipeline x pipeline where each value is a p-value, effect_df is a DF where each value is an effect size - ''' + """ if simplify: effect_df.columns = effect_df.columns.map(_simplify_names) sig_df.columns = sig_df.columns.map(_simplify_names) @@ -91,8 +99,9 @@ def summary_plot(sig_df, effect_df, p_threshold=0.05, simplify=True): for row in annot_df.index: for col in annot_df.columns: if effect_df.loc[row, col] > 0: - txt = '{:.2f}\np={:1.0e}'.format(effect_df.loc[row, col], - sig_df.loc[row, col]) + txt = '{:.2f}\np={:1.0e}'.format( + effect_df.loc[row, col], sig_df.loc[row, col] + ) else: # we need the effect direction and p-value to coincide. # TODO: current is hack @@ -105,10 +114,18 @@ def summary_plot(sig_df, effect_df, p_threshold=0.05, simplify=True): palette = sea.light_palette("green", as_cmap=True) palette.set_under(color=[1, 1, 1]) palette.set_over(color=[0.5, 0, 0]) - sea.heatmap(data=-np.log(sig_df), annot=annot_df, - fmt='', cmap=palette, linewidths=1, - linecolor='0.8', annot_kws={'size': 10}, cbar=False, - vmin=-np.log(0.05), vmax=-np.log(1e-100)) + sea.heatmap( + data=-np.log(sig_df), + annot=annot_df, + fmt='', + cmap=palette, + linewidths=1, + linecolor='0.8', + annot_kws={'size': 10}, + cbar=False, + vmin=-np.log(0.05), + vmax=-np.log(1e-100), + ) for lb in ax.get_xticklabels(): lb.set_rotation(45) ax.tick_params(axis='y', rotation=0.9) @@ -118,9 +135,10 @@ def summary_plot(sig_df, effect_df, p_threshold=0.05, simplify=True): def meta_analysis_plot(stats_df, alg1, alg2): # noqa: C901 - '''A meta-analysis style plot that shows the standardized effect with + """A meta-analysis style plot that shows the standardized effect with confidence intervals over all datasets for two algorithms. - Hypothesis is that alg1 is larger than alg2''' + Hypothesis is that alg1 is larger than alg2""" + def _marker(pval): if pval < 0.001: return '$***$', 100 @@ -130,8 +148,9 @@ def _marker(pval): return '$*$', 30 else: raise ValueError('insignificant pval {}'.format(pval)) - assert (alg1 in stats_df.pipe1.unique()) - assert (alg2 in stats_df.pipe1.unique()) + + assert alg1 in stats_df.pipe1.unique() + assert alg2 in stats_df.pipe1.unique() df_fw = stats_df.loc[(stats_df.pipe1 == alg1) & (stats_df.pipe2 == alg2)] df_fw = df_fw.sort_values(by='pipe1') df_bk = stats_df.loc[(stats_df.pipe1 == alg2) & (stats_df.pipe2 == alg1)] @@ -166,21 +185,22 @@ def _marker(pval): pvals.append(p) _min = _min if (_min < (v - ci[-1])) else (v - ci[-1]) _max = _max if (_max > (v + ci[-1])) else (v + ci[-1]) - ax.plot(np.array([v - ci[-1], v + ci[-1]]), - np.ones((2,)) * (ind + 1), c='tab:grey') + ax.plot( + np.array([v - ci[-1], v + ci[-1]]), np.ones((2,)) * (ind + 1), c='tab:grey' + ) _range = max(abs(_min), abs(_max)) ax.set_xlim((0 - _range, 0 + _range)) final_effect = combine_effects(df_fw['smd'], df_fw['nsub']) - ax.scatter(pd.concat([pd.Series([final_effect]), df_fw['smd']]), - np.arange(len(dsets) + 1), - s=np.array([50] + [30] * len(dsets)), - marker='D', - c=['k'] + ['tab:grey'] * len(dsets)) + ax.scatter( + pd.concat([pd.Series([final_effect]), df_fw['smd']]), + np.arange(len(dsets) + 1), + s=np.array([50] + [30] * len(dsets)), + marker='D', + c=['k'] + ['tab:grey'] * len(dsets), + ) for i, p in zip(sig_ind, pvals): m, s = _marker(p) - ax.scatter(df_fw['smd'].iloc[i], - i + 1.4, s=s, - marker=m, color='r') + ax.scatter(df_fw['smd'].iloc[i], i + 1.4, s=s, marker=m, color='r') # pvalues axis stuf pval_ax.set_xlim([-0.1, 0.1]) pval_ax.grid(False) @@ -189,37 +209,49 @@ def _marker(pval): for spine in pval_ax.spines.values(): spine.set_visible(False) for ind, p in zip(sig_ind, pvals): - pval_ax.text(0, ind + 1, horizontalalignment='center', - verticalalignment='center', - s='{:.2e}'.format(p), fontsize=8) + pval_ax.text( + 0, + ind + 1, + horizontalalignment='center', + verticalalignment='center', + s='{:.2e}'.format(p), + fontsize=8, + ) if final_effect > 0: p = combine_pvalues(df_fw['p'], df_fw['nsub']) if p < 0.05: m, s = _marker(p) - ax.scatter([final_effect], [-0.4], s=s, - marker=m, c='r') - pval_ax.text(0, 0, horizontalalignment='center', - verticalalignment='center', - s='{:.2e}'.format(p), fontsize=8) + ax.scatter([final_effect], [-0.4], s=s, marker=m, c='r') + pval_ax.text( + 0, + 0, + horizontalalignment='center', + verticalalignment='center', + s='{:.2e}'.format(p), + fontsize=8, + ) else: p = combine_pvalues(df_bk['p'], df_bk['nsub']) if p < 0.05: m, s = _marker(p) - ax.scatter([final_effect], [-0.4], s=s, - marker=m, c='r') - pval_ax.text(0, 0, horizontalalignment='center', - verticalalignment='center', - s='{:.2e}'.format(p), fontsize=8) + ax.scatter([final_effect], [-0.4], s=s, marker=m, c='r') + pval_ax.text( + 0, + 0, + horizontalalignment='center', + verticalalignment='center', + s='{:.2e}'.format(p), + fontsize=8, + ) ax.grid(False) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.axvline(0, linestyle='--', c='k') ax.axhline(0.5, linestyle='-', linewidth=3, c='k') - title = '< {} better{}\n{}{} better >'.format(alg2, - ' ' * (45 - len(alg2)), - ' ' * (45 - len(alg1)), - alg1) + title = '< {} better{}\n{}{} better >'.format( + alg2, ' ' * (45 - len(alg2)), ' ' * (45 - len(alg1)), alg1 + ) ax.set_title(title, ha='left', ma='right', loc='left') ax.set_xlabel('Standardized Mean Difference') fig.tight_layout() diff --git a/moabb/analysis/results.py b/moabb/analysis/results.py index 212291135..fc065e21a 100644 --- a/moabb/analysis/results.py +++ b/moabb/analysis/results.py @@ -28,7 +28,7 @@ def get_digest(obj): class Results: - '''Class to hold results from the evaluation.evaluate method. + """Class to hold results from the evaluation.evaluate method. Appropriate test would be to ensure the result of 'evaluate' is consistent and can be accepted by 'results.add' @@ -36,16 +36,24 @@ class Results: Saves dataframe per pipeline and can query to see if particular subject has already been run - ''' + """ - def __init__(self, evaluation_class, paradigm_class, suffix='', - overwrite=False, hdf5_path=None, additional_columns=None): + def __init__( + self, + evaluation_class, + paradigm_class, + suffix='', + overwrite=False, + hdf5_path=None, + additional_columns=None, + ): """ class that will abstract result storage """ import moabb from moabb.evaluations.base import BaseEvaluation from moabb.paradigms.base import BaseParadigm + assert issubclass(evaluation_class, BaseEvaluation) assert issubclass(paradigm_class, BaseParadigm) @@ -56,14 +64,16 @@ class that will abstract result storage self.additional_columns = additional_columns if hdf5_path is None: - self.mod_dir = os.path.dirname( - os.path.abspath(inspect.getsourcefile(moabb))) + self.mod_dir = os.path.dirname(os.path.abspath(inspect.getsourcefile(moabb))) else: self.mod_dir = os.path.abspath(hdf5_path) - self.filepath = os.path.join(self.mod_dir, 'results', - paradigm_class.__name__, - evaluation_class.__name__, - 'results{}.hdf5'.format('_' + suffix)) + self.filepath = os.path.join( + self.mod_dir, + 'results', + paradigm_class.__name__, + evaluation_class.__name__, + 'results{}.hdf5'.format('_' + suffix), + ) os.makedirs(os.path.dirname(self.filepath), exist_ok=True) self.filepath = self.filepath @@ -74,16 +84,20 @@ class that will abstract result storage if not os.path.isfile(self.filepath): with h5py.File(self.filepath, 'w') as f: f.attrs['create_time'] = np.string_( - '{:%Y-%m-%d, %H:%M}'.format(datetime.now())) + '{:%Y-%m-%d, %H:%M}'.format(datetime.now()) + ) def add(self, results, pipelines): """add results""" + def to_list(res): if type(res) is dict: return [res] elif type(res) is not list: - raise ValueError("Results are given as neither dict nor" - "list but {}".format(type(res).__name__)) + raise ValueError( + "Results are given as neither dict nor" + "list but {}".format(type(res).__name__) + ) else: return res @@ -108,34 +122,34 @@ def to_list(res): dset.attrs['n_subj'] = len(d1['dataset'].subject_list) dset.attrs['n_sessions'] = d1['dataset'].n_sessions dt = h5py.special_dtype(vlen=str) - dset.create_dataset('id', (0, 2), dtype=dt, - maxshape=(None, 2)) - dset.create_dataset('data', (0, 3 + n_add_cols), - maxshape=(None, 3 + n_add_cols)) + dset.create_dataset('id', (0, 2), dtype=dt, maxshape=(None, 2)) + dset.create_dataset( + 'data', (0, 3 + n_add_cols), maxshape=(None, 3 + n_add_cols) + ) dset.attrs['channels'] = d1['n_channels'] - dset.attrs.create('columns', - ['score', 'time', 'samples', - *self.additional_columns], - dtype=dt) + dset.attrs.create( + 'columns', + ['score', 'time', 'samples', *self.additional_columns], + dtype=dt, + ) dset = ppline_grp[dname] for d in dlist: # add id and scores to group length = len(dset['id']) + 1 dset['id'].resize(length, 0) dset['data'].resize(length, 0) - dset['id'][-1, :] = np.asarray([str(d['subject']), - str(d['session'])]) + dset['id'][-1, :] = np.asarray([str(d['subject']), str(d['session'])]) try: add_cols = [d[ac] for ac in self.additional_columns] except KeyError: raise ValueError( f'Additional columns: {self.additional_columns} ' f'were specified in the evaluation, but results' - f' contain only these keys: {d.keys()}.') - dset['data'][-1, :] = np.asarray([d['score'], - d['time'], - d['n_samples'], - *add_cols]) + f' contain only these keys: {d.keys()}.' + ) + dset['data'][-1, :] = np.asarray( + [d['score'], d['time'], d['n_samples'], *add_cols] + ) def to_dataframe(self, pipelines=None): df_list = [] @@ -156,8 +170,7 @@ def to_dataframe(self, pipelines=None): for dname, dset in p_group.items(): array = np.array(dset['data']) ids = np.array(dset['id']) - df = pd.DataFrame(array, - columns=dset.attrs['columns']) + df = pd.DataFrame(array, columns=dset.attrs['columns']) df['subject'] = ids[:, 0] df['session'] = ids[:, 1] df['channels'] = dset.attrs['channels'] @@ -169,8 +182,11 @@ def to_dataframe(self, pipelines=None): def not_yet_computed(self, pipelines, dataset, subj): """Check if a results has already been computed.""" - ret = {k: pipelines[k] for k in pipelines.keys() - if not self._already_computed(pipelines[k], dataset, subj)} + ret = { + k: pipelines[k] + for k in pipelines.keys() + if not self._already_computed(pipelines[k], dataset, subj) + } return ret def _already_computed(self, pipeline, dataset, subject, session=None): @@ -192,4 +208,4 @@ def _already_computed(self, pipeline, dataset, subject, session=None): else: # if dataset, check for subject dset = pipe_grp[dataset.code] - return (str(subject).encode('utf-8') in dset['id'][:, 0]) + return str(subject).encode('utf-8') in dset['id'][:, 0] diff --git a/moabb/datasets/Weibo2014.py b/moabb/datasets/Weibo2014.py index 2b296f887..24d18296f 100644 --- a/moabb/datasets/Weibo2014.py +++ b/moabb/datasets/Weibo2014.py @@ -33,22 +33,25 @@ def eeg_data_path(base_path, subject): def get_subjects(sub_inds, sub_names, ind): dataname = 'data{}'.format(ind) if not os.path.isfile(os.path.join(base_path, dataname + '.zip')): - _fetch_file(FILES[ind], os.path.join( - base_path, dataname + '.zip'), print_destination=False) + _fetch_file( + FILES[ind], + os.path.join(base_path, dataname + '.zip'), + print_destination=False, + ) with z.ZipFile(os.path.join(base_path, dataname + '.zip'), 'r') as f: os.makedirs(os.path.join(base_path, dataname), exist_ok=True) f.extractall(os.path.join(base_path, dataname)) for fname in os.listdir(os.path.join(base_path, dataname)): for ind, prefix in zip(sub_inds, sub_names): if fname.startswith(prefix): - os.rename(os.path.join(base_path, dataname, fname), - os.path.join(base_path, - 'subject_{}.mat'.format(ind))) + os.rename( + os.path.join(base_path, dataname, fname), + os.path.join(base_path, 'subject_{}.mat'.format(ind)), + ) os.remove(os.path.join(base_path, dataname + '.zip')) shutil.rmtree(os.path.join(base_path, dataname)) - if not os.path.isfile(os.path.join(base_path, - 'subject_{}.mat'.format(subject))): + if not os.path.isfile(os.path.join(base_path, 'subject_{}.mat'.format(subject))): if subject in range(1, 5): get_subjects(list(range(1, 5)), file1_subj, 0) elif subject in range(5, 8): @@ -99,22 +102,35 @@ def __init__(self): super().__init__( subjects=list(range(1, 11)), sessions_per_subject=1, - events=dict(left_hand=1, right_hand=2, - hands=3, feet=4, left_hand_right_foot=5, - right_hand_left_foot=6, rest=7), + events=dict( + left_hand=1, + right_hand=2, + hands=3, + feet=4, + left_hand_right_foot=5, + right_hand_left_foot=6, + rest=7, + ), code='Weibo 2014', # Full trial w/ rest is 0-8 interval=[3, 7], paradigm='imagery', - doi='10.1371/journal.pone.0114853') + doi='10.1371/journal.pone.0114853', + ) def _get_single_subject_data(self, subject): """return data for a single subject""" fname = self.data_path(subject) # TODO: add 1s 0 buffer between trials and make continuous - data = loadmat(fname, squeeze_me=True, struct_as_record=False, - verify_compressed_data_integrity=False) + data = loadmat( + fname, + squeeze_me=True, + struct_as_record=False, + verify_compressed_data_integrity=False, + ) montage = mne.channels.make_standard_montage('standard_1005') + + # fmt: off ch_names = ['Fp1', 'Fpz', 'Fp2', 'AF3', 'AF4', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'FT8', 'T7', 'C5', 'C3', 'C1', @@ -123,14 +139,15 @@ def _get_single_subject_data(self, subject): 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO5', 'PO3', 'POz', 'PO4', 'PO6', 'PO8', 'CB1', 'O1', 'Oz', 'O2', 'CB2', 'VEO', 'HEO'] + # fmt: on ch_types = ['eeg'] * 62 + ['eog'] * 2 # FIXME not sure what are those CB1 / CB2 ch_types[57] = 'misc' ch_types[61] = 'misc' - info = mne.create_info(ch_names=ch_names + ['STIM014'], - ch_types=ch_types + ['stim'], - sfreq=200) + info = mne.create_info( + ch_names=ch_names + ['STIM014'], ch_types=ch_types + ['stim'], sfreq=200 + ) # until we get the channel names montage is None event_ids = data['label'].ravel() raw_data = np.transpose(data['data'], axes=[2, 0, 1]) @@ -141,20 +158,21 @@ def _get_single_subject_data(self, subject): data = np.concatenate([1e-6 * raw_data, raw_events], axis=1) # add buffer in between trials log.warning( - "Trial data de-meaned and concatenated with a buffer to create " - "cont data") + "Trial data de-meaned and concatenated with a buffer to create " "cont data" + ) zeroshape = (data.shape[0], data.shape[1], 50) - data = np.concatenate([np.zeros(zeroshape), data, - np.zeros(zeroshape)], axis=2) - raw = mne.io.RawArray(data=np.concatenate(list(data), axis=1), - info=info, verbose=False) + data = np.concatenate([np.zeros(zeroshape), data, np.zeros(zeroshape)], axis=2) + raw = mne.io.RawArray( + data=np.concatenate(list(data), axis=1), info=info, verbose=False + ) raw.set_montage(montage) return {'session_0': {'run_0': raw}} - def data_path(self, subject, path=None, force_update=False, - update_path=None, verbose=None): + def data_path( + self, subject, path=None, force_update=False, update_path=None, verbose=None + ): if subject not in self.subject_list: - raise(ValueError("Invalid subject number")) + raise (ValueError("Invalid subject number")) key = 'MNE_DATASETS_WEIBO2014_PATH' path = _get_path(path, key, "Weibo 2014") _do_path_update(path, True, key, "Weibo 2014") diff --git a/moabb/datasets/Zhou2016.py b/moabb/datasets/Zhou2016.py index a962e6f3f..63229a4b3 100644 --- a/moabb/datasets/Zhou2016.py +++ b/moabb/datasets/Zhou2016.py @@ -20,11 +20,11 @@ def local_data_path(base_path, subject): - if not os.path.isdir(os.path.join(base_path, - 'subject_{}'.format(subject))): + if not os.path.isdir(os.path.join(base_path, 'subject_{}'.format(subject))): if not os.path.isdir(os.path.join(base_path, 'data')): - _fetch_file(DATA_PATH, os.path.join(base_path, 'data.zip'), - print_destination=False) + _fetch_file( + DATA_PATH, os.path.join(base_path, 'data.zip'), print_destination=False + ) with z.ZipFile(os.path.join(base_path, 'data.zip'), 'r') as f: f.extractall(base_path) os.remove(os.path.join(base_path, 'data.zip')) @@ -33,16 +33,20 @@ def local_data_path(base_path, subject): os.makedirs(os.path.join(base_path, 'subject_{}'.format(i))) for session in range(1, 4): for run in ['A', 'B']: - os.rename(os.path.join(datapath, - 'S{}_{}{}.cnt'.format(i, session, - run)), - os.path.join(base_path, - 'subject_{}'.format(i), - '{}{}.cnt'.format(session, run))) + os.rename( + os.path.join(datapath, 'S{}_{}{}.cnt'.format(i, session, run)), + os.path.join( + base_path, + 'subject_{}'.format(i), + '{}{}.cnt'.format(session, run), + ), + ) shutil.rmtree(os.path.join(base_path, 'data')) subjpath = os.path.join(base_path, 'subject_{}'.format(subject)) - return [[os.path.join(subjpath, '{}{}.cnt'.format(y, x)) - for x in ['A', 'B']] for y in ['1', '2', '3']] + return [ + [os.path.join(subjpath, '{}{}.cnt'.format(y, x)) for x in ['A', 'B']] + for y in ['1', '2', '3'] + ] class Zhou2016(BaseDataset): @@ -78,14 +82,14 @@ def __init__(self): super().__init__( subjects=list(range(1, 5)), sessions_per_subject=3, - events=dict(left_hand=1, right_hand=2, - feet=3), + events=dict(left_hand=1, right_hand=2, feet=3), code='Zhou 2016', # MI 1-6s, prepare 0-1, break 6-10 # boundary effects interval=[0, 5], paradigm='imagery', - doi='10.1371/journal.pone.0162657') + doi='10.1371/journal.pone.0162657', + ) def _get_single_subject_data(self, subject): """return data for a single subject""" @@ -97,22 +101,21 @@ def _get_single_subject_data(self, subject): out[sess_key] = {} for run_ind, fname in enumerate(runlist): run_key = 'run_{}'.format(run_ind) - raw = read_raw_cnt(fname, preload=True, - eog=['VEOU', 'VEOL']) + raw = read_raw_cnt(fname, preload=True, eog=['VEOU', 'VEOL']) stim = raw.annotations.description.astype(np.dtype('<10U')) stim[stim == '1'] = 'left_hand' stim[stim == '2'] = 'right_hand' stim[stim == '3'] = 'feet' raw.annotations.description = stim out[sess_key][run_key] = raw - out[sess_key][run_key].set_montage( - make_standard_montage('standard_1005')) + out[sess_key][run_key].set_montage(make_standard_montage('standard_1005')) return out - def data_path(self, subject, path=None, force_update=False, - update_path=None, verbose=None): + def data_path( + self, subject, path=None, force_update=False, update_path=None, verbose=None + ): if subject not in self.subject_list: - raise(ValueError("Invalid subject number")) + raise (ValueError("Invalid subject number")) key = 'MNE_DATASETS_ZHOU2016_PATH' path = _get_path(path, key, "Zhou 2016") _do_path_update(path, True, key, "Zhou 2016") diff --git a/moabb/datasets/alex_mi.py b/moabb/datasets/alex_mi.py index 30de8ecdf..f488f4d4d 100644 --- a/moabb/datasets/alex_mi.py +++ b/moabb/datasets/alex_mi.py @@ -45,17 +45,18 @@ def __init__(self): events=dict(right_hand=2, feet=3, rest=4), code='Alexandre Motor Imagery', interval=[0, 3], - paradigm='imagery') + paradigm='imagery', + ) def _get_single_subject_data(self, subject): """return data for a single subject""" raw = Raw(self.data_path(subject), preload=True, verbose='ERROR') return {"session_0": {"run_0": raw}} - def data_path(self, subject, path=None, force_update=False, - update_path=None, verbose=None): + def data_path( + self, subject, path=None, force_update=False, update_path=None, verbose=None + ): if subject not in self.subject_list: - raise(ValueError("Invalid subject number")) + raise (ValueError("Invalid subject number")) url = '{:s}subject{:d}.raw.fif'.format(ALEX_URL, subject) - return dl.data_path(url, 'ALEXEEG', path, force_update, update_path, - verbose) + return dl.data_path(url, 'ALEXEEG', path, force_update, update_path, verbose) diff --git a/moabb/datasets/base.py b/moabb/datasets/base.py index 25db67d09..c5c562437 100644 --- a/moabb/datasets/base.py +++ b/moabb/datasets/base.py @@ -49,10 +49,19 @@ class BaseDataset(metaclass=abc.ABCMeta): doi: DOI for dataset, optional (for now) """ - def __init__(self, subjects, sessions_per_subject, events, - code, interval, paradigm, doi=None, unit_factor=1e6): + def __init__( + self, + subjects, + sessions_per_subject, + events, + code, + interval, + paradigm, + doi=None, + unit_factor=1e6, + ): if not isinstance(subjects, list): - raise(ValueError("subjects must be a list")) + raise (ValueError("subjects must be a list")) self.subject_list = subjects self.n_sessions = sessions_per_subject @@ -94,7 +103,7 @@ def get_data(self, subjects=None): subjects = self.subject_list if not isinstance(subjects, list): - raise(ValueError('subjects must be a list')) + raise (ValueError('subjects must be a list')) data = dict() for subject in subjects: @@ -104,8 +113,14 @@ def get_data(self, subjects=None): return data - def download(self, subject_list=None, path=None, force_update=False, - update_path=None, verbose=None): + def download( + self, + subject_list=None, + path=None, + force_update=False, + update_path=None, + verbose=None, + ): """Download all data from the dataset. This function is only usefull to download all the dataset at once. @@ -135,9 +150,13 @@ def download(self, subject_list=None, path=None, force_update=False, if subject_list is None: subject_list = self.subject_list for subject in subject_list: - self.data_path(subject=subject, path=path, - force_update=force_update, - update_path=update_path, verbose=verbose) + self.data_path( + subject=subject, + path=path, + force_update=force_update, + update_path=update_path, + verbose=verbose, + ) @abc.abstractmethod def _get_single_subject_data(self, subject): @@ -162,8 +181,9 @@ def _get_single_subject_data(self, subject): pass @abc.abstractmethod - def data_path(self, subject, path=None, force_update=False, - update_path=None, verbose=None): + def data_path( + self, subject, path=None, force_update=False, update_path=None, verbose=None + ): """Get path to local copy of a subject data. Parameters diff --git a/moabb/datasets/bbci_eeg_fnirs.py b/moabb/datasets/bbci_eeg_fnirs.py index 1b7b800c3..5d823db9d 100644 --- a/moabb/datasets/bbci_eeg_fnirs.py +++ b/moabb/datasets/bbci_eeg_fnirs.py @@ -21,8 +21,9 @@ def eeg_data_path(base_path, subject): - datapath = op.join(base_path, 'EEG', 'subject {:02d}'.format( - subject), 'with occular artifact') + datapath = op.join( + base_path, 'EEG', 'subject {:02d}'.format(subject), 'with occular artifact' + ) if not op.isfile(op.join(datapath, 'cnt.mat')): if not op.isdir(op.join(base_path, 'EEG')): os.makedirs(op.join(base_path, 'EEG')) @@ -30,17 +31,16 @@ def eeg_data_path(base_path, subject): for low, high in intervals: if subject >= low and subject <= high: if not op.isfile(op.join(base_path, 'EEG.zip')): - _fetch_file('{}/EEG/EEG_{:02d}-{:02d}.zip'.format(SHIN_URL, - low, - high), - op.join(base_path, 'EEG.zip'), - print_destination=False) + _fetch_file( + '{}/EEG/EEG_{:02d}-{:02d}.zip'.format(SHIN_URL, low, high), + op.join(base_path, 'EEG.zip'), + print_destination=False, + ) with z.ZipFile(op.join(base_path, 'EEG.zip'), 'r') as f: f.extractall(op.join(base_path, 'EEG')) os.remove(op.join(base_path, 'EEG.zip')) break - assert op.isfile(op.join(datapath, 'cnt.mat') - ), op.join(datapath, 'cnt.mat') + assert op.isfile(op.join(datapath, 'cnt.mat')), op.join(datapath, 'cnt.mat') return [op.join(datapath, fn) for fn in ['cnt.mat', 'mrk.mat']] @@ -49,8 +49,11 @@ def fnirs_data_path(path, subject): if not op.isfile(op.join(datapath, 'mrk.mat')): # fNIRS if not op.isfile(op.join(path, 'fNIRS.zip')): - _fetch_file('http://doc.ml.tu-berlin.de/hBCI/NIRS/NIRS_01-29.zip', - op.join(path, 'fNIRS.zip'), print_destination=False) + _fetch_file( + 'http://doc.ml.tu-berlin.de/hBCI/NIRS/NIRS_01-29.zip', + op.join(path, 'fNIRS.zip'), + print_destination=False, + ) if not op.isdir(op.join(path, 'NIRS')): os.makedirs(op.join(path, 'NIRS')) with z.ZipFile(op.join(path, 'fNIRS.zip'), 'r') as f: @@ -60,13 +63,15 @@ def fnirs_data_path(path, subject): class Shin2017(BaseDataset): - """Not to be used. - """ - def __init__(self, fnirs=False, motor_imagery=True, - mental_arithmetic=False): + """Not to be used.""" + + def __init__(self, fnirs=False, motor_imagery=True, mental_arithmetic=False): if not any([motor_imagery, mental_arithmetic]): - raise(ValueError("at least one of motor_imagery or" - " mental_arithmetic must be true")) + raise ( + ValueError( + "at least one of motor_imagery or" " mental_arithmetic must be true" + ) + ) events = dict() paradigms = [] n_sessions = 0 @@ -83,39 +88,38 @@ def __init__(self, fnirs=False, motor_imagery=True, self.motor_imagery = motor_imagery self.mental_arithmetic = mental_arithmetic - super().__init__(subjects=list(range(1, 30)), - sessions_per_subject=n_sessions, - events=events, - code='Shin2017', - # marker is for *task* start not cue start - interval=[0, 10], - paradigm=('/').join(paradigms), - doi='10.1109/TNSRE.2016.2628057') + super().__init__( + subjects=list(range(1, 30)), + sessions_per_subject=n_sessions, + events=events, + code='Shin2017', + # marker is for *task* start not cue start + interval=[0, 10], + paradigm=('/').join(paradigms), + doi='10.1109/TNSRE.2016.2628057', + ) if fnirs: - raise(NotImplementedError("Fnirs not implemented.")) + raise (NotImplementedError("Fnirs not implemented.")) self.fnirs = fnirs # TODO: actually incorporate fNIRS somehow def _get_single_subject_data(self, subject): """return data for a single subject""" fname, fname_mrk = self.data_path(subject) data = loadmat(fname, squeeze_me=True, struct_as_record=False)['cnt'] - mrk = loadmat(fname_mrk, squeeze_me=True, - struct_as_record=False)['mrk'] + mrk = loadmat(fname_mrk, squeeze_me=True, struct_as_record=False)['mrk'] sessions = {} # motor imagery if self.motor_imagery: for ii in [0, 2, 4]: - session = self._convert_one_session(data, mrk, ii, - trig_offset=0) + session = self._convert_one_session(data, mrk, ii, trig_offset=0) sessions['session_%d' % ii] = session # arithmetic/rest if self.mental_arithmetic: for ii in [1, 3, 5]: - session = self._convert_one_session(data, mrk, ii, - trig_offset=2) + session = self._convert_one_session(data, mrk, ii, trig_offset=2) sessions['session_%d' % ii] = session return sessions @@ -130,16 +134,16 @@ def _convert_one_session(self, data, mrk, session, trig_offset=0): ch_types = ['eeg'] * 30 + ['eog'] * 2 + ['stim'] montage = make_standard_montage('standard_1005') - info = create_info(ch_names=ch_names, ch_types=ch_types, - sfreq=200.) + info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=200.0) raw = RawArray(data=eeg, info=info, verbose=False) raw.set_montage(montage) return {'run_0': raw} - def data_path(self, subject, path=None, force_update=False, - update_path=None, verbose=None): + def data_path( + self, subject, path=None, force_update=False, update_path=None, verbose=None + ): if subject not in self.subject_list: - raise(ValueError("Invalid subject number")) + raise (ValueError("Invalid subject number")) key = 'MNE_DATASETS_BBCIFNIRS_PATH' path = _get_path(path, key, 'BBCI EEG-fNIRS') @@ -269,8 +273,7 @@ class Shin2017A(Shin2017): """ def __init__(self): - super().__init__(fnirs=False, motor_imagery=True, - mental_arithmetic=False) + super().__init__(fnirs=False, motor_imagery=True, mental_arithmetic=False) self.code = 'Shin2017A' @@ -367,6 +370,5 @@ class Shin2017B(Shin2017): """ def __init__(self): - super().__init__(fnirs=False, motor_imagery=False, - mental_arithmetic=True) + super().__init__(fnirs=False, motor_imagery=False, mental_arithmetic=True) self.code = 'Shin2017B' diff --git a/moabb/datasets/bnci.py b/moabb/datasets/bnci.py index 345fe7816..ba265ab57 100644 --- a/moabb/datasets/bnci.py +++ b/moabb/datasets/bnci.py @@ -16,24 +16,20 @@ BBCI_URL = 'http://doc.ml.tu-berlin.de/bbci/' -def data_path(url, - path=None, - force_update=False, - update_path=None, - verbose=None): - return [ - dl.data_path(url, 'BNCI', path, force_update, update_path, verbose) - ] +def data_path(url, path=None, force_update=False, update_path=None, verbose=None): + return [dl.data_path(url, 'BNCI', path, force_update, update_path, verbose)] @verbose -def load_data(subject, - dataset='001-2014', - path=None, - force_update=False, - update_path=None, - base_url=BNCI_URL, - verbose=None): # noqa: D301 +def load_data( + subject, + dataset='001-2014', + path=None, + force_update=False, + update_path=None, + base_url=BNCI_URL, + verbose=None, +): # noqa: D301 """Get paths to local copies of a BNCI dataset files. This will fetch data for a given BNCI dataset. Report to the bnci website @@ -81,7 +77,7 @@ def load_data(subject, '009-2015': _load_data_009_2015, '010-2015': _load_data_010_2015, '012-2015': _load_data_012_2015, - '013-2015': _load_data_013_2015 + '013-2015': _load_data_013_2015, } baseurl_list = { @@ -96,34 +92,40 @@ def load_data(subject, '009-2015': BBCI_URL, '010-2015': BBCI_URL, '012-2015': BBCI_URL, - '013-2015': BNCI_URL + '013-2015': BNCI_URL, } if dataset not in dataset_list.keys(): - raise ValueError("Dataset '%s' is not a valid BNCI dataset ID. " - "Valid dataset are %s." % - (dataset, ", ".join(dataset_list.keys()))) + raise ValueError( + "Dataset '%s' is not a valid BNCI dataset ID. " + "Valid dataset are %s." % (dataset, ", ".join(dataset_list.keys())) + ) - return dataset_list[dataset](subject, path, force_update, update_path, - baseurl_list[dataset], verbose) + return dataset_list[dataset]( + subject, path, force_update, update_path, baseurl_list[dataset], verbose + ) @verbose -def _load_data_001_2014(subject, - path=None, - force_update=False, - update_path=None, - base_url=BNCI_URL, - verbose=None): +def _load_data_001_2014( + subject, + path=None, + force_update=False, + update_path=None, + base_url=BNCI_URL, + verbose=None, +): """Load data for 001-2014 dataset.""" if (subject < 1) or (subject > 9): raise ValueError("Subject must be between 1 and 9. Got %d." % subject) + # fmt: off ch_names = [ 'Fz', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'P1', 'Pz', 'P2', 'POz', 'EOG1', 'EOG2', 'EOG3' ] + # fmt: on ch_types = ['eeg'] * 22 + ['eog'] * 3 sessions = {} @@ -132,18 +134,19 @@ def _load_data_001_2014(subject, filename = data_path(url, path, force_update, update_path) runs, ev = _convert_mi(filename[0], ch_names, ch_types) # FIXME: deal with run with no event (1:3) and name them - sessions['session_%s' % r] = {'run_%d' % ii: run - for ii, run in enumerate(runs)} + sessions['session_%s' % r] = {'run_%d' % ii: run for ii, run in enumerate(runs)} return sessions @verbose -def _load_data_002_2014(subject, - path=None, - force_update=False, - update_path=None, - base_url=BNCI_URL, - verbose=None): +def _load_data_002_2014( + subject, + path=None, + force_update=False, + update_path=None, + base_url=BNCI_URL, + verbose=None, +): """Load data for 002-2014 dataset.""" if (subject < 1) or (subject > 14): raise ValueError("Subject must be between 1 and 14. Got %d." % subject) @@ -162,17 +165,21 @@ def _load_data_002_2014(subject, @verbose -def _load_data_004_2014(subject, - path=None, - force_update=False, - update_path=None, - base_url=BNCI_URL, - verbose=None): +def _load_data_004_2014( + subject, + path=None, + force_update=False, + update_path=None, + base_url=BNCI_URL, + verbose=None, +): """Load data for 004-2014 dataset.""" if (subject < 1) or (subject > 9): raise ValueError("Subject must be between 1 and 9. Got %d." % subject) + # fmt: off ch_names = ['C3', 'Cz', 'C4', 'EOG1', 'EOG2', 'EOG3'] + # fmt: on ch_types = ['eeg'] * 3 + ['eog'] * 3 sessions = [] @@ -182,18 +189,19 @@ def _load_data_004_2014(subject, raws, _ = _convert_mi(filename, ch_names, ch_types) sessions.extend(raws) - sessions = {'session_%d' % ii: {'run_0': run} - for ii, run in enumerate(sessions)} + sessions = {'session_%d' % ii: {'run_0': run} for ii, run in enumerate(sessions)} return sessions @verbose -def _load_data_008_2014(subject, - path=None, - force_update=False, - update_path=None, - base_url=BNCI_URL, - verbose=None): +def _load_data_008_2014( + subject, + path=None, + force_update=False, + update_path=None, + base_url=BNCI_URL, + verbose=None, +): """Load data for 008-2014 dataset.""" if (subject < 1) or (subject > 8): raise ValueError("Subject must be between 1 and 8. Got %d." % subject) @@ -211,12 +219,14 @@ def _load_data_008_2014(subject, @verbose -def _load_data_009_2014(subject, - path=None, - force_update=False, - update_path=None, - base_url=BNCI_URL, - verbose=None): +def _load_data_009_2014( + subject, + path=None, + force_update=False, + update_path=None, + base_url=BNCI_URL, + verbose=None, +): """Load data for 009-2014 dataset.""" if (subject < 1) or (subject > 10): raise ValueError("Subject must be between 1 and 10. Got %d." % subject) @@ -244,12 +254,14 @@ def _load_data_009_2014(subject, @verbose -def _load_data_001_2015(subject, - path=None, - force_update=False, - update_path=None, - base_url=BNCI_URL, - verbose=None): +def _load_data_001_2015( + subject, + path=None, + force_update=False, + update_path=None, + base_url=BNCI_URL, + verbose=None, +): """Load data for 001-2015 dataset.""" if (subject < 1) or (subject > 12): raise ValueError("Subject must be between 1 and 12. Got %d." % subject) @@ -259,8 +271,21 @@ def _load_data_001_2015(subject, else: ses = ['A', 'B'] - ch_names = ['FC3', 'FCz', 'FC4', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', - 'CP3', 'CPz', 'CP4'] + ch_names = [ + 'FC3', + 'FCz', + 'FC4', + 'C5', + 'C3', + 'C1', + 'Cz', + 'C2', + 'C4', + 'C6', + 'CP3', + 'CPz', + 'CP4', + ] ch_types = ['eeg'] * 13 sessions = {} @@ -268,18 +293,19 @@ def _load_data_001_2015(subject, url = '{u}001-2015/S{s:02d}{r}.mat'.format(u=base_url, s=subject, r=r) filename = data_path(url, path, force_update, update_path) runs, ev = _convert_mi(filename[0], ch_names, ch_types) - sessions['session_%s' % r] = {'run_%d' % ii: run - for ii, run in enumerate(runs)} + sessions['session_%s' % r] = {'run_%d' % ii: run for ii, run in enumerate(runs)} return sessions @verbose -def _load_data_003_2015(subject, - path=None, - force_update=False, - update_path=None, - base_url=BNCI_URL, - verbose=None): +def _load_data_003_2015( + subject, + path=None, + force_update=False, + update_path=None, + base_url=BNCI_URL, + verbose=None, +): """Load data for 003-2015 dataset.""" if (subject < 1) or (subject > 10): raise ValueError("Subject must be between 1 and 12. Got %d." % subject) @@ -291,11 +317,13 @@ def _load_data_003_2015(subject, data = loadmat(filename, struct_as_record=False, squeeze_me=True) data = data['s%d' % subject] - sfreq = 256. + sfreq = 256.0 + # fmt: off ch_names = [ 'Fz', 'Cz', 'P3', 'Pz', 'P4', 'PO7', 'Oz', 'PO8', 'Target', 'Flash' ] + # fmt: on ch_types = ['eeg'] * 8 + ['stim'] * 2 montage = make_standard_montage('standard_1005') @@ -332,12 +360,14 @@ def _load_data_003_2015(subject, @verbose -def _load_data_004_2015(subject, - path=None, - force_update=False, - update_path=None, - base_url=BNCI_URL, - verbose=None): +def _load_data_004_2015( + subject, + path=None, + force_update=False, + update_path=None, + base_url=BNCI_URL, + verbose=None, +): """Load data for 004-2015 dataset.""" if (subject < 1) or (subject > 9): raise ValueError("Subject must be between 1 and 9. Got %d." % subject) @@ -347,34 +377,39 @@ def _load_data_004_2015(subject, url = '{u}004-2015/{s}.mat'.format(u=base_url, s=subjects[subject - 1]) filename = data_path(url, path, force_update, update_path)[0] + # fmt: off ch_names = [ 'AFz', 'F7', 'F3', 'Fz', 'F4', 'F8', 'FC3', 'FCz', 'FC4', 'T3', 'C3', 'Cz', 'C4', 'T4', 'CP3', 'CPz', 'CP4', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO3', 'PO4', 'O1', 'O2' ] + # fmt: on ch_types = ['eeg'] * 30 raws, ev = _convert_mi(filename, ch_names, ch_types) - sessions = {'session_%d' % ii: {'run_0': run} - for ii, run in enumerate(raws)} + sessions = {'session_%d' % ii: {'run_0': run} for ii, run in enumerate(raws)} return sessions @verbose -def _load_data_009_2015(subject, - path=None, - force_update=False, - update_path=None, - base_url=BBCI_URL, - verbose=None): +def _load_data_009_2015( + subject, + path=None, + force_update=False, + update_path=None, + base_url=BBCI_URL, + verbose=None, +): """Load data for 009-2015 dataset.""" if (subject < 1) or (subject > 21): raise ValueError("Subject must be between 1 and 21. Got %d." % subject) + # fmt: off subjects = [ 'fce', 'kw', 'faz', 'fcj', 'fcg', 'far', 'faw', 'fax', 'fcc', 'fcm', 'fas', 'fch', 'fcd', 'fca', 'fcb', 'fau', 'fci', 'fav', 'fat', 'fcl', 'fck' ] + # fmt: on s = subjects[subject - 1] url = '{u}BNCIHorizon2020-AMUSE/AMUSE_VP{s}.mat'.format(u=base_url, s=s) filename = data_path(url, path, force_update, update_path)[0] @@ -385,20 +420,24 @@ def _load_data_009_2015(subject, @verbose -def _load_data_010_2015(subject, - path=None, - force_update=False, - update_path=None, - base_url=BBCI_URL, - verbose=None): +def _load_data_010_2015( + subject, + path=None, + force_update=False, + update_path=None, + base_url=BBCI_URL, + verbose=None, +): """Load data for 010-2015 dataset.""" if (subject < 1) or (subject > 12): raise ValueError("Subject must be between 1 and 12. Got %d." % subject) + # fmt: off subjects = [ 'fat', 'gcb', 'gcc', 'gcd', 'gce', 'gcf', 'gcg', 'gch', 'iay', 'icn', 'icr', 'pia' ] + # fmt: on s = subjects[subject - 1] url = '{u}BNCIHorizon2020-RSVP/RSVP_VP{s}.mat'.format(u=base_url, s=s) @@ -410,19 +449,23 @@ def _load_data_010_2015(subject, @verbose -def _load_data_012_2015(subject, - path=None, - force_update=False, - update_path=None, - base_url=BBCI_URL, - verbose=None): +def _load_data_012_2015( + subject, + path=None, + force_update=False, + update_path=None, + base_url=BBCI_URL, + verbose=None, +): """Load data for 012-2015 dataset.""" if (subject < 1) or (subject > 12): raise ValueError("Subject must be between 1 and 12. Got %d." % subject) + # fmt: off subjects = [ 'nv', 'nw', 'nx', 'ny', 'nz', 'mg', 'oa', 'ob', 'oc', 'od', 'ja', 'oe' ] + # fmt: on s = subjects[subject - 1] url = '{u}BNCIHorizon2020-PASS2D/PASS2D_VP{s}.mat'.format(u=base_url, s=s) @@ -434,20 +477,21 @@ def _load_data_012_2015(subject, @verbose -def _load_data_013_2015(subject, - path=None, - force_update=False, - update_path=None, - base_url=BNCI_URL, - verbose=None): +def _load_data_013_2015( + subject, + path=None, + force_update=False, + update_path=None, + base_url=BNCI_URL, + verbose=None, +): """Load data for 013-2015 dataset.""" if (subject < 1) or (subject > 6): raise ValueError("Subject must be between 1 and 6. Got %d." % subject) data_paths = [] for r in ['s1', 's2']: - url = '{u}013-2015/Subject{s:02d}_{r}.mat'.format( - u=base_url, s=subject, r=r) + url = '{u}013-2015/Subject{s:02d}_{r}.mat'.format(u=base_url, s=subject, r=r) data_paths.extend(data_path(url, path, force_update, update_path)) raws = [] @@ -464,10 +508,10 @@ def _load_data_013_2015(subject, def _convert_mi(filename, ch_names, ch_types): - ''' + """ Processes (Graz) motor imagery data from MAT files, returns list of recording runs. - ''' + """ from scipy.io import loadmat runs = [] @@ -491,14 +535,16 @@ def _convert_mi(filename, ch_names, ch_types): def standardize_keys(d): - master_list = [['both feet', 'feet'], - ['left hand', 'left_hand'], - ['right hand', 'right_hand'], - ['FEET', 'feet'], - ['HAND', 'right_hand'], - ['NAV', 'navigation'], - ['SUB', 'subtraction'], - ['WORD', 'word_ass']] + master_list = [ + ['both feet', 'feet'], + ['left hand', 'left_hand'], + ['right hand', 'right_hand'], + ['FEET', 'feet'], + ['HAND', 'right_hand'], + ['NAV', 'navigation'], + ['SUB', 'subtraction'], + ['WORD', 'word_ass'], + ] for old, new in master_list: if old in d.keys(): d[new] = d.pop(old) @@ -643,12 +689,18 @@ def _get_single_subject_data(self, subject): sessions = load_data(subject=subject, dataset=self.code, verbose=False) return sessions - def data_path(self, subject, path=None, force_update=False, - update_path=None, verbose=None): + def data_path( + self, subject, path=None, force_update=False, update_path=None, verbose=None + ): print(f"warning - datapath not implemented correctly for {self.code}") - return load_data(subject=subject, dataset=self.code, verbose=verbose, - update_path=update_path, path=path, - force_update=force_update) + return load_data( + subject=subject, + dataset=self.code, + verbose=verbose, + update_path=update_path, + path=path, + force_update=force_update, + ) class BNCI2014001(MNEBNCI): @@ -702,7 +754,8 @@ def __init__(self): code='001-2014', interval=[2, 6], paradigm='imagery', - doi='10.3389/fnins.2012.00055') + doi='10.3389/fnins.2012.00055', + ) class BNCI2014002(MNEBNCI): @@ -755,7 +808,8 @@ def __init__(self): code='002-2014', interval=[3, 8], paradigm='imagery', - doi='10.1515/bmt-2014-0117') + doi='10.1515/bmt-2014-0117', + ) class BNCI2014004(MNEBNCI): @@ -828,7 +882,8 @@ def __init__(self): code='004-2014', interval=[3, 7.5], paradigm='imagery', - doi='10.1109/TNSRE.2007.906956') + doi='10.1109/TNSRE.2007.906956', + ) class BNCI2014008(MNEBNCI): @@ -890,7 +945,8 @@ def __init__(self): code='008-2014', interval=[0, 1.0], paradigm='p300', - doi='10.3389/fnhum.2013.00732') + doi='10.3389/fnhum.2013.00732', + ) class BNCI2014009(MNEBNCI): @@ -943,7 +999,8 @@ def __init__(self): code='009-2014', interval=[0, 0.8], paradigm='p300', - doi='10.1088/1741-2560/11/3/035008') + doi='10.1088/1741-2560/11/3/035008', + ) class BNCI2015001(MNEBNCI): @@ -990,7 +1047,8 @@ def __init__(self): code='001-2015', interval=[0, 5], paradigm='imagery', - doi='10.1109/tnsre.2012.2189584') + doi='10.1109/tnsre.2012.2189584', + ) class BNCI2015003(MNEBNCI): @@ -1024,7 +1082,8 @@ def __init__(self): code='003-2015', interval=[0, 0.8], paradigm='p300', - doi='10.1016/j.neulet.2009.06.045') + doi='10.1016/j.neulet.2009.06.045', + ) class BNCI2015004(MNEBNCI): @@ -1085,9 +1144,9 @@ def __init__(self): super().__init__( subjects=list(range(1, 10)), sessions_per_subject=2, - events=dict(right_hand=4, feet=5, navigation=3, subtraction=2, - word_ass=1), + events=dict(right_hand=4, feet=5, navigation=3, subtraction=2, word_ass=1), code='004-2015', interval=[3, 10], paradigm='imagery', - doi='10.1371/journal.pone.0123727') + doi='10.1371/journal.pone.0123727', + ) diff --git a/moabb/datasets/braininvaders.py b/moabb/datasets/braininvaders.py index 1cc7e1c02..33a7bfa98 100644 --- a/moabb/datasets/braininvaders.py +++ b/moabb/datasets/braininvaders.py @@ -14,7 +14,7 @@ class bi2013a(BaseDataset): - '''P300 dataset bi2013a from a "Brain Invaders" experiment (2013) + """P300 dataset bi2013a from a "Brain Invaders" experiment (2013) carried-out at University of Grenoble Alpes. Dataset following the setup from [1]_. @@ -103,14 +103,9 @@ class bi2013a(BaseDataset): Design, Test and Use Brain-Computer Interfaces in Real and Virtual Environments. PRESENCE : Teleoperators and Virtual Environments 19(1), 35-53. - ''' - - def __init__( - self, - NonAdaptive=True, - Adaptive=False, - Training=True, - Online=False): + """ + + def __init__(self, NonAdaptive=True, Adaptive=False, Training=True, Online=False): super().__init__( subjects=list(range(1, 24 + 1)), sessions_per_subject=1, @@ -118,7 +113,8 @@ def __init__( code='Brain Invaders 2013a', interval=[0, 1], paradigm='p300', - doi='') + doi='', + ) self.adaptive = Adaptive self.nonadaptive = NonAdaptive @@ -142,8 +138,7 @@ def _get_single_subject_data(self, subject): run_number = run_number.split('.gdf')[0] run_name = 'run_' + run_number - raw_original = mne.io.read_raw_gdf(file_path, - preload=True) + raw_original = mne.io.read_raw_gdf(file_path, preload=True) raw_original.rename_channels({'FP1': 'Fp1', 'FP2': 'Fp2'}) raw_original.set_montage(make_standard_montage('standard_1020')) @@ -151,11 +146,12 @@ def _get_single_subject_data(self, subject): return sessions - def data_path(self, subject, path=None, force_update=False, - update_path=None, verbose=None): + def data_path( + self, subject, path=None, force_update=False, update_path=None, verbose=None + ): if subject not in self.subject_list: - raise(ValueError("Invalid subject number")) + raise (ValueError("Invalid subject number")) # check if has the .zip url = '{:s}subject{:d}.zip'.format(BI2013a_URL, subject) @@ -163,7 +159,7 @@ def data_path(self, subject, path=None, force_update=False, path_folder = path_zip.strip('subject{:d}.zip'.format(subject)) # check if has to unzip - if not(os.path.isdir(path_folder + 'subject{:d}'.format(subject))): + if not (os.path.isdir(path_folder + 'subject{:d}'.format(subject))): print('unzip', path_zip) zip_ref = zipfile.ZipFile(path_zip, "r") zip_ref.extractall(path_folder) @@ -193,6 +189,9 @@ def data_path(self, subject, path=None, force_update=False, # list the filepaths for this subject subject_paths = [] for filename in filenames: - subject_paths = subject_paths + \ - glob.glob(os.path.join(path_folder, 'subject{:d}'.format(subject), 'Session*', filename)) # noqa + subject_paths = subject_paths + glob.glob( + os.path.join( + path_folder, 'subject{:d}'.format(subject), 'Session*', filename + ) + ) # noqa return subject_paths diff --git a/moabb/datasets/download.py b/moabb/datasets/download.py index 9907bc346..1c9e601ed 100644 --- a/moabb/datasets/download.py +++ b/moabb/datasets/download.py @@ -11,8 +11,7 @@ @verbose -def data_path(url, sign, path=None, force_update=False, update_path=True, - verbose=None): +def data_path(url, sign, path=None, force_update=False, update_path=True, verbose=None): """Get path to local copy of given dataset URL. This is a low-level function useful for getting a local copy of a diff --git a/moabb/datasets/epfl.py b/moabb/datasets/epfl.py index 428f317cd..75b743e0a 100644 --- a/moabb/datasets/epfl.py +++ b/moabb/datasets/epfl.py @@ -69,7 +69,8 @@ def __init__(self): code='EPFL P300 dataset', interval=[0, 1], paradigm='p300', - doi='10.1016/j.jneumeth.2007.03.005') + doi='10.1016/j.jneumeth.2007.03.005', + ) def _get_single_run_data(self, file_path): @@ -116,14 +117,14 @@ def _get_single_run_data(self, file_path): 'Fz', 'Cz', 'MA1', - 'MA2'] + 'MA2', + ] ch_types = ['eeg'] * 32 + ['misc'] * 2 # The last X entries are 0 for all signals. This leads to # artifacts when epoching and band-pass filtering the data. # Correct the signals for this. - sig_i = np.where( - np.diff(np.all(signals == 0, axis=0).astype(int)) != 0)[0][0] + sig_i = np.where(np.diff(np.all(signals == 0, axis=0).astype(int)) != 0)[0][0] signals = signals[:, :sig_i] signals *= 1e-6 # data is stored as uV, but MNE expects V # we have to re-reference the signals @@ -135,15 +136,15 @@ def _get_single_run_data(self, file_path): # getting the event time in a Python standardized way events_datetime = [] for eventi in events: - events_datetime.append(dt.datetime( - *eventi.astype(int), int(eventi[-1] * 1e3) % 1000 * 1000)) + events_datetime.append( + dt.datetime(*eventi.astype(int), int(eventi[-1] * 1e3) % 1000 * 1000) + ) # get the indices of the stimuli pos = [] n_trials = len(stimuli) for j in range(n_trials): - delta_seconds = ( - events_datetime[j] - events_datetime[0]).total_seconds() + delta_seconds = (events_datetime[j] - events_datetime[0]).total_seconds() delta_indices = int(delta_seconds * sfreq) # has to add an offset pos.append(delta_indices + int(0.4 * sfreq)) @@ -177,28 +178,22 @@ def _get_single_subject_data(self, subject): for file_path in sorted(file_path_list): - session_name = 'session_' + \ - file_path.split(os.sep)[-2].replace('session', '') + session_name = 'session_' + file_path.split(os.sep)[-2].replace('session', '') if session_name not in sessions.keys(): sessions[session_name] = {} run_name = 'run_' + str(len(sessions[session_name]) + 1) - sessions[session_name][run_name] = self._get_single_run_data( - file_path) + sessions[session_name][run_name] = self._get_single_run_data(file_path) return sessions def data_path( - self, - subject, - path=None, - force_update=False, - update_path=None, - verbose=None): + self, subject, path=None, force_update=False, update_path=None, verbose=None + ): if subject not in self.subject_list: - raise(ValueError("Invalid subject number")) + raise (ValueError("Invalid subject number")) # check if has the .zip url = '{:s}subject{:d}.zip'.format(EPFLP300_URL, subject) @@ -206,14 +201,13 @@ def data_path( path_folder = path_zip.strip('subject{:d}.zip'.format(subject)) # check if has to unzip - if not(os.path.isdir(path_folder + 'subject{:d}'.format(subject))): + if not (os.path.isdir(path_folder + 'subject{:d}'.format(subject))): print('unzip', path_zip) zip_ref = zipfile.ZipFile(path_zip, "r") zip_ref.extractall(path_folder) # get the path to all files pattern = os.path.join('subject{:d}'.format(subject), '*', '*') - subject_paths = glob.glob( - path_folder + pattern) + subject_paths = glob.glob(path_folder + pattern) return subject_paths diff --git a/moabb/datasets/fake.py b/moabb/datasets/fake.py index c063d763b..a59ba7477 100644 --- a/moabb/datasets/fake.py +++ b/moabb/datasets/fake.py @@ -13,19 +13,32 @@ class FakeDataset(BaseDataset): """ - def __init__(self, event_list=('fake_c1', 'fake_c2', 'fake_c3'), - n_sessions=2, n_runs=2, n_subjects=10, paradigm='imagery'): + def __init__( + self, + event_list=('fake_c1', 'fake_c2', 'fake_c3'), + n_sessions=2, + n_runs=2, + n_subjects=10, + paradigm='imagery', + ): self.n_runs = n_runs event_id = {ev: ii + 1 for ii, ev in enumerate(event_list)} - super().__init__(list(range(1, n_subjects + 1)), n_sessions, event_id, - 'FakeDataset', [0, 3], paradigm) + super().__init__( + list(range(1, n_subjects + 1)), + n_sessions, + event_id, + 'FakeDataset', + [0, 3], + paradigm, + ) def _get_single_subject_data(self, subject): data = dict() for session in range(self.n_sessions): - data[f"session_{session}"] = {f"run_{ii}": self._generate_raw() - for ii in range(self.n_runs)} + data[f"session_{session}"] = { + f"run_{ii}": self._generate_raw() for ii in range(self.n_runs) + } return data def _generate_raw(self): @@ -38,8 +51,8 @@ def _generate_raw(self): eeg_data = 2e-5 * np.random.randn(duration * sfreq, len(ch_names)) y = np.zeros((duration * sfreq)) for ii, ev in enumerate(self.event_id): - start_idx = ((1 + 5 * ii) * 128) - jump = (5 * len(self.event_id) * 128) + start_idx = (1 + 5 * ii) * 128 + jump = 5 * len(self.event_id) * 128 y[start_idx::jump] = self.event_id[ev] ch_types = ['eeg'] * len(ch_names) + ['stim'] @@ -47,12 +60,12 @@ def _generate_raw(self): eeg_data = np.c_[eeg_data, y] - info = create_info(ch_names=ch_names, ch_types=ch_types, - sfreq=sfreq) + info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq) raw = RawArray(data=eeg_data.T, info=info, verbose=False) raw.set_montage(montage) return raw - def data_path(self, subject, path=None, force_update=False, - update_path=None, verbose=None): + def data_path( + self, subject, path=None, force_update=False, update_path=None, verbose=None + ): pass diff --git a/moabb/datasets/gigadb.py b/moabb/datasets/gigadb.py index 98226d1e9..956b6e8cb 100644 --- a/moabb/datasets/gigadb.py +++ b/moabb/datasets/gigadb.py @@ -15,7 +15,9 @@ log = logging.getLogger() -GIGA_URL = 'ftp://parrot.genomics.cn/gigadb/pub/10.5524/100001_101000/100295/mat_data/' # noqa +GIGA_URL = ( + 'ftp://parrot.genomics.cn/gigadb/pub/10.5524/100001_101000/100295/mat_data/' # noqa +) class Cho2017(BaseDataset): @@ -65,7 +67,8 @@ def __init__(self): code='Cho2017', interval=[0, 3], # full trial is 0-3s, but edge effects paradigm='imagery', - doi='10.5524/100295') + doi='10.5524/100295', + ) for ii in [32, 46, 49]: self.subject_list.remove(ii) @@ -74,9 +77,14 @@ def _get_single_subject_data(self, subject): """return data for a single subject""" fname = self.data_path(subject) - data = loadmat(fname, squeeze_me=True, struct_as_record=False, - verify_compressed_data_integrity=False)['eeg'] + data = loadmat( + fname, + squeeze_me=True, + struct_as_record=False, + verify_compressed_data_integrity=False, + )['eeg'] + # fmt: off eeg_ch_names = ['Fp1', 'AF7', 'AF3', 'F1', 'F3', 'F5', 'F7', 'FT7', 'FC5', 'FC3', 'FC1', 'C1', 'C3', 'C5', 'T7', 'TP7', 'CP5', 'CP3', 'CP1', 'P1', 'P3', 'P5', 'P7', 'P9', @@ -85,38 +93,40 @@ def _get_single_subject_data(self, subject): 'F6', 'F8', 'FT8', 'FC6', 'FC4', 'FC2', 'FCz', 'Cz', 'C2', 'C4', 'C6', 'T8', 'TP8', 'CP6', 'CP4', 'CP2', 'P2', 'P4', 'P6', 'P8', 'P10', 'PO8', 'PO4', 'O2'] + # fmt: on emg_ch_names = ['EMG1', 'EMG2', 'EMG3', 'EMG4'] ch_names = eeg_ch_names + emg_ch_names + ['Stim'] ch_types = ['eeg'] * 64 + ['emg'] * 4 + ['stim'] montage = make_standard_montage('standard_1005') - imagery_left = data.imagery_left - \ - data.imagery_left.mean(axis=1, keepdims=True) - imagery_right = data.imagery_right - \ - data.imagery_right.mean(axis=1, keepdims=True) + imagery_left = data.imagery_left - data.imagery_left.mean(axis=1, keepdims=True) + imagery_right = data.imagery_right - data.imagery_right.mean( + axis=1, keepdims=True + ) eeg_data_l = np.vstack([imagery_left * 1e-6, data.imagery_event]) - eeg_data_r = np.vstack([imagery_right * 1e-6, - data.imagery_event * 2]) + eeg_data_r = np.vstack([imagery_right * 1e-6, data.imagery_event * 2]) # trials are already non continuous. edge artifact can appears but # are likely to be present during rest / inter-trial activity - eeg_data = np.hstack([eeg_data_l, np.zeros((eeg_data_l.shape[0], 500)), - eeg_data_r]) - log.warning("Trials demeaned and stacked with zero buffer to create " - "continuous data -- edge effects present") - - info = create_info(ch_names=ch_names, ch_types=ch_types, - sfreq=data.srate) + eeg_data = np.hstack( + [eeg_data_l, np.zeros((eeg_data_l.shape[0], 500)), eeg_data_r] + ) + log.warning( + "Trials demeaned and stacked with zero buffer to create " + "continuous data -- edge effects present" + ) + + info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=data.srate) raw = RawArray(data=eeg_data, info=info, verbose=False) raw.set_montage(montage) return {'session_0': {'run_0': raw}} - def data_path(self, subject, path=None, force_update=False, - update_path=None, verbose=None): + def data_path( + self, subject, path=None, force_update=False, update_path=None, verbose=None + ): if subject not in self.subject_list: - raise(ValueError("Invalid subject number")) + raise (ValueError("Invalid subject number")) url = '{:s}s{:02d}.mat'.format(GIGA_URL, subject) - return dl.data_path(url, 'GIGADB', path, force_update, update_path, - verbose) + return dl.data_path(url, 'GIGADB', path, force_update, update_path, verbose) diff --git a/moabb/datasets/mpi_mi.py b/moabb/datasets/mpi_mi.py index 3f7d3eaff..b0b99c2e9 100644 --- a/moabb/datasets/mpi_mi.py +++ b/moabb/datasets/mpi_mi.py @@ -64,12 +64,14 @@ def __init__(self): code='Grosse-Wentrup 2009', interval=[0, 7], paradigm='imagery', - doi='10.1109/TBME.2008.2009768') + doi='10.1109/TBME.2008.2009768', + ) def _get_single_subject_data(self, subject): """return data for a single subject""" - raw = mne.io.read_raw_eeglab(self.data_path(subject), preload=True, - verbose='ERROR') + raw = mne.io.read_raw_eeglab( + self.data_path(subject), preload=True, verbose='ERROR' + ) stim = raw.annotations.description.astype(np.dtype('<10U')) stim[stim == '20'] = 'right_hand' @@ -77,17 +79,18 @@ def _get_single_subject_data(self, subject): raw.annotations.description = stim return {"session_0": {"run_0": raw}} - def data_path(self, subject, path=None, force_update=False, - update_path=None, verbose=None): + def data_path( + self, subject, path=None, force_update=False, update_path=None, verbose=None + ): if subject not in self.subject_list: - raise(ValueError("Invalid subject number")) + raise (ValueError("Invalid subject number")) # download .set _set = '{:s}subject{:d}.set'.format(DOWNLOAD_URL, subject) - set_local = dl.data_path(_set, 'MUNICHMI', path, force_update, - update_path, verbose) + set_local = dl.data_path( + _set, 'MUNICHMI', path, force_update, update_path, verbose + ) # download .fdt _fdt = '{:s}subject{:d}.fdt'.format(DOWNLOAD_URL, subject) - dl.data_path(_fdt, 'MUNICHMI', path, force_update, - update_path, verbose) + dl.data_path(_fdt, 'MUNICHMI', path, force_update, update_path, verbose) return set_local diff --git a/moabb/datasets/physionet_mi.py b/moabb/datasets/physionet_mi.py index a80883b26..2e6be1825 100644 --- a/moabb/datasets/physionet_mi.py +++ b/moabb/datasets/physionet_mi.py @@ -82,7 +82,8 @@ def __init__(self, imagined=True, executed=False): # interval between 2 trial is 4 second. interval=[0, 3], paradigm='imagery', - doi='10.1109/TBME.2004.827072') + doi='10.1109/TBME.2004.827072', + ) self.feet_runs = [] self.hand_runs = [] @@ -97,17 +98,21 @@ def __init__(self, imagined=True, executed=False): def _load_one_run(self, subject, run, preload=True): if get_config('MNE_DATASETS_EEGBCI_PATH') is None: - set_config('MNE_DATASETS_EEGBCI_PATH', - osp.join(osp.expanduser("~"), "mne_data")) - raw_fname = eegbci.load_data(subject, runs=[run], verbose='ERROR', - base_url=BASE_URL)[0] + set_config( + 'MNE_DATASETS_EEGBCI_PATH', osp.join(osp.expanduser("~"), "mne_data") + ) + raw_fname = eegbci.load_data( + subject, runs=[run], verbose='ERROR', base_url=BASE_URL + )[0] raw = read_raw_edf(raw_fname, preload=preload, verbose='ERROR') raw.rename_channels(lambda x: x.strip('.')) raw.rename_channels(lambda x: x.upper()) + # fmt: off raw.rename_channels({'AFZ': 'AFz', 'PZ': 'Pz', 'FPZ': 'Fpz', 'FCZ': 'FCz', 'FP1': 'Fp1', 'CZ': 'Cz', 'OZ': 'Oz', 'POZ': 'POz', 'IZ': 'Iz', 'CPZ': 'CPz', 'FP2': 'Fp2', 'FZ': 'Fz'}) + # fmt: on raw.set_montage(mne.channels.make_standard_montage('standard_1005')) return raw @@ -115,8 +120,9 @@ def _get_single_subject_data(self, subject): """return data for a single subject""" data = {} if get_config('MNE_DATASETS_EEGBCI_PATH') is None: - set_config('MNE_DATASETS_EEGBCI_PATH', - osp.join(osp.expanduser("~"), "mne_data")) + set_config( + 'MNE_DATASETS_EEGBCI_PATH', osp.join(osp.expanduser("~"), "mne_data") + ) # hand runs for run in self.hand_runs: @@ -142,15 +148,17 @@ def _get_single_subject_data(self, subject): return {"session_0": data} - def data_path(self, subject, path=None, force_update=False, - update_path=None, verbose=None): + def data_path( + self, subject, path=None, force_update=False, update_path=None, verbose=None + ): if subject not in self.subject_list: raise (ValueError("Invalid subject number")) if get_config('MNE_DATASETS_EEGBCI_PATH') is None: - set_config('MNE_DATASETS_EEGBCI_PATH', - osp.join(osp.expanduser("~"), "mne_data")) - paths = eegbci.load_data(subject, - runs=[1, 2] + self.hand_runs + self.feet_runs, - verbose=verbose) + set_config( + 'MNE_DATASETS_EEGBCI_PATH', osp.join(osp.expanduser("~"), "mne_data") + ) + paths = eegbci.load_data( + subject, runs=[1, 2] + self.hand_runs + self.feet_runs, verbose=verbose + ) return paths diff --git a/moabb/datasets/schirrmeister2017.py b/moabb/datasets/schirrmeister2017.py index b139b5aa7..230663750 100644 --- a/moabb/datasets/schirrmeister2017.py +++ b/moabb/datasets/schirrmeister2017.py @@ -13,45 +13,47 @@ log = logging.getLogger(__name__) -GIN_URL = "https://web.gin.g-node.org/robintibor/high-gamma-dataset/raw/master/data" # noqa +GIN_URL = ( + "https://web.gin.g-node.org/robintibor/high-gamma-dataset/raw/master/data" # noqa +) class Schirrmeister2017(BaseDataset): """High-gamma dataset discribed in Schirrmeister et al. 2017 -Our “High-Gamma Dataset” is a 128-electrode dataset (of which we later only use -44 sensors covering the motor cortex, (see Section 2.7.1), obtained from 14 -healthy subjects (6 female, 2 left-handed, age 27.2 ± 3.6 (mean ± std)) with -roughly 1000 (963.1 ± 150.9, mean ± std) four-second trials of executed -movements divided into 13 runs per subject. The four classes of movements were -movements of either the left hand, the right hand, both feet, and rest (no -movement, but same type of visual cue as for the other classes). The training -set consists of the approx. 880 trials of all runs except the last two runs, -the test set of the approx. 160 trials of the last 2 runs. This dataset was -acquired in an EEG lab optimized for non-invasive detection of high- frequency -movement-related EEG components (Ball et al., 2008; Darvas et al., 2010). - - Depending on the direction of a gray arrow that was shown on black back- -ground, the subjects had to repetitively clench their toes (downward arrow), -perform sequential finger-tapping of their left (leftward arrow) or right -(rightward arrow) hand, or relax (upward arrow). The movements were selected -to require little proximal muscular activity while still being complex enough -to keep subjects in- volved. Within the 4-s trials, the subjects performed the -repetitive movements at their own pace, which had to be maintained as long as -the arrow was showing. Per run, 80 arrows were displayed for 4 s each, with 3 -to 4 s of continuous random inter-trial interval. The order of presentation -was pseudo-randomized, with all four arrows being shown every four trials. -Ideally 13 runs were performed to collect 260 trials of each movement and rest. -The stimuli were presented and the data recorded with BCI2000 (Schalk et al., -2004). The experiment was approved by the ethical committee of the University -of Freiburg. - -References ----------- - -.. [1] Schirrmeister, Robin Tibor, et al. "Deep learning with convolutional -neural networks for EEG decoding and visualization." Human brain mapping 38.11 -(2017): 5391-5420. + Our “High-Gamma Dataset” is a 128-electrode dataset (of which we later only use + 44 sensors covering the motor cortex, (see Section 2.7.1), obtained from 14 + healthy subjects (6 female, 2 left-handed, age 27.2 ± 3.6 (mean ± std)) with + roughly 1000 (963.1 ± 150.9, mean ± std) four-second trials of executed + movements divided into 13 runs per subject. The four classes of movements were + movements of either the left hand, the right hand, both feet, and rest (no + movement, but same type of visual cue as for the other classes). The training + set consists of the approx. 880 trials of all runs except the last two runs, + the test set of the approx. 160 trials of the last 2 runs. This dataset was + acquired in an EEG lab optimized for non-invasive detection of high- frequency + movement-related EEG components (Ball et al., 2008; Darvas et al., 2010). + + Depending on the direction of a gray arrow that was shown on black back- + ground, the subjects had to repetitively clench their toes (downward arrow), + perform sequential finger-tapping of their left (leftward arrow) or right + (rightward arrow) hand, or relax (upward arrow). The movements were selected + to require little proximal muscular activity while still being complex enough + to keep subjects in- volved. Within the 4-s trials, the subjects performed the + repetitive movements at their own pace, which had to be maintained as long as + the arrow was showing. Per run, 80 arrows were displayed for 4 s each, with 3 + to 4 s of continuous random inter-trial interval. The order of presentation + was pseudo-randomized, with all four arrows being shown every four trials. + Ideally 13 runs were performed to collect 260 trials of each movement and rest. + The stimuli were presented and the data recorded with BCI2000 (Schalk et al., + 2004). The experiment was approved by the ethical committee of the University + of Freiburg. + + References + ---------- + + .. [1] Schirrmeister, Robin Tibor, et al. "Deep learning with convolutional + neural networks for EEG decoding and visualization." Human brain mapping 38.11 + (2017): 5391-5420. """ @@ -63,18 +65,24 @@ def __init__(self): code='Schirrmeister2017', interval=[0, 4], paradigm='imagery', - doi='10.1002/hbm.23730') + doi='10.1002/hbm.23730', + ) - def data_path(self, subject, path=None, force_update=False, - update_path=None, verbose=None): + def data_path( + self, subject, path=None, force_update=False, update_path=None, verbose=None + ): if subject not in self.subject_list: - raise(ValueError('Invalid subject number')) + raise (ValueError('Invalid subject number')) def _url(prefix): return '/'.join([GIN_URL, prefix, '{:d}.mat'.format(subject)]) - return [dl.data_path(_url(t), 'SCHIRRMEISTER2017', path, force_update, - update_path, verbose) for t in ['train', 'test']] + return [ + dl.data_path( + _url(t), 'SCHIRRMEISTER2017', path, force_update, update_path, verbose + ) + for t in ['train', 'test'] + ] def _get_single_subject_data(self, subject): train, test = [BBCIDataset(path) for path in self.data_path(subject)] @@ -115,8 +123,7 @@ def _load_continuous_signal(self): with h5py.File(self.filename, 'r') as h5file: samples = int(h5file['nfo']['T'][0, 0]) cnt_signal_shape = (samples, len(wanted_chan_inds)) - continuous_signal = np.ones(cnt_signal_shape, - dtype=np.float32) * np.nan + continuous_signal = np.ones(cnt_signal_shape, dtype=np.float32) * np.nan for chan_ind_arr, chan_ind_set in enumerate(wanted_chan_inds): # + 1 because matlab/this hdf5-naming logic # has 1-based indexing @@ -124,15 +131,14 @@ def _load_continuous_signal(self): chan_set_name = 'ch' + str(chan_ind_set + 1) # first 0 to unpack into vector, before it is 1xN matrix chan_signal = h5file[chan_set_name][ - :].squeeze() # already load into memory + : + ].squeeze() # already load into memory continuous_signal[:, chan_ind_arr] = chan_signal - assert not np.any( - np.isnan(continuous_signal)), "No NaNs expected in signal" + assert not np.any(np.isnan(continuous_signal)), "No NaNs expected in signal" # Assume we cant know channel type here automatically ch_types = ['eeg'] * len(wanted_chan_inds) - info = mne.create_info(ch_names=wanted_sensor_names, sfreq=fs, - ch_types=ch_types) + info = mne.create_info(ch_names=wanted_sensor_names, sfreq=fs, ch_types=ch_types) # Scale to volts from microvolts, (VJ 19.6.18) continuous_signal = continuous_signal * 1e-6 cnt = mne.io.RawArray(continuous_signal.T, info) @@ -144,22 +150,21 @@ def _determine_sensors(self): # if no sensor names given, take all EEG-chans eeg_sensor_names = all_sensor_names - eeg_sensor_names = filter(lambda s: not s.startswith('BIP'), - eeg_sensor_names) - eeg_sensor_names = filter(lambda s: not s.startswith('E'), - eeg_sensor_names) - eeg_sensor_names = filter(lambda s: not s.startswith('Microphone'), - eeg_sensor_names) - eeg_sensor_names = filter(lambda s: not s.startswith('Breath'), - eeg_sensor_names) - eeg_sensor_names = filter(lambda s: not s.startswith('GSR'), - eeg_sensor_names) + eeg_sensor_names = filter(lambda s: not s.startswith('BIP'), eeg_sensor_names) + eeg_sensor_names = filter(lambda s: not s.startswith('E'), eeg_sensor_names) + eeg_sensor_names = filter( + lambda s: not s.startswith('Microphone'), eeg_sensor_names + ) + eeg_sensor_names = filter( + lambda s: not s.startswith('Breath'), eeg_sensor_names + ) + eeg_sensor_names = filter(lambda s: not s.startswith('GSR'), eeg_sensor_names) eeg_sensor_names = list(eeg_sensor_names) - assert (len(eeg_sensor_names) in set( - [128, 64, 32, 16])), "check this code if you have different sensors..." # noqa + assert len(eeg_sensor_names) in set( + [128, 64, 32, 16] + ), "check this code if you have different sensors..." # noqa self.load_sensor_names = eeg_sensor_names - chan_inds = self._determine_chan_inds(all_sensor_names, - self.load_sensor_names) + chan_inds = self._determine_chan_inds(all_sensor_names, self.load_sensor_names) return chan_inds, self.load_sensor_names def _determine_samplingrate(self): @@ -173,14 +178,10 @@ def _determine_samplingrate(self): def _determine_chan_inds(all_sensor_names, sensor_names): assert sensor_names is not None chan_inds = [all_sensor_names.index(s) for s in sensor_names] - assert len(chan_inds) == len(sensor_names), ("All" - "sensors" - "should be there.") + assert len(chan_inds) == len(sensor_names), "All" "sensors" "should be there." # TODO: is it possible for this to fail? the list # comp fails first right? - assert len(set(chan_inds)) == len(chan_inds), ("No" - "duplicated sensors" - "wanted.") + assert len(set(chan_inds)) == len(chan_inds), "No" "duplicated sensors" "wanted." return chan_inds @staticmethod @@ -201,19 +202,19 @@ def get_all_sensors(filename, pattern=None): """ with h5py.File(filename, 'r') as h5file: clab_set = h5file['nfo']['clab'][:].squeeze() - all_sensor_names = [''.join( - chr(c.squeeze()) for c in h5file[obj_ref]) for obj_ref in clab_set] + all_sensor_names = [ + ''.join(chr(c.squeeze()) for c in h5file[obj_ref]) for obj_ref in clab_set + ] if pattern is not None: all_sensor_names = filter( - lambda sname: re.search(pattern, sname), - all_sensor_names) + lambda sname: re.search(pattern, sname), all_sensor_names + ) return all_sensor_names def _add_markers(self, cnt): with h5py.File(self.filename, 'r') as h5file: event_times_in_ms = h5file['mrk']['time'][:].squeeze() - event_classes = h5file['mrk']['event']['desc'][:].squeeze().astype( - np.int64) + event_classes = h5file['mrk']['event']['desc'][:].squeeze().astype(np.int64) # Check whether class names known and correct order # class_name_set = h5file['nfo']['className'][:].squeeze() @@ -225,27 +226,31 @@ def _add_markers(self, cnt): # Check if there are markers at the same time previous_i_sample = -1 - for i_event, (i_sample, _) in enumerate(zip(event_times_in_samples, event_classes)): + for i_event, (i_sample, _) in enumerate( + zip(event_times_in_samples, event_classes) + ): if i_sample == previous_i_sample: - info = "{:d}: ({:.0f} and {:.0f}).\n".format(i_sample, - event_classes[ - i_event - 1], - event_classes[ - i_event]) - log.warning(f"Same sample has at least two markers.\n{info}Marker codes will be summed.") + info = "{:d}: ({:.0f} and {:.0f}).\n".format( + i_sample, event_classes[i_event - 1], event_classes[i_event] + ) + log.warning( + f"Same sample has at least two markers.\n{info}Marker codes will be summed." + ) previous_i_sample = i_sample # Now create stim chan stim_chan = np.zeros_like(cnt.get_data()[0]) for i_sample, id_class in zip(event_times_in_samples, event_classes): stim_chan[i_sample] += id_class - info = mne.create_info(ch_names=['STI 014'], - sfreq=cnt.info['sfreq'], - ch_types=['stim']) + info = mne.create_info( + ch_names=['STI 014'], sfreq=cnt.info['sfreq'], ch_types=['stim'] + ) stim_cnt = mne.io.RawArray(stim_chan[None], info, verbose='WARNING') cnt = cnt.add_channels([stim_cnt]) - event_arr = [event_times_in_samples, - [0] * len(event_times_in_samples), - event_classes] + event_arr = [ + event_times_in_samples, + [0] * len(event_times_in_samples), + event_classes, + ] cnt.info['events'] = np.array(event_arr).T return cnt diff --git a/moabb/datasets/ssvep_exo.py b/moabb/datasets/ssvep_exo.py index 2e87d62b4..afcb51876 100644 --- a/moabb/datasets/ssvep_exo.py +++ b/moabb/datasets/ssvep_exo.py @@ -56,7 +56,8 @@ def __init__(self): code='SSVEP Exoskeleton', interval=[2, 4], paradigm='ssvep', - doi='10.1016/j.neucom.2016.01.007') + doi='10.1016/j.neucom.2016.01.007', + ) def _get_single_subject_data(self, subject): """Return the data of a single subject""" @@ -68,20 +69,20 @@ def _get_single_subject_data(self, subject): out['run_%d' % ii] = raw return {"session_0": out} - def data_path(self, subject, path=None, force_update=False, - update_path=None, verbose=None): + def data_path( + self, subject, path=None, force_update=False, update_path=None, verbose=None + ): - runs = {s + 1: n for s, - n in enumerate([2] * 6 + [3] + [2] * 2 + [4, 2, 5])} + runs = {s + 1: n for s, n in enumerate([2] * 6 + [3] + [2] * 2 + [4, 2, 5])} if subject not in self.subject_list: - raise(ValueError("Invalid subject number")) + raise (ValueError("Invalid subject number")) paths = [] for run in range(runs[subject]): - url = '{:s}subject{:02d}_run{:d}_raw.fif'.format(SSVEPEXO_URL, - subject, run + 1) - p = dl.data_path(url, 'SSVEPEXO', path, force_update, update_path, - verbose) + url = '{:s}subject{:02d}_run{:d}_raw.fif'.format( + SSVEPEXO_URL, subject, run + 1 + ) + p = dl.data_path(url, 'SSVEPEXO', path, force_update, update_path, verbose) paths.append(p) return paths diff --git a/moabb/datasets/ssvep_mamem.py b/moabb/datasets/ssvep_mamem.py index b2b639ad6..f05ec40d9 100644 --- a/moabb/datasets/ssvep_mamem.py +++ b/moabb/datasets/ssvep_mamem.py @@ -31,13 +31,20 @@ # MAMEM2_URL = 'https://ndownloader.figshare.com/articles/3153409/versions/2' # MAMEM3_URL = 'https://ndownloader.figshare.com/articles/3413851/versions/1' -MAMEM1_URL = "https://archive.physionet.org/physiobank/database/mssvepdb/dataset1/" # noqa: E501 -MAMEM2_URL = "https://archive.physionet.org/physiobank/database/mssvepdb/dataset2/" # noqa: E501 -MAMEM3_URL = "https://archive.physionet.org/physiobank/database/mssvepdb/dataset3/" # noqa: E501 +MAMEM1_URL = ( + "https://archive.physionet.org/physiobank/database/mssvepdb/dataset1/" # noqa: E501 +) +MAMEM2_URL = ( + "https://archive.physionet.org/physiobank/database/mssvepdb/dataset2/" # noqa: E501 +) +MAMEM3_URL = ( + "https://archive.physionet.org/physiobank/database/mssvepdb/dataset3/" # noqa: E501 +) class BaseMAMEM(BaseDataset): """Base class for MAMEM datasets""" + def __init__(self, sessions_per_subject, code, doi): super().__init__( subjects=list(range(1, 11)), @@ -46,7 +53,7 @@ def __init__(self, sessions_per_subject, code, doi): paradigm='ssvep', sessions_per_subject=sessions_per_subject, code=code, - doi=doi + doi=doi, ) def _get_single_subject_data(self, subject): @@ -73,8 +80,9 @@ def _get_single_subject_data(self, subject): n_samples = record.sig_len stim_freq = np.array([float(e) for e in self.event_id.keys()]) # aux_note are the exact frequencies, matched to nearest class - events_label = [np.argmin(np.abs(stim_freq - float(f))) + 1 - for f in annots.aux_note] + events_label = [ + np.argmin(np.abs(stim_freq - float(f))) + 1 for f in annots.aux_note + ] raw_events = np.zeros([1, n_samples]) # annots.sample indicates the start of the trial # of class "events_label" @@ -106,8 +114,9 @@ def _get_single_subject_data(self, subject): sessions[session_name][run_name] = raw return sessions - def data_path(self, subject, path=None, force_update=False, - update_path=None, verbose=None): + def data_path( + self, subject, path=None, force_update=False, update_path=None, verbose=None + ): if subject not in self.subject_list: raise (ValueError("Invalid subject number")) # Check if the .dat, .hea and .win files are present @@ -129,8 +138,7 @@ def data_path(self, subject, path=None, force_update=False, key_dest = "MNE-{:s}-data".format(sign.lower()) path = _get_path(path, key, sign) path = os.path.join(path, key_dest) - s_paths = glob.glob(os.path.join( - path, fn.format(sub))) + s_paths = glob.glob(os.path.join(path, fn.format(sub))) subject_paths = [] for name in s_paths: subject_paths.append(os.path.splitext(name)[0]) @@ -142,8 +150,9 @@ def data_path(self, subject, path=None, force_update=False, for ele in datarec: if fn.format(sub) in ele: datalist.append(ele) - wfdb.io.dl_database("mssvepdb", path, datalist, - annotators="win", overwrite=force_update) + wfdb.io.dl_database( + "mssvepdb", path, datalist, annotators="win", overwrite=force_update + ) # Return the file paths depending on the number of sessions s_paths = glob.glob(os.path.join(path, fn.format(sub))) subject_paths = [] @@ -245,11 +254,14 @@ class MAMEM1(BaseMAMEM): [2] DataAcquisitionDetails.pdf on https://figshare.com/articles/dataset/MAMEM_EEG_SSVEP_Dataset_I_256_channels_11_subjects_5_frequencies_/2068677?file=3793738 # noqa: E501 """ + def __init__(self): - super().__init__(sessions_per_subject=3, - # 3 for S001, S003, S008, 4 for S004 - code="SSVEP MAMEM1", - doi="https://arxiv.org/abs/1602.00904") + super().__init__( + sessions_per_subject=3, + # 3 for S001, S003, S008, 4 for S004 + code="SSVEP MAMEM1", + doi="https://arxiv.org/abs/1602.00904", + ) class MAMEM2(BaseMAMEM): @@ -319,10 +331,13 @@ class MAMEM2(BaseMAMEM): [2] DataAcquisitionDetails.pdf on https://figshare.com/articles/dataset/MAMEM_EEG_SSVEP_Dataset_II_256_channels_11_subjects_5_frequencies_presented_simultaneously_/3153409?file=4911931 # noqa: E501 """ + def __init__(self): - super().__init__(sessions_per_subject=5, - code="SSVEP MAMEM2", - doi="https://arxiv.org/abs/1602.00904") + super().__init__( + sessions_per_subject=5, + code="SSVEP MAMEM2", + doi="https://arxiv.org/abs/1602.00904", + ) class MAMEM3(BaseMAMEM): @@ -401,7 +416,10 @@ class MAMEM3(BaseMAMEM): [2] DataAcquisitionDetails.pdf on https://figshare.com/articles/dataset/MAMEM_EEG_SSVEP_Dataset_III_14_channels_11_subjects_5_frequencies_presented_simultaneously_/3413851 # noqa: E501 """ + def __init__(self): - super().__init__(sessions_per_subject=5, - code="SSVEP MAMEM3", - doi="https://arxiv.org/abs/1602.00904") + super().__init__( + sessions_per_subject=5, + code="SSVEP MAMEM3", + doi="https://arxiv.org/abs/1602.00904", + ) diff --git a/moabb/datasets/ssvep_nakanishi.py b/moabb/datasets/ssvep_nakanishi.py index 09f0010dc..daa29f895 100644 --- a/moabb/datasets/ssvep_nakanishi.py +++ b/moabb/datasets/ssvep_nakanishi.py @@ -40,13 +40,25 @@ def __init__(self): super().__init__( subjects=list(range(1, 10)), sessions_per_subject=1, - events={'9.25': 1, '11.25': 2, '13.25': 3, '9.75': 4, '11.75': 5, - '13.75': 6, '10.25': 7, '12.25': 8, '14.25': 9, - '10.75': 10, '12.75': 11, '14.75': 12}, + events={ + '9.25': 1, + '11.25': 2, + '13.25': 3, + '9.75': 4, + '11.75': 5, + '13.75': 6, + '10.25': 7, + '12.25': 8, + '14.25': 9, + '10.75': 10, + '12.75': 11, + '14.75': 12, + }, code='SSVEP Nakanishi', interval=[0.15, 4.3], paradigm='ssvep', - doi='doi.org/10.1371/journal.pone.0140703') + doi='doi.org/10.1371/journal.pone.0140703', + ) def _get_single_subject_data(self, subject): """Return the data of a single subject""" @@ -59,30 +71,30 @@ def _get_single_subject_data(self, subject): data = np.reshape(data, newshape=(-1, n_channels, n_samples)) data = data - data.mean(axis=2, keepdims=True) raw_events = np.zeros((data.shape[0], 1, n_samples)) - raw_events[:, 0, 0] = np.array([n_trials * [i + 1] - for i in range(n_classes)]).flatten() + raw_events[:, 0, 0] = np.array( + [n_trials * [i + 1] for i in range(n_classes)] + ).flatten() data = np.concatenate([1e-6 * data, raw_events], axis=1) # add buffer in between trials - log.warning("Trial data de-meaned and concatenated with a buffer" - " to create continuous data") + log.warning( + "Trial data de-meaned and concatenated with a buffer" + " to create continuous data" + ) buff = (data.shape[0], n_channels + 1, 50) - data = np.concatenate([np.zeros(buff), data, - np.zeros(buff)], axis=2) - ch_names = ['PO7', 'PO3', 'POz', 'PO4', 'PO8', - 'O1', 'Oz', 'O2', 'stim'] + data = np.concatenate([np.zeros(buff), data, np.zeros(buff)], axis=2) + ch_names = ['PO7', 'PO3', 'POz', 'PO4', 'PO8', 'O1', 'Oz', 'O2', 'stim'] ch_types = ['eeg'] * 8 + ['stim'] sfreq = 256 info = create_info(ch_names, sfreq, ch_types) - raw = RawArray(data=np.concatenate(list(data), axis=1), - info=info, verbose=False) + raw = RawArray(data=np.concatenate(list(data), axis=1), info=info, verbose=False) montage = make_standard_montage('standard_1005') raw.set_montage(montage) return {'session_0': {'run_0': raw}} - def data_path(self, subject, path=None, force_update=False, - update_path=None, verbose=None): + def data_path( + self, subject, path=None, force_update=False, update_path=None, verbose=None + ): if subject not in self.subject_list: - raise(ValueError("Invalid subject number")) + raise (ValueError("Invalid subject number")) url = '{:s}s{:d}.mat'.format(NAKAHISHI_URL, subject) - return dl.data_path(url, 'NAKANISHI', path, force_update, update_path, - verbose) + return dl.data_path(url, 'NAKANISHI', path, force_update, update_path, verbose) diff --git a/moabb/datasets/ssvep_wang.py b/moabb/datasets/ssvep_wang.py index 4b1c8c468..e209d8b78 100644 --- a/moabb/datasets/ssvep_wang.py +++ b/moabb/datasets/ssvep_wang.py @@ -85,6 +85,7 @@ def __init__(self): super().__init__( subjects=list(range(1, 35)), sessions_per_subject=1, + # fmt: off events={'8': 1, '9': 2, '10': 3, '11': 4, '12': 5, '13': 6, '14': 7, '15': 8, '8.2': 9, '9.2': 10, '10.2': 11, '11.2': 12, '12.2': 13, '13.2': 14, '14.2': 15, '15.2': 16, @@ -94,10 +95,12 @@ def __init__(self): '14.6': 31, '15.6': 32, '8.8': 33, '9.8': 34, '10.8': 35, '11.8': 36, '12.8': 37, '13.8': 38, '14.8': 39, '15.8': 40}, + # fmt: on code='SSVEP Wang', interval=[0.5, 5.5], paradigm='ssvep', - doi='doi://10.1109/TNSRE.2016.2627556') + doi='doi://10.1109/TNSRE.2016.2627556', + ) def _get_single_subject_data(self, subject): """Return the data of a single subject""" @@ -112,16 +115,19 @@ def _get_single_subject_data(self, subject): data = np.reshape(data, newshape=(-1, n_channels, n_samples)) data = data - data.mean(axis=2, keepdims=True) raw_events = np.zeros((data.shape[0], 1, n_samples)) - raw_events[:, 0, 0] = np.array([n_trials * [i + 1] - for i in range(n_classes)]).flatten() + raw_events[:, 0, 0] = np.array( + [n_trials * [i + 1] for i in range(n_classes)] + ).flatten() data = np.concatenate([1e-6 * data, raw_events], axis=1) # add buffer in between trials - log.warning("Trial data de-meaned and concatenated with a buffer" - " to create continuous data") + log.warning( + "Trial data de-meaned and concatenated with a buffer" + " to create continuous data" + ) buff = (data.shape[0], n_channels + 1, 50) - data = np.concatenate([np.zeros(buff), data, - np.zeros(buff)], axis=2) + data = np.concatenate([np.zeros(buff), data, np.zeros(buff)], axis=2) + # fmt: off ch_names = ['Fp1', 'Fpz', 'Fp2', 'AF3', 'AF4', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'FT8', 'T7', 'C5', 'C3', 'C1', @@ -130,19 +136,19 @@ def _get_single_subject_data(self, subject): 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO5', 'PO3', 'POz', 'PO4', 'PO6', 'PO8', 'CB1', 'O1', 'Oz', 'O2', 'CB2', 'stim'] + # fmt: on ch_types = ['eeg'] * 59 + ['misc'] + 3 * ['eeg'] + ['misc', 'stim'] sfreq = 250 info = create_info(ch_names, sfreq, ch_types) - raw = RawArray(data=np.concatenate(list(data), axis=1), - info=info, verbose=False) + raw = RawArray(data=np.concatenate(list(data), axis=1), info=info, verbose=False) montage = make_standard_montage('standard_1005') raw.set_montage(montage) return {'session_0': {'run_0': raw}} - def data_path(self, subject, path=None, force_update=False, - update_path=None, verbose=None): + def data_path( + self, subject, path=None, force_update=False, update_path=None, verbose=None + ): if subject not in self.subject_list: - raise(ValueError("Invalid subject number")) + raise (ValueError("Invalid subject number")) url = '{:s}s{:d}.rar'.format(WANG_URL, subject) - return dl.data_path(url, 'WANG', path, force_update, update_path, - verbose) + return dl.data_path(url, 'WANG', path, force_update, update_path, verbose) diff --git a/moabb/datasets/upper_limb.py b/moabb/datasets/upper_limb.py index 9246a2ebf..9dc63659d 100644 --- a/moabb/datasets/upper_limb.py +++ b/moabb/datasets/upper_limb.py @@ -59,13 +59,15 @@ class Ofner2017(BaseDataset): def __init__(self, imagined=True, executed=False): self.imagined = imagined self.executed = executed - event_id = {"right_elbow_flexion": 1536, - "right_elbow_extension": 1537, - "right_supination": 1538, - "right_pronation": 1539, - "right_hand_close": 1540, - "right_hand_open": 1541, - "rest": 1542} + event_id = { + "right_elbow_flexion": 1536, + "right_elbow_extension": 1537, + "right_supination": 1538, + "right_pronation": 1539, + "right_hand_close": 1540, + "right_hand_open": 1541, + "rest": 1542, + } n_sessions = int(imagined) + int(executed) super().__init__( @@ -75,7 +77,8 @@ def __init__(self, imagined=True, executed=False): code='Ofner2017', interval=[0, 3], # according to paper 2-5 paradigm='imagery', - doi='10.1371/journal.pone.0182578') + doi='10.1371/journal.pone.0182578', + ) def _get_single_subject_data(self, subject): """return data for a single subject""" @@ -95,8 +98,9 @@ def _get_single_subject_data(self, subject): montage = make_standard_montage('standard_1005') data = {} for ii, path in enumerate(paths): - raw = read_raw_gdf(path, eog=eog, misc=range(64, 96), - preload=True, verbose='ERROR') + raw = read_raw_gdf( + path, eog=eog, misc=range(64, 96), preload=True, verbose='ERROR' + ) raw.set_montage(montage) # there is nan in the data raw._data[np.isnan(raw._data)] = 0 @@ -115,10 +119,17 @@ def _get_single_subject_data(self, subject): out[session] = data return out - def data_path(self, subject, path=None, force_update=False, - update_path=None, verbose=None, session=None): + def data_path( + self, + subject, + path=None, + force_update=False, + update_path=None, + verbose=None, + session=None, + ): if subject not in self.subject_list: - raise(ValueError("Invalid subject number")) + raise (ValueError("Invalid subject number")) paths = [] @@ -135,9 +146,12 @@ def data_path(self, subject, path=None, force_update=False, # FIXME check the value are in V and not uV. for session in sessions: for run in range(1, 11): - url = (f"{UPPER_LIMB_URL}motor{session}_subject{subject}" + f"_run{run}.gdf") - p = dl.data_path(url, 'UPPERLIMB', path, force_update, - update_path, verbose) + url = ( + f"{UPPER_LIMB_URL}motor{session}_subject{subject}" + f"_run{run}.gdf" + ) + p = dl.data_path( + url, 'UPPERLIMB', path, force_update, update_path, verbose + ) paths.append(p) return paths diff --git a/moabb/datasets/utils.py b/moabb/datasets/utils.py index ffb2b3973..f595d93c5 100644 --- a/moabb/datasets/utils.py +++ b/moabb/datasets/utils.py @@ -14,10 +14,16 @@ dataset_list.append(ds[1]) -def dataset_search(paradigm, multi_session=False, events=None, # noqa: C901 - has_all_events=False, interval=None, - min_subjects=1, channels=()): - ''' +def dataset_search( + paradigm, + multi_session=False, + events=None, # noqa: C901 + has_all_events=False, + interval=None, + min_subjects=1, + channels=(), +): + """ Function that returns a list of datasets that match given criteria. Valid criteria are: @@ -46,7 +52,7 @@ def dataset_search(paradigm, multi_session=False, events=None, # noqa: C901 channels: list of str list or set of channels - ''' + """ channels = set(channels) out_data = [] if events is not None and has_all_events: @@ -97,13 +103,13 @@ def dataset_search(paradigm, multi_session=False, events=None, # noqa: C901 def find_intersecting_channels(datasets, verbose=False): - ''' + """ Given a list of dataset instances return a list of channels shared by all datasets. Skip datasets which have 0 overlap with the others returns: set of common channels, list of datasets with valid channels - ''' + """ allchans = set() dset_chans = [] keep_datasets = [] @@ -126,8 +132,9 @@ def find_intersecting_channels(datasets, verbose=False): dset_chans.append(processed) keep_datasets.append(d) else: - print('Dataset {:s} has no recognizable EEG channels'. - format(type(d).__name__)) # noqa + print( + 'Dataset {:s} has no recognizable EEG channels'.format(type(d).__name__) + ) # noqa allchans.intersection_update(*dset_chans) allchans = [s.replace('Z', 'z') for s in allchans] return allchans, keep_datasets diff --git a/moabb/evaluations/base.py b/moabb/evaluations/base.py index 06f46b11c..36884dd00 100644 --- a/moabb/evaluations/base.py +++ b/moabb/evaluations/base.py @@ -12,7 +12,7 @@ class BaseEvaluation(ABC): - '''Base class that defines necessary operations for an evaluation. + """Base class that defines necessary operations for an evaluation. Evaluations determine what the train and test sets are and can implement additional data preprocessing steps for more complicated algorithms. @@ -31,11 +31,20 @@ class BaseEvaluation(ABC): if true, overwrite the results. suffix: str suffix for the results file. - ''' - - def __init__(self, paradigm, datasets=None, random_state=None, n_jobs=1, - overwrite=False, error_score='raise', suffix='', - hdf5_path=None, additional_columns=None): + """ + + def __init__( + self, + paradigm, + datasets=None, + random_state=None, + n_jobs=1, + overwrite=False, + error_score='raise', + suffix='', + hdf5_path=None, + additional_columns=None, + ): self.random_state = random_state self.n_jobs = n_jobs self.error_score = error_score @@ -43,7 +52,7 @@ def __init__(self, paradigm, datasets=None, random_state=None, n_jobs=1, # check paradigm if not isinstance(paradigm, BaseParadigm): - raise(ValueError("paradigm must be an Paradigm instance")) + raise (ValueError("paradigm must be an Paradigm instance")) self.paradigm = paradigm # if no dataset provided, then we get the list from the paradigm @@ -54,43 +63,49 @@ def __init__(self, paradigm, datasets=None, random_state=None, n_jobs=1, if isinstance(datasets, BaseDataset): datasets = [datasets] else: - raise(ValueError("datasets must be a list or a dataset " - "instance")) + raise (ValueError("datasets must be a list or a dataset " "instance")) for dataset in datasets: - if not(isinstance(dataset, BaseDataset)): - raise(ValueError("datasets must only contains dataset " - "instance")) + if not (isinstance(dataset, BaseDataset)): + raise (ValueError("datasets must only contains dataset " "instance")) rm = [] for dataset in datasets: # fixme, we might want to drop dataset that are not compatible valid_for_paradigm = self.paradigm.is_valid(dataset) valid_for_eval = self.is_valid(dataset) if not valid_for_paradigm: - log.warning(f"{dataset} not compatible with " - "paradigm. Removing this dataset from the list.") + log.warning( + f"{dataset} not compatible with " + "paradigm. Removing this dataset from the list." + ) rm.append(dataset) elif not valid_for_eval: - log.warning(f"{dataset} not compatible with evaluation. " - "Removing this dataset from the list.") + log.warning( + f"{dataset} not compatible with evaluation. " + "Removing this dataset from the list." + ) rm.append(dataset) [datasets.remove(r) for r in rm] if len(datasets) > 0: self.datasets = datasets else: - raise Exception('''No datasets left after paradigm - and evaluation checks''') - - self.results = Results(type(self), - type(self.paradigm), - overwrite=overwrite, - suffix=suffix, - hdf5_path=self.hdf5_path, - additional_columns=additional_columns) + raise Exception( + '''No datasets left after paradigm + and evaluation checks''' + ) + + self.results = Results( + type(self), + type(self.paradigm), + overwrite=overwrite, + suffix=suffix, + hdf5_path=self.hdf5_path, + additional_columns=additional_columns, + ) def process(self, pipelines): - '''Runs all pipelines on all datasets. + """Runs all pipelines on all datasets. This function will apply all provided pipelines and return a dataframe containing the results of the evaluation. @@ -105,16 +120,15 @@ def process(self, pipelines): results: pd.DataFrame A dataframe containing the results. - ''' + """ # check pipelines if not isinstance(pipelines, dict): - raise(ValueError("pipelines must be a dict")) + raise (ValueError("pipelines must be a dict")) for _, pipeline in pipelines.items(): - if not(isinstance(pipeline, BaseEstimator)): - raise(ValueError("pipelines must only contains Pipelines " - "instance")) + if not (isinstance(pipeline, BaseEstimator)): + raise (ValueError("pipelines must only contains Pipelines " "instance")) for dataset in self.datasets: log.info('Processing dataset: {}'.format(dataset.code)) @@ -126,8 +140,9 @@ def process(self, pipelines): def push_result(self, res, pipelines): message = '{} | '.format(res['pipeline']) - message += '{} | {} | {}'.format(res['dataset'].code, - res['subject'], res['session']) + message += '{} | {} | {}'.format( + res['dataset'].code, res['subject'], res['session'] + ) message += ': Score %.3f' % res['score'] log.info(message) self.results.add({res['pipeline']: res}, pipelines=pipelines) @@ -137,7 +152,7 @@ def get_results(self): @abstractmethod def evaluate(self, dataset, pipelines): - '''Evaluate results on a single dataset. + """Evaluate results on a single dataset. This method return a generator. each results item is a dict with the following convension:: @@ -150,7 +165,7 @@ def evaluate(self, dataset, pipelines): 'n_samples': number of training examples, 'n_channels': number of channel, 'pipeline': pipeline name} - ''' + """ pass @abstractmethod diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index af0b32c0c..71a4d49da 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -26,9 +26,7 @@ def evaluate(self, dataset, pipelines): for subject in dataset.subject_list: # check if we already have result for this subject/pipeline # we might need a better granularity, if we query the DB - run_pipes = self.results.not_yet_computed(pipelines, - dataset, - subject) + run_pipes = self.results.not_yet_computed(pipelines, dataset, subject) if len(run_pipes) == 0: continue @@ -42,17 +40,18 @@ def evaluate(self, dataset, pipelines): for name, clf in run_pipes.items(): t_start = time() - score = self.score(clf, X[ix], y[ix], - self.paradigm.scoring) + score = self.score(clf, X[ix], y[ix], self.paradigm.scoring) duration = time() - t_start - res = {'time': duration / 5., # 5 fold CV - 'dataset': dataset, - 'subject': subject, - 'session': session, - 'score': score, - 'n_samples': len(y[ix]), # not training sample - 'n_channels': X.shape[1], - 'pipeline': name} + res = { + 'time': duration / 5.0, # 5 fold CV + 'dataset': dataset, + 'subject': subject, + 'session': session, + 'score': score, + 'n_samples': len(y[ix]), # not training sample + 'n_channels': X.shape[1], + 'pipeline': name, + } yield res @@ -61,8 +60,15 @@ def score(self, clf, X, y, scoring): le = LabelEncoder() y = le.fit_transform(y) - acc = cross_val_score(clf, X, y, cv=cv, scoring=scoring, - n_jobs=self.n_jobs, error_score=self.error_score) + acc = cross_val_score( + clf, + X, + y, + cv=cv, + scoring=scoring, + n_jobs=self.n_jobs, + error_score=self.error_score, + ) return acc.mean() def is_valid(self, dataset): @@ -84,9 +90,7 @@ def evaluate(self, dataset, pipelines): for subject in dataset.subject_list: # check if we already have result for this subject/pipeline # we might need a better granularity, if we query the DB - run_pipes = self.results.not_yet_computed(pipelines, - dataset, - subject) + run_pipes = self.results.not_yet_computed(pipelines, dataset, subject) if len(run_pipes) == 0: continue @@ -103,24 +107,33 @@ def evaluate(self, dataset, pipelines): cv = LeaveOneGroupOut() for train, test in cv.split(X, y, groups): t_start = time() - score = _fit_and_score(clone(clf), X, y, scorer, train, - test, verbose=False, - parameters=None, - fit_params=None, - error_score=self.error_score)[0] + score = _fit_and_score( + clone(clf), + X, + y, + scorer, + train, + test, + verbose=False, + parameters=None, + fit_params=None, + error_score=self.error_score, + )[0] duration = time() - t_start - res = {'time': duration, - 'dataset': dataset, - 'subject': subject, - 'session': groups[test][0], - 'score': score, - 'n_samples': len(train), - 'n_channels': X.shape[1], - 'pipeline': name} + res = { + 'time': duration, + 'dataset': dataset, + 'subject': subject, + 'session': groups[test][0], + 'score': score, + 'n_samples': len(train), + 'n_channels': X.shape[1], + 'pipeline': name, + } yield res def is_valid(self, dataset): - return (dataset.n_sessions > 1) + return dataset.n_sessions > 1 class CrossSubjectEvaluation(BaseEvaluation): @@ -140,9 +153,7 @@ def evaluate(self, dataset, pipelines): # we might need a better granularity, if we query the DB run_pipes = {} for subject in dataset.subject_list: - run_pipes.update(self.results.not_yet_computed(pipelines, - dataset, - subject)) + run_pipes.update(self.results.not_yet_computed(pipelines, dataset, subject)) if len(run_pipes) != 0: # get the data @@ -164,8 +175,7 @@ def evaluate(self, dataset, pipelines): subject = groups[test[0]] # now we can check if this subject has results - run_pipes = self.results.not_yet_computed(pipelines, dataset, - subject) + run_pipes = self.results.not_yet_computed(pipelines, dataset, subject) # iterate over pipelines for name, clf in run_pipes.items(): @@ -178,16 +188,18 @@ def evaluate(self, dataset, pipelines): ix = sessions[test] == session score = _score(model, X[test[ix]], y[test[ix]], scorer) - res = {'time': duration, - 'dataset': dataset, - 'subject': subject, - 'session': session, - 'score': score, - 'n_samples': len(train), - 'n_channels': X.shape[1], - 'pipeline': name} + res = { + 'time': duration, + 'dataset': dataset, + 'subject': subject, + 'session': session, + 'score': score, + 'n_samples': len(train), + 'n_channels': X.shape[1], + 'pipeline': name, + } yield res def is_valid(self, dataset): - return (len(dataset.subject_list) > 1) + return len(dataset.subject_list) > 1 diff --git a/moabb/paradigms/base.py b/moabb/paradigms/base.py index ead4f9cd3..0967791c0 100644 --- a/moabb/paradigms/base.py +++ b/moabb/paradigms/base.py @@ -10,26 +10,23 @@ class BaseParadigm(metaclass=ABCMeta): - """Base Paradigm. - """ + """Base Paradigm.""" def __init__(self): pass @abstractproperty def scoring(self): - '''Property that defines scoring metric (e.g. ROC-AUC or accuracy + """Property that defines scoring metric (e.g. ROC-AUC or accuracy or f-score), given as a sklearn-compatible string or a compatible sklearn scorer. - ''' + """ pass @abstractproperty def datasets(self): - '''Property that define the list of compatible datasets - - ''' + """Property that define the list of compatible datasets""" pass @abstractmethod @@ -105,18 +102,16 @@ def process_raw(self, raw, dataset, return_epochs=False): # noqa: C901 event_id = self.used_events(dataset) # find the events, first check stim_channels then annotations - stim_channels = mne.utils._get_stim_channel(None, raw.info, - raise_error=False) + stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False) if len(stim_channels) > 0: events = mne.find_events(raw, shortest_event=0, verbose=False) else: try: - events, _ = mne.events_from_annotations(raw, - event_id=event_id, - verbose=False) + events, _ = mne.events_from_annotations( + raw, event_id=event_id, verbose=False + ) except ValueError: - log.warning("No matching annotations in {}" - .format(raw.filenames)) + log.warning("No matching annotations in {}".format(raw.filenames)) return # picks channels @@ -143,24 +138,35 @@ def process_raw(self, raw, dataset, return_epochs=False): # noqa: C901 for bandpass in self.filters: fmin, fmax = bandpass # filter data - raw_f = raw.copy().filter(fmin, fmax, method='iir', - picks=picks, verbose=False) + raw_f = raw.copy().filter( + fmin, fmax, method='iir', picks=picks, verbose=False + ) # epoch data baseline = self.baseline if baseline is not None: - baseline = (self.baseline[0] + dataset.interval[0], - self.baseline[1] + dataset.interval[0]) + baseline = ( + self.baseline[0] + dataset.interval[0], + self.baseline[1] + dataset.interval[0], + ) bmin = baseline[0] if baseline[0] < tmin else tmin bmax = baseline[1] if baseline[1] > tmax else tmax else: bmin = tmin bmax = tmax - epochs = mne.Epochs(raw_f, events, event_id=event_id, - tmin=bmin, tmax=bmax, proj=False, - baseline=baseline, preload=True, - verbose=False, picks=picks, - event_repeated='drop', - on_missing='ignore') + epochs = mne.Epochs( + raw_f, + events, + event_id=event_id, + tmin=bmin, + tmax=bmax, + proj=False, + baseline=baseline, + preload=True, + verbose=False, + picks=picks, + event_repeated='drop', + on_missing='ignore', + ) if bmin < tmin or bmax > tmax: epochs.crop(tmin=tmin, tmax=tmax) if self.resample is not None: @@ -217,8 +223,7 @@ def get_data(self, dataset, subjects=None, return_epochs=False): """ if not self.is_valid(dataset): - message = "Dataset {} is not valid for paradigm".format( - dataset.code) + message = "Dataset {} is not valid for paradigm".format(dataset.code) raise AssertionError(message) data = dataset.get_data(subjects) @@ -230,8 +235,7 @@ def get_data(self, dataset, subjects=None, return_epochs=False): for subject, sessions in data.items(): for session, runs in sessions.items(): for run, raw in runs.items(): - proc = self.process_raw(raw, dataset, - return_epochs=return_epochs) + proc = self.process_raw(raw, dataset, return_epochs=return_epochs) if proc is None: # this mean the run did not contain any selected event diff --git a/moabb/paradigms/motor_imagery.py b/moabb/paradigms/motor_imagery.py index 2673eccc6..8c42cbeb3 100644 --- a/moabb/paradigms/motor_imagery.py +++ b/moabb/paradigms/motor_imagery.py @@ -53,8 +53,16 @@ class BaseMotorImagery(BaseParadigm): If not None, resample the eeg data with the sampling rate provided. """ - def __init__(self, filters=([7, 35],), events=None, tmin=0.0, tmax=None, - baseline=None, channels=None, resample=None): + def __init__( + self, + filters=([7, 35],), + events=None, + tmin=0.0, + tmax=None, + baseline=None, + channels=None, + resample=None, + ): super().__init__() self.filters = filters self.events = events @@ -62,9 +70,9 @@ def __init__(self, filters=([7, 35],), events=None, tmin=0.0, tmax=None, self.baseline = baseline self.resample = resample - if (tmax is not None): + if tmax is not None: if tmin >= tmax: - raise(ValueError("tmax must be greater than tmin")) + raise (ValueError("tmax must be greater than tmin")) self.tmin = tmin self.tmax = tmax @@ -92,10 +100,9 @@ def datasets(self): interval = None else: interval = self.tmax - self.tmin - return utils.dataset_search(paradigm='imagery', - events=self.events, - interval=interval, - has_all_events=True) + return utils.dataset_search( + paradigm='imagery', events=self.events, interval=interval, has_all_events=True + ) @property def scoring(self): @@ -149,15 +156,18 @@ class SinglePass(BaseMotorImagery): def __init__(self, fmin=8, fmax=32, **kwargs): if 'filters' in kwargs.keys(): - raise(ValueError("MotorImagery does not take argument filters")) + raise (ValueError("MotorImagery does not take argument filters")) super().__init__(filters=[[fmin, fmax]], **kwargs) class FilterBank(BaseMotorImagery): """Filter Bank MI.""" - def __init__(self, filters=([8, 12], [12, 16], [16, 20], [20, 24], - [24, 28], [28, 32]), **kwargs): + def __init__( + self, + filters=([8, 12], [12, 16], [16, 20], [20, 24], [24, 28], [28, 32]), + **kwargs, + ): """init""" super().__init__(filters=filters, **kwargs) @@ -171,7 +181,7 @@ class LeftRightImagery(SinglePass): def __init__(self, **kwargs): if 'events' in kwargs.keys(): - raise(ValueError('LeftRightImagery dont accept events')) + raise (ValueError('LeftRightImagery dont accept events')) super().__init__(events=['left_hand', 'right_hand'], **kwargs) def used_events(self, dataset): @@ -191,7 +201,7 @@ class FilterBankLeftRightImagery(FilterBank): def __init__(self, **kwargs): if 'events' in kwargs.keys(): - raise(ValueError('LeftRightImagery dont accept events')) + raise (ValueError('LeftRightImagery dont accept events')) super().__init__(events=['left_hand', 'right_hand'], **kwargs) def used_events(self, dataset): @@ -228,8 +238,7 @@ def __init__(self, n_classes=2, **kwargs): if self.events is None: log.warning("Choosing from all possible events") else: - assert n_classes <= len( - self.events), 'More classes than events specified' + assert n_classes <= len(self.events), 'More classes than events specified' def is_valid(self, dataset): ret = True @@ -258,8 +267,12 @@ def used_events(self, dataset): if len(out) == self.n_classes: break if len(out) < self.n_classes: - raise(ValueError(f"Dataset {dataset.code} did not have enough " - f"events in {self.events} to run analysis")) + raise ( + ValueError( + f"Dataset {dataset.code} did not have enough " + f"events in {self.events} to run analysis" + ) + ) return out @property @@ -268,11 +281,13 @@ def datasets(self): interval = None else: interval = self.tmax - self.tmin - return utils.dataset_search(paradigm='imagery', - events=self.events, - total_classes=self.n_classes, - interval=interval, - has_all_events=False) + return utils.dataset_search( + paradigm='imagery', + events=self.events, + total_classes=self.n_classes, + interval=interval, + has_all_events=False, + ) @property def scoring(self): @@ -339,8 +354,7 @@ def __init__(self, n_classes=2, **kwargs): if self.events is None: log.warning("Choosing from all possible events") else: - assert n_classes <= len( - self.events), 'More classes than events specified' + assert n_classes <= len(self.events), 'More classes than events specified' def is_valid(self, dataset): ret = True @@ -369,8 +383,12 @@ def used_events(self, dataset): if len(out) == self.n_classes: break if len(out) < self.n_classes: - raise(ValueError(f"Dataset {dataset.code} did not have enough " - f"events in {self.events} to run analysis")) + raise ( + ValueError( + f"Dataset {dataset.code} did not have enough " + f"events in {self.events} to run analysis" + ) + ) return out @property @@ -379,10 +397,12 @@ def datasets(self): interval = None else: interval = self.tmax - self.tmin - return utils.dataset_search(paradigm='imagery', - events=self.events, - interval=interval, - has_all_events=False) + return utils.dataset_search( + paradigm='imagery', + events=self.events, + interval=interval, + has_all_events=False, + ) @property def scoring(self): @@ -393,8 +413,7 @@ def scoring(self): class FakeImageryParadigm(LeftRightImagery): - """Fake Imagery for left hand/right hand classification. - """ + """Fake Imagery for left hand/right hand classification.""" @property def datasets(self): diff --git a/moabb/paradigms/p300.py b/moabb/paradigms/p300.py index 71129e6d2..9ef5928d9 100644 --- a/moabb/paradigms/p300.py +++ b/moabb/paradigms/p300.py @@ -57,8 +57,16 @@ class BaseP300(BaseParadigm): If not None, resample the eeg data with the sampling rate provided. """ - def __init__(self, filters=([1, 24],), events=None, tmin=0.0, tmax=None, - baseline=None, channels=None, resample=None): + def __init__( + self, + filters=([1, 24],), + events=None, + tmin=0.0, + tmax=None, + baseline=None, + channels=None, + resample=None, + ): super().__init__() self.filters = filters self.events = events @@ -66,9 +74,9 @@ def __init__(self, filters=([1, 24],), events=None, tmin=0.0, tmax=None, self.baseline = baseline self.resample = resample - if (tmax is not None): + if tmax is not None: if tmin >= tmax: - raise(ValueError("tmax must be greater than tmin")) + raise (ValueError("tmax must be greater than tmin")) self.tmin = tmin self.tmax = tmax @@ -92,8 +100,7 @@ def used_events(self, dataset): def process_raw(self, raw, dataset, return_epochs=False): # find the events, first check stim_channels then annotations - stim_channels = mne.utils._get_stim_channel( - None, raw.info, raise_error=False) + stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False) if len(stim_channels) > 0: events = mne.find_events(raw, shortest_event=0, verbose=False) else: @@ -102,15 +109,14 @@ def process_raw(self, raw, dataset, return_epochs=False): channels = () if self.channels is None else self.channels # picks channels - picks = mne.pick_types(raw.info, eeg=True, stim=False, - include=channels) + picks = mne.pick_types(raw.info, eeg=True, stim=False, include=channels) # get event id event_id = self.used_events(dataset) # pick events, based on event_id try: - if (type(event_id['Target']) is list and type(event_id['NonTarget']) == list): + if type(event_id['Target']) is list and type(event_id['NonTarget']) == list: event_id_new = dict(Target=1, NonTarget=0) events = mne.merge_events(events, event_id['Target'], 1) events = mne.merge_events(events, event_id['NonTarget'], 0) @@ -131,23 +137,34 @@ def process_raw(self, raw, dataset, return_epochs=False): for bandpass in self.filters: fmin, fmax = bandpass # filter data - raw_f = raw.copy().filter(fmin, fmax, method='iir', - picks=picks, verbose=False) + raw_f = raw.copy().filter( + fmin, fmax, method='iir', picks=picks, verbose=False + ) # epoch data baseline = self.baseline if baseline is not None: - baseline = (self.baseline[0] + dataset.interval[0], - self.baseline[1] + dataset.interval[0]) + baseline = ( + self.baseline[0] + dataset.interval[0], + self.baseline[1] + dataset.interval[0], + ) bmin = baseline[0] if baseline[0] < tmin else tmin bmax = baseline[1] if baseline[1] > tmax else tmax else: bmin = tmin bmax = tmax - epochs = mne.Epochs(raw_f, events, event_id=event_id, - tmin=bmin, tmax=bmax, proj=False, - baseline=baseline, preload=True, - verbose=False, picks=picks, - on_missing='ignore') + epochs = mne.Epochs( + raw_f, + events, + event_id=event_id, + tmin=bmin, + tmax=bmax, + proj=False, + baseline=baseline, + preload=True, + verbose=False, + picks=picks, + on_missing='ignore', + ) if bmin < tmin or bmax > tmax: epochs.crop(tmin=tmin, tmax=tmax) if self.resample is not None: @@ -176,10 +193,9 @@ def datasets(self): interval = None else: interval = self.tmax - self.tmin - return utils.dataset_search(paradigm='p300', - events=self.events, - interval=interval, - has_all_events=True) + return utils.dataset_search( + paradigm='p300', events=self.events, interval=interval, has_all_events=True + ) @property def scoring(self): @@ -230,9 +246,10 @@ class SinglePass(BaseP300): If not None, resample the eeg data with the sampling rate provided. """ + def __init__(self, fmin=1, fmax=24, **kwargs): if 'filters' in kwargs.keys(): - raise(ValueError("P300 does not take argument filters")) + raise (ValueError("P300 does not take argument filters")) super().__init__(filters=[[fmin, fmax]], **kwargs) @@ -245,7 +262,7 @@ class P300(SinglePass): def __init__(self, **kwargs): if 'events' in kwargs.keys(): - raise(ValueError('P300 dont accept events')) + raise (ValueError('P300 dont accept events')) super().__init__(events=['Target', 'NonTarget'], **kwargs) def used_events(self, dataset): @@ -257,8 +274,7 @@ def scoring(self): class FakeP300Paradigm(P300): - """Fake P300 for Target/NonTarget classification. - """ + """Fake P300 for Target/NonTarget classification.""" @property def datasets(self): diff --git a/moabb/paradigms/ssvep.py b/moabb/paradigms/ssvep.py index cff83601a..d584ad6dc 100644 --- a/moabb/paradigms/ssvep.py +++ b/moabb/paradigms/ssvep.py @@ -52,9 +52,17 @@ class BaseSSVEP(BaseParadigm): If not None, resample the eeg data with the sampling rate provided. """ - def __init__(self, filters=((7, 45), ), events=None, n_classes=None, - tmin=0.0, tmax=None, baseline=None, channels=None, - resample=None): + def __init__( + self, + filters=((7, 45),), + events=None, + n_classes=None, + tmin=0.0, + tmax=None, + baseline=None, + channels=None, + resample=None, + ): super().__init__() self.filters = filters self.events = events @@ -64,16 +72,19 @@ def __init__(self, filters=((7, 45), ), events=None, n_classes=None, self.resample = resample if tmax is not None and tmin >= tmax: - raise(ValueError("tmax must be greater than tmin")) + raise (ValueError("tmax must be greater than tmin")) self.tmin = tmin self.tmax = tmax if self.events is None: - log.warning("Choosing the first " + str(n_classes) + " classes" - + " from all possible events") + log.warning( + "Choosing the first " + + str(n_classes) + + " classes" + + " from all possible events" + ) else: - assert n_classes <= len( - self.events), 'More classes than events specified' + assert n_classes <= len(self.events), 'More classes than events specified' def is_valid(self, dataset): ret = True @@ -101,8 +112,12 @@ def used_events(self, dataset): if self.n_classes and len(out) == self.n_classes: break if self.n_classes and len(out) < self.n_classes: - raise(ValueError(f"Dataset {dataset.code} did not have enough " - f"freqs in {self.events} to run analysis")) + raise ( + ValueError( + f"Dataset {dataset.code} did not have enough " + f"freqs in {self.events} to run analysis" + ) + ) return out def prepare_process(self, dataset): @@ -110,9 +125,11 @@ def prepare_process(self, dataset): # get filters if self.filters is None: - self.filters = [[float(f) - 0.5, float(f) + 0.5] - for f in event_id.keys() - if f.replace('.', '', 1).isnumeric()] + self.filters = [ + [float(f) - 0.5, float(f) + 0.5] + for f in event_id.keys() + if f.replace('.', '', 1).isnumeric() + ] @property def datasets(self): @@ -120,11 +137,13 @@ def datasets(self): interval = None else: interval = self.tmax - self.tmin - return utils.dataset_search(paradigm='ssvep', - events=self.events, - # total_classes=self.n_classes, - interval=interval, - has_all_events=True) + return utils.dataset_search( + paradigm='ssvep', + events=self.events, + # total_classes=self.n_classes, + interval=interval, + has_all_events=True, + ) @property def scoring(self): @@ -184,12 +203,12 @@ class SSVEP(BaseSSVEP): def __init__(self, fmin=7, fmax=45, **kwargs): if 'filters' in kwargs.keys(): - raise(ValueError("SSVEP does not take argument filters")) + raise (ValueError("SSVEP does not take argument filters")) super().__init__(filters=[(fmin, fmax)], **kwargs) class FilterBankSSVEP(BaseSSVEP): - """ Filtered bank n-class SSVEP paradigm + """Filtered bank n-class SSVEP paradigm SSVEP paradigm with multiple narrow bandpass filters, centered around the frequencies of considered events. @@ -240,8 +259,7 @@ def __init__(self, filters=None, **kwargs): class FakeSSVEPParadigm(BaseSSVEP): - """Fake SSVEP classification. - """ + """Fake SSVEP classification.""" @property def datasets(self): diff --git a/moabb/pipelines/classification.py b/moabb/pipelines/classification.py index cb4d1efb0..02406e4b9 100644 --- a/moabb/pipelines/classification.py +++ b/moabb/pipelines/classification.py @@ -20,6 +20,7 @@ class SSVEP_CCA(BaseEstimator, ClassifierMixin): engineering, 6(4), 046002. https://doi.org/10.1088/1741-2560/6/4/046002 """ + def __init__(self, interval, freqs, n_harmonics=3): self.Yf = dict() self.cca = CCA(n_components=1) @@ -43,8 +44,12 @@ def fit(self, X, y, sample_weight=None): freq = float(f) yf = [] for h in range(1, self.n_harmonics + 1): - yf.append(np.sin(2 * np.pi * freq * h * np.linspace(0, self.slen, n_times))) - yf.append(np.cos(2 * np.pi * freq * h * np.linspace(0, self.slen, n_times))) + yf.append( + np.sin(2 * np.pi * freq * h * np.linspace(0, self.slen, n_times)) + ) + yf.append( + np.cos(2 * np.pi * freq * h * np.linspace(0, self.slen, n_times)) + ) self.Yf[f] = np.array(yf) return self diff --git a/moabb/pipelines/csp.py b/moabb/pipelines/csp.py index b62396d8d..ae2313507 100644 --- a/moabb/pipelines/csp.py +++ b/moabb/pipelines/csp.py @@ -5,19 +5,19 @@ class TRCSP(CSP): - ''' + """ Weighted Tikhonov-regularized CSP as described in Lotte and Guan 2011 - ''' + """ def __init__(self, nfilter=4, metric='euclid', log=True, alpha=1): super().__init__(nfilter, metric, log) self.alpha = alpha def fit(self, X, y): - ''' + """ Train spatial filters. Only deals with two class - ''' + """ if not isinstance(X, (np.ndarray, list)): raise TypeError('X must be an array.') @@ -56,8 +56,8 @@ def fit(self, X, y): evecs[i] = evecs[i][:, ix] # spatial patterns A = np.linalg.pinv(evecs[i].T) - filters.append(evecs[i][:, :(self.nfilter // 2)]) - patterns.append(A[:, :(self.nfilter // 2)]) + filters.append(evecs[i][:, : (self.nfilter // 2)]) + patterns.append(A[:, : (self.nfilter // 2)]) self.filters_ = np.concatenate(filters, axis=1).T self.patterns_ = np.concatenate(patterns, axis=1).T diff --git a/moabb/pipelines/features.py b/moabb/pipelines/features.py index f4d5f85b6..03e9d0117 100644 --- a/moabb/pipelines/features.py +++ b/moabb/pipelines/features.py @@ -4,7 +4,6 @@ class LogVariance(BaseEstimator, TransformerMixin): - def fit(self, X, y): """fit.""" return self @@ -16,14 +15,13 @@ def transform(self, X): class FM(BaseEstimator, TransformerMixin): - def __init__(self, freq=128): - '''instantaneous frequencies require a sampling frequency to be properly + """instantaneous frequencies require a sampling frequency to be properly scaled, which is helpful for some algorithms. This assumes 128 if not told otherwise. - ''' + """ self.freq = freq def fit(self, X, y): @@ -33,8 +31,7 @@ def fit(self, X, y): def transform(self, X): """transform. """ xphase = np.unwrap(np.angle(signal.hilbert(X, axis=-1))) - return np.median(self.freq * np.diff(xphase, axis=-1) / (2 * np.pi), - axis=-1) + return np.median(self.freq * np.diff(xphase, axis=-1) / (2 * np.pi), axis=-1) class ExtendedSSVEPSignal(BaseEstimator, TransformerMixin): @@ -47,6 +44,7 @@ class ExtendedSSVEPSignal(BaseEstimator, TransformerMixin): and should be convert in (n_trials, n_channels*n_freqs, n_times) to estimate covariance matrices of (n_channels*n_freqs, n_channels*n_freqs). """ + def __init__(self): """Empty init for ExtendedSSVEPSignal""" pass @@ -56,8 +54,7 @@ def fit(self, X, y): return self def transform(self, X): - """Transpose and reshape EEG for extended covmat estimation - """ + """Transpose and reshape EEG for extended covmat estimation""" out = X.transpose((0, 3, 1, 2)) n_trials, n_freqs, n_channels, n_times = out.shape out = out.reshape((n_trials, n_channels * n_freqs, n_times)) diff --git a/moabb/pipelines/utils.py b/moabb/pipelines/utils.py index e0c51ffe9..eb52727c6 100644 --- a/moabb/pipelines/utils.py +++ b/moabb/pipelines/utils.py @@ -65,16 +65,17 @@ def __init__(self, estimator, flatten=True): def fit(self, X, y=None): assert X.ndim == 4 self.models = [ - deepcopy(self.estimator).fit(X[..., i], y) - for i in range(X.shape[-1]) + deepcopy(self.estimator).fit(X[..., i], y) for i in range(X.shape[-1]) ] return self def transform(self, X): assert X.ndim == 4 out = [self.models[i].transform(X[..., i]) for i in range(X.shape[-1])] - assert out[0].ndim == 2, ("Each band must return a two dimensional " - f" matrix, currently have {out[0].ndim}") + assert out[0].ndim == 2, ( + "Each band must return a two dimensional " + f" matrix, currently have {out[0].ndim}" + ) if self.flatten: return np.concatenate(out, axis=1) else: @@ -84,4 +85,5 @@ def __repr__(self): estimator_name = type(self).__name__ estimator_prms = self.estimator.get_params() return '{}(estimator={}, flatten={})'.format( - estimator_name, estimator_prms, self.flatten) + estimator_name, estimator_prms, self.flatten + ) diff --git a/moabb/run.py b/moabb/run.py index 3dcba4f03..f7f6e3f81 100755 --- a/moabb/run.py +++ b/moabb/run.py @@ -34,54 +34,54 @@ def parser_init(): dest="pipelines", type=str, default='./pipelines/', - help="Folder containing the pipelines to evaluates.") + help="Folder containing the pipelines to evaluates.", + ) parser.add_argument( "-r", "--results", dest="results", type=str, default='./results/', - help="Folder to store the results.") + help="Folder to store the results.", + ) parser.add_argument( "-f", "--force-update", dest="force", action="store_true", default=False, - help="Force evaluation of cached pipelines.") + help="Force evaluation of cached pipelines.", + ) parser.add_argument( - "-v", - "--verbose", - dest="verbose", - action="store_true", - default=False) + "-v", "--verbose", dest="verbose", action="store_true", default=False + ) parser.add_argument( "-d", "--debug", dest="debug", action="store_true", default=False, - help="Print debug level parse statements. Overrides verbose") + help="Print debug level parse statements. Overrides verbose", + ) parser.add_argument( "-o", "--output", dest="output", type=str, default='./', - help="Folder to put analysis results") + help="Folder to put analysis results", + ) parser.add_argument( - "--threads", - dest="threads", - type=int, - default=1, - help="Number of threads to run") + "--threads", dest="threads", type=int, default=1, help="Number of threads to run" + ) parser.add_argument( "--plot", dest="plot", action="store_true", default=False, - help="Plot results after computing. Defaults false") + help="Plot results after computing. Defaults false", + ) parser.add_argument( "-c", "--contexts", @@ -90,20 +90,22 @@ def parser_init(): default=None, help="File path to context.yml file that describes context parameters." "If none, assumes all defaults. Must contain an entry for all " - "paradigms described in the pipelines") + "paradigms described in the pipelines", + ) return parser def parse_pipelines_from_directory(d): - ''' + """ Given directory, returns generated pipeline config dictionaries. Each entry has structure: 'name': string 'pipeline': sklearn.BaseEstimator 'paradigms': list of class names that are compatible with said pipeline - ''' - assert os.path.isdir(os.path.abspath(d) - ), "Given pipeline path {} is not valid".format(d) + """ + assert os.path.isdir( + os.path.abspath(d) + ), "Given pipeline path {} is not valid".format(d) # get list of config files yaml_files = glob(os.path.join(d, '*.yml')) @@ -116,9 +118,13 @@ def parse_pipelines_from_directory(d): # load config config_dict = yaml.load(content, Loader=yaml.FullLoader) ppl = create_pipeline_from_config(config_dict['pipeline']) - pipeline_configs.append({'paradigms': config_dict['paradigms'], - 'pipeline': ppl, - 'name': config_dict['name']}) + pipeline_configs.append( + { + 'paradigms': config_dict['paradigms'], + 'pipeline': ppl, + 'name': config_dict['name'], + } + ) # we can do the same for python defined pipeline python_files = glob(os.path.join(d, '*.py')) @@ -149,14 +155,17 @@ def generate_paradigms(pipeline_configs, context=None): if len(context) > 0: if paradigm not in context.keys(): log.debug(context) - log.warning("Paradigm {} not in context file {}".format( - paradigm, context.keys())) + log.warning( + "Paradigm {} not in context file {}".format( + paradigm, context.keys() + ) + ) if isinstance(config['pipeline'], BaseEstimator): pipeline = deepcopy(config['pipeline']) else: log.error(config['pipeline']) - raise(ValueError('pipeline must be a sklearn estimator')) + raise (ValueError('pipeline must be a sklearn estimator')) # append the pipeline in the paradigm list if paradigm not in paradigms.keys(): @@ -203,10 +212,9 @@ def generate_paradigms(pipeline_configs, context=None): # get the context log.debug('{}: {}'.format(paradigm, context_params[paradigm])) p = getattr(moabb_paradigms, paradigm)(**context_params[paradigm]) - context = WithinSessionEvaluation(paradigm=p, random_state=42, - n_jobs=options.threads, - overwrite=options.force) + context = WithinSessionEvaluation( + paradigm=p, random_state=42, n_jobs=options.threads, overwrite=options.force + ) results = context.process(pipelines=paradigms[paradigm]) all_results.append(results) - analyze(pd.concat(all_results, ignore_index=True), options.output, - plot=options.plot) + analyze(pd.concat(all_results, ignore_index=True), options.output, plot=options.plot) diff --git a/moabb/tests/analysis.py b/moabb/tests/analysis.py index 1b584bcac..f9c8e6fe4 100644 --- a/moabb/tests/analysis.py +++ b/moabb/tests/analysis.py @@ -15,7 +15,6 @@ class DummyEvaluation(BaseEvaluation): - def evaluate(self, dataset, pipelines): raise NotImplementedError('dummy') @@ -24,7 +23,6 @@ def is_valid(self, dataset): class DummyParadigm(BaseParadigm): - def __init__(self): pass @@ -44,38 +42,46 @@ def datasets(self): # Create dummy data for tests -d1 = {'time': 1, - 'dataset': FakeDataset(['d1', 'd2']), - 'subject': 1, - 'session': 'session_0', - 'score': 0.9, - 'n_samples': 100, - 'n_channels': 10} - -d2 = {'time': 2, - 'dataset': FakeDataset(['d1', 'd2']), - 'subject': 2, - 'session': 'session_0', - 'score': 0.9, - 'n_samples': 100, - 'n_channels': 10} - - -d3 = {'time': 2, - 'dataset': FakeDataset(['d1', 'd2']), - 'subject': 2, - 'session': 'session_0', - 'score': 0.9, - 'n_samples': 100, - 'n_channels': 10} - -d4 = {'time': 2, - 'dataset': FakeDataset(['d1', 'd2']), - 'subject': 1, - 'session': 'session_0', - 'score': 0.9, - 'n_samples': 100, - 'n_channels': 10} +d1 = { + 'time': 1, + 'dataset': FakeDataset(['d1', 'd2']), + 'subject': 1, + 'session': 'session_0', + 'score': 0.9, + 'n_samples': 100, + 'n_channels': 10, +} + +d2 = { + 'time': 2, + 'dataset': FakeDataset(['d1', 'd2']), + 'subject': 2, + 'session': 'session_0', + 'score': 0.9, + 'n_samples': 100, + 'n_channels': 10, +} + + +d3 = { + 'time': 2, + 'dataset': FakeDataset(['d1', 'd2']), + 'subject': 2, + 'session': 'session_0', + 'score': 0.9, + 'n_samples': 100, + 'n_channels': 10, +} + +d4 = { + 'time': 2, + 'dataset': FakeDataset(['d1', 'd2']), + 'subject': 1, + 'session': 'session_0', + 'score': 0.9, + 'n_samples': 100, + 'n_channels': 10, +} def to_pipeline_dict(pnames): @@ -87,7 +93,6 @@ def to_result_input(pnames, dsets): class Test_Stats(unittest.TestCase): - def return_df(self, shape): size = shape[0] * shape[1] data = np.arange(size).reshape(*shape) @@ -109,11 +114,10 @@ def test_perm_random(self): class Test_Integration(unittest.TestCase): - def setUp(self): - self.obj = Results(evaluation_class=DummyEvaluation, - paradigm_class=DummyParadigm, - suffix='test') + self.obj = Results( + evaluation_class=DummyEvaluation, paradigm_class=DummyParadigm, suffix='test' + ) def tearDown(self): path = self.obj.filepath @@ -122,11 +126,10 @@ def tearDown(self): class Test_Results(unittest.TestCase): - def setUp(self): - self.obj = Results(evaluation_class=DummyEvaluation, - paradigm_class=DummyParadigm, - suffix='test') + self.obj = Results( + evaluation_class=DummyEvaluation, paradigm_class=DummyParadigm, suffix='test' + ) def tearDown(self): path = self.obj.filepath @@ -140,7 +143,8 @@ def testRecognizesAlreadyComputed(self): _in = to_result_input(['a'], [d1]) self.obj.add(_in, to_pipeline_dict(['a'])) not_yet_computed = self.obj.not_yet_computed( - to_pipeline_dict(['a']), d1['dataset'], d1['subject']) + to_pipeline_dict(['a']), d1['dataset'], d1['subject'] + ) self.assertTrue(len(not_yet_computed) == 0) def testCanAddMultiplePipelines(self): @@ -151,13 +155,16 @@ def testCanAddMultipleValuesPerPipeline(self): _in = to_result_input(['a', 'b'], [[d1, d2], [d2, d1]]) self.obj.add(_in, to_pipeline_dict(['a', 'b'])) not_yet_computed = self.obj.not_yet_computed( - to_pipeline_dict(['a']), d1['dataset'], d1['subject']) + to_pipeline_dict(['a']), d1['dataset'], d1['subject'] + ) self.assertTrue(len(not_yet_computed) == 0, not_yet_computed) not_yet_computed = self.obj.not_yet_computed( - to_pipeline_dict(['b']), d2['dataset'], d2['subject']) + to_pipeline_dict(['b']), d2['dataset'], d2['subject'] + ) self.assertTrue(len(not_yet_computed) == 0, not_yet_computed) not_yet_computed = self.obj.not_yet_computed( - to_pipeline_dict(['b']), d1['dataset'], d1['subject']) + to_pipeline_dict(['b']), d1['dataset'], d1['subject'] + ) self.assertTrue(len(not_yet_computed) == 0, not_yet_computed) def testCanExportToDataframe(self): @@ -166,8 +173,10 @@ def testCanExportToDataframe(self): _in = to_result_input(['a', 'b', 'c'], [d2, d2, d3]) self.obj.add(_in, to_pipeline_dict(['a', 'b', 'c'])) df = self.obj.to_dataframe() - self.assertTrue(set(np.unique(df['pipeline'])) == set( - ('a', 'b', 'c')), np.unique(df['pipeline'])) + self.assertTrue( + set(np.unique(df['pipeline'])) == set(('a', 'b', 'c')), + np.unique(df['pipeline']), + ) self.assertTrue(df.shape[0] == 6, df.shape[0]) diff --git a/moabb/tests/datasets.py b/moabb/tests/datasets.py index cadf91f06..a0bb7295e 100644 --- a/moabb/tests/datasets.py +++ b/moabb/tests/datasets.py @@ -25,7 +25,6 @@ def _run_tests_on_dataset(d): class Test_Datasets(unittest.TestCase): - def test_fake_dataset(self): """this test will insure the basedataset works""" n_subjects = 3 @@ -34,8 +33,12 @@ def test_fake_dataset(self): for paradigm in ['imagery', 'p300']: - ds = FakeDataset(n_sessions=n_sessions, n_runs=n_runs, - n_subjects=n_subjects, paradigm=paradigm) + ds = FakeDataset( + n_sessions=n_sessions, + n_runs=n_runs, + n_subjects=n_subjects, + paradigm=paradigm, + ) data = ds.get_data() # we should get a dict @@ -51,8 +54,7 @@ def test_fake_dataset(self): self.assertEqual(len(data[1]['session_0']), n_runs) # We should get a raw array at the end - self.assertEqual(type(data[1]['session_0']['run_0']), - mne.io.RawArray) + self.assertEqual(type(data[1]['session_0']['run_0']), mne.io.RawArray) # bad subject id must raise error self.assertRaises(ValueError, ds.get_data, [1000]) diff --git a/moabb/tests/download.py b/moabb/tests/download.py index 9e904d112..686e9754e 100644 --- a/moabb/tests/download.py +++ b/moabb/tests/download.py @@ -21,11 +21,9 @@ class Test_Downloads(unittest.TestCase): - def run_dataset(self, dataset, subj=(0, 2)): def _get_events(raw): - stim_channels = mne.utils._get_stim_channel( - None, raw.info, raise_error=False) + stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False) if len(stim_channels) > 0: events = mne.find_events(raw, shortest_event=0, verbose=False) else: @@ -33,7 +31,7 @@ def _get_events(raw): return events obj = dataset() - obj.subject_list = obj.subject_list[subj[0]:subj[1]] + obj.subject_list = obj.subject_list[subj[0] : subj[1]] data = obj.get_data(obj.subject_list) # get data return a dict diff --git a/moabb/tests/evaluations.py b/moabb/tests/evaluations.py index 0689f4c85..cce511fcf 100644 --- a/moabb/tests/evaluations.py +++ b/moabb/tests/evaluations.py @@ -18,16 +18,17 @@ class Test_WithinSess(unittest.TestCase): - '''This is actually integration testing but I don't know how to do this + """This is actually integration testing but I don't know how to do this better. A paradigm implements pre-processing so it needs files to run MNE stuff on. To test the scoring and train/test we need to also have data and run it. Putting this on the future docket... - ''' + """ def setUp(self): - self.eval = ev.WithinSessionEvaluation(paradigm=FakeImageryParadigm(), - datasets=[dataset]) + self.eval = ev.WithinSessionEvaluation( + paradigm=FakeImageryParadigm(), datasets=[dataset] + ) def tearDown(self): path = self.eval.results.filepath @@ -42,11 +43,12 @@ def test_eval_results(self): class Test_AdditionalColumns(unittest.TestCase): - def setUp(self): self.eval = ev.WithinSessionEvaluation( - paradigm=FakeImageryParadigm(), datasets=[dataset], - additional_columns=['one', 'two']) + paradigm=FakeImageryParadigm(), + datasets=[dataset], + additional_columns=['one', 'two'], + ) def tearDown(self): path = self.eval.results.filepath @@ -59,10 +61,10 @@ def test_fails_if_nothing_returned(self): class Test_CrossSubj(Test_WithinSess): - def setUp(self): - self.eval = ev.CrossSubjectEvaluation(paradigm=FakeImageryParadigm(), - datasets=[dataset]) + self.eval = ev.CrossSubjectEvaluation( + paradigm=FakeImageryParadigm(), datasets=[dataset] + ) def test_compatible_dataset(self): # raise @@ -76,8 +78,9 @@ def test_compatible_dataset(self): class Test_CrossSess(Test_WithinSess): def setUp(self): - self.eval = ev.CrossSessionEvaluation(paradigm=FakeImageryParadigm(), - datasets=[dataset]) + self.eval = ev.CrossSessionEvaluation( + paradigm=FakeImageryParadigm(), datasets=[dataset] + ) def test_compatible_dataset(self): ds = FakeDataset(['left_hand', 'right_hand'], n_sessions=1) diff --git a/moabb/tests/paradigms.py b/moabb/tests/paradigms.py index f6f237cf1..c861d8341 100644 --- a/moabb/tests/paradigms.py +++ b/moabb/tests/paradigms.py @@ -27,7 +27,6 @@ def used_events(self, dataset): class Test_MotorImagery(unittest.TestCase): - def test_BaseImagery_paradigm(self): paradigm = SimpleMotorImagery() dataset = FakeDataset(paradigm='imagery') @@ -86,8 +85,7 @@ def test_BaseImagery_noevent(self): def test_LeftRightImagery_paradigm(self): # with a good dataset paradigm = LeftRightImagery() - dataset = FakeDataset(event_list=['left_hand', 'right_hand'], - paradigm='imagery') + dataset = FakeDataset(event_list=['left_hand', 'right_hand'], paradigm='imagery') X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) self.assertEqual(len(np.unique(labels)), 2) self.assertEqual(list(np.unique(labels)), ['left_hand', 'right_hand']) @@ -113,14 +111,14 @@ def test_FilterBankMotorImagery_paradigm(self): self.assertEqual(X.shape[-1], 6) def test_FilterBankMotorImagery_moreclassesthanevent(self): - self.assertRaises(AssertionError, FilterBankMotorImagery, n_classes=3, - events=['hands', 'feet']) + self.assertRaises( + AssertionError, FilterBankMotorImagery, n_classes=3, events=['hands', 'feet'] + ) def test_FilterBankLeftRightImagery_paradigm(self): # can work with filter bank paradigm = FilterBankLeftRightImagery() - dataset = FakeDataset(event_list=['left_hand', 'right_hand'], - paradigm='imagery') + dataset = FakeDataset(event_list=['left_hand', 'right_hand'], paradigm='imagery') X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) # X must be a 4D Array @@ -134,11 +132,9 @@ def used_events(self, dataset): class Test_P300(unittest.TestCase): - def test_BaseP300_paradigm(self): paradigm = SimpleP300() - dataset = FakeDataset(paradigm='p300', - event_list=['Target', 'NonTarget']) + dataset = FakeDataset(paradigm='p300', event_list=['Target', 'NonTarget']) X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) # we should have all the same length @@ -168,8 +164,7 @@ def test_BaseP300_tmintmax(self): def test_BaseP300_filters(self): # can work with filter bank paradigm = SimpleP300(filters=[[1, 12], [12, 24]]) - dataset = FakeDataset(paradigm='p300', - event_list=['Target', 'NonTarget']) + dataset = FakeDataset(paradigm='p300', event_list=['Target', 'NonTarget']) X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) # X must be a 4D Array @@ -180,8 +175,7 @@ def test_BaseP300_wrongevent(self): # test process_raw return empty list if raw does not contain any # selected event. cetain runs in dataset are event specific. paradigm = SimpleP300(filters=[[1, 12], [12, 24]]) - dataset = FakeDataset(paradigm='p300', - event_list=['Target', 'NonTarget']) + dataset = FakeDataset(paradigm='p300', event_list=['Target', 'NonTarget']) raw = dataset.get_data([1])[1]['session_0']['run_0'] # add something on the event channel raw._data[-1] *= 10 @@ -203,12 +197,10 @@ def test_P300_wrongevent(self): def test_P300_paradigm(self): # with a good dataset paradigm = P300() - dataset = FakeDataset(event_list=['Target', 'NonTarget'], - paradigm='p300') + dataset = FakeDataset(event_list=['Target', 'NonTarget'], paradigm='p300') X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) self.assertEqual(len(np.unique(labels)), 2) - self.assertEqual(list(np.unique(labels)), - sorted(['Target', 'NonTarget'])) + self.assertEqual(list(np.unique(labels)), sorted(['Target', 'NonTarget'])) def test_BaseImagery_noevent(self): # Assert error if events from paradigm and dataset dont overlap @@ -245,8 +237,9 @@ def test_FilterBankMotorImagery_paradigm(self): self.assertEqual(X.shape[-1], 6) def test_FilterBankMotorImagery_moreclassesthanevent(self): - self.assertRaises(AssertionError, FilterBankMotorImagery, n_classes=3, - events=['hands', 'feet']) + self.assertRaises( + AssertionError, FilterBankMotorImagery, n_classes=3, events=['hands', 'feet'] + ) def test_FilterBankLeftRightImagery_paradigm(self): # can work with filter bank @@ -260,7 +253,6 @@ def test_FilterBankLeftRightImagery_paradigm(self): class Test_SSVEP(unittest.TestCase): - def test_BaseSSVEP_paradigm(self): paradigm = BaseSSVEP(n_classes=None) dataset = FakeDataset(paradigm='ssvep') @@ -312,8 +304,7 @@ def test_BaseSSVEP_nclasses_default(self): def test_BaseSSVEP_specified_nclasses(self): # Set the number of classes paradigm = BaseSSVEP(n_classes=3) - dataset = FakeDataset(event_list=['13', '15', '17', '19'], - paradigm='ssvep') + dataset = FakeDataset(event_list=['13', '15', '17', '19'], paradigm='ssvep') X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) # labels must contain 3 values @@ -325,8 +316,7 @@ def test_BaseSSVEP_toomany_nclasses(self): self.assertRaises(ValueError, paradigm.get_data, dataset) def test_BaseSSVEP_moreclassesthanevent(self): - self.assertRaises(AssertionError, BaseSSVEP, n_classes=3, - events=['13.', '14.']) + self.assertRaises(AssertionError, BaseSSVEP, n_classes=3, events=['13.', '14.']) def test_SSVEP_noevent(self): # Assert error if events from paradigm and dataset dont overlap @@ -336,8 +326,7 @@ def test_SSVEP_noevent(self): def test_SSVEP_paradigm(self): paradigm = SSVEP(n_classes=None) - dataset = FakeDataset(event_list=['13', '15', '17', '19'], - paradigm='ssvep') + dataset = FakeDataset(event_list=['13', '15', '17', '19'], paradigm='ssvep') X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) # Verify that they have the same length @@ -374,14 +363,12 @@ def test_SSVEP_singlepass(self): def test_SSVEP_filter(self): # Do not accept multiple filters - self.assertRaises(ValueError, SSVEP, - filters=[(10.5, 11.5), (12.5, 13.5)]) + self.assertRaises(ValueError, SSVEP, filters=[(10.5, 11.5), (12.5, 13.5)]) def test_FilterBankSSVEP_paradigm(self): # FilterBankSSVEP with all events paradigm = FilterBankSSVEP(n_classes=None) - dataset = FakeDataset(event_list=['13', '15', '17', '19'], - paradigm='ssvep') + dataset = FakeDataset(event_list=['13', '15', '17', '19'], paradigm='ssvep') X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) # X must be a 4D array diff --git a/moabb/tests/util_tests.py b/moabb/tests/util_tests.py index bc7f4ec45..20f452615 100644 --- a/moabb/tests/util_tests.py +++ b/moabb/tests/util_tests.py @@ -4,34 +4,44 @@ class Test_Utils(unittest.TestCase): - def test_channel_intersection_fun(self): - print(utils.find_intersecting_channels( - [d() for d in utils.dataset_list])[0]) + print(utils.find_intersecting_channels([d() for d in utils.dataset_list])[0]) def test_dataset_search_fun(self): - print([type(i).__name__ for i in utils.dataset_search( - 'imagery', multi_session=True)]) - print([type(i).__name__ for i in utils.dataset_search( - 'imagery', multi_session=False)]) - res = utils.dataset_search('imagery', events=[ - 'right_hand', 'left_hand', 'feet', 'tongue', 'rest']) + print( + [ + type(i).__name__ + for i in utils.dataset_search('imagery', multi_session=True) + ] + ) + print( + [ + type(i).__name__ + for i in utils.dataset_search('imagery', multi_session=False) + ] + ) + res = utils.dataset_search( + 'imagery', events=['right_hand', 'left_hand', 'feet', 'tongue', 'rest'] + ) for out in res: print('multiclass: {}'.format(out.event_id.keys())) - res = utils.dataset_search('imagery', events=[ - 'right_hand', 'feet'], has_all_events=True) + res = utils.dataset_search( + 'imagery', events=['right_hand', 'feet'], has_all_events=True + ) for out in res: - self.assertTrue(set(['right_hand', 'feet']) - <= set(out.event_id.keys())) + self.assertTrue(set(['right_hand', 'feet']) <= set(out.event_id.keys())) def test_dataset_channel_search(self): chans = ['C3', 'Cz'] - All = utils.dataset_search('imagery', events=[ - 'right_hand', 'left_hand', 'feet', 'tongue', 'rest']) - has_chans = utils.dataset_search('imagery', events=[ - 'right_hand', 'left_hand', 'feet', 'tongue', 'rest'], - channels=chans) + All = utils.dataset_search( + 'imagery', events=['right_hand', 'left_hand', 'feet', 'tongue', 'rest'] + ) + has_chans = utils.dataset_search( + 'imagery', + events=['right_hand', 'left_hand', 'feet', 'tongue', 'rest'], + channels=chans, + ) has_types = set([type(x) for x in has_chans]) for d in has_chans: s1 = d.get_data([1])[1] diff --git a/moabb/utils.py b/moabb/utils.py index f0c145c72..ca8c348dd 100644 --- a/moabb/utils.py +++ b/moabb/utils.py @@ -11,8 +11,6 @@ def set_log_level(verbose='info'): """ mne.set_log_level(False) - level = {'debug': logging.DEBUG, - 'info': logging.INFO, - 'warning': logging.WARNING} + level = {'debug': logging.DEBUG, 'info': logging.INFO, 'warning': logging.WARNING} coloredlogs.install(level=level.get(verbose, logging.INFO)) diff --git a/pipelines/CSP_svm_search.py b/pipelines/CSP_svm_search.py index 2f3a5cc13..1c559584a 100644 --- a/pipelines/CSP_svm_search.py +++ b/pipelines/CSP_svm_search.py @@ -10,6 +10,4 @@ pipe = make_pipeline(Covariances('oas'), CSP(6), clf) # this is what will be loaded -PIPELINE = {'name': 'CSP + optSVM', - 'paradigms': ['LeftRightImagery'], - 'pipeline': pipe} +PIPELINE = {'name': 'CSP + optSVM', 'paradigms': ['LeftRightImagery'], 'pipeline': pipe} diff --git a/pipelines/FBCSP.py b/pipelines/FBCSP.py index 967642f74..cc574d3bd 100644 --- a/pipelines/FBCSP.py +++ b/pipelines/FBCSP.py @@ -12,10 +12,11 @@ parameters = {'C': np.logspace(-2, 2, 10)} clf = GridSearchCV(SVC(kernel='linear'), parameters) fb = FilterBank(make_pipeline(Covariances(estimator='oas'), CSP(nfilter=4))) -pipe = make_pipeline(fb, SelectKBest(score_func=mutual_info_classif, k=10), - clf) +pipe = make_pipeline(fb, SelectKBest(score_func=mutual_info_classif, k=10), clf) # this is what will be loaded -PIPELINE = {'name': 'FBCSP + optSVM', - 'paradigms': ['FilterBankMotorImagery'], - 'pipeline': pipe} +PIPELINE = { + 'name': 'FBCSP + optSVM', + 'paradigms': ['FilterBankMotorImagery'], + 'pipeline': pipe, +} diff --git a/pipelines/LogVar.py b/pipelines/LogVar.py index cc8e911bb..6dcbcbefe 100644 --- a/pipelines/LogVar.py +++ b/pipelines/LogVar.py @@ -11,6 +11,4 @@ pipe = make_pipeline(LogVariance(), clf) # this is what will be loaded -PIPELINE = {'name': 'AM + optSVM', - 'paradigms': ['MotorImagery'], - 'pipeline': pipe} +PIPELINE = {'name': 'AM + optSVM', 'paradigms': ['MotorImagery'], 'pipeline': pipe} diff --git a/pipelines/TSSVM.py b/pipelines/TSSVM.py index 83b936822..fab032a3d 100644 --- a/pipelines/TSSVM.py +++ b/pipelines/TSSVM.py @@ -11,6 +11,4 @@ pipe = make_pipeline(Covariances('oas'), TangentSpace(metric='riemann'), clf) # this is what will be loaded -PIPELINE = {'name': 'TS + optSVM', - 'paradigms': ['MotorImagery'], - 'pipeline': pipe} +PIPELINE = {'name': 'TS + optSVM', 'paradigms': ['MotorImagery'], 'pipeline': pipe} diff --git a/pipelines/WTRCSP.py b/pipelines/WTRCSP.py index da681f9d8..38b2139fc 100644 --- a/pipelines/WTRCSP.py +++ b/pipelines/WTRCSP.py @@ -5,10 +5,7 @@ from moabb.pipelines.csp import TRCSP -pipe = make_pipeline(Covariances('scm'), TRCSP( - nfilter=6), LinearDiscriminantAnalysis()) +pipe = make_pipeline(Covariances('scm'), TRCSP(nfilter=6), LinearDiscriminantAnalysis()) # this is what will be loaded -PIPELINE = {'name': 'TRCSP + LDA', - 'paradigms': ['MotorImagery'], - 'pipeline': pipe} +PIPELINE = {'name': 'TRCSP + LDA', 'paradigms': ['MotorImagery'], 'pipeline': pipe} diff --git a/pyproject.toml b/pyproject.toml index 5b5feb285..af7be85a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,83 +1,7 @@ [tool.black] line-length = 90 target-version = ["py36"] -force-exclude = ''' -( - # default black exclude - /(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/ - - # moabb specific files - | setup.py - | moabb/run.py - | moabb/__init__.py - | moabb/utils.py - | moabb/analysis/results.py - | moabb/analysis/plotting.py - | moabb/analysis/__init__.py - | moabb/analysis/meta_analysis.py - | moabb/evaluations/__init__.py - | moabb/evaluations/evaluations.py - | moabb/evaluations/base.py - | moabb/datasets/ssvep_mamem.py - | moabb/datasets/ssvep_nakanishi.py - | moabb/datasets/fake.py - | moabb/datasets/download.py - | moabb/datasets/__init__.py - | moabb/datasets/alex_mi.py - | moabb/datasets/physionet_mi.py - | moabb/datasets/mpi_mi.py - | moabb/datasets/ssvep_exo.py - | moabb/datasets/gigadb.py - | moabb/datasets/utils.py - | moabb/datasets/bnci.py - | moabb/datasets/braininvaders.py - | moabb/datasets/schirrmeister2017.py - | moabb/datasets/epfl.py - | moabb/datasets/upper_limb.py - | moabb/datasets/Zhou2016.py - | moabb/datasets/Weibo2014.py - | moabb/datasets/ssvep_wang.py - | moabb/datasets/base.py - | moabb/datasets/bbci_eeg_fnirs.py - | moabb/tests/analysis.py - | moabb/tests/download.py - | moabb/tests/datasets.py - | moabb/tests/__init__.py - | moabb/tests/paradigms.py - | moabb/tests/evaluations.py - | moabb/tests/util_tests.py - | moabb/pipelines/classification.py - | moabb/pipelines/__init__.py - | moabb/pipelines/csp.py - | moabb/pipelines/features.py - | moabb/pipelines/utils.py - | moabb/paradigms/p300.py - | moabb/paradigms/__init__.py - | moabb/paradigms/motor_imagery.py - | moabb/paradigms/ssvep.py - | moabb/paradigms/base.py - | pipelines/FBCSP.py - | pipelines/TSSVM.py - | pipelines/CSP_svm_search.py - | pipelines/LogVar.py - | pipelines/WTRCSP.py - | docs/source/conf.py - | examples/plot_cross_subject_ssvep.py - | examples/plot_within_session_p300.py - | examples/plot_cross_session_ssvep.py - | examples/plot_cross_session_motor_imagery.py - | examples/plot_cross_session_multiple_datasets.py - | examples/plot_filterbank_csp_vs_csp.py - | tutorials/plot_Getting_Started.py - | tutorials/tutorial_4_adding_a_dataset.py - | tutorials/plot_statistical_analysis.py - | tutorials/tutorial_1_simple_example_motor_imagery.py - | tutorials/tutorial_2_using_mulitple_datasets.py - | tutorials/plot_explore_paradigm.py - | tutorials/select_electrodes_resample.py - | tutorials/tutorial_3_benchmarking_multiple_pipelines.py -) -''' +skip-string-normalization = true # useful for migrating codebases, but not respected in 20.8b1, see: https://github.com/psf/black/issues/1880 [tool.isort] src_paths = ["moabb"] diff --git a/setup.py b/setup.py index fa76fe325..285ad1fe0 100644 --- a/setup.py +++ b/setup.py @@ -1,14 +1,23 @@ from setuptools import find_packages, setup -setup(name='moabb', - version='0.2.1', - description='Mother of all BCI Benchmarks', - url='', - author='Alexandre Barachant, Vinay Jayaram', - author_email='{alexandre.barachant, vinayjayaram13}@gmail.com', - license='BSD (3-clause)', - packages=find_packages(), - install_requires=['numpy', 'scipy', 'scikit-learn', 'pandas', - 'mne', 'pyriemann', 'pyyaml'], - zip_safe=False) +setup( + name='moabb', + version='0.2.1', + description='Mother of all BCI Benchmarks', + url='', + author='Alexandre Barachant, Vinay Jayaram', + author_email='{alexandre.barachant, vinayjayaram13}@gmail.com', + license='BSD (3-clause)', + packages=find_packages(), + install_requires=[ + 'numpy', + 'scipy', + 'scikit-learn', + 'pandas', + 'mne', + 'pyriemann', + 'pyyaml', + ], + zip_safe=False, +) diff --git a/tutorials/plot_Getting_Started.py b/tutorials/plot_Getting_Started.py index 6438c687e..acbe92b17 100644 --- a/tutorials/plot_Getting_Started.py +++ b/tutorials/plot_Getting_Started.py @@ -55,8 +55,7 @@ # is the name of the pipeline and the value is the Pipeline object pipelines = {} -pipelines['AM + LDA'] = make_pipeline(LogVariance(), - LDA()) +pipelines['AM + LDA'] = make_pipeline(LogVariance(), LDA()) parameters = {'C': np.logspace(-2, 2, 10)} clf = GridSearchCV(SVC(kernel='linear'), parameters) pipe = make_pipeline(LogVariance(), clf) @@ -105,8 +104,9 @@ # be cross-validated within a single recording, or across days, or sessions, or # subjects. This also is the correct place to specify multiple threads. -evaluation = CrossSessionEvaluation(paradigm=paradigm, datasets=datasets, - suffix='examples', overwrite=False) +evaluation = CrossSessionEvaluation( + paradigm=paradigm, datasets=datasets, suffix='examples', overwrite=False +) results = evaluation.process(pipelines) ########################################################################## diff --git a/tutorials/plot_statistical_analysis.py b/tutorials/plot_statistical_analysis.py index 8c050ba89..140e67430 100644 --- a/tutorials/plot_statistical_analysis.py +++ b/tutorials/plot_statistical_analysis.py @@ -58,11 +58,9 @@ pipelines['CSP + LDA'] = make_pipeline(CSP(n_components=8), LDA()) -pipelines['RG + LR'] = make_pipeline(Covariances(), TangentSpace(), - LogisticRegression()) +pipelines['RG + LR'] = make_pipeline(Covariances(), TangentSpace(), LogisticRegression()) -pipelines['CSP + LR'] = make_pipeline( - CSP(n_components=8), LogisticRegression()) +pipelines['CSP + LR'] = make_pipeline(CSP(n_components=8), LogisticRegression()) pipelines['RG + LDA'] = make_pipeline(Covariances(), TangentSpace(), LDA()) @@ -84,10 +82,8 @@ datasets = [dataset] overwrite = False # set to True if we want to overwrite cached results evaluation = CrossSessionEvaluation( - paradigm=paradigm, - datasets=datasets, - suffix='examples', - overwrite=overwrite) + paradigm=paradigm, datasets=datasets, suffix='examples', overwrite=overwrite +) results = evaluation.process(pipelines) diff --git a/tutorials/select_electrodes_resample.py b/tutorials/select_electrodes_resample.py index 97a0d5971..7e782d2ad 100644 --- a/tutorials/select_electrodes_resample.py +++ b/tutorials/select_electrodes_resample.py @@ -43,7 +43,7 @@ # Also, use a specific resampling. In this example, all datasets are # set to 200 Hz. -paradigm = LeftRightImagery(channels=['C3', 'C4', 'Cz'], resample=200.) +paradigm = LeftRightImagery(channels=['C3', 'C4', 'Cz'], resample=200.0) ############################################################################## # Evaluation @@ -52,12 +52,11 @@ # The evaluation is conducted on with CSP+LDA, only on the 3 electrodes, with # a sampling rate of 200 Hz. -evaluation = WithinSessionEvaluation(paradigm=paradigm, - datasets=datasets) +evaluation = WithinSessionEvaluation(paradigm=paradigm, datasets=datasets) csp_lda = make_pipeline(CSP(n_components=2), LDA()) -ts_lr = make_pipeline(Covariances(estimator='oas'), - TangentSpace(metric='riemann'), - LR(C=1.0)) +ts_lr = make_pipeline( + Covariances(estimator='oas'), TangentSpace(metric='riemann'), LR(C=1.0) +) results = evaluation.process({'csp+lda': csp_lda, 'ts+lr': ts_lr}) print(results.head()) @@ -71,9 +70,7 @@ # as well as the list of datasets with valid channels. electrodes, datasets = find_intersecting_channels(datasets) -evaluation = WithinSessionEvaluation(paradigm=paradigm, - datasets=datasets, - overwrite=True) +evaluation = WithinSessionEvaluation(paradigm=paradigm, datasets=datasets, overwrite=True) results = evaluation.process({'csp+lda': csp_lda, 'ts+lr': ts_lr}) print(results.head()) diff --git a/tutorials/tutorial_1_simple_example_motor_imagery.py b/tutorials/tutorial_1_simple_example_motor_imagery.py index ff0fb89aa..32a6a6a4c 100644 --- a/tutorials/tutorial_1_simple_example_motor_imagery.py +++ b/tutorials/tutorial_1_simple_example_motor_imagery.py @@ -120,8 +120,9 @@ # `BetweenSessionEvaluation`, which takes all but one session as training # partition and the remaining one as testing partition. -evaluation = WithinSessionEvaluation(paradigm=paradigm, datasets=[dataset], - overwrite=True) +evaluation = WithinSessionEvaluation( + paradigm=paradigm, datasets=[dataset], overwrite=True +) # We obtain the results in the form of a pandas dataframe results = evaluation.process({'csp+lda': pipeline}) @@ -145,6 +146,7 @@ fig, ax = plt.subplots(figsize=(8, 7)) results["subj"] = results["subject"].apply(str) -sns.barplot(x="score", y="subj", hue='session', data=results, orient='h', - palette='viridis', ax=ax) +sns.barplot( + x="score", y="subj", hue='session', data=results, orient='h', palette='viridis', ax=ax +) fig.show() diff --git a/tutorials/tutorial_2_using_mulitple_datasets.py b/tutorials/tutorial_2_using_mulitple_datasets.py index 5650b5b58..88c205a08 100644 --- a/tutorials/tutorial_2_using_mulitple_datasets.py +++ b/tutorials/tutorial_2_using_mulitple_datasets.py @@ -47,8 +47,9 @@ # set `overwrite` to False to cache the results, avoiding to restart all the # evaluation from scratch if a problem occurs. paradigm = LeftRightImagery() -evaluation = WithinSessionEvaluation(paradigm=paradigm, datasets=datasets, - overwrite=False) +evaluation = WithinSessionEvaluation( + paradigm=paradigm, datasets=datasets, overwrite=False +) pipeline = make_pipeline(CSP(n_components=8), LDA()) results = evaluation.process({"csp+lda": pipeline}) @@ -68,6 +69,13 @@ # is to plot the results from the three datasets with just one line. results["subj"] = [str(resi).zfill(2) for resi in results["subject"]] -g = sns.catplot(kind='bar', x="score", y="subj", col="dataset", - data=results, orient='h', palette='viridis') +g = sns.catplot( + kind='bar', + x="score", + y="subj", + col="dataset", + data=results, + orient='h', + palette='viridis', +) plt.show() diff --git a/tutorials/tutorial_3_benchmarking_multiple_pipelines.py b/tutorials/tutorial_3_benchmarking_multiple_pipelines.py index dbc9ef3c9..ff3285d42 100644 --- a/tutorials/tutorial_3_benchmarking_multiple_pipelines.py +++ b/tutorials/tutorial_3_benchmarking_multiple_pipelines.py @@ -47,17 +47,16 @@ # MDM classifier that works directly on covariance matrices. pipelines = {} pipelines["csp+lda"] = make_pipeline(CSP(n_components=8), LDA()) -pipelines["tgsp+svm"] = make_pipeline(Covariances('oas'), - TangentSpace(metric='riemann'), - SVC(kernel='linear')) +pipelines["tgsp+svm"] = make_pipeline( + Covariances('oas'), TangentSpace(metric='riemann'), SVC(kernel='linear') +) pipelines["MDM"] = make_pipeline(Covariances('oas'), MDM(metric='riemann')) # The following lines go exactly as in the previous example, where we end up # obtaining a pandas dataframe containing the results of the evaluation. datasets = [BNCI2014001(), Weibo2014(), Zhou2016()] paradigm = LeftRightImagery() -evaluation = WithinSessionEvaluation(paradigm=paradigm, datasets=datasets, - overwrite=True) +evaluation = WithinSessionEvaluation(paradigm=paradigm, datasets=datasets, overwrite=True) results = evaluation.process(pipelines) if not os.path.exists("./results"): os.mkdir("./results") @@ -72,7 +71,16 @@ # for each subject of each dataset. results["subj"] = [str(resi).zfill(2) for resi in results["subject"]] -g = sns.catplot(kind='bar', x="score", y="subj", hue="pipeline", - col="dataset", height=12, aspect=0.5, data=results, - orient='h', palette='viridis') +g = sns.catplot( + kind='bar', + x="score", + y="subj", + hue="pipeline", + col="dataset", + height=12, + aspect=0.5, + data=results, + orient='h', + palette='viridis', +) plt.show() diff --git a/tutorials/tutorial_4_adding_a_dataset.py b/tutorials/tutorial_4_adding_a_dataset.py index ec4a28447..14c59f8b3 100644 --- a/tutorials/tutorial_4_adding_a_dataset.py +++ b/tutorials/tutorial_4_adding_a_dataset.py @@ -31,6 +31,7 @@ # The fake dataset is available on the # [Zenodo website](https://sandbox.zenodo.org/record/369543) + def create_example_dataset(): """Create a fake example for a dataset""" sfreq = 256 @@ -51,7 +52,7 @@ def create_example_dataset(): tn = int(t_offset * sfreq + n * (t_trial + intertrial) * sfreq) stim[tn] = label noise = 0.1 * np.random.randn(n_chan, len(signal)) - x[:-1, tn:(tn + t_trial * sfreq)] = label * signal + noise + x[:-1, tn : (tn + t_trial * sfreq)] = label * signal + noise x[-1, :] = stim return x, sfreq @@ -92,11 +93,12 @@ def create_example_dataset(): class ExampleDataset(BaseDataset): - ''' + """ Dataset used to exemplify the creation of a dataset class in MOABB. The data samples have been simulated and has no physiological meaning whatsoever. - ''' + """ + def __init__(self): super().__init__( subjects=[1, 2, 3], @@ -105,7 +107,8 @@ def __init__(self): code='Example dataset', interval=[0, 0.75], paradigm='imagery', - doi='') + doi='', + ) def _get_single_subject_data(self, subject): """return data for a single subject""" @@ -124,11 +127,12 @@ def _get_single_subject_data(self, subject): sessions['session_1']['run_1'] = raw return sessions - def data_path(self, subject, path=None, force_update=False, - update_path=None, verbose=None): + def data_path( + self, subject, path=None, force_update=False, update_path=None, verbose=None + ): """Download the data from one subject""" if subject not in self.subject_list: - raise(ValueError("Invalid subject number")) + raise (ValueError("Invalid subject number")) url = '{:s}subject_0{:d}.mat'.format(ExampleDataset_URL, subject) path = dl.data_path(url, 'ExampleDataset') @@ -146,8 +150,7 @@ def data_path(self, subject, path=None, force_update=False, paradigm = LeftRightImagery() X, labels, meta = paradigm.get_data(dataset=dataset, subjects=[1]) -evaluation = WithinSessionEvaluation(paradigm=paradigm, datasets=dataset, - overwrite=True) +evaluation = WithinSessionEvaluation(paradigm=paradigm, datasets=dataset, overwrite=True) pipelines = {} pipelines['MDM'] = make_pipeline(Covariances('oas'), MDM(metric='riemann')) scores = evaluation.process(pipelines) From fe2d4895bed8ef23fa61bb633cebd3b7ff560341 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Mon, 8 Mar 2021 18:56:29 +0100 Subject: [PATCH 02/17] fix: fixed noqa comment, simplified dataset_search slightly --- moabb/datasets/utils.py | 60 +++++++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/moabb/datasets/utils.py b/moabb/datasets/utils.py index f595d93c5..01ba0301a 100644 --- a/moabb/datasets/utils.py +++ b/moabb/datasets/utils.py @@ -14,10 +14,10 @@ dataset_list.append(ds[1]) -def dataset_search( +def dataset_search( # noqa: C901 paradigm, multi_session=False, - events=None, # noqa: C901 + events=None, has_all_events=False, interval=None, min_subjects=1, @@ -70,35 +70,37 @@ def dataset_search( if len(d.subject_list) < min_subjects: continue - if paradigm == d.paradigm: - if interval is not None: - if d.interval[1] - d.interval[0] < interval: - continue - keep_event_dict = {} - if events is None: - keep_event_dict = d.event_id.copy() - else: - n_events = 0 - for e in events: - if n_classes is not None: - if n_events == n_classes: - break - if e in d.event_id.keys(): - keep_event_dict[e] = d.event_id[e] - n_events += 1 - else: - if has_all_events: - skip_dataset = True - if keep_event_dict and not skip_dataset: - if len(channels) > 0: - s1 = d.get_data([1])[1] - sess1 = s1[list(s1.keys())[0]] - raw = sess1[list(sess1.keys())[0]] - raw.pick_types(eeg=True) - if channels <= set(raw.info['ch_names']): - out_data.append(d) + if paradigm != d.paradigm: + continue + + if interval is not None and d.interval[1] - d.interval[0] < interval: + continue + + keep_event_dict = {} + if events is None: + keep_event_dict = d.event_id.copy() + else: + n_events = 0 + for e in events: + if n_classes is not None: + if n_events == n_classes: + break + if e in d.event_id.keys(): + keep_event_dict[e] = d.event_id[e] + n_events += 1 else: + if has_all_events: + skip_dataset = True + if keep_event_dict and not skip_dataset: + if len(channels) > 0: + s1 = d.get_data([1])[1] + sess1 = s1[list(s1.keys())[0]] + raw = sess1[list(sess1.keys())[0]] + raw.pick_types(eeg=True) + if channels <= set(raw.info['ch_names']): out_data.append(d) + else: + out_data.append(d) return out_data From 2b540845d15efcc09326d792ced2269aaf1808bf Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Thu, 11 Mar 2021 17:06:16 +0100 Subject: [PATCH 03/17] correct electrodes names --- moabb/datasets/Weibo2014.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/moabb/datasets/Weibo2014.py b/moabb/datasets/Weibo2014.py index 24d18296f..1f5d10e33 100644 --- a/moabb/datasets/Weibo2014.py +++ b/moabb/datasets/Weibo2014.py @@ -131,14 +131,12 @@ def _get_single_subject_data(self, subject): montage = mne.channels.make_standard_montage('standard_1005') # fmt: off - ch_names = ['Fp1', 'Fpz', 'Fp2', 'AF3', 'AF4', 'F7', 'F5', 'F3', 'F1', - 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', - 'FCz', 'FC2', 'FC4', 'FC6', 'FT8', 'T7', 'C5', 'C3', 'C1', - 'Cz', 'C2', 'C4', 'C6', 'T8', 'TP7', 'CP5', 'CP3', 'CP1', - 'CPz', 'CP2', 'CP4', 'CP6', 'TP8', 'P7', 'P5', 'P3', 'P1', - 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO5', 'PO3', 'POz', - 'PO4', 'PO6', 'PO8', 'CB1', 'O1', 'Oz', 'O2', 'CB2', 'VEO', - 'HEO'] + ch_names = [ + "Fp1", "Fpz", "Fp2", "AF3", "AF4", "F7", "F5", "F3", "F1", "Fz", "F2", "F4", "F6", + "F8", "FT7", "FC5", "FC3", "FC1", "FCz", "FC2", "FC4", "FC6", "FT8", "T7", "C5", + "C3", "C1", "Pz", "P2", "P4", "P6", "P8", "PO7", "PO5", "PO3", "POz", "PO4", "PO6", + "PO8", "CB1", "O1", "Oz", "O2", "CB2", "VEO", "HEO", + ] # fmt: on ch_types = ['eeg'] * 62 + ['eog'] * 2 From 711501817be2e0fd4d5f0ba25921f5679fbafe59 Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Thu, 11 Mar 2021 17:13:39 +0100 Subject: [PATCH 04/17] correct formatting for gigadb --- moabb/datasets/gigadb.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/moabb/datasets/gigadb.py b/moabb/datasets/gigadb.py index 956b6e8cb..5b622a675 100644 --- a/moabb/datasets/gigadb.py +++ b/moabb/datasets/gigadb.py @@ -85,14 +85,14 @@ def _get_single_subject_data(self, subject): )['eeg'] # fmt: off - eeg_ch_names = ['Fp1', 'AF7', 'AF3', 'F1', 'F3', 'F5', 'F7', 'FT7', - 'FC5', 'FC3', 'FC1', 'C1', 'C3', 'C5', 'T7', 'TP7', - 'CP5', 'CP3', 'CP1', 'P1', 'P3', 'P5', 'P7', 'P9', - 'PO7', 'PO3', 'O1', 'Iz', 'Oz', 'POz', 'Pz', 'CPz', - 'Fpz', 'Fp2', 'AF8', 'AF4', 'AFz', 'Fz', 'F2', 'F4', - 'F6', 'F8', 'FT8', 'FC6', 'FC4', 'FC2', 'FCz', 'Cz', - 'C2', 'C4', 'C6', 'T8', 'TP8', 'CP6', 'CP4', 'CP2', - 'P2', 'P4', 'P6', 'P8', 'P10', 'PO8', 'PO4', 'O2'] + eeg_ch_names = [ + "Fp1", "AF7", "AF3", "F1", "F3", "F5", "F7", "FT7", "FC5", "FC3", "FC1", + "C1", "C3", "C5", "T7", "TP7", "CP5", "CP3", "CP1", "P1", "P3", "P5", "P7", + "P9", "PO7", "PO3", "O1", "Iz", "Oz", "POz", "Pz", "CPz", "Fpz", "Fp2", + "AF8", "AF4", "AFz", "Fz", "F2", "F4", "F6", "F8", "FT8", "FC6", "FC4", + "FC2", "FCz", "Cz", "C2", "C4", "C6", "T8", "TP8", "CP6", "CP4", "CP2", + "P2", "P4", "P6", "P8", "P10", "PO8", "PO4", "O2", + ] # fmt: on emg_ch_names = ['EMG1', 'EMG2', 'EMG3', 'EMG4'] ch_names = eeg_ch_names + emg_ch_names + ['Stim'] From a1c43adcdc0f05d7bf76195a0bd2693eff4e2629 Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Thu, 11 Mar 2021 17:18:25 +0100 Subject: [PATCH 05/17] correcting electrodes rename formatting --- moabb/datasets/physionet_mi.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/moabb/datasets/physionet_mi.py b/moabb/datasets/physionet_mi.py index 2e6be1825..4ee4cb854 100644 --- a/moabb/datasets/physionet_mi.py +++ b/moabb/datasets/physionet_mi.py @@ -108,11 +108,13 @@ def _load_one_run(self, subject, run, preload=True): raw.rename_channels(lambda x: x.strip('.')) raw.rename_channels(lambda x: x.upper()) # fmt: off - raw.rename_channels({'AFZ': 'AFz', 'PZ': 'Pz', 'FPZ': 'Fpz', - 'FCZ': 'FCz', 'FP1': 'Fp1', 'CZ': 'Cz', - 'OZ': 'Oz', 'POZ': 'POz', 'IZ': 'Iz', - 'CPZ': 'CPz', 'FP2': 'Fp2', 'FZ': 'Fz'}) + renames = { + "AFZ": "AFz", "PZ": "Pz", "FPZ": "Fpz", "FCZ": "FCz", "FP1": "Fp1", + "CZ": "Cz", "OZ": "Oz", "POZ": "POz", "IZ": "Iz", "CPZ": "CPz", + "FP2": "Fp2", "FZ": "Fz", + } # fmt: on + raw.rename_channels(renames) raw.set_montage(mne.channels.make_standard_montage('standard_1005')) return raw From 9bdfa5a46bd563de8075bcb2429602169f6358c2 Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Thu, 11 Mar 2021 17:19:20 +0100 Subject: [PATCH 06/17] code linting --- moabb/datasets/gigadb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moabb/datasets/gigadb.py b/moabb/datasets/gigadb.py index 5b622a675..6f20a149f 100644 --- a/moabb/datasets/gigadb.py +++ b/moabb/datasets/gigadb.py @@ -91,7 +91,7 @@ def _get_single_subject_data(self, subject): "P9", "PO7", "PO3", "O1", "Iz", "Oz", "POz", "Pz", "CPz", "Fpz", "Fp2", "AF8", "AF4", "AFz", "Fz", "F2", "F4", "F6", "F8", "FT8", "FC6", "FC4", "FC2", "FCz", "Cz", "C2", "C4", "C6", "T8", "TP8", "CP6", "CP4", "CP2", - "P2", "P4", "P6", "P8", "P10", "PO8", "PO4", "O2", + "P2", "P4", "P6", "P8", "P10", "PO8", "PO4", "O2", ] # fmt: on emg_ch_names = ['EMG1', 'EMG2', 'EMG3', 'EMG4'] From 1f6e558aeb8b1b6a87422cb8d749cfe76f1e8f54 Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Thu, 11 Mar 2021 17:34:01 +0100 Subject: [PATCH 07/17] Correct formatting for non-black part --- moabb/datasets/bnci.py | 51 +++++++++++++++++------------------------- 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/moabb/datasets/bnci.py b/moabb/datasets/bnci.py index ba265ab57..dae8ace38 100644 --- a/moabb/datasets/bnci.py +++ b/moabb/datasets/bnci.py @@ -121,9 +121,9 @@ def _load_data_001_2014( # fmt: off ch_names = [ - 'Fz', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'Cz', 'C2', - 'C4', 'C6', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'P1', 'Pz', 'P2', 'POz', - 'EOG1', 'EOG2', 'EOG3' + "Fz", "FC3", "FC1", "FCz", "FC2", "FC4", "C5", "C3", "C1", "Cz", "C2", + "C4", "C6", "CP3", "CP1", "CPz", "CP2", "CP4", "P1", "Pz", "P2", "POz", + "EOG1", "EOG2", "EOG3", ] # fmt: on ch_types = ['eeg'] * 22 + ['eog'] * 3 @@ -178,7 +178,7 @@ def _load_data_004_2014( raise ValueError("Subject must be between 1 and 9. Got %d." % subject) # fmt: off - ch_names = ['C3', 'Cz', 'C4', 'EOG1', 'EOG2', 'EOG3'] + ch_names = ["C3", "Cz", "C4", "EOG1", "EOG2", "EOG3", ] # fmt: on ch_types = ['eeg'] * 3 + ['eog'] * 3 @@ -267,25 +267,16 @@ def _load_data_001_2015( raise ValueError("Subject must be between 1 and 12. Got %d." % subject) if subject in [8, 9, 10, 11]: - ses = ['A', 'B', 'C'] # 3 sessions for those subjects + ses = ["A", "B", "C"] # 3 sessions for those subjects else: - ses = ['A', 'B'] + ses = ["A", "B"] + # fmt: off ch_names = [ - 'FC3', - 'FCz', - 'FC4', - 'C5', - 'C3', - 'C1', - 'Cz', - 'C2', - 'C4', - 'C6', - 'CP3', - 'CPz', - 'CP4', + "FC3", "FCz", "FC4", "C5", "C3", "C1", "Cz", + "C2", "C4", "C6", "CP3","CPz", "CP4", ] + # fmt: on ch_types = ['eeg'] * 13 sessions = {} @@ -321,7 +312,7 @@ def _load_data_003_2015( # fmt: off ch_names = [ - 'Fz', 'Cz', 'P3', 'Pz', 'P4', 'PO7', 'Oz', 'PO8', 'Target', 'Flash' + "Fz", "Cz", "P3", "Pz", "P4", "PO7", "Oz", "PO8", "Target", "Flash", ] # fmt: on @@ -372,16 +363,16 @@ def _load_data_004_2015( if (subject < 1) or (subject > 9): raise ValueError("Subject must be between 1 and 9. Got %d." % subject) - subjects = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'L'] + subjects = ["A", "C", "D", "E", "F", "G", "H", "J", "L"] url = '{u}004-2015/{s}.mat'.format(u=base_url, s=subjects[subject - 1]) filename = data_path(url, path, force_update, update_path)[0] # fmt: off ch_names = [ - 'AFz', 'F7', 'F3', 'Fz', 'F4', 'F8', 'FC3', 'FCz', 'FC4', 'T3', 'C3', - 'Cz', 'C4', 'T4', 'CP3', 'CPz', 'CP4', 'P7', 'P5', 'P3', 'P1', 'Pz', - 'P2', 'P4', 'P6', 'P8', 'PO3', 'PO4', 'O1', 'O2' + "AFz", "F7", "F3", "Fz", "F4", "F8", "FC3", "FCz", "FC4", "T3", "C3", + "Cz", "C4", "T4", "CP3", "CPz", "CP4", "P7", "P5", "P3", "P1", "Pz", + "P2", "P4", "P6", "P8", "PO3", "PO4", "O1", "O2", ] # fmt: on ch_types = ['eeg'] * 30 @@ -405,9 +396,9 @@ def _load_data_009_2015( # fmt: off subjects = [ - 'fce', 'kw', 'faz', 'fcj', 'fcg', 'far', 'faw', 'fax', 'fcc', 'fcm', - 'fas', 'fch', 'fcd', 'fca', 'fcb', 'fau', 'fci', 'fav', 'fat', 'fcl', - 'fck' + "fce", "kw", "faz", "fcj", "fcg", "far", "faw", "fax", "fcc", "fcm", + "fas", "fch", "fcd", "fca", "fcb", "fau", "fci", "fav", "fat", "fcl", + "fck", ] # fmt: on s = subjects[subject - 1] @@ -434,8 +425,8 @@ def _load_data_010_2015( # fmt: off subjects = [ - 'fat', 'gcb', 'gcc', 'gcd', 'gce', 'gcf', 'gcg', 'gch', 'iay', 'icn', - 'icr', 'pia' + "fat", "gcb", "gcc", "gcd", "gce", "gcf", + "gcg", "gch", "iay", "icn", "icr", "pia", ] # fmt: on @@ -463,7 +454,7 @@ def _load_data_012_2015( # fmt: off subjects = [ - 'nv', 'nw', 'nx', 'ny', 'nz', 'mg', 'oa', 'ob', 'oc', 'od', 'ja', 'oe' + "nv", "nw", "nx", "ny", "nz", "mg", "oa", "ob", "oc", "od", "ja", "oe" ] # fmt: on From 7e04a4ba9839d7aceb3f8bf771bfdc3a0e4966a4 Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Thu, 11 Mar 2021 17:36:44 +0100 Subject: [PATCH 08/17] remove black formatting for url --- moabb/datasets/ssvep_mamem.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/moabb/datasets/ssvep_mamem.py b/moabb/datasets/ssvep_mamem.py index f05ec40d9..a732bdd9d 100644 --- a/moabb/datasets/ssvep_mamem.py +++ b/moabb/datasets/ssvep_mamem.py @@ -31,15 +31,11 @@ # MAMEM2_URL = 'https://ndownloader.figshare.com/articles/3153409/versions/2' # MAMEM3_URL = 'https://ndownloader.figshare.com/articles/3413851/versions/1' -MAMEM1_URL = ( - "https://archive.physionet.org/physiobank/database/mssvepdb/dataset1/" # noqa: E501 -) -MAMEM2_URL = ( - "https://archive.physionet.org/physiobank/database/mssvepdb/dataset2/" # noqa: E501 -) -MAMEM3_URL = ( - "https://archive.physionet.org/physiobank/database/mssvepdb/dataset3/" # noqa: E501 -) +# fmt: off +MAMEM1_URL = "https://archive.physionet.org/physiobank/database/mssvepdb/dataset1/" # noqa: E501 +MAMEM2_URL = "https://archive.physionet.org/physiobank/database/mssvepdb/dataset2/" # noqa: E501 +MAMEM3_URL = "https://archive.physionet.org/physiobank/database/mssvepdb/dataset3/" # noqa: E501 +# fmt: on class BaseMAMEM(BaseDataset): From 8c79a52b4f7fcaa866a991c9dee1b51e55232c8c Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Thu, 11 Mar 2021 17:37:41 +0100 Subject: [PATCH 09/17] correct url formatting --- moabb/datasets/ssvep_mamem.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/moabb/datasets/ssvep_mamem.py b/moabb/datasets/ssvep_mamem.py index a732bdd9d..a9e396e34 100644 --- a/moabb/datasets/ssvep_mamem.py +++ b/moabb/datasets/ssvep_mamem.py @@ -32,9 +32,9 @@ # MAMEM3_URL = 'https://ndownloader.figshare.com/articles/3413851/versions/1' # fmt: off -MAMEM1_URL = "https://archive.physionet.org/physiobank/database/mssvepdb/dataset1/" # noqa: E501 -MAMEM2_URL = "https://archive.physionet.org/physiobank/database/mssvepdb/dataset2/" # noqa: E501 -MAMEM3_URL = "https://archive.physionet.org/physiobank/database/mssvepdb/dataset3/" # noqa: E501 +MAMEM1_URL = "https://archive.physionet.org/physiobank/database/mssvepdb/dataset1/" +MAMEM2_URL = "https://archive.physionet.org/physiobank/database/mssvepdb/dataset2/" +MAMEM3_URL = "https://archive.physionet.org/physiobank/database/mssvepdb/dataset3/" # fmt: on From 514059e9554e018a02472804cd5affa270dfffc1 Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Thu, 11 Mar 2021 17:42:13 +0100 Subject: [PATCH 10/17] correct linting in bnci --- moabb/datasets/bnci.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moabb/datasets/bnci.py b/moabb/datasets/bnci.py index dae8ace38..42e284fdd 100644 --- a/moabb/datasets/bnci.py +++ b/moabb/datasets/bnci.py @@ -274,7 +274,7 @@ def _load_data_001_2015( # fmt: off ch_names = [ "FC3", "FCz", "FC4", "C5", "C3", "C1", "Cz", - "C2", "C4", "C6", "CP3","CPz", "CP4", + "C2", "C4", "C6", "CP3", "CPz", "CP4", ] # fmt: on ch_types = ['eeg'] * 13 From a56d789423b68e1d383e871e775710f0a1f39972 Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Thu, 11 Mar 2021 17:43:15 +0100 Subject: [PATCH 11/17] correct error in linting for physionet --- moabb/datasets/physionet_mi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moabb/datasets/physionet_mi.py b/moabb/datasets/physionet_mi.py index 4ee4cb854..5b7e657c9 100644 --- a/moabb/datasets/physionet_mi.py +++ b/moabb/datasets/physionet_mi.py @@ -112,7 +112,7 @@ def _load_one_run(self, subject, run, preload=True): "AFZ": "AFz", "PZ": "Pz", "FPZ": "Fpz", "FCZ": "FCz", "FP1": "Fp1", "CZ": "Cz", "OZ": "Oz", "POZ": "POz", "IZ": "Iz", "CPZ": "CPz", "FP2": "Fp2", "FZ": "Fz", - } + } # fmt: on raw.rename_channels(renames) raw.set_montage(mne.channels.make_standard_montage('standard_1005')) From 18e1c6b315681cbfe986c8b5a248ed70063934b0 Mon Sep 17 00:00:00 2001 From: Sylvain Chevallier Date: Thu, 11 Mar 2021 18:14:46 +0100 Subject: [PATCH 12/17] correct some black formatting in util_test --- moabb/tests/util_tests.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/moabb/tests/util_tests.py b/moabb/tests/util_tests.py index 20f452615..1118ccba3 100644 --- a/moabb/tests/util_tests.py +++ b/moabb/tests/util_tests.py @@ -8,18 +8,10 @@ def test_channel_intersection_fun(self): print(utils.find_intersecting_channels([d() for d in utils.dataset_list])[0]) def test_dataset_search_fun(self): - print( - [ - type(i).__name__ - for i in utils.dataset_search('imagery', multi_session=True) - ] - ) - print( - [ - type(i).__name__ - for i in utils.dataset_search('imagery', multi_session=False) - ] - ) + found = utils.dataset_search('imagery', multi_session=True) + print([type(dataset).__name__ for dataset in found]) + found = utils.dataset_search('imagery', multi_session=False) + print([type(dataset).__name__ for dataset in found]) res = utils.dataset_search( 'imagery', events=['right_hand', 'left_hand', 'feet', 'tongue', 'rest'] ) From eeba303b088b976734cbf73601de541875eb2fc1 Mon Sep 17 00:00:00 2001 From: Vladislav Goncharenko Date: Thu, 11 Mar 2021 20:09:54 +0300 Subject: [PATCH 13/17] removed skipping single quotes --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index af7be85a7..c02ce9bb1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,6 @@ [tool.black] line-length = 90 target-version = ["py36"] -skip-string-normalization = true # useful for migrating codebases, but not respected in 20.8b1, see: https://github.com/psf/black/issues/1880 [tool.isort] src_paths = ["moabb"] From 8a9c30ac35d7eb677efca391a6f3a1dbd485233c Mon Sep 17 00:00:00 2001 From: Vladislav Goncharenko Date: Thu, 11 Mar 2021 20:11:11 +0300 Subject: [PATCH 14/17] formatted to double quotes --- docs/source/conf.py | 98 +++---- examples/plot_cross_session_motor_imagery.py | 24 +- .../plot_cross_session_multiple_datasets.py | 20 +- examples/plot_cross_session_ssvep.py | 10 +- examples/plot_cross_subject_ssvep.py | 26 +- examples/plot_filterbank_csp_vs_csp.py | 24 +- examples/plot_within_session_p300.py | 30 +- moabb/analysis/__init__.py | 24 +- moabb/analysis/meta_analysis.py | 34 +-- moabb/analysis/plotting.py | 106 +++---- moabb/analysis/results.py | 82 +++--- moabb/datasets/Weibo2014.py | 56 ++-- moabb/datasets/Zhou2016.py | 60 ++-- moabb/datasets/alex_mi.py | 12 +- moabb/datasets/base.py | 4 +- moabb/datasets/bbci_eeg_fnirs.py | 92 +++--- moabb/datasets/bnci.py | 264 +++++++++--------- moabb/datasets/braininvaders.py | 54 ++-- moabb/datasets/download.py | 4 +- moabb/datasets/epfl.py | 110 ++++---- moabb/datasets/fake.py | 14 +- moabb/datasets/gigadb.py | 24 +- moabb/datasets/mpi_mi.py | 24 +- moabb/datasets/physionet_mi.py | 48 ++-- moabb/datasets/schirrmeister2017.py | 58 ++-- moabb/datasets/ssvep_exo.py | 18 +- moabb/datasets/ssvep_mamem.py | 14 +- moabb/datasets/ssvep_nakanishi.py | 46 +-- moabb/datasets/ssvep_wang.py | 20 +- moabb/datasets/upper_limb.py | 42 +-- moabb/datasets/utils.py | 20 +- moabb/evaluations/base.py | 20 +- moabb/evaluations/evaluations.py | 52 ++-- moabb/paradigms/base.py | 12 +- moabb/paradigms/motor_imagery.py | 46 +-- moabb/paradigms/p300.py | 28 +- moabb/paradigms/ssvep.py | 16 +- moabb/pipelines/classification.py | 6 +- moabb/pipelines/csp.py | 14 +- moabb/pipelines/utils.py | 10 +- moabb/run.py | 42 +-- moabb/tests/analysis.py | 104 +++---- moabb/tests/datasets.py | 10 +- moabb/tests/download.py | 6 +- moabb/tests/evaluations.py | 14 +- moabb/tests/paradigms.py | 102 +++---- moabb/tests/util_tests.py | 26 +- moabb/utils.py | 4 +- pipelines/CSP_svm_search.py | 6 +- pipelines/FBCSP.py | 12 +- pipelines/LogVar.py | 6 +- pipelines/TSSVM.py | 8 +- pipelines/WTRCSP.py | 4 +- setup.py | 28 +- tutorials/plot_Getting_Started.py | 14 +- tutorials/plot_explore_paradigm.py | 2 +- tutorials/plot_statistical_analysis.py | 16 +- tutorials/select_electrodes_resample.py | 10 +- ...tutorial_1_simple_example_motor_imagery.py | 18 +- .../tutorial_2_using_mulitple_datasets.py | 12 +- ...orial_3_benchmarking_multiple_pipelines.py | 18 +- tutorials/tutorial_4_adding_a_dataset.py | 34 +-- 62 files changed, 1081 insertions(+), 1081 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 3bcba8365..4ff1dc3ca 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -22,21 +22,21 @@ import moabb # noqa: F401 -sys.path.insert(0, os.path.abspath('../')) -sys.path.insert(0, os.path.abspath('../../')) +sys.path.insert(0, os.path.abspath("../")) +sys.path.insert(0, os.path.abspath("../../")) -matplotlib.use('Agg') +matplotlib.use("Agg") # -- Project information ----------------------------------------------------- -project = 'moabb' -copyright = '2018, Alexandre Barachant' -author = 'Alexandre Barachant, Vinay Jayaram' +project = "moabb" +copyright = "2018, Alexandre Barachant" +author = "Alexandre Barachant, Vinay Jayaram" # The short X.Y version -version = '' +version = "" # The full version, including alpha/beta/rc tags -release = '' +release = "" # -- General configuration --------------------------------------------------- @@ -49,17 +49,17 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.imgmath', - 'sphinx.ext.viewcode', - 'sphinx.ext.napoleon', - 'sphinx.ext.autosummary', - 'sphinx_gallery.gen_gallery', - 'm2r', + "sphinx.ext.autodoc", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.imgmath", + "sphinx.ext.viewcode", + "sphinx.ext.napoleon", + "sphinx.ext.autosummary", + "sphinx_gallery.gen_gallery", + "m2r", ] napoleon_google_docstring = False @@ -72,24 +72,24 @@ plot_html_show_source_link = False sphinx_gallery_conf = { - 'examples_dirs': ['../../examples', '../../tutorials'], - 'gallery_dirs': ['auto_examples', 'auto_tutorials'], - 'backreferences_dir': False, + "examples_dirs": ["../../examples", "../../tutorials"], + "gallery_dirs": ["auto_examples", "auto_tutorials"], + "backreferences_dir": False, } # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] -autodoc_default_flags = ['inherited-members'] +autodoc_default_flags = ["inherited-members"] autosummary_generate = True numpydoc_show_class_members = True # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # -source_suffix = ['.rst', '.md'] +source_suffix = [".rst", ".md"] # The master toctree document. -master_doc = 'index' +master_doc = "index" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -104,7 +104,7 @@ exclude_patterns = [] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # -- Options for HTML output ------------------------------------------------- @@ -112,7 +112,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'bootstrap' +html_theme = "bootstrap" html_theme_path = sphinx_bootstrap_theme.get_html_theme_path() @@ -134,20 +134,20 @@ # (name, "http://example.com", True) # arbitrary absolute url # Note the "1" or "True" value above as the third argument to indicate # an arbitrary url. - 'navbar_links': [ + "navbar_links": [ ("API", "api"), ("Gallery", "auto_examples/index"), ("Tutorials", "auto_tutorials/index"), ], # Render the next and previous page links in navbar. (Default: true) - 'navbar_sidebarrel': False, + "navbar_sidebarrel": False, # Render the current pages TOC in the navbar. (Default: true) - 'navbar_pagenav': True, + "navbar_pagenav": True, # Tab name for the current pages TOC. (Default: "Page") - 'navbar_pagenav_name': "Page", + "navbar_pagenav_name": "Page", # Global TOC depth for "site" navbar tab. (Default: 1) # Switching to -1 shows all levels. - 'globaltoc_depth': 2, + "globaltoc_depth": 2, # Include hidden TOCs in Site navbar? # # Note: If this is "false", you cannot have mixed ``:hidden:`` and @@ -155,16 +155,16 @@ # will break. # # Values: "true" (default) or "false" - 'globaltoc_includehidden': "true", + "globaltoc_includehidden": "true", # HTML navbar class (Default: "navbar") to attach to
element. # For black navbar, do "navbar navbar-inverse" - 'navbar_class': "navbar navbar-inverse", + "navbar_class": "navbar navbar-inverse", # Fix navigation bar to top of page? # Values: "true" (default) or "false" - 'navbar_fixed_top': "true", + "navbar_fixed_top": "true", # Location of link to source. # Options are "nav" (default), "footer" or anything else to exclude. - 'source_link_position': "footer", + "source_link_position": "footer", # Bootswatch (http://bootswatch.com/) theme. # # Options are nothing (default) or the name of a valid theme @@ -176,16 +176,16 @@ # Currently, the supported themes are: # - Bootstrap 2: https://bootswatch.com/2 # - Bootstrap 3: https://bootswatch.com/3 - 'bootswatch_theme': "united", + "bootswatch_theme": "united", # Choose Bootstrap version. # Values: "3" (default) or "2" (in quotes) - 'bootstrap_version': "3", + "bootstrap_version": "3", } # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. @@ -201,7 +201,7 @@ # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. -htmlhelp_basename = 'moabbdoc' +htmlhelp_basename = "moabbdoc" # -- Options for LaTeX output ------------------------------------------------ @@ -225,7 +225,7 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'moabb.tex', 'moabb Documentation', 'Alexandre Barachant', 'manual'), + (master_doc, "moabb.tex", "moabb Documentation", "Alexandre Barachant", "manual"), ] @@ -233,7 +233,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [(master_doc, 'moabb', 'moabb Documentation', [author], 1)] +man_pages = [(master_doc, "moabb", "moabb Documentation", [author], 1)] # -- Options for Texinfo output ---------------------------------------------- @@ -244,12 +244,12 @@ texinfo_documents = [ ( master_doc, - 'moabb', - 'moabb Documentation', + "moabb", + "moabb Documentation", author, - 'moabb', - 'One line description of project.', - 'Miscellaneous', + "moabb", + "One line description of project.", + "Miscellaneous", ), ] @@ -259,7 +259,7 @@ # -- Options for intersphinx extension --------------------------------------- # Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'https://docs.python.org/': None} +intersphinx_mapping = {"https://docs.python.org/": None} # -- Options for todo extension ---------------------------------------------- diff --git a/examples/plot_cross_session_motor_imagery.py b/examples/plot_cross_session_motor_imagery.py index 96bab1f03..cf8187c71 100644 --- a/examples/plot_cross_session_motor_imagery.py +++ b/examples/plot_cross_session_motor_imagery.py @@ -38,7 +38,7 @@ from moabb.paradigms import LeftRightImagery -moabb.set_log_level('info') +moabb.set_log_level("info") ############################################################################## # Create pipelines @@ -54,10 +54,10 @@ pipelines = {} -pipelines['CSP + LDA'] = make_pipeline(CSP(n_components=8), LDA()) +pipelines["CSP + LDA"] = make_pipeline(CSP(n_components=8), LDA()) -pipelines['RG + LR'] = make_pipeline( - Covariances(), TangentSpace(), LogisticRegression(solver='lbfgs') +pipelines["RG + LR"] = make_pipeline( + Covariances(), TangentSpace(), LogisticRegression(solver="lbfgs") ) ############################################################################## @@ -79,7 +79,7 @@ datasets = [dataset] overwrite = False # set to True if we want to overwrite cached results evaluation = CrossSessionEvaluation( - paradigm=paradigm, datasets=datasets, suffix='examples', overwrite=overwrite + paradigm=paradigm, datasets=datasets, suffix="examples", overwrite=overwrite ) results = evaluation.process(pipelines) @@ -100,27 +100,27 @@ sns.stripplot( data=results, - y='score', - x='pipeline', + y="score", + x="pipeline", ax=axes[0], jitter=True, alpha=0.5, zorder=1, palette="Set1", ) -sns.pointplot(data=results, y='score', x='pipeline', ax=axes[0], zorder=1, palette="Set1") +sns.pointplot(data=results, y="score", x="pipeline", ax=axes[0], zorder=1, palette="Set1") -axes[0].set_ylabel('ROC AUC') +axes[0].set_ylabel("ROC AUC") axes[0].set_ylim(0.5, 1) # paired plot paired = results.pivot_table( - values='score', columns='pipeline', index=['subject', 'session'] + values="score", columns="pipeline", index=["subject", "session"] ) paired = paired.reset_index() -sns.regplot(data=paired, y='RG + LR', x='CSP + LDA', ax=axes[1], fit_reg=False) -axes[1].plot([0, 1], [0, 1], ls='--', c='k') +sns.regplot(data=paired, y="RG + LR", x="CSP + LDA", ax=axes[1], fit_reg=False) +axes[1].plot([0, 1], [0, 1], ls="--", c="k") axes[1].set_xlim(0.5, 1) plt.show() diff --git a/examples/plot_cross_session_multiple_datasets.py b/examples/plot_cross_session_multiple_datasets.py index 2e4363c8a..b9ca2db51 100644 --- a/examples/plot_cross_session_multiple_datasets.py +++ b/examples/plot_cross_session_multiple_datasets.py @@ -28,9 +28,9 @@ from moabb.pipelines import SSVEP_CCA -warnings.simplefilter(action='ignore', category=FutureWarning) -warnings.simplefilter(action='ignore', category=RuntimeWarning) -moabb.set_log_level('info') +warnings.simplefilter(action="ignore", category=FutureWarning) +warnings.simplefilter(action="ignore", category=RuntimeWarning) +moabb.set_log_level("info") ############################################################################### # Loading dataset @@ -97,7 +97,7 @@ overwrite = True # set to True if we want to overwrite cached results evaluation = CrossSessionEvaluation( - paradigm=paradigm, datasets=datasets, suffix='examples', overwrite=overwrite + paradigm=paradigm, datasets=datasets, suffix="examples", overwrite=overwrite ) results = evaluation.process(pipeline) @@ -111,11 +111,11 @@ sns.catplot( data=results, - x='session', - y='score', - hue='subject', - col='dataset', - kind='bar', - palette='viridis', + x="session", + y="score", + hue="subject", + col="dataset", + kind="bar", + palette="viridis", ) plt.show() diff --git a/examples/plot_cross_session_ssvep.py b/examples/plot_cross_session_ssvep.py index 92e2f6ea5..ac69fad25 100644 --- a/examples/plot_cross_session_ssvep.py +++ b/examples/plot_cross_session_ssvep.py @@ -28,9 +28,9 @@ from moabb.pipelines import SSVEP_CCA -warnings.simplefilter(action='ignore', category=FutureWarning) -warnings.simplefilter(action='ignore', category=RuntimeWarning) -moabb.set_log_level('info') +warnings.simplefilter(action="ignore", category=FutureWarning) +warnings.simplefilter(action="ignore", category=RuntimeWarning) +moabb.set_log_level("info") ############################################################################### # Loading dataset @@ -90,7 +90,7 @@ overwrite = True # set to True if we want to overwrite cached results evaluation = CrossSessionEvaluation( - paradigm=paradigm, datasets=dataset, suffix='examples', overwrite=overwrite + paradigm=paradigm, datasets=dataset, suffix="examples", overwrite=overwrite ) results = evaluation.process(pipeline) @@ -103,6 +103,6 @@ # Here we plot the results, indicating the score for each session and subject plt.figure() -sns.barplot(data=results, y='score', x='session', hue='subject', palette='viridis') +sns.barplot(data=results, y="score", x="session", hue="subject", palette="viridis") plt.show() diff --git a/examples/plot_cross_subject_ssvep.py b/examples/plot_cross_subject_ssvep.py index fc064d054..41ffa7975 100644 --- a/examples/plot_cross_subject_ssvep.py +++ b/examples/plot_cross_subject_ssvep.py @@ -29,9 +29,9 @@ from moabb.pipelines import SSVEP_CCA, ExtendedSSVEPSignal -warnings.simplefilter(action='ignore', category=FutureWarning) -warnings.simplefilter(action='ignore', category=RuntimeWarning) -moabb.set_log_level('info') +warnings.simplefilter(action="ignore", category=FutureWarning) +warnings.simplefilter(action="ignore", category=RuntimeWarning) +moabb.set_log_level("info") ############################################################################### # Loading dataset @@ -84,15 +84,15 @@ # The second pipeline relies on the above defined CCA classifier. pipelines_fb = {} -pipelines_fb['RG + LogReg'] = make_pipeline( +pipelines_fb["RG + LogReg"] = make_pipeline( ExtendedSSVEPSignal(), - Covariances(estimator='lwf'), + Covariances(estimator="lwf"), TangentSpace(), - LogisticRegression(solver='lbfgs', multi_class='auto'), + LogisticRegression(solver="lbfgs", multi_class="auto"), ) pipelines = {} -pipelines['CCA'] = make_pipeline(SSVEP_CCA(interval=interval, freqs=freqs, n_harmonics=3)) +pipelines["CCA"] = make_pipeline(SSVEP_CCA(interval=interval, freqs=freqs, n_harmonics=3)) ############################################################################## # Evaluation @@ -130,19 +130,19 @@ # # Here we plot the results. -fig, ax = plt.subplots(facecolor='white', figsize=[8, 4]) +fig, ax = plt.subplots(facecolor="white", figsize=[8, 4]) sns.stripplot( data=results, - y='score', - x='pipeline', + y="score", + x="pipeline", ax=ax, jitter=True, alpha=0.5, zorder=1, palette="Set1", ) -sns.pointplot(data=results, y='score', x='pipeline', ax=ax, zorder=1, palette="Set1") -ax.set_ylabel('Accuracy') +sns.pointplot(data=results, y="score", x="pipeline", ax=ax, zorder=1, palette="Set1") +ax.set_ylabel("Accuracy") ax.set_ylim(0.1, 0.6) -plt.savefig('ssvep.png') +plt.savefig("ssvep.png") fig.show() diff --git a/examples/plot_filterbank_csp_vs_csp.py b/examples/plot_filterbank_csp_vs_csp.py index c0e2adaa6..7fc02ed86 100644 --- a/examples/plot_filterbank_csp_vs_csp.py +++ b/examples/plot_filterbank_csp_vs_csp.py @@ -24,7 +24,7 @@ from moabb.pipelines.utils import FilterBank -moabb.set_log_level('info') +moabb.set_log_level("info") ############################################################################## # Create pipelines @@ -41,10 +41,10 @@ # their own dict. pipelines = {} -pipelines['CSP + LDA'] = make_pipeline(CSP(n_components=8), LDA()) +pipelines["CSP + LDA"] = make_pipeline(CSP(n_components=8), LDA()) pipelines_fb = {} -pipelines_fb['FBCSP + LDA'] = make_pipeline(FilterBank(CSP(n_components=4)), LDA()) +pipelines_fb["FBCSP + LDA"] = make_pipeline(FilterBank(CSP(n_components=4)), LDA()) ############################################################################## # Evaluation @@ -71,7 +71,7 @@ fmax = 35 paradigm = LeftRightImagery(fmin=fmin, fmax=fmax) evaluation = CrossSessionEvaluation( - paradigm=paradigm, datasets=datasets, suffix='examples', overwrite=overwrite + paradigm=paradigm, datasets=datasets, suffix="examples", overwrite=overwrite ) results = evaluation.process(pipelines) @@ -79,7 +79,7 @@ filters = [[8, 12], [12, 16], [16, 20], [20, 24], [24, 28], [28, 35]] paradigm = FilterBankLeftRightImagery(filters=filters) evaluation = CrossSessionEvaluation( - paradigm=paradigm, datasets=datasets, suffix='examples', overwrite=overwrite + paradigm=paradigm, datasets=datasets, suffix="examples", overwrite=overwrite ) results_fb = evaluation.process(pipelines_fb) @@ -103,27 +103,27 @@ sns.stripplot( data=results, - y='score', - x='pipeline', + y="score", + x="pipeline", ax=axes[0], jitter=True, alpha=0.5, zorder=1, palette="Set1", ) -sns.pointplot(data=results, y='score', x='pipeline', ax=axes[0], zorder=1, palette="Set1") +sns.pointplot(data=results, y="score", x="pipeline", ax=axes[0], zorder=1, palette="Set1") -axes[0].set_ylabel('ROC AUC') +axes[0].set_ylabel("ROC AUC") axes[0].set_ylim(0.5, 1) # paired plot paired = results.pivot_table( - values='score', columns='pipeline', index=['subject', 'session'] + values="score", columns="pipeline", index=["subject", "session"] ) paired = paired.reset_index() -sns.regplot(data=paired, y='FBCSP + LDA', x='CSP + LDA', ax=axes[1], fit_reg=False) -axes[1].plot([0, 1], [0, 1], ls='--', c='k') +sns.regplot(data=paired, y="FBCSP + LDA", x="CSP + LDA", ax=axes[1], fit_reg=False) +axes[1].plot([0, 1], [0, 1], ls="--", c="k") axes[1].set_xlim(0.5, 1) plt.show() diff --git a/examples/plot_within_session_p300.py b/examples/plot_within_session_p300.py index 0a23901c7..b0fd1a98a 100644 --- a/examples/plot_within_session_p300.py +++ b/examples/plot_within_session_p300.py @@ -37,11 +37,11 @@ from moabb.paradigms import P300 -warnings.simplefilter(action='ignore', category=FutureWarning) -warnings.simplefilter(action='ignore', category=RuntimeWarning) +warnings.simplefilter(action="ignore", category=FutureWarning) +warnings.simplefilter(action="ignore", category=RuntimeWarning) -moabb.set_log_level('info') +moabb.set_log_level("info") # This is an auxiliary transformer that allows one to vectorize data # structures in a pipeline For instance, in the case of a X with dimensions @@ -74,18 +74,18 @@ def transform(self, X): # we have to do this because the classes are called 'Target' and 'NonTarget' # but the evaluation function uses a LabelEncoder, transforming them # to 0 and 1 -labels_dict = {'Target': 1, 'NonTarget': 0} +labels_dict = {"Target": 1, "NonTarget": 0} -pipelines['RG + LDA'] = make_pipeline( +pipelines["RG + LDA"] = make_pipeline( XdawnCovariances( - nfilter=2, classes=[labels_dict['Target']], estimator='lwf', xdawn_estimator='lwf' + nfilter=2, classes=[labels_dict["Target"]], estimator="lwf", xdawn_estimator="lwf" ), TangentSpace(), - LDA(solver='lsqr', shrinkage='auto'), + LDA(solver="lsqr", shrinkage="auto"), ) -pipelines['Xdw + LDA'] = make_pipeline( - Xdawn(nfilter=2, estimator='lwf'), Vectorizer(), LDA(solver='lsqr', shrinkage='auto') +pipelines["Xdw + LDA"] = make_pipeline( + Xdawn(nfilter=2, estimator="lwf"), Vectorizer(), LDA(solver="lsqr", shrinkage="auto") ) ############################################################################## @@ -106,7 +106,7 @@ def transform(self, X): datasets = [dataset] overwrite = True # set to True if we want to overwrite cached results evaluation = WithinSessionEvaluation( - paradigm=paradigm, datasets=datasets, suffix='examples', overwrite=overwrite + paradigm=paradigm, datasets=datasets, suffix="examples", overwrite=overwrite ) results = evaluation.process(pipelines) @@ -116,21 +116,21 @@ def transform(self, X): # # Here we plot the results. -fig, ax = plt.subplots(facecolor='white', figsize=[8, 4]) +fig, ax = plt.subplots(facecolor="white", figsize=[8, 4]) sns.stripplot( data=results, - y='score', - x='pipeline', + y="score", + x="pipeline", ax=ax, jitter=True, alpha=0.5, zorder=1, palette="Set1", ) -sns.pointplot(data=results, y='score', x='pipeline', ax=ax, zorder=1, palette="Set1") +sns.pointplot(data=results, y="score", x="pipeline", ax=ax, zorder=1, palette="Set1") -ax.set_ylabel('ROC AUC') +ax.set_ylabel("ROC AUC") ax.set_ylim(0.5, 1) fig.show() diff --git a/moabb/analysis/__init__.py b/moabb/analysis/__init__.py index 726ee1e0b..f50001757 100644 --- a/moabb/analysis/__init__.py +++ b/moabb/analysis/__init__.py @@ -14,7 +14,7 @@ log = logging.getLogger() -def analyze(results, out_path, name='analysis', plot=False): +def analyze(results, out_path, name="analysis", plot=False): """Analyze results. Given a results dataframe, generates a folder with @@ -36,9 +36,9 @@ def analyze(results, out_path, name='analysis', plot=False): """ # input checks # if not isinstance(out_path, str): - raise ValueError('Given out_path argument is not string') + raise ValueError("Given out_path argument is not string") elif not os.path.isdir(out_path): - raise IOError('Given directory does not exist') + raise IOError("Given directory does not exist") else: analysis_path = os.path.join(out_path, name) @@ -47,24 +47,24 @@ def analyze(results, out_path, name='analysis', plot=False): print(unique_ids) print(set(unique_ids)) if len(unique_ids) != len(set(unique_ids)): - log.warning('Pipeline names are too similar, turning off name shortening') + log.warning("Pipeline names are too similar, turning off name shortening") simplify = False os.makedirs(analysis_path, exist_ok=True) # TODO: no good cross-platform way of recording CPU info? - with open(os.path.join(analysis_path, 'info.txt'), 'a') as f: + with open(os.path.join(analysis_path, "info.txt"), "a") as f: dt = datetime.now() - f.write('Date: {:%Y-%m-%d}\n Time: {:%H:%M}\n'.format(dt, dt)) - f.write('System: {}\n'.format(platform.system())) - f.write('CPU: {}\n'.format(platform.processor())) + f.write("Date: {:%Y-%m-%d}\n Time: {:%H:%M}\n".format(dt, dt)) + f.write("System: {}\n".format(platform.system())) + f.write("CPU: {}\n".format(platform.processor())) - results.to_csv(os.path.join(analysis_path, 'data.csv')) + results.to_csv(os.path.join(analysis_path, "data.csv")) stats = compute_dataset_statistics(results) - stats.to_csv(os.path.join(analysis_path, 'stats.csv')) + stats.to_csv(os.path.join(analysis_path, "stats.csv")) P, T = find_significant_differences(stats) if plot: fig, color_dict = plt.score_plot(results) - fig.savefig(os.path.join(analysis_path, 'scores.pdf')) + fig.savefig(os.path.join(analysis_path, "scores.pdf")) fig = plt.summary_plot(P, T, simplify=simplify) - fig.savefig(os.path.join(analysis_path, 'ordering.pdf')) + fig.savefig(os.path.join(analysis_path, "ordering.pdf")) diff --git a/moabb/analysis/meta_analysis.py b/moabb/analysis/meta_analysis.py index 6005407bf..f33687a22 100644 --- a/moabb/analysis/meta_analysis.py +++ b/moabb/analysis/meta_analysis.py @@ -10,7 +10,7 @@ def collapse_session_scores(df): - return df.groupby(['pipeline', 'dataset', 'subject']).mean().reset_index() + return df.groupby(["pipeline", "dataset", "subject"]).mean().reset_index() def compute_pvals_wilcoxon(df, order=None): @@ -27,7 +27,7 @@ def compute_pvals_wilcoxon(df, order=None): if order is None: order = df.columns else: - errormsg = 'provided order does not have all columns of dataframe' + errormsg = "provided order does not have all columns of dataframe" assert set(order) == set(df.columns), errormsg out = np.zeros((len(df.columns), len(df.columns))) @@ -103,7 +103,7 @@ def compute_pvals_perm(df, order=None): if order is None: order = df.columns else: - errormsg = 'provided order does not have all columns of dataframe' + errormsg = "provided order does not have all columns of dataframe" assert set(order) == set(df.columns), errormsg # reshape df into matrix (sub, k, k) of differences data = np.zeros((df.shape[0], len(order), len(order))) @@ -133,7 +133,7 @@ def compute_effect(df, order=None): if order is None: order = df.columns else: - errormsg = 'provided order does not have all columns of dataframe' + errormsg = "provided order does not have all columns of dataframe" assert set(order) == set(df.columns), errormsg out = np.zeros((len(df.columns), len(df.columns))) @@ -158,21 +158,21 @@ def compute_dataset_statistics(df, perm_cutoff=20): out = {} for d in dsets: score_data = df[df.dataset == d].pivot( - index='subject', values='score', columns='pipeline' + index="subject", values="score", columns="pipeline" ) if score_data.shape[0] < perm_cutoff: p = compute_pvals_perm(score_data, algs) else: p = compute_pvals_wilcoxon(score_data, algs) t = compute_effect(score_data, algs) - P = pd.DataFrame(index=pd.Index(algs, name='pipe1'), columns=algs, data=p) - T = pd.DataFrame(index=pd.Index(algs, name='pipe1'), columns=algs, data=t) - D1 = pd.melt(P.reset_index(), id_vars='pipe1', var_name='pipe2', value_name='p') - D2 = pd.melt(T.reset_index(), id_vars='pipe1', var_name='pipe2', value_name='smd') + P = pd.DataFrame(index=pd.Index(algs, name="pipe1"), columns=algs, data=p) + T = pd.DataFrame(index=pd.Index(algs, name="pipe1"), columns=algs, data=t) + D1 = pd.melt(P.reset_index(), id_vars="pipe1", var_name="pipe2", value_name="p") + D2 = pd.melt(T.reset_index(), id_vars="pipe1", var_name="pipe2", value_name="smd") stats_df = D1.merge(D2) - stats_df['nsub'] = score_data.shape[0] + stats_df["nsub"] = score_data.shape[0] out[d] = stats_df - return pd.concat(out, axis=0, names=['dataset', 'index']).reset_index() + return pd.concat(out, axis=0, names=["dataset", "index"]).reset_index() def combine_effects(effects, nsubs): @@ -194,7 +194,7 @@ def combine_pvalues(p, nsubs): return p.item() else: W = np.sqrt(nsubs) - out = stats.combine_pvalues(np.array(p), weights=W, method='stouffer')[1] + out = stats.combine_pvalues(np.array(p), weights=W, method="stouffer")[1] return out @@ -216,9 +216,9 @@ def find_significant_differences(df, perm_cutoff=20): """ dsets = df.dataset.unique() algs = df.pipe1.unique() - nsubs = np.array([df.loc[df.dataset == d, 'nsub'].mean() for d in dsets]) - P_full = df.pivot_table(values='p', index=['dataset', 'pipe1'], columns='pipe2') - T_full = df.pivot_table(values='smd', index=['dataset', 'pipe1'], columns='pipe2') + nsubs = np.array([df.loc[df.dataset == d, "nsub"].mean() for d in dsets]) + P_full = df.pivot_table(values="p", index=["dataset", "pipe1"], columns="pipe2") + T_full = df.pivot_table(values="smd", index=["dataset", "pipe1"], columns="pipe2") P = np.full((len(algs), len(algs)), np.NaN) T = np.full((len(algs), len(algs)), np.NaN) for i in range(len(algs)): @@ -228,8 +228,8 @@ def find_significant_differences(df, perm_cutoff=20): t = T_full.loc[(slice(None), algs[i]), algs[j]] P[i, j] = combine_pvalues(p, nsubs) if np.isnan(P[i, j]): - log.info('NaN p-value found, turned to 1') - print('NaN') + log.info("NaN p-value found, turned to 1") + print("NaN") # P[i, j] = 1.0 T[i, j] = combine_effects(t, nsubs) dfP = pd.DataFrame(index=algs, columns=algs, data=P) diff --git a/moabb/analysis/plotting.py b/moabb/analysis/plotting.py index 994811984..aa772576e 100644 --- a/moabb/analysis/plotting.py +++ b/moabb/analysis/plotting.py @@ -15,14 +15,14 @@ PIPELINE_PALETTE = sea.color_palette("husl", 6) -sea.set(font='serif', style='whitegrid', palette=PIPELINE_PALETTE, color_codes=False) +sea.set(font="serif", style="whitegrid", palette=PIPELINE_PALETTE, color_codes=False) log = logging.getLogger() def _simplify_names(x): if len(x) > 10: - return x.split(' ')[0] + return x.split(" ")[0] else: return x @@ -36,7 +36,7 @@ def score_plot(data, pipelines=None): ax: pyplot Axes reference """ data = collapse_session_scores(data) - data['dataset'] = data['dataset'].apply(_simplify_names) + data["dataset"] = data["dataset"].apply(_simplify_names) if pipelines is not None: data = data[data.pipeline.isin(pipelines)] fig = plt.figure(figsize=(8.5, 11)) @@ -48,14 +48,14 @@ def score_plot(data, pipelines=None): x="score", jitter=0.15, palette=PIPELINE_PALETTE, - hue='pipeline', + hue="pipeline", dodge=True, ax=ax, alpha=0.7, ) ax.set_xlim([0, 1]) - ax.axvline(0.5, linestyle='--', color='k', linewidth=2) - ax.set_title('Scores per dataset and algorithm') + ax.axvline(0.5, linestyle="--", color="k", linewidth=2) + ax.set_title("Scores per dataset and algorithm") handles, labels = ax.get_legend_handles_labels() color_dict = {lb: h.get_facecolor()[0] for lb, h in zip(labels, handles)} plt.tight_layout() @@ -73,13 +73,13 @@ def paired_plot(data, alg1, alg2): data = collapse_session_scores(data) data = data[data.pipeline.isin([alg1, alg2])] data = data.pivot_table( - values='score', columns='pipeline', index=['subject', 'dataset'] + values="score", columns="pipeline", index=["subject", "dataset"] ) data = data.reset_index() fig = plt.figure(figsize=(11, 8.5)) ax = fig.add_subplot(111) data.plot.scatter(alg1, alg2, ax=ax) - ax.plot([0, 1], [0, 1], ls='--', c='k') + ax.plot([0, 1], [0, 1], ls="--", c="k") ax.set_xlim([0.5, 1]) ax.set_ylim([0.5, 1]) return fig @@ -99,7 +99,7 @@ def summary_plot(sig_df, effect_df, p_threshold=0.05, simplify=True): for row in annot_df.index: for col in annot_df.columns: if effect_df.loc[row, col] > 0: - txt = '{:.2f}\np={:1.0e}'.format( + txt = "{:.2f}\np={:1.0e}".format( effect_df.loc[row, col], sig_df.loc[row, col] ) else: @@ -107,7 +107,7 @@ def summary_plot(sig_df, effect_df, p_threshold=0.05, simplify=True): # TODO: current is hack if sig_df.loc[row, col] < p_threshold: sig_df.loc[row, col] = 1e-110 - txt = '' + txt = "" annot_df.loc[row, col] = txt fig = plt.figure() ax = fig.add_subplot(111) @@ -117,18 +117,18 @@ def summary_plot(sig_df, effect_df, p_threshold=0.05, simplify=True): sea.heatmap( data=-np.log(sig_df), annot=annot_df, - fmt='', + fmt="", cmap=palette, linewidths=1, - linecolor='0.8', - annot_kws={'size': 10}, + linecolor="0.8", + annot_kws={"size": 10}, cbar=False, vmin=-np.log(0.05), vmax=-np.log(1e-100), ) for lb in ax.get_xticklabels(): lb.set_rotation(45) - ax.tick_params(axis='y', rotation=0.9) + ax.tick_params(axis="y", rotation=0.9) ax.set_title("Algorithm comparison") plt.tight_layout() return fig @@ -141,20 +141,20 @@ def meta_analysis_plot(stats_df, alg1, alg2): # noqa: C901 def _marker(pval): if pval < 0.001: - return '$***$', 100 + return "$***$", 100 elif pval < 0.01: - return '$**$', 70 + return "$**$", 70 elif pval < 0.05: - return '$*$', 30 + return "$*$", 30 else: - raise ValueError('insignificant pval {}'.format(pval)) + raise ValueError("insignificant pval {}".format(pval)) assert alg1 in stats_df.pipe1.unique() assert alg2 in stats_df.pipe1.unique() df_fw = stats_df.loc[(stats_df.pipe1 == alg1) & (stats_df.pipe2 == alg2)] - df_fw = df_fw.sort_values(by='pipe1') + df_fw = df_fw.sort_values(by="pipe1") df_bk = stats_df.loc[(stats_df.pipe1 == alg2) & (stats_df.pipe2 == alg1)] - df_bk = df_bk.sort_values(by='pipe1') + df_bk = df_bk.sort_values(by="pipe1") dsets = df_fw.dataset.unique() ci = [] fig = plt.figure() @@ -163,48 +163,48 @@ def _marker(pval): pvals = [] ax = fig.add_subplot(gs[0, :-1]) ax.set_yticks(np.arange(len(dsets) + 1)) - ax.set_yticklabels(['Meta-effect'] + [_simplify_names(d) for d in dsets]) + ax.set_yticklabels(["Meta-effect"] + [_simplify_names(d) for d in dsets]) pval_ax = fig.add_subplot(gs[0, -1], sharey=ax) plt.setp(pval_ax.get_yticklabels(), visible=False) _min = 0 _max = 0 for ind, d in enumerate(dsets): - nsub = float(df_fw.loc[df_fw.dataset == d, 'nsub']) + nsub = float(df_fw.loc[df_fw.dataset == d, "nsub"]) t_dof = nsub - 1 ci.append(t.ppf(0.95, t_dof) / np.sqrt(nsub)) - v = float(df_fw.loc[df_fw.dataset == d, 'smd']) + v = float(df_fw.loc[df_fw.dataset == d, "smd"]) if v > 0: - p = df_fw.loc[df_fw.dataset == d, 'p'].item() + p = df_fw.loc[df_fw.dataset == d, "p"].item() if p < 0.05: sig_ind.append(ind) pvals.append(p) else: - p = df_bk.loc[df_bk.dataset == d, 'p'].item() + p = df_bk.loc[df_bk.dataset == d, "p"].item() if p < 0.05: sig_ind.append(ind) pvals.append(p) _min = _min if (_min < (v - ci[-1])) else (v - ci[-1]) _max = _max if (_max > (v + ci[-1])) else (v + ci[-1]) ax.plot( - np.array([v - ci[-1], v + ci[-1]]), np.ones((2,)) * (ind + 1), c='tab:grey' + np.array([v - ci[-1], v + ci[-1]]), np.ones((2,)) * (ind + 1), c="tab:grey" ) _range = max(abs(_min), abs(_max)) ax.set_xlim((0 - _range, 0 + _range)) - final_effect = combine_effects(df_fw['smd'], df_fw['nsub']) + final_effect = combine_effects(df_fw["smd"], df_fw["nsub"]) ax.scatter( - pd.concat([pd.Series([final_effect]), df_fw['smd']]), + pd.concat([pd.Series([final_effect]), df_fw["smd"]]), np.arange(len(dsets) + 1), s=np.array([50] + [30] * len(dsets)), - marker='D', - c=['k'] + ['tab:grey'] * len(dsets), + marker="D", + c=["k"] + ["tab:grey"] * len(dsets), ) for i, p in zip(sig_ind, pvals): m, s = _marker(p) - ax.scatter(df_fw['smd'].iloc[i], i + 1.4, s=s, marker=m, color='r') + ax.scatter(df_fw["smd"].iloc[i], i + 1.4, s=s, marker=m, color="r") # pvalues axis stuf pval_ax.set_xlim([-0.1, 0.1]) pval_ax.grid(False) - pval_ax.set_title('p-value', fontdict={'fontsize': 10}) + pval_ax.set_title("p-value", fontdict={"fontsize": 10}) pval_ax.set_xticks([]) for spine in pval_ax.spines.values(): spine.set_visible(False) @@ -212,48 +212,48 @@ def _marker(pval): pval_ax.text( 0, ind + 1, - horizontalalignment='center', - verticalalignment='center', - s='{:.2e}'.format(p), + horizontalalignment="center", + verticalalignment="center", + s="{:.2e}".format(p), fontsize=8, ) if final_effect > 0: - p = combine_pvalues(df_fw['p'], df_fw['nsub']) + p = combine_pvalues(df_fw["p"], df_fw["nsub"]) if p < 0.05: m, s = _marker(p) - ax.scatter([final_effect], [-0.4], s=s, marker=m, c='r') + ax.scatter([final_effect], [-0.4], s=s, marker=m, c="r") pval_ax.text( 0, 0, - horizontalalignment='center', - verticalalignment='center', - s='{:.2e}'.format(p), + horizontalalignment="center", + verticalalignment="center", + s="{:.2e}".format(p), fontsize=8, ) else: - p = combine_pvalues(df_bk['p'], df_bk['nsub']) + p = combine_pvalues(df_bk["p"], df_bk["nsub"]) if p < 0.05: m, s = _marker(p) - ax.scatter([final_effect], [-0.4], s=s, marker=m, c='r') + ax.scatter([final_effect], [-0.4], s=s, marker=m, c="r") pval_ax.text( 0, 0, - horizontalalignment='center', - verticalalignment='center', - s='{:.2e}'.format(p), + horizontalalignment="center", + verticalalignment="center", + s="{:.2e}".format(p), fontsize=8, ) ax.grid(False) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - ax.axvline(0, linestyle='--', c='k') - ax.axhline(0.5, linestyle='-', linewidth=3, c='k') - title = '< {} better{}\n{}{} better >'.format( - alg2, ' ' * (45 - len(alg2)), ' ' * (45 - len(alg1)), alg1 + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.axvline(0, linestyle="--", c="k") + ax.axhline(0.5, linestyle="-", linewidth=3, c="k") + title = "< {} better{}\n{}{} better >".format( + alg2, " " * (45 - len(alg2)), " " * (45 - len(alg1)), alg1 ) - ax.set_title(title, ha='left', ma='right', loc='left') - ax.set_xlabel('Standardized Mean Difference') + ax.set_title(title, ha="left", ma="right", loc="left") + ax.set_xlabel("Standardized Mean Difference") fig.tight_layout() return fig diff --git a/moabb/analysis/results.py b/moabb/analysis/results.py index fc065e21a..c1ff202fa 100644 --- a/moabb/analysis/results.py +++ b/moabb/analysis/results.py @@ -15,8 +15,8 @@ def get_string_rep(obj): str_repr = repr(obj.get_params()) else: str_repr = repr(obj) - str_no_addresses = re.sub('0x[a-z0-9]*', '0x__', str_repr) - return str_no_addresses.replace('\n', '').encode('utf8') + str_no_addresses = re.sub("0x[a-z0-9]*", "0x__", str_repr) + return str_no_addresses.replace("\n", "").encode("utf8") def get_digest(obj): @@ -42,7 +42,7 @@ def __init__( self, evaluation_class, paradigm_class, - suffix='', + suffix="", overwrite=False, hdf5_path=None, additional_columns=None, @@ -69,10 +69,10 @@ class that will abstract result storage self.mod_dir = os.path.abspath(hdf5_path) self.filepath = os.path.join( self.mod_dir, - 'results', + "results", paradigm_class.__name__, evaluation_class.__name__, - 'results{}.hdf5'.format('_' + suffix), + "results{}.hdf5".format("_" + suffix), ) os.makedirs(os.path.dirname(self.filepath), exist_ok=True) @@ -82,9 +82,9 @@ class that will abstract result storage os.remove(self.filepath) if not os.path.isfile(self.filepath): - with h5py.File(self.filepath, 'w') as f: - f.attrs['create_time'] = np.string_( - '{:%Y-%m-%d, %H:%M}'.format(datetime.now()) + with h5py.File(self.filepath, "w") as f: + f.attrs["create_time"] = np.string_( + "{:%Y-%m-%d, %H:%M}".format(datetime.now()) ) def add(self, results, pipelines): @@ -101,7 +101,7 @@ def to_list(res): else: return res - with h5py.File(self.filepath, 'r+') as f: + with h5py.File(self.filepath, "r+") as f: for name, data_dict in results.items(): digest = get_digest(pipelines[name]) if digest not in f.keys(): @@ -109,46 +109,46 @@ def to_list(res): f.create_group(digest) ppline_grp = f[digest] - ppline_grp.attrs['name'] = name - ppline_grp.attrs['repr'] = repr(pipelines[name]) + ppline_grp.attrs["name"] = name + ppline_grp.attrs["repr"] = repr(pipelines[name]) dlist = to_list(data_dict) d1 = dlist[0] # FIXME: handle multiple session ? - dname = d1['dataset'].code + dname = d1["dataset"].code n_add_cols = len(self.additional_columns) if dname not in ppline_grp.keys(): # create dataset subgroup if nonexistant dset = ppline_grp.create_group(dname) - dset.attrs['n_subj'] = len(d1['dataset'].subject_list) - dset.attrs['n_sessions'] = d1['dataset'].n_sessions + dset.attrs["n_subj"] = len(d1["dataset"].subject_list) + dset.attrs["n_sessions"] = d1["dataset"].n_sessions dt = h5py.special_dtype(vlen=str) - dset.create_dataset('id', (0, 2), dtype=dt, maxshape=(None, 2)) + dset.create_dataset("id", (0, 2), dtype=dt, maxshape=(None, 2)) dset.create_dataset( - 'data', (0, 3 + n_add_cols), maxshape=(None, 3 + n_add_cols) + "data", (0, 3 + n_add_cols), maxshape=(None, 3 + n_add_cols) ) - dset.attrs['channels'] = d1['n_channels'] + dset.attrs["channels"] = d1["n_channels"] dset.attrs.create( - 'columns', - ['score', 'time', 'samples', *self.additional_columns], + "columns", + ["score", "time", "samples", *self.additional_columns], dtype=dt, ) dset = ppline_grp[dname] for d in dlist: # add id and scores to group - length = len(dset['id']) + 1 - dset['id'].resize(length, 0) - dset['data'].resize(length, 0) - dset['id'][-1, :] = np.asarray([str(d['subject']), str(d['session'])]) + length = len(dset["id"]) + 1 + dset["id"].resize(length, 0) + dset["data"].resize(length, 0) + dset["id"][-1, :] = np.asarray([str(d["subject"]), str(d["session"])]) try: add_cols = [d[ac] for ac in self.additional_columns] except KeyError: raise ValueError( - f'Additional columns: {self.additional_columns} ' - f'were specified in the evaluation, but results' - f' contain only these keys: {d.keys()}.' + f"Additional columns: {self.additional_columns} " + f"were specified in the evaluation, but results" + f" contain only these keys: {d.keys()}." ) - dset['data'][-1, :] = np.asarray( - [d['score'], d['time'], d['n_samples'], *add_cols] + dset["data"][-1, :] = np.asarray( + [d["score"], d["time"], d["n_samples"], *add_cols] ) def to_dataframe(self, pipelines=None): @@ -159,24 +159,24 @@ def to_dataframe(self, pipelines=None): if pipelines is not None: digests = [get_digest(pipelines[name]) for name in pipelines] - with h5py.File(self.filepath, 'r') as f: + with h5py.File(self.filepath, "r") as f: for digest, p_group in f.items(): # skip if not in pipeline list if (pipelines is not None) & (digest not in digests): continue - name = p_group.attrs['name'] + name = p_group.attrs["name"] for dname, dset in p_group.items(): - array = np.array(dset['data']) - ids = np.array(dset['id']) - df = pd.DataFrame(array, columns=dset.attrs['columns']) - df['subject'] = ids[:, 0] - df['session'] = ids[:, 1] - df['channels'] = dset.attrs['channels'] - df['n_sessions'] = dset.attrs['n_sessions'] - df['dataset'] = dname - df['pipeline'] = name + array = np.array(dset["data"]) + ids = np.array(dset["id"]) + df = pd.DataFrame(array, columns=dset.attrs["columns"]) + df["subject"] = ids[:, 0] + df["session"] = ids[:, 1] + df["channels"] = dset.attrs["channels"] + df["n_sessions"] = dset.attrs["n_sessions"] + df["dataset"] = dname + df["pipeline"] = name df_list.append(df) return pd.concat(df_list, ignore_index=True) @@ -193,7 +193,7 @@ def _already_computed(self, pipeline, dataset, subject, session=None): """Check if we have results for a current combination of pipeline / dataset / subject. """ - with h5py.File(self.filepath, 'r') as f: + with h5py.File(self.filepath, "r") as f: # get the digest from repr digest = get_digest(pipeline) @@ -208,4 +208,4 @@ def _already_computed(self, pipeline, dataset, subject, session=None): else: # if dataset, check for subject dset = pipe_grp[dataset.code] - return str(subject).encode('utf-8') in dset['id'][:, 0] + return str(subject).encode("utf-8") in dset["id"][:, 0] diff --git a/moabb/datasets/Weibo2014.py b/moabb/datasets/Weibo2014.py index 1f5d10e33..c2545c674 100644 --- a/moabb/datasets/Weibo2014.py +++ b/moabb/datasets/Weibo2014.py @@ -1,7 +1,7 @@ -''' +""" Simple and compound motor imagery https://doi.org/10.1371/journal.pone.0114853 -''' +""" import logging import os @@ -20,25 +20,25 @@ log = logging.getLogger() FILES = [] -FILES.append('https://dataverse.harvard.edu/api/access/datafile/2499178') -FILES.append('https://dataverse.harvard.edu/api/access/datafile/2499182') -FILES.append('https://dataverse.harvard.edu/api/access/datafile/2499179') +FILES.append("https://dataverse.harvard.edu/api/access/datafile/2499178") +FILES.append("https://dataverse.harvard.edu/api/access/datafile/2499182") +FILES.append("https://dataverse.harvard.edu/api/access/datafile/2499179") def eeg_data_path(base_path, subject): - file1_subj = ['cl', 'cyy', 'kyf', 'lnn'] - file2_subj = ['ls', 'ry', 'wcf'] - file3_subj = ['wx', 'yyx', 'zd'] + file1_subj = ["cl", "cyy", "kyf", "lnn"] + file2_subj = ["ls", "ry", "wcf"] + file3_subj = ["wx", "yyx", "zd"] def get_subjects(sub_inds, sub_names, ind): - dataname = 'data{}'.format(ind) - if not os.path.isfile(os.path.join(base_path, dataname + '.zip')): + dataname = "data{}".format(ind) + if not os.path.isfile(os.path.join(base_path, dataname + ".zip")): _fetch_file( FILES[ind], - os.path.join(base_path, dataname + '.zip'), + os.path.join(base_path, dataname + ".zip"), print_destination=False, ) - with z.ZipFile(os.path.join(base_path, dataname + '.zip'), 'r') as f: + with z.ZipFile(os.path.join(base_path, dataname + ".zip"), "r") as f: os.makedirs(os.path.join(base_path, dataname), exist_ok=True) f.extractall(os.path.join(base_path, dataname)) for fname in os.listdir(os.path.join(base_path, dataname)): @@ -46,19 +46,19 @@ def get_subjects(sub_inds, sub_names, ind): if fname.startswith(prefix): os.rename( os.path.join(base_path, dataname, fname), - os.path.join(base_path, 'subject_{}.mat'.format(ind)), + os.path.join(base_path, "subject_{}.mat".format(ind)), ) - os.remove(os.path.join(base_path, dataname + '.zip')) + os.remove(os.path.join(base_path, dataname + ".zip")) shutil.rmtree(os.path.join(base_path, dataname)) - if not os.path.isfile(os.path.join(base_path, 'subject_{}.mat'.format(subject))): + if not os.path.isfile(os.path.join(base_path, "subject_{}.mat".format(subject))): if subject in range(1, 5): get_subjects(list(range(1, 5)), file1_subj, 0) elif subject in range(5, 8): get_subjects(list(range(5, 8)), file2_subj, 1) elif subject in range(8, 11): get_subjects(list(range(8, 11)), file3_subj, 2) - return os.path.join(base_path, 'subject_{}.mat'.format(subject)) + return os.path.join(base_path, "subject_{}.mat".format(subject)) class Weibo2014(BaseDataset): @@ -111,11 +111,11 @@ def __init__(self): right_hand_left_foot=6, rest=7, ), - code='Weibo 2014', + code="Weibo 2014", # Full trial w/ rest is 0-8 interval=[3, 7], - paradigm='imagery', - doi='10.1371/journal.pone.0114853', + paradigm="imagery", + doi="10.1371/journal.pone.0114853", ) def _get_single_subject_data(self, subject): @@ -128,7 +128,7 @@ def _get_single_subject_data(self, subject): struct_as_record=False, verify_compressed_data_integrity=False, ) - montage = mne.channels.make_standard_montage('standard_1005') + montage = mne.channels.make_standard_montage("standard_1005") # fmt: off ch_names = [ @@ -139,16 +139,16 @@ def _get_single_subject_data(self, subject): ] # fmt: on - ch_types = ['eeg'] * 62 + ['eog'] * 2 + ch_types = ["eeg"] * 62 + ["eog"] * 2 # FIXME not sure what are those CB1 / CB2 - ch_types[57] = 'misc' - ch_types[61] = 'misc' + ch_types[57] = "misc" + ch_types[61] = "misc" info = mne.create_info( - ch_names=ch_names + ['STIM014'], ch_types=ch_types + ['stim'], sfreq=200 + ch_names=ch_names + ["STIM014"], ch_types=ch_types + ["stim"], sfreq=200 ) # until we get the channel names montage is None - event_ids = data['label'].ravel() - raw_data = np.transpose(data['data'], axes=[2, 0, 1]) + event_ids = data["label"].ravel() + raw_data = np.transpose(data["data"], axes=[2, 0, 1]) # de-mean each trial raw_data = raw_data - np.mean(raw_data, axis=2, keepdims=True) raw_events = np.zeros((raw_data.shape[0], 1, raw_data.shape[2])) @@ -164,14 +164,14 @@ def _get_single_subject_data(self, subject): data=np.concatenate(list(data), axis=1), info=info, verbose=False ) raw.set_montage(montage) - return {'session_0': {'run_0': raw}} + return {"session_0": {"run_0": raw}} def data_path( self, subject, path=None, force_update=False, update_path=None, verbose=None ): if subject not in self.subject_list: raise (ValueError("Invalid subject number")) - key = 'MNE_DATASETS_WEIBO2014_PATH' + key = "MNE_DATASETS_WEIBO2014_PATH" path = _get_path(path, key, "Weibo 2014") _do_path_update(path, True, key, "Weibo 2014") basepath = os.path.join(path, "MNE-weibo-2014") diff --git a/moabb/datasets/Zhou2016.py b/moabb/datasets/Zhou2016.py index 63229a4b3..354f18eac 100644 --- a/moabb/datasets/Zhou2016.py +++ b/moabb/datasets/Zhou2016.py @@ -1,7 +1,7 @@ -''' +""" Simple and compound motor imagery. https://doi.org/10.1371/journal.pone.0114853 -''' +""" import os import shutil @@ -16,36 +16,36 @@ from .base import BaseDataset -DATA_PATH = 'https://ndownloader.figshare.com/files/3662952' +DATA_PATH = "https://ndownloader.figshare.com/files/3662952" def local_data_path(base_path, subject): - if not os.path.isdir(os.path.join(base_path, 'subject_{}'.format(subject))): - if not os.path.isdir(os.path.join(base_path, 'data')): + if not os.path.isdir(os.path.join(base_path, "subject_{}".format(subject))): + if not os.path.isdir(os.path.join(base_path, "data")): _fetch_file( - DATA_PATH, os.path.join(base_path, 'data.zip'), print_destination=False + DATA_PATH, os.path.join(base_path, "data.zip"), print_destination=False ) - with z.ZipFile(os.path.join(base_path, 'data.zip'), 'r') as f: + with z.ZipFile(os.path.join(base_path, "data.zip"), "r") as f: f.extractall(base_path) - os.remove(os.path.join(base_path, 'data.zip')) - datapath = os.path.join(base_path, 'data') + os.remove(os.path.join(base_path, "data.zip")) + datapath = os.path.join(base_path, "data") for i in range(1, 5): - os.makedirs(os.path.join(base_path, 'subject_{}'.format(i))) + os.makedirs(os.path.join(base_path, "subject_{}".format(i))) for session in range(1, 4): - for run in ['A', 'B']: + for run in ["A", "B"]: os.rename( - os.path.join(datapath, 'S{}_{}{}.cnt'.format(i, session, run)), + os.path.join(datapath, "S{}_{}{}.cnt".format(i, session, run)), os.path.join( base_path, - 'subject_{}'.format(i), - '{}{}.cnt'.format(session, run), + "subject_{}".format(i), + "{}{}.cnt".format(session, run), ), ) - shutil.rmtree(os.path.join(base_path, 'data')) - subjpath = os.path.join(base_path, 'subject_{}'.format(subject)) + shutil.rmtree(os.path.join(base_path, "data")) + subjpath = os.path.join(base_path, "subject_{}".format(subject)) return [ - [os.path.join(subjpath, '{}{}.cnt'.format(y, x)) for x in ['A', 'B']] - for y in ['1', '2', '3'] + [os.path.join(subjpath, "{}{}.cnt".format(y, x)) for x in ["A", "B"]] + for y in ["1", "2", "3"] ] @@ -83,12 +83,12 @@ def __init__(self): subjects=list(range(1, 5)), sessions_per_subject=3, events=dict(left_hand=1, right_hand=2, feet=3), - code='Zhou 2016', + code="Zhou 2016", # MI 1-6s, prepare 0-1, break 6-10 # boundary effects interval=[0, 5], - paradigm='imagery', - doi='10.1371/journal.pone.0162657', + paradigm="imagery", + doi="10.1371/journal.pone.0162657", ) def _get_single_subject_data(self, subject): @@ -97,18 +97,18 @@ def _get_single_subject_data(self, subject): out = {} for sess_ind, runlist in enumerate(files): - sess_key = 'session_{}'.format(sess_ind) + sess_key = "session_{}".format(sess_ind) out[sess_key] = {} for run_ind, fname in enumerate(runlist): - run_key = 'run_{}'.format(run_ind) - raw = read_raw_cnt(fname, preload=True, eog=['VEOU', 'VEOL']) - stim = raw.annotations.description.astype(np.dtype('<10U')) - stim[stim == '1'] = 'left_hand' - stim[stim == '2'] = 'right_hand' - stim[stim == '3'] = 'feet' + run_key = "run_{}".format(run_ind) + raw = read_raw_cnt(fname, preload=True, eog=["VEOU", "VEOL"]) + stim = raw.annotations.description.astype(np.dtype("<10U")) + stim[stim == "1"] = "left_hand" + stim[stim == "2"] = "right_hand" + stim[stim == "3"] = "feet" raw.annotations.description = stim out[sess_key][run_key] = raw - out[sess_key][run_key].set_montage(make_standard_montage('standard_1005')) + out[sess_key][run_key].set_montage(make_standard_montage("standard_1005")) return out def data_path( @@ -116,7 +116,7 @@ def data_path( ): if subject not in self.subject_list: raise (ValueError("Invalid subject number")) - key = 'MNE_DATASETS_ZHOU2016_PATH' + key = "MNE_DATASETS_ZHOU2016_PATH" path = _get_path(path, key, "Zhou 2016") _do_path_update(path, True, key, "Zhou 2016") basepath = os.path.join(path, "MNE-zhou-2016") diff --git a/moabb/datasets/alex_mi.py b/moabb/datasets/alex_mi.py index f488f4d4d..2e234aefd 100644 --- a/moabb/datasets/alex_mi.py +++ b/moabb/datasets/alex_mi.py @@ -8,7 +8,7 @@ from .base import BaseDataset -ALEX_URL = 'https://zenodo.org/record/806023/files/' +ALEX_URL = "https://zenodo.org/record/806023/files/" class AlexMI(BaseDataset): @@ -43,14 +43,14 @@ def __init__(self): subjects=list(range(1, 9)), sessions_per_subject=1, events=dict(right_hand=2, feet=3, rest=4), - code='Alexandre Motor Imagery', + code="Alexandre Motor Imagery", interval=[0, 3], - paradigm='imagery', + paradigm="imagery", ) def _get_single_subject_data(self, subject): """return data for a single subject""" - raw = Raw(self.data_path(subject), preload=True, verbose='ERROR') + raw = Raw(self.data_path(subject), preload=True, verbose="ERROR") return {"session_0": {"run_0": raw}} def data_path( @@ -58,5 +58,5 @@ def data_path( ): if subject not in self.subject_list: raise (ValueError("Invalid subject number")) - url = '{:s}subject{:d}.raw.fif'.format(ALEX_URL, subject) - return dl.data_path(url, 'ALEXEEG', path, force_update, update_path, verbose) + url = "{:s}subject{:d}.raw.fif".format(ALEX_URL, subject) + return dl.data_path(url, "ALEXEEG", path, force_update, update_path, verbose) diff --git a/moabb/datasets/base.py b/moabb/datasets/base.py index c5c562437..73148b251 100644 --- a/moabb/datasets/base.py +++ b/moabb/datasets/base.py @@ -103,12 +103,12 @@ def get_data(self, subjects=None): subjects = self.subject_list if not isinstance(subjects, list): - raise (ValueError('subjects must be a list')) + raise (ValueError("subjects must be a list")) data = dict() for subject in subjects: if subject not in self.subject_list: - raise ValueError('Invalid subject {:d} given'.format(subject)) + raise ValueError("Invalid subject {:d} given".format(subject)) data[subject] = self._get_single_subject_data(subject) return data diff --git a/moabb/datasets/bbci_eeg_fnirs.py b/moabb/datasets/bbci_eeg_fnirs.py index 5d823db9d..451b39d6b 100644 --- a/moabb/datasets/bbci_eeg_fnirs.py +++ b/moabb/datasets/bbci_eeg_fnirs.py @@ -17,49 +17,49 @@ from .base import BaseDataset -SHIN_URL = 'http://doc.ml.tu-berlin.de/hBCI' +SHIN_URL = "http://doc.ml.tu-berlin.de/hBCI" def eeg_data_path(base_path, subject): datapath = op.join( - base_path, 'EEG', 'subject {:02d}'.format(subject), 'with occular artifact' + base_path, "EEG", "subject {:02d}".format(subject), "with occular artifact" ) - if not op.isfile(op.join(datapath, 'cnt.mat')): - if not op.isdir(op.join(base_path, 'EEG')): - os.makedirs(op.join(base_path, 'EEG')) + if not op.isfile(op.join(datapath, "cnt.mat")): + if not op.isdir(op.join(base_path, "EEG")): + os.makedirs(op.join(base_path, "EEG")) intervals = [[1, 5], [6, 10], [11, 15], [16, 20], [21, 25], [26, 29]] for low, high in intervals: if subject >= low and subject <= high: - if not op.isfile(op.join(base_path, 'EEG.zip')): + if not op.isfile(op.join(base_path, "EEG.zip")): _fetch_file( - '{}/EEG/EEG_{:02d}-{:02d}.zip'.format(SHIN_URL, low, high), - op.join(base_path, 'EEG.zip'), + "{}/EEG/EEG_{:02d}-{:02d}.zip".format(SHIN_URL, low, high), + op.join(base_path, "EEG.zip"), print_destination=False, ) - with z.ZipFile(op.join(base_path, 'EEG.zip'), 'r') as f: - f.extractall(op.join(base_path, 'EEG')) - os.remove(op.join(base_path, 'EEG.zip')) + with z.ZipFile(op.join(base_path, "EEG.zip"), "r") as f: + f.extractall(op.join(base_path, "EEG")) + os.remove(op.join(base_path, "EEG.zip")) break - assert op.isfile(op.join(datapath, 'cnt.mat')), op.join(datapath, 'cnt.mat') - return [op.join(datapath, fn) for fn in ['cnt.mat', 'mrk.mat']] + assert op.isfile(op.join(datapath, "cnt.mat")), op.join(datapath, "cnt.mat") + return [op.join(datapath, fn) for fn in ["cnt.mat", "mrk.mat"]] def fnirs_data_path(path, subject): - datapath = op.join(path, 'NIRS', 'subject {:02d}'.format(subject)) - if not op.isfile(op.join(datapath, 'mrk.mat')): + datapath = op.join(path, "NIRS", "subject {:02d}".format(subject)) + if not op.isfile(op.join(datapath, "mrk.mat")): # fNIRS - if not op.isfile(op.join(path, 'fNIRS.zip')): + if not op.isfile(op.join(path, "fNIRS.zip")): _fetch_file( - 'http://doc.ml.tu-berlin.de/hBCI/NIRS/NIRS_01-29.zip', - op.join(path, 'fNIRS.zip'), + "http://doc.ml.tu-berlin.de/hBCI/NIRS/NIRS_01-29.zip", + op.join(path, "fNIRS.zip"), print_destination=False, ) - if not op.isdir(op.join(path, 'NIRS')): - os.makedirs(op.join(path, 'NIRS')) - with z.ZipFile(op.join(path, 'fNIRS.zip'), 'r') as f: - f.extractall(op.join(path, 'NIRS')) - os.remove(op.join(path, 'fNIRS.zip')) - return [op.join(datapath, fn) for fn in ['cnt.mat', 'mrk.mat']] + if not op.isdir(op.join(path, "NIRS")): + os.makedirs(op.join(path, "NIRS")) + with z.ZipFile(op.join(path, "fNIRS.zip"), "r") as f: + f.extractall(op.join(path, "NIRS")) + os.remove(op.join(path, "fNIRS.zip")) + return [op.join(datapath, fn) for fn in ["cnt.mat", "mrk.mat"]] class Shin2017(BaseDataset): @@ -77,12 +77,12 @@ def __init__(self, fnirs=False, motor_imagery=True, mental_arithmetic=False): n_sessions = 0 if motor_imagery: events.update(dict(left_hand=1, right_hand=2)) - paradigms.append('imagery') + paradigms.append("imagery") n_sessions += 3 if mental_arithmetic: events.update(dict(substraction=3, rest=4)) - paradigms.append('arithmetic') + paradigms.append("arithmetic") n_sessions += 3 self.motor_imagery = motor_imagery @@ -92,11 +92,11 @@ def __init__(self, fnirs=False, motor_imagery=True, mental_arithmetic=False): subjects=list(range(1, 30)), sessions_per_subject=n_sessions, events=events, - code='Shin2017', + code="Shin2017", # marker is for *task* start not cue start interval=[0, 10], - paradigm=('/').join(paradigms), - doi='10.1109/TNSRE.2016.2628057', + paradigm=("/").join(paradigms), + doi="10.1109/TNSRE.2016.2628057", ) if fnirs: @@ -106,21 +106,21 @@ def __init__(self, fnirs=False, motor_imagery=True, mental_arithmetic=False): def _get_single_subject_data(self, subject): """return data for a single subject""" fname, fname_mrk = self.data_path(subject) - data = loadmat(fname, squeeze_me=True, struct_as_record=False)['cnt'] - mrk = loadmat(fname_mrk, squeeze_me=True, struct_as_record=False)['mrk'] + data = loadmat(fname, squeeze_me=True, struct_as_record=False)["cnt"] + mrk = loadmat(fname_mrk, squeeze_me=True, struct_as_record=False)["mrk"] sessions = {} # motor imagery if self.motor_imagery: for ii in [0, 2, 4]: session = self._convert_one_session(data, mrk, ii, trig_offset=0) - sessions['session_%d' % ii] = session + sessions["session_%d" % ii] = session # arithmetic/rest if self.mental_arithmetic: for ii in [1, 3, 5]: session = self._convert_one_session(data, mrk, ii, trig_offset=2) - sessions['session_%d' % ii] = session + sessions["session_%d" % ii] = session return sessions @@ -130,14 +130,14 @@ def _convert_one_session(self, data, mrk, session, trig_offset=0): idx = (mrk[session].time - 1) // 5 trig[0, idx] = mrk[session].event.desc // 16 + trig_offset eeg = np.vstack([eeg, trig]) - ch_names = list(data[session].clab) + ['Stim'] - ch_types = ['eeg'] * 30 + ['eog'] * 2 + ['stim'] + ch_names = list(data[session].clab) + ["Stim"] + ch_types = ["eeg"] * 30 + ["eog"] * 2 + ["stim"] - montage = make_standard_montage('standard_1005') + montage = make_standard_montage("standard_1005") info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=200.0) raw = RawArray(data=eeg, info=info, verbose=False) raw.set_montage(montage) - return {'run_0': raw} + return {"run_0": raw} def data_path( self, subject, path=None, force_update=False, update_path=None, verbose=None @@ -145,16 +145,16 @@ def data_path( if subject not in self.subject_list: raise (ValueError("Invalid subject number")) - key = 'MNE_DATASETS_BBCIFNIRS_PATH' - path = _get_path(path, key, 'BBCI EEG-fNIRS') + key = "MNE_DATASETS_BBCIFNIRS_PATH" + path = _get_path(path, key, "BBCI EEG-fNIRS") # FIXME: this always update the path - _do_path_update(path, True, key, 'BBCI EEG-fNIRS') - if not op.isdir(op.join(path, 'MNE-eegfnirs-data')): - os.makedirs(op.join(path, 'MNE-eegfnirs-data')) + _do_path_update(path, True, key, "BBCI EEG-fNIRS") + if not op.isdir(op.join(path, "MNE-eegfnirs-data")): + os.makedirs(op.join(path, "MNE-eegfnirs-data")) if self.fnirs: - return fnirs_data_path(op.join(path, 'MNE-eegfnirs-data'), subject) + return fnirs_data_path(op.join(path, "MNE-eegfnirs-data"), subject) else: - return eeg_data_path(op.join(path, 'MNE-eegfnirs-data'), subject) + return eeg_data_path(op.join(path, "MNE-eegfnirs-data"), subject) class Shin2017A(Shin2017): @@ -274,7 +274,7 @@ class Shin2017A(Shin2017): def __init__(self): super().__init__(fnirs=False, motor_imagery=True, mental_arithmetic=False) - self.code = 'Shin2017A' + self.code = "Shin2017A" class Shin2017B(Shin2017): @@ -371,4 +371,4 @@ class Shin2017B(Shin2017): def __init__(self): super().__init__(fnirs=False, motor_imagery=False, mental_arithmetic=True) - self.code = 'Shin2017B' + self.code = "Shin2017B" diff --git a/moabb/datasets/bnci.py b/moabb/datasets/bnci.py index 42e284fdd..6138caa6e 100644 --- a/moabb/datasets/bnci.py +++ b/moabb/datasets/bnci.py @@ -12,18 +12,18 @@ from moabb.datasets.base import BaseDataset -BNCI_URL = 'http://bnci-horizon-2020.eu/database/data-sets/' -BBCI_URL = 'http://doc.ml.tu-berlin.de/bbci/' +BNCI_URL = "http://bnci-horizon-2020.eu/database/data-sets/" +BBCI_URL = "http://doc.ml.tu-berlin.de/bbci/" def data_path(url, path=None, force_update=False, update_path=None, verbose=None): - return [dl.data_path(url, 'BNCI', path, force_update, update_path, verbose)] + return [dl.data_path(url, "BNCI", path, force_update, update_path, verbose)] @verbose def load_data( subject, - dataset='001-2014', + dataset="001-2014", path=None, force_update=False, update_path=None, @@ -66,33 +66,33 @@ def load_data( dictonary containing events and their code. """ dataset_list = { - '001-2014': _load_data_001_2014, - '002-2014': _load_data_002_2014, - '004-2014': _load_data_004_2014, - '008-2014': _load_data_008_2014, - '009-2014': _load_data_009_2014, - '001-2015': _load_data_001_2015, - '003-2015': _load_data_003_2015, - '004-2015': _load_data_004_2015, - '009-2015': _load_data_009_2015, - '010-2015': _load_data_010_2015, - '012-2015': _load_data_012_2015, - '013-2015': _load_data_013_2015, + "001-2014": _load_data_001_2014, + "002-2014": _load_data_002_2014, + "004-2014": _load_data_004_2014, + "008-2014": _load_data_008_2014, + "009-2014": _load_data_009_2014, + "001-2015": _load_data_001_2015, + "003-2015": _load_data_003_2015, + "004-2015": _load_data_004_2015, + "009-2015": _load_data_009_2015, + "010-2015": _load_data_010_2015, + "012-2015": _load_data_012_2015, + "013-2015": _load_data_013_2015, } baseurl_list = { - '001-2014': BNCI_URL, - '002-2014': BNCI_URL, - '001-2015': BNCI_URL, - '004-2014': BNCI_URL, - '008-2014': BNCI_URL, - '009-2014': BNCI_URL, - '003-2015': BNCI_URL, - '004-2015': BNCI_URL, - '009-2015': BBCI_URL, - '010-2015': BBCI_URL, - '012-2015': BBCI_URL, - '013-2015': BNCI_URL, + "001-2014": BNCI_URL, + "002-2014": BNCI_URL, + "001-2015": BNCI_URL, + "004-2014": BNCI_URL, + "008-2014": BNCI_URL, + "009-2014": BNCI_URL, + "003-2015": BNCI_URL, + "004-2015": BNCI_URL, + "009-2015": BBCI_URL, + "010-2015": BBCI_URL, + "012-2015": BBCI_URL, + "013-2015": BNCI_URL, } if dataset not in dataset_list.keys(): @@ -126,15 +126,15 @@ def _load_data_001_2014( "EOG1", "EOG2", "EOG3", ] # fmt: on - ch_types = ['eeg'] * 22 + ['eog'] * 3 + ch_types = ["eeg"] * 22 + ["eog"] * 3 sessions = {} - for r in ['T', 'E']: - url = '{u}001-2014/A{s:02d}{r}.mat'.format(u=base_url, s=subject, r=r) + for r in ["T", "E"]: + url = "{u}001-2014/A{s:02d}{r}.mat".format(u=base_url, s=subject, r=r) filename = data_path(url, path, force_update, update_path) runs, ev = _convert_mi(filename[0], ch_names, ch_types) # FIXME: deal with run with no event (1:3) and name them - sessions['session_%s' % r] = {'run_%d' % ii: run for ii, run in enumerate(runs)} + sessions["session_%s" % r] = {"run_%d" % ii: run for ii, run in enumerate(runs)} return sessions @@ -152,16 +152,16 @@ def _load_data_002_2014( raise ValueError("Subject must be between 1 and 14. Got %d." % subject) runs = [] - for r in ['T', 'E']: - url = '{u}002-2014/S{s:02d}{r}.mat'.format(u=base_url, s=subject, r=r) + for r in ["T", "E"]: + url = "{u}002-2014/S{s:02d}{r}.mat".format(u=base_url, s=subject, r=r) filename = data_path(url, path, force_update, update_path)[0] # FIXME: electrode position and name are not provided directly. - raws, _ = _convert_mi(filename, None, ['eeg'] * 15) + raws, _ = _convert_mi(filename, None, ["eeg"] * 15) runs.extend(raws) - runs = {'run_%d' % ii: run for ii, run in enumerate(runs)} - return {'session_0': runs} + runs = {"run_%d" % ii: run for ii, run in enumerate(runs)} + return {"session_0": runs} @verbose @@ -180,16 +180,16 @@ def _load_data_004_2014( # fmt: off ch_names = ["C3", "Cz", "C4", "EOG1", "EOG2", "EOG3", ] # fmt: on - ch_types = ['eeg'] * 3 + ['eog'] * 3 + ch_types = ["eeg"] * 3 + ["eog"] * 3 sessions = [] - for r in ['T', 'E']: - url = '{u}004-2014/B{s:02d}{r}.mat'.format(u=base_url, s=subject, r=r) + for r in ["T", "E"]: + url = "{u}004-2014/B{s:02d}{r}.mat".format(u=base_url, s=subject, r=r) filename = data_path(url, path, force_update, update_path)[0] raws, _ = _convert_mi(filename, ch_names, ch_types) sessions.extend(raws) - sessions = {'session_%d' % ii: {'run_0': run} for ii, run in enumerate(sessions)} + sessions = {"session_%d" % ii: {"run_0": run} for ii, run in enumerate(sessions)} return sessions @@ -206,14 +206,14 @@ def _load_data_008_2014( if (subject < 1) or (subject > 8): raise ValueError("Subject must be between 1 and 8. Got %d." % subject) - url = '{u}008-2014/A{s:02d}.mat'.format(u=base_url, s=subject) + url = "{u}008-2014/A{s:02d}.mat".format(u=base_url, s=subject) filename = data_path(url, path, force_update, update_path)[0] from scipy.io import loadmat - run = loadmat(filename, struct_as_record=False, squeeze_me=True)['data'] + run = loadmat(filename, struct_as_record=False, squeeze_me=True)["data"] raw, event_id = _convert_run_p300_sl(run, verbose=verbose) - sessions = {'session_0': {'run_0': raw}} + sessions = {"session_0": {"run_0": raw}} return sessions @@ -233,11 +233,11 @@ def _load_data_009_2014( # FIXME there is two type of speller, grid speller and geo-speller. # we load only grid speller data - url = '{u}009-2014/A{s:02d}S.mat'.format(u=base_url, s=subject) + url = "{u}009-2014/A{s:02d}S.mat".format(u=base_url, s=subject) filename = data_path(url, path, force_update, update_path)[0] from scipy.io import loadmat - data = loadmat(filename, struct_as_record=False, squeeze_me=True)['data'] + data = loadmat(filename, struct_as_record=False, squeeze_me=True)["data"] raws = [] event_id = {} for run in data: @@ -246,9 +246,9 @@ def _load_data_009_2014( event_id.update(ev) sessions = {} - sessions['session_0'] = {} + sessions["session_0"] = {} for i, rawi in enumerate(raws): - sessions['session_0']['run_' + str(i)] = rawi + sessions["session_0"]["run_" + str(i)] = rawi return sessions @@ -277,14 +277,14 @@ def _load_data_001_2015( "C2", "C4", "C6", "CP3", "CPz", "CP4", ] # fmt: on - ch_types = ['eeg'] * 13 + ch_types = ["eeg"] * 13 sessions = {} for r in ses: - url = '{u}001-2015/S{s:02d}{r}.mat'.format(u=base_url, s=subject, r=r) + url = "{u}001-2015/S{s:02d}{r}.mat".format(u=base_url, s=subject, r=r) filename = data_path(url, path, force_update, update_path) runs, ev = _convert_mi(filename[0], ch_names, ch_types) - sessions['session_%s' % r] = {'run_%d' % ii: run for ii, run in enumerate(runs)} + sessions["session_%s" % r] = {"run_%d" % ii: run for ii, run in enumerate(runs)} return sessions @@ -301,13 +301,13 @@ def _load_data_003_2015( if (subject < 1) or (subject > 10): raise ValueError("Subject must be between 1 and 12. Got %d." % subject) - url = '{u}003-2015/s{s:d}.mat'.format(u=base_url, s=subject) + url = "{u}003-2015/s{s:d}.mat".format(u=base_url, s=subject) filename = data_path(url, path, force_update, update_path)[0] from scipy.io import loadmat data = loadmat(filename, struct_as_record=False, squeeze_me=True) - data = data['s%d' % subject] + data = data["s%d" % subject] sfreq = 256.0 # fmt: off @@ -316,13 +316,13 @@ def _load_data_003_2015( ] # fmt: on - ch_types = ['eeg'] * 8 + ['stim'] * 2 - montage = make_standard_montage('standard_1005') + ch_types = ["eeg"] * 8 + ["stim"] * 2 + montage = make_standard_montage("standard_1005") info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq) sessions = {} - sessions['session_0'] = {} + sessions["session_0"] = {} for ri, run in enumerate([data.train, data.test]): # flash events on the channel 9 flashs = run[9:10] @@ -332,11 +332,11 @@ def _load_data_003_2015( if len(flash_code) == 36: # char mode - evd = {'Char%d' % ii: (ii + 2) for ii in range(1, 37)} + evd = {"Char%d" % ii: (ii + 2) for ii in range(1, 37)} else: # row / column mode - evd = {'Col%d' % ii: (ii + 2) for ii in range(1, 7)} - evd.update({'Row%d' % ii: (ii + 8) for ii in range(1, 7)}) + evd = {"Col%d" % ii: (ii + 2) for ii in range(1, 7)} + evd.update({"Row%d" % ii: (ii + 8) for ii in range(1, 7)}) # target events are on channel 10 targets = np.zeros_like(flashs) @@ -345,7 +345,7 @@ def _load_data_003_2015( eeg_data = np.r_[run[1:-2] * 1e-6, targets, flashs] raw = RawArray(data=eeg_data, info=info, verbose=verbose) raw.set_montage(montage) - sessions['session_0']['run_' + str(ri)] = raw + sessions["session_0"]["run_" + str(ri)] = raw return sessions @@ -365,7 +365,7 @@ def _load_data_004_2015( subjects = ["A", "C", "D", "E", "F", "G", "H", "J", "L"] - url = '{u}004-2015/{s}.mat'.format(u=base_url, s=subjects[subject - 1]) + url = "{u}004-2015/{s}.mat".format(u=base_url, s=subjects[subject - 1]) filename = data_path(url, path, force_update, update_path)[0] # fmt: off @@ -375,9 +375,9 @@ def _load_data_004_2015( "P2", "P4", "P6", "P8", "PO3", "PO4", "O1", "O2", ] # fmt: on - ch_types = ['eeg'] * 30 + ch_types = ["eeg"] * 30 raws, ev = _convert_mi(filename, ch_names, ch_types) - sessions = {'session_%d' % ii: {'run_0': run} for ii, run in enumerate(raws)} + sessions = {"session_%d" % ii: {"run_0": run} for ii, run in enumerate(raws)} return sessions @@ -402,10 +402,10 @@ def _load_data_009_2015( ] # fmt: on s = subjects[subject - 1] - url = '{u}BNCIHorizon2020-AMUSE/AMUSE_VP{s}.mat'.format(u=base_url, s=s) + url = "{u}BNCIHorizon2020-AMUSE/AMUSE_VP{s}.mat".format(u=base_url, s=s) filename = data_path(url, path, force_update, update_path)[0] - ch_types = ['eeg'] * 60 + ['eog'] * 2 + ch_types = ["eeg"] * 60 + ["eog"] * 2 return _convert_bbci(filename, ch_types, verbose=None) @@ -431,10 +431,10 @@ def _load_data_010_2015( # fmt: on s = subjects[subject - 1] - url = '{u}BNCIHorizon2020-RSVP/RSVP_VP{s}.mat'.format(u=base_url, s=s) + url = "{u}BNCIHorizon2020-RSVP/RSVP_VP{s}.mat".format(u=base_url, s=s) filename = data_path(url, path, force_update, update_path)[0] - ch_types = ['eeg'] * 63 + ch_types = ["eeg"] * 63 return _convert_bbci(filename, ch_types, verbose=None) @@ -459,10 +459,10 @@ def _load_data_012_2015( # fmt: on s = subjects[subject - 1] - url = '{u}BNCIHorizon2020-PASS2D/PASS2D_VP{s}.mat'.format(u=base_url, s=s) + url = "{u}BNCIHorizon2020-PASS2D/PASS2D_VP{s}.mat".format(u=base_url, s=s) filename = data_path(url, path, force_update, update_path)[0] - ch_types = ['eeg'] * 63 + ch_types = ["eeg"] * 63 return _convert_bbci(filename, ch_types, verbose=None) @@ -481,8 +481,8 @@ def _load_data_013_2015( raise ValueError("Subject must be between 1 and 6. Got %d." % subject) data_paths = [] - for r in ['s1', 's2']: - url = '{u}013-2015/Subject{s:02d}_{r}.mat'.format(u=base_url, s=subject, r=r) + for r in ["s1", "s2"]: + url = "{u}013-2015/Subject{s:02d}_{r}.mat".format(u=base_url, s=subject, r=r) data_paths.extend(data_path(url, path, force_update, update_path)) raws = [] @@ -491,7 +491,7 @@ def _load_data_013_2015( for filename in data_paths: data = loadmat(filename, struct_as_record=False, squeeze_me=True) - for run in data['run']: + for run in data["run"]: raw, evd = _convert_run_epfl(run, verbose=verbose) raws.append(raw) event_id.update(evd) @@ -509,10 +509,10 @@ def _convert_mi(filename, ch_names, ch_types): event_id = {} data = loadmat(filename, struct_as_record=False, squeeze_me=True) - if isinstance(data['data'], np.ndarray): - run_array = data['data'] + if isinstance(data["data"], np.ndarray): + run_array = data["data"] else: - run_array = [data['data']] + run_array = [data["data"]] for run in run_array: raw, evd = _convert_run(run, ch_names, ch_types, None) @@ -527,14 +527,14 @@ def _convert_mi(filename, ch_names, ch_types): def standardize_keys(d): master_list = [ - ['both feet', 'feet'], - ['left hand', 'left_hand'], - ['right hand', 'right_hand'], - ['FEET', 'feet'], - ['HAND', 'right_hand'], - ['NAV', 'navigation'], - ['SUB', 'subtraction'], - ['WORD', 'word_ass'], + ["both feet", "feet"], + ["left hand", "left_hand"], + ["right hand", "right_hand"], + ["FEET", "feet"], + ["HAND", "right_hand"], + ["NAV", "navigation"], + ["SUB", "subtraction"], + ["WORD", "word_ass"], ] for old, new in master_list: if old in d.keys(): @@ -547,16 +547,16 @@ def _convert_run(run, ch_names=None, ch_types=None, verbose=None): # parse eeg data event_id = {} n_chan = run.X.shape[1] - montage = make_standard_montage('standard_1005') + montage = make_standard_montage("standard_1005") eeg_data = 1e-6 * run.X sfreq = run.fs if not ch_names: - ch_names = ['EEG%d' % ch for ch in range(1, n_chan + 1)] + ch_names = ["EEG%d" % ch for ch in range(1, n_chan + 1)] montage = None # no montage if not ch_types: - ch_types = ['eeg'] * n_chan + ch_types = ["eeg"] * n_chan trigger = np.zeros((len(eeg_data), 1)) # some runs does not contains trials i.e baseline runs @@ -566,8 +566,8 @@ def _convert_run(run, ch_names=None, ch_types=None, verbose=None): return None, None eeg_data = np.c_[eeg_data, trigger] - ch_names = ch_names + ['stim'] - ch_types = ch_types + ['stim'] + ch_names = ch_names + ["stim"] + ch_types = ch_types + ["stim"] event_id = {ev: (ii + 1) for ii, ev in enumerate(run.classes)} info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq) raw = RawArray(data=eeg_data.T, info=info, verbose=verbose) @@ -578,11 +578,11 @@ def _convert_run(run, ch_names=None, ch_types=None, verbose=None): @verbose def _convert_run_p300_sl(run, verbose=None): """Convert one p300 run from santa lucia file format.""" - montage = make_standard_montage('standard_1005') + montage = make_standard_montage("standard_1005") eeg_data = 1e-6 * run.X sfreq = 256 - ch_names = list(run.channels) + ['Target stim', 'Flash stim'] - ch_types = ['eeg'] * len(run.channels) + ['stim'] * 2 + ch_names = list(run.channels) + ["Target stim", "Flash stim"] + ch_types = ["eeg"] * len(run.channels) + ["stim"] * 2 flash_stim = run.y_stim flash_stim[flash_stim > 0] += 2 @@ -603,7 +603,7 @@ def _convert_bbci(filename, ch_types, verbose=None): from scipy.io import loadmat data = loadmat(filename, struct_as_record=False, squeeze_me=True) - for run in data['data']: + for run in data["data"]: raw, evd = _convert_run_bbci(run, ch_types, verbose) raws.append(raw) event_id.update(evd) @@ -615,7 +615,7 @@ def _convert_bbci(filename, ch_types, verbose=None): def _convert_run_bbci(run, ch_types, verbose=None): """Convert one run to raw.""" # parse eeg data - montage = make_standard_montage('standard_1005') + montage = make_standard_montage("standard_1005") eeg_data = 1e-6 * run.X sfreq = run.fs @@ -627,12 +627,12 @@ def _convert_run_bbci(run, ch_types, verbose=None): flash = np.zeros((len(eeg_data), 1)) flash[run.trial - 1, 0] = run.y_stim + 2 - ev_fl = {'Stim%d' % (stim): (stim + 2) for stim in np.unique(run.y_stim)} + ev_fl = {"Stim%d" % (stim): (stim + 2) for stim in np.unique(run.y_stim)} event_id.update(ev_fl) eeg_data = np.c_[eeg_data, trigger, flash] - ch_names = ch_names + ['Target', 'Flash'] - ch_types = ch_types + ['stim'] * 2 + ch_names = ch_names + ["Target", "Flash"] + ch_types = ch_types + ["stim"] * 2 info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq) raw = RawArray(data=eeg_data.T, info=info, verbose=verbose) @@ -646,12 +646,12 @@ def _convert_run_epfl(run, verbose=None): # parse eeg data event_id = {} - montage = make_standard_montage('standard_1005') + montage = make_standard_montage("standard_1005") eeg_data = 1e-6 * run.eeg sfreq = run.header.SampleRate ch_names = list(run.header.Label[:-1]) - ch_types = ['eeg'] * len(ch_names) + ch_types = ["eeg"] * len(ch_names) trigger = np.zeros((len(eeg_data), 1)) @@ -662,9 +662,9 @@ def _convert_run_epfl(run, verbose=None): trigger[run.header.EVENT.POS[ii] - 1, 0] = 1 eeg_data = np.c_[eeg_data, trigger] - ch_names = ch_names + ['stim'] - ch_types = ch_types + ['stim'] - event_id = {'correct': 1, 'error': 2} + ch_names = ch_names + ["stim"] + ch_types = ch_types + ["stim"] + event_id = {"correct": 1, "error": 2} info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq) raw = RawArray(data=eeg_data.T, info=info, verbose=verbose) @@ -741,11 +741,11 @@ def __init__(self): super().__init__( subjects=list(range(1, 10)), sessions_per_subject=2, - events={'left_hand': 1, 'right_hand': 2, 'feet': 3, 'tongue': 4}, - code='001-2014', + events={"left_hand": 1, "right_hand": 2, "feet": 3, "tongue": 4}, + code="001-2014", interval=[2, 6], - paradigm='imagery', - doi='10.3389/fnins.2012.00055', + paradigm="imagery", + doi="10.3389/fnins.2012.00055", ) @@ -795,11 +795,11 @@ def __init__(self): super().__init__( subjects=list(range(1, 15)), sessions_per_subject=1, - events={'right_hand': 1, 'feet': 2}, - code='002-2014', + events={"right_hand": 1, "feet": 2}, + code="002-2014", interval=[3, 8], - paradigm='imagery', - doi='10.1515/bmt-2014-0117', + paradigm="imagery", + doi="10.1515/bmt-2014-0117", ) @@ -869,11 +869,11 @@ def __init__(self): super().__init__( subjects=list(range(1, 10)), sessions_per_subject=5, - events={'left_hand': 1, 'right_hand': 2}, - code='004-2014', + events={"left_hand": 1, "right_hand": 2}, + code="004-2014", interval=[3, 7.5], - paradigm='imagery', - doi='10.1109/TNSRE.2007.906956', + paradigm="imagery", + doi="10.1109/TNSRE.2007.906956", ) @@ -932,11 +932,11 @@ def __init__(self): super().__init__( subjects=list(range(1, 9)), sessions_per_subject=1, - events={'Target': 2, 'NonTarget': 1}, - code='008-2014', + events={"Target": 2, "NonTarget": 1}, + code="008-2014", interval=[0, 1.0], - paradigm='p300', - doi='10.3389/fnhum.2013.00732', + paradigm="p300", + doi="10.3389/fnhum.2013.00732", ) @@ -986,11 +986,11 @@ def __init__(self): super().__init__( subjects=list(range(1, 11)), sessions_per_subject=1, - events={'Target': 2, 'NonTarget': 1}, - code='009-2014', + events={"Target": 2, "NonTarget": 1}, + code="009-2014", interval=[0, 0.8], - paradigm='p300', - doi='10.1088/1741-2560/11/3/035008', + paradigm="p300", + doi="10.1088/1741-2560/11/3/035008", ) @@ -1034,11 +1034,11 @@ def __init__(self): super().__init__( subjects=list(range(1, 13)), sessions_per_subject=2, - events={'right_hand': 1, 'feet': 2}, - code='001-2015', + events={"right_hand": 1, "feet": 2}, + code="001-2015", interval=[0, 5], - paradigm='imagery', - doi='10.1109/tnsre.2012.2189584', + paradigm="imagery", + doi="10.1109/tnsre.2012.2189584", ) @@ -1069,11 +1069,11 @@ def __init__(self): super().__init__( subjects=list(range(1, 11)), sessions_per_subject=1, - events={'Target': 2, 'NonTarget': 1}, - code='003-2015', + events={"Target": 2, "NonTarget": 1}, + code="003-2015", interval=[0, 0.8], - paradigm='p300', - doi='10.1016/j.neulet.2009.06.045', + paradigm="p300", + doi="10.1016/j.neulet.2009.06.045", ) @@ -1136,8 +1136,8 @@ def __init__(self): subjects=list(range(1, 10)), sessions_per_subject=2, events=dict(right_hand=4, feet=5, navigation=3, subtraction=2, word_ass=1), - code='004-2015', + code="004-2015", interval=[3, 10], - paradigm='imagery', - doi='10.1371/journal.pone.0123727', + paradigm="imagery", + doi="10.1371/journal.pone.0123727", ) diff --git a/moabb/datasets/braininvaders.py b/moabb/datasets/braininvaders.py index 33a7bfa98..6dddf77dc 100644 --- a/moabb/datasets/braininvaders.py +++ b/moabb/datasets/braininvaders.py @@ -10,7 +10,7 @@ from moabb.datasets.base import BaseDataset -BI2013a_URL = 'https://zenodo.org/record/1494240/files/' +BI2013a_URL = "https://zenodo.org/record/1494240/files/" class bi2013a(BaseDataset): @@ -110,10 +110,10 @@ def __init__(self, NonAdaptive=True, Adaptive=False, Training=True, Online=False subjects=list(range(1, 24 + 1)), sessions_per_subject=1, events=dict(Target=1, NonTarget=2), - code='Brain Invaders 2013a', + code="Brain Invaders 2013a", interval=[0, 1], - paradigm='p300', - doi='', + paradigm="p300", + doi="", ) self.adaptive = Adaptive @@ -128,19 +128,19 @@ def _get_single_subject_data(self, subject): sessions = {} for file_path in file_path_list: - session_number = file_path.split(os.sep)[-2].replace('Session', '') - session_name = 'session_' + session_number + session_number = file_path.split(os.sep)[-2].replace("Session", "") + session_name = "session_" + session_number if session_name not in sessions.keys(): sessions[session_name] = {} run_number = file_path.split(os.sep)[-1] - run_number = run_number.split('_')[-1] - run_number = run_number.split('.gdf')[0] - run_name = 'run_' + run_number + run_number = run_number.split("_")[-1] + run_number = run_number.split(".gdf")[0] + run_name = "run_" + run_number raw_original = mne.io.read_raw_gdf(file_path, preload=True) - raw_original.rename_channels({'FP1': 'Fp1', 'FP2': 'Fp2'}) - raw_original.set_montage(make_standard_montage('standard_1020')) + raw_original.rename_channels({"FP1": "Fp1", "FP2": "Fp2"}) + raw_original.set_montage(make_standard_montage("standard_1020")) sessions[session_name][run_name] = raw_original @@ -154,44 +154,44 @@ def data_path( raise (ValueError("Invalid subject number")) # check if has the .zip - url = '{:s}subject{:d}.zip'.format(BI2013a_URL, subject) - path_zip = dl.data_path(url, 'BRAININVADERS') - path_folder = path_zip.strip('subject{:d}.zip'.format(subject)) + url = "{:s}subject{:d}.zip".format(BI2013a_URL, subject) + path_zip = dl.data_path(url, "BRAININVADERS") + path_folder = path_zip.strip("subject{:d}.zip".format(subject)) # check if has to unzip - if not (os.path.isdir(path_folder + 'subject{:d}'.format(subject))): - print('unzip', path_zip) + if not (os.path.isdir(path_folder + "subject{:d}".format(subject))): + print("unzip", path_zip) zip_ref = zipfile.ZipFile(path_zip, "r") zip_ref.extractall(path_folder) # filter the data regarding the experimental conditions - meta_file = os.path.join('subject{:d}'.format(subject), 'meta.yml') + meta_file = os.path.join("subject{:d}".format(subject), "meta.yml") meta_path = path_folder + meta_file - with open(meta_path, 'r') as stream: + with open(meta_path, "r") as stream: meta = yaml.load(stream, Loader=yaml.FullLoader) conditions = [] if self.adaptive: - conditions = conditions + ['adaptive'] + conditions = conditions + ["adaptive"] if self.nonadaptive: - conditions = conditions + ['nonadaptive'] + conditions = conditions + ["nonadaptive"] types = [] if self.training: - types = types + ['training'] + types = types + ["training"] if self.online: - types = types + ['online'] + types = types + ["online"] filenames = [] - for run in meta['runs']: - run_condition = run['experimental_condition'] - run_type = run['type'] + for run in meta["runs"]: + run_condition = run["experimental_condition"] + run_type = run["type"] if (run_condition in conditions) and (run_type in types): - filenames = filenames + [run['filename']] + filenames = filenames + [run["filename"]] # list the filepaths for this subject subject_paths = [] for filename in filenames: subject_paths = subject_paths + glob.glob( os.path.join( - path_folder, 'subject{:d}'.format(subject), 'Session*', filename + path_folder, "subject{:d}".format(subject), "Session*", filename ) ) # noqa return subject_paths diff --git a/moabb/datasets/download.py b/moabb/datasets/download.py index 1c9e601ed..e5897edc1 100644 --- a/moabb/datasets/download.py +++ b/moabb/datasets/download.py @@ -46,8 +46,8 @@ def data_path(url, sign, path=None, force_update=False, update_path=True, verbos """ # noqa: E501 sign = sign.upper() - key = 'MNE_DATASETS_{:s}_PATH'.format(sign) - key_dest = 'MNE-{:s}-data'.format(sign.lower()) + key = "MNE_DATASETS_{:s}_PATH".format(sign) + key_dest = "MNE-{:s}-data".format(sign.lower()) if get_config(key) is None: set_config(key, osp.join(osp.expanduser("~"), "mne_data")) path = _get_path(path, key, sign) diff --git a/moabb/datasets/epfl.py b/moabb/datasets/epfl.py index 75b743e0a..4e8bd2b68 100644 --- a/moabb/datasets/epfl.py +++ b/moabb/datasets/epfl.py @@ -12,7 +12,7 @@ from moabb.datasets.base import BaseDataset -EPFLP300_URL = 'http://documents.epfl.ch/groups/m/mm/mmspg/www/BCI/p300/' +EPFLP300_URL = "http://documents.epfl.ch/groups/m/mm/mmspg/www/BCI/p300/" class EPFLP300(BaseDataset): @@ -66,60 +66,60 @@ def __init__(self): subjects=[1, 2, 3, 4, 6, 7, 8, 9], sessions_per_subject=4, events=dict(Target=2, NonTarget=1), - code='EPFL P300 dataset', + code="EPFL P300 dataset", interval=[0, 1], - paradigm='p300', - doi='10.1016/j.jneumeth.2007.03.005', + paradigm="p300", + doi="10.1016/j.jneumeth.2007.03.005", ) def _get_single_run_data(self, file_path): # data from the .mat data = loadmat(file_path) - signals = data['data'] - stimuli = data['stimuli'].squeeze() - events = data['events'] - target = data['target'][0][0] + signals = data["data"] + stimuli = data["stimuli"].squeeze() + events = data["events"] + target = data["target"][0][0] # meta-info from the readme.pdf sfreq = 2048 ch_names = [ - 'Fp1', - 'AF3', - 'F7', - 'F3', - 'FC1', - 'FC5', - 'T7', - 'C3', - 'CP1', - 'CP5', - 'P7', - 'P3', - 'Pz', - 'PO3', - 'O1', - 'Oz', - 'O2', - 'PO4', - 'P4', - 'P8', - 'CP6', - 'CP2', - 'C4', - 'T8', - 'FC6', - 'FC2', - 'F4', - 'F8', - 'AF4', - 'Fp2', - 'Fz', - 'Cz', - 'MA1', - 'MA2', + "Fp1", + "AF3", + "F7", + "F3", + "FC1", + "FC5", + "T7", + "C3", + "CP1", + "CP5", + "P7", + "P3", + "Pz", + "PO3", + "O1", + "Oz", + "O2", + "PO4", + "P4", + "P8", + "CP6", + "CP2", + "C4", + "T8", + "FC6", + "FC2", + "F4", + "F8", + "AF4", + "Fp2", + "Fz", + "Cz", + "MA1", + "MA2", ] - ch_types = ['eeg'] * 32 + ['misc'] * 2 + ch_types = ["eeg"] * 32 + ["misc"] * 2 # The last X entries are 0 for all signals. This leads to # artifacts when epoching and band-pass filtering the data. @@ -155,17 +155,17 @@ def _get_single_run_data(self, file_path): stim_aux[stimuli != target] = 1 stim_channel = np.zeros(signals.shape[1]) stim_channel[pos] = stim_aux - ch_names = ch_names + ['STI'] - ch_types = ch_types + ['stim'] + ch_names = ch_names + ["STI"] + ch_types = ch_types + ["stim"] signals = np.concatenate([signals, stim_channel[None, :]]) # create info dictionary info = mne.create_info(ch_names, sfreq, ch_types) - info['description'] = 'EPFL P300 dataset' + info["description"] = "EPFL P300 dataset" # create the Raw structure raw = mne.io.RawArray(signals, info, verbose=False) - montage = make_standard_montage('biosemi32') + montage = make_standard_montage("biosemi32") raw.set_montage(montage) return raw @@ -178,12 +178,12 @@ def _get_single_subject_data(self, subject): for file_path in sorted(file_path_list): - session_name = 'session_' + file_path.split(os.sep)[-2].replace('session', '') + session_name = "session_" + file_path.split(os.sep)[-2].replace("session", "") if session_name not in sessions.keys(): sessions[session_name] = {} - run_name = 'run_' + str(len(sessions[session_name]) + 1) + run_name = "run_" + str(len(sessions[session_name]) + 1) sessions[session_name][run_name] = self._get_single_run_data(file_path) return sessions @@ -196,18 +196,18 @@ def data_path( raise (ValueError("Invalid subject number")) # check if has the .zip - url = '{:s}subject{:d}.zip'.format(EPFLP300_URL, subject) - path_zip = dl.data_path(url, 'EPFLP300') - path_folder = path_zip.strip('subject{:d}.zip'.format(subject)) + url = "{:s}subject{:d}.zip".format(EPFLP300_URL, subject) + path_zip = dl.data_path(url, "EPFLP300") + path_folder = path_zip.strip("subject{:d}.zip".format(subject)) # check if has to unzip - if not (os.path.isdir(path_folder + 'subject{:d}'.format(subject))): - print('unzip', path_zip) + if not (os.path.isdir(path_folder + "subject{:d}".format(subject))): + print("unzip", path_zip) zip_ref = zipfile.ZipFile(path_zip, "r") zip_ref.extractall(path_folder) # get the path to all files - pattern = os.path.join('subject{:d}'.format(subject), '*', '*') + pattern = os.path.join("subject{:d}".format(subject), "*", "*") subject_paths = glob.glob(path_folder + pattern) return subject_paths diff --git a/moabb/datasets/fake.py b/moabb/datasets/fake.py index a59ba7477..c14d1c627 100644 --- a/moabb/datasets/fake.py +++ b/moabb/datasets/fake.py @@ -15,11 +15,11 @@ class FakeDataset(BaseDataset): def __init__( self, - event_list=('fake_c1', 'fake_c2', 'fake_c3'), + event_list=("fake_c1", "fake_c2", "fake_c3"), n_sessions=2, n_runs=2, n_subjects=10, - paradigm='imagery', + paradigm="imagery", ): self.n_runs = n_runs event_id = {ev: ii + 1 for ii, ev in enumerate(event_list)} @@ -27,7 +27,7 @@ def __init__( list(range(1, n_subjects + 1)), n_sessions, event_id, - 'FakeDataset', + "FakeDataset", [0, 3], paradigm, ) @@ -43,9 +43,9 @@ def _get_single_subject_data(self, subject): def _generate_raw(self): - ch_names = ['C3', 'Cz', 'C4'] + ch_names = ["C3", "Cz", "C4"] - montage = make_standard_montage('standard_1005') + montage = make_standard_montage("standard_1005") sfreq = 128 duration = len(self.event_id) * 60 eeg_data = 2e-5 * np.random.randn(duration * sfreq, len(ch_names)) @@ -55,8 +55,8 @@ def _generate_raw(self): jump = 5 * len(self.event_id) * 128 y[start_idx::jump] = self.event_id[ev] - ch_types = ['eeg'] * len(ch_names) + ['stim'] - ch_names = ch_names + ['stim'] + ch_types = ["eeg"] * len(ch_names) + ["stim"] + ch_names = ch_names + ["stim"] eeg_data = np.c_[eeg_data, y] diff --git a/moabb/datasets/gigadb.py b/moabb/datasets/gigadb.py index 6f20a149f..3f5dd81a5 100644 --- a/moabb/datasets/gigadb.py +++ b/moabb/datasets/gigadb.py @@ -16,7 +16,7 @@ log = logging.getLogger() GIGA_URL = ( - 'ftp://parrot.genomics.cn/gigadb/pub/10.5524/100001_101000/100295/mat_data/' # noqa + "ftp://parrot.genomics.cn/gigadb/pub/10.5524/100001_101000/100295/mat_data/" # noqa ) @@ -64,10 +64,10 @@ def __init__(self): subjects=list(range(1, 53)), sessions_per_subject=1, events=dict(left_hand=1, right_hand=2), - code='Cho2017', + code="Cho2017", interval=[0, 3], # full trial is 0-3s, but edge effects - paradigm='imagery', - doi='10.5524/100295', + paradigm="imagery", + doi="10.5524/100295", ) for ii in [32, 46, 49]: @@ -82,7 +82,7 @@ def _get_single_subject_data(self, subject): squeeze_me=True, struct_as_record=False, verify_compressed_data_integrity=False, - )['eeg'] + )["eeg"] # fmt: off eeg_ch_names = [ @@ -94,10 +94,10 @@ def _get_single_subject_data(self, subject): "P2", "P4", "P6", "P8", "P10", "PO8", "PO4", "O2", ] # fmt: on - emg_ch_names = ['EMG1', 'EMG2', 'EMG3', 'EMG4'] - ch_names = eeg_ch_names + emg_ch_names + ['Stim'] - ch_types = ['eeg'] * 64 + ['emg'] * 4 + ['stim'] - montage = make_standard_montage('standard_1005') + emg_ch_names = ["EMG1", "EMG2", "EMG3", "EMG4"] + ch_names = eeg_ch_names + emg_ch_names + ["Stim"] + ch_types = ["eeg"] * 64 + ["emg"] * 4 + ["stim"] + montage = make_standard_montage("standard_1005") imagery_left = data.imagery_left - data.imagery_left.mean(axis=1, keepdims=True) imagery_right = data.imagery_right - data.imagery_right.mean( axis=1, keepdims=True @@ -120,7 +120,7 @@ def _get_single_subject_data(self, subject): raw = RawArray(data=eeg_data, info=info, verbose=False) raw.set_montage(montage) - return {'session_0': {'run_0': raw}} + return {"session_0": {"run_0": raw}} def data_path( self, subject, path=None, force_update=False, update_path=None, verbose=None @@ -128,5 +128,5 @@ def data_path( if subject not in self.subject_list: raise (ValueError("Invalid subject number")) - url = '{:s}s{:02d}.mat'.format(GIGA_URL, subject) - return dl.data_path(url, 'GIGADB', path, force_update, update_path, verbose) + url = "{:s}s{:02d}.mat".format(GIGA_URL, subject) + return dl.data_path(url, "GIGADB", path, force_update, update_path, verbose) diff --git a/moabb/datasets/mpi_mi.py b/moabb/datasets/mpi_mi.py index b0b99c2e9..5c0a9a665 100644 --- a/moabb/datasets/mpi_mi.py +++ b/moabb/datasets/mpi_mi.py @@ -9,7 +9,7 @@ from moabb.datasets.base import BaseDataset -DOWNLOAD_URL = 'https://zenodo.org/record/1217449/files/' +DOWNLOAD_URL = "https://zenodo.org/record/1217449/files/" class MunichMI(BaseDataset): @@ -61,21 +61,21 @@ def __init__(self): subjects=list(range(1, 11)), sessions_per_subject=1, events=dict(right_hand=2, left_hand=1), - code='Grosse-Wentrup 2009', + code="Grosse-Wentrup 2009", interval=[0, 7], - paradigm='imagery', - doi='10.1109/TBME.2008.2009768', + paradigm="imagery", + doi="10.1109/TBME.2008.2009768", ) def _get_single_subject_data(self, subject): """return data for a single subject""" raw = mne.io.read_raw_eeglab( - self.data_path(subject), preload=True, verbose='ERROR' + self.data_path(subject), preload=True, verbose="ERROR" ) - stim = raw.annotations.description.astype(np.dtype('<10U')) + stim = raw.annotations.description.astype(np.dtype("<10U")) - stim[stim == '20'] = 'right_hand' - stim[stim == '10'] = 'left_hand' + stim[stim == "20"] = "right_hand" + stim[stim == "10"] = "left_hand" raw.annotations.description = stim return {"session_0": {"run_0": raw}} @@ -86,11 +86,11 @@ def data_path( raise (ValueError("Invalid subject number")) # download .set - _set = '{:s}subject{:d}.set'.format(DOWNLOAD_URL, subject) + _set = "{:s}subject{:d}.set".format(DOWNLOAD_URL, subject) set_local = dl.data_path( - _set, 'MUNICHMI', path, force_update, update_path, verbose + _set, "MUNICHMI", path, force_update, update_path, verbose ) # download .fdt - _fdt = '{:s}subject{:d}.fdt'.format(DOWNLOAD_URL, subject) - dl.data_path(_fdt, 'MUNICHMI', path, force_update, update_path, verbose) + _fdt = "{:s}subject{:d}.fdt".format(DOWNLOAD_URL, subject) + dl.data_path(_fdt, "MUNICHMI", path, force_update, update_path, verbose) return set_local diff --git a/moabb/datasets/physionet_mi.py b/moabb/datasets/physionet_mi.py index 5b7e657c9..b61cddea4 100644 --- a/moabb/datasets/physionet_mi.py +++ b/moabb/datasets/physionet_mi.py @@ -13,7 +13,7 @@ from .base import BaseDataset -BASE_URL = 'http://archive.physionet.org/pn4/eegmmidb/' +BASE_URL = "http://archive.physionet.org/pn4/eegmmidb/" class PhysionetMI(BaseDataset): @@ -77,12 +77,12 @@ def __init__(self, imagined=True, executed=False): subjects=list(range(1, 110)), sessions_per_subject=1, events=dict(left_hand=2, right_hand=3, feet=5, hands=4, rest=1), - code='Physionet Motor Imagery', + code="Physionet Motor Imagery", # website does not specify how long the trials are, but the # interval between 2 trial is 4 second. interval=[0, 3], - paradigm='imagery', - doi='10.1109/TBME.2004.827072', + paradigm="imagery", + doi="10.1109/TBME.2004.827072", ) self.feet_runs = [] @@ -97,15 +97,15 @@ def __init__(self, imagined=True, executed=False): self.hand_runs += [3, 7, 11] def _load_one_run(self, subject, run, preload=True): - if get_config('MNE_DATASETS_EEGBCI_PATH') is None: + if get_config("MNE_DATASETS_EEGBCI_PATH") is None: set_config( - 'MNE_DATASETS_EEGBCI_PATH', osp.join(osp.expanduser("~"), "mne_data") + "MNE_DATASETS_EEGBCI_PATH", osp.join(osp.expanduser("~"), "mne_data") ) raw_fname = eegbci.load_data( - subject, runs=[run], verbose='ERROR', base_url=BASE_URL + subject, runs=[run], verbose="ERROR", base_url=BASE_URL )[0] - raw = read_raw_edf(raw_fname, preload=preload, verbose='ERROR') - raw.rename_channels(lambda x: x.strip('.')) + raw = read_raw_edf(raw_fname, preload=preload, verbose="ERROR") + raw.rename_channels(lambda x: x.strip(".")) raw.rename_channels(lambda x: x.upper()) # fmt: off renames = { @@ -115,38 +115,38 @@ def _load_one_run(self, subject, run, preload=True): } # fmt: on raw.rename_channels(renames) - raw.set_montage(mne.channels.make_standard_montage('standard_1005')) + raw.set_montage(mne.channels.make_standard_montage("standard_1005")) return raw def _get_single_subject_data(self, subject): """return data for a single subject""" data = {} - if get_config('MNE_DATASETS_EEGBCI_PATH') is None: + if get_config("MNE_DATASETS_EEGBCI_PATH") is None: set_config( - 'MNE_DATASETS_EEGBCI_PATH', osp.join(osp.expanduser("~"), "mne_data") + "MNE_DATASETS_EEGBCI_PATH", osp.join(osp.expanduser("~"), "mne_data") ) # hand runs for run in self.hand_runs: raw = self._load_one_run(subject, run) - stim = raw.annotations.description.astype(np.dtype(' 0: if verbose: - print('Found EEG channels: {}'.format(processed)) + print("Found EEG channels: {}".format(processed)) dset_chans.append(processed) keep_datasets.append(d) else: print( - 'Dataset {:s} has no recognizable EEG channels'.format(type(d).__name__) + "Dataset {:s} has no recognizable EEG channels".format(type(d).__name__) ) # noqa allchans.intersection_update(*dset_chans) - allchans = [s.replace('Z', 'z') for s in allchans] + allchans = [s.replace("Z", "z") for s in allchans] return allchans, keep_datasets diff --git a/moabb/evaluations/base.py b/moabb/evaluations/base.py index 36884dd00..d607e80c5 100644 --- a/moabb/evaluations/base.py +++ b/moabb/evaluations/base.py @@ -40,8 +40,8 @@ def __init__( random_state=None, n_jobs=1, overwrite=False, - error_score='raise', - suffix='', + error_score="raise", + suffix="", hdf5_path=None, additional_columns=None, ): @@ -91,8 +91,8 @@ def __init__( self.datasets = datasets else: raise Exception( - '''No datasets left after paradigm - and evaluation checks''' + """No datasets left after paradigm + and evaluation checks""" ) self.results = Results( @@ -131,7 +131,7 @@ def process(self, pipelines): raise (ValueError("pipelines must only contains Pipelines " "instance")) for dataset in self.datasets: - log.info('Processing dataset: {}'.format(dataset.code)) + log.info("Processing dataset: {}".format(dataset.code)) results = self.evaluate(dataset, pipelines) for res in results: self.push_result(res, pipelines) @@ -139,13 +139,13 @@ def process(self, pipelines): return self.results.to_dataframe(pipelines=pipelines) def push_result(self, res, pipelines): - message = '{} | '.format(res['pipeline']) - message += '{} | {} | {}'.format( - res['dataset'].code, res['subject'], res['session'] + message = "{} | ".format(res["pipeline"]) + message += "{} | {} | {}".format( + res["dataset"].code, res["subject"], res["session"] ) - message += ': Score %.3f' % res['score'] + message += ": Score %.3f" % res["score"] log.info(message) - self.results.add({res['pipeline']: res}, pipelines=pipelines) + self.results.add({res["pipeline"]: res}, pipelines=pipelines) def get_results(self): return self.results.to_dataframe() diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index 71a4d49da..eb274c612 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -43,14 +43,14 @@ def evaluate(self, dataset, pipelines): score = self.score(clf, X[ix], y[ix], self.paradigm.scoring) duration = time() - t_start res = { - 'time': duration / 5.0, # 5 fold CV - 'dataset': dataset, - 'subject': subject, - 'session': session, - 'score': score, - 'n_samples': len(y[ix]), # not training sample - 'n_channels': X.shape[1], - 'pipeline': name, + "time": duration / 5.0, # 5 fold CV + "dataset": dataset, + "subject": subject, + "session": session, + "score": score, + "n_samples": len(y[ix]), # not training sample + "n_channels": X.shape[1], + "pipeline": name, } yield res @@ -86,7 +86,7 @@ class CrossSessionEvaluation(BaseEvaluation): def evaluate(self, dataset, pipelines): if not self.is_valid(dataset): - raise AssertionError('Dataset is not appropriate for evaluation') + raise AssertionError("Dataset is not appropriate for evaluation") for subject in dataset.subject_list: # check if we already have result for this subject/pipeline # we might need a better granularity, if we query the DB @@ -121,14 +121,14 @@ def evaluate(self, dataset, pipelines): )[0] duration = time() - t_start res = { - 'time': duration, - 'dataset': dataset, - 'subject': subject, - 'session': groups[test][0], - 'score': score, - 'n_samples': len(train), - 'n_channels': X.shape[1], - 'pipeline': name, + "time": duration, + "dataset": dataset, + "subject": subject, + "session": groups[test][0], + "score": score, + "n_samples": len(train), + "n_channels": X.shape[1], + "pipeline": name, } yield res @@ -146,7 +146,7 @@ class CrossSubjectEvaluation(BaseEvaluation): def evaluate(self, dataset, pipelines): if not self.is_valid(dataset): - raise AssertionError('Dataset is not appropriate for evaluation') + raise AssertionError("Dataset is not appropriate for evaluation") # this is a bit akward, but we need to check if at least one pipe # have to be run before loading the data. If at least one pipeline # need to be run, we have to load all the data. @@ -189,14 +189,14 @@ def evaluate(self, dataset, pipelines): score = _score(model, X[test[ix]], y[test[ix]], scorer) res = { - 'time': duration, - 'dataset': dataset, - 'subject': subject, - 'session': session, - 'score': score, - 'n_samples': len(train), - 'n_channels': X.shape[1], - 'pipeline': name, + "time": duration, + "dataset": dataset, + "subject": subject, + "session": session, + "score": score, + "n_samples": len(train), + "n_channels": X.shape[1], + "pipeline": name, } yield res diff --git a/moabb/paradigms/base.py b/moabb/paradigms/base.py index 0967791c0..96256a775 100644 --- a/moabb/paradigms/base.py +++ b/moabb/paradigms/base.py @@ -139,7 +139,7 @@ def process_raw(self, raw, dataset, return_epochs=False): # noqa: C901 fmin, fmax = bandpass # filter data raw_f = raw.copy().filter( - fmin, fmax, method='iir', picks=picks, verbose=False + fmin, fmax, method="iir", picks=picks, verbose=False ) # epoch data baseline = self.baseline @@ -164,8 +164,8 @@ def process_raw(self, raw, dataset, return_epochs=False): # noqa: C901 preload=True, verbose=False, picks=picks, - event_repeated='drop', - on_missing='ignore', + event_repeated="drop", + on_missing="ignore", ) if bmin < tmin or bmax > tmax: epochs.crop(tmin=tmin, tmax=tmax) @@ -243,9 +243,9 @@ def get_data(self, dataset, subjects=None, return_epochs=False): continue x, lbs, met = proc - met['subject'] = subject - met['session'] = session - met['run'] = run + met["subject"] = subject + met["session"] = session + met["run"] = run metadata.append(met) # grow X and labels in a memory efficient way. can be slow diff --git a/moabb/paradigms/motor_imagery.py b/moabb/paradigms/motor_imagery.py index 8c42cbeb3..95ba3bf89 100644 --- a/moabb/paradigms/motor_imagery.py +++ b/moabb/paradigms/motor_imagery.py @@ -79,7 +79,7 @@ def __init__( def is_valid(self, dataset): ret = True - if not (dataset.paradigm == 'imagery'): + if not (dataset.paradigm == "imagery"): ret = False # check if dataset has required events @@ -101,12 +101,12 @@ def datasets(self): else: interval = self.tmax - self.tmin return utils.dataset_search( - paradigm='imagery', events=self.events, interval=interval, has_all_events=True + paradigm="imagery", events=self.events, interval=interval, has_all_events=True ) @property def scoring(self): - return 'accuracy' + return "accuracy" class SinglePass(BaseMotorImagery): @@ -155,7 +155,7 @@ class SinglePass(BaseMotorImagery): """ def __init__(self, fmin=8, fmax=32, **kwargs): - if 'filters' in kwargs.keys(): + if "filters" in kwargs.keys(): raise (ValueError("MotorImagery does not take argument filters")) super().__init__(filters=[[fmin, fmax]], **kwargs) @@ -180,16 +180,16 @@ class LeftRightImagery(SinglePass): """ def __init__(self, **kwargs): - if 'events' in kwargs.keys(): - raise (ValueError('LeftRightImagery dont accept events')) - super().__init__(events=['left_hand', 'right_hand'], **kwargs) + if "events" in kwargs.keys(): + raise (ValueError("LeftRightImagery dont accept events")) + super().__init__(events=["left_hand", "right_hand"], **kwargs) def used_events(self, dataset): return {ev: dataset.event_id[ev] for ev in self.events} @property def scoring(self): - return 'roc_auc' + return "roc_auc" class FilterBankLeftRightImagery(FilterBank): @@ -200,16 +200,16 @@ class FilterBankLeftRightImagery(FilterBank): """ def __init__(self, **kwargs): - if 'events' in kwargs.keys(): - raise (ValueError('LeftRightImagery dont accept events')) - super().__init__(events=['left_hand', 'right_hand'], **kwargs) + if "events" in kwargs.keys(): + raise (ValueError("LeftRightImagery dont accept events")) + super().__init__(events=["left_hand", "right_hand"], **kwargs) def used_events(self, dataset): return {ev: dataset.event_id[ev] for ev in self.events} @property def scoring(self): - return 'roc_auc' + return "roc_auc" class FilterBankMotorImagery(FilterBank): @@ -238,11 +238,11 @@ def __init__(self, n_classes=2, **kwargs): if self.events is None: log.warning("Choosing from all possible events") else: - assert n_classes <= len(self.events), 'More classes than events specified' + assert n_classes <= len(self.events), "More classes than events specified" def is_valid(self, dataset): ret = True - if not dataset.paradigm == 'imagery': + if not dataset.paradigm == "imagery": ret = False if self.events is None: if not len(dataset.event_id) >= self.n_classes: @@ -282,7 +282,7 @@ def datasets(self): else: interval = self.tmax - self.tmin return utils.dataset_search( - paradigm='imagery', + paradigm="imagery", events=self.events, total_classes=self.n_classes, interval=interval, @@ -292,9 +292,9 @@ def datasets(self): @property def scoring(self): if self.n_classes == 2: - return 'roc_auc' + return "roc_auc" else: - return 'accuracy' + return "accuracy" class MotorImagery(SinglePass): @@ -354,11 +354,11 @@ def __init__(self, n_classes=2, **kwargs): if self.events is None: log.warning("Choosing from all possible events") else: - assert n_classes <= len(self.events), 'More classes than events specified' + assert n_classes <= len(self.events), "More classes than events specified" def is_valid(self, dataset): ret = True - if not dataset.paradigm == 'imagery': + if not dataset.paradigm == "imagery": ret = False if self.events is None: if not len(dataset.event_id) >= self.n_classes: @@ -398,7 +398,7 @@ def datasets(self): else: interval = self.tmax - self.tmin return utils.dataset_search( - paradigm='imagery', + paradigm="imagery", events=self.events, interval=interval, has_all_events=False, @@ -407,9 +407,9 @@ def datasets(self): @property def scoring(self): if self.n_classes == 2: - return 'roc_auc' + return "roc_auc" else: - return 'accuracy' + return "accuracy" class FakeImageryParadigm(LeftRightImagery): @@ -417,4 +417,4 @@ class FakeImageryParadigm(LeftRightImagery): @property def datasets(self): - return [FakeDataset(['left_hand', 'right_hand'], paradigm='imagery')] + return [FakeDataset(["left_hand", "right_hand"], paradigm="imagery")] diff --git a/moabb/paradigms/p300.py b/moabb/paradigms/p300.py index 9ef5928d9..384d2979d 100644 --- a/moabb/paradigms/p300.py +++ b/moabb/paradigms/p300.py @@ -83,7 +83,7 @@ def __init__( def is_valid(self, dataset): ret = True - if not (dataset.paradigm == 'p300'): + if not (dataset.paradigm == "p300"): ret = False # check if dataset has required events @@ -116,10 +116,10 @@ def process_raw(self, raw, dataset, return_epochs=False): # pick events, based on event_id try: - if type(event_id['Target']) is list and type(event_id['NonTarget']) == list: + if type(event_id["Target"]) is list and type(event_id["NonTarget"]) == list: event_id_new = dict(Target=1, NonTarget=0) - events = mne.merge_events(events, event_id['Target'], 1) - events = mne.merge_events(events, event_id['NonTarget'], 0) + events = mne.merge_events(events, event_id["Target"], 1) + events = mne.merge_events(events, event_id["NonTarget"], 0) event_id = event_id_new events = mne.pick_events(events, include=list(event_id.values())) except RuntimeError: @@ -138,7 +138,7 @@ def process_raw(self, raw, dataset, return_epochs=False): fmin, fmax = bandpass # filter data raw_f = raw.copy().filter( - fmin, fmax, method='iir', picks=picks, verbose=False + fmin, fmax, method="iir", picks=picks, verbose=False ) # epoch data baseline = self.baseline @@ -163,7 +163,7 @@ def process_raw(self, raw, dataset, return_epochs=False): preload=True, verbose=False, picks=picks, - on_missing='ignore', + on_missing="ignore", ) if bmin < tmin or bmax > tmax: epochs.crop(tmin=tmin, tmax=tmax) @@ -194,12 +194,12 @@ def datasets(self): else: interval = self.tmax - self.tmin return utils.dataset_search( - paradigm='p300', events=self.events, interval=interval, has_all_events=True + paradigm="p300", events=self.events, interval=interval, has_all_events=True ) @property def scoring(self): - return 'roc_auc' + return "roc_auc" class SinglePass(BaseP300): @@ -248,7 +248,7 @@ class SinglePass(BaseP300): """ def __init__(self, fmin=1, fmax=24, **kwargs): - if 'filters' in kwargs.keys(): + if "filters" in kwargs.keys(): raise (ValueError("P300 does not take argument filters")) super().__init__(filters=[[fmin, fmax]], **kwargs) @@ -261,16 +261,16 @@ class P300(SinglePass): """ def __init__(self, **kwargs): - if 'events' in kwargs.keys(): - raise (ValueError('P300 dont accept events')) - super().__init__(events=['Target', 'NonTarget'], **kwargs) + if "events" in kwargs.keys(): + raise (ValueError("P300 dont accept events")) + super().__init__(events=["Target", "NonTarget"], **kwargs) def used_events(self, dataset): return {ev: dataset.event_id[ev] for ev in self.events} @property def scoring(self): - return 'roc_auc' + return "roc_auc" class FakeP300Paradigm(P300): @@ -278,4 +278,4 @@ class FakeP300Paradigm(P300): @property def datasets(self): - return [FakeDataset(['Target', 'NonTarget'], paradigm='p300')] + return [FakeDataset(["Target", "NonTarget"], paradigm="p300")] diff --git a/moabb/paradigms/ssvep.py b/moabb/paradigms/ssvep.py index d584ad6dc..d3c6783dc 100644 --- a/moabb/paradigms/ssvep.py +++ b/moabb/paradigms/ssvep.py @@ -84,11 +84,11 @@ def __init__( + " from all possible events" ) else: - assert n_classes <= len(self.events), 'More classes than events specified' + assert n_classes <= len(self.events), "More classes than events specified" def is_valid(self, dataset): ret = True - if not (dataset.paradigm == 'ssvep'): + if not (dataset.paradigm == "ssvep"): ret = False # check if dataset has required events @@ -128,7 +128,7 @@ def prepare_process(self, dataset): self.filters = [ [float(f) - 0.5, float(f) + 0.5] for f in event_id.keys() - if f.replace('.', '', 1).isnumeric() + if f.replace(".", "", 1).isnumeric() ] @property @@ -138,7 +138,7 @@ def datasets(self): else: interval = self.tmax - self.tmin return utils.dataset_search( - paradigm='ssvep', + paradigm="ssvep", events=self.events, # total_classes=self.n_classes, interval=interval, @@ -148,9 +148,9 @@ def datasets(self): @property def scoring(self): if self.n_classes == 2: - return 'roc_auc' + return "roc_auc" else: - return 'accuracy' + return "accuracy" class SSVEP(BaseSSVEP): @@ -202,7 +202,7 @@ class SSVEP(BaseSSVEP): """ def __init__(self, fmin=7, fmax=45, **kwargs): - if 'filters' in kwargs.keys(): + if "filters" in kwargs.keys(): raise (ValueError("SSVEP does not take argument filters")) super().__init__(filters=[(fmin, fmax)], **kwargs) @@ -263,4 +263,4 @@ class FakeSSVEPParadigm(BaseSSVEP): @property def datasets(self): - return [FakeDataset(event_list=['13', '15'], paradigm='ssvep')] + return [FakeDataset(event_list=["13", "15"], paradigm="ssvep")] diff --git a/moabb/pipelines/classification.py b/moabb/pipelines/classification.py index 02406e4b9..4c8ebe64f 100644 --- a/moabb/pipelines/classification.py +++ b/moabb/pipelines/classification.py @@ -40,7 +40,7 @@ def fit(self, X, y, sample_weight=None): n_times = X.shape[2] for f in self.freqs: - if f.replace('.', '', 1).isnumeric(): + if f.replace(".", "", 1).isnumeric(): freq = float(f) yf = [] for h in range(1, self.n_harmonics + 1): @@ -59,7 +59,7 @@ def predict(self, X): for x in X: corr_f = {} for f in self.freqs: - if f.replace('.', '', 1).isnumeric(): + if f.replace(".", "", 1).isnumeric(): S_x, S_y = self.cca.fit_transform(x.T, self.Yf[f].T) corr_f[f] = np.corrcoef(S_x.T, S_y.T)[0, 1] y.append(self.one_hot[max(corr_f, key=lambda k: corr_f[k])]) @@ -70,7 +70,7 @@ def predict_proba(self, X): P = np.zeros(shape=(len(X), len(self.freqs))) for i, x in enumerate(X): for j, f in enumerate(self.freqs): - if f.replace('.', '', 1).isnumeric(): + if f.replace(".", "", 1).isnumeric(): S_x, S_y = self.cca.fit_transform(x.T, self.Yf[f].T) P[i, j] = np.corrcoef(S_x.T, S_y.T)[0, 1] return P / np.resize(P.sum(axis=1), P.T.shape).T diff --git a/moabb/pipelines/csp.py b/moabb/pipelines/csp.py index ae2313507..c829decf1 100644 --- a/moabb/pipelines/csp.py +++ b/moabb/pipelines/csp.py @@ -10,7 +10,7 @@ class TRCSP(CSP): """ - def __init__(self, nfilter=4, metric='euclid', log=True, alpha=1): + def __init__(self, nfilter=4, metric="euclid", log=True, alpha=1): super().__init__(nfilter, metric, log) self.alpha = alpha @@ -20,20 +20,20 @@ def fit(self, X, y): """ if not isinstance(X, (np.ndarray, list)): - raise TypeError('X must be an array.') + raise TypeError("X must be an array.") if not isinstance(y, (np.ndarray, list)): - raise TypeError('y must be an array.') + raise TypeError("y must be an array.") X, y = np.asarray(X), np.asarray(y) if X.ndim != 3: - raise ValueError('X must be n_trials * n_channels * n_channels') + raise ValueError("X must be n_trials * n_channels * n_channels") if len(y) != len(X): - raise ValueError('X and y must have the same length.') + raise ValueError("X and y must have the same length.") if np.squeeze(y).ndim != 1: - raise ValueError('y must be of shape (n_trials,).') + raise ValueError("y must be of shape (n_trials,).") Nt, Ne, Ns = X.shape classes = np.unique(y) - assert len(classes) == 2, 'Can only do 2-class TRCSP' + assert len(classes) == 2, "Can only do 2-class TRCSP" # estimate class means C = [] for c in classes: diff --git a/moabb/pipelines/utils.py b/moabb/pipelines/utils.py index eb52727c6..9c4355327 100644 --- a/moabb/pipelines/utils.py +++ b/moabb/pipelines/utils.py @@ -25,13 +25,13 @@ def create_pipeline_from_config(config): for component in config: # load the package - mod = __import__(component['from'], fromlist=[component['name']]) + mod = __import__(component["from"], fromlist=[component["name"]]) # create the instance - if 'parameters' in component.keys(): - params = component['parameters'] + if "parameters" in component.keys(): + params = component["parameters"] else: params = {} - instance = getattr(mod, component['name'])(**params) + instance = getattr(mod, component["name"])(**params) components.append(instance) pipeline = make_pipeline(*components) @@ -84,6 +84,6 @@ def transform(self, X): def __repr__(self): estimator_name = type(self).__name__ estimator_prms = self.estimator.get_params() - return '{}(estimator={}, flatten={})'.format( + return "{}(estimator={}, flatten={})".format( estimator_name, estimator_prms, self.flatten ) diff --git a/moabb/run.py b/moabb/run.py index f7f6e3f81..58b1fd795 100755 --- a/moabb/run.py +++ b/moabb/run.py @@ -33,7 +33,7 @@ def parser_init(): "--pipelines", dest="pipelines", type=str, - default='./pipelines/', + default="./pipelines/", help="Folder containing the pipelines to evaluates.", ) parser.add_argument( @@ -41,7 +41,7 @@ def parser_init(): "--results", dest="results", type=str, - default='./results/', + default="./results/", help="Folder to store the results.", ) parser.add_argument( @@ -69,7 +69,7 @@ def parser_init(): "--output", dest="output", type=str, - default='./', + default="./", help="Folder to put analysis results", ) parser.add_argument( @@ -108,26 +108,26 @@ def parse_pipelines_from_directory(d): ), "Given pipeline path {} is not valid".format(d) # get list of config files - yaml_files = glob(os.path.join(d, '*.yml')) + yaml_files = glob(os.path.join(d, "*.yml")) pipeline_configs = [] for yaml_file in yaml_files: - with open(yaml_file, 'r') as _file: + with open(yaml_file, "r") as _file: content = _file.read() # load config config_dict = yaml.load(content, Loader=yaml.FullLoader) - ppl = create_pipeline_from_config(config_dict['pipeline']) + ppl = create_pipeline_from_config(config_dict["pipeline"]) pipeline_configs.append( { - 'paradigms': config_dict['paradigms'], - 'pipeline': ppl, - 'name': config_dict['name'], + "paradigms": config_dict["paradigms"], + "pipeline": ppl, + "name": config_dict["name"], } ) # we can do the same for python defined pipeline - python_files = glob(os.path.join(d, '*.py')) + python_files = glob(os.path.join(d, "*.py")) for python_file in python_files: spec = importlib.util.spec_from_file_location("custom", python_file) @@ -143,13 +143,13 @@ def generate_paradigms(pipeline_configs, context=None): paradigms = OrderedDict() for config in pipeline_configs: - if 'paradigms' not in config.keys(): + if "paradigms" not in config.keys(): log.error("{} must have a 'paradigms' key.".format(config)) continue # iterate over paradigms - for paradigm in config['paradigms']: + for paradigm in config["paradigms"]: # check if it is in the context parameters file if len(context) > 0: @@ -161,24 +161,24 @@ def generate_paradigms(pipeline_configs, context=None): ) ) - if isinstance(config['pipeline'], BaseEstimator): - pipeline = deepcopy(config['pipeline']) + if isinstance(config["pipeline"], BaseEstimator): + pipeline = deepcopy(config["pipeline"]) else: - log.error(config['pipeline']) - raise (ValueError('pipeline must be a sklearn estimator')) + log.error(config["pipeline"]) + raise (ValueError("pipeline must be a sklearn estimator")) # append the pipeline in the paradigm list if paradigm not in paradigms.keys(): paradigms[paradigm] = {} # FIXME name are not unique - log.debug('Pipeline: \n\n {} \n'.format(get_string_rep(pipeline))) - paradigms[paradigm][config['name']] = pipeline + log.debug("Pipeline: \n\n {} \n".format(get_string_rep(pipeline))) + paradigms[paradigm][config["name"]] = pipeline return paradigms -if __name__ == '__main__': +if __name__ == "__main__": # set logs mne.set_log_level(False) # logging.basicConfig(level=logging.WARNING) @@ -198,7 +198,7 @@ def generate_paradigms(pipeline_configs, context=None): context_params = {} if options.context is not None: - with open(options.context, 'r') as cfile: + with open(options.context, "r") as cfile: context_params = yaml.load(cfile.read(), Loader=yaml.FullLoader) paradigms = generate_paradigms(pipeline_configs, context_params) @@ -210,7 +210,7 @@ def generate_paradigms(pipeline_configs, context=None): all_results = [] for paradigm in paradigms: # get the context - log.debug('{}: {}'.format(paradigm, context_params[paradigm])) + log.debug("{}: {}".format(paradigm, context_params[paradigm])) p = getattr(moabb_paradigms, paradigm)(**context_params[paradigm]) context = WithinSessionEvaluation( paradigm=p, random_state=42, n_jobs=options.threads, overwrite=options.force diff --git a/moabb/tests/analysis.py b/moabb/tests/analysis.py index f9c8e6fe4..cfd688f15 100644 --- a/moabb/tests/analysis.py +++ b/moabb/tests/analysis.py @@ -16,7 +16,7 @@ class DummyEvaluation(BaseEvaluation): def evaluate(self, dataset, pipelines): - raise NotImplementedError('dummy') + raise NotImplementedError("dummy") def is_valid(self, dataset): pass @@ -28,64 +28,64 @@ def __init__(self): @property def scoring(self): - raise NotImplementedError('dummy') + raise NotImplementedError("dummy") def is_valid(self, dataset): pass def process_raw(self, raw, dataset, return_epochs=False): - raise NotImplementedError('dummy') + raise NotImplementedError("dummy") @property def datasets(self): - return [FakeDataset(['d1', 'd2'])] + return [FakeDataset(["d1", "d2"])] # Create dummy data for tests d1 = { - 'time': 1, - 'dataset': FakeDataset(['d1', 'd2']), - 'subject': 1, - 'session': 'session_0', - 'score': 0.9, - 'n_samples': 100, - 'n_channels': 10, + "time": 1, + "dataset": FakeDataset(["d1", "d2"]), + "subject": 1, + "session": "session_0", + "score": 0.9, + "n_samples": 100, + "n_channels": 10, } d2 = { - 'time': 2, - 'dataset': FakeDataset(['d1', 'd2']), - 'subject': 2, - 'session': 'session_0', - 'score': 0.9, - 'n_samples': 100, - 'n_channels': 10, + "time": 2, + "dataset": FakeDataset(["d1", "d2"]), + "subject": 2, + "session": "session_0", + "score": 0.9, + "n_samples": 100, + "n_channels": 10, } d3 = { - 'time': 2, - 'dataset': FakeDataset(['d1', 'd2']), - 'subject': 2, - 'session': 'session_0', - 'score': 0.9, - 'n_samples': 100, - 'n_channels': 10, + "time": 2, + "dataset": FakeDataset(["d1", "d2"]), + "subject": 2, + "session": "session_0", + "score": 0.9, + "n_samples": 100, + "n_channels": 10, } d4 = { - 'time': 2, - 'dataset': FakeDataset(['d1', 'd2']), - 'subject': 1, - 'session': 'session_0', - 'score': 0.9, - 'n_samples': 100, - 'n_channels': 10, + "time": 2, + "dataset": FakeDataset(["d1", "d2"]), + "subject": 1, + "session": "session_0", + "score": 0.9, + "n_samples": 100, + "n_channels": 10, } def to_pipeline_dict(pnames): - return {n: 'pipeline {}'.format(n) for n in pnames} + return {n: "pipeline {}".format(n) for n in pnames} def to_result_input(pnames, dsets): @@ -116,7 +116,7 @@ def test_perm_random(self): class Test_Integration(unittest.TestCase): def setUp(self): self.obj = Results( - evaluation_class=DummyEvaluation, paradigm_class=DummyParadigm, suffix='test' + evaluation_class=DummyEvaluation, paradigm_class=DummyParadigm, suffix="test" ) def tearDown(self): @@ -128,7 +128,7 @@ def tearDown(self): class Test_Results(unittest.TestCase): def setUp(self): self.obj = Results( - evaluation_class=DummyEvaluation, paradigm_class=DummyParadigm, suffix='test' + evaluation_class=DummyEvaluation, paradigm_class=DummyParadigm, suffix="test" ) def tearDown(self): @@ -137,45 +137,45 @@ def tearDown(self): os.remove(path) def testCanAddSample(self): - self.obj.add(to_result_input(['a'], [d1]), to_pipeline_dict(['a'])) + self.obj.add(to_result_input(["a"], [d1]), to_pipeline_dict(["a"])) def testRecognizesAlreadyComputed(self): - _in = to_result_input(['a'], [d1]) - self.obj.add(_in, to_pipeline_dict(['a'])) + _in = to_result_input(["a"], [d1]) + self.obj.add(_in, to_pipeline_dict(["a"])) not_yet_computed = self.obj.not_yet_computed( - to_pipeline_dict(['a']), d1['dataset'], d1['subject'] + to_pipeline_dict(["a"]), d1["dataset"], d1["subject"] ) self.assertTrue(len(not_yet_computed) == 0) def testCanAddMultiplePipelines(self): - _in = to_result_input(['a', 'b', 'c'], [d1, d1, d2]) - self.obj.add(_in, to_pipeline_dict(['a', 'b', 'c'])) + _in = to_result_input(["a", "b", "c"], [d1, d1, d2]) + self.obj.add(_in, to_pipeline_dict(["a", "b", "c"])) def testCanAddMultipleValuesPerPipeline(self): - _in = to_result_input(['a', 'b'], [[d1, d2], [d2, d1]]) - self.obj.add(_in, to_pipeline_dict(['a', 'b'])) + _in = to_result_input(["a", "b"], [[d1, d2], [d2, d1]]) + self.obj.add(_in, to_pipeline_dict(["a", "b"])) not_yet_computed = self.obj.not_yet_computed( - to_pipeline_dict(['a']), d1['dataset'], d1['subject'] + to_pipeline_dict(["a"]), d1["dataset"], d1["subject"] ) self.assertTrue(len(not_yet_computed) == 0, not_yet_computed) not_yet_computed = self.obj.not_yet_computed( - to_pipeline_dict(['b']), d2['dataset'], d2['subject'] + to_pipeline_dict(["b"]), d2["dataset"], d2["subject"] ) self.assertTrue(len(not_yet_computed) == 0, not_yet_computed) not_yet_computed = self.obj.not_yet_computed( - to_pipeline_dict(['b']), d1['dataset'], d1['subject'] + to_pipeline_dict(["b"]), d1["dataset"], d1["subject"] ) self.assertTrue(len(not_yet_computed) == 0, not_yet_computed) def testCanExportToDataframe(self): - _in = to_result_input(['a', 'b', 'c'], [d1, d1, d2]) - self.obj.add(_in, to_pipeline_dict(['a', 'b', 'c'])) - _in = to_result_input(['a', 'b', 'c'], [d2, d2, d3]) - self.obj.add(_in, to_pipeline_dict(['a', 'b', 'c'])) + _in = to_result_input(["a", "b", "c"], [d1, d1, d2]) + self.obj.add(_in, to_pipeline_dict(["a", "b", "c"])) + _in = to_result_input(["a", "b", "c"], [d2, d2, d3]) + self.obj.add(_in, to_pipeline_dict(["a", "b", "c"])) df = self.obj.to_dataframe() self.assertTrue( - set(np.unique(df['pipeline'])) == set(('a', 'b', 'c')), - np.unique(df['pipeline']), + set(np.unique(df["pipeline"])) == set(("a", "b", "c")), + np.unique(df["pipeline"]), ) self.assertTrue(df.shape[0] == 6, df.shape[0]) diff --git a/moabb/tests/datasets.py b/moabb/tests/datasets.py index a0bb7295e..44eb1389d 100644 --- a/moabb/tests/datasets.py +++ b/moabb/tests/datasets.py @@ -5,7 +5,7 @@ from moabb.datasets.fake import FakeDataset -_ = mne.set_log_level('CRITICAL') +_ = mne.set_log_level("CRITICAL") def _run_tests_on_dataset(d): @@ -16,7 +16,7 @@ def _run_tests_on_dataset(d): assert isinstance(data, dict) # We should get a raw array at the end - rawdata = data[s]['session_0']['run_0'] + rawdata = data[s]["session_0"]["run_0"] assert issubclass(type(rawdata), mne.io.BaseRaw), type(rawdata) # print events @@ -31,7 +31,7 @@ def test_fake_dataset(self): n_sessions = 2 n_runs = 2 - for paradigm in ['imagery', 'p300']: + for paradigm in ["imagery", "p300"]: ds = FakeDataset( n_sessions=n_sessions, @@ -51,10 +51,10 @@ def test_fake_dataset(self): self.assertEqual(len(data[1]), n_sessions) # right number of run - self.assertEqual(len(data[1]['session_0']), n_runs) + self.assertEqual(len(data[1]["session_0"]), n_runs) # We should get a raw array at the end - self.assertEqual(type(data[1]['session_0']['run_0']), mne.io.RawArray) + self.assertEqual(type(data[1]["session_0"]["run_0"]), mne.io.RawArray) # bad subject id must raise error self.assertRaises(ValueError, ds.get_data, [1000]) diff --git a/moabb/tests/download.py b/moabb/tests/download.py index 686e9754e..d8185fcb2 100644 --- a/moabb/tests/download.py +++ b/moabb/tests/download.py @@ -1,6 +1,6 @@ -''' +""" Tests to ensure that datasets download correctly -''' +""" # from moabb.datasets.gigadb import Cho2017 # from moabb.datasets.alex_mi import AlexMI # from moabb.datasets.physionet_mi import PhysionetMI @@ -115,5 +115,5 @@ def _get_events(raw): # self.run_dataset(bi2013a) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/moabb/tests/evaluations.py b/moabb/tests/evaluations.py index cce511fcf..3b9f364f0 100644 --- a/moabb/tests/evaluations.py +++ b/moabb/tests/evaluations.py @@ -13,8 +13,8 @@ pipelines = OrderedDict() -pipelines['C'] = make_pipeline(Covariances('oas'), CSP(8), LDA()) -dataset = FakeDataset(['left_hand', 'right_hand'], n_subjects=2) +pipelines["C"] = make_pipeline(Covariances("oas"), CSP(8), LDA()) +dataset = FakeDataset(["left_hand", "right_hand"], n_subjects=2) class Test_WithinSess(unittest.TestCase): @@ -47,7 +47,7 @@ def setUp(self): self.eval = ev.WithinSessionEvaluation( paradigm=FakeImageryParadigm(), datasets=[dataset], - additional_columns=['one', 'two'], + additional_columns=["one", "two"], ) def tearDown(self): @@ -68,11 +68,11 @@ def setUp(self): def test_compatible_dataset(self): # raise - ds = FakeDataset(['left_hand', 'right_hand'], n_subjects=1) + ds = FakeDataset(["left_hand", "right_hand"], n_subjects=1) self.assertFalse(self.eval.is_valid(dataset=ds)) # do not raise - ds = FakeDataset(['left_hand', 'right_hand'], n_subjects=2) + ds = FakeDataset(["left_hand", "right_hand"], n_subjects=2) self.assertTrue(self.eval.is_valid(dataset=ds)) @@ -83,11 +83,11 @@ def setUp(self): ) def test_compatible_dataset(self): - ds = FakeDataset(['left_hand', 'right_hand'], n_sessions=1) + ds = FakeDataset(["left_hand", "right_hand"], n_sessions=1) self.assertFalse(self.eval.is_valid(ds)) # do not raise - ds = FakeDataset(['left_hand', 'right_hand'], n_sessions=2) + ds = FakeDataset(["left_hand", "right_hand"], n_sessions=2) self.assertTrue(self.eval.is_valid(dataset=ds)) diff --git a/moabb/tests/paradigms.py b/moabb/tests/paradigms.py index c861d8341..7ecc3b230 100644 --- a/moabb/tests/paradigms.py +++ b/moabb/tests/paradigms.py @@ -29,7 +29,7 @@ def used_events(self, dataset): class Test_MotorImagery(unittest.TestCase): def test_BaseImagery_paradigm(self): paradigm = SimpleMotorImagery() - dataset = FakeDataset(paradigm='imagery') + dataset = FakeDataset(paradigm="imagery") X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) # we should have all the same length @@ -40,9 +40,9 @@ def test_BaseImagery_paradigm(self): self.assertEqual(len(np.unique(labels)), 3) # metadata must have subjets, sessions, runs - self.assertTrue('subject' in metadata.columns) - self.assertTrue('session' in metadata.columns) - self.assertTrue('run' in metadata.columns) + self.assertTrue("subject" in metadata.columns) + self.assertTrue("session" in metadata.columns) + self.assertTrue("run" in metadata.columns) # we should have only one subject in the metadata self.assertEqual(np.unique(metadata.subject), 1) @@ -56,7 +56,7 @@ def test_BaseImagery_tmintmax(self): def test_BaseImagery_filters(self): # can work with filter bank paradigm = SimpleMotorImagery(filters=[[7, 12], [12, 24]]) - dataset = FakeDataset(paradigm='imagery') + dataset = FakeDataset(paradigm="imagery") X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) # X must be a 4D Array @@ -67,8 +67,8 @@ def test_baseImagery_wrongevent(self): # test process_raw return empty list if raw does not contain any # selected event. cetain runs in dataset are event specific. paradigm = SimpleMotorImagery(filters=[[7, 12], [12, 24]]) - dataset = FakeDataset(paradigm='imagery') - raw = dataset.get_data([1])[1]['session_0']['run_0'] + dataset = FakeDataset(paradigm="imagery") + raw = dataset.get_data([1])[1]["session_0"]["run_0"] # add something on the event channel raw._data[-1] *= 10 self.assertIsNone(paradigm.process_raw(raw, dataset)) @@ -78,32 +78,32 @@ def test_baseImagery_wrongevent(self): def test_BaseImagery_noevent(self): # Assert error if events from paradigm and dataset dont overlap - paradigm = SimpleMotorImagery(events=['left_hand', 'right_hand']) - dataset = FakeDataset(paradigm='imagery') + paradigm = SimpleMotorImagery(events=["left_hand", "right_hand"]) + dataset = FakeDataset(paradigm="imagery") self.assertRaises(AssertionError, paradigm.get_data, dataset) def test_LeftRightImagery_paradigm(self): # with a good dataset paradigm = LeftRightImagery() - dataset = FakeDataset(event_list=['left_hand', 'right_hand'], paradigm='imagery') + dataset = FakeDataset(event_list=["left_hand", "right_hand"], paradigm="imagery") X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) self.assertEqual(len(np.unique(labels)), 2) - self.assertEqual(list(np.unique(labels)), ['left_hand', 'right_hand']) + self.assertEqual(list(np.unique(labels)), ["left_hand", "right_hand"]) def test_LeftRightImagery_noevent(self): # we cant pass event to this class - self.assertRaises(ValueError, LeftRightImagery, events=['a']) + self.assertRaises(ValueError, LeftRightImagery, events=["a"]) def test_LeftRightImagery_badevents(self): paradigm = LeftRightImagery() # does not accept dataset with bad event - dataset = FakeDataset(paradigm='imagery') + dataset = FakeDataset(paradigm="imagery") self.assertRaises(AssertionError, paradigm.get_data, dataset) def test_FilterBankMotorImagery_paradigm(self): # can work with filter bank paradigm = FilterBankMotorImagery() - dataset = FakeDataset(paradigm='imagery') + dataset = FakeDataset(paradigm="imagery") X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) # X must be a 4D Array @@ -112,13 +112,13 @@ def test_FilterBankMotorImagery_paradigm(self): def test_FilterBankMotorImagery_moreclassesthanevent(self): self.assertRaises( - AssertionError, FilterBankMotorImagery, n_classes=3, events=['hands', 'feet'] + AssertionError, FilterBankMotorImagery, n_classes=3, events=["hands", "feet"] ) def test_FilterBankLeftRightImagery_paradigm(self): # can work with filter bank paradigm = FilterBankLeftRightImagery() - dataset = FakeDataset(event_list=['left_hand', 'right_hand'], paradigm='imagery') + dataset = FakeDataset(event_list=["left_hand", "right_hand"], paradigm="imagery") X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) # X must be a 4D Array @@ -134,7 +134,7 @@ def used_events(self, dataset): class Test_P300(unittest.TestCase): def test_BaseP300_paradigm(self): paradigm = SimpleP300() - dataset = FakeDataset(paradigm='p300', event_list=['Target', 'NonTarget']) + dataset = FakeDataset(paradigm="p300", event_list=["Target", "NonTarget"]) X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) # we should have all the same length @@ -145,9 +145,9 @@ def test_BaseP300_paradigm(self): self.assertEqual(len(np.unique(labels)), 2) # metadata must have subjets, sessions, runs - self.assertTrue('subject' in metadata.columns) - self.assertTrue('session' in metadata.columns) - self.assertTrue('run' in metadata.columns) + self.assertTrue("subject" in metadata.columns) + self.assertTrue("session" in metadata.columns) + self.assertTrue("run" in metadata.columns) # we should have only one subject in the metadata self.assertEqual(np.unique(metadata.subject), 1) @@ -164,7 +164,7 @@ def test_BaseP300_tmintmax(self): def test_BaseP300_filters(self): # can work with filter bank paradigm = SimpleP300(filters=[[1, 12], [12, 24]]) - dataset = FakeDataset(paradigm='p300', event_list=['Target', 'NonTarget']) + dataset = FakeDataset(paradigm="p300", event_list=["Target", "NonTarget"]) X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) # X must be a 4D Array @@ -175,8 +175,8 @@ def test_BaseP300_wrongevent(self): # test process_raw return empty list if raw does not contain any # selected event. cetain runs in dataset are event specific. paradigm = SimpleP300(filters=[[1, 12], [12, 24]]) - dataset = FakeDataset(paradigm='p300', event_list=['Target', 'NonTarget']) - raw = dataset.get_data([1])[1]['session_0']['run_0'] + dataset = FakeDataset(paradigm="p300", event_list=["Target", "NonTarget"]) + raw = dataset.get_data([1])[1]["session_0"]["run_0"] # add something on the event channel raw._data[-1] *= 10 self.assertIsNone(paradigm.process_raw(raw, dataset)) @@ -186,39 +186,39 @@ def test_BaseP300_wrongevent(self): def test_P300_specifyevent(self): # we cant pass event to this class - self.assertRaises(ValueError, P300, events=['a']) + self.assertRaises(ValueError, P300, events=["a"]) def test_P300_wrongevent(self): # does not accept dataset with bad event paradigm = P300() - dataset = FakeDataset(paradigm='p300') + dataset = FakeDataset(paradigm="p300") self.assertRaises(AssertionError, paradigm.get_data, dataset) def test_P300_paradigm(self): # with a good dataset paradigm = P300() - dataset = FakeDataset(event_list=['Target', 'NonTarget'], paradigm='p300') + dataset = FakeDataset(event_list=["Target", "NonTarget"], paradigm="p300") X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) self.assertEqual(len(np.unique(labels)), 2) - self.assertEqual(list(np.unique(labels)), sorted(['Target', 'NonTarget'])) + self.assertEqual(list(np.unique(labels)), sorted(["Target", "NonTarget"])) def test_BaseImagery_noevent(self): # Assert error if events from paradigm and dataset dont overlap - paradigm = SimpleMotorImagery(events=['left_hand', 'right_hand']) + paradigm = SimpleMotorImagery(events=["left_hand", "right_hand"]) dataset = FakeDataset() self.assertRaises(AssertionError, paradigm.get_data, dataset) def test_LeftRightImagery_paradigm(self): # with a good dataset paradigm = LeftRightImagery() - dataset = FakeDataset(event_list=['left_hand', 'right_hand']) + dataset = FakeDataset(event_list=["left_hand", "right_hand"]) X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) self.assertEqual(len(np.unique(labels)), 2) - self.assertEqual(list(np.unique(labels)), ['left_hand', 'right_hand']) + self.assertEqual(list(np.unique(labels)), ["left_hand", "right_hand"]) def test_LeftRightImagery_noevent(self): # we cant pass event to this class - self.assertRaises(ValueError, LeftRightImagery, events=['a']) + self.assertRaises(ValueError, LeftRightImagery, events=["a"]) def test_LeftRightImagery_badevents(self): paradigm = LeftRightImagery() @@ -238,13 +238,13 @@ def test_FilterBankMotorImagery_paradigm(self): def test_FilterBankMotorImagery_moreclassesthanevent(self): self.assertRaises( - AssertionError, FilterBankMotorImagery, n_classes=3, events=['hands', 'feet'] + AssertionError, FilterBankMotorImagery, n_classes=3, events=["hands", "feet"] ) def test_FilterBankLeftRightImagery_paradigm(self): # can work with filter bank paradigm = FilterBankLeftRightImagery() - dataset = FakeDataset(event_list=['left_hand', 'right_hand']) + dataset = FakeDataset(event_list=["left_hand", "right_hand"]) X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) # X must be a 4D Array @@ -255,7 +255,7 @@ def test_FilterBankLeftRightImagery_paradigm(self): class Test_SSVEP(unittest.TestCase): def test_BaseSSVEP_paradigm(self): paradigm = BaseSSVEP(n_classes=None) - dataset = FakeDataset(paradigm='ssvep') + dataset = FakeDataset(paradigm="ssvep") X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) # Verify that they have the same length @@ -266,9 +266,9 @@ def test_BaseSSVEP_paradigm(self): self.assertEqual(len(np.unique(labels)), 3) # metadata must have subjets, sessions, runs - self.assertTrue('subject' in metadata.columns) - self.assertTrue('session' in metadata.columns) - self.assertTrue('run' in metadata.columns) + self.assertTrue("subject" in metadata.columns) + self.assertTrue("session" in metadata.columns) + self.assertTrue("run" in metadata.columns) # Only one subject in the metadata self.assertEqual(np.unique(metadata.subject), 1) @@ -283,7 +283,7 @@ def test_baseSSVEP_tmintmax(self): def test_BaseSSVEP_filters(self): # Accept filters paradigm = BaseSSVEP(filters=[(10.5, 11.5), (12.5, 13.5)]) - dataset = FakeDataset(paradigm='ssvep') + dataset = FakeDataset(paradigm="ssvep") X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) # X must be a 4D array @@ -294,7 +294,7 @@ def test_BaseSSVEP_filters(self): def test_BaseSSVEP_nclasses_default(self): # Default is with 3 classes paradigm = BaseSSVEP() - dataset = FakeDataset(paradigm='ssvep') + dataset = FakeDataset(paradigm="ssvep") X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) # labels must contain all 3 classes of dataset, @@ -304,7 +304,7 @@ def test_BaseSSVEP_nclasses_default(self): def test_BaseSSVEP_specified_nclasses(self): # Set the number of classes paradigm = BaseSSVEP(n_classes=3) - dataset = FakeDataset(event_list=['13', '15', '17', '19'], paradigm='ssvep') + dataset = FakeDataset(event_list=["13", "15", "17", "19"], paradigm="ssvep") X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) # labels must contain 3 values @@ -312,21 +312,21 @@ def test_BaseSSVEP_specified_nclasses(self): def test_BaseSSVEP_toomany_nclasses(self): paradigm = BaseSSVEP(n_classes=4) - dataset = FakeDataset(event_list=['13', '15'], paradigm='ssvep') + dataset = FakeDataset(event_list=["13", "15"], paradigm="ssvep") self.assertRaises(ValueError, paradigm.get_data, dataset) def test_BaseSSVEP_moreclassesthanevent(self): - self.assertRaises(AssertionError, BaseSSVEP, n_classes=3, events=['13.', '14.']) + self.assertRaises(AssertionError, BaseSSVEP, n_classes=3, events=["13.", "14."]) def test_SSVEP_noevent(self): # Assert error if events from paradigm and dataset dont overlap - paradigm = SSVEP(events=['11', '12'], n_classes=2) - dataset = FakeDataset(event_list=['13', '14'], paradigm='ssvep') + paradigm = SSVEP(events=["11", "12"], n_classes=2) + dataset = FakeDataset(event_list=["13", "14"], paradigm="ssvep") self.assertRaises(AssertionError, paradigm.get_data, dataset) def test_SSVEP_paradigm(self): paradigm = SSVEP(n_classes=None) - dataset = FakeDataset(event_list=['13', '15', '17', '19'], paradigm='ssvep') + dataset = FakeDataset(event_list=["13", "15", "17", "19"], paradigm="ssvep") X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) # Verify that they have the same length @@ -337,9 +337,9 @@ def test_SSVEP_paradigm(self): self.assertEqual(len(np.unique(labels)), 4) # metadata must have subjets, sessions, runs - self.assertTrue('subject' in metadata.columns) - self.assertTrue('session' in metadata.columns) - self.assertTrue('run' in metadata.columns) + self.assertTrue("subject" in metadata.columns) + self.assertTrue("session" in metadata.columns) + self.assertTrue("run" in metadata.columns) # Only one subject in the metadata self.assertEqual(np.unique(metadata.subject), 1) @@ -350,7 +350,7 @@ def test_SSVEP_paradigm(self): def test_SSVEP_singlepass(self): # Accept only single pass filter paradigm = SSVEP(fmin=2, fmax=25) - dataset = FakeDataset(paradigm='ssvep') + dataset = FakeDataset(paradigm="ssvep") X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) # Verify that they have the same length @@ -368,7 +368,7 @@ def test_SSVEP_filter(self): def test_FilterBankSSVEP_paradigm(self): # FilterBankSSVEP with all events paradigm = FilterBankSSVEP(n_classes=None) - dataset = FakeDataset(event_list=['13', '15', '17', '19'], paradigm='ssvep') + dataset = FakeDataset(event_list=["13", "15", "17", "19"], paradigm="ssvep") X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) # X must be a 4D array @@ -379,7 +379,7 @@ def test_FilterBankSSVEP_paradigm(self): def test_FilterBankSSVEP_filters(self): # can work with filter bank paradigm = FilterBankSSVEP(filters=[(10.5, 11.5), (12.5, 13.5)]) - dataset = FakeDataset(event_list=['13', '15', '17'], paradigm='ssvep') + dataset = FakeDataset(event_list=["13", "15", "17"], paradigm="ssvep") X, labels, metadata = paradigm.get_data(dataset, subjects=[1]) # X must be a 4D array with d=2 as last dimension for the 2 filters diff --git a/moabb/tests/util_tests.py b/moabb/tests/util_tests.py index 1118ccba3..f9c918a89 100644 --- a/moabb/tests/util_tests.py +++ b/moabb/tests/util_tests.py @@ -8,30 +8,30 @@ def test_channel_intersection_fun(self): print(utils.find_intersecting_channels([d() for d in utils.dataset_list])[0]) def test_dataset_search_fun(self): - found = utils.dataset_search('imagery', multi_session=True) + found = utils.dataset_search("imagery", multi_session=True) print([type(dataset).__name__ for dataset in found]) - found = utils.dataset_search('imagery', multi_session=False) + found = utils.dataset_search("imagery", multi_session=False) print([type(dataset).__name__ for dataset in found]) res = utils.dataset_search( - 'imagery', events=['right_hand', 'left_hand', 'feet', 'tongue', 'rest'] + "imagery", events=["right_hand", "left_hand", "feet", "tongue", "rest"] ) for out in res: - print('multiclass: {}'.format(out.event_id.keys())) + print("multiclass: {}".format(out.event_id.keys())) res = utils.dataset_search( - 'imagery', events=['right_hand', 'feet'], has_all_events=True + "imagery", events=["right_hand", "feet"], has_all_events=True ) for out in res: - self.assertTrue(set(['right_hand', 'feet']) <= set(out.event_id.keys())) + self.assertTrue(set(["right_hand", "feet"]) <= set(out.event_id.keys())) def test_dataset_channel_search(self): - chans = ['C3', 'Cz'] + chans = ["C3", "Cz"] All = utils.dataset_search( - 'imagery', events=['right_hand', 'left_hand', 'feet', 'tongue', 'rest'] + "imagery", events=["right_hand", "left_hand", "feet", "tongue", "rest"] ) has_chans = utils.dataset_search( - 'imagery', - events=['right_hand', 'left_hand', 'feet', 'tongue', 'rest'], + "imagery", + events=["right_hand", "left_hand", "feet", "tongue", "rest"], channels=chans, ) has_types = set([type(x) for x in has_chans]) @@ -39,14 +39,14 @@ def test_dataset_channel_search(self): s1 = d.get_data([1])[1] sess1 = s1[list(s1.keys())[0]] raw = sess1[list(sess1.keys())[0]] - self.assertTrue(set(chans) <= set(raw.info['ch_names'])) + self.assertTrue(set(chans) <= set(raw.info["ch_names"])) for d in All: if type(d) not in has_types: s1 = d.get_data([1])[1] sess1 = s1[list(s1.keys())[0]] raw = sess1[list(sess1.keys())[0]] - self.assertFalse(set(chans) <= set(raw.info['ch_names'])) + self.assertFalse(set(chans) <= set(raw.info["ch_names"])) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/moabb/utils.py b/moabb/utils.py index ca8c348dd..9b89f2487 100644 --- a/moabb/utils.py +++ b/moabb/utils.py @@ -4,13 +4,13 @@ import mne -def set_log_level(verbose='info'): +def set_log_level(verbose="info"): """Set lot level. Set the general log level. level can be 'info', 'debug' or 'warning' """ mne.set_log_level(False) - level = {'debug': logging.DEBUG, 'info': logging.INFO, 'warning': logging.WARNING} + level = {"debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING} coloredlogs.install(level=level.get(verbose, logging.INFO)) diff --git a/pipelines/CSP_svm_search.py b/pipelines/CSP_svm_search.py index 1c559584a..3223c5fea 100644 --- a/pipelines/CSP_svm_search.py +++ b/pipelines/CSP_svm_search.py @@ -5,9 +5,9 @@ from sklearn.svm import SVC -parameters = {'kernel': ('linear', 'rbf'), 'C': [0.1, 1, 10]} +parameters = {"kernel": ("linear", "rbf"), "C": [0.1, 1, 10]} clf = GridSearchCV(SVC(), parameters, cv=3) -pipe = make_pipeline(Covariances('oas'), CSP(6), clf) +pipe = make_pipeline(Covariances("oas"), CSP(6), clf) # this is what will be loaded -PIPELINE = {'name': 'CSP + optSVM', 'paradigms': ['LeftRightImagery'], 'pipeline': pipe} +PIPELINE = {"name": "CSP + optSVM", "paradigms": ["LeftRightImagery"], "pipeline": pipe} diff --git a/pipelines/FBCSP.py b/pipelines/FBCSP.py index cc574d3bd..9b0ca4d47 100644 --- a/pipelines/FBCSP.py +++ b/pipelines/FBCSP.py @@ -9,14 +9,14 @@ from moabb.pipelines.utils import FilterBank -parameters = {'C': np.logspace(-2, 2, 10)} -clf = GridSearchCV(SVC(kernel='linear'), parameters) -fb = FilterBank(make_pipeline(Covariances(estimator='oas'), CSP(nfilter=4))) +parameters = {"C": np.logspace(-2, 2, 10)} +clf = GridSearchCV(SVC(kernel="linear"), parameters) +fb = FilterBank(make_pipeline(Covariances(estimator="oas"), CSP(nfilter=4))) pipe = make_pipeline(fb, SelectKBest(score_func=mutual_info_classif, k=10), clf) # this is what will be loaded PIPELINE = { - 'name': 'FBCSP + optSVM', - 'paradigms': ['FilterBankMotorImagery'], - 'pipeline': pipe, + "name": "FBCSP + optSVM", + "paradigms": ["FilterBankMotorImagery"], + "pipeline": pipe, } diff --git a/pipelines/LogVar.py b/pipelines/LogVar.py index 6dcbcbefe..1ec381354 100644 --- a/pipelines/LogVar.py +++ b/pipelines/LogVar.py @@ -6,9 +6,9 @@ from moabb.pipelines.features import LogVariance -parameters = {'C': np.logspace(-2, 2, 10)} -clf = GridSearchCV(SVC(kernel='linear'), parameters) +parameters = {"C": np.logspace(-2, 2, 10)} +clf = GridSearchCV(SVC(kernel="linear"), parameters) pipe = make_pipeline(LogVariance(), clf) # this is what will be loaded -PIPELINE = {'name': 'AM + optSVM', 'paradigms': ['MotorImagery'], 'pipeline': pipe} +PIPELINE = {"name": "AM + optSVM", "paradigms": ["MotorImagery"], "pipeline": pipe} diff --git a/pipelines/TSSVM.py b/pipelines/TSSVM.py index fab032a3d..52bc0498d 100644 --- a/pipelines/TSSVM.py +++ b/pipelines/TSSVM.py @@ -6,9 +6,9 @@ from sklearn.svm import SVC -parameters = {'C': np.logspace(-2, 2, 10)} -clf = GridSearchCV(SVC(kernel='linear'), parameters, cv=3) -pipe = make_pipeline(Covariances('oas'), TangentSpace(metric='riemann'), clf) +parameters = {"C": np.logspace(-2, 2, 10)} +clf = GridSearchCV(SVC(kernel="linear"), parameters, cv=3) +pipe = make_pipeline(Covariances("oas"), TangentSpace(metric="riemann"), clf) # this is what will be loaded -PIPELINE = {'name': 'TS + optSVM', 'paradigms': ['MotorImagery'], 'pipeline': pipe} +PIPELINE = {"name": "TS + optSVM", "paradigms": ["MotorImagery"], "pipeline": pipe} diff --git a/pipelines/WTRCSP.py b/pipelines/WTRCSP.py index 38b2139fc..d19029eeb 100644 --- a/pipelines/WTRCSP.py +++ b/pipelines/WTRCSP.py @@ -5,7 +5,7 @@ from moabb.pipelines.csp import TRCSP -pipe = make_pipeline(Covariances('scm'), TRCSP(nfilter=6), LinearDiscriminantAnalysis()) +pipe = make_pipeline(Covariances("scm"), TRCSP(nfilter=6), LinearDiscriminantAnalysis()) # this is what will be loaded -PIPELINE = {'name': 'TRCSP + LDA', 'paradigms': ['MotorImagery'], 'pipeline': pipe} +PIPELINE = {"name": "TRCSP + LDA", "paradigms": ["MotorImagery"], "pipeline": pipe} diff --git a/setup.py b/setup.py index 285ad1fe0..9f5f2e080 100644 --- a/setup.py +++ b/setup.py @@ -2,22 +2,22 @@ setup( - name='moabb', - version='0.2.1', - description='Mother of all BCI Benchmarks', - url='', - author='Alexandre Barachant, Vinay Jayaram', - author_email='{alexandre.barachant, vinayjayaram13}@gmail.com', - license='BSD (3-clause)', + name="moabb", + version="0.2.1", + description="Mother of all BCI Benchmarks", + url="", + author="Alexandre Barachant, Vinay Jayaram", + author_email="{alexandre.barachant, vinayjayaram13}@gmail.com", + license="BSD (3-clause)", packages=find_packages(), install_requires=[ - 'numpy', - 'scipy', - 'scikit-learn', - 'pandas', - 'mne', - 'pyriemann', - 'pyyaml', + "numpy", + "scipy", + "scikit-learn", + "pandas", + "mne", + "pyriemann", + "pyyaml", ], zip_safe=False, ) diff --git a/tutorials/plot_Getting_Started.py b/tutorials/plot_Getting_Started.py index acbe92b17..d0746df67 100644 --- a/tutorials/plot_Getting_Started.py +++ b/tutorials/plot_Getting_Started.py @@ -42,7 +42,7 @@ # we will make a couple pipelines just for convenience -moabb.set_log_level('info') +moabb.set_log_level("info") ############################################################################## # Create pipelines @@ -55,12 +55,12 @@ # is the name of the pipeline and the value is the Pipeline object pipelines = {} -pipelines['AM + LDA'] = make_pipeline(LogVariance(), LDA()) -parameters = {'C': np.logspace(-2, 2, 10)} -clf = GridSearchCV(SVC(kernel='linear'), parameters) +pipelines["AM + LDA"] = make_pipeline(LogVariance(), LDA()) +parameters = {"C": np.logspace(-2, 2, 10)} +clf = GridSearchCV(SVC(kernel="linear"), parameters) pipe = make_pipeline(LogVariance(), clf) -pipelines['AM + SVM'] = pipe +pipelines["AM + SVM"] = pipe ############################################################################## # Datasets @@ -74,7 +74,7 @@ ########################################################################## # Or you can run a search through the available datasets: -print(utils.dataset_search(paradigm='imagery', min_subjects=6)) +print(utils.dataset_search(paradigm="imagery", min_subjects=6)) ########################################################################## # Or you can simply make your own list (which we do here due to computational @@ -105,7 +105,7 @@ # subjects. This also is the correct place to specify multiple threads. evaluation = CrossSessionEvaluation( - paradigm=paradigm, datasets=datasets, suffix='examples', overwrite=False + paradigm=paradigm, datasets=datasets, suffix="examples", overwrite=False ) results = evaluation.process(pipelines) diff --git a/tutorials/plot_explore_paradigm.py b/tutorials/plot_explore_paradigm.py index c870b4414..ec07b7fe6 100644 --- a/tutorials/plot_explore_paradigm.py +++ b/tutorials/plot_explore_paradigm.py @@ -91,7 +91,7 @@ # For this data, we have one subjecy, 2 sessions (2 different recording day) # and 6 run per session. -print(metadata.describe(include='all')) +print(metadata.describe(include="all")) ############################################################################### # Paradigm object can also return the list of all dataset compatible. here diff --git a/tutorials/plot_statistical_analysis.py b/tutorials/plot_statistical_analysis.py index 140e67430..6942b1636 100644 --- a/tutorials/plot_statistical_analysis.py +++ b/tutorials/plot_statistical_analysis.py @@ -30,7 +30,7 @@ from moabb.paradigms import LeftRightImagery -moabb.set_log_level('info') +moabb.set_log_level("info") print(__doc__) @@ -56,13 +56,13 @@ pipelines = {} -pipelines['CSP + LDA'] = make_pipeline(CSP(n_components=8), LDA()) +pipelines["CSP + LDA"] = make_pipeline(CSP(n_components=8), LDA()) -pipelines['RG + LR'] = make_pipeline(Covariances(), TangentSpace(), LogisticRegression()) +pipelines["RG + LR"] = make_pipeline(Covariances(), TangentSpace(), LogisticRegression()) -pipelines['CSP + LR'] = make_pipeline(CSP(n_components=8), LogisticRegression()) +pipelines["CSP + LR"] = make_pipeline(CSP(n_components=8), LogisticRegression()) -pipelines['RG + LDA'] = make_pipeline(Covariances(), TangentSpace(), LDA()) +pipelines["RG + LDA"] = make_pipeline(Covariances(), TangentSpace(), LDA()) ############################################################################## # Evaluation @@ -82,7 +82,7 @@ datasets = [dataset] overwrite = False # set to True if we want to overwrite cached results evaluation = CrossSessionEvaluation( - paradigm=paradigm, datasets=datasets, suffix='examples', overwrite=overwrite + paradigm=paradigm, datasets=datasets, suffix="examples", overwrite=overwrite ) results = evaluation.process(pipelines) @@ -104,7 +104,7 @@ # datasets. Note that there is only one score per subject, regardless of the # number of sessions. -fig = moabb_plt.paired_plot(results, 'CSP + LDA', 'RG + LDA') +fig = moabb_plt.paired_plot(results, "CSP + LDA", "RG + LDA") plt.show() ############################################################################### @@ -124,7 +124,7 @@ # The meta-analysis style plot shows the standardized mean difference within # each tested dataset for the two algorithms in question, in addition to a # meta-effect and significances both per-dataset and overall. -fig = moabb_plt.meta_analysis_plot(stats, 'CSP + LDA', 'RG + LDA') +fig = moabb_plt.meta_analysis_plot(stats, "CSP + LDA", "RG + LDA") plt.show() ############################################################################### diff --git a/tutorials/select_electrodes_resample.py b/tutorials/select_electrodes_resample.py index 7e782d2ad..e067fe97d 100644 --- a/tutorials/select_electrodes_resample.py +++ b/tutorials/select_electrodes_resample.py @@ -43,7 +43,7 @@ # Also, use a specific resampling. In this example, all datasets are # set to 200 Hz. -paradigm = LeftRightImagery(channels=['C3', 'C4', 'Cz'], resample=200.0) +paradigm = LeftRightImagery(channels=["C3", "C4", "Cz"], resample=200.0) ############################################################################## # Evaluation @@ -55,9 +55,9 @@ evaluation = WithinSessionEvaluation(paradigm=paradigm, datasets=datasets) csp_lda = make_pipeline(CSP(n_components=2), LDA()) ts_lr = make_pipeline( - Covariances(estimator='oas'), TangentSpace(metric='riemann'), LR(C=1.0) + Covariances(estimator="oas"), TangentSpace(metric="riemann"), LR(C=1.0) ) -results = evaluation.process({'csp+lda': csp_lda, 'ts+lr': ts_lr}) +results = evaluation.process({"csp+lda": csp_lda, "ts+lr": ts_lr}) print(results.head()) ############################################################################## @@ -71,7 +71,7 @@ electrodes, datasets = find_intersecting_channels(datasets) evaluation = WithinSessionEvaluation(paradigm=paradigm, datasets=datasets, overwrite=True) -results = evaluation.process({'csp+lda': csp_lda, 'ts+lr': ts_lr}) +results = evaluation.process({"csp+lda": csp_lda, "ts+lr": ts_lr}) print(results.head()) ############################################################################## @@ -81,5 +81,5 @@ # Compare the obtained results with the two pipelines, CSP+LDA and logistic # regression computed in the tangent space of the covariance matrices. -fig = moabb_plt.paired_plot(results, 'csp+lda', 'ts+lr') +fig = moabb_plt.paired_plot(results, "csp+lda", "ts+lr") plt.show() diff --git a/tutorials/tutorial_1_simple_example_motor_imagery.py b/tutorials/tutorial_1_simple_example_motor_imagery.py index 32a6a6a4c..71a6fd851 100644 --- a/tutorials/tutorial_1_simple_example_motor_imagery.py +++ b/tutorials/tutorial_1_simple_example_motor_imagery.py @@ -28,7 +28,7 @@ from moabb.paradigms import LeftRightImagery -moabb.set_log_level('info') +moabb.set_log_level("info") warnings.filterwarnings("ignore") ############################################################################## @@ -60,8 +60,8 @@ # for some users, since the pre-processing and epoching steps can be easily # done via MNE. However, to conduct an assessment of several classifiers on # multiple subjects, MOABB ends up being a more appropriate option. -session_name = 'session_T' -run_name = 'run_1' +session_name = "session_T" +run_name = "run_1" raw = sessions[session_name][run_name] @@ -125,15 +125,15 @@ ) # We obtain the results in the form of a pandas dataframe -results = evaluation.process({'csp+lda': pipeline}) +results = evaluation.process({"csp+lda": pipeline}) # To export the results in CSV within a directory: -if not os.path.exists('./results'): - os.mkdir('./results') -results.to_csv('./results/results_part2-1.csv') +if not os.path.exists("./results"): + os.mkdir("./results") +results.to_csv("./results/results_part2-1.csv") # To load previously obtained results saved in CSV -results = pd.read_csv('./results/results_part2-1.csv') +results = pd.read_csv("./results/results_part2-1.csv") ############################################################################## # Plotting Results @@ -147,6 +147,6 @@ fig, ax = plt.subplots(figsize=(8, 7)) results["subj"] = results["subject"].apply(str) sns.barplot( - x="score", y="subj", hue='session', data=results, orient='h', palette='viridis', ax=ax + x="score", y="subj", hue="session", data=results, orient="h", palette="viridis", ax=ax ) fig.show() diff --git a/tutorials/tutorial_2_using_mulitple_datasets.py b/tutorials/tutorial_2_using_mulitple_datasets.py index 88c205a08..d6a925963 100644 --- a/tutorials/tutorial_2_using_mulitple_datasets.py +++ b/tutorials/tutorial_2_using_mulitple_datasets.py @@ -27,9 +27,9 @@ from moabb.paradigms import LeftRightImagery -moabb.set_log_level('info') -mne.set_log_level('CRITICAL') -warnings.filterwarnings('ignore') +moabb.set_log_level("info") +mne.set_log_level("CRITICAL") +warnings.filterwarnings("ignore") ############################################################################## @@ -70,12 +70,12 @@ results["subj"] = [str(resi).zfill(2) for resi in results["subject"]] g = sns.catplot( - kind='bar', + kind="bar", x="score", y="subj", col="dataset", data=results, - orient='h', - palette='viridis', + orient="h", + palette="viridis", ) plt.show() diff --git a/tutorials/tutorial_3_benchmarking_multiple_pipelines.py b/tutorials/tutorial_3_benchmarking_multiple_pipelines.py index ff3285d42..0f373465c 100644 --- a/tutorials/tutorial_3_benchmarking_multiple_pipelines.py +++ b/tutorials/tutorial_3_benchmarking_multiple_pipelines.py @@ -30,9 +30,9 @@ from moabb.paradigms import LeftRightImagery -mne.set_log_level('CRITICAL') -moabb.set_log_level('info') -warnings.filterwarnings('ignore') +mne.set_log_level("CRITICAL") +moabb.set_log_level("info") +warnings.filterwarnings("ignore") ############################################################################## @@ -48,9 +48,9 @@ pipelines = {} pipelines["csp+lda"] = make_pipeline(CSP(n_components=8), LDA()) pipelines["tgsp+svm"] = make_pipeline( - Covariances('oas'), TangentSpace(metric='riemann'), SVC(kernel='linear') + Covariances("oas"), TangentSpace(metric="riemann"), SVC(kernel="linear") ) -pipelines["MDM"] = make_pipeline(Covariances('oas'), MDM(metric='riemann')) +pipelines["MDM"] = make_pipeline(Covariances("oas"), MDM(metric="riemann")) # The following lines go exactly as in the previous example, where we end up # obtaining a pandas dataframe containing the results of the evaluation. @@ -61,7 +61,7 @@ if not os.path.exists("./results"): os.mkdir("./results") results.to_csv("./results/results_part2-3.csv") -results = pd.read_csv('./results/results_part2-3.csv') +results = pd.read_csv("./results/results_part2-3.csv") ############################################################################## # Plotting Results @@ -72,7 +72,7 @@ results["subj"] = [str(resi).zfill(2) for resi in results["subject"]] g = sns.catplot( - kind='bar', + kind="bar", x="score", y="subj", hue="pipeline", @@ -80,7 +80,7 @@ height=12, aspect=0.5, data=results, - orient='h', - palette='viridis', + orient="h", + palette="viridis", ) plt.show() diff --git a/tutorials/tutorial_4_adding_a_dataset.py b/tutorials/tutorial_4_adding_a_dataset.py index 14c59f8b3..709b6aa66 100644 --- a/tutorials/tutorial_4_adding_a_dataset.py +++ b/tutorials/tutorial_4_adding_a_dataset.py @@ -60,10 +60,10 @@ def create_example_dataset(): # Create the fake data for subject in [1, 2, 3]: x, fs = create_example_dataset() - filename = 'subject_' + str(subject).zfill(2) + '.mat' + filename = "subject_" + str(subject).zfill(2) + ".mat" mdict = {} - mdict['x'] = x - mdict['fs'] = fs + mdict["x"] = x + mdict["fs"] = fs savemat(filename, mdict) @@ -83,7 +83,7 @@ def create_example_dataset(): # The global variable with the dataset's URL should specify an online # repository where all the files are stored. -ExampleDataset_URL = 'https://sandbox.zenodo.org/record/369543/files/' +ExampleDataset_URL = "https://sandbox.zenodo.org/record/369543/files/" # The `ExampleDataset` needs to implement only 3 functions: # - `__init__` for indicating the parameter of the dataset @@ -103,11 +103,11 @@ def __init__(self): super().__init__( subjects=[1, 2, 3], sessions_per_subject=1, - events={'left_hand': 1, 'right_hand': 2}, - code='Example dataset', + events={"left_hand": 1, "right_hand": 2}, + code="Example dataset", interval=[0, 0.75], - paradigm='imagery', - doi='', + paradigm="imagery", + doi="", ) def _get_single_subject_data(self, subject): @@ -115,16 +115,16 @@ def _get_single_subject_data(self, subject): file_path_list = self.data_path(subject) data = loadmat(file_path_list[0]) - x = data['x'] - fs = data['fs'] - ch_names = ['ch' + str(i) for i in range(8)] + ['stim'] - ch_types = ['eeg' for i in range(8)] + ['stim'] + x = data["x"] + fs = data["fs"] + ch_names = ["ch" + str(i) for i in range(8)] + ["stim"] + ch_types = ["eeg" for i in range(8)] + ["stim"] info = mne.create_info(ch_names, fs, ch_types) raw = mne.io.RawArray(x, info) sessions = {} - sessions['session_1'] = {} - sessions['session_1']['run_1'] = raw + sessions["session_1"] = {} + sessions["session_1"]["run_1"] = raw return sessions def data_path( @@ -134,8 +134,8 @@ def data_path( if subject not in self.subject_list: raise (ValueError("Invalid subject number")) - url = '{:s}subject_0{:d}.mat'.format(ExampleDataset_URL, subject) - path = dl.data_path(url, 'ExampleDataset') + url = "{:s}subject_0{:d}.mat".format(ExampleDataset_URL, subject) + path = dl.data_path(url, "ExampleDataset") return [path] # it has to return a list @@ -152,7 +152,7 @@ def data_path( evaluation = WithinSessionEvaluation(paradigm=paradigm, datasets=dataset, overwrite=True) pipelines = {} -pipelines['MDM'] = make_pipeline(Covariances('oas'), MDM(metric='riemann')) +pipelines["MDM"] = make_pipeline(Covariances("oas"), MDM(metric="riemann")) scores = evaluation.process(pipelines) print(scores) From e394201b8a1a6c6844300f3f4682c9933e11de4b Mon Sep 17 00:00:00 2001 From: Vladislav Goncharenko Date: Thu, 11 Mar 2021 20:13:43 +0300 Subject: [PATCH 15/17] shortened renaming --- moabb/datasets/physionet_mi.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/moabb/datasets/physionet_mi.py b/moabb/datasets/physionet_mi.py index b61cddea4..44876c211 100644 --- a/moabb/datasets/physionet_mi.py +++ b/moabb/datasets/physionet_mi.py @@ -109,9 +109,8 @@ def _load_one_run(self, subject, run, preload=True): raw.rename_channels(lambda x: x.upper()) # fmt: off renames = { - "AFZ": "AFz", "PZ": "Pz", "FPZ": "Fpz", "FCZ": "FCz", "FP1": "Fp1", - "CZ": "Cz", "OZ": "Oz", "POZ": "POz", "IZ": "Iz", "CPZ": "CPz", - "FP2": "Fp2", "FZ": "Fz", + "AFZ": "AFz", "PZ": "Pz", "FPZ": "Fpz", "FCZ": "FCz", "FP1": "Fp1", "CZ": "Cz", + "OZ": "Oz", "POZ": "POz", "IZ": "Iz", "CPZ": "CPz", "FP2": "Fp2", "FZ": "Fz", } # fmt: on raw.rename_channels(renames) From 62879e3dd2f6b38f91ef53afd108dbd36d46e86e Mon Sep 17 00:00:00 2001 From: Vladislav Goncharenko Date: Thu, 11 Mar 2021 20:14:35 +0300 Subject: [PATCH 16/17] removed unneccsarry fmt off/on blocks --- moabb/datasets/ssvep_mamem.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/moabb/datasets/ssvep_mamem.py b/moabb/datasets/ssvep_mamem.py index cef4cdcb5..7edf29d31 100644 --- a/moabb/datasets/ssvep_mamem.py +++ b/moabb/datasets/ssvep_mamem.py @@ -31,11 +31,9 @@ # MAMEM2_URL = 'https://ndownloader.figshare.com/articles/3153409/versions/2' # MAMEM3_URL = 'https://ndownloader.figshare.com/articles/3413851/versions/1' -# fmt: off MAMEM1_URL = "https://archive.physionet.org/physiobank/database/mssvepdb/dataset1/" MAMEM2_URL = "https://archive.physionet.org/physiobank/database/mssvepdb/dataset2/" MAMEM3_URL = "https://archive.physionet.org/physiobank/database/mssvepdb/dataset3/" -# fmt: on class BaseMAMEM(BaseDataset): From c62f7a6c5e358432daf9cdff3f2cd101d9aa87ed Mon Sep 17 00:00:00 2001 From: Vladislav Goncharenko Date: Thu, 11 Mar 2021 20:15:33 +0300 Subject: [PATCH 17/17] simplified url --- moabb/datasets/gigadb.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/moabb/datasets/gigadb.py b/moabb/datasets/gigadb.py index 3f5dd81a5..b9ba81a1a 100644 --- a/moabb/datasets/gigadb.py +++ b/moabb/datasets/gigadb.py @@ -15,9 +15,7 @@ log = logging.getLogger() -GIGA_URL = ( - "ftp://parrot.genomics.cn/gigadb/pub/10.5524/100001_101000/100295/mat_data/" # noqa -) +GIGA_URL = "ftp://parrot.genomics.cn/gigadb/pub/10.5524/100001_101000/100295/mat_data/" class Cho2017(BaseDataset):