Skip to content

Commit

Permalink
refactor: revise workflow for react agent (agiresearch#197)
Browse files Browse the repository at this point in the history
  • Loading branch information
dongyuanjushi authored Jul 30, 2024
1 parent b019261 commit bae8bac
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 34 deletions.
2 changes: 1 addition & 1 deletion aios/llm_core/llm_classes/gemini_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def process(self,
)
else:
if message_return_type == "json":
result = self.json_parse_format(result)
result = self.parse_json_format(result)
agent_process.set_response(
Response(
response_message=result,
Expand Down
2 changes: 2 additions & 0 deletions aios/llm_core/llm_classes/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def process(self,
# print(response)
result = response[0].outputs[0].text

print(f"***** Result: {result} *****")

tool_calls = self.parse_tool_calls(
result
)
Expand Down
12 changes: 9 additions & 3 deletions pyopenagi/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self,
agent_process_factory,
log_mode: str
):

self.agent_name = agent_name
self.config = self.load_config()
self.tool_names = self.config["tools"]
Expand All @@ -47,6 +48,8 @@ def __init__(self,

self.tool_list = dict()
self.tools = []
self.tool_info = [] # simplified information of the tool: {"name": "xxx", "description": "xxx"}

self.load_tools(self.tool_names)

self.start_time = None
Expand Down Expand Up @@ -130,16 +133,19 @@ def snake_to_camel(self, snake_str):
return ''.join(x.title() for x in components)

def load_tools(self, tool_names):

for tool_name in tool_names:
org, name = tool_name.split("/")
module_name = ".".join(["pyopenagi", "tools", org, name])
class_name = self.snake_to_camel(name)

tool_module = importlib.import_module(module_name)
tool_class = getattr(tool_module, class_name)

self.tool_list[name] = tool_class()
self.tools.append(tool_class().get_tool_call_format())
tool_format = tool_class().get_tool_call_format()
self.tools.append(tool_format)
self.tool_info.append(
{"name": tool_format["function"]["name"], "description": tool_format["function"]["description"]}
)

def pre_select_tools(self, tool_names):
pre_selected_tools = []
Expand Down
43 changes: 14 additions & 29 deletions pyopenagi/agents/react_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ def __init__(self,
log_mode
)

# self.tool_list = {}

self.plan_max_fail_times = 3
self.tool_call_max_fail_times = 3

Expand All @@ -33,10 +31,11 @@ def build_system_instruction(self):
"".join(self.config["description"])
]
)

plan_instruction = "".join(
[
f'You are given the available tools from the tool list: {json.dumps(self.tools)} to help you solve problems.',
'Generate a plan of steps you need to take.',
f'You are given the available tools from the tool list: {json.dumps(self.tool_info)} to help you solve problems. ',
'Generate a plan of steps you need to take. ',
'The plan must follow the json format as: ',
'[',
'{"message": "message_value1","tool_use": [tool_name1, tool_name2,...]}',
Expand All @@ -46,45 +45,27 @@ def build_system_instruction(self):
'In each step of the planned workflow, you must select the most related tool to use',
'Followings are some plan examples:',
'[',
'{"message": "Gather information from arxiv", "tool_use": ["arxiv"]},',
'{"message", "Based on the gathered information, write a summarization", "tool_use": []}',
'];',
'[',
'{"message": "identify the tool that you need to call to obtain information.", "tool_use": ["imdb_top_movies", "imdb_top_series"]},',
'{"message", "based on the information, give recommendations for the user based on the constrains.", "tool_use": []}',
'{"message": "gather information from arxiv. ", "tool_use": ["arxiv"]},',
'{"message", "write a summarization based on the gathered information. ", "tool_use": []}',
'];',
'[',
'{"message": "identify the tool that you need to call to obtain information.", "tool_use": ["imdb_top_movies", "imdb_top_series"]},',
'{"message", "based on the information, give recommendations for the user based on the constrains.", "tool_use": []}',
'{"message": "identify the tool that you need to call to obtain information. ", "tool_use": ["imdb_top_movies", "imdb_top_series"]},',
'{"message", "give recommendations for the user based on the information. ", "tool_use": []}',
'];',
'[',
'{"message": "identify the tool that you need to call to obtain information.", "tool_use": ["imdb_top_movies", "imdb_top_series"]},'
'{"message", "based on the information, give recommendations for the user based on the constrains.", "tool_use": []}',
']'
]
)
# exection_instruction = "".join(
# [
# 'To execute each step in the workflow, you need to output as the following json format:',
# '{"[Action]": "Your action that is indended to take",',
# '"[Observation]": "What will you do? If you will call an external tool, give a valid tool call of the tool name and tool parameters"}'
# ]
# )

if self.workflow_mode == "manual":
self.messages.append(
{"role": "system", "content": prefix}
)
# self.messages.append(
# {"role": "user", "content": exection_instruction}
# )

else:
assert self.workflow_mode == "automatic"
self.messages.append(
{"role": "system", "content": prefix + plan_instruction}
)
# self.messages.append(
# {"role": "user", "content": plan_instruction}
# )


def automatic_workflow(self):
return super().automatic_workflow()
Expand Down Expand Up @@ -136,6 +117,10 @@ def run(self):
{"role": "assistant", "content": f"[Thinking]: The workflow generated for the problem is {json.dumps(workflow)}"}
)

self.messages.append(
{"role": "user", "content": "[Thinking]: Follow the workflow to solve the problem step by step. "}
)

self.logger.log(f"Generated workflow is: {workflow}\n", level="info")

if workflow:
Expand Down
2 changes: 1 addition & 1 deletion pyopenagi/tools/stability-ai/sdxl_turbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_tool_call_format(self):
"type": "function",
"function": {
"name": "sdxl_turbo",
"description": "generate images by calling sdxl-turbo model",
"description": "generate images with the given texts",
"parameters": {
"type": "object",
"properties": {
Expand Down

0 comments on commit bae8bac

Please sign in to comment.