-
Notifications
You must be signed in to change notification settings - Fork 22
/
run_agents.py
174 lines (156 loc) · 6.12 KB
/
run_agents.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import copy
import json
import os
import traceback
from datetime import datetime
from queue import Queue
from typing import Any, Callable, Dict, List
import pandas as pd
from datasets import Dataset
from langchain.agents import AgentExecutor
from langchain.tools.base import ToolException
from tqdm import tqdm
from transformers.agents.agents import AgentError
def run_agent(
example: Dict,
agent_executor: AgentExecutor,
agent_name: str,
agent_call_function: Callable,
writer_queue: Queue = None,
**kwargs,
) -> dict:
start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
augmented_question = example["augmented_question"]
try:
# run executor agent
response = agent_call_function(agent_executor, augmented_question, **kwargs)
# check for parsing errors which indicate the LLM failed to follow the ReACT format
# this could be due to an issue with the tool calling format or ReACT formatting (i.e. Thought, Action, Observation, etc.)
parsing_error = (
True
if any(
[
"Could not parse LLM output" in step
for step in response["intermediate_steps"]
]
)
else False
)
# check if iteration limit exceeded
iteration_limit_exceeded = (
True
if "Agent stopped due to iteration limit or time limit." in response["output"]
else False
)
raised_exception = False
except (ValueError, ToolException) as e:
print("Error on ", augmented_question, e)
response = {"output": None, "intermediate_steps": None, "metrics": {}}
parsing_error = False
iteration_limit_exceeded = False
exception = e
raised_exception = True
end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
intermediate_steps = response["intermediate_steps"]
metrics = response["metrics"]
annotated_example = {
"agent_name": agent_name,
"question": example["question"],
"augmented_question": augmented_question,
"prediction": response["output"],
"intermediate_steps": intermediate_steps,
"parsing_error": parsing_error,
"iteration_limit_exceeded": iteration_limit_exceeded,
"agent_error": str(exception) if raised_exception else None,
"start_time": start_time,
"end_time": end_time,
"level": example["level"],
"true_answer": example["true_answer"],
"metrics": metrics,
}
temp = copy.deepcopy(example)
temp.update(annotated_example)
annotated_example = temp
if writer_queue:
writer_queue.put(annotated_example)
return annotated_example
def serialize_agent_error(obj):
if isinstance(obj, AgentError):
return {"error_type": obj.__class__.__name__, "message": obj.message}
else:
return str(obj)
def answer_questions(
dataset: Dataset,
agent: AgentExecutor,
agent_name: str,
agent_call_function: Callable,
output_folder: str = "output",
) -> List[Dict[str, Any]]:
"""
Evaluates the agent on a given dataset.
Args:
dataset (Dataset): The dataset to test the agent on.
agent: The agent.
agent_name (str): The name of the agent model.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the evaluation results for each example in the dataset.
Each dictionary includes the agent model ID, evaluator model ID, question, ground truth answer, prediction,
intermediate steps, evaluation score, evaluation feedback, tool call parsing error flag, iteration limit
exceeded flag, agent error (if any), and example metadata (task).
"""
os.makedirs(output_folder, exist_ok=True)
output_path = f"{output_folder}/{agent_name}.jsonl"
print(f"Loading answers from {output_path}...")
if os.path.exists(output_path):
results = pd.read_json(output_path, lines=True).to_dict(orient="records")
print(f"Found {len(results)} previous results!")
else:
print("Found no usable records! 🤔 Starting new")
results = []
results_df = pd.DataFrame(results)
for _, example in tqdm(enumerate(dataset), total=len(dataset)):
if len(results_df) > 0:
if example["question"] in results_df["question"].unique():
continue
try:
prompt_use_files = ""
if example["file_name"]:
prompt_use_files += f"\n\nTo answer the question above, you will have to use these attached files:"
prompt_use_files += f"\nAttached file: {example['file_name']}"
else:
prompt_use_files += "\n\nYou have been given no local files to access."
example["augmented_question"] = (
f"""It is paramount that you complete this task and provide a correct answer.
Give it all you can: I know for a fact that you have access to all the relevant tools to solve it. Failure or 'I cannot answer' will not be tolerated, success will be rewarded.
Here is the task:
"""
+ example["question"]
+ prompt_use_files
)
# run agent
result = run_agent(
example=example,
agent_executor=agent,
agent_name=agent_name,
agent_call_function=agent_call_function,
)
except Exception as e:
# raise Exception
error_trace = ("\n\n" + traceback.format_exc()).strip()
result = example
result["error_trace"] = error_trace
# add in example metadata
result.update(
{
"true_answer": example["true_answer"],
"level": example["level"],
}
)
results.append(result)
with open(output_path, "w") as f:
for d in results:
json.dump(d, f, default=serialize_agent_error)
f.write("\n") # add a newline for JSONL format
# except Exception as e:
# print("EXCEPTION!!!!=================\nFIND THE EXCEPTION LOG BELOW:\n", e)
return results