Skip to content

Commit

Permalink
Add user profile tools and put the profile into the state
Browse files Browse the repository at this point in the history
  • Loading branch information
YoanSallami committed Jul 4, 2024
1 parent b5b2170 commit df9524f
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 17 deletions.
19 changes: 11 additions & 8 deletions hybridagi/agents/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(
# DSPy reasoners
# The interpreter model used to navigate, only contains decision signatures
# With that, DSPy should better optimize the graph navigation task
self.current_hop = 0
self.decision_hop = 0
self.decisions = [
dspy.ChainOfThought(DecisionSignature) for i in range(0, self.max_iters)
Expand All @@ -109,7 +110,7 @@ def run_step(self) -> Union[AgentAction, AgentDecision, ProgramCall, ProgramEnd]
"""Method to run a step of the program"""
current_node = self.agent_state.get_current_node()
if current_node:
self.agent_state.current_hop += 1
self.current_hop += 1
if current_node.labels[0] == "Program":
try:
program_purpose = current_node.properties["name"]
Expand Down Expand Up @@ -250,7 +251,7 @@ def act(
else:
self.agent_state.variables[output] = dict(prediction)[list(dict(prediction).keys())[0]]
action = AgentAction(
hop = self.agent_state.current_hop,
hop = self.current_hop,
objective = self.agent_state.objective,
purpose = purpose,
tool = tool,
Expand Down Expand Up @@ -305,7 +306,7 @@ def decide(self, purpose: str, question:str, options: List[str], inputs: List[st
)
next_node = result.result_set[0][0]
decision = AgentDecision(
hop = self.agent_state.current_hop,
hop = self.current_hop,
objective = self.agent_state.objective,
purpose = purpose,
question = question,
Expand All @@ -316,9 +317,9 @@ def decide(self, purpose: str, question:str, options: List[str], inputs: List[st
self.agent_state.set_current_node(next_node)
return decision

def forward(self, objective: str):
def forward(self, objective: str, user_profile: str = "An average user"):
"""DSPy forward prediction"""
self.start(objective)
self.start(objective, user_profile = user_profile)
for i in range(self.max_iters):
if not self.finished():
self.run_step()
Expand Down Expand Up @@ -377,7 +378,7 @@ def call_program(
first_node = self.program_memory.get_next_node(starting_node, program_called)
self.agent_state.call_program(first_node, program_called)
program_call = ProgramCall(
hop = self.agent_state.current_hop,
hop = self.current_hop,
purpose = purpose,
program = program_name,
)
Expand All @@ -394,19 +395,21 @@ def end_current_program(self) -> Optional[ProgramEnd]:
if next_node:
self.agent_state.set_current_node(next_node)
program_end = ProgramEnd(
hop = self.agent_state.current_hop,
hop = self.current_hop,
program = program_name,
)
return program_end

def start(self, objective: str):
def start(self, objective: str, user_profile: str = "An average User"):
"""Start the interpreter"""
self.agent_state.init()
self.current_hop = 0
self.decision_hop = 0
self.agent_state.chat_history.append(
{"role": "User", "message": objective}
)
self.agent_state.objective = objective
self.agent_state.user_profile = user_profile
first_step = self.call_program(objective, self.entrypoint)
self.agent_state.program_trace.append(str(first_step))
if self.trace_memory:
Expand Down
14 changes: 9 additions & 5 deletions hybridagi/hybridagi.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@

ReadFileTool,
ReadProgramTool,
ReadUserProfileTool,
RevertTraceTool,

SpeakTool,

UpdateObjectiveTool,
UpdateUserProfileTool,
UploadTool,

WriteFileTool,
Expand Down Expand Up @@ -182,7 +184,6 @@ def __init__(
),
AskUserTool(
agent_state = self.agent_state,
user_profile = self.user_profile,
),
BrowseWebsiteTool(),
CallProgramTool(
Expand Down Expand Up @@ -225,6 +226,9 @@ def __init__(
ReadProgramTool(
program_memory = self.program_memory,
),
ReadUserProfileTool(
agent_state = self.agent_state,
),
RevertTraceTool(
agent_state = self.agent_state,
),
Expand Down Expand Up @@ -276,9 +280,9 @@ def add_programs_from_folders(self, folders: List[str]):
def add_knowledge_from_folders(self, folders: List[str]):
self.knowleddge_loader.from_folders(folders)

def execute(self, objective: str, verbose: bool = False):
def execute(self, objective: str, user_profile: str = "An average user", verbose: bool = False):
self.interpreter.verbose = verbose
prediction = self.interpreter(objective)
prediction = self.interpreter(objective, user_profile = user_profile)
self.interpreter.verbose = self.verbose
return prediction

Expand All @@ -296,8 +300,8 @@ def optimize(
verbose: bool = False,
):
config = dict(
max_bootstrapped_demos=max_bootstrapped_demos,
max_labeled_demos=0,
max_bootstrapped_demos = max_bootstrapped_demos,
max_labeled_demos = 0,
)
if len(trainset) <= max_bootstrapped_demos:
optimizer = BootstrapFewShot(
Expand Down
6 changes: 6 additions & 0 deletions hybridagi/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from .clear_trace import ClearTraceTool
from .revert_trace import RevertTraceTool

from .update_user_profile import UpdateUserProfileTool
from .read_user_profile import ReadUserProfileTool

from .query_facts import QueryFactsTool

from .code_interpreter import CodeInterpreterTool
Expand Down Expand Up @@ -62,6 +65,9 @@
ClearTraceTool,
RevertTraceTool,

UpdateUserProfileTool,
ReadUserProfileTool,

QueryFactsTool,

CodeInterpreterTool,
Expand Down
2 changes: 1 addition & 1 deletion hybridagi/tools/ask_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class AskUserTool(BaseTool):
def __init__(
self,
agent_state: AgentState,
user_profile: str = "An average user",
ask_user_func: Optional[Callable[[str], str]] = None,
num_history: int = 50,
simulated: bool = True,
Expand All @@ -56,6 +55,7 @@ def simulate_ask_user(self, question: str):
chat_history = json.dumps(self.agent_state.chat_history[:-self.num_history], indent=2)
pred = self.simulate(
objective = self.agent_state.objective,
user_profile = self.agent_state.user_profile,
chat_history = chat_history,
question = question,
)
Expand Down
35 changes: 35 additions & 0 deletions hybridagi/tools/read_user_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""The read user profile tool. Copyright (C) 2024 SynaLinks. License: GPL-3.0"""

import copy
import dspy
from .base import BaseTool
from typing import Optional, Callable
from ..types.state import AgentState

class ReadUserProfileTool(BaseTool):

def __init__(
self,
agent_state: AgentState,
):
super().__init__(name = "ReadUserProfile")
self.agent_state = agent_state

def forward(
self,
context: str,
objective: str,
purpose: str,
prompt: str,
disable_inference: bool = False,
) -> dspy.Prediction:
"""Method to perform DSPy forward prediction"""
return dspy.Prediction(
user_profile = self.agent_state.user_profile,
)

def __deepcopy__(self, memo):
cpy = (type)(self)(
agent_state = self.agent_state,
)
return cpy
73 changes: 73 additions & 0 deletions hybridagi/tools/update_user_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""The update user profile tool. Copyright (C) 2024 SynaLinks. License: GPL-3.0"""

import copy
import dspy
from .base import BaseTool
from typing import Optional, Callable
from ..types.state import AgentState
from ..output_parsers.prediction import PredictionOutputParser

class UpdateUserProfileSignature(dspy.Signature):
"""You will be given an objective, purpose and context
Using the prompt to help you, you will infer the correct user profile
Note: Never give an apology or explain what you are doing."""
objective = dspy.InputField(desc = "The long-term objective (what you are doing)")
context = dspy.InputField(desc = "The previous actions (what you have done)")
purpose = dspy.InputField(desc = "The purpose of the action (what you have to do now)")
prompt = dspy.InputField(desc = "The action specific instructions (How to do it)")
user_profile = dspy.OutputField(desc = "The user profile")

class UpdateUserProfileTool(BaseTool):

def __init__(
self,
agent_state: AgentState,
lm: Optional[dspy.LM] = None,
):
super().__init__(name = "UpdateUserProfile")
self.agent_state = agent_state
self.prediction_parser = PredictionOutputParser()
self.predict = dspy.Predictor(UpdateUserProfileSignature)

def update_user_profile(self, user_profile: str):
self.agent_state.user_profile = user_profile

def forward(
self,
context: str,
objective: str,
purpose: str,
prompt: str,
disable_inference: bool = False,
) -> dspy.Prediction:
"""Method to perform DSPy forward prediction"""
if not disable_inference:
pred = self.predict(
objective = objective,
context = context,
purpose = purpose,
prompt = prompt,
)
pred.user_profile = self.prediction_parser(
pred.user_profile,
prefix="User Profile:",
)
self.update_user_profile(pred.user_profile)
return dspy.Prediction(
user_profile = pred.user_profile,
observation = "Successfully updated",
)
else:
self.update_user_profile(prompt)
return dspy.Prediction(
user_profile = prompt,
observation = "Successfully updated",
)

def __deepcopy__(self, memo):
cpy = (type)(self)(
agent_state = self.agent_state,
lm = self.lm,
)
return cpy
6 changes: 3 additions & 3 deletions hybridagi/types/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class AgentState():

def __init__(self):
self.init()
self.user_profile = "An average user"

def get_current_program(self) -> Optional[Graph]:
"""Method to retreive the current program from the stack"""
Expand Down Expand Up @@ -43,10 +44,9 @@ def update_variable(self, key:str, value:str):

def init(self):
self.objective = "N/A"
self.current_hop = 0
self.user_profile = "An average user"
self.program_trace = []
self.program_stack = deque()
self.context = FileSystemContext()
self.chat_history = []
self.variables = {}
self.plots = []
self.variables = {}

0 comments on commit df9524f

Please sign in to comment.