-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathagent_eval.py
184 lines (162 loc) · 6.88 KB
/
agent_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
from typing import Dict, List, Optional
from autogen_core import CancellationToken
from autogen_core.models import ChatCompletionClient
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.conditions import TextMentionTermination, MaxMessageTermination
from autogen_agentchat.ui import Console
from autogen_agentchat.messages import TextMessage
from .criterion import Criterion
from .critic_agent import CriticAgent
from .quantifier_agent import QuantifierAgent
from .subcritic_agent import SubCriticAgent
from .verifier_agent import CriticSummarizerAgent
from .quantification import Quantification
from .task import Task
async def generate_criteria(
model_client: ChatCompletionClient,
task: Task = None,
additional_instructions: str = "",
max_round=2,
use_subcritic: bool = False,
):
"""
Creates a list of criteria for evaluating the utility of a given task.
Args:
model_client (ChatCompletionClient): The model client for ChatCompletion inference.
task (Task): The task to evaluate.
additional_instructions (str): Additional instructions for the criteria agent.
max_round (int): The maximum number of rounds to run the conversation.
use_subcritic (bool): Whether to use the subcritic agent to generate subcriteria.
Returns:
list: A list of Criterion objects for evaluating the utility of the given task.
"""
critic = CriticAgent(
system_message=CriticAgent.DEFAULT_SYSTEM_MESSAGE + "\n" + additional_instructions,
model_client=model_client,
)
agents = [critic]
if use_subcritic:
subcritic = SubCriticAgent(
model_client=model_client,
)
agents.append(subcritic)
text_mention_termination = TextMentionTermination("TERMINATE")
max_messages_termination = MaxMessageTermination(max_messages=max_round)
termination = text_mention_termination | max_messages_termination
team = RoundRobinGroupChat(agents, termination_condition=termination)
group_chat_messages = team.run_stream(task=task.get_sys_message())
criteria_messages = await Console(group_chat_messages)
content = criteria_messages.messages[-1].content
# need to strip out any extra code around the returned json
content = content[content.find("[") : content.rfind("]") + 1]
criteria = Criterion.parse_json_str(content)
return criteria
async def generate_summarized_criteria_multiple_seeds(
model_client: ChatCompletionClient,
task: Task = None,
additional_instructions: str = "",
max_round=2,
use_subcritic: bool = False,
seed: Optional[int] = 10,
) -> List[Criterion]:
"""
Creates a list of summarized criteria by running the generate_criteria multiple times (seed times) and then summarize the results.
Args:
model_client (ChatCompletionClient): The model client for ChatCompletion inference.
task (Task): The task to evaluate.
additional_instructions (str): Additional instructions for the criteria agent.
max_round (int): The maximum number of rounds to run the conversation.
use_subcritic (bool): Whether to use the subcritic agent to generate subcriteria.
seed (int): The number of times to run the generate_criteria function.
Returns:
list: A list of Criterion objects for evaluating the utility of the given task.
"""
all_criteria = ""
for i in range(seed):
criteria = await generate_criteria(
model_client=model_client,
task=task,
additional_instructions=additional_instructions,
max_round=max_round,
use_subcritic=use_subcritic,
)
all_criteria += Criterion.write_json(criteria) + "\n"
summarized_criteria_agent = CriticSummarizerAgent(
model_client=model_client,
)
response = await summarized_criteria_agent.on_messages(
[TextMessage(
source="summarized_criteria_user",
content=all_criteria,
)],
cancellation_token=CancellationToken(),
)
content = response.chat_message.content
content = content[content.find("[") : content.rfind("]") + 1]
summarized_criteria = Criterion.parse_json_str(content)
return summarized_criteria
async def quantify_criteria(
model_client: ChatCompletionClient,
criteria: List[Criterion] = None,
task: Task = None,
test_case: str = "",
ground_truth: str = "",
):
"""
Quantifies the performance of a system using the provided criteria.
Args:
model_client (ChatCompletionClient): The model client for ChatCompletion inference.
criteria ([Criterion]): A list of criteria for evaluating the utility of a given task.
task (Task): The task to evaluate.
test_case (str): The test case to evaluate.
ground_truth (str): The ground truth for the test case.
Returns:
dict: A dictionary where the keys are the criteria and the values are the assessed performance based on accepted values for each criteria.
"""
quantifier = QuantifierAgent(
model_client=model_client,
)
response = await quantifier.on_messages(
[TextMessage(
source="quantifier_user",
content=task.get_sys_message()
+ "Evaluation dictionary: "
+ Criterion.write_json(criteria)
+ "actual test case to evaluate: "
+ test_case,
)],
cancellation_token=CancellationToken(),
)
quantified_results = response.chat_message.content
return {"actual_success": ground_truth, "estimated_performance": quantified_results}
async def quantify_criteria_multiple_seeds(
model_client: ChatCompletionClient,
criteria: List[Criterion] = None,
task: Task = None,
test_case: str = "",
ground_truth: str = "",
seeds: int = 10,
) -> Dict[int, List[Quantification]]:
"""
Quantifies the performance of a system using the provided criteria multiple times (seeds) and returns all the results.
Args:
model_client (ChatCompletionClient): The model client for ChatCompletion inference.
criteria ([Criterion]): A list of criteria for evaluating the utility of a given task.
task (Task): The task to evaluate.
test_case (str): The test case to evaluate.
ground_truth (str): The ground truth for the test case.
seeds (int): The number of times to run the quantify_criteria function.
Returns:
dict: A dictionary where the keys are the seeds and the values are the quantified results.
"""
results = {}
for i in range(seeds):
response = await quantify_criteria(
model_client=model_client,
criteria=criteria,
task=task,
test_case=test_case,
ground_truth=ground_truth,
)
results[i] = Quantification.parse_json_str(response["estimated_performance"])
return results