This repository has been archived by the owner on Jan 30, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 20
/
language_generation.py
684 lines (577 loc) · 25 KB
/
language_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
"""Generate language using XLNet"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import numpy as np
import os
import re
from tqdm import tqdm
import absl.logging as _logging # pylint: disable=unused-import
import tensorflow as tf
import sentencepiece as spm
import model_utils
from data_utils import CLS_ID, special_symbols, EOD_ID
import xlnet
from prepro_utils import preprocess_text, encode_ids
EOP_ID = special_symbols["<eop>"]
parser = argparse.ArgumentParser()
# Model
parser.add_argument("--model_config_path", default=None,
help="Model config path.", type=str)
parser.add_argument("--clamp_len", default=-1,
help="Clamp length", type=int)
parser.add_argument("--same_length", default=False,
help="Same length attention", action='store_true')
# Data and memory
parser.add_argument("--batch_size", default=1, help='batch size', type=int)
parser.add_argument("--max_mem_length", default=128,
help="Max sequence length for cached hidden states"
" which each predicted token is conditioned upon"
". Directly increases the memory requirement", type=int)
parser.add_argument("--uncased", default=False,
help="Use uncased inputs or not.", action='store_true')
# I/O paths
parser.add_argument("--init_checkpoint", default=None,
help="checkpoint path for initializing the model. "
"Could be a pretrained model or a finetuned model.")
parser.add_argument("--spiece_model_file", default="",
help="Sentence Piece model path.")
parser.add_argument("--input_file", default="",
help="File containing prompts separated by empty new line "
"for conditional sampling")
# prediction
parser.add_argument("--num_samples", default=1,
help="Number of samples to predict per instance", type=int)
parser.add_argument(
"--interactive",
default=False,
help="Flag for interactive prediction through command line",
action='store_true')
parser.add_argument(
"--unconditional",
default=False,
help="Prints samples wihtout any prompt",
action='store_true')
parser.add_argument(
"--top_p",
default=0,
help="Top-p coverage to use. Set 0 to use top_k sampling",
type=float)
parser.add_argument(
"--top_k",
default=40,
help="Top-k sampling strategy parameter. Use only when top-p is zero. Set"
"-1 to use all the samples",
type=int)
parser.add_argument("--temperature", default=1,
help="Scaling factor for logits", type=int)
parser.add_argument("--num_toks_pred", default=1024,
help="Number of tokens to predict", type=int)
parser.add_argument("--bidirectional_eachstep", help="Compute bidirectional"
"attention every step. Consumes a lot of time but better results",
action='store_true')
FLAGS = parser.parse_args()
def _create_mask(qlen, mlen):
"""Simple bi-directional attention mask. Attend
to all token in sequence and memory"""
klen = qlen + mlen
return tf.zeros((qlen, klen))
def get_preprocessor(examples, tokenize_fn, pad_ids):
"""
Input:
examples: [List[str]] input texts
tokenize_fn: [function] encodes text into IDs
Output:
tf input features
"""
def generator():
for example in examples:
tokens = tokenize_fn(example)
yield pad_ids + tokens
return generator
def get_input_dataset(preprocessor):
"""Returns tf.data.Dataset for input"""
batch_size = FLAGS.batch_size
max_mem_length = FLAGS.max_mem_length
def mask(ids):
example = {'input_k': ids}
input_k = example['input_k'][-max_mem_length:]
seq_len = tf.shape(input_k)[0]
input_mask = tf.tile(
tf.convert_to_tensor(
[0],
dtype=tf.float32),
[seq_len])
pad_len = tf.maximum(0, max_mem_length - seq_len)
pad_tensor = tf.concat([[[pad_len]], [[0]]], axis=-1)
input_k = tf.pad(input_k, pad_tensor, constant_values=0)
input_mask = tf.pad(input_mask, pad_tensor, constant_values=1)
example['input_mask'] = input_mask
example['input_k'] = input_k
example['seg_id'] = tf.convert_to_tensor([0] * max_mem_length)
return example
dataset = tf.data.Dataset.from_generator(preprocessor,
output_types=tf.int32)
dataset = dataset.map(mask)
dataset = dataset.batch(batch_size,
drop_remainder=False)
dataset.prefetch(1)
return dataset
def inputs_and_mask(latest_tokens, batch_size):
"""Computes inputs and masks for prediction loop.
A dummy token ([CLS]) is appended at the at of the previous
tokens
Input:
latest_tokens: Tensor [batch_size,1] or None
If None then last dimension is 1 in the returned tensors
output:
input_k: [batch_size,2] latest_tokens with a dummy
token appened at the end of the sequence
seg_id: [batch_size,2]
attn_masks: [batch_size,2,2]
input_q: [batch_size,2]
masks the tokens to predict. In this case the last token
"""
input_k = tf.tile([[CLS_ID]], [batch_size, 1])
seg_id = tf.tile([[0]], [batch_size, 1])
input_q = tf.tile([[1]], [batch_size, 1])
if latest_tokens is not None:
input_k = tf.concat([latest_tokens, input_k], axis=-1)
seg_id = tf.tile(seg_id, [1, 2])
input_q_0 = tf.tile([[0]], [batch_size, 1])
input_q = tf.concat([input_q_0, input_q], axis=-1)
target_mapping = tf.tile(tf.constant(
[[[0], [1]]], dtype=tf.float32), [1, 1, batch_size])
attn_masks = tf.convert_to_tensor([[0, 1], [0, 1]], dtype=tf.float32)
else:
attn_masks = tf.convert_to_tensor([[1]], dtype=tf.float32)
target_mapping = tf.tile(tf.constant(
[[[1]]], dtype=tf.float32), [1, 1, batch_size])
attn_masks = tf.tile(attn_masks[None, :, :], [batch_size, 1, 1])
input_q = tf.cast(input_q, tf.float32)
return input_k, seg_id, attn_masks, input_q, target_mapping
def get_logits(xlnet_model, xlnet_config):
"""Builds the graph for calculating the final logits"""
lookup_table = xlnet_model.get_embedding_table()
tie_weight = True
with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
initializer = xlnet_model.get_initializer()
hidden = xlnet_model.get_sequence_output()[-1:, :, :]
n_token = xlnet_config.n_token
d_model = xlnet_config.d_model
with tf.variable_scope('lm_loss'):
if tie_weight:
assert lookup_table is not None, \
'lookup_table cannot be None for tie_weight'
softmax_w = lookup_table
else:
softmax_w = tf.get_variable(
'weight', [
n_token, d_model], dtype=hidden.dtype, initializer=initializer)
softmax_b = tf.get_variable('bias', [n_token], dtype=hidden.dtype,
initializer=tf.zeros_initializer())
logits = tf.einsum('ibd,nd->ibn', hidden, softmax_w) + softmax_b
return logits
def sampling_strategy():
"""Based on flags return either top_k or
top_p strategy."""
if FLAGS.top_p != 0:
return 'top_p'
return 'top_k'
def sample_token(logits):
"""
Inputs:
logits: tf.Tensor([batch_size,len,num_tokens])
Outpus:
samples: tf.Tensor([batch_size,len,1])
"""
# credits: https://github.com/nshepperd/gpt-2
logits /= FLAGS.temperature
batch_size = tf.shape(logits)[0]
seq_len = tf.shape(logits)[1]
num_toks = tf.shape(logits)[2]
if sampling_strategy() == 'top_p':
logits_sorted = tf.sort(logits,
direction="DESCENDING",
axis=-1)
probs = tf.nn.softmax(logits_sorted, axis=-1)
cum_probs = tf.math.cumsum(probs,
axis=-1,
exclusive=True)
logits_masked = tf.where(cum_probs < FLAGS.top_p,
logits_sorted,
tf.ones_like(logits_sorted) * 100)
min_logits = tf.reduce_min(logits_masked, axis=-1)
logits_masked = tf.where(logits < min_logits,
tf.ones_like(logits) * -1e10,
logits)
elif sampling_strategy() == "top_k":
if FLAGS.top_k != 0:
values, _ = tf.nn.top_k(logits, k=FLAGS.top_k)
min_values = values[:, :, -1:]
logits_masked = tf.where(
logits < min_values,
tf.ones_like(logits, dtype=logits.dtype) * -1e10,
logits,
)
else:
raise NotImplementedError("Invalid sampling strategy")
logits_masked = tf.reshape(logits_masked, (-1, num_toks))
samples = tf.random.categorical(logits_masked,
num_samples=1,
dtype=tf.int32)
probs = tf.nn.softmax(tf.reshape(logits, (-1, num_toks)), axis=-1)
confidences = tf.gather_nd(params=probs, batch_dims=1, indices=samples)
return tf.reshape(samples, (batch_size, seq_len, 1)),\
tf.reshape(confidences, (batch_size, seq_len, 1))
def prediction_graph_memory():
"""Gets features and
return predicted tokens)
features: Dict[str:tf.train.features] Contains following features:
input_k
seg_id
input_mask
"""
features = {
"input_k": tf.placeholder(tf.int32, (None, None)),
"seg_id": tf.placeholder(tf.int32, (None, None)),
"input_mask": tf.placeholder(tf.float32, (None, None))
}
# Building prediction graph
# Transforming features for batch channel on last axis
inp = tf.transpose(features["input_k"], [1, 0])
seg_id = tf.transpose(features["seg_id"], [1, 0])
inp_mask = tf.transpose(features["input_mask"], [1, 0])
# Model config
xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path)
run_config = xlnet.create_run_config(False, True, FLAGS)
run_config.mem_len = FLAGS.max_mem_length
perm_mask = _create_mask(tf.shape(inp)[0], 0)[:, :, None]
# Getting the hidden states for the prompts
xlnet_model = xlnet.XLNetModel(
xlnet_config=xlnet_config,
run_config=run_config,
input_ids=inp,
seg_ids=seg_id,
input_mask=inp_mask,
perm_mask=perm_mask)
# getting memory
mems = xlnet_model.get_new_memory()
latest_tokens = None
prev_tokens = None
prev_confs = None
batch_size = tf.shape(mems[0])[1]
def cond(*_):
"""Dummy condition since we stop based on iteration"""
return True
def body(mems, latest_tokens, mem_mask, prev_tokens, prev_confs):
"""The main body of sampling loop.
mem: cache memory--calculated hidden states
of previous tokens
latest_tokens: latest sampled tokens
mem_mask: masking for setting previous memory zero. Used for padding
prev_tokens: all the previous tokens including latest_tokens
prev_confs: confidences of respective tokens in prev_tokens
"""
# get dummy input token and permutation mask
input_k, seg_id, perm_mask, input_q, target_mapping = \
inputs_and_mask(latest_tokens,
batch_size)
input_k = tf.transpose(input_k, (1, 0))
input_q = tf.transpose(input_q, (1, 0))
seg_id = tf.transpose(seg_id, (1, 0))
perm_mask = tf.transpose(perm_mask, (1, 2, 0))
# Set the hidden state of the padded tokens to be zero[
for i, mem in enumerate(mems):
mems[i] = (1 - mem_mask[:, :, None]) * mems[i]
# Get logits
xlnet_model = xlnet.XLNetModel(
xlnet_config=xlnet_config,
run_config=run_config,
input_ids=input_k,
seg_ids=seg_id,
perm_mask=perm_mask,
mems=mems,
input_mask=None,
inp_q=input_q,
target_mapping=target_mapping)
logits = get_logits(xlnet_model, xlnet_config)
# Getting new memory
new_mems = xlnet_model.get_new_memory()
# sample a token
logits = tf.transpose(logits, (1, 0, 2))
sampled_tokens, confs = sample_token(logits)
sampled_tokens = sampled_tokens[:, -1, :] # Last token
confs = confs[:, -1, :] # Last token
prev_tokens = sampled_tokens if prev_tokens is None \
else tf.concat([prev_tokens, sampled_tokens], axis=1)
prev_confs = confs if prev_confs is None \
else tf.concat([prev_confs, confs], axis=1)
# Cache the memory of the the last latest_tokens
if latest_tokens is not None:
merged_mems = []
for i, mem in enumerate(mems):
merged_mems.append(
tf.concat([mems[i][1:], new_mems[i][-2:-1]], axis=0))
mem_mask = tf.concat(
[mem_mask[1:], tf.zeros_like(mem_mask[:1])], axis=0)
return [
merged_mems,
sampled_tokens,
mem_mask,
prev_tokens,
prev_confs]
return [mems, sampled_tokens, mem_mask, prev_tokens, prev_confs]
mems, latest_tokens, mem_mask, prev_tokens, prev_confs = body(
mems, latest_tokens, inp_mask, prev_tokens, prev_confs)
args = tf.while_loop(
cond=cond,
body=body,
maximum_iterations=FLAGS.num_toks_pred - 1,
loop_vars=[mems, latest_tokens, mem_mask, prev_tokens, prev_confs],
shape_invariants=[
[tf.TensorShape([None, None, None]) for _ in range(len(mems))],
tf.TensorShape([None, None]),
tf.TensorShape([None, None]),
tf.TensorShape([None, None]),
tf.TensorShape([None, None])
]
)
predicted_tokens, predicted_confs = args[-2:]
return (predicted_tokens, predicted_confs), features
def prediction_graph_no_memory():
"""Builds graphs and returns prediction and input features.
Output:
predictions: Tuple(Tensors) Currently returns sampled tokens and confidences
features: Dict[str:tf.train.features] Contains following features:
input_k
seg_id
input_mask
"""
features = {
"input_k": tf.placeholder(tf.int32, (None, None)),
"seg_id": tf.placeholder(tf.int32, (None, None)),
"input_mask": tf.placeholder(tf.float32, (None, None))
}
# Building prediction graph
# Transforming features for batch channel on last axis
inp = tf.transpose(features["input_k"], [1, 0])
seg_id = tf.transpose(features["seg_id"], [1, 0])
inp_mask = tf.transpose(features["input_mask"], [1, 0])
# Model config
xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path)
run_config = xlnet.create_run_config(False, True, FLAGS)
run_config.mem_len = FLAGS.max_mem_length
perm_mask = _create_mask(tf.shape(inp)[0], 0)[:, :, None]
# Getting the hidden states for the prompts
prev_tokens = None
prev_conf = None
# target mapping
seq_len = tf.shape(inp)[0]
batch_size = tf.shape(inp)[-1]
target_mapping = tf.concat(
[tf.zeros((1, seq_len - 1, batch_size)), tf.ones((1, 1, batch_size))], axis=1)
def cond(*_):
"""Dummy condition since we stop based on iteration"""
return True
def recalc(inp, inp_mask, seg_id, perm_mask):
"""Augment the inputs for the new token. Appends 1 row or columns accordingly"""
input_q = tf.zeros_like(inp, dtype=tf.float32)
inp = tf.pad(inp, tf.convert_to_tensor(
[[0, 1], [0, 0]]), constant_values=0)
inp_mask = tf.pad(inp_mask, tf.convert_to_tensor(
[[0, 1], [0, 0]]), constant_values=0)
seg_id = tf.pad(seg_id, tf.convert_to_tensor(
[[0, 1], [0, 0]]), constant_values=0)
col = tf.ones(tf.shape(perm_mask)[0:1], dtype=tf.float32)
perm_mask = tf.concat([perm_mask, col[:, None, None]], axis=1)
row = tf.concat([tf.zeros(tf.shape(perm_mask)[1:2] - 1, dtype=tf.float32),
tf.ones([1], dtype=tf.float32)], axis=0)
perm_mask = tf.concat([perm_mask, row[None, :, None]], axis=0)
input_q = tf.pad(input_q, tf.convert_to_tensor(
[[0, 1], [0, 0]]), constant_values=1)
return inp[1:], inp_mask[1:], perm_mask[1:, 1:], input_q[1:], seg_id[1:]
def body(inp, inp_mask, seg_id, perm_mask, prev_tokens, prev_conf):
"""The main body of sampling loop.
inp: input ids
inp_mask: input masks for paddings, etc.
seg_id: segment id. Zeros here.
perm_mask: permutation mask to pass to transformer
prev_tokens: all the previous tokens including latest_tokens
prev_conf: confidences of respective tokens in prev_tokens
"""
# get dummy input token and permutation mask
input_k, input_mask, perm_mask, input_q, seg_id = recalc(
inp, inp_mask, seg_id, perm_mask)
# Get logits
xlnet_model = xlnet.XLNetModel(
xlnet_config=xlnet_config,
run_config=run_config,
input_ids=input_k,
seg_ids=seg_id,
input_mask=inp_mask,
perm_mask=perm_mask,
inp_q=input_q,
target_mapping=target_mapping)
logits = get_logits(xlnet_model, xlnet_config)
# sample a token
logits = tf.transpose(logits, (1, 0, 2))
sampled_tokens, confidences = sample_token(logits)
sampled_tokens = sampled_tokens[:, -1, :] # Last token
confidences = confidences[:, -1, :]
prev_tokens = sampled_tokens if prev_tokens is None \
else tf.concat([prev_tokens, sampled_tokens], axis=1)
prev_conf = confidences if prev_conf is None \
else tf.concat([prev_conf, confidences], axis=1)
input_k = tf.concat(
[input_k[:-1], tf.transpose(sampled_tokens, (1, 0))], axis=0)
perm_mask = _create_mask(tf.shape(input_k)[0], 0)[:, :, None]
return [input_k, input_mask, seg_id, perm_mask, prev_tokens, prev_conf]
inp, inp_mask, seg_id, perm_mask, prev_tokens, prev_conf = body(
inp, inp_mask, seg_id, perm_mask, prev_tokens, prev_conf)
args = tf.while_loop(
cond=cond,
body=body,
maximum_iterations=FLAGS.num_toks_pred - 1,
loop_vars=[inp, inp_mask, seg_id, perm_mask, prev_tokens, prev_conf],
shape_invariants=[
tf.TensorShape([None, None]),
tf.TensorShape([None, None]),
tf.TensorShape([None, None]),
tf.TensorShape([None, None, None]),
tf.TensorShape([None, None]),
tf.TensorShape([None, None]),
]
)
predicted_tokens, predicted_confs = args[-2:]
return (predicted_tokens, predicted_confs), features
def main():
"""Main function routine"""
tf.logging.set_verbosity(tf.logging.INFO)
# Text encoding
sp = spm.SentencePieceProcessor()
sp.Load(FLAGS.spiece_model_file)
def tokenize_fn(text):
text = preprocess_text(text, lower=FLAGS.uncased)
return encode_ids(sp, text)
# Temporary fix for context problem.
pad_txt = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
Rasputin has a vision and denounces one of the men as a horse thief. Although his
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. """
pad_ids = tokenize_fn(pad_txt)
pad_ids.append(EOD_ID)
to_special_symbol = {v:k for k,v in special_symbols.items()}
def parse_ids(toks):
"""Uses sentencepiece to conver to text. Subsitute
EOP_ID and EOD_ID with new lines, and rest with their names"""
start = 0
sent = ""
for i in range(len(toks)):
if toks[i] in to_special_symbol:
if start<i:
sent+=sp.decode_ids(toks[start:i])
if toks[i] in [EOD_ID,EOP_ID]:
replace_by = "\n\n"
else:
replace_by = to_special_symbol[toks[i]]
sent+=f" {replace_by} "
start=i+1
if start<len(toks):
sent+=sp.decode_ids(toks[start:])
return sent
if not FLAGS.bidirectional_eachstep:
prediction_graph = prediction_graph_memory
else:
prediction_graph = prediction_graph_no_memory
predictions, features = prediction_graph()
gpu_options = tf.GPUOptions(allow_growth=True)
model_utils.init_from_checkpoint(FLAGS, global_vars=False)
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
gpu_options=gpu_options)) as sess:
sess.run(tf.global_variables_initializer())
def predict(examples):
"""Given a list of texts in examples
return the result"""
preprocessor = get_preprocessor(examples,
tokenize_fn, pad_ids)
dataset = get_input_dataset(preprocessor)
example = dataset.make_one_shot_iterator().get_next()
num_examples = len(examples)
num_batches = int(np.ceil(num_examples / FLAGS.batch_size))
for _ in tqdm(range(num_batches)):
inputs = sess.run(example)
output, conf = sess.run(
predictions, feed_dict={
features[k]: v for k, v in inputs.items()})
for _output,_conf in zip(output,conf):
yield _output,_conf
if FLAGS.unconditional or FLAGS.interactive:
tf.logging.info("Interactive flag received."
" Ignoring input files if any.")
while True:
if FLAGS.unconditional:
text = ""
else:
text = input("----PROMPT----\n")
outputs = predict([text] * FLAGS.num_samples)
for i, (output,_) in enumerate(outputs):
out = parse_ids(output.tolist())
print("======SAMPLE {}======".format(i))
print(out)
print("=====================")
if FLAGS.unconditional:
break
else:
assert FLAGS.input_file!="", "Please provide either an"\
" input file or set interactive flag for command line input"
assert os.path.exists(FLAGS.input_file), FLAGS.input_file+\
" does not exists"
with open(FLAGS.input_file) as f:
texts = []
text = ""
for line in f:
if line.strip()=="":
if text!="":
# Removing the last <eop> of prompt
# since it is not desired
if text.endswith("<eop>"):
text=text[:-5]
texts.extend([text]*FLAGS.num_samples)
text=""
continue
text+=re.sub(r'\n','<eop>',line)
if text!="":
texts.extend([text]*FLAGS.num_samples)
tf.logging.info("Got %s lines in the input file",
len(texts)//FLAGS.num_samples)
tf.logging.info("Sampling each line %s times",FLAGS.num_samples)
outputs = iter(predict(texts))
with open(os.path.join(FLAGS.input_file+".xlnet"),'w') as f:
for i in range(0,len(texts),FLAGS.num_samples):
f.write("\n======Example {}=================\n".format(i))
f.write(texts[i])
for j in range(FLAGS.num_samples):
output,_ = next(outputs)
out = parse_ids(output.tolist())
f.write("\n======Example {} SAMPLE {}======\n".format(i,j))
f.write(out)
f.write("\n==================================\n")
if __name__ == "__main__":
# Fixed flags
FLAGS.use_tpu = False
FLAGS.use_bfloat16 = False
FLAGS.dropout = 0
FLAGS.dropatt = 0
FLAGS.init = "normal"
FLAGS.init_std = 0.02
FLAGS.init_range = 0.1
print("Args: {}".format(vars(FLAGS)))
main()