-
Notifications
You must be signed in to change notification settings - Fork 0
/
Model.py
58 lines (52 loc) · 2.76 KB
/
Model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Permute, Dropout
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import SpatialDropout2D
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras.constraints import max_norm
from tensorflow.keras import backend as K
def EEGNet(nb_classes, Chans = 64, Samples = 128,
dropoutRate = 0.5, kernLength = 64, F1 = 8,
D = 2, F2 = 16, norm_rate = 0.25):
'''
Create an EEGNet Model with the given parameters
see for original implementation :
https://github.com/vlawhern/arl-eegmodels
--nb_classes : number of classes and number of neurons in the last layer
--Chans : number of data channels ,
--Samples : number of sample during time
--dropoutRate
--kernLength
--F1 : number of filter in the first convolutional layer
--D : depth multiplier in the depthwise layer
--F2 : number of filter in the second convolutional layer
--norm_rate : normalization rate
Return: EEGNet Tensorflow model
'''
dropoutType = Dropout
input1 = Input(shape = (Chans, Samples, 1))
##################################################################
block1 = Conv2D(F1, (1, kernLength), padding = 'same',
input_shape = (Chans, Samples, 1),
use_bias = False)(input1)
block1 = BatchNormalization()(block1)
block1 = DepthwiseConv2D((Chans, 1), use_bias = False,
depth_multiplier = D,
depthwise_constraint = max_norm(1.))(block1)
block1 = BatchNormalization()(block1)
block1 = Activation('elu')(block1)
block1 = AveragePooling2D((1, 4))(block1)
block1 = dropoutType(dropoutRate)(block1)
block2 = SeparableConv2D(F2, (1, 16),
use_bias = False, padding = 'same')(block1)
block2 = BatchNormalization()(block2)
block2 = Activation('elu')(block2)
block2 = AveragePooling2D((1, 8))(block2)
block2 = dropoutType(dropoutRate)(block2)
flatten = Flatten(name = 'flatten')(block2)
dense = Dense(nb_classes, name = 'dense',
kernel_constraint = max_norm(norm_rate))(flatten)
softmax = Activation('softmax', name = 'softmax')(dense)
return Model(inputs=input1, outputs=softmax)