-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathISC.py
154 lines (116 loc) · 4.58 KB
/
ISC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from timeit import default_timer
import numpy as np
from scipy.linalg import eigh
def train_cca(data):
"""Run Correlated Component Analysis on your training data.
Parameters:
----------
data : dict
Dictionary with keys are names of conditions and values are numpy
arrays structured like (subjects, channels, samples).
The number of channels must be the same between all conditions!
Returns:
-------
W : np.array
Columns are spatial filters. They are sorted in descending order, it means that first column-vector maximize
correlation the most.
ISC : np.array
Inter-subject correlation sorted in descending order
"""
start = default_timer()
C = len(data.keys())
print(f'train_cca - calculations started. There are {C} conditions')
gamma = 0.1
Rw, Rb = 0, 0
for cond in data.values():
N, D, T, = cond.shape
print(f'Condition has {N} subjects, {D} sensors and {T} samples')
cond = cond.reshape(D * N, T)
# Rij
Rij = np.swapaxes(np.reshape(np.cov(cond), (N, D, N, D)), 1, 2)
# Rw
Rw = Rw + np.mean([Rij[i, i, :, :]
for i in range(0, N)], axis=0)
# Rb
Rb = Rb + np.mean([Rij[i, j, :, :]
for i in range(0, N)
for j in range(0, N) if i != j], axis=0)
# Divide by number of condition
Rw, Rb = Rw/C, Rb/C
# Regularization
Rw_reg = (1 - gamma) * Rw + gamma * np.mean(eigh(Rw)[0]) * np.identity(Rw.shape[0])
# ISCs and Ws
[ISC, W] = eigh(Rb, Rw_reg)
# Make descending order
ISC, W = ISC[::-1], W[:, ::-1]
stop = default_timer()
print(f'Elapsed time: {round(stop - start)} seconds.')
return W, ISC
def apply_cca(X, W, fs):
"""Applying precomputed spatial filters to your data.
Parameters:
----------
X : ndarray
3-D numpy array structured like (subject, channel, sample)
W : ndarray
Spatial filters.
fs : int
Frequency sampling.
Returns:
-------
ISC : ndarray
Inter-subject correlations values are sorted in descending order.
ISC_persecond : ndarray
Inter-subject correlations values per second where first row is the most correlated.
ISC_bysubject : ndarray
Description goes here.
A : ndarray
Scalp projections of ISC.
"""
start = default_timer()
print('apply_cca - calculations started')
N, D, T = X.shape
# gamma = 0.1
window_sec = 5
X = X.reshape(D * N, T)
# Rij
Rij = np.swapaxes(np.reshape(np.cov(X), (N, D, N, D)), 1, 2)
# Rw
Rw = np.mean([Rij[i, i, :, :]
for i in range(0, N)], axis=0)
# Rw_reg = (1 - gamma) * Rw + gamma * np.mean(eigh(Rw)[0]) * np.identity(Rw.shape[0])
# Rb
Rb = np.mean([Rij[i, j, :, :]
for i in range(0, N)
for j in range(0, N) if i != j], axis=0)
# ISCs
ISC = np.sort(np.diag(np.transpose(W) @ Rb @ W) / np.diag(np.transpose(W) @ Rw @ W))[::-1]
# Scalp projections
A = np.linalg.solve(Rw @ W, np.transpose(W) @ Rw @ W)
# ISC by subject
print('by subject is calculating')
ISC_bysubject = np.empty((D, N))
for subj_k in range(0, N):
Rw, Rb = 0, 0
Rw = np.mean([Rw + 1 / (N - 1) * (Rij[subj_k, subj_k, :, :] + Rij[subj_l, subj_l, :, :])
for subj_l in range(0, N) if subj_k != subj_l], axis=0)
Rb = np.mean([Rb + 1 / (N - 1) * (Rij[subj_k, subj_l, :, :] + Rij[subj_l, subj_k, :, :])
for subj_l in range(0, N) if subj_k != subj_l], axis=0)
ISC_bysubject[:, subj_k] = np.diag(np.transpose(W) @ Rb @ W) / np.diag(np.transpose(W) @ Rw @ W)
# ISC per second
print('by persecond is calculating')
ISC_persecond = np.empty((D, int(T / fs) + 1))
window_i = 0
for t in range(0, T, fs):
Xt = X[:, t:t+window_sec*fs]
Rij = np.cov(Xt)
Rw = np.mean([Rij[i:i + D, i:i + D]
for i in range(0, D * N, D)], axis=0)
Rb = np.mean([Rij[i:i + D, j:j + D]
for i in range(0, D * N, D)
for j in range(0, D * N, D) if i != j], axis=0)
ISC_persecond[:, window_i] = np.diag(np.transpose(W) @ Rb @ W) / np.diag(np.transpose(W) @ Rw @ W)
window_i += 1
stop = default_timer()
print(f'Elapsed time: {round(stop - start)} seconds.')
return ISC, ISC_persecond, ISC_bysubject, A