-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathops.py
359 lines (297 loc) · 12.4 KB
/
ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
import numpy as np
import scipy.linalg
import tensorflow as tf
try:
import horovod.tensorflow as hvd
except ImportError:
pass
from tensorflow.contrib.framework.python.ops import add_arg_scope, arg_scope
def shape(inputs, name=None):
with tf.name_scope(name, "shape"):
static_shape = inputs.get_shape().as_list()
dynamic_shape = tf.shape(inputs)
shape = []
for i, dim in enumerate(static_shape):
dim = dim if dim is not None else dynamic_shape[i]
shape.append(dim)
return(shape)
def allreduce_sum(x):
if hvd.size() == 1:
return x
return hvd.mpi_ops._allreduce(x)
def allreduce_mean(x):
return allreduce_sum(x) / hvd.size()
def squeeze2d(inputs, factor=2):
assert factor >= 1
if factor == 1:
return inputs
shape = inputs.get_shape()
height, width, channels = int(shape[1]), int(shape[2]), int(shape[3])
assert height % factor == 0 and width % factor == 0
inputs = tf.reshape(inputs, [-1, height//factor, factor, width//factor, factor, channels])
inputs = tf.transpose(inputs, [0, 1, 3, 5, 2, 4])
inputs = tf.reshape(inputs, [-1, height//factor, width//factor, channels*factor*factor])
return inputs
def unsqueeze2d(x, factor=2):
assert factor >= 1
if factor == 1:
return x
shape = x.get_shape()
height, width, channels = int(shape[1]), int(shape[2]), int(shape[3])
assert channels >= 4 and channels % 4 == 0
x = tf.reshape(x, (-1, height, width, int(channels/factor**2), factor, factor))
x = tf.transpose(x, [0, 1, 4, 2, 5, 3])
x = tf.reshape(x, (-1, int(height*factor), int(width*factor), int(channels/factor**2)))
return x
def default_initializer(std=0.05):
return tf.random_normal_initializer(0., std)
@add_arg_scope
def linear(name, x, width, do_weightnorm=True, do_actnorm=True, initializer=None, use_bias=True, scale=1.):
initializer = initializer or default_initializer()
with tf.variable_scope(name):
n_in = int(x.get_shape()[1])
w = tf.get_variable("W", [n_in, width],
tf.float32, initializer=initializer)
if do_weightnorm:
w = tf.nn.l2_normalize(w, [0])
x = tf.matmul(x, w)
if use_bias:
x += tf.get_variable("b", [1, width],
initializer=tf.zeros_initializer())
if do_actnorm:
x = actnorm(x, scale=scale, name="actnorm")
return x
@add_arg_scope
def linear_zeros(x, width, logscale_factor=3, name=None):
with tf.variable_scope(name, "linear_zeros"):
n_in = int(x.get_shape()[1])
w = tf.get_variable("W", [n_in, width], tf.float32,
initializer=tf.zeros_initializer())
x = tf.matmul(x, w)
x += tf.get_variable("b", [1, width],
initializer=tf.zeros_initializer())
x *= tf.exp(tf.get_variable("logs",
[1, width], initializer=tf.zeros_initializer()) * logscale_factor)
return x
def add_edge_padding(x, filter_size):
assert filter_size[0] % 2 == 1
if filter_size[0] == 1 and filter_size[1] == 1:
return x
a = (filter_size[0] - 1) // 2 # vertical padding size
b = (filter_size[1] - 1) // 2 # horizontal padding size
in_shape = x.get_shape().as_list()
x = tf.pad(x, [[0, 0], [a, a], [b, b], [0, 0]])
name = "_".join([str(dim) for dim in [a, b, *in_shape[1:3]]])
pads = tf.get_collection(name)
if not pads:
pad = np.zeros([1] + x.get_shape().as_list()[1:3] + [1], dtype='float32')
pad[:, :a, :, 0] = 1.
pad[:, -a:, :, 0] = 1.
pad[:, :, :b, 0] = 1.
pad[:, :, -b:, 0] = 1.
pad = tf.convert_to_tensor(pad)
tf.add_to_collection(name, pad)
else:
pad = pads[0]
pad = tf.tile(pad, [tf.shape(x)[0], 1, 1, 1])
x = tf.concat([x, pad], axis=3)
return x
@add_arg_scope
def conv2d(x, width,
filter_size=[3, 3],
stride=[1, 1],
pad="SAME",
do_weightnorm=False,
do_actnorm=True,
context1d=None,
skip=1,
edge_bias=True,
name=None):
with tf.variable_scope(name, "conv2d"):
if edge_bias and pad == "SAME":
x = add_edge_padding(x, filter_size)
pad = 'VALID'
n_in = int(x.get_shape()[3])
stride_shape = [1] + stride + [1]
filter_shape = filter_size + [n_in, width]
w = tf.get_variable("W", filter_shape, tf.float32,
initializer=default_initializer())
if do_weightnorm:
w = tf.nn.l2_normalize(w, [0, 1, 2])
if skip == 1:
x = tf.nn.conv2d(x, w, stride_shape, pad, data_format='NHWC')
else:
assert stride[0] == 1 and stride[1] == 1
x = tf.nn.atrous_conv2d(x, w, skip, pad)
if do_actnorm:
x = actnorm(x, name="actnorm")
else:
x += tf.get_variable("b", [1, 1, 1, width],
initializer=tf.zeros_initializer())
if context1d != None:
context = tf.reshape(linear("context", context1d,
width), [-1, 1, 1, width])
x += context
return x
@add_arg_scope
def conv2d_zeros(x,
width,
filter_size=[3, 3],
stride=[1, 1],
pad="SAME",
logscale_factor=3,
skip=1,
edge_bias=True,
name=None):
with tf.variable_scope(name, "conv2d"):
if edge_bias and pad == "SAME":
x = add_edge_padding(x, filter_size)
pad = 'VALID'
n_in = int(x.get_shape()[3])
stride_shape = [1] + stride + [1]
filter_shape = filter_size + [n_in, width]
w = tf.get_variable("W", filter_shape, tf.float32,
initializer=tf.zeros_initializer())
if skip == 1:
x = tf.nn.conv2d(x, w, stride_shape, pad, data_format='NHWC')
else:
assert stride[0] == 1 and stride[1] == 1
x = tf.nn.atrous_conv2d(x, w, skip, pad)
x += tf.get_variable("b", [1, 1, 1, width],
initializer=tf.ones_initializer())
x *= tf.exp(tf.get_variable("logs",
[1, width], initializer=tf.zeros_initializer()) * logscale_factor)
return x
@add_arg_scope
def actnorm(x,
scale=1.,
logdet=None,
logscale_factor=3.,
batch_variance=False,
reverse=False,
context1d=None,
trainable=True,
name=None):
with tf.variable_scope(name, "actnorm"):
shape = x.get_shape()
rank = len(shape)
assert rank == 2 or rank == 4
_shape = [1 for i in range(rank - 1)] + [shape[-1]]
logdet_factor = 1 if rank == 2 else int(shape[1])*int(shape[2])
axis = [0] if rank == 2 else [0,1,2]
version = 1
if version == 0:
b = tf.get_variable("b", _shape, initializer=tf.constant_initializer(0.))
logs = tf.get_variable("logs", _shape) #, initializer=tf.constant_initializer(0.))
if context1d is not None:
logs = logs + tf.reshape(linear("context_s", context1d, _shape[-1], do_actnorm=False, use_bias=False), [-1] + _shape[1:])
b = b + tf.reshape(linear("context_b", context1d, _shape[-1], do_actnorm=False, use_bias=False), [-1] + _shape[1:])
b = tf.nn.tanh(b)
logs = -tf.nn.sigmoid(logs)
elif version == 1:
b = tf.get_variable("b", _shape)
logs = tf.get_variable("logs", _shape, initializer=tf.constant_initializer(-1.)) * logscale_factor
if not reverse:
x += b
x = x * tf.exp(logs)
else:
x = x * tf.exp(-logs)
x -= b
if logdet is not None:
dlogdet = tf.reduce_sum(logs) * logdet_factor
if reverse:
dlogdet *= -1
# if not reverse and False:
# tf.summary.scalar("dlogdet", dlogdet)
logdet += dlogdet
return x, logdet
else:
return x
@add_arg_scope
def invertible_1x1_conv(z, logdet, reverse=False, name=None, use_bias=False):
with tf.variable_scope(name, "invconv"):
shape = z.get_shape().as_list()
w_shape = [shape[3], shape[3]]
# Sample a random orthogonal matrix:
w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype('float32')
w = tf.get_variable("W", dtype=tf.float32, initializer=w_init)
det_w = tf.matrix_determinant(tf.cast(w, 'float64'))
dlogdet = tf.cast(tf.log(abs(det_w)), 'float32') * shape[1] * shape[2]
if use_bias:
b = tf.get_variable("bias", [1, 1, 1, shape[3]])
if not reverse:
_w = w[tf.newaxis, tf.newaxis, ...]
z = tf.nn.conv2d(z, _w, [1, 1, 1, 1], 'SAME', data_format='NHWC')
logdet += dlogdet
if use_bias:
z += b
else:
if use_bias:
z -= b
w_inv = tf.matrix_inverse(w)
_w = w_inv[tf.newaxis, tf.newaxis, ...]
z = tf.nn.conv2d(z, _w, [1, 1, 1, 1], 'SAME', data_format='NHWC')
logdet -= dlogdet
return z, logdet
def flatten_sum(logps):
if len(logps.get_shape()) == 2:
return tf.reduce_sum(logps, [1])
elif len(logps.get_shape()) == 4:
return tf.reduce_sum(logps, [1, 2, 3])
else:
raise Exception()
def gaussian_diag(mean, logsd):
class o(object):
pass
o.mean = mean
o.logsd = logsd
o.eps = tf.random_normal(tf.shape(mean))
def sample(eps=None):
epsilon = eps if eps is not None else o.eps
return mean + tf.exp(logsd) * epsilon
o.sample = sample
o.logps = lambda x: -0.5*(np.log(2 * np.pi) + 2. * logsd + tf.square(x - mean) * tf.exp(-2.*logsd))
o.logp = lambda x: flatten_sum(o.logps(x))
o.get_eps = lambda x: (x - mean) * tf.exp(-logsd)
return o
def logistic_logpdf(inputs, mean, logs, name="logistic_logpdf"):
with tf.variable_scope(name):
z = (inputs - mean) * tf.exp(-logs)
return z - logs - 2 * tf.nn.softplus(z)
def logitic_logcdf(inputs, mean, logs, name="logistic_logcdf"):
with tf.variable_scope(name):
z = (inputs - mean) * tf.exp(-logs)
return tf.log_sigmoid(z)
def mixlogistic_logpdf(inputs, prior_logits, means, logs, name="mixlogistic_logpdf"):
with tf.variable_scope(name):
logpdf = logistic_logpdf(tf.expand_dims(inputs, axis=-1), means, logs)
return tf.reduce_logsumexp(tf.nn.log_softmax(prior_logits, axis=-1) + logpdf, axis=-1)
def mixlogistic_logcdf(inputs, prior_logits, means, logs, name="mixlogistic_logcdf"):
with tf.variable_scope(name):
logpdf = logistic_logcdf(tf.expand_dims(inputs, axis=-1), means, logs)
return tf.reduce_logsumexp(tf.nn.log_softmax(prior_logits, axis=-1) + logpdf, axis=-1)
def assert_in_range(x, min_value, max_value):
return tf.Assert(tf.logical_and(tf.greater_equal(tf.reduce_min(x), min_value),
tf.less_equal(tf.reduce_max(x), max_value)), [x])
def mixlogistic_invcdf(inputs, prior_logits, means, logs, tol=1e-10, max_iters=500, name="mixlogistic_invcdf"):
with tf.variable_scope(name):
with tf.control_dependencies([assert_in_range(inputs, 0., 1.)]):
y = tf.identity(inputs)
def body(x, lb, ub, _last_diff):
cur_y = tf.exp(mixlogistic_logcdf(x, prior_logits, means, logs))
gt = tf.cast(tf.greater(cur_y, y), dtype=y.dtype)
lt = 1. - gt
new_x = gt * (x + lb) / 2. + lt * (x + ub) / 2.
new_lb = gt * lb + lt * x
new_ub = gt * x + lt * ub
diff = tf.reduce_max(tf.abs(new_x - x))
return new_x, new_lb, new_ub, diff
init_x = tf.zeros_like(y)
max_scale = tf.reduce_sum(tf.exp(logs), axis=-1, keepdims=True)
init_lb = tf.reduce_min(means - 50 * max_scale, axis=-1)
init_ub = tf.reduce_min(means + 50 * max_scale, axis=-1)
init_diff = tf.constant(np.inf, dtype=y.dtype)
out_x, _, _, _ = tf.while_loop(cond=lambda _x, _lb, _ub, last_diff: last_diff > tol,
body=body, loop_vars=(init_x, init_lb, init_ub, init_diff),
back_prop=False, maximum_iterations=max_iters)
return out_x