-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathlstm_network.py
251 lines (230 loc) · 14.5 KB
/
lstm_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import tensorflow as tf
import numpy as np
# currently 1 bi-directional lstm layer followed by a dense layer
class LSTM_network:
def __init__(self, n_hidden, embedding_dim, n_classes, weights=None, debug=False):
self.n_hidden = n_hidden
self.embedding_dim = embedding_dim
self.n_classes = n_classes
self.debug = debug
# model parameters
if weights is not None:
self.check_weights(weights)
self.W_x_fward = tf.constant(weights[0], dtype=tf.float64)
self.W_h_fward = tf.constant(weights[1], dtype=tf.float64)
self.b_fward = tf.constant(weights[2], dtype=tf.float64)
self.W_x_bward = tf.constant(weights[3], dtype=tf.float64)
self.W_h_bward = tf.constant(weights[4], dtype=tf.float64)
self.b_bward = tf.constant(weights[5], dtype=tf.float64)
self.W_dense_fw = tf.constant(weights[6][:self.n_hidden], dtype=tf.float64)
self.W_dense_bw = tf.constant(weights[6][self.n_hidden:], dtype=tf.float64)
self.b_dense = tf.constant(weights[7], dtype=tf.float64)
else:
self.W_x_fward = tf.constant(np.random.randn(self.embedding_dim, 4 * self.n_hidden))
self.W_h_fward = tf.constant(np.random.randn(self.n_hidden, 4 * self.n_hidden))
self.b_fward = tf.constant(np.random.randn(4*self.n_hidden,))
self.W_x_bward = tf.constant(np.random.randn(self.embedding_dim, 4 * self.n_hidden))
self.W_h_bward = tf.constant(np.random.randn(self.n_hidden, 4 * self.n_hidden))
self.b_bward = tf.constant(np.random.randn(4 * self.n_hidden, ))
self.W_dense_fw = tf.constant(np.random.randn(n_hidden, n_classes))
self.W_dense_bw = tf.constant(np.random.randn(n_hidden, n_classes))
self.b_dense = tf.constant(np.random.randn(n_classes))
# prediction of the net
self.y_hat = tf.Variable(0., shape=tf.TensorShape(None), dtype=tf.float64, name='y_hat')
# the following order is from keras. You might have to adjust it if you use different frameworks
self.idx_i = slice(0, self.n_hidden)
self.idx_f = slice(self.n_hidden, 2 * self.n_hidden)
self.idx_c = slice(2 * self.n_hidden, 3 * self.n_hidden)
self.idx_o = slice(3 * self.n_hidden, 4 * self.n_hidden)
def check_weights(self, weights):
assert len(weights) == 8
assert weights[0].shape == weights[3].shape == (self.embedding_dim, 4 * self.n_hidden)
assert weights[1].shape == weights[4].shape == (self.n_hidden, 4 * self.n_hidden)
assert weights[2].shape == weights[5].shape == (4 * self.n_hidden, )
assert weights[6].shape == (2 * self.n_hidden, self.n_classes)
assert weights[7].shape == (self.n_classes,)
# x is batch of embedding vectors (batch_size, embedding_dim)
@tf.function
def cell_step(self, x, h_old, c_old, W_x, W_h, b):
# fward pass
gate_x = tf.matmul(x, W_x)
gate_h = tf.matmul(h_old, W_h)
gate_pre = gate_x + gate_h + b
gate_post = tf.concat([
tf.sigmoid(gate_pre[:, self.idx_i]), tf.sigmoid(gate_pre[:, self.idx_f]),
tf.tanh(gate_pre[:, self.idx_c]), tf.sigmoid(gate_pre[:, self.idx_o]),
], axis=1)
c_new = gate_post[:, self.idx_f] * c_old + gate_post[:, self.idx_i] * gate_post[:, self.idx_c]
h_new = gate_post[:, self.idx_o] * tf.tanh(c_new)
return gate_pre, gate_post, c_new, h_new
# x is batch of embedding vectors (batch_size, embedding_dim)
@tf.function
def one_step_fward(self, x, h_old_fw, c_old_fw):
fward = self.cell_step(x, h_old_fw, c_old_fw, self.W_x_fward, self.W_h_fward, self.b_fward)
return fward
# x_rev is batch of embedding vectors (batch_size, embedding_dim)
@tf.function
def one_step_bward(self, x_rev, h_old_bw, c_old_bw):
bward = self.cell_step(x_rev, h_old_bw, c_old_bw, self.W_x_bward, self.W_h_bward, self.b_bward)
return bward
# input is full batch (batch_size, T, embedding_dim)
@tf.function(experimental_relax_shapes=True)
def full_pass(self, x):
batch_size = x.shape[0]
# we have to reorder the input since tf.scan scans the input along the first axis
elems = tf.transpose(x, perm=[1,0,2])
initializer = (tf.constant(np.zeros((batch_size, 4 * self.n_hidden))), # gates_pre
tf.constant(np.zeros((batch_size, 4 * self.n_hidden))), # gates_post
tf.constant(np.zeros((batch_size, self.n_hidden))), # c_t
tf.constant(np.zeros((batch_size, self.n_hidden)))) # h_t
fn_fward = lambda a, x: self.one_step_fward(x, a[3], a[2])
fn_bward = lambda a, x: self.one_step_bward(x, a[3], a[2])
# outputs contain tesnors with (T, gates_pre, gates_post, c,h)
o_fward = tf.scan(fn_fward, elems, initializer=initializer)
o_bward = tf.scan(fn_bward, elems, initializer=initializer, reverse=True)
# final prediction scores
y_fward = tf.matmul(o_fward[3][-1], self.W_dense_fw)
y_bward = tf.matmul(o_bward[3][0], self.W_dense_bw)
y_hat = y_fward + y_bward + self.b_dense
self.y_hat.assign(y_hat)
return y_hat, o_fward, o_bward
def lrp_linear_layer(self, h_in, w, b, hout, Rout, bias_nb_units, eps, bias_factor=0.0):
"""
LRP for a linear layer with input dim D and output dim M.
Args:
- hin: forward pass input, of shape (batch_size, D)
- w: connection weights, of shape (D, M)
- b: biases, of shape (M,)
- hout: forward pass output, of shape (batch_size, M) (unequal to np.dot(w.T,hin)+b if more than
one incoming layer!)
- Rout: relevance at layer output, of shape (batch_size, M)
- bias_nb_units: total number of connected lower-layer units (onto which the bias/stabilizer contribution
is redistributed for sanity check)
- eps: stabilizer (small positive number)
- bias_factor: set to 1.0 to check global relevance conservation, otherwise use 0.0 to ignore
bias/stabilizer redistribution (recommended)
Returns:
- Rin: relevance at layer input, of shape (batch_size, D)
"""
bias_factor_t = tf.constant(bias_factor, dtype=tf.float64)
eps_t = tf.constant(eps, dtype=tf.float64)
sign_out = tf.cast(tf.where(hout >= 0, 1., -1.), tf.float64) # shape (batch_size, M)
numerator_1 = tf.expand_dims(h_in, axis=2) * w
numerator_2 = bias_factor_t * (tf.expand_dims(b, 0) + eps_t * sign_out) / bias_nb_units
# use the following term if you want to check relevance property
#numerator_2 = (bias_factor_t * tf.expand_dims(b, 0) + eps_t * sign_out) / bias_nb_units
numerator = numerator_1 + tf.expand_dims(numerator_2, 1)
denom = hout + (eps*sign_out)
message = numerator / tf.expand_dims(denom, 1) * tf.expand_dims(Rout, 1)
R_in = tf.reduce_sum(message, axis=2)
return R_in
def lrp(self, x, y=None, eps=1e-3, bias_factor=0.0):
"""
LRP for a batch of samples x.
Args:
- x: input array. dim = (batch_size, T, embedding_dim)
- y: desired output_class to explain. dim = (batch_size,)
- eps: eps value for lrp-eps
- bias_factor: bias factor for lrp
Returns:
- Relevances: relevances of each input dimension. dim = (batch_size, T, embedding_dim
"""
lrp_pass = self.lrp_lstm(x,y,eps, bias_factor)
# add forward and backward relevances of x.
# Here we have to reverse R_x_fw since the tf.scan() function starts at the last timestep (T-1) and moves to
# timestep 0. Therefore the last entry of lrp_pass[2] belongs to the first timestep of x. Likewise, the last
# entry of lrp_pass[5] (R_x_rev) belongs to the last timestep of x and is thus already in the right order.
Rx_ = tf.reverse(lrp_pass[2], axis=[0]) + lrp_pass[5]
Rx = tf.transpose(Rx_, perm=(1,0,2)) # put batch dimension to first dim again
# remaining relevance is sum of last entry of Rh and Rc
rest = tf.reduce_sum(lrp_pass[0][-1] + lrp_pass[1][-1] + lrp_pass[3][-1] + lrp_pass[4][-1], axis=1)
return Rx, rest
@tf.function
def lrp_lstm(self, x, y=None, eps=1e-3, bias_factor=0.0):
batch_size = x.shape[0]
T = x.shape[1]
x_rev = tf.reverse(x, axis=[1])
# update inner states
y_hat, output_fw, output_bw = self.full_pass(x)
# if classes are given, use them. Else choose prediction of the network
if y is not None:
if not y.dtype is tf.int64:
y = tf.cast(y, tf.int64)
R_out_mask = tf.one_hot(y, depth=self.n_classes, dtype=tf.float64)
else:
R_out_mask = tf.one_hot(tf.argmax(y_hat, axis=1), depth=self.n_classes, dtype=tf.float64)
R_T = y_hat * R_out_mask
gates_pre_fw, gates_post_fw, c_fw, h_fw = output_fw
gates_pre_bw, gates_post_bw, c_bw, h_bw = output_bw
# c and h have one timestep more than x (the initial one, we have to add these zeros manually)
zero_block = tf.constant(np.zeros((1, batch_size, self.n_hidden)))
c_fw = tf.concat([c_fw, zero_block], axis=0)
h_fw = tf.concat([h_fw, zero_block], axis=0)
gates_pre_bw = tf.reverse(gates_pre_bw, [0])
gates_post_bw = tf.reverse(gates_post_bw, [0])
c_bw = tf.reverse(c_bw, [0])
h_bw = tf.reverse(h_bw, [0])
c_bw = tf.concat([c_bw, zero_block], axis=0)
h_bw = tf.concat([h_bw, zero_block], axis=0)
# first calculate relevaces from final linear layer
Rh_fw_T = self.lrp_linear_layer(h_fw[T - 1], self.W_dense_fw, self.b_dense,
y_hat, R_T, 2*self.n_hidden, eps, bias_factor)
Rh_bw_T = self.lrp_linear_layer(h_bw[T - 1], self.W_dense_bw, self.b_dense,
y_hat, R_T, 2*self.n_hidden, eps, bias_factor)
if self.debug:
tf.print('Dense: Input relevance', tf.reduce_sum(R_T, axis=1))
tf.print('Dense: Output relevance', tf.reduce_sum(Rh_fw_T+Rh_bw_T, axis=1))
elems = np.arange(T-1, -1, -1)
initializer = (
Rh_fw_T, # R_h_fw
Rh_fw_T, # R_c_fw
tf.constant(np.zeros((batch_size, self.embedding_dim)), name='R_x_fw'), # R_x_fw
Rh_bw_T, # R_h_bw
Rh_bw_T, # R_c_bw
tf.constant(np.zeros((batch_size, self.embedding_dim)), name='R_x_bw') # R_x_bw
)
eye = tf.eye(self.n_hidden, dtype=tf.float64)
zeros_hidden = tf.constant(np.zeros((self.n_hidden)))
@tf.function
def update(input_tuple, t):
# t starts with T-1 ; the values we want to update are essentially Rh, Rc and Rx
# input_tuple is (R_h_fw_t+1, R_c_fw_t+1, R_x_fw_t+1, R_h_bw_t+1, R_h_bw_t+1, R_x_bw_t+1)
#forward
Rc_fw_t = self.lrp_linear_layer(gates_post_fw[t, :, self.idx_f] * c_fw[t-1, :], eye, zeros_hidden,
c_fw[t, :], input_tuple[1], 2*self.n_hidden, eps, bias_factor)
R_g_fw = self.lrp_linear_layer(gates_post_fw[t,:,self.idx_i] * gates_post_fw[t,:,self.idx_c], eye,
zeros_hidden, c_fw[t, :], input_tuple[1], 2*self.n_hidden, eps, bias_factor)
if self.debug:
tf.print('Fw1: Input relevance', tf.reduce_sum(input_tuple[1], axis=1))
tf.print('Fw1: Output relevance', tf.reduce_sum(Rc_fw_t + R_g_fw, axis=1))
Rx_t = self.lrp_linear_layer(x[:,t], self.W_x_fward[:, self.idx_c], self.b_fward[self.idx_c],
gates_pre_fw[t, :, self.idx_c], R_g_fw, self.n_hidden + self.embedding_dim, eps, bias_factor)
Rh_fw_t = self.lrp_linear_layer(h_fw[t-1, :], self.W_h_fward[:, self.idx_c], self.b_fward[self.idx_c],
gates_pre_fw[t, :, self.idx_c], R_g_fw, self.n_hidden + self.embedding_dim, eps, bias_factor
)
if self.debug:
tf.print('Fw2: Input relevance', tf.reduce_sum(R_g_fw, axis=1))
tf.print('Fw2: Output relevance', tf.reduce_sum(Rx_t,axis=1)+tf.reduce_sum(Rh_fw_t, axis=1))
if t != 0:
Rc_fw_t += Rh_fw_t
#backward
Rc_bw_t = self.lrp_linear_layer(gates_post_bw[t, :, self.idx_f] * c_bw[t-1, :], eye, zeros_hidden,
c_bw[t, :], input_tuple[4], 2*self.n_hidden, eps, bias_factor)
R_g_bw = self.lrp_linear_layer(gates_post_bw[t, :, self.idx_i] * gates_post_bw[t, :, self.idx_c], eye,
zeros_hidden, c_bw[t,:], input_tuple[4], 2*self.n_hidden, eps, bias_factor)
if self.debug:
tf.print('Bw1: Input relevance', tf.reduce_sum(input_tuple[4], axis=1))
tf.print('Bw1: Output relevance', tf.reduce_sum(Rc_bw_t + R_g_bw, axis=1))
Rx_rev_t = self.lrp_linear_layer(x_rev[:, t], self.W_x_bward[:, self.idx_c], self.b_bward[self.idx_c],
gates_pre_bw[t, :, self.idx_c], R_g_bw, self.n_hidden + self.embedding_dim, eps, bias_factor)
Rh_bw_t = self.lrp_linear_layer(h_bw[t-1, :], self.W_h_bward[:, self.idx_c], self.b_bward[self.idx_c],
gates_pre_bw[t, :, self.idx_c], R_g_bw, self.n_hidden + self.embedding_dim, eps, bias_factor
)
if self.debug:
tf.print('Bw2: Input relevance', tf.reduce_sum(R_g_bw, axis=1))
tf.print('Bw2: Output relevance', tf.reduce_sum(Rx_rev_t,axis=1)+tf.reduce_sum(Rh_bw_t, axis=1))
if t != 0:
Rc_bw_t += Rh_bw_t
return Rh_fw_t, Rc_fw_t, Rx_t, Rh_bw_t, Rc_bw_t, Rx_rev_t
lrp_pass = tf.scan(update, elems, initializer)
return lrp_pass