Skip to content

Commit

Permalink
Tool-Memory LTM (#1007)
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaSharma13064 committed Aug 10, 2023
1 parent 936f1be commit 5a42280
Show file tree
Hide file tree
Showing 13 changed files with 338 additions and 21 deletions.
6 changes: 3 additions & 3 deletions superagi/agent/agent_iteration_step_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def execute_step(self):
assistant_reply = response['content']
output_handler = get_output_handler(iteration_workflow_step.output_type,
agent_execution_id=self.agent_execution_id,
agent_config=agent_config, agent_tools=agent_tools)
agent_config=agent_config,memory=self.memory,agent_tools=agent_tools)
response = output_handler.handle(self.session, assistant_reply)
if response.status == "COMPLETE":
execution.status = "COMPLETED"
Expand Down Expand Up @@ -153,7 +153,7 @@ def _build_tools(self, agent_config: dict, agent_execution_config: dict):
agent_tools.append(tool_builder.build_tool(tool))

agent_tools = [tool_builder.set_default_params_tool(tool, agent_config, agent_execution_config,
model_api_key, resource_summary) for tool in agent_tools]
model_api_key, resource_summary,self.memory) for tool in agent_tools]
return agent_tools

def _handle_wait_for_permission(self, agent_execution, agent_config: dict, agent_execution_config: dict,
Expand All @@ -179,7 +179,7 @@ def _handle_wait_for_permission(self, agent_execution, agent_config: dict, agent
return False
if agent_execution_permission.status == "APPROVED":
agent_tools = self._build_tools(agent_config, agent_execution_config)
tool_output_handler = ToolOutputHandler(self.agent_execution_id, agent_config, agent_tools)
tool_output_handler = ToolOutputHandler(self.agent_execution_id, agent_config, agent_tools,self.memory)
tool_result = tool_output_handler.handle_tool_response(self.session,
agent_execution_permission.assistant_reply)
result = tool_result.result
Expand Down
4 changes: 2 additions & 2 deletions superagi/agent/agent_tool_step_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def execute_step(self):
assistant_reply = self._process_input_instruction(agent_config, agent_execution_config, step_tool,
workflow_step)
tool_obj = self._build_tool_obj(agent_config, agent_execution_config, step_tool.tool_name)
tool_output_handler = ToolOutputHandler(self.agent_execution_id, agent_config, [tool_obj],
tool_output_handler = ToolOutputHandler(self.agent_execution_id, agent_config, [tool_obj],self.memory,
output_parser=AgentSchemaToolOutputParser())
final_response = tool_output_handler.handle(self.session, assistant_reply)
step_response = "default"
Expand Down Expand Up @@ -119,7 +119,7 @@ def _build_tool_obj(self, agent_config, agent_execution_config, tool_name: str):
tool = self.session.query(Tool).filter(Tool.name == tool_name).first()
tool_obj = tool_builder.build_tool(tool)
tool_obj = tool_builder.set_default_params_tool(tool_obj, agent_config, agent_execution_config, model_api_key,
resource_summary)
resource_summary,self.memory)
return tool_obj

def _process_output_instruction(self, final_response: str, step_tool: AgentWorkflowStepTool,
Expand Down
49 changes: 43 additions & 6 deletions superagi/agent/output_handler.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
import json
from superagi.agent.common_types import TaskExecutorResponse, ToolExecutorResponse
from superagi.agent.output_parser import AgentSchemaOutputParser
from superagi.agent.task_queue import TaskQueue
from superagi.agent.tool_executor import ToolExecutor
from superagi.helper.json_cleaner import JsonCleaner
from superagi.lib.logger import logger
from langchain.text_splitter import TokenTextSplitter
from superagi.models.agent import Agent
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.vector_store.base import VectorStore
import numpy as np

from superagi.models.agent_execution_permission import AgentExecutionPermission


class ToolOutputHandler:
"""Handles the tool output response from the thinking step"""
def __init__(self, agent_execution_id: int, agent_config: dict,
tools: list, output_parser=AgentSchemaOutputParser()):
def __init__(self,
agent_execution_id: int,
agent_config: dict,
tools: list,
memory:VectorStore=None,
output_parser=AgentSchemaOutputParser()):
self.agent_execution_id = agent_execution_id
self.task_queue = TaskQueue(str(agent_execution_id))
self.agent_config = agent_config
self.tools = tools
self.output_parser = output_parser
self.memory=memory

def handle(self, session, assistant_reply):
"""Handles the tool output response from the thinking step.
Expand Down Expand Up @@ -54,7 +62,36 @@ def handle(self, session, assistant_reply):
if not tool_response.retry:
tool_response = self._check_for_completion(tool_response)
# print("Tool Response:", tool_response)
print("Here is the assistant reply: ",assistant_reply,"ENDD")
self.add_text_to_memory(assistant_reply, tool_response.result)
return tool_response

def add_text_to_memory(self, assistant_reply,tool_response_result):
"""
Adds the text generated by the assistant and tool response to the memory.
Args:
assistant_reply (str): The assistant reply.
tool_response_result (str): The tool response.
Returns:
None
"""
if self.memory is not None:
data = json.loads(assistant_reply)
task_description = data['thoughts']['text']
final_tool_response = tool_response_result
prompt = task_description + final_tool_response
text_splitter = TokenTextSplitter(chunk_size=1024, chunk_overlap=10)
chunk_response = text_splitter.split_text(prompt)
metadata = {"agent_execution_id": self.agent_execution_id}
metadatas = []
for _ in chunk_response:
metadatas.append(metadata)

self.memory.add_texts(chunk_response, metadatas)



def handle_tool_response(self, session, assistant_reply):
"""Only handle processing of tool response"""
Expand Down Expand Up @@ -134,7 +171,7 @@ class ReplaceTaskOutputHandler:
def __init__(self, agent_execution_id: int, agent_config: dict):
self.agent_execution_id = agent_execution_id
self.task_queue = TaskQueue(str(agent_execution_id))
self.agent_config = agent_config
self.agent_config = agent_config

def handle(self, session, assistant_reply):
assistant_reply = JsonCleaner.extract_json_array_section(assistant_reply)
Expand All @@ -149,11 +186,11 @@ def handle(self, session, assistant_reply):
return TaskExecutorResponse(status=status, retry=False)


def get_output_handler(output_type: str, agent_execution_id: int, agent_config: dict, agent_tools: list = []):
def get_output_handler(output_type: str, agent_execution_id: int, agent_config: dict, agent_tools: list = [],memory=None):
if output_type == "tools":
return ToolOutputHandler(agent_execution_id, agent_config, agent_tools)
return ToolOutputHandler(agent_execution_id, agent_config, agent_tools,memory=memory)
elif output_type == "replace_tasks":
return ReplaceTaskOutputHandler(agent_execution_id, agent_config)
elif output_type == "tasks":
return TaskOutputHandler(agent_execution_id, agent_config)
return ToolOutputHandler(agent_execution_id, agent_config, agent_tools)
return ToolOutputHandler(agent_execution_id, agent_config, agent_tools,memory=memory)
5 changes: 3 additions & 2 deletions superagi/agent/tool_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, session, agent_id: int, agent_execution_id: int = None):
self.session = session
self.agent_id = agent_id
self.agent_execution_id = agent_execution_id


def __validate_filename(self, filename):
"""
Expand Down Expand Up @@ -78,7 +79,7 @@ def build_tool(self, tool: Tool):
return new_object

def set_default_params_tool(self, tool, agent_config, agent_execution_config, model_api_key: str,
resource_summary: str = ""):
resource_summary: str = "",memory=None):
"""
Set the default parameters for the tools.
Expand Down Expand Up @@ -110,7 +111,7 @@ def set_default_params_tool(self, tool, agent_config, agent_execution_config, mo
agent_execution_id=self.agent_execution_id)
if hasattr(tool, 'tool_response_manager'):
tool.tool_response_manager = ToolResponseQueryManager(session=self.session,
agent_execution_id=self.agent_execution_id)
agent_execution_id=self.agent_execution_id,memory=memory)

if tool.name == "QueryResourceTool":
tool.description = tool.description.replace("{summary}", resource_summary)
Expand Down
6 changes: 4 additions & 2 deletions superagi/jobs/agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from superagi.types.vector_store_types import VectorStoreType
from superagi.vector_store.embedding.openai import OpenAiEmbedding
from superagi.vector_store.vector_factory import VectorFactory
from superagi.vector_store.redis import Redis
from superagi.config.config import get_config

# from superagi.helper.tool_helper import get_tool_config_by_key

Expand Down Expand Up @@ -54,11 +56,11 @@ def execute_next_step(self, agent_execution_id):
model_api_key = AgentConfiguration.get_model_api_key(session, agent_execution.agent_id, agent_config["model"])
model_llm_source = ModelSourceType.get_model_source_from_model(agent_config["model"]).value
try:
vector_store_type = VectorStoreType.get_vector_store_type(agent_config["LTM_DB"])
vector_store_type = VectorStoreType.get_vector_store_type(get_config("LTM_DB","Redis"))
memory = VectorFactory.get_vector_storage(vector_store_type, "super-agent-index1",
AgentExecutor.get_embedding(model_llm_source, model_api_key))
except:
logger.info("Unable to setup the pinecone connection...")
logger.info("Unable to setup the connection...")
memory = None

agent_workflow_step = session.query(AgentWorkflowStep).filter(
Expand Down
3 changes: 3 additions & 0 deletions superagi/tools/thinking/prompts/thinking.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and the following task, `{task_description}`.
Below is last tool response:
`{last_tool_response}`

Below is the relevant tool response:
`{relevant_tool_response}`

Perform the task by understanding the problem, extracting variables, and being smart
and efficient. Provide a descriptive response, make decisions yourself when
confronted with choices and provide reasoning for ideas / decisions.
4 changes: 4 additions & 0 deletions superagi/tools/thinking/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class ThinkingTool(BaseTool):
)
args_schema: Type[ThinkingSchema] = ThinkingSchema
goals: List[str] = []
agent_execution_id:int=None
permission_required: bool = False
tool_response_manager: Optional[ToolResponseQueryManager] = None

Expand All @@ -56,6 +57,9 @@ def _execute(self, task_description: str):
prompt = prompt.replace("{task_description}", task_description)
last_tool_response = self.tool_response_manager.get_last_response()
prompt = prompt.replace("{last_tool_response}", last_tool_response)
metadata = {"agent_execution_id":self.agent_execution_id}
relevant_tool_response = self.tool_response_manager.get_relevant_response(query=task_description,metadata=metadata)
prompt = prompt.replace("{relevant_tool_response}",relevant_tool_response)
messages = [{"role": "system", "content": prompt}]
result = self.llm.chat_completion(messages, max_tokens=self.max_token_limit)
return result["content"]
Expand Down
14 changes: 12 additions & 2 deletions superagi/tools/tool_response_query_manager.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
from sqlalchemy.orm import Session

from superagi.models.agent_execution_feed import AgentExecutionFeed
from superagi.vector_store.base import VectorStore


class ToolResponseQueryManager:
def __init__(self, session: Session, agent_execution_id: int):
def __init__(self, session: Session, agent_execution_id: int,memory:VectorStore):
self.session = session
self.agent_execution_id = agent_execution_id

self.memory=memory


def get_last_response(self, tool_name: str = None):
return AgentExecutionFeed.get_last_tool_response(self.session, self.agent_execution_id, tool_name)

def get_relevant_response(self, query: str,metadata:dict, top_k: int = 5):
documents = self.memory.get_matching_text(query, metadata=metadata)
relevant_responses = ""
for document in documents["documents"]:
relevant_responses += document.text_content
return relevant_responses
12 changes: 12 additions & 0 deletions superagi/vector_store/embedding/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@ class OpenAiEmbedding:
def __init__(self, api_key, model="text-embedding-ada-002"):
self.model = model
self.api_key = api_key

async def get_embedding_async(self, text: str):
try:
# openai.api_key = get_config("OPENAI_API_KEY")
openai.api_key = self.api_key
response = await openai.Embedding.create(
input=[text],
engine=self.model
)
return response['data'][0]['embedding']
except Exception as exception:
return {"error": exception}

def get_embedding(self, text):
try:
Expand Down
Loading

0 comments on commit 5a42280

Please sign in to comment.