-
Notifications
You must be signed in to change notification settings - Fork 0
/
qrnn_decode_eval.py
86 lines (78 loc) · 3.76 KB
/
qrnn_decode_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
''' QRNN decode function for 2l-dr: headline generation (take 2)
defines qrnn decoding, a deep learning sequence-to-sequence model'''
import tensorflow as tf
def get_input_from_state(state, embeddings, output_projection):
vocab = tf.nn.xw_plus_b(state, output_projection[0], output_projection[1])
word_ids = tf.argmax(vocab, axis=1)
return tf.nn.embedding_lookup(embeddings, word_ids)
def advance_step_input(step_input, new_input):
result = tf.concat(1, [step_input, new_input])
# result = tf.concat(1, [step_input, tf.expand_dims(new_input, 1)])
return result[:, 1:, :]
def decode_evaluate(decoder, encode_outputs, embedded_dec_inputs,
embeddings):
H = []
batch_size = tf.shape(embedded_dec_inputs)[0]
layer_inputs = {}
layer_outputs = {}
for i in range(decoder.seq_length):
if i == 0:
step_input = tf.fill([batch_size, decoder.conv_size,
decoder.embedding_size], 0.0)
layer_inputs[0] = step_input
new_input = embedded_dec_inputs[:, 0, :]
else:
step_input = layer_inputs[0]
new_input = get_input_from_state(H[-1], embeddings,
decoder.output_projection)
step_input = advance_step_input(step_input,
tf.expand_dims(new_input, 1))
for j in range(decoder.num_layers):
enc_out = tf.squeeze(encode_outputs[j][:, -1, :])
if i == 0:
if j < decoder.num_layers-1:
input_size = decoder.embedding_size if j == 0 \
else decoder.num_convs
step_input, c_t = decoder.conv_with_encode_output(
j,
enc_out,
layer_inputs[j],
input_size,
seq_len=decoder.conv_size)
layer_inputs[j+1] = step_input
layer_outputs[j] = c_t
else:
input_size = decoder.embedding_size \
if decoder.num_layers == 1 \
else decoder.num_convs
h_t, c_t = decoder.conv_with_attention(
j, encode_outputs,
layer_inputs[j],
input_size,
seq_len=decoder.conv_size)
H.append(tf.squeeze(h_t[:, -1:, :]))
layer_outputs[j] = c_t
else:
input_shape = decoder.embedding_size if j == 0 else \
decoder.num_convs
if j < decoder.num_layers-1:
h_t, c_t = decoder.eval_conv_with_encode_output(
j,
enc_out,
layer_inputs[j],
input_shape,
layer_outputs[j])
layer_inputs[j+1] = advance_step_input(layer_inputs[j+1],
h_t)
layer_outputs[j] = c_t
else:
h_t, c_t = decoder.eval_conv_with_attention(
j,
encode_outputs,
layer_inputs[j],
input_shape,
layer_outputs[j])
H.append(tf.squeeze(h_t))
layer_outputs[j] = c_t
return tf.reshape(tf.pack(H), [batch_size,
decoder.seq_length, decoder.num_convs])