-
Notifications
You must be signed in to change notification settings - Fork 81
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add user profile tools and put the profile into the state
- Loading branch information
1 parent
b5b2170
commit df9524f
Showing
7 changed files
with
138 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
"""The read user profile tool. Copyright (C) 2024 SynaLinks. License: GPL-3.0""" | ||
|
||
import copy | ||
import dspy | ||
from .base import BaseTool | ||
from typing import Optional, Callable | ||
from ..types.state import AgentState | ||
|
||
class ReadUserProfileTool(BaseTool): | ||
|
||
def __init__( | ||
self, | ||
agent_state: AgentState, | ||
): | ||
super().__init__(name = "ReadUserProfile") | ||
self.agent_state = agent_state | ||
|
||
def forward( | ||
self, | ||
context: str, | ||
objective: str, | ||
purpose: str, | ||
prompt: str, | ||
disable_inference: bool = False, | ||
) -> dspy.Prediction: | ||
"""Method to perform DSPy forward prediction""" | ||
return dspy.Prediction( | ||
user_profile = self.agent_state.user_profile, | ||
) | ||
|
||
def __deepcopy__(self, memo): | ||
cpy = (type)(self)( | ||
agent_state = self.agent_state, | ||
) | ||
return cpy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
"""The update user profile tool. Copyright (C) 2024 SynaLinks. License: GPL-3.0""" | ||
|
||
import copy | ||
import dspy | ||
from .base import BaseTool | ||
from typing import Optional, Callable | ||
from ..types.state import AgentState | ||
from ..output_parsers.prediction import PredictionOutputParser | ||
|
||
class UpdateUserProfileSignature(dspy.Signature): | ||
"""You will be given an objective, purpose and context | ||
Using the prompt to help you, you will infer the correct user profile | ||
Note: Never give an apology or explain what you are doing.""" | ||
objective = dspy.InputField(desc = "The long-term objective (what you are doing)") | ||
context = dspy.InputField(desc = "The previous actions (what you have done)") | ||
purpose = dspy.InputField(desc = "The purpose of the action (what you have to do now)") | ||
prompt = dspy.InputField(desc = "The action specific instructions (How to do it)") | ||
user_profile = dspy.OutputField(desc = "The user profile") | ||
|
||
class UpdateUserProfileTool(BaseTool): | ||
|
||
def __init__( | ||
self, | ||
agent_state: AgentState, | ||
lm: Optional[dspy.LM] = None, | ||
): | ||
super().__init__(name = "UpdateUserProfile") | ||
self.agent_state = agent_state | ||
self.prediction_parser = PredictionOutputParser() | ||
self.predict = dspy.Predictor(UpdateUserProfileSignature) | ||
|
||
def update_user_profile(self, user_profile: str): | ||
self.agent_state.user_profile = user_profile | ||
|
||
def forward( | ||
self, | ||
context: str, | ||
objective: str, | ||
purpose: str, | ||
prompt: str, | ||
disable_inference: bool = False, | ||
) -> dspy.Prediction: | ||
"""Method to perform DSPy forward prediction""" | ||
if not disable_inference: | ||
pred = self.predict( | ||
objective = objective, | ||
context = context, | ||
purpose = purpose, | ||
prompt = prompt, | ||
) | ||
pred.user_profile = self.prediction_parser( | ||
pred.user_profile, | ||
prefix="User Profile:", | ||
) | ||
self.update_user_profile(pred.user_profile) | ||
return dspy.Prediction( | ||
user_profile = pred.user_profile, | ||
observation = "Successfully updated", | ||
) | ||
else: | ||
self.update_user_profile(prompt) | ||
return dspy.Prediction( | ||
user_profile = prompt, | ||
observation = "Successfully updated", | ||
) | ||
|
||
def __deepcopy__(self, memo): | ||
cpy = (type)(self)( | ||
agent_state = self.agent_state, | ||
lm = self.lm, | ||
) | ||
return cpy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters