forked from hzy46/Char-RNN-TensorFlow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sample.py
39 lines (30 loc) · 1.46 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import tensorflow as tf
from read_utils import TextConverter
from model import CharRNN
import os
from IPython import embed
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_integer('lstm_size', 128, 'size of hidden state of lstm')
tf.flags.DEFINE_integer('num_layers', 2, 'number of lstm layers')
tf.flags.DEFINE_boolean('use_embedding', False, 'whether to use embedding')
tf.flags.DEFINE_integer('embedding_size', 128, 'size of embedding')
tf.flags.DEFINE_string('converter_path', '', 'model/name/converter.pkl')
tf.flags.DEFINE_string('checkpoint_path', '', 'checkpoint path')
tf.flags.DEFINE_string('start_string', '', 'use this string to start generating')
tf.flags.DEFINE_integer('max_length', 30, 'max length to generate')
def main(_):
FLAGS.start_string = FLAGS.start_string.decode('utf-8')
converter = TextConverter(filename=FLAGS.converter_path)
if os.path.isdir(FLAGS.checkpoint_path):
FLAGS.checkpoint_path =\
tf.train.latest_checkpoint(FLAGS.checkpoint_path)
model = CharRNN(converter.vocab_size, sampling=True,
lstm_size=FLAGS.lstm_size, num_layers=FLAGS.num_layers,
use_embedding=FLAGS.use_embedding,
embedding_size=FLAGS.embedding_size)
model.load(FLAGS.checkpoint_path)
start = converter.text_to_arr(FLAGS.start_string)
arr = model.sample(FLAGS.max_length, start, converter.vocab_size)
print(converter.arr_to_text(arr))
if __name__ == '__main__':
tf.app.run()