A custom layer in keras to implement a pointer network decoder
Scheduling as an example.
An NN is a layered structure of neurons that only has connections between different layers:
Every neuron receives an input, converts it through an activation function and nonlinear transformation into an output, and gives it to the connected neurons in the next layer.
An example:
Inputs
The output is compared with a given target through a loss function. The loss function dictates the changes in neuron parameters (the direction and magnitude of changes) through gradient descent. In this case, the parameters are the
backpropagation:
How the loss function backpropagates through all layers to change the parameters of all neurons: for simplicity, we use mean squared error as loss function, or the Euclidean distance between output and target.
for simplification, we use logistic function as activation function:
According to the chain rule of derivatives:
if j is the last layer,
A recurrent NN creates a loop that feeds back information from output layer to input layer. On the time scale, this serves as a sort of "memory". We use RNN for solving optimization problems, especially those involving sequencing, for example, scheduling.
A simple diagram of RNN:
(https://arxiv.org/pdf/1503.04069.pdf)
Modeling:
input at time t:
N: number of LSTM neurons
M: input dimension
footnotes i, o, f, z:input gate, output gate, forget gate, block
trainable parameters:
input weight:
recurrence weight:
bias:
equations:
input
input gate
forget gate
cell
output gate
output
This structure aligns inputs and outputs of a RNN, like telling you which part of that long memory is more relevant to the current task. It has been very successful in natural language processing applications. The original paper is attention network, Bahdanau et. al.. We apply the concept to scheduling as well.
Bahdanau et.al.'s attention network uses LSTM as encoder, whose output is the input to attention model, which adjusts the weight of decoder to align with encoder.
Decoder topology:
modeling:
encoder input
decoder output
Decoder:
hidden state
proposal state
update state
reset state
context vector
probability vector
alignment vector
where
Trainable variables are
The above attention network and RNN with encoder-decoder structure has an output that is a different set from the input. This is different from the application in scheduling, where output is a permutation of the input, whose length changes with the input length.
Pointer networkadjusted the structure of the attention network, now the output sequence is a set of pointers back towards the input sequence.
(https://arxiv.org/pdf/1506.03134.pdf)
Compared to attention network, pointer network deleted the context vetor connections and
Encoder uses keras LSTM. Decoder structure did not exist at the time of the project. I customized the RNN layer based on keras.
pointer network's customized NN layer inherits keras layer structure to ensure compatibility:
Build: executed before compilation. This function creates the network structure, including dimensions, memory states, trainable variables, initialization, etc.
Call: executed during compilation. In RNN, we use step function for the recurrent calculations; in my project I use the Call function in pointer network to connect encoder hidden states to decoder (the time-distributed dense layer), to reduce compuation load in recurrence.
Step: The most important calculation happens here. Input to the function is standard for RNN: (output, [hidden state, cell state], the second array has to be a symbolic variable for recurrence, to enable backpropagation. In my pointer network, similar to the standard LSTM output format (output, [hidden state, cell state]), we return the probability vector
Build Sample code:
def build(self, input_shape):
""" input_shape: shape of the encoder output.
Assuming the encoder is an LSTM,
input_shape = (batchsize, timestep, encoder hiddensize)
"""
self.batch_size, self.timesteps, self.input_dim = input_shape
self.output_dim = self.timesteps
if self.stateful:
super(AttentionPointer, self).reset_states()
self.states = [None, None] # z, s_p
# Matrices for creating the probability vector alpha
self.V_a = self.add_weight(shape=(self.output_dim,),
name='V_a',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
self.W_a = self.add_weight(shape=(self.units, self.output_dim),
name='W_a',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
self.U_a = self.add_weight(shape=(self.input_dim, self.output_dim),
name='U_a',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
self.b_a = self.add_weight(shape=(self.output_dim,),
name='b_a',
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
# Matrices for the r (reset) gate
self.U_r = self.add_weight(shape=(self.units, self.units),
name='U_r',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
self.W_r = self.add_weight(shape=(self.output_dim, self.units),
name='W_r',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
self.b_r = self.add_weight(shape=(self.units, ),
name='b_r',
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
# Matrices for the z (update) gate
self.U_z = self.add_weight(shape=(self.units, self.units),
name='U_z',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
self.W_z = self.add_weight(shape=(self.output_dim, self.units),
name='W_z',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
self.b_z = self.add_weight(shape=(self.units, ),
name='b_z',
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
# Matrices for the proposal
self.U_p = self.add_weight(shape=(self.units, self.units),
name='U_p',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
self.W_p = self.add_weight(shape=(self.output_dim, self.units),
name='W_p',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
self.b_p = self.add_weight(shape=(self.units, ),
name='b_p',
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
# For creating the initial state:
# input to the pointer network is its own output, therefore
# use output_dim to initialize states.
self.W_s = self.add_weight(shape=(self.output_dim, self.units),
name='W_s',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
self.input_spec = [
InputSpec(shape=(self.batch_size, self.timesteps, self.input_dim))]
self.built = True
Call sample code:
def call(self, x):
# x is the hidden state of encoder.
self.x_seq = x
# a_ij = softmax(V_a^T tanh(W_a \cdot s_{t-1} + U_a \cdot h_t))
# apply a dense layer over the time dimension of the sequence
# (get the U_a \cdot h_t) part).
self._uxpb = _time_distributed_dense(self.x_seq, self.U_a, b=self.b_a,
input_dim=self.input_dim,
timesteps=self.timesteps,
output_dim=self.output_dim)
x = self._uxpb
return super(AttentionPointer, self).call(x)
Step sample code:
def step(self, x, states):
""" get the previous hidden state of the decoder from states = [z, s_p]
alignment model:
waStm1 = W_a \dot s_{t-1}
uaHt = U_a \dot h_t
tmp = tanh(waStm1 + uaHt)
e_ij = V_a^T * tmp
vector of length = timestep is: u_t = softmax(e_tj)
"""
atm1 = x
ztm1, s_tpm1 = states
# old hidden state:
# shape (batchsize, units)
stm1 = (1 - ztm1) * self.stm2 + ztm1 * s_tpm1
# shape (batchsize, timesteps, units)
_stm = K.repeat(stm1, self.timesteps)
# shape (batchsize, timesteps, output_dim)
_Wxstm = K.dot(_stm, self.W_a)
# calculate the attention probabilities:
# self._uxpb has shape (batchsize, timesteps, output_dim)
# V_a has shape (output_dim, )
# after K.expand_dims it is (output_dim, 1)
# therefore et has shape (batchsize, timesteps, 1)
et = K.dot(activations.tanh(_Wxstm + self._uxpb),
K.expand_dims(self.V_a))
at = K.exp(et)
at_sum = K.sum(at, axis=1)
at_sum_repeated = K.repeat(at_sum, self.timesteps)
at /= at_sum_repeated # vector of shape (batchsize, timesteps, 1)
# reset gate:
rt = activations.sigmoid( K.dot(atm1, self.W_r) + K.dot(stm1, self.U_r)
+ self.b_r )
# update gate:
zt = activations.sigmoid( K.dot(atm1, self.W_z) + K.dot(stm1, self.U_z)
+ self.b_z )
# proposal hidden state:
s_tp = activations.tanh( K.dot(atm1, self.W_p)
+ K.dot((rt * stm1), self.U_p) + self.b_p )
yt = activations.softmax(at)
if self.return_probabilities:
return at, [zt, s_tp]
else:
return yt, [zt, s_tp]
The custom layer class is too much code. Please go to the file for details.
LSTM, attention network, pointer network performance comparison on a very simple scheduling problem:
Test case:
Random generated 8-digit integer sequence (10k of data points):
x0 | x1 | x2 | x3 | x4 | x5 | x6 | x7 |
---|---|---|---|---|---|---|---|
9 | 7 | 4 | 1 | 6 | 2 | 5 | 3 |
9 | 2 | 4 | 1 | 3 | 8 | 7 | 0 |
5 | 9 | 1 | 0 | 8 | 3 | 2 | 4 |
9 | 5 | 3 | 1 | 2 | 7 | 0 | 6 |
7 | 3 | 1 | 9 | 0 | 8 | 2 | 4 |
6 | 4 | 8 | 7 | 1 | 0 | 2 | 3 |
4 | 8 | 1 | 2 | 9 | 5 | 6 | 3 |
6 | 5 | 7 | 8 | 0 | 3 | 4 | 1 |
8 | 9 | 7 | 0 | 1 | 3 | 2 | 5 |
3 | 9 | 6 | 2 | 0 | 7 | 4 | 1 |
Target output is the original sequence in descending order.
Create the model:
main_input = Input(shape=(n_steps, 1), name='main_input')
masked = Masking(mask_value=-1)(main_input)
enc = Bidirectional(LSTM(hidden_size, return_sequences=True), merge_mode='concat')(masked)
dropout = Dropout(rate=dropoutRate)(enc)
dec = AttentionPointer(hidden_size, return_probabilities=True)(dropout)
model = Model(inputs=main_input, outputs=dec)
model.summary()
Test result:
Encoders are all bidirectional LSTM. Decoders are LSTM,attention network and pointer network, respectively.
Prediction accuracy is the percentage of correctly sequenced elements.
Prediction accuracy:
Loss:
Sample output from pointer network:
y0 | y1 | y2 | y3 | y4 | y5 | y6 | y7 |
---|---|---|---|---|---|---|---|
8 | 7 | 6 | 5 | 4 | 3 | 2 | 1 |
8 | 7 | 5 | 4 | 3 | 2 | 1 | 0 |
9 | 8 | 6 | 5 | 3 | 2 | 1 | 0 |
9 | 8 | 7 | 6 | 4 | 3 | 2 | 0 |
9 | 8 | 6 | 5 | 4 | 3 | 2 | 1 |
8 | 7 | 6 | 5 | 4 | 3 | 1 | 0 |
9 | 8 | 7 | 6 | 4 | 3 | 2 | 0 |
9 | 8 | 7 | 6 | 3 | 2 | 1 | 0 |
9 | 8 | 7 | 4 | 3 | 2 | 1 | 0 |
9 | 8 | 7 | 4 | 3 | 2 | 1 | 0 |
It is obvious that pointer network has higher accuracy with faster convergence.