-
Notifications
You must be signed in to change notification settings - Fork 8
/
rollout.py
296 lines (247 loc) · 13 KB
/
rollout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
import tensorflow as tf
from tensorflow.python.ops import tensor_array_ops, control_flow_ops
import numpy as np
import yaml
with open("SeqGAN.yaml") as stream:
try:
config = yaml.load(stream)
except yaml.YAMLError as exc:
print(exc)
class ROLLOUT(object):
# policy rollout object for policy gradient update
# it takes the generator network as object, having the same structure as the generator
# during adversarial training, it produces rewards by get_reward()
# it updates its parameters by update_params()
def __init__(self, lstm, update_rate):
# define the network & update rate
self.lstm = lstm
self.update_rate = update_rate
# define hyperparams of the lstm network
self.num_emb = self.lstm.num_emb
self.batch_size = self.lstm.batch_size
self.emb_dim = self.lstm.emb_dim
self.hidden_dim = self.lstm.hidden_dim
self.sequence_length = self.lstm.sequence_length
# copy the start token and learning rate of the generator
self.start_token = tf.identity(self.lstm.start_token)
self.learning_rate = self.lstm.learning_rate
# define the generator embeddings & units
self.g_embeddings = tf.identity(self.lstm.g_embeddings)
self.g_recurrent_unit = self.create_recurrent_unit() # maps h_tm1 to h_t for generator
self.g_output_unit = self.create_output_unit() # maps h_t to o_t (output token logits)
#####################################################################################################
# placeholder definition for input sequence of tokens
self.x = tf.placeholder(tf.int32, shape=[self.batch_size, self.sequence_length])
self.given_num = tf.placeholder(tf.int32)
# process the input x with embeddings
# permutation is for [seq_length, batch_size, emb_dim]
# the reference code does this within cpu (for memory efficiency?)
with tf.device('/cpu:0'):
self.processed_x = tf.transpose(tf.nn.embedding_lookup(self.g_embeddings, self.x), perm=[1, 0, 2])
# unstack the processed_x to tensor array
ta_emb_x = tensor_array_ops.TensorArray(dtype=tf.float32, size=self.sequence_length)
ta_emb_x = ta_emb_x.unstack(self.processed_x)
# same goes for the x without embedding, note the int32 instead of float32
ta_x = tensor_array_ops.TensorArray(dtype=tf.int32, size=self.sequence_length)
ta_x = ta_x.unstack(tf.transpose(self.x, perm=[1, 0]))
#####################################################################################################
# define zero initial state
self.h0 = tf.zeros([self.batch_size, self.hidden_dim])
# stack two of it?
self.h0 = tf.stack([self.h0, self.h0])
# define tensor array of fake data from generator
gen_x = tensor_array_ops.TensorArray(dtype=tf.int32, size=self.sequence_length,
dynamic_size=False, infer_shape=True)
# generation procedure consists of two phases: when i < given_num, and i > given_num
# when current index i < given_num, use the provided tokens as the input at each time step
def _g_recurrence_1(i, x_t, h_tm1, given_num, gen_x):
h_t = self.g_recurrent_unit(x_t, h_tm1) # h_tm1 stands for hidden memory tuple
x_tp1 = ta_emb_x.read(i)
gen_x = gen_x.write(i, ta_x.read(i))
return i+1, x_tp1, h_t, given_num, gen_x
# when current index i >= given_num, start roll-out
# use the output at t as the input at t+1
def _g_recurrence_2(i, x_t, h_tm1, given_num, gen_x):
h_t = self.g_recurrent_unit(x_t, h_tm1)
# define output logits, with size of [batch_size, vocab_size]
o_t = self.g_output_unit(h_t)
# calculate log probabilities
# may expect numerical instability due to the direct usage of log, might cause NaN?
log_prob = tf.log(tf.nn.softmax(o_t))
# generate next token based on the log prob: reshape to 1D of batch_size, then cast to int
next_token = tf.cast(tf.reshape(tf.multinomial(log_prob, 1), [self.batch_size]), tf.int32)
# generate embedding from the next token, with size of [batch_size, emb_dim]
x_tp1 = tf.nn.embedding_lookup(self.g_embeddings, next_token)
# write the next token to gen_x to the current index i
gen_x = gen_x.write(i, next_token)
return i+1, x_tp1, h_t, given_num, gen_x
# generate gen_x from the defined recurrences above, using the while loop control ops
# remember that TF uses static graph, requiring this special control flow for conditional branching
i, x_t, h_tm1, given_num, self.gen_x = control_flow_ops.while_loop(
# loop condition
cond=lambda i, _1, _2, given_num, _4: i < given_num,
# body function to loop
body=_g_recurrence_1,
# initial values to each variables
loop_vars=(tf.constant(0, dtype=tf.int32),
tf.nn.embedding_lookup(self.g_embeddings, self.start_token),
self.h0, self.given_num, gen_x))
# we only need gen_x from roll-out phase for further processing
_, _, _, _, self.gen_x = control_flow_ops.while_loop(
cond=lambda i, _1, _2, _3, _4: i < self.sequence_length,
body=_g_recurrence_2,
loop_vars=(i, x_t, h_tm1, given_num, self.gen_x))
# unstack the gen_x, with shape [seq_length, batch_size]
self.gen_x = tf.transpose(self.gen_x.stack(), perm=[1, 0])
def get_reward(self, sess, input_x, rollout_num, discriminator):
"""
calculate rewards from policy rollout
:param sess: TF session
:param input_x: input data
:param rollout_num: the number rollout for Monte Carlo search
:param discriminator: discriminator object
:return: rewards; list of reward at each step
"""
# define empty rewards list, append for each time step
rewards = []
# iterate over the defined rollout_num
for i in range(rollout_num):
# given_num for time step is explicitly from 1 to SEQ_LENGTH
for given_num in range(1, config['SEQ_LENGTH']):
# define feed for generation
feed = {self.x: input_x, self.given_num: given_num}
# run the gen_x op defined from __init__ with feed
samples = sess.run(self.gen_x, feed)
# define new feed for discrimination
feed = {discriminator.input_x: samples, discriminator.dropout_keep_prob: 1.0}
# run prediction by discriminator with feed
ypred_for_auc = sess.run(discriminator.ypred_for_auc, feed)
ypred = np.array([item[1] for item in ypred_for_auc])
# add rewards for each given_num
if i == 0: # initial rollout
rewards.append(ypred)
else: # from 2nd rollout, add to the existing value
rewards[given_num-1] += ypred
# the last token reward
feed = {discriminator.input_x: input_x, discriminator.dropout_keep_prob: 1.0}
ypred_for_auc = sess.run(discriminator.ypred_for_auc, feed)
ypred = np.array([item[1] for item in ypred_for_auc])
if i == 0:
rewards.append(ypred)
else:
rewards[config['SEQ_LENGTH']-1] += ypred
# average out the rewards, with shape [batch_size, seq_length]
rewards = np.transpose(np.array(rewards)) / (1.0 * rollout_num)
return rewards
def create_recurrent_unit(self):
# Weights and Bias for input and hidden tensor
# copy-paste of the generator: the original paper assumes structure of rollout = generator
self.Wi = tf.identity(self.lstm.Wi)
self.Ui = tf.identity(self.lstm.Ui)
self.bi = tf.identity(self.lstm.bi)
self.Wf = tf.identity(self.lstm.Wf)
self.Uf = tf.identity(self.lstm.Uf)
self.bf = tf.identity(self.lstm.bf)
self.Wog = tf.identity(self.lstm.Wog)
self.Uog = tf.identity(self.lstm.Uog)
self.bog = tf.identity(self.lstm.bog)
self.Wc = tf.identity(self.lstm.Wc)
self.Uc = tf.identity(self.lstm.Uc)
self.bc = tf.identity(self.lstm.bc)
def unit(x, hidden_memory_tm1):
previous_hidden_state, c_prev = tf.unstack(hidden_memory_tm1)
# Input Gate
i = tf.sigmoid(
tf.matmul(x, self.Wi) +
tf.matmul(previous_hidden_state, self.Ui) + self.bi
)
# Forget Gate
f = tf.sigmoid(
tf.matmul(x, self.Wf) +
tf.matmul(previous_hidden_state, self.Uf) + self.bf
)
# Output Gate
o = tf.sigmoid(
tf.matmul(x, self.Wog) +
tf.matmul(previous_hidden_state, self.Uog) + self.bog
)
# New Memory Cell
c_ = tf.nn.tanh(
tf.matmul(x, self.Wc) +
tf.matmul(previous_hidden_state, self.Uc) + self.bc
)
# Final Memory cell
c = f * c_prev + i * c_
# Current Hidden state
current_hidden_state = o * tf.nn.tanh(c)
return tf.stack([current_hidden_state, c])
return unit
def update_recurrent_unit(self):
# Weights and Bias for input and hidden tensor
self.Wi = self.update_rate * self.Wi + (1 - self.update_rate) * tf.identity(self.lstm.Wi)
self.Ui = self.update_rate * self.Ui + (1 - self.update_rate) * tf.identity(self.lstm.Ui)
self.bi = self.update_rate * self.bi + (1 - self.update_rate) * tf.identity(self.lstm.bi)
self.Wf = self.update_rate * self.Wf + (1 - self.update_rate) * tf.identity(self.lstm.Wf)
self.Uf = self.update_rate * self.Uf + (1 - self.update_rate) * tf.identity(self.lstm.Uf)
self.bf = self.update_rate * self.bf + (1 - self.update_rate) * tf.identity(self.lstm.bf)
self.Wog = self.update_rate * self.Wog + (1 - self.update_rate) * tf.identity(self.lstm.Wog)
self.Uog = self.update_rate * self.Uog + (1 - self.update_rate) * tf.identity(self.lstm.Uog)
self.bog = self.update_rate * self.bog + (1 - self.update_rate) * tf.identity(self.lstm.bog)
self.Wc = self.update_rate * self.Wc + (1 - self.update_rate) * tf.identity(self.lstm.Wc)
self.Uc = self.update_rate * self.Uc + (1 - self.update_rate) * tf.identity(self.lstm.Uc)
self.bc = self.update_rate * self.bc + (1 - self.update_rate) * tf.identity(self.lstm.bc)
def unit(x, hidden_memory_tm1):
previous_hidden_state, c_prev = tf.unstack(hidden_memory_tm1)
# Input Gate
i = tf.sigmoid(
tf.matmul(x, self.Wi) +
tf.matmul(previous_hidden_state, self.Ui) + self.bi
)
# Forget Gate
f = tf.sigmoid(
tf.matmul(x, self.Wf) +
tf.matmul(previous_hidden_state, self.Uf) + self.bf
)
# Output Gate
o = tf.sigmoid(
tf.matmul(x, self.Wog) +
tf.matmul(previous_hidden_state, self.Uog) + self.bog
)
# New Memory Cell
c_ = tf.nn.tanh(
tf.matmul(x, self.Wc) +
tf.matmul(previous_hidden_state, self.Uc) + self.bc
)
# Final Memory cell
c = f * c_prev + i * c_
# Current Hidden state
current_hidden_state = o * tf.nn.tanh(c)
return tf.stack([current_hidden_state, c])
return unit
def create_output_unit(self):
self.Wo = tf.identity(self.lstm.Wo)
self.bo = tf.identity(self.lstm.bo)
def unit(hidden_memory_tuple):
hidden_state, c_prev = tf.unstack(hidden_memory_tuple)
# hidden_state : batch x hidden_dim
logits = tf.matmul(hidden_state, self.Wo) + self.bo
# output = tf.nn.softmax(logits)
return logits
return unit
def update_output_unit(self):
self.Wo = self.update_rate * self.Wo + (1 - self.update_rate) * tf.identity(self.lstm.Wo)
self.bo = self.update_rate * self.bo + (1 - self.update_rate) * tf.identity(self.lstm.bo)
def unit(hidden_memory_tuple):
hidden_state, c_prev = tf.unstack(hidden_memory_tuple)
# hidden_state : batch x hidden_dim
logits = tf.matmul(hidden_state, self.Wo) + self.bo
# output = tf.nn.softmax(logits)
return logits
return unit
def update_params(self):
# update the parameters of the generator rollout object
# is this line necessary?: g_embeddings already initialized from __init__
self.g_embeddings = tf.identity(self.lstm.g_embeddings)
# update the recurrent unit and output unit
self.g_recurrent_unit = self.update_recurrent_unit()
self.g_output_unit = self.update_output_unit()