forked from yjwu17/STBP-for-training-SpikingNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
spiking_model.py
89 lines (70 loc) · 3.22 KB
/
spiking_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
thresh = 0.5 # neuronal threshold
lens = 0.5 # hyper-parameters of approximate function
decay = 0.2 # decay constants
num_classes = 10
batch_size = 100
learning_rate = 1e-3
num_epochs = 100 # max epoch
# define approximate firing function
class ActFun(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.gt(thresh).float()
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
temp = abs(input - thresh) < lens
return grad_input * temp.float()
act_fun = ActFun.apply
# membrane potential update
def mem_update(ops, x, mem, spike):
mem = mem * decay * (1. - spike) + ops(x)
spike = act_fun(mem) # act_fun : approximation firing function
return mem, spike
# cnn_layer(in_planes, out_planes, stride, padding, kernel_size)
cfg_cnn = [(1, 32, 1, 1, 3),
(32, 32, 1, 1, 3),]
# kernel size
cfg_kernel = [28, 14, 7]
# fc layer
cfg_fc = [128, 10]
# Dacay learning_rate
def lr_scheduler(optimizer, epoch, init_lr=0.1, lr_decay_epoch=50):
"""Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
if epoch % lr_decay_epoch == 0 and epoch > 1:
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr'] * 0.1
return optimizer
class SCNN(nn.Module):
def __init__(self):
super(SCNN, self).__init__()
in_planes, out_planes, stride, padding, kernel_size = cfg_cnn[0]
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding)
in_planes, out_planes, stride, padding, kernel_size = cfg_cnn[1]
self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding)
self.fc1 = nn.Linear(cfg_kernel[-1] * cfg_kernel[-1] * cfg_cnn[-1][1], cfg_fc[0])
self.fc2 = nn.Linear(cfg_fc[0], cfg_fc[1])
def forward(self, input, time_window = 20):
c1_mem = c1_spike = torch.zeros(batch_size, cfg_cnn[0][1], cfg_kernel[0], cfg_kernel[0], device=device)
c2_mem = c2_spike = torch.zeros(batch_size, cfg_cnn[1][1], cfg_kernel[1], cfg_kernel[1], device=device)
h1_mem = h1_spike = h1_sumspike = torch.zeros(batch_size, cfg_fc[0], device=device)
h2_mem = h2_spike = h2_sumspike = torch.zeros(batch_size, cfg_fc[1], device=device)
for step in range(time_window): # simulation time steps
x = input > torch.rand(input.size(), device=device) # prob. firing
c1_mem, c1_spike = mem_update(self.conv1, x.float(), c1_mem, c1_spike)
x = F.avg_pool2d(c1_spike, 2)
c2_mem, c2_spike = mem_update(self.conv2,x, c2_mem,c2_spike)
x = F.avg_pool2d(c2_spike, 2)
x = x.view(batch_size, -1)
h1_mem, h1_spike = mem_update(self.fc1, x, h1_mem, h1_spike)
h1_sumspike += h1_spike
h2_mem, h2_spike = mem_update(self.fc2, h1_spike, h2_mem,h2_spike)
h2_sumspike += h2_spike
outputs = h2_sumspike / time_window
return outputs