-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.py
136 lines (119 loc) · 4.7 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import tensorflow as tf
from instance import InstanceNormalization
from tensorflow import keras
class Encoder(keras.Model):
def __init__(self, args):
"""
:param Arg args:
"""
super(Encoder, self).__init__()
self.args = args
for i in range(1, 5):
self.__setattr__("conv" + str(i),
tf.compat.v1.layers.Conv2D(self.args.conv_filter[4 - i], self.args.kernel_size, 2, "same"))
self.__setattr__("norm" + str(i), InstanceNormalization())
def call(self, inputs, training=None, mask=None):
x = inputs
outputs = []
for i in range(1, 5):
x = self.__getattribute__("conv" + str(i))(x)
x = self.__getattribute__("norm" + str(i))(x)
x = tf.nn.leaky_relu(x, self.args.leaky_alpha)
x = tf.compat.v1.layers.dropout(x, self.args.dropout_rate)
outputs.append(x)
return outputs
class Decoder(keras.Model):
def __init__(self, args):
"""
:param Arg args:
"""
super(Decoder, self).__init__()
self.args = args
for i in range(1, 5):
self.__setattr__("conv" + str(i),
tf.compat.v1.layers.Conv2DTranspose(self.args.conv_filter[i], self.args.kernel_size,
(2, 2), "same"))
self.__setattr__("norm" + str(i), InstanceNormalization())
def call(self, inputs, training=None, mask=None):
x, add = inputs
for i in range(1, 5):
if add[i - 1] is not None:
x = tf.add(x, add[i - 1])
x = self.__getattribute__("conv" + str(i))(x)
x = self.__getattribute__("norm" + str(i))(x)
x = tf.nn.leaky_relu(x, self.args.leaky_alpha)
return x
class Discriminator(keras.Model):
def __init__(self, args, encoder):
"""
:param Arg args:
"""
super(Discriminator, self).__init__()
self.args = args
self.encoder = encoder
self.dense_pr = tf.compat.v1.layers.Dense(1, "sigmoid")
self.dense_cond = tf.compat.v1.layers.Dense(self.args.cond_dim, "sigmoid")
@tf.contrib.eager.defun
def call(self, inputs, training=None, mask=None):
# Todo: try to discriminate the feature map
x = inputs
encoder_layers = self.encoder(x)
x = tf.compat.v1.layers.flatten(encoder_layers.pop())
output_pr = self.dense_pr(x)
output_cond = self.dense_cond(x)
return output_pr, output_cond
class Generator(keras.Model):
def __init__(self, args, decoder):
"""
:param Arg args:
"""
super(Generator, self).__init__()
self.args = args
self.dense = tf.compat.v1.layers.Dense(self.args.init_dim ** 2 * self.args.conv_filter[0])
self.norm = InstanceNormalization()
self.decoder = decoder
self.conv = tf.compat.v1.layers.Conv2DTranspose(self.args.image_channel, self.args.kernel_size, strides=(1, 1),
padding="same", activation="tanh")
@tf.contrib.eager.defun
def call(self, inputs, training=None, mask=None):
"""
生成器
:param inputs: [noise, real_cond]
:param training:
:param mask:
:return:
"""
x = tf.concat(inputs, -1)
x = self.dense(x)
x = tf.nn.leaky_relu(x, self.args.leaky_alpha)
x = tf.reshape(x, [-1, self.args.init_dim, self.args.init_dim, self.args.conv_filter[0]])
x = self.norm(x)
x = self.decoder([x, [None] * 4])
output_image = self.conv(x)
return output_image
class Adjuster(keras.Model):
def __init__(self, args, discriminator, generator):
"""
:param Arg args:
:param Discriminator discriminator:
:param Generator generator:
"""
super(Adjuster, self).__init__()
self.args = args
self.encoder = discriminator.encoder
self.dense = tf.compat.v1.layers.Dense(self.args.init_dim ** 2 * self.args.conv_filter[0])
self.norm = InstanceNormalization()
self.decoder = generator.decoder
self.conv = generator.conv
@tf.contrib.eager.defun
def call(self, inputs, training=None, mask=None):
image, cond = inputs
encoder_layers = self.encoder(image)
c = self.dense(cond)
c = tf.nn.leaky_relu(c, alpha=self.args.leaky_alpha)
c = self.norm(c)
c = tf.reshape(c, [-1, self.args.init_dim, self.args.init_dim, self.args.conv_filter[0]])
encoder_layers.reverse()
x = self.decoder([c, encoder_layers])
output_adj = self.conv(x)
return output_adj