Skip to content

Commit

Permalink
Correct TRCA (NeuroTechX#522)
Browse files Browse the repository at this point in the history
* fix: try to fix TRCA

* fix: remove n_fbands as not used in filterbanks

* fix: adapt filter

* fix: correct ssvep example

* enh: add whats new

* enh: adding a check for update of whats new file

* fix: correct filepath

---------

Co-authored-by: Sylvain Chevallier <sylain.chevallier@universite-paris-saclay.fr>
  • Loading branch information
sylvchev and Sylvain Chevallier authored Nov 7, 2023
1 parent 2074690 commit bce0f9f
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 24 deletions.
38 changes: 38 additions & 0 deletions .github/workflows/whats-new.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: Check What's new update

on:
push:
branches: [develop]
pull_request:
branches: [develop]

jobs:
check-whats-news:
runs-on: ubuntu-latest

steps:
- name: Check for file changes in PR
run: |
pr_number=${{ github.event.pull_request.number }}
response=$(curl -s -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \
"https://api.github.com/repos/${{ github.repository }}/pulls/${pr_number}/files")
file_changed=false
file_to_check="docs/source/whats_new.rst" # Specify the path to your file
for file in $(echo "${response}" | jq -r '.[] | .filename'); do
if [ "$file" == "$file_to_check" ]; then
file_changed=true
break
fi
done
if $file_changed; then
echo "File ${file_to_check} has been changed in the PR."
else
echo "File ${file_to_check} has not been changed in the PR."
echo "::error::File ${file_to_check} has not been changed in the PR."
exit 1
fi
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # This token is provided by Actions, you do not need to create your own token
2 changes: 1 addition & 1 deletion docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Enhancements
Bugs
~~~~

- None
- Fix TRCA implementation for different stimulation freqs and for signal filtering (:gh:522 by `Sylvain Chevallier`_)

API changes
~~~~~~~~~~~
Expand Down
4 changes: 1 addition & 3 deletions examples/plot_cross_subject_ssvep.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@
pipelines["CCA"] = make_pipeline(SSVEP_CCA(interval=interval, freqs=freqs, n_harmonics=2))

pipelines_TRCA = {}
pipelines_TRCA["TRCA"] = make_pipeline(
SSVEP_TRCA(interval=interval, freqs=freqs, n_fbands=5)
)
pipelines_TRCA["TRCA"] = make_pipeline(SSVEP_TRCA(interval=interval, freqs=freqs))

pipelines_MSET_CCA = {}
pipelines_MSET_CCA["MSET_CCA"] = make_pipeline(SSVEP_MsetCCA(freqs=freqs))
Expand Down
6 changes: 1 addition & 5 deletions moabb/pipelines/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,6 @@ class SSVEP_TRCA(BaseEstimator, ClassifierMixin):
Frequencies corresponding to the SSVEP components. These are
necessary to design the filterbank bands.
n_fbands : int, default=5
Number of sub-bands considered for filterbank analysis.
downsample: int, default=1
Factor by which downsample the data. A downsample value of N will result
on a sampling frequency of (sfreq // N) by taking one sample every N of
Expand Down Expand Up @@ -188,15 +185,14 @@ def __init__(
self,
interval,
freqs,
n_fbands=5,
downsample=1,
is_ensemble=True,
method="original",
estimator="scm",
):
self.freqs = freqs
self.peaks = np.array([float(f) for f in freqs.keys()])
self.n_fbands = n_fbands
self.n_fbands = len(self.peaks)
self.downsample = downsample
self.interval = interval
self.slen = interval[1] - interval[0]
Expand Down
28 changes: 13 additions & 15 deletions moabb/pipelines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,10 @@ def filterbank(X, sfreq, idx_fb, peaks):
Code based on the Matlab implementation from authors of [1]_
(https://github.com/mnakanishi/TRCA-SSVEP).
"""
if idx_fb > len(peaks):
raise (
ValueError("idx_fb should be less than number of SSVEP stimulus frequency")
)

# Calibration data comes in batches of trials
if X.ndim == 3:
Expand All @@ -299,39 +303,33 @@ def filterbank(X, sfreq, idx_fb, peaks):
elif X.ndim == 2:
num_chans = X.shape[0]
num_trials = 1
else:
print("error")

sfreq = sfreq / 2

min_freq = np.min(peaks)
peaks = np.sort(peaks)
max_freq = np.max(peaks)

if max_freq < 40:
top = 100
top = 40
else:
top = 115
top = 60
# Check for Nyquist
if top >= sfreq:
top = sfreq - 10

diff = max_freq - min_freq
# Lowcut frequencies for the pass band (depends on the frequencies of SSVEP)
# No more than 3dB loss in the passband

passband = [min_freq - 2 + x * diff for x in range(7)]
passband = [peaks[i] - 1 for i in range(len(peaks))]

# At least 40db attenuation in the stopband
if min_freq - 4 > 0:
stopband = [
min_freq - 4 + x * (diff - 2) if x < 3 else min_freq - 4 + x * diff
for x in range(7)
]
else:
stopband = [2 + x * (diff - 2) if x < 3 else 2 + x * diff for x in range(7)]
stopband = [peaks[i] - 2 for i in range(len(peaks))]

Wp = [passband[idx_fb] / sfreq, top / sfreq]
Ws = [stopband[idx_fb] / sfreq, (top + 7) / sfreq]
Ws = [stopband[idx_fb] / sfreq, (top + 20) / sfreq]

N, Wn = scp.cheb1ord(Wp, Ws, 3, 40) # Chebyshev type I filter order selection.
N, Wn = scp.cheb1ord(Wp, Ws, 3, 15) # Chebyshev type I filter order selection.

B, A = scp.cheby1(N, 0.5, Wn, btype="bandpass") # Chebyshev type I filter design

Expand Down

0 comments on commit bce0f9f

Please sign in to comment.