-
Notifications
You must be signed in to change notification settings - Fork 129
/
chat.py
145 lines (103 loc) · 4.22 KB
/
chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#!/usr/bin/env python
__docformat__ = 'restructedtext en'
__authors__ = ("Iulian Serban, Alessandro Sordoni")
__contact__ = "Iulian Serban <julianserban@gmail.com>"
import argparse
import cPickle
import traceback
import itertools
import logging
import time
import sys
import search
import collections
import string
import os
import numpy
import codecs
import nltk
from random import randint
from dialog_encdec import DialogEncoderDecoder
from numpy_compat import argpartition
from state import prototype_state
import theano
logger = logging.getLogger(__name__)
class Timer(object):
def __init__(self):
self.total = 0
def start(self):
self.start_time = time.time()
def finish(self):
self.total += time.time() - self.start_time
def sample(model, seqs=[[]], n_samples=1, sampler=None, ignore_unk=False):
if sampler:
context_samples, context_costs = sampler.sample(seqs,
n_samples=n_samples,
n_turns=1,
ignore_unk=ignore_unk,
verbose=True)
return context_samples
else:
raise Exception("I don't know what to do")
def remove_speaker_tokens(s):
s = s.replace('<first_speaker> ', '')
s = s.replace('<second_speaker> ', '')
s = s.replace('<third_speaker> ', '')
s = s.replace('<minor_speaker> ', '')
s = s.replace('<voice_over> ', '')
s = s.replace('<off_screen> ', '')
return s
def parse_args():
parser = argparse.ArgumentParser("Sample (with beam-search) from the session model")
parser.add_argument("--ignore-unk",
default=True, action="store_true",
help="Ignore unknown words")
parser.add_argument("model_prefix",
help="Path to the model prefix (without _model.npz or _state.pkl)")
parser.add_argument("--normalize",
action="store_true", default=False,
help="Normalize log-prob with the word count")
return parser.parse_args()
def main():
args = parse_args()
state = prototype_state()
state_path = args.model_prefix + "_state.pkl"
model_path = args.model_prefix + "_model.npz"
with open(state_path) as src:
state.update(cPickle.load(src))
logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
model = DialogEncoderDecoder(state)
if os.path.isfile(model_path):
logger.debug("Loading previous model")
model.load(model_path)
else:
raise Exception("Must specify a valid model path")
logger.info("This model uses " + model.decoder_bias_type + " bias type")
#sampler = search.RandomSampler(model)
sampler = search.BeamSampler(model)
# Start chat loop
utterances = collections.deque()
while (True):
var = raw_input("User - ")
# Increase number of utterances. We just set it to zero for simplicity so that model has no memory.
# But it works fine if we increase this number
while len(utterances) > 0:
utterances.popleft()
current_utterance = [ model.end_sym_utterance ] + ['<first_speaker>'] + var.split() + [ model.end_sym_utterance ]
utterances.append(current_utterance)
#TODO Sample a random reply. To spice it up, we could pick the longest reply or the reply with the fewest placeholders...
seqs = list(itertools.chain(*utterances))
#TODO Retrieve only replies which are generated for second speaker...
sentences = sample(model, \
seqs=[seqs], ignore_unk=args.ignore_unk, \
sampler=sampler, n_samples=5)
if len(sentences) == 0:
raise ValueError("Generation error, no sentences were produced!")
utterances.append(sentences[0][0].split())
reply = sentences[0][0].encode('utf-8')
print "AI - ", remove_speaker_tokens(reply)
if __name__ == "__main__":
# Run with THEANO_FLAGS=mode=FAST_RUN,floatX=float32,allow_gc=True,scan.allow_gc=False,nvcc.flags=-use_fast_math python chat.py Model_Name
# Models only run with float32
assert(theano.config.floatX == 'float32')
main()