diff --git a/.gitignore b/.gitignore index e321f6b..3973969 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,9 @@ __pycache__/ *.py[cod] *$py.class +# chromadb +escargot_memory/ + # C extensions *.so diff --git a/escargot/__init__.py b/escargot/__init__.py index 7f24ab1..5bbceb2 100644 --- a/escargot/__init__.py +++ b/escargot/__init__.py @@ -1,2 +1,2 @@ __version__ = '0.0.1' -from .escargot import Escargot \ No newline at end of file +from escargot.escargot import * \ No newline at end of file diff --git a/escargot/escargot.py b/escargot/escargot.py index 96d2954..e7c59c8 100644 --- a/escargot/escargot.py +++ b/escargot/escargot.py @@ -7,10 +7,6 @@ import escargot.language_models as language_models -import escargot.controller as controller -from escargot.parser import ESCARGOTParser -from escargot.prompter import ESCARGOTPrompter -from escargot import operations import logging import io @@ -28,6 +24,7 @@ def __init__(self, config: str, node_types:str = "", relationship_types:str = "" self.log = "" self.lm = language_models.AzureGPT(config, model_name=model_name, logger=logger) self.vdb = WeaviateClient(config, self.logger) + self.memory = None self.node_types = "" self.relationship_types = "" self.question = "" @@ -45,24 +42,7 @@ def __init__(self, config: str, node_types:str = "", relationship_types:str = "" self.node_types = node_types self.relationship_types = relationship_types - #debug_level: 0, 1, 2, 3 - #0: no debug, only output - #1: output, instructions, and exceptions - #2: output, instructions, exceptions, and debug info - #3: output, instructions, exceptions, debug info, and LLM output - def ask(self, question, answer_type = 'natural', num_strategies=3, debug_level = 0): - """ - Ask a question and get an answer. - - :param question: The question to ask. - :type question: str - :param answer_type: The type of answer to expect. Defaults to 'natural'. Options are 'natural', 'array'. - :type answer_type: str - :param num_strategies: The number of strategies to generate. Defaults to 3. - :type num_strategies: int - :return: The answer to the question. - :rtype: str - """ + def setup_logger(self, debug_level): log_stream = io.StringIO() f_handler = logging.StreamHandler(log_stream) c_handler = logging.StreamHandler() @@ -88,6 +68,39 @@ def ask(self, question, answer_type = 'natural', num_strategies=3, debug_level = self.logger.addHandler(f_handler) self.logger.addHandler(c_handler) + return log_stream, c_handler, f_handler + + def finalize_logger(self,log_stream, c_handler, f_handler): + self.log += log_stream.getvalue() + #reset logger + self.logger.removeHandler(c_handler) + c_handler.close() + self.logger.removeHandler(f_handler) + f_handler.close() + + #debug_level: 0, 1, 2, 3 + #0: no debug, only output + #1: output, instructions, and exceptions + #2: output, instructions, exceptions, and debug info + #3: output, instructions, exceptions, debug info, and LLM output + def ask(self, question, answer_type = 'natural', num_strategies=3, debug_level = 0, memory_name = "escargot_memory"): + """ + Ask a question and get an answer. + + :param question: The question to ask. + :type question: str + :param answer_type: The type of answer to expect. Defaults to 'natural'. Options are 'natural', 'array'. + :type answer_type: str + :param num_strategies: The number of strategies to generate. Defaults to 3. + :type num_strategies: int + :return: The answer to the question. + :rtype: str + """ + import escargot.memory as memory + self.memory = memory.Memory(self.lm, collection_name = memory_name) + #setup logger + log_stream, c_handler, f_handler = self.setup_logger(debug_level) + from escargot import operations def got() -> operations.GraphOfOperations: operations_graph = operations.GraphOfOperations() @@ -99,6 +112,9 @@ def got() -> operations.GraphOfOperations: # Create the Controller got = got() try: + from escargot.parser import ESCARGOTParser + from escargot.prompter import ESCARGOTPrompter + import escargot.controller as controller self.controller = controller.Controller( self.lm, got, @@ -118,14 +134,6 @@ def got() -> operations.GraphOfOperations: except Exception as e: self.logger.error("Error executing controller: %s", e) - self.log = log_stream.getvalue() - - #reset logger - self.logger.removeHandler(c_handler) - c_handler.close() - self.logger.removeHandler(f_handler) - f_handler.close() - self.operations_graph = self.controller.graph.operations output = "" if self.controller.final_thought is not None: @@ -136,6 +144,66 @@ def got() -> operations.GraphOfOperations: output = list(self.controller.coder.step_output.values())[-1] self.logger.warning(f"Output: {output}") + + #remove logger + self.finalize_logger(log_stream, c_handler, f_handler) + + return output + + def initialize_controller(self, question, answer_type = 'natural', num_strategies=3, debug_level = 0): + + #setup logger + log_stream, c_handler, f_handler = self.setup_logger(debug_level) + + def got() -> operations.GraphOfOperations: + operations_graph = operations.GraphOfOperations() + + instruction_node = operations.Generate(1, 1) + operations_graph.append_operation(instruction_node) + + return operations_graph + + # Create the Controller + got = got() + try: + from escargot.parser import ESCARGOTParser + from escargot.prompter import ESCARGOTPrompter + self.controller = controller.Controller( + self.lm, + got, + ESCARGOTPrompter(memgraph_client = self.memgraph_client,vector_db = self.vdb, lm=self.lm,node_types=self.node_types,relationship_types=self.relationship_types, logger = self.logger), + ESCARGOTParser(self.logger), + self.logger, + { + "question": question, + "input": "", + "phase": "planning", + "method" : "got", + "num_branches_response": num_strategies, + "answer_type": answer_type + } + ) + except Exception as e: + self.logger.error("Error initializing controller: %s", e) + self.finalize_logger(log_stream, c_handler, f_handler) + + def step(self): + #setup logger + log_stream, c_handler, f_handler = self.setup_logger(self.logger.level) + try: + self.controller.execute_step() + except Exception as e: + self.logger.error("Error executing controller: %s", e) + + output = "" + if self.controller.final_thought is not None: + self.operations_graph = self.controller.graph.operations + output = self.controller.final_thought.state['input'] + + self.logger.warning(f"Output: {output}") + + #remove logger + self.finalize_logger(log_stream, c_handler, f_handler) return output def quick_chat(self,chat, num_responses=1): diff --git a/escargot/memory/__init__.py b/escargot/memory/__init__.py new file mode 100644 index 0000000..e6279a9 --- /dev/null +++ b/escargot/memory/__init__.py @@ -0,0 +1 @@ +from escargot.memory import Memory \ No newline at end of file diff --git a/escargot/memory/memory.py b/escargot/memory/memory.py new file mode 100644 index 0000000..312f0f5 --- /dev/null +++ b/escargot/memory/memory.py @@ -0,0 +1,51 @@ +import chromadb +from chromadb.config import Settings +import pandas as pd +class Memory: + def __init__(self, lm, collection_name="escargot_memory"): + # Initialize ChromaDB client and specify the collection for storing vectors + self.client = chromadb.PersistentClient(path="./escargot_memory", settings=Settings(allow_reset=True)) + self.collection_name = collection_name + self.lm = lm + + self.collection = self.client.get_or_create_collection(collection_name) + + def reset_collection(self): + self.client.reset() + + def store_memory(self, text, metadata={None:None},collection_name = None): + # Embed the text using the lm's embed function + vector = self.lm.get_embedding(text) + + # Add the embedded vector to the collection + if collection_name is None: + self.collection.add(ids=text, embeddings=[vector], metadatas=[metadata]) + else: + collection = self.client.get_collection(collection_name) + collection.add(ids=text, embeddings=[vector], metadatas=[metadata]) + + def query_collection(self,query, max_results=10, collection_name = "escargot_memory", metadata = None): + collection = self.client.get_collection(collection_name) + query_embeddings = self.lm.get_embedding(query) + if metadata is not None: + results = collection.query( + query_embeddings=query_embeddings, + n_results=max_results, + where=metadata, + include=['distances'] + ) + else: + results = collection.query( + query_embeddings=query_embeddings, + n_results=max_results, + include=['distances'] + ) + return results + + def get_all_vectors(self): + # Return all vectors stored in the collection + return self.collection.get() + + def delete_vector(self, text): + # Delete a vector by id (text) + self.collection.delete(ids=text) \ No newline at end of file diff --git a/escargot/operations/__init__.py b/escargot/operations/__init__.py index 9f7d0c9..2d4d85a 100644 --- a/escargot/operations/__init__.py +++ b/escargot/operations/__init__.py @@ -1,6 +1,6 @@ -from .thought import Thought -from .graph_of_operations import GraphOfOperations -from .operations import ( +from escargot.operations.thought import Thought +from escargot.operations.graph_of_operations import GraphOfOperations +from escargot.operations.operations import ( Operation, Generate ) diff --git a/escargot/parser/__init__.py b/escargot/parser/__init__.py index ec95166..d88505d 100644 --- a/escargot/parser/__init__.py +++ b/escargot/parser/__init__.py @@ -1 +1 @@ -from .parser import ESCARGOTParser +from escargot.parser import ESCARGOTParser diff --git a/escargot/prompter/__init__.py b/escargot/prompter/__init__.py index 4842736..564ec9f 100644 --- a/escargot/prompter/__init__.py +++ b/escargot/prompter/__init__.py @@ -1 +1 @@ -from .prompter import ESCARGOTPrompter \ No newline at end of file +from escargot.prompter import ESCARGOTPrompter \ No newline at end of file diff --git a/escargot/vector_db/__init__.py b/escargot/vector_db/__init__.py index 042169b..6afcbac 100644 --- a/escargot/vector_db/__init__.py +++ b/escargot/vector_db/__init__.py @@ -1,2 +1,2 @@ -from .azure_embedding import * -from .weaviate import * +from escargot.vector_db.azure_embedding import * +from escargot.vector_db.weaviate import * diff --git a/pyproject.toml b/pyproject.toml index f427104..ddef6b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "backoff>=2.2.1,<3.0.0", "weaviate-client>=4.6.5", "dill>=0.3.8", + "chromadb>=0.5.5" ] [project.urls]