Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Short term memory #5

Merged
merged 3 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ __pycache__/
*.py[cod]
*$py.class

# chromadb
escargot_memory/

# C extensions
*.so

Expand Down
2 changes: 1 addition & 1 deletion escargot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = '0.0.1'
from .escargot import Escargot
from escargot.escargot import *
128 changes: 98 additions & 30 deletions escargot/escargot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@


import escargot.language_models as language_models
import escargot.controller as controller
from escargot.parser import ESCARGOTParser
from escargot.prompter import ESCARGOTPrompter
from escargot import operations
import logging
import io

Expand All @@ -28,6 +24,7 @@ def __init__(self, config: str, node_types:str = "", relationship_types:str = ""
self.log = ""
self.lm = language_models.AzureGPT(config, model_name=model_name, logger=logger)
self.vdb = WeaviateClient(config, self.logger)
self.memory = None
self.node_types = ""
self.relationship_types = ""
self.question = ""
Expand All @@ -45,24 +42,7 @@ def __init__(self, config: str, node_types:str = "", relationship_types:str = ""
self.node_types = node_types
self.relationship_types = relationship_types

#debug_level: 0, 1, 2, 3
#0: no debug, only output
#1: output, instructions, and exceptions
#2: output, instructions, exceptions, and debug info
#3: output, instructions, exceptions, debug info, and LLM output
def ask(self, question, answer_type = 'natural', num_strategies=3, debug_level = 0):
"""
Ask a question and get an answer.

:param question: The question to ask.
:type question: str
:param answer_type: The type of answer to expect. Defaults to 'natural'. Options are 'natural', 'array'.
:type answer_type: str
:param num_strategies: The number of strategies to generate. Defaults to 3.
:type num_strategies: int
:return: The answer to the question.
:rtype: str
"""
def setup_logger(self, debug_level):
log_stream = io.StringIO()
f_handler = logging.StreamHandler(log_stream)
c_handler = logging.StreamHandler()
Expand All @@ -88,6 +68,39 @@ def ask(self, question, answer_type = 'natural', num_strategies=3, debug_level =
self.logger.addHandler(f_handler)
self.logger.addHandler(c_handler)

return log_stream, c_handler, f_handler

def finalize_logger(self,log_stream, c_handler, f_handler):
self.log += log_stream.getvalue()
#reset logger
self.logger.removeHandler(c_handler)
c_handler.close()
self.logger.removeHandler(f_handler)
f_handler.close()

#debug_level: 0, 1, 2, 3
#0: no debug, only output
#1: output, instructions, and exceptions
#2: output, instructions, exceptions, and debug info
#3: output, instructions, exceptions, debug info, and LLM output
def ask(self, question, answer_type = 'natural', num_strategies=3, debug_level = 0, memory_name = "escargot_memory"):
"""
Ask a question and get an answer.

:param question: The question to ask.
:type question: str
:param answer_type: The type of answer to expect. Defaults to 'natural'. Options are 'natural', 'array'.
:type answer_type: str
:param num_strategies: The number of strategies to generate. Defaults to 3.
:type num_strategies: int
:return: The answer to the question.
:rtype: str
"""
import escargot.memory as memory
self.memory = memory.Memory(self.lm, collection_name = memory_name)
#setup logger
log_stream, c_handler, f_handler = self.setup_logger(debug_level)
from escargot import operations
def got() -> operations.GraphOfOperations:
operations_graph = operations.GraphOfOperations()

Expand All @@ -99,6 +112,9 @@ def got() -> operations.GraphOfOperations:
# Create the Controller
got = got()
try:
from escargot.parser import ESCARGOTParser
from escargot.prompter import ESCARGOTPrompter
import escargot.controller as controller
self.controller = controller.Controller(
self.lm,
got,
Expand All @@ -118,14 +134,6 @@ def got() -> operations.GraphOfOperations:
except Exception as e:
self.logger.error("Error executing controller: %s", e)

self.log = log_stream.getvalue()

#reset logger
self.logger.removeHandler(c_handler)
c_handler.close()
self.logger.removeHandler(f_handler)
f_handler.close()

self.operations_graph = self.controller.graph.operations
output = ""
if self.controller.final_thought is not None:
Expand All @@ -136,6 +144,66 @@ def got() -> operations.GraphOfOperations:
output = list(self.controller.coder.step_output.values())[-1]

self.logger.warning(f"Output: {output}")

#remove logger
self.finalize_logger(log_stream, c_handler, f_handler)

return output

def initialize_controller(self, question, answer_type = 'natural', num_strategies=3, debug_level = 0):

#setup logger
log_stream, c_handler, f_handler = self.setup_logger(debug_level)

def got() -> operations.GraphOfOperations:
operations_graph = operations.GraphOfOperations()

instruction_node = operations.Generate(1, 1)
operations_graph.append_operation(instruction_node)

return operations_graph

# Create the Controller
got = got()
try:
from escargot.parser import ESCARGOTParser
from escargot.prompter import ESCARGOTPrompter
self.controller = controller.Controller(
self.lm,
got,
ESCARGOTPrompter(memgraph_client = self.memgraph_client,vector_db = self.vdb, lm=self.lm,node_types=self.node_types,relationship_types=self.relationship_types, logger = self.logger),
ESCARGOTParser(self.logger),
self.logger,
{
"question": question,
"input": "",
"phase": "planning",
"method" : "got",
"num_branches_response": num_strategies,
"answer_type": answer_type
}
)
except Exception as e:
self.logger.error("Error initializing controller: %s", e)
self.finalize_logger(log_stream, c_handler, f_handler)

def step(self):
#setup logger
log_stream, c_handler, f_handler = self.setup_logger(self.logger.level)
try:
self.controller.execute_step()
except Exception as e:
self.logger.error("Error executing controller: %s", e)

output = ""
if self.controller.final_thought is not None:
self.operations_graph = self.controller.graph.operations
output = self.controller.final_thought.state['input']

self.logger.warning(f"Output: {output}")

#remove logger
self.finalize_logger(log_stream, c_handler, f_handler)
return output

def quick_chat(self,chat, num_responses=1):
Expand Down
1 change: 1 addition & 0 deletions escargot/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from escargot.memory import Memory
51 changes: 51 additions & 0 deletions escargot/memory/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import chromadb
from chromadb.config import Settings
import pandas as pd
class Memory:
def __init__(self, lm, collection_name="escargot_memory"):
# Initialize ChromaDB client and specify the collection for storing vectors
self.client = chromadb.PersistentClient(path="./escargot_memory", settings=Settings(allow_reset=True))
self.collection_name = collection_name
self.lm = lm

self.collection = self.client.get_or_create_collection(collection_name)

def reset_collection(self):
self.client.reset()

def store_memory(self, text, metadata={None:None},collection_name = None):
# Embed the text using the lm's embed function
vector = self.lm.get_embedding(text)

# Add the embedded vector to the collection
if collection_name is None:
self.collection.add(ids=text, embeddings=[vector], metadatas=[metadata])
else:
collection = self.client.get_collection(collection_name)
collection.add(ids=text, embeddings=[vector], metadatas=[metadata])

def query_collection(self,query, max_results=10, collection_name = "escargot_memory", metadata = None):
collection = self.client.get_collection(collection_name)
query_embeddings = self.lm.get_embedding(query)
if metadata is not None:
results = collection.query(
query_embeddings=query_embeddings,
n_results=max_results,
where=metadata,
include=['distances']
)
else:
results = collection.query(
query_embeddings=query_embeddings,
n_results=max_results,
include=['distances']
)
return results

def get_all_vectors(self):
# Return all vectors stored in the collection
return self.collection.get()

def delete_vector(self, text):
# Delete a vector by id (text)
self.collection.delete(ids=text)
6 changes: 3 additions & 3 deletions escargot/operations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .thought import Thought
from .graph_of_operations import GraphOfOperations
from .operations import (
from escargot.operations.thought import Thought
from escargot.operations.graph_of_operations import GraphOfOperations
from escargot.operations.operations import (
Operation,
Generate
)
2 changes: 1 addition & 1 deletion escargot/parser/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .parser import ESCARGOTParser
from escargot.parser import ESCARGOTParser
2 changes: 1 addition & 1 deletion escargot/prompter/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .prompter import ESCARGOTPrompter
from escargot.prompter import ESCARGOTPrompter
4 changes: 2 additions & 2 deletions escargot/vector_db/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .azure_embedding import *
from .weaviate import *
from escargot.vector_db.azure_embedding import *
from escargot.vector_db.weaviate import *
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"backoff>=2.2.1,<3.0.0",
"weaviate-client>=4.6.5",
"dill>=0.3.8",
"chromadb>=0.5.5"
]

[project.urls]
Expand Down