-
Notifications
You must be signed in to change notification settings - Fork 7
/
generator.py
59 lines (55 loc) · 2.56 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
import tensorflow as tf
from util import log
from ops import fc
from ops import conv2d_res
class Generator(object):
def __init__(self, name, h, w, c, norm_type, deconv_type,
num_res_block, is_train):
self.name = name
self._h = h
self._w = w
self._c = c
self._norm_type = norm_type
self._deconv_type = deconv_type
self._num_res_block = num_res_block
self._is_train = is_train
self._reuse = False
self.start_dim_x = 4 if w > 32 else 1
self.start_dim_y = 4 if h > 32 else 1
self.start_dim_ch = 256
def __call__(self, input):
if self._deconv_type == 'bilinear':
from ops import bilinear_deconv2d as deconv2d
elif self._deconv_type == 'nn':
from ops import nn_deconv2d as deconv2d
elif self._deconv_type == 'transpose':
from ops import deconv2d
else:
raise NotImplementedError
with tf.variable_scope(self.name, reuse=self._reuse):
if not self._reuse:
log.warn(self.name)
_ = fc(input, self.start_dim_x * self.start_dim_y * self.start_dim_ch,
self._is_train, info=not self._reuse, norm='none', name='fc')
_ = tf.reshape(_, [_.shape.as_list()[0], self.start_dim_y,
self.start_dim_x, self.start_dim_ch])
if not self._reuse:
log.info('reshape {} '.format(_.shape.as_list()))
num_deconv_layer = int(np.ceil(np.log2(
max(float(self._h/self.start_dim_y), float(self._w/self.start_dim_x)))))
for i in range(num_deconv_layer):
_ = deconv2d(_, max(self._c, int(_.get_shape().as_list()[-1]/2)),
self._is_train, info=not self._reuse, norm=self._norm_type,
name='deconv{}'.format(i+1))
if num_deconv_layer - i <= self._num_res_block:
_ = conv2d_res(
_, self._is_train, info=not self._reuse,
name='res_block{}'.format(self._num_res_block - num_deconv_layer + i + 1))
_ = deconv2d(_, self._c, self._is_train, k=1, s=1, info=not self._reuse,
activation_fn=tf.tanh, norm='none',
name='deconv{}'.format(i+2))
_ = tf.image.resize_bilinear(_, [self._h, self._w])
self._reuse = True
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
return _