-
Notifications
You must be signed in to change notification settings - Fork 0
/
attention.py
27 lines (24 loc) · 1.04 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from keras.layers import Layer
import keras.backend as K
class attention(Layer):
def __init__(self, **kwargs):
super(attention, self).__init__(**kwargs)
def build(self, input_shape):
self.W = self.add_weight(name='attention_weight', shape=(input_shape[-1], 1),
initializer='random_normal', trainable=True)
self.b = self.add_weight(name='attention_bias', shape=(input_shape[1], 1),
initializer='zeros', trainable=True)
super(attention, self).build(input_shape)
def call(self, x):
# Alignment scores. Pass them through tanh function
e = K.tanh(K.dot(x, self.W) + self.b)
# Remove dimension of size 1
e = K.squeeze(e, axis=-1)
# Compute the weights
alpha = K.softmax(e)
# Reshape to tensorFlow format
alpha = K.expand_dims(alpha, axis=-1)
# Compute the context vector
context = x * alpha
context = K.sum(context, axis=1)
return context