-
Notifications
You must be signed in to change notification settings - Fork 0
/
AIManager.py
160 lines (137 loc) · 5.87 KB
/
AIManager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import configparser
from os import path
import openai
import re
from queue import Queue
from threading import Thread
# import subprocess
# import rospy
# from std_msgs.msg import String
class AIManager:
l = None
def __init__(self, logger, config_file, display_man, sound_man):
self.l = logger
self.config = configparser.ConfigParser()
self.config.read(config_file)
self.parse_config()
self.display_man = display_man
self.sound_man = sound_man
self.result_outputs = Queue()
self.valid_command = {"painting1": 0, "painting1+painting2": 0}
# rospy.init_node("send_audio", anonymous=True)
# self.pub = rospy.Publisher("audio", String, queue_size=100)
# This is called from the main thread, so we want to pass it off
# to a separate thread to not hang the program
def handle_command(self, text):
async_thread = Thread(target = self.handle_command_async,
args=(text,))
async_thread.start()
def handle_command_async(self, text):
prompt = self.prompt_text
# print("$$$$$$$$$$$$$$$$$$$$$$$$$$")
print("############################# before if: {}\n prompt name: {}".format(prompt, self.prompt_name))
# Could add separate handlers for different prompts if desired...
if self.prompt_name == "prompt1":
prompt = prompt+text+"\n"+"Object:"
print("############################# after if: {}".format(prompt))
response = openai.Completion.create(
engine="davinci",
prompt=prompt,
temperature=self.temperature,
max_tokens=self.resp_len,
# top_p=1.0,
top_p=self.top_p,
frequency_penalty=0.0,
presence_penalty=0.0,
stop=[self.stop_token]
)
result = {}
# currently using prompt1
if self.prompt_name == "prompt1":
# run everytime
result = self.parse_prompt1_response(response)
if result is None:
self.sound_man.play_blocking("ai failed")
else:
self.display_man.got_ai_result(result)
self.result_outputs.put(result)
self.l.log(f"Response: {result}", "DEBUG")
def parse_prompt1_response(self, response):
# There's lots of weird stuff that GPT-3 could return, so
# we'll try and filter out as many of these errors as possible
# before attempting to parse the command
try:
print("!!!!!!!!!!!!!!!!!! Response: {}".format(response))
reason = response["choices"][0]["finish_reason"]
text = response["choices"][0]["text"]
self.l.log(reason, "CHECK")
self.l.log(text, "CHECK")
if reason != "stop":
self.l.log(f"Response didn't stop: {response}", "RUN")
raise ValueError
elif "State" not in text or "Response" not in text:
self.l.log(f"Bad response text: {response}", "RUN")
raise ValueError
# Sometimes this can help it if it's slightly wrong
if text[-1] != "\n": text = text+"\n"
# Response looks halfway decent; let's parse
result = {}
result["prompt_name"] = self.prompt_name
result["object"] = self.between_strings(" ", "\n", text).lower()
# obj = result["object"].split("+")
# obj = result["object"]
# if (obj in self.valid_command):
# print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
# subprocess.call(['sh', './helper.sh', obj])
# print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
msg = String()
msg.data = obj
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print(obj)
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
if (obj):
self.pub.publish(msg)
result["state"] = \
self.between_strings("\nDesired State: ", "\n", text).lower()
result["quip"] = \
self.between_strings("\nResponse: ", "\n", text)
return result
except KeyError as e:
self.l.log(e, "DEBUG")
self.l.log(f"KeyError in Response",
"DEBUG")
except AttributeError as e:
self.l.log(e, "DEBUG")
self.l.log(f"AttributeError in Response",
"DEBUG")
except ValueError as e:
self.l.log(e, "DEBUG")
self.l.log(f"ValueError in Response",
"DEBUG")
self.l.log(f"Error parsing response: {response}",
"RUN")
self.display_man.ai_result_failed(response)
return None
def parse_config(self):
prompt_file = self.config["AI"]["prompt_file_path"]
f = open(prompt_file, "r")
prompt = f.read()
# Emacs adds a newline on save
if prompt[-1] == "\n": prompt = prompt[:-1]
self.prompt_text = prompt
# From test/val1.txt grabs val1
prompt_name = path.split(prompt_file)[-1].split(".")[0]
self.temperature = float(self.config[prompt_name]["temperature"])
self.top_p = float(self.config[prompt_name]["top_p"])
self.resp_len = int(self.config[prompt_name]["response_length"])
self.stop_token = self.config[prompt_name]["stop_token"]
self.prompt_name = prompt_name
# You'll get an error here if the key file doesn't exist
f = open(self.config["AI"]["key_path"], "r")
key_text = f.read()
if key_text[-1] == "\n": key_text = key_text[:-1]
openai.api_key = key_text
# Get text between two strings
def between_strings(self, start, end, text):
search_re = start+"(.*)"+end
return re.search(search_re, text).group(1)