Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tool flexibility #40

Merged
merged 8 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,40 @@ cd streamlit_app
streamlit run app.py
```

## Usage
```python
from memary.agent.chat_agent import ChatAgent

system_persona_txt = "data/system_persona.txt"
user_persona_txt = "data/user_persona.txt"
past_chat_json = "data/past_chat.json"
memory_stream_json = "data/memory_stream.json"
entity_knowledge_store_json = "data/entity_knowledge_store.json"
chat_agent = ChatAgent(
"Personal Agent",
memory_stream_json,
entity_knowledge_store_json,
system_persona_txt,
user_persona_txt,
past_chat_json,
)
```
Pass in subset of `['search', 'vision', 'locate', 'stocks']` as `include_from_defaults` for different set of default tools upon initialization.
### Adding Custom Tools
```python
def multiply(a: int, b: int) -> int:
"""Multiply two integers and returns the result integer"""
return a * b

chat_agent.add_tool({"multiply": multiply})
```
More information about creating custom tools for the LlamaIndex ReAct Agent can be found [here](https://docs.llamaindex.ai/en/stable/examples/agent/react_agent/).

### Removing Tools
```python
chat_agent.remove_tool("multiply")
```

## Detailed Component Breakdown

### Routing Agent
Expand Down
Binary file added diagrams/context_window.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added diagrams/memary_logo_bw.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
112 changes: 73 additions & 39 deletions src/memary/agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,16 @@
import os
import sys
from pathlib import Path
from typing import Any, Callable, Dict, List

import geocoder
import googlemaps
import numpy as np
import requests
from ansistrip import ansi_strip
from dotenv import load_dotenv
from llama_index.core import (
KnowledgeGraphIndex,
Settings,
SimpleDirectoryReader,
StorageContext,
)
from llama_index.core import (KnowledgeGraphIndex, Settings,
SimpleDirectoryReader, StorageContext)
from llama_index.core.agent import ReActAgent
from llama_index.core.llms import ChatMessage
from llama_index.core.query_engine import RetrieverQueryEngine
Expand All @@ -28,10 +25,8 @@
from llama_index.multi_modal_llms.openai import OpenAIMultiModal

from memary.agent.data_types import Context, Message
from memary.agent.llm_api.tools import (
ollama_chat_completions_request,
openai_chat_completions_request,
)
from memary.agent.llm_api.tools import (ollama_chat_completions_request,
openai_chat_completions_request)
from memary.memory import EntityKnowledgeStore, MemoryStream
from memary.synonym_expand.synonym import custom_synonym_expand_fn

Expand Down Expand Up @@ -65,6 +60,7 @@ def __init__(
past_chat_json,
llm_model_name="llama3",
vision_model_name="llava",
include_from_defaults=["search", "locate", "vision", "stocks"],
debug=True,
):
load_dotenv()
Expand Down Expand Up @@ -99,7 +95,6 @@ def __init__(
)

self.vantage_key = os.getenv("ALPHA_VANTAGE_API_KEY")
# self.news_data_key = os.getenv("NEWS_DATA_API_KEY")

self.storage_context = StorageContext.from_defaults(
graph_store=self.graph_store
Expand All @@ -116,18 +111,9 @@ def __init__(
graph_rag_retriever,
)

search_tool = FunctionTool.from_defaults(fn=self.search)
locate_tool = FunctionTool.from_defaults(fn=self.locate)
vision_tool = FunctionTool.from_defaults(fn=self.vision)
stock_tool = FunctionTool.from_defaults(fn=self.stock_price)
# news_tool = FunctionTool.from_defaults(fn=self.get_news)

self.debug = debug
self.routing_agent = ReActAgent.from_tools(
[search_tool, locate_tool, vision_tool, stock_tool],
llm=self.llm,
verbose=True,
)
self.tools = {}
self._init_default_tools(default_tools=include_from_defaults)
kingjulio8238 marked this conversation as resolved.
Show resolved Hide resolved

self.memory_stream = MemoryStream(memory_stream_json)
self.entity_knowledge_store = EntityKnowledgeStore(entity_knowledge_store_json)
Expand Down Expand Up @@ -211,7 +197,7 @@ def vision(self, query: str, img_url: str) -> str:
os.remove(query_image_path) # delete image after use
return response

def stock_price(self, query: str) -> str:
def stocks(self, query: str) -> str:
"""Get the stock price of the company given the ticker"""
request_api = requests.get(
r"https://www.alphavantage.co/query?function=GLOBAL_QUOTE&symbol="
Expand Down Expand Up @@ -435,19 +421,67 @@ def get_entity(self, retrieve) -> list[str]:
entities.remove(exceptions)
return entities

def update_tools(self, updatedTools):
print("recieved update tools")
tools = []
for tool in updatedTools:
if tool == "Search":
tools.append(FunctionTool.from_defaults(fn=self.search))
elif tool == "Location":
tools.append(FunctionTool.from_defaults(fn=self.locate))
elif tool == "Vision":
tools.append(FunctionTool.from_defaults(fn=self.vision))
elif tool == "Stocks":
tools.append(FunctionTool.from_defaults(fn=self.stock_price))
# elif tool == "News":
# tools.append(FunctionTool.from_defaults(fn=self.get_news))

self.routing_agent = ReActAgent.from_tools(tools, llm=self.llm, verbose=True)
def _init_ReAct_agent(self):
kingjulio8238 marked this conversation as resolved.
Show resolved Hide resolved
"""Initializes ReAct Agent with list of tools in self.tools."""
tool_fns = []
for func in self.tools.values():
tool_fns.append(FunctionTool.from_defaults(fn=func))
self.routing_agent = ReActAgent.from_tools(tool_fns, llm=self.llm, verbose=True)

def _init_default_tools(self, default_tools: List[str]):
"""Initializes ReAct Agent from the default list of tools memary provides.
List of strings passed in during initialization denoting which default tools to include.
Args:
default_tools (list(str)): list of tool names in string form
"""

for tool in default_tools:
if tool == "search":
self.tools["search"] = self.search
elif tool == "locate":
self.tools["locate"] = self.locate
elif tool == "vision":
self.tools["vision"] = self.vision
elif tool == "stocks":
self.tools["stocks"] = self.stocks
self._init_ReAct_agent()

def add_tool(self, tool_additions: Dict[str, Callable[..., Any]]):
"""Adds specified tools to be used by the ReAct Agent.
Args:
tools (dict(str, func)): dictionary of tools with names as keys and associated functions as values
"""

for tool_name in tool_additions:
self.tools[tool_name] = tool_additions[tool_name]
self._init_ReAct_agent()

def remove_tool(self, tool_name: str):
"""Removes specified tool from list of available tools for use by the ReAct Agent.
Args:
tool_name (str): name of tool to be removed in string form
"""

if tool_name in self.tools:
del self.tools[tool_name]
self._init_ReAct_agent()
else:
raise ("Unknown tool_name provided for removal.")

def update_tools(self, updated_tools: List[str]):
"""Resets ReAct Agent tools to only include subset of default tools.
Args:
updated_tools (list(str)): list of default tools to include
"""

self.tools.clear()
for tool in updated_tools:
if tool == "search":
self.tools["search"] = self.search
elif tool == "locate":
self.tools["locate"] = self.locate
elif tool == "vision":
self.tools["vision"] = self.vision
elif tool == "stocks":
self.tools["stocks"] = self.stocks
self._init_ReAct_agent()
59 changes: 37 additions & 22 deletions src/memary/agent/chat_agent.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,39 @@
from typing import Optional, List

from memary.agent.base_agent import Agent
import logging


class ChatAgent(Agent):
"""ChatAgent currently able to support Llama3 running on Ollama (default) and gpt-3.5-turbo for llm models,
and LLaVA running on Ollama (default) and gpt-4-vision-preview for the vision tool.
"""
def __init__(self, name, memory_stream_json, entity_knowledge_store_json,
system_persona_txt, user_persona_txt, past_chat_json, llm_model_name="llama3", vision_model_name="llava"):
super().__init__(name, memory_stream_json, entity_knowledge_store_json,
system_persona_txt, user_persona_txt, past_chat_json, llm_model_name, vision_model_name)


def add_chat(self,
role: str,
content: str,
entities: Optional[List[str]] = None):
def __init__(
self,
name,
memory_stream_json,
entity_knowledge_store_json,
system_persona_txt,
user_persona_txt,
past_chat_json,
llm_model_name="llama3",
vision_model_name="llava",
include_from_defaults=["search", "locate", "vision", "stocks"],
):
super().__init__(
name,
memory_stream_json,
entity_knowledge_store_json,
system_persona_txt,
user_persona_txt,
past_chat_json,
llm_model_name,
vision_model_name,
include_from_defaults,
)

def add_chat(self, role: str, content: str, entities: Optional[List[str]] = None):
"""Add a chat to the agent's memory.

Args:
Expand All @@ -30,8 +47,7 @@ def add_chat(self,
if entities:
self.memory_stream.add_memory(entities)
self.memory_stream.save_memory()
self.entity_knowledge_store.add_memory(
self.memory_stream.get_memory())
self.entity_knowledge_store.add_memory(self.memory_stream.get_memory())
self.entity_knowledge_store.save_memory()

self._replace_memory_from_llm_message()
Expand All @@ -41,28 +57,27 @@ def get_chat(self):
return self.contexts

def clearMemory(self):
"""Clears Neo4j database and memory stream/entity knowledge store."""

logging.info("Deleting memory stream and entity knowledge store...")
self.memory_stream.clear_memory()
self.entity_knowledge_store.clear_memory()

# print("removed from mem stream and entity knowdlege store ")
"clears knowledge neo4j database"

print("Deleting nodes from Neo4j...")
logging.info("Deleting nodes from Neo4j...")
try:
self.graph_store.query("MATCH (n) DETACH DELETE n")
except Exception as e:
print(f"Error deleting nodes: {e}")
print("Nodes deleted from Neo4j.")
logging.error(f"Error deleting nodes: {e}")
logging.info("Nodes deleted from Neo4j.")

def _replace_memory_from_llm_message(self):
"""Replace the memory_stream from the llm_message."""
self.message.llm_message[
"memory_stream"] = self.memory_stream.get_memory()
self.message.llm_message["memory_stream"] = self.memory_stream.get_memory()

def _replace_eks_to_from_message(self):
"""Replace the entity knowledge store from the llm_message.
eks = entity knowledge store"""

self.message.llm_message[
"knowledge_entity_store"] = self.entity_knowledge_store.get_memory(
)
self.message.llm_message["knowledge_entity_store"] = (
self.entity_knowledge_store.get_memory()
)
14 changes: 6 additions & 8 deletions streamlit_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_models(llm_models, vision_models):
external_response = ""
st.title("memary")

llm_models = ["gpt-3.5.turbo"]
llm_models = ["gpt-3.5-turbo"]
vision_models = ["gpt-4-vision-preview"]
get_models(llm_models, vision_models)

Expand Down Expand Up @@ -136,13 +136,12 @@ def get_models(llm_models, vision_models):

tools = st.multiselect(
"Select tools to include:",
# ["Search", "Location", "Vision", "Stocks", "News"], #all options available
# ["Search", "Location", "Vision", "Stocks", "News"],) #options that are selected by default
["Search", "Location", "Vision", "Stocks"], # all options available
["Search", "Location", "Vision", "Stocks"],
) # options that are selected by default
["search", "locate", "vision", "stocks"], # all options available
["search", "locate", "vision", "stocks"], # options that are selected by default
)

if "Vision" in tools:
img_url = ""
if "vision" in tools:
img_url = st.text_input("URL of image, leave blank if no image to provide")
if img_url:
st.image(img_url, caption="Uploaded Image", use_column_width=True)
Expand All @@ -167,7 +166,6 @@ def get_models(llm_models, vision_models):
st.write("Please select at least one tool")
st.stop()

print("start update tools")
chat_agent.update_tools(tools)

if img_url:
Expand Down