-
Notifications
You must be signed in to change notification settings - Fork 62
/
convlstm.py
146 lines (124 loc) · 5.41 KB
/
convlstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import tensorflow as tf
class ConvLSTMCell(tf.contrib.rnn.RNNCell):
"""A LSTM cell with convolutions instead of multiplications.
Reference:
Xingjian, S. H. I., et al. "Convolutional LSTM network:
A machine learning approach for precipitation nowcasting.
" Advances in Neural Information Processing Systems. 2015.
"""
def __init__(self,
shape,
filters,
kernel,
initializer=None,
forget_bias=1.0,
activation=tf.tanh,
normalize=True):
self._kernel = kernel
self._filters = filters
self._initializer = initializer
self._forget_bias = forget_bias
self._activation = activation
self._size = tf.TensorShape(shape + [self._filters])
self._normalize = normalize
self._feature_axis = self._size.ndims
@property
def state_size(self):
return tf.contrib.rnn.LSTMStateTuple(self._size, self._size)
@property
def output_size(self):
return self._size
def __call__(self, x, h, scope=None):
with tf.variable_scope(scope or self.__class__.__name__):
previous_memory, previous_output = h
channels = x.shape[-1].value
filters = self._filters
gates = 4 * filters if filters > 1 else 4
x = tf.concat([x, previous_output], axis=self._feature_axis)
n = channels + filters
m = gates
W = tf.get_variable(
'kernel', self._kernel + [n, m], initializer=self._initializer)
y = tf.nn.convolution(x, W, 'SAME')
if not self._normalize:
y += tf.get_variable(
'bias', [m], initializer=tf.constant_initializer(0.0))
input_contribution, input_gate, forget_gate, output_gate = tf.split(
y, 4, axis=self._feature_axis)
if self._normalize:
input_contribution = tf.contrib.layers.layer_norm(
input_contribution)
input_gate = tf.contrib.layers.layer_norm(input_gate)
forget_gate = tf.contrib.layers.layer_norm(forget_gate)
output_gate = tf.contrib.layers.layer_norm(output_gate)
memory = (
previous_memory * tf.sigmoid(forget_gate + self._forget_bias) +
tf.sigmoid(input_gate) * self._activation(input_contribution))
if self._normalize:
memory = tf.contrib.layers.layer_norm(memory)
output = self._activation(memory) * tf.sigmoid(output_gate)
return output, tf.contrib.rnn.LSTMStateTuple(memory, output)
class ConvGRUCell(tf.contrib.rnn.RNNCell):
"""A GRU cell with convolutions instead of multiplications."""
def __init__(self,
shape,
filters,
kernel,
initializer=None,
activation=tf.tanh,
normalize=True):
self._filters = filters
self._kernel = kernel
self._initializer = initializer
self._activation = activation
self._size = tf.TensorShape(shape + [self._filters])
self._normalize = normalize
self._feature_axis = self._size.ndims
@property
def state_size(self):
return self._size
@property
def output_size(self):
return self._size
def __call__(self, x, h, scope=None):
with tf.variable_scope(scope or self.__class__.__name__):
with tf.variable_scope('Gates'):
channels = x.shape[-1].value
inputs = tf.concat([x, h], axis=self._feature_axis)
n = channels + self._filters
m = 2 * self._filters if self._filters > 1 else 2
W = tf.get_variable(
'kernel',
self._kernel + [n, m],
initializer=self._initializer)
y = tf.nn.convolution(inputs, W, 'SAME')
if self._normalize:
reset_gate, update_gate = tf.split(
y, 2, axis=self._feature_axis)
reset_gate = tf.contrib.layers.layer_norm(reset_gate)
update_gate = tf.contrib.layers.layer_norm(update_gate)
else:
y += tf.get_variable(
'bias', [m], initializer=tf.constant_initializer(1.0))
reset_gate, update_gate = tf.split(
y, 2, axis=self._feature_axis)
reset_gate, update_gate = tf.sigmoid(reset_gate), tf.sigmoid(
update_gate)
with tf.variable_scope('Output'):
inputs = tf.concat(
[x, reset_gate * h], axis=self._feature_axis)
n = channels + self._filters
m = self._filters
W = tf.get_variable(
'kernel',
self._kernel + [n, m],
initializer=self._initializer)
y = tf.nn.convolution(inputs, W, 'SAME')
if self._normalize:
y = tf.contrib.layers.layer_norm(y)
else:
y += tf.get_variable(
'bias', [m], initializer=tf.constant_initializer(0.0))
y = self._activation(y)
output = update_gate * h + (1 - update_gate) * y
return output, output