forked from ChFrenkel/eprop-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
254 lines (181 loc) · 11.8 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
# -*- coding: utf-8 -*-
"""
------------------------------------------------------------------------------
Copyright (C) 2020-2022 University of Zurich
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
------------------------------------------------------------------------------
"models.py" - Spiking RNN network model embedding hardcoded e-prop training procedures.
Project: PyTorch e-prop
Author: C. Frenkel, Institute of Neuroinformatics, University of Zurich and ETH Zurich
Cite this code: BibTeX/APA citation formats auto-converted from the CITATION.cff file in the repository are available
through the "Cite this repository" link in the root GitHub repo https://github.com/ChFrenkel/eprop-PyTorch/
------------------------------------------------------------------------------
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import numpy as np
import matplotlib.pyplot as plt
class SRNN(nn.Module):
def __init__(self, n_in, n_rec, n_out, n_t, thr, tau_m, tau_o, b_o, gamma, dt, model, classif, w_init_gain, lr_layer, t_crop, visualize, visualize_light, device):
super(SRNN, self).__init__()
self.n_in = n_in
self.n_rec = n_rec
self.n_out = n_out
self.n_t = n_t
self.thr = thr
self.dt = dt
self.alpha = np.exp(-dt/tau_m)
self.kappa = np.exp(-dt/tau_o)
self.gamma = gamma
self.b_o = b_o
self.model = model
self.classif = classif
self.lr_layer = lr_layer
self.t_crop = t_crop
self.visu = visualize
self.visu_l = visualize_light
self.device = device
#Parameters
self.w_in = nn.Parameter(torch.Tensor(n_rec, n_in ))
self.w_rec = nn.Parameter(torch.Tensor(n_rec, n_rec))
self.w_out = nn.Parameter(torch.Tensor(n_out, n_rec))
self.reg_term = torch.zeros(self.n_rec).to(self.device)
self.B_out = torch.Tensor(n_out, n_rec).to(self.device)
self.reset_parameters(w_init_gain)
#Visualization
if self.visu:
plt.ion()
self.fig, self.ax_list = plt.subplots(2+self.n_out+5, sharex=True)
def reset_parameters(self, gain):
torch.nn.init.kaiming_normal_(self.w_in)
self.w_in.data = gain[0]*self.w_in.data
torch.nn.init.kaiming_normal_(self.w_rec)
self.w_rec.data = gain[1]*self.w_rec.data
torch.nn.init.kaiming_normal_(self.w_out)
self.w_out.data = gain[2]*self.w_out.data
def init_net(self, n_b, n_t, n_in, n_rec, n_out):
#Hidden state
self.v = torch.zeros(n_t,n_b,n_rec).to(self.device)
self.vo = torch.zeros(n_t,n_b,n_out).to(self.device)
#Visible state
self.z = torch.zeros(n_t,n_b,n_rec).to(self.device)
#Weight gradients
self.w_in.grad = torch.zeros_like(self.w_in)
self.w_rec.grad = torch.zeros_like(self.w_rec)
self.w_out.grad = torch.zeros_like(self.w_out)
def forward(self, x, yt, do_training):
self.n_b = x.shape[1] # Extracting batch size
self.init_net(self.n_b, self.n_t, self.n_in, self.n_rec, self.n_out) # Network reset
self.w_rec *= (1 - torch.eye(self.n_rec, self.n_rec, device=self.device)) # Making sure recurrent self excitation/inhibition is cancelled
for t in range(self.n_t-1): # Computing the network state and outputs for the whole sample duration
# Forward pass - Hidden state: v: recurrent layer membrane potential
# Visible state: z: recurrent layer spike output, vo: output layer membrane potential (yo incl. activation function)
self.v[t+1] = (self.alpha * self.v[t] + torch.mm(self.z[t], self.w_rec.t()) + torch.mm(x[t], self.w_in.t())) - self.z[t]*self.thr
self.z[t+1] = (self.v[t+1] > self.thr).float()
self.vo[t+1] = self.kappa * self.vo[t] + torch.mm(self.z[t+1], self.w_out.t()) + self.b_o
if self.classif: #Apply a softmax function for classification problems
yo = F.softmax(self.vo,dim=2)
else:
yo = self.vo
if do_training:
self.grads_batch(x, yo, yt)
return yo
def grads_batch(self, x, yo, yt):
# Surrogate derivatives
h = self.gamma*torch.max(torch.zeros_like(self.v), 1-torch.abs((self.v-self.thr)/self.thr))
# Input and recurrent eligibility vectors for the 'LIF' model (vectorized computation, model-dependent)
assert self.model == "LIF", "Nice try, but model " + self.model + " is not supported. ;-)"
alpha_conv = torch.tensor([self.alpha ** (self.n_t-i-1) for i in range(self.n_t)]).float().view(1,1,-1).to(self.device)
trace_in = F.conv1d( x.permute(1,2,0), alpha_conv.expand(self.n_in ,-1,-1), padding=self.n_t, groups=self.n_in )[:,:,1:self.n_t+1].unsqueeze(1).expand(-1,self.n_rec,-1,-1) #n_b, n_rec, n_in , n_t
trace_in = torch.einsum('tbr,brit->brit', h, trace_in ) #n_b, n_rec, n_in , n_t
trace_rec = F.conv1d(self.z.permute(1,2,0), alpha_conv.expand(self.n_rec,-1,-1), padding=self.n_t, groups=self.n_rec)[:,:, :self.n_t ].unsqueeze(1).expand(-1,self.n_rec,-1,-1) #n_b, n_rec, n_rec, n_t
trace_rec = torch.einsum('tbr,brit->brit', h, trace_rec) #n_b, n_rec, n_rec, n_t
trace_reg = trace_rec
# Output eligibility vector (vectorized computation, model-dependent)
kappa_conv = torch.tensor([self.kappa ** (self.n_t-i-1) for i in range(self.n_t)]).float().view(1,1,-1).to(self.device)
trace_out = F.conv1d(self.z.permute(1,2,0), kappa_conv.expand(self.n_rec,-1,-1), padding=self.n_t, groups=self.n_rec)[:,:,1:self.n_t+1] #n_b, n_rec, n_t
# Eligibility traces
trace_in = F.conv1d( trace_in.reshape(self.n_b,self.n_in *self.n_rec,self.n_t), kappa_conv.expand(self.n_in *self.n_rec,-1,-1), padding=self.n_t, groups=self.n_in *self.n_rec)[:,:,1:self.n_t+1].reshape(self.n_b,self.n_rec,self.n_in ,self.n_t) #n_b, n_rec, n_in , n_t
trace_rec = F.conv1d( trace_rec.reshape(self.n_b,self.n_rec*self.n_rec,self.n_t), kappa_conv.expand(self.n_rec*self.n_rec,-1,-1), padding=self.n_t, groups=self.n_rec*self.n_rec)[:,:,1:self.n_t+1].reshape(self.n_b,self.n_rec,self.n_rec,self.n_t) #n_b, n_rec, n_rec, n_t
# Learning signals
err = yo - yt
L = torch.einsum('tbo,or->brt', err, self.w_out)
# Update network visualization
if self.visu:
self.update_plot(x, self.z, yo, yt, L, trace_reg, trace_in, trace_rec, trace_out)
# Compute network updates taking only the timesteps where the target is present
if self.t_crop != 0:
L = L[:,:,-self.t_crop:]
err = err[-self.t_crop:,:,:]
trace_in = trace_in[:,:,:,-self.t_crop:]
trace_rec = trace_rec[:,:,:,-self.t_crop:]
trace_out = trace_out[:,:,-self.t_crop:]
# Weight gradient updates
self.w_in.grad += self.lr_layer[0]*torch.sum(L.unsqueeze(2).expand(-1,-1,self.n_in ,-1) * trace_in , dim=(0,3))
self.w_rec.grad += self.lr_layer[1]*torch.sum(L.unsqueeze(2).expand(-1,-1,self.n_rec,-1) * trace_rec, dim=(0,3))
self.w_out.grad += self.lr_layer[2]*torch.einsum('tbo,brt->or', err, trace_out)
def update_plot(self, x, z, yo, yt, L, trace_reg, trace_in, trace_rec, trace_out):
"""Adapted from the original TensorFlow e-prop implemation from TU Graz, available at https://github.com/IGITUGraz/eligibility_propagation"""
# Clear the axis to print new plots
for k in range(self.ax_list.shape[0]):
ax = self.ax_list[k]
ax.clear()
# Plot input signals
for k, spike_ref in enumerate(zip(['In spikes','Rec spikes'],[x,z])):
spikes = spike_ref[1][:,0,:].cpu().numpy()
ax = self.ax_list[k]
ax.imshow(spikes.T, aspect='auto', cmap='hot_r', interpolation="none")
ax.set_xlim([0, self.n_t])
ax.set_ylabel(spike_ref[0])
for i in range(self.n_out):
ax = self.ax_list[i + 2]
if self.classif:
ax.set_ylim([-0.05, 1.05])
ax.set_ylabel('Output '+str(i))
ax.plot(np.arange(self.n_t), yo[:,0,i].cpu().numpy(), linestyle='dashed', label='Output', alpha=0.8)
if self.t_crop != 0:
ax.plot(np.arange(self.n_t)[-self.t_crop:], yt[-self.t_crop:,0,i].cpu().numpy(), linestyle='solid', label='Target', alpha=0.8)
else:
ax.plot(np.arange(self.n_t), yt[:,0,i].cpu().numpy(), linestyle='solid' , label='Target', alpha=0.8)
ax.set_xlim([0, self.n_t])
for i in range(5):
ax = self.ax_list[i + 2 + self.n_out]
ax.set_ylabel("Trace reg" if i==0 else "Traces out" if i==1 else "Traces rec" if i==2 else "Traces in" if i==3 else "Learning sigs")
if i==0:
if self.visu_l:
ax.plot(np.arange(self.n_t), trace_reg[0,:,0,:].T.cpu().numpy(), linestyle='dashed', label='Output', alpha=0.8)
else:
ax.plot(np.arange(self.n_t), trace_reg[0,:,:,:].reshape(self.n_rec*self.n_rec,self.n_t).T.cpu().numpy(), linestyle='dashed', label='Output', alpha=0.8)
elif i<4:
if self.visu_l:
ax.plot(np.arange(self.n_t), trace_out[0,:,:].T.cpu().numpy() if i==1 \
else trace_rec[0,:,0,:].T.cpu().numpy() if i==2 \
else trace_in[0,:,0,:].T.cpu().numpy(), linestyle='dashed', label='Output', alpha=0.8)
else:
ax.plot(np.arange(self.n_t), trace_out[0,:,:].T.cpu().numpy() if i==1 \
else trace_rec[0,:,:,:].reshape(self.n_rec*self.n_rec,self.n_t).T.cpu().numpy() if i==2 \
else trace_in[0,:,:,:].reshape(self.n_rec*self.n_in,self.n_t).T.cpu().numpy(), linestyle='dashed', label='Output', alpha=0.8)
elif self.t_crop != 0:
ax.plot(np.arange(self.n_t)[-self.t_crop:], L[0,:,-self.t_crop:].T.cpu().numpy(), linestyle='dashed', label='Output', alpha=0.8)
else:
ax.plot(np.arange(self.n_t), L[0,:,:].T.cpu().numpy(), linestyle='dashed', label='Output', alpha=0.8)
ax.set_xlim([0, self.n_t])
ax.set_xlabel('Time in ms')
# Short wait time to draw with interactive python
plt.draw()
plt.pause(0.1)
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.n_in) + ' -> ' \
+ str(self.n_rec) + ' -> ' \
+ str(self.n_out) + ') '