-
Notifications
You must be signed in to change notification settings - Fork 4
/
spiking_learning.py
472 lines (321 loc) · 10.1 KB
/
spiking_learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
# IMSL Lab - University of Notre Dame
# Author: Clemens JS Schaefer
import functools
from typing import Any, Callable, Sequence
import jax
from jax import dtypes
from jax import random
from flax import linen as nn
import jax.numpy as jnp
import numpy as np
from jax._src.nn.initializers import lecun_normal
Array = jnp.ndarray
DType = Any
def uniform(scale=1e-2, dtype: DType = jnp.float_) -> Callable:
"""Builds an initializer that returns real uniformly-distributed random
arrays.
Args:
scale: optional; the upper and lower bound of the random distribution.
dtype: optional; the initializer's default dtype.
Returns:
An initializer that returns arrays whose values are uniformly distributed
in the range ``[-scale, scale)``.
"""
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
return random.uniform(key, shape, dtype) * scale * 2 - scale
return init
def static_init(val=1.0, dtype: DType = jnp.float_) -> Callable:
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
return jnp.ones(shape, dtype) * val
return init
def normal_shift(
bias=0, scale=1e-2, no_sign_flip=True, dtype: DType = jnp.float_
) -> Callable:
"""Builds an initializer that returns real uniformly-distributed random
arrays.
Args:
scale: optional; the upper and lower bound of the random distribution.
dtype: optional; the initializer's default dtype.
Returns:
An initializer that returns arrays whose values are uniformly distributed
in the range ``[-scale, scale)``.
"""
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
x = random.normal(key, shape, dtype) * scale + bias
if no_sign_flip:
x = jnp.abs(x)
return x
return init
@jax.custom_vjp
def debug(x):
return x
def debug_fwd(x):
return debug(x), x
def debug_bwd(res, g):
import pdb
pdb.set_trace()
return (g,)
debug.defvjp(debug_fwd, debug_bwd)
class gsis(nn.Module):
sigmoid_bias: float = 2
sigmoid_scale: float = 2
theta: float = 0.1
fn: Callable = lambda x: 1 / (1 + (2 * jnp.pi / 2 * x) ** 2)
@nn.compact
def __call__(self, x):
@jax.custom_vjp
def gsis_fn(x):
return x
def gsis_fn_fwd(x):
return gsis_fn(x), x
def gsis_fn_bwd(res, g):
x = res
# alpha = 2 # fixed at two !!!!
# scale = self.theta * jax.nn.relu(1 - (jnp.abs(x) * 2)) # piecewise
scale = 1 + self.theta * self.fn(x) # atan
# scale = (1 + self.theta * jnp.abs(x - (x >= .5)))
return (g * scale,)
gsis_fn.defvjp(gsis_fn_fwd, gsis_fn_bwd)
alpha = self.param(
"upscale",
normal_shift(self.sigmoid_bias, self.sigmoid_scale),
(x.shape[-1],),
)
# pre process
x = jax.nn.sigmoid(x * alpha)
return gsis_fn(x)
@jax.custom_vjp
def fast_sigmoid(x):
# if not dtype float grad ops wont work
return jnp.array(x >= 0.0, dtype=x.dtype)
def fast_sigmoid_fwd(x):
return fast_sigmoid(x), x
def fast_sigmoid_bwd(res, g):
x = res
alpha = 10
scale = 1 / (alpha * jnp.abs(x) + 1.0) ** 2
return (g * scale,)
fast_sigmoid.defvjp(fast_sigmoid_fwd, fast_sigmoid_bwd)
@jax.custom_vjp
def slayer(x):
# if not dtype float grad ops wont work
return jnp.array(x >= 0.0, dtype=x.dtype)
def slayer_fwd(x):
return slayer(x), x
def slayer_bwd(res, g):
x = res
scale = jnp.exp(-jnp.abs(x) * 5)
return (g * scale,)
slayer.defvjp(slayer_fwd, slayer_bwd)
@jax.custom_vjp
def smooth_step(x):
# if not dtype float grad ops wont work
return jnp.array(x >= 0.0, dtype=x.dtype)
def smooth_step_fwd(x):
return smooth_step(x), x
def smooth_step_bwd(res, g):
x = res
scale = jnp.logical_and((x < 0.5), (x >= -0.5))
return (g * scale,)
smooth_step.defvjp(smooth_step_fwd, smooth_step_bwd)
@jax.custom_vjp
def piecewise_linear(x):
# if not dtype float grad ops wont work
return jnp.array(x >= 0.0, dtype=x.dtype)
def piecewise_linear_fwd(x):
return piecewise_linear(x), x
def piecewise_linear_bwd(res, g):
x = res
# mask = jnp.logical_and((x > . 5), (x <= -.5))
scale = jax.nn.relu(1 - (jnp.abs(x) * 2))
return (g * scale,) # * mask,
piecewise_linear.defvjp(piecewise_linear_fwd, piecewise_linear_bwd)
@jax.custom_vjp
def atan(x):
# if not dtype float grad ops wont work
return jnp.array(x >= 0.0, dtype=x.dtype)
def atan_fwd(x):
return atan(x), x
def atan_bwd(res, g):
# originally from SpikingJelly
x = res
alpha = 2
shared_c = g / (1 + (alpha * jnp.pi / 2 * x) ** 2)
return (alpha / 2 * shared_c,)
atan.defvjp(atan_fwd, atan_bwd)
class leaky_current_based_IF_rel_refactory(nn.Module):
beta: float
alpha: float
alpharp: float
spike_fn: Callable
connection_fn: Callable
wrp: float = 1.0
"""
From "Synaptic Plasticity Dynamics for Deep Continuous Local Learning
(DECOLLE)" - https://arxiv.org/abs/1811.10766
"""
@nn.compact
def __call__(self, carry, s_in):
sQ, sP, sR, sS = carry
Q = self.beta * sQ + (1 - self.beta) * s_in
P = self.alpha * sP + (1 - self.alpha) * sQ
R = self.alpharp * sR - (1 - self.alpharp) * sS * self.wrp
U = self.connection_fn(P) + R
S = self.spike_fn(U)
return (Q, P, R, S), U
@staticmethod
def initialize_carry(inputs, connection_fn):
x = connection_fn(inputs)
return (
jnp.zeros_like(inputs, dtype=jnp.float32),
jnp.zeros_like(inputs, dtype=jnp.float32),
jnp.zeros(x.shape, dtype=jnp.float32),
jnp.zeros(x.shape, dtype=jnp.float32),
)
class DecolleSpikingBlock(nn.Module):
connection_fn: Callable
loss_type: Callable
num_classes: int
neural_dynamics: Callable
pool_window: Sequence[int] = (1, 1)
train: bool = True
drop_out: float = 0.5
"""
From "Synaptic Plasticity Dynamics for Deep Continuous Local Learning
(DECOLLE)" - https://arxiv.org/abs/1811.10766
"""
@functools.partial(
nn.transforms.scan,
variable_broadcast="params",
split_rngs={"params": False, "dropout": True},
)
@nn.compact
def __call__(self, carry, pair):
inputs, trgt = pair
carry, u = self.neural_dynamics(connection_fn=self.connection_fn)(
carry, inputs
)
u_p = nn.max_pool(u, self.pool_window, strides=self.pool_window)
s_ = fast_sigmoid(u_p)
# local learning
flatten_size = np.prod(u_p.shape[1:])
w_ro = self.param(
"w_ro", lecun_normal(), (self.num_classes, flatten_size)
)
stdv = 0.5 / np.sqrt(self.num_classes) # lc_ampl
b_ro = self.param("b_ro", uniform(stdv), (self.num_classes,))
@jax.custom_vjp
def decolle(x, w, b, trgt):
out_local = jnp.dot(x, w.transpose()) + b
return out_local
def decolle_fwd(x, w, b, trgt):
out_local = decolle(x, w, b, trgt)
return out_local, (out_local, w, trgt, x.shape)
def decolle_bwd(res, g):
(out_local, w, trgt, shape) = res
err = jax.grad(
lambda x: jnp.mean(jnp.mean(self.loss_type(x, trgt)))
)(out_local)
grad = jnp.dot(err, w)
return grad, jnp.zeros_like(w), jnp.zeros((err.shape[-1])), None
decolle.defvjp(decolle_fwd, decolle_bwd)
sd_ = nn.Dropout(self.drop_out)(s_, deterministic=not self.train)
# reshape has to be compatible with decolle pytorch
sd_ = jnp.reshape(
jnp.moveaxis(sd_, (0, 1, 2, 3), (0, 2, 3, 1)), (sd_.shape[0], -1)
)
out_local = decolle(sd_, w_ro, b_ro, trgt)
return carry, (s_, out_local)
@staticmethod
def initialize_carry(inputs, connection_fn, neural_dynamics):
return neural_dynamics(connection_fn=connection_fn).initialize_carry(
inputs[0, :], connection_fn
)
class parametric_leaky_IF(nn.Module):
init_tau: float
spike_fn: Callable
v_threshold: float = 1.0
v_reset: float = 0.0
pre_spike_fn: Callable = None
dtype: Any = jnp.float32
"""
From "Incorporating Learnable Membrane Time Constant to Enhance Learning of
Spiking Neural Networks" - https://arxiv.org/pdf/2007.05785.pdf
"""
@nn.compact
def __call__(self, u, s_in):
tau = self.param(
"tau",
static_init(-jnp.log(self.init_tau - 1), dtype=self.dtype),
(1,),
)
v_threshold = jnp.array([self.v_threshold], dtype=self.dtype)
v_reset = jnp.array([self.v_reset], dtype=self.dtype)
u += (s_in - (u - v_reset)) * jax.nn.sigmoid(tau)
s = self.spike_fn(u - v_threshold)
u = jnp.where(s, v_reset, u)
return u, s
class multi_step_LIF(nn.Module):
tau: float
spike_fn: Callable
v_threshold: float = 1.0
v_reset: float = 0.0
pre_spike_fn: Callable = None
dtype: Any = jnp.float32
"""
From "TCJA-SNN: Temporal-Channel Joint Attention for Spiking Neural Networks"
- https://arxiv.org/pdf/2206.10177.pdf
"""
@nn.compact
def __call__(self, u, s_in):
tau = jnp.array([self.tau], dtype=self.dtype)
v_threshold = jnp.array([self.v_threshold], dtype=self.dtype)
v_reset = jnp.array([self.v_reset], dtype=self.dtype)
u += (s_in - (u - v_reset)) / tau
s = self.spike_fn(u - v_threshold)
u = jnp.where(s, v_reset, u)
return u, s
class LIF(nn.Module):
init_tau: float
spike_fn: Callable
v_threshold: float = 1.0
v_reset: float = 0.0
dtype: Any = jnp.float32
@nn.compact
def __call__(self, u, s_in):
tau = self.param("tau", uniform(self.init_tau), (u.shape[-1],))
v_threshold = jnp.array([self.v_threshold], dtype=self.dtype)
v_reset = jnp.array([self.v_reset], dtype=self.dtype)
u = u * jax.nn.sigmoid(tau) + s_in
s = self.spike_fn(u - v_threshold)
u = jnp.where(s > 0.5, v_reset, u)
return u, s
class SpikingBlock(nn.Module):
connection_fn: Callable
neural_dynamics: Callable
norm_fn: Callable = None
@nn.remat
@functools.partial(
nn.transforms.scan,
variable_broadcast="params",
variable_carry="batch_stats",
split_rngs={"params": False},
)
@nn.compact
def __call__(self, u, inputs):
x = self.connection_fn(inputs)
if self.norm_fn:
x = self.norm_fn(x)
u, s = self.neural_dynamics(u, x)
return u, s
@staticmethod
def initialize_carry(
inputs, connection_fn, norm_fn=None, dtype=jnp.float32
):
x = connection_fn(inputs[0, :])
if norm_fn:
x = norm_fn(x)
return jnp.zeros_like(x, dtype=dtype)