From df9524f2bf8af39894cf1b844bca8d993cc87453 Mon Sep 17 00:00:00 2001 From: YoanSallami Date: Thu, 4 Jul 2024 12:48:37 +0200 Subject: [PATCH] Add user profile tools and put the profile into the state --- hybridagi/agents/interpreter.py | 19 ++++--- hybridagi/hybridagi.py | 14 +++-- hybridagi/tools/__init__.py | 6 +++ hybridagi/tools/ask_user.py | 2 +- hybridagi/tools/read_user_profile.py | 35 ++++++++++++ hybridagi/tools/update_user_profile.py | 73 ++++++++++++++++++++++++++ hybridagi/types/state.py | 6 +-- 7 files changed, 138 insertions(+), 17 deletions(-) create mode 100644 hybridagi/tools/read_user_profile.py create mode 100644 hybridagi/tools/update_user_profile.py diff --git a/hybridagi/agents/interpreter.py b/hybridagi/agents/interpreter.py index cd6a941..3d450e3 100644 --- a/hybridagi/agents/interpreter.py +++ b/hybridagi/agents/interpreter.py @@ -95,6 +95,7 @@ def __init__( # DSPy reasoners # The interpreter model used to navigate, only contains decision signatures # With that, DSPy should better optimize the graph navigation task + self.current_hop = 0 self.decision_hop = 0 self.decisions = [ dspy.ChainOfThought(DecisionSignature) for i in range(0, self.max_iters) @@ -109,7 +110,7 @@ def run_step(self) -> Union[AgentAction, AgentDecision, ProgramCall, ProgramEnd] """Method to run a step of the program""" current_node = self.agent_state.get_current_node() if current_node: - self.agent_state.current_hop += 1 + self.current_hop += 1 if current_node.labels[0] == "Program": try: program_purpose = current_node.properties["name"] @@ -250,7 +251,7 @@ def act( else: self.agent_state.variables[output] = dict(prediction)[list(dict(prediction).keys())[0]] action = AgentAction( - hop = self.agent_state.current_hop, + hop = self.current_hop, objective = self.agent_state.objective, purpose = purpose, tool = tool, @@ -305,7 +306,7 @@ def decide(self, purpose: str, question:str, options: List[str], inputs: List[st ) next_node = result.result_set[0][0] decision = AgentDecision( - hop = self.agent_state.current_hop, + hop = self.current_hop, objective = self.agent_state.objective, purpose = purpose, question = question, @@ -316,9 +317,9 @@ def decide(self, purpose: str, question:str, options: List[str], inputs: List[st self.agent_state.set_current_node(next_node) return decision - def forward(self, objective: str): + def forward(self, objective: str, user_profile: str = "An average user"): """DSPy forward prediction""" - self.start(objective) + self.start(objective, user_profile = user_profile) for i in range(self.max_iters): if not self.finished(): self.run_step() @@ -377,7 +378,7 @@ def call_program( first_node = self.program_memory.get_next_node(starting_node, program_called) self.agent_state.call_program(first_node, program_called) program_call = ProgramCall( - hop = self.agent_state.current_hop, + hop = self.current_hop, purpose = purpose, program = program_name, ) @@ -394,19 +395,21 @@ def end_current_program(self) -> Optional[ProgramEnd]: if next_node: self.agent_state.set_current_node(next_node) program_end = ProgramEnd( - hop = self.agent_state.current_hop, + hop = self.current_hop, program = program_name, ) return program_end - def start(self, objective: str): + def start(self, objective: str, user_profile: str = "An average User"): """Start the interpreter""" self.agent_state.init() + self.current_hop = 0 self.decision_hop = 0 self.agent_state.chat_history.append( {"role": "User", "message": objective} ) self.agent_state.objective = objective + self.agent_state.user_profile = user_profile first_step = self.call_program(objective, self.entrypoint) self.agent_state.program_trace.append(str(first_step)) if self.trace_memory: diff --git a/hybridagi/hybridagi.py b/hybridagi/hybridagi.py index 35ecd8f..a66ae16 100644 --- a/hybridagi/hybridagi.py +++ b/hybridagi/hybridagi.py @@ -43,11 +43,13 @@ ReadFileTool, ReadProgramTool, + ReadUserProfileTool, RevertTraceTool, SpeakTool, UpdateObjectiveTool, + UpdateUserProfileTool, UploadTool, WriteFileTool, @@ -182,7 +184,6 @@ def __init__( ), AskUserTool( agent_state = self.agent_state, - user_profile = self.user_profile, ), BrowseWebsiteTool(), CallProgramTool( @@ -225,6 +226,9 @@ def __init__( ReadProgramTool( program_memory = self.program_memory, ), + ReadUserProfileTool( + agent_state = self.agent_state, + ), RevertTraceTool( agent_state = self.agent_state, ), @@ -276,9 +280,9 @@ def add_programs_from_folders(self, folders: List[str]): def add_knowledge_from_folders(self, folders: List[str]): self.knowleddge_loader.from_folders(folders) - def execute(self, objective: str, verbose: bool = False): + def execute(self, objective: str, user_profile: str = "An average user", verbose: bool = False): self.interpreter.verbose = verbose - prediction = self.interpreter(objective) + prediction = self.interpreter(objective, user_profile = user_profile) self.interpreter.verbose = self.verbose return prediction @@ -296,8 +300,8 @@ def optimize( verbose: bool = False, ): config = dict( - max_bootstrapped_demos=max_bootstrapped_demos, - max_labeled_demos=0, + max_bootstrapped_demos = max_bootstrapped_demos, + max_labeled_demos = 0, ) if len(trainset) <= max_bootstrapped_demos: optimizer = BootstrapFewShot( diff --git a/hybridagi/tools/__init__.py b/hybridagi/tools/__init__.py index 55908bb..0f346e0 100644 --- a/hybridagi/tools/__init__.py +++ b/hybridagi/tools/__init__.py @@ -27,6 +27,9 @@ from .clear_trace import ClearTraceTool from .revert_trace import RevertTraceTool +from .update_user_profile import UpdateUserProfileTool +from .read_user_profile import ReadUserProfileTool + from .query_facts import QueryFactsTool from .code_interpreter import CodeInterpreterTool @@ -62,6 +65,9 @@ ClearTraceTool, RevertTraceTool, + UpdateUserProfileTool, + ReadUserProfileTool, + QueryFactsTool, CodeInterpreterTool, diff --git a/hybridagi/tools/ask_user.py b/hybridagi/tools/ask_user.py index 7223057..a32cf39 100644 --- a/hybridagi/tools/ask_user.py +++ b/hybridagi/tools/ask_user.py @@ -30,7 +30,6 @@ class AskUserTool(BaseTool): def __init__( self, agent_state: AgentState, - user_profile: str = "An average user", ask_user_func: Optional[Callable[[str], str]] = None, num_history: int = 50, simulated: bool = True, @@ -56,6 +55,7 @@ def simulate_ask_user(self, question: str): chat_history = json.dumps(self.agent_state.chat_history[:-self.num_history], indent=2) pred = self.simulate( objective = self.agent_state.objective, + user_profile = self.agent_state.user_profile, chat_history = chat_history, question = question, ) diff --git a/hybridagi/tools/read_user_profile.py b/hybridagi/tools/read_user_profile.py new file mode 100644 index 0000000..7818a6d --- /dev/null +++ b/hybridagi/tools/read_user_profile.py @@ -0,0 +1,35 @@ +"""The read user profile tool. Copyright (C) 2024 SynaLinks. License: GPL-3.0""" + +import copy +import dspy +from .base import BaseTool +from typing import Optional, Callable +from ..types.state import AgentState + +class ReadUserProfileTool(BaseTool): + + def __init__( + self, + agent_state: AgentState, + ): + super().__init__(name = "ReadUserProfile") + self.agent_state = agent_state + + def forward( + self, + context: str, + objective: str, + purpose: str, + prompt: str, + disable_inference: bool = False, + ) -> dspy.Prediction: + """Method to perform DSPy forward prediction""" + return dspy.Prediction( + user_profile = self.agent_state.user_profile, + ) + + def __deepcopy__(self, memo): + cpy = (type)(self)( + agent_state = self.agent_state, + ) + return cpy diff --git a/hybridagi/tools/update_user_profile.py b/hybridagi/tools/update_user_profile.py new file mode 100644 index 0000000..140879e --- /dev/null +++ b/hybridagi/tools/update_user_profile.py @@ -0,0 +1,73 @@ +"""The update user profile tool. Copyright (C) 2024 SynaLinks. License: GPL-3.0""" + +import copy +import dspy +from .base import BaseTool +from typing import Optional, Callable +from ..types.state import AgentState +from ..output_parsers.prediction import PredictionOutputParser + +class UpdateUserProfileSignature(dspy.Signature): + """You will be given an objective, purpose and context + Using the prompt to help you, you will infer the correct user profile + + Note: Never give an apology or explain what you are doing.""" + objective = dspy.InputField(desc = "The long-term objective (what you are doing)") + context = dspy.InputField(desc = "The previous actions (what you have done)") + purpose = dspy.InputField(desc = "The purpose of the action (what you have to do now)") + prompt = dspy.InputField(desc = "The action specific instructions (How to do it)") + user_profile = dspy.OutputField(desc = "The user profile") + +class UpdateUserProfileTool(BaseTool): + + def __init__( + self, + agent_state: AgentState, + lm: Optional[dspy.LM] = None, + ): + super().__init__(name = "UpdateUserProfile") + self.agent_state = agent_state + self.prediction_parser = PredictionOutputParser() + self.predict = dspy.Predictor(UpdateUserProfileSignature) + + def update_user_profile(self, user_profile: str): + self.agent_state.user_profile = user_profile + + def forward( + self, + context: str, + objective: str, + purpose: str, + prompt: str, + disable_inference: bool = False, + ) -> dspy.Prediction: + """Method to perform DSPy forward prediction""" + if not disable_inference: + pred = self.predict( + objective = objective, + context = context, + purpose = purpose, + prompt = prompt, + ) + pred.user_profile = self.prediction_parser( + pred.user_profile, + prefix="User Profile:", + ) + self.update_user_profile(pred.user_profile) + return dspy.Prediction( + user_profile = pred.user_profile, + observation = "Successfully updated", + ) + else: + self.update_user_profile(prompt) + return dspy.Prediction( + user_profile = prompt, + observation = "Successfully updated", + ) + + def __deepcopy__(self, memo): + cpy = (type)(self)( + agent_state = self.agent_state, + lm = self.lm, + ) + return cpy diff --git a/hybridagi/types/state.py b/hybridagi/types/state.py index 8cea898..b9ec667 100644 --- a/hybridagi/types/state.py +++ b/hybridagi/types/state.py @@ -9,6 +9,7 @@ class AgentState(): def __init__(self): self.init() + self.user_profile = "An average user" def get_current_program(self) -> Optional[Graph]: """Method to retreive the current program from the stack""" @@ -43,10 +44,9 @@ def update_variable(self, key:str, value:str): def init(self): self.objective = "N/A" - self.current_hop = 0 + self.user_profile = "An average user" self.program_trace = [] self.program_stack = deque() self.context = FileSystemContext() self.chat_history = [] - self.variables = {} - self.plots = [] \ No newline at end of file + self.variables = {} \ No newline at end of file