This repository has been archived by the owner on May 14, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 8
/
basic_lstm_cell.py
181 lines (155 loc) · 7.88 KB
/
basic_lstm_cell.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""
https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/rnn_cell_impl.py
https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/contrib/rnn/python/ops/rnn_cell.py
BasicLSTMCell (and other RNN based cell) only for input with (batch, time).
Dynamic RNN cell can be handle input with (batch, time, input size) and dynamic sequence length.
"""
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import nn_impl
_EPSILON = 10**-4
class CustomLSTMCell(rnn_cell_impl.RNNCell):
"""Customized LSTM with several additional regularization
Edit `BasicLSTMCell` of tensorflow.
The implementation is based on: http://arxiv.org/abs/1409.2329.
We add forget_bias (default: 1) to the biases of the forget gate in order to
reduce the scale of forgetting in the beginning of the training.
It does not allow cell clipping, a projection layer, and does not
use peep-hole connections: it is the basic baseline.
For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
that follows.
- layer normalization
- recurrent dropout
- variational dropout (per-sample masking version)
"""
def __init__(self,
num_units,
forget_bias=1.0,
activation=None,
reuse=None,
layer_norm: bool=False,
norm_shift: float=0.0,
norm_gain: float=1.0, # layer normalization
dropout_keep_prob_in: float = 1.0,
dropout_keep_prob_h: float=1.0,
dropout_keep_prob_out: float=1.0,
dropout_keep_prob_gate: float=1.0,
dropout_keep_prob_forget: float=1.0,
dropout_prob_seed: int=None,
variational_dropout: bool=False,
recurrent_dropout: bool=False
):
"""Initialize the basic LSTM cell.
Args:
num_units: int, The number of units in the LSTM cell.
forget_bias: float, The bias added to forget gates (see above).
Must set to `0.0` manually when restoring from CudnnLSTM-trained
checkpoints.
activation: Activation function of the inner states. Default: `tanh`.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
layer_norm: (optional) If True, apply layer normalization.
norm_shift: (optional) Shift parameter for layer normalization.
norm_gain: (optional) Gain parameter for layer normalization.
dropout_prob_seed: (optional)
recurrent_dropout: (optional)
dropout_keep_prob_in: (optional) keep probability of variational dropout for input
dropout_keep_prob_out: (optional) keep probability of variational dropout for output
dropout_keep_prob_gate: (optional) keep probability of variational dropout for gating cell
dropout_keep_prob_forget: (optional) keep probability of variational dropout for forget cell
dropout_keep_prob_h: (optional) keep probability of recurrent dropout for gated state
"""
super(CustomLSTMCell, self).__init__(_reuse=reuse)
self._num_units = num_units
self._forget_bias = forget_bias
self._activation = activation or math_ops.tanh
self._layer_norm = layer_norm
self._g = norm_gain
self._b = norm_shift
self._recurrent_dropout = recurrent_dropout
self._variational_dropout = variational_dropout
self._seed = dropout_prob_seed
self._keep_prob_i = dropout_keep_prob_in
self._keep_prob_g = dropout_keep_prob_gate
self._keep_prob_f = dropout_keep_prob_forget
self._keep_prob_o = dropout_keep_prob_out
self._keep_prob_h = dropout_keep_prob_h
@property
def state_size(self):
return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
@property
def output_size(self):
return self._num_units
def _layer_normalization(self, inputs, scope=None):
"""
:param inputs: (batch, shape)
:param scope:
:return : layer normalized inputs (batch, shape)
"""
shape = inputs.get_shape()[-1:]
with vs.variable_scope(scope or "layer_norm"):
# Initialize beta and gamma for use by layer_norm.
g = vs.get_variable("gain", shape=shape, initializer=init_ops.constant_initializer(self._g)) # (shape,)
s = vs.get_variable("shift", shape=shape, initializer=init_ops.constant_initializer(self._b)) # (shape,)
m, v = nn_impl.moments(inputs, [1], keep_dims=True) # (batch,)
normalized_input = (inputs - m) / math_ops.sqrt(v + _EPSILON) # (batch, shape)
return normalized_input * g + s
@staticmethod
def _linear(x, weight_shape, bias=True, scope=None):
""" linear projection (weight_shape: input size, output size) """
with vs.variable_scope(scope or "linear"):
w = vs.get_variable("kernel", shape=weight_shape)
x = math_ops.matmul(x, w)
if bias:
b = vs.get_variable("bias", initializer=[0.0] * weight_shape[-1])
return nn_ops.bias_add(x, b)
else:
return x
def call(self, inputs, state):
"""Long short-term memory cell (LSTM).
Args:
inputs: `2-D` tensor with shape `[batch_size x input_size]`.
state: An `LSTMStateTuple` of state tensors, each shaped
`[batch_size x self.state_size]`, if `state_is_tuple` has been set to
`True`. Otherwise, a `Tensor` shaped
`[batch_size x 2 * self.state_size]`.
Returns:
A pair containing the new hidden state, and the new state (either a
`LSTMStateTuple` or a concatenated state, depending on
`state_is_tuple`).
Pep8 inspection appears since this signature is not same as `call` in tensorflow/python/layers/base.
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/layers/base.py
"""
c, h = state # memory cell, hidden unit
args = array_ops.concat([inputs, h], 1)
concat = self._linear(args, [args.get_shape()[-1], 4 * self._num_units])
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
if self._layer_norm:
i = self._layer_normalization(i, "layer_norm_i")
j = self._layer_normalization(j, "layer_norm_j")
f = self._layer_normalization(f, "layer_norm_f")
o = self._layer_normalization(o, "layer_norm_o")
g = self._activation(j) # gating
# variational dropout
if self._variational_dropout:
i = nn_ops.dropout(i, self._keep_prob_i, seed=self._seed)
g = nn_ops.dropout(g, self._keep_prob_g, seed=self._seed)
f = nn_ops.dropout(f, self._keep_prob_f, seed=self._seed)
o = nn_ops.dropout(o, self._keep_prob_o, seed=self._seed)
gated_in = math_ops.sigmoid(i) * g
memory = c * math_ops.sigmoid(f + self._forget_bias)
# recurrent dropout
if self._recurrent_dropout:
gated_in = nn_ops.dropout(gated_in, self._keep_prob_h, seed=self._seed)
# layer normalization for memory cell (original paper didn't use for memory cell).
# if self._layer_norm:
# new_c = self._layer_normalization(new_c, "state")
new_c = memory + gated_in
new_h = self._activation(new_c) * math_ops.sigmoid(o)
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
return new_h, new_state