Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smurty version #170

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
125 changes: 125 additions & 0 deletions src/agentlab/agents/generic_agent/generic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,128 @@ def get_action_post_hoc(agent: GenericAgent, obs: dict, ans_dict: dict):
output += f"\n<action>\n{action}\n</action>"

return system_prompt, instruction_prompt, output


def get_action_post_hoc(agent: GenericAgent, step_info):
"""
Comment on lines +263 to +266
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be safely ported to the unsupervised project

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove this too. This is some old stuff i was trying

Get the action post-hoc for the agent.

This function is used to get the action after the agent has already been run.
Its goal is to recreate the prompt and the output of the agent a posteriori.
The purpose is to build datasets for training the agents.

Args:
agent (GenericAgent): The agent for which the action is being determined.
obs (dict): The observation dictionary to append to the agent's history.
ans_dict (dict): The answer dictionary containing the plan, step, memory, think, and action.

Returns:
Tuple[str, str]: The complete prompt used for the agent and the reconstructed output based on the answer dictionary.
"""
system_prompt = dp.SystemPrompt().prompt

agent.obs_history.append(step_info.obs)

main_prompt = MainPrompt(
action_set=agent.action_set,
obs_history=agent.obs_history,
actions=agent.actions,
memories=agent.memories,
thoughts=agent.thoughts,
previous_plan=agent.plan,
step=agent.plan_step,
flags=agent.flags,
)

max_prompt_tokens, max_trunc_itr = agent._get_maxes()

fit_function = partial(
dp.fit_tokens,
max_prompt_tokens=max_prompt_tokens,
model_name=agent.chat_model_args.model_name,
max_iterations=max_trunc_itr,
)

instruction_prompt = fit_function(shrinkable=main_prompt)

if isinstance(instruction_prompt, list):
# NOTE: this is when we have images
instruction_prompt = instruction_prompt[0]["text"]

def parser(text):
try:
ans_dict = main_prompt._parse_answer(text)
except ParseError as e:
# these parse errors will be caught by the retry function and
# the chat_llm will have a chance to recover
return None, False, str(e)
return ans_dict, True, ""

og_agent_output = step_info.agent_info["chat_messages"][-1].content
if og_agent_output.startswith("assistant\n"):
og_agent_output = og_agent_output[10:]

ans_dict = parser(og_agent_output)[0]

# self.plan = ans_dict.get("plan", self.plan)
# self.plan_step = ans_dict.get("step", self.plan_step)
# self.actions.append(ans_dict["action"])
# self.memories.append(ans_dict.get("memory", None))
# self.thoughts.append(ans_dict.get("think", None))

agent_output = ""

# TODO: validate this
thought = ans_dict.get("think", None)
agent.thoughts.append(thought)
if thought is not None:
agent_output += f"\n<think>\n{thought}\n</think>\n"

agent.plan = ans_dict.get("plan", agent.plan)
if agent.plan != "No plan yet":
agent_output += f"\n<plan>\n{agent.plan}\n</plan>\n"

agent.plan_step = ans_dict.get("step", agent.plan_step)
if agent.plan_step != -1:
agent_output += f"\n<step>{agent.plan_step}</step>\n"

memory = ans_dict.get("memory", None)
agent.memories.append(memory)
if memory is not None:
agent_output += f"\n<memory>\n{memory}\n</memory>\n"

action = step_info.action
agent.actions.append(action)
if action is not None:
agent_output += f"\n<action>\n{action}\n</action>"

def find_bid(string):
# Try to find 'a' followed by digits within single or double quotes
match = re.search(r"[\"'](a\d+)[\"']", string)

# If not found, search digits within single or double quotes
if not match:
match = re.search(r"[\"'](\d+)[\"']", string)

# Return the matched pattern or None if no match found
if match:
return match.group(1) # Return the match inside the quotes
else:
return None

# TODO: finish this
bid = find_bid(action)
if bid is not None:
if bid not in instruction_prompt:
logging.info("Bid is not in the instruction prompt.")
return "missing_bid"

# NOTE: keep in mind the original agent output can be more verbose
if agent_output not in og_agent_output:
logging.info("Agent output does exactly not match the last chat message.")
if not set(agent_output.split()).issubset(set(og_agent_output.split())):
logging.info("Agent output does not match the last chat message.")
return "action_output_mismatch"

# TODO: make sure the bid is in the prompt
return (system_prompt, instruction_prompt, agent_output)
26 changes: 18 additions & 8 deletions src/agentlab/analyze/agent_xray.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about what these changes do. Worst case you could copy this file and run it from your laptop. We could also make changes to the main x-ray if they are generic use cases

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were some issues with the earlier version of agentlab, which this fixes. But now, i think we can discard these

Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from logging import warning
from pathlib import Path

from finetuning.data import data_collection_library
import gradio as gr
import matplotlib.patches as patches
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -220,7 +221,7 @@ def run_gradio(results_dir: Path):
content. You have to sort back with the Idx column to align the click with
the order."""
)
agent_table = gr.DataFrame(height=500, show_label=False, interactive=False)
agent_table = gr.DataFrame(max_height=500, show_label=False, interactive=False)
with gr.Tab("Select Task and Seed", id="Select Task"):
with gr.Row():
with gr.Column(scale=4):
Expand All @@ -236,7 +237,9 @@ def run_gradio(results_dir: Path):
)
refresh_results_button = gr.Button("↺", scale=0, size="sm")

task_table = gr.DataFrame(height=500, show_label=False, interactive=False)
task_table = gr.DataFrame(
max_height=500, show_label=False, interactive=False
)

with gr.Column(scale=2):
with gr.Accordion("Seed Selector (click for help)", open=False):
Expand All @@ -249,7 +252,9 @@ def run_gradio(results_dir: Path):
the order."""
)

seed_table = gr.DataFrame(height=500, show_label=False, interactive=False)
seed_table = gr.DataFrame(
max_height=500, show_label=False, interactive=False
)

with gr.Tab("Constants and Variables"):
with gr.Row():
Expand All @@ -261,7 +266,9 @@ def run_gradio(results_dir: Path):
**all** agents. They are displayed as a table with the name and value of the
constant."""
)
constants = gr.DataFrame(height=500, show_label=False, interactive=False)
constants = gr.DataFrame(
max_height=500, show_label=False, interactive=False
)
with gr.Column(scale=2):
with gr.Accordion("Variables", open=False):
gr.Markdown(
Expand All @@ -270,9 +277,11 @@ def run_gradio(results_dir: Path):
They are displayed as a table with the name, value and count of unique
values. A maximum of 3 different values are displayed."""
)
variables = gr.DataFrame(height=500, show_label=False, interactive=False)
variables = gr.DataFrame(
max_height=500, show_label=False, interactive=False
)
with gr.Tab("Global Stats"):
global_stats = gr.DataFrame(height=500, show_label=False, interactive=False)
global_stats = gr.DataFrame(max_height=500, show_label=False, interactive=False)

with gr.Row():
episode_info = gr.Markdown(label="Episode Info", elem_classes="my-markdown")
Expand Down Expand Up @@ -345,7 +354,7 @@ def run_gradio(results_dir: Path):
logs = gr.Code(language=None, **code_args)

with gr.Tab("Stats") as tab_stats:
stats = gr.DataFrame(height=500, show_label=False, interactive=False)
stats = gr.DataFrame(max_height=500, show_label=False, interactive=False)

with gr.Tab("Agent Info HTML") as tab_agent_info_html:
with gr.Row():
Expand Down Expand Up @@ -1131,7 +1140,8 @@ def plot_profiling(ax, step_info_list: list[StepInfo], summary_info: dict, progr


def main():
run_gradio(RESULTS_DIR)
# run_gradio(RESULTS_DIR)
run_gradio(data_collection_library.WORKARENA_V1_TRACES_PATHS[0])


if __name__ == "__main__":
Expand Down
Loading