-
Notifications
You must be signed in to change notification settings - Fork 0
/
react.py
172 lines (135 loc) · 6.06 KB
/
react.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
from typing import List
import dotenv
import gym
import tiktoken
from langchain import OpenAI
from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate
from environment import QAEnv
from prompts import reflect_prompt, react_agent_prompt, react_reflect_agent_prompt, REFLECTION_HEADER
from fewshots import WEBTHINK_SIMPLE6, REFLECTIONS
dotenv.load_dotenv()
class ReactAgent:
"""
A question answering ReAct Agent.
"""
def __init__(self,
question: str,
env: QAEnv,
agent_prompt: PromptTemplate = react_agent_prompt,
react_llm: BaseLLM = OpenAI(
temperature=0,
max_tokens=100,
model_name="text-davinci-003",
model_kwargs={"stop": "\n"},
openai_api_key=os.environ['OPENAI_API_KEY']),
) -> None:
self.question = question
self.agent_prompt = agent_prompt
self.react_examples = WEBTHINK_SIMPLE6
self.env = env
self.env.reset()
self.reset()
self.truncated, self.reward, self.terminated = False, False, False
self.llm = react_llm
self.enc = tiktoken.encoding_for_model("text-davinci-003")
def run(self, reset = True) -> None:
if reset:
self.env.reset()
self.reset()
while not (self.is_truncated() or self.is_terminated()):
self.step()
def step(self) -> None:
# Think
self.scratchpad += f'\nThought {self.curr_step}:'
self.scratchpad += ' ' + self.prompt_agent()
print(self.scratchpad.split('\n')[-1])
# Act
self.scratchpad += f'\nAction {self.curr_step}:'
action = self.prompt_agent()
self.scratchpad += ' ' + action
print(self.scratchpad.split('\n')[-1])
# Observe
self.scratchpad += f'\nObservation {self.curr_step}: '
observation, self.reward, self.terminated, self.truncated, self.curr_step = self.env.step(action)
self.scratchpad += observation
print(self.scratchpad.split('\n')[-1])
def prompt_agent(self) -> str:
return format_step(self.llm(self._build_agent_prompt()))
def _build_agent_prompt(self) -> str:
return self.agent_prompt.format(
examples = self.react_examples,
question = self.question,
scratchpad = self.scratchpad)
def is_terminated(self) -> bool:
return self.env.is_terminated()
def is_correct(self) -> bool:
return self.env.is_correct()
def is_truncated(self) -> bool:
return self.env.is_truncated() or (len(self.enc.encode(self._build_agent_prompt())) > 3896)
def reset(self) -> None:
self.scratchpad = ''
self.curr_step = 1
class ReactReflectAgent(ReactAgent):
"""
A question answering Self-Reflecting React Agent.
"""
def __init__(self,
question: str,
env: QAEnv,
agent_prompt: PromptTemplate = react_reflect_agent_prompt,
reflect_prompt: PromptTemplate = reflect_prompt,
react_llm: BaseLLM = OpenAI(
temperature=0,
max_tokens=100,
model_name="text-davinci-003",
model_kwargs={"stop": "\n"},
openai_api_key=os.environ['OPENAI_API_KEY']),
reflect_llm: BaseLLM = OpenAI(
temperature=0,
max_tokens=250,
model_name="text-davinci-003",
openai_api_key=os.environ['OPENAI_API_KEY']),
) -> None:
super().__init__(question, env, agent_prompt, react_llm)
self.reflect_llm = reflect_llm
self.reflect_prompt = reflect_prompt
self.reflect_examples = REFLECTIONS
self.reflections = []
def run(self, reset = True) -> None:
if (self.is_terminated() or self.is_truncated()) and not self.is_correct():
self.reflect()
ReactAgent.run(self, reset)
def reflect(self) -> None:
self.reflections.append(self.prompt_reflection())
def prompt_reflection(self) -> str:
return format_step(self.reflect_llm(self._build_reflection_prompt()))
def _build_reflection_prompt(self) -> str:
return self.reflect_prompt.format(
examples = self.reflect_examples,
question = self.question,
scratchpad = self._format_scratchpad())
def _build_agent_prompt(self) -> str:
return self.agent_prompt.format(
examples = self.react_examples,
reflections = format_reflections(self.reflections),
question = self.question,
scratchpad = self.scratchpad)
def _format_scratchpad(self) -> str:
lines = self.scratchpad.split('\n')
lines_by_tokens = sorted(lines, key=lambda x: len(self.enc.encode(x)))
while len(self.enc.encode('\n'.join(lines))) > 1600:
ind = lines.index(lines_by_tokens.pop(-1))
line = lines[ind]
lines[ind] = line.split(':')[0] + ': ...'
return '\n'.join(lines)
### String Operations ###
def format_reflections(reflections: List[str]) -> str:
if reflections == []:
return ''
else:
header = REFLECTION_HEADER
return header + 'Reflections:\n- ' + '\n- '.join([r.strip() for r in reflections])
def format_step(step: str) -> str:
return step.strip('\n').strip().replace('\n', '')