-
Notifications
You must be signed in to change notification settings - Fork 6
/
target_chat.py
100 lines (88 loc) · 4.34 KB
/
target_chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import tensorflow as tf
import importlib
import random
import os
from preprocess.data_utils import utter_preprocess, is_reach_goal
from utils.log_utils import create_logs, add_log
import time
class Target_Chat:
def __init__(self, model, config_model, config_data):
self.agent = model.Predictor(config_model, config_data, 'test')
self.sess = tf.Session(config=self.agent.gpu_config)
self.agent.retrieve_init(self.sess)
self.target_set = config_data._target_keywords_for_simulation
self.start_corpus = config_data._start_corpus
self.max_turns = config_data._max_turns
self.conversation_save_path = config_model._conversation_save_path
self.current_sessions = 0
create_logs(self.conversation_save_path)
def chat(self, user_input=None):
responses = []
# if is the beginning of a conversation
if user_input is None:
self._reset()
reply = self.start_utterance
add_log(self.conversation_save_path, '-------- Session {} --------'.format(self.current_sessions))
add_log(self.conversation_save_path, 'START: {}'.format(reply))
else:
self.history.append(user_input)
source = utter_preprocess(self.history, self.agent.data_config._max_seq_len)
reply = self.agent.retrieve(source, self.sess)
add_log(self.conversation_save_path, 'HUMAN: {}'.format(user_input), print_details=False)
add_log(self.conversation_save_path, 'AGENT: {}'.format(reply))
self.history.append(reply)
responses.append(reply)
self.current_turns += 1
# if the last two utterances contain target keyword
if is_reach_goal(' '.join(self.history[-2:]), self.target_keyword):
end_message = '[SUCCESS] target: \'{}\'.'.format(self.target_keyword)
add_log(self.conversation_save_path, end_message)
responses.append(end_message)
# if is out of the max dialogue turn
elif self.current_turns > self.max_turns:
end_message = '[FAIL] out of the max dialogue turns, target: \'{}\'.'.format(self.target_keyword)
add_log(self.conversation_save_path, end_message)
responses.append(end_message)
return responses
def _reset(self):
self.current_turns = 0
self.current_sessions += 1
self.history = []
self.start_utterance = random.sample(self.start_corpus, 1)[0]
self.target_keyword = random.sample(self.target_set,1)[0]
self.agent.target = self.target_keyword
self.agent.score = 0.
self.agent.reply_list = []
def init_target_chat(agent_name, dataset):
# Target-Guided PersonaChat Dataset
if dataset == 'TGPC':
config_dir = 'config.'
os.environ['is_weibo'] = 'False'
# Chinese Weibo Conversation Dataset
elif dataset == 'CWC':
config_dir = 'config_weibo.'
os.environ['is_weibo'] = 'True'
config_data = importlib.import_module(config_dir + 'data_config')
config_model = importlib.import_module(config_dir + agent_name)
model = importlib.import_module('model.' + agent_name)
predictor = model.Predictor(config_model, config_data, 'test')
init_start_time = time.time()
print("生成 TGODC-{}-{} Model 实例.................".format(agent_name, dataset))
target_chat_instance = Target_Chat(model, config_model, config_data)
print("TGODC-{}-{} Model 实例生成完成...............".format(agent_name, dataset))
init_end_time = time.time()
print('初始化花费时间: {:.2f}s'.format(init_end_time - init_start_time))
return target_chat_instance
if __name__ == '__main__':
flags = tf.flags
flags.DEFINE_string('dataset', 'TGPC', 'The dataset, supports TGPC / CWC.')
flags.DEFINE_string('agent', 'neural_dkr', 'The agent type, \
supports neural_dkr / kernel / matrix / neural / retrieval / retrieval_stgy.')
flags.DEFINE_integer('times', 10, 'Conversation times.')
FLAGS = flags.FLAGS
target_chat_instance = init_target_chat(FLAGS.agent, FLAGS.dataset)
for i in range(FLAGS.times):
responses = []
target_chat_instance.chat()
while len(responses) < 2:
responses = target_chat_instance.chat(input('HUMAN: '))