diff --git a/.github/workflows/whats-new.yml b/.github/workflows/whats-new.yml new file mode 100644 index 000000000..4c121cda4 --- /dev/null +++ b/.github/workflows/whats-new.yml @@ -0,0 +1,38 @@ +name: Check What's new update + +on: + push: + branches: [develop] + pull_request: + branches: [develop] + +jobs: + check-whats-news: + runs-on: ubuntu-latest + + steps: + - name: Check for file changes in PR + run: | + pr_number=${{ github.event.pull_request.number }} + response=$(curl -s -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \ + "https://api.github.com/repos/${{ github.repository }}/pulls/${pr_number}/files") + + file_changed=false + file_to_check="docs/source/whats_new.rst" # Specify the path to your file + + for file in $(echo "${response}" | jq -r '.[] | .filename'); do + if [ "$file" == "$file_to_check" ]; then + file_changed=true + break + fi + done + + if $file_changed; then + echo "File ${file_to_check} has been changed in the PR." + else + echo "File ${file_to_check} has not been changed in the PR." + echo "::error::File ${file_to_check} has not been changed in the PR." + exit 1 + fi + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # This token is provided by Actions, you do not need to create your own token diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index 301f3a0ce..f67a393ce 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -23,7 +23,7 @@ Enhancements Bugs ~~~~ -- None +- Fix TRCA implementation for different stimulation freqs and for signal filtering (:gh:522 by `Sylvain Chevallier`_) API changes ~~~~~~~~~~~ diff --git a/examples/plot_cross_subject_ssvep.py b/examples/plot_cross_subject_ssvep.py index e6319b509..20dd35052 100644 --- a/examples/plot_cross_subject_ssvep.py +++ b/examples/plot_cross_subject_ssvep.py @@ -104,9 +104,7 @@ pipelines["CCA"] = make_pipeline(SSVEP_CCA(interval=interval, freqs=freqs, n_harmonics=2)) pipelines_TRCA = {} -pipelines_TRCA["TRCA"] = make_pipeline( - SSVEP_TRCA(interval=interval, freqs=freqs, n_fbands=5) -) +pipelines_TRCA["TRCA"] = make_pipeline(SSVEP_TRCA(interval=interval, freqs=freqs)) pipelines_MSET_CCA = {} pipelines_MSET_CCA["MSET_CCA"] = make_pipeline(SSVEP_MsetCCA(freqs=freqs)) diff --git a/moabb/pipelines/classification.py b/moabb/pipelines/classification.py index 07fba8d4c..d9afff3fa 100644 --- a/moabb/pipelines/classification.py +++ b/moabb/pipelines/classification.py @@ -114,9 +114,6 @@ class SSVEP_TRCA(BaseEstimator, ClassifierMixin): Frequencies corresponding to the SSVEP components. These are necessary to design the filterbank bands. - n_fbands : int, default=5 - Number of sub-bands considered for filterbank analysis. - downsample: int, default=1 Factor by which downsample the data. A downsample value of N will result on a sampling frequency of (sfreq // N) by taking one sample every N of @@ -188,7 +185,6 @@ def __init__( self, interval, freqs, - n_fbands=5, downsample=1, is_ensemble=True, method="original", @@ -196,7 +192,7 @@ def __init__( ): self.freqs = freqs self.peaks = np.array([float(f) for f in freqs.keys()]) - self.n_fbands = n_fbands + self.n_fbands = len(self.peaks) self.downsample = downsample self.interval = interval self.slen = interval[1] - interval[0] diff --git a/moabb/pipelines/utils.py b/moabb/pipelines/utils.py index f836bb01a..105b2a8b4 100644 --- a/moabb/pipelines/utils.py +++ b/moabb/pipelines/utils.py @@ -289,6 +289,10 @@ def filterbank(X, sfreq, idx_fb, peaks): Code based on the Matlab implementation from authors of [1]_ (https://github.com/mnakanishi/TRCA-SSVEP). """ + if idx_fb > len(peaks): + raise ( + ValueError("idx_fb should be less than number of SSVEP stimulus frequency") + ) # Calibration data comes in batches of trials if X.ndim == 3: @@ -299,39 +303,33 @@ def filterbank(X, sfreq, idx_fb, peaks): elif X.ndim == 2: num_chans = X.shape[0] num_trials = 1 + else: + print("error") sfreq = sfreq / 2 - min_freq = np.min(peaks) + peaks = np.sort(peaks) max_freq = np.max(peaks) if max_freq < 40: - top = 100 + top = 40 else: - top = 115 + top = 60 # Check for Nyquist if top >= sfreq: top = sfreq - 10 - diff = max_freq - min_freq # Lowcut frequencies for the pass band (depends on the frequencies of SSVEP) # No more than 3dB loss in the passband - - passband = [min_freq - 2 + x * diff for x in range(7)] + passband = [peaks[i] - 1 for i in range(len(peaks))] # At least 40db attenuation in the stopband - if min_freq - 4 > 0: - stopband = [ - min_freq - 4 + x * (diff - 2) if x < 3 else min_freq - 4 + x * diff - for x in range(7) - ] - else: - stopband = [2 + x * (diff - 2) if x < 3 else 2 + x * diff for x in range(7)] + stopband = [peaks[i] - 2 for i in range(len(peaks))] Wp = [passband[idx_fb] / sfreq, top / sfreq] - Ws = [stopband[idx_fb] / sfreq, (top + 7) / sfreq] + Ws = [stopband[idx_fb] / sfreq, (top + 20) / sfreq] - N, Wn = scp.cheb1ord(Wp, Ws, 3, 40) # Chebyshev type I filter order selection. + N, Wn = scp.cheb1ord(Wp, Ws, 3, 15) # Chebyshev type I filter order selection. B, A = scp.cheby1(N, 0.5, Wn, btype="bandpass") # Chebyshev type I filter design