From cd3fac0511a1e4ed66f44e7a1ca1491f105d228f Mon Sep 17 00:00:00 2001 From: FacerAin Date: Sun, 10 Dec 2023 20:19:19 +0900 Subject: [PATCH] [#31] feat: Add docstrings --- app/agent/agent.py | 45 +++++++++++++++++++++ app/agent/retriever.py | 89 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 131 insertions(+), 3 deletions(-) diff --git a/app/agent/agent.py b/app/agent/agent.py index 6f6473f..5a559d4 100644 --- a/app/agent/agent.py +++ b/app/agent/agent.py @@ -17,6 +17,22 @@ class ExecutorAgent: + """ + ExecutorAgent class represents an agent that executes queries using a set of tools. + + Attributes: + retriever (PineconeRetriever): The retriever used to retrieve relevant document strings. + llm (ChatOpenAI): The language model used for generating responses. + tools (list[Tool]): The list of tools available for executing queries. + agent_prompt (AgentPromptTemplate): The template for agent prompts. + output_parser (CustomAgentOutputParser): The output parser for parsing agent responses. + agent (LLMSingleActionAgent): The language model agent for executing queries. + executor (AgentExecutor): The executor for running queries using the agent and tools. + + Methods: + run(query): Executes a query using the executor and returns the response. + """ + def __init__(self): self.retriever = PineconeRetriever(index_name="khugpt") @@ -58,16 +74,45 @@ def __init__(self): ) def run(self, query): + """ + Executes a query using the executor and returns the response. + + Args: + query (str): The query to be executed. + + Returns: + str: The response generated by the agent. + """ response = self.executor.run(query) return response class RetrieverAgent: + """ + A class representing a Retriever Agent. + + Attributes: + index_name (str): The name of the index used for retrieval. + + Methods: + __init__(self, index_name: str = "khugpt") -> None: Initializes the RetrieverAgent object. + run(self, query: str): Runs the retrieval process and returns the answer. + """ + def __init__(self, index_name: str = "khugpt") -> None: self.llm = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0, openai_api_key=settings.OPENAI_API_KEY) self.retreiver = PineconeRetriever(index_name=index_name) def run(self, query: str): + """ + Runs the retrieval process and returns the answer. + + Args: + query (str): The query string. + + Returns: + str: The answer retrieved from the system. + """ context = self.retreiver.get_relevant_doc_string(query) system_prompt = retriever_prompt_template.format( question=query, context=context, current_date=datetime.datetime.now() diff --git a/app/agent/retriever.py b/app/agent/retriever.py index e317a2e..1518b25 100644 --- a/app/agent/retriever.py +++ b/app/agent/retriever.py @@ -11,6 +11,8 @@ class Retriever(ABC): + """Abstract base class for retrievers.""" + @abstractmethod def similarity_search(self, query: str, top_k: int = 5, **kwargs: Any): """Return docs most similar to query.""" @@ -18,6 +20,36 @@ def similarity_search(self, query: str, top_k: int = 5, **kwargs: Any): class PineconeRetriever(Retriever): + """ + A retriever class that uses Pinecone for similarity search. + + Args: + index_name (str): The name of the Pinecone index. + embedding_model (Union[Embeddings, Callable]): The embedding model used for generating embeddings. + Defaults to OpenAIEmbeddings with the provided OpenAI API key. + + Attributes: + _index (pinecone.Index): The Pinecone index object. + _embedding_model (Union[Embeddings, Callable]): The embedding model used for generating embeddings. + + Methods: + get_pinecone_index(index_name: str) -> pinecone.Index: + Retrieves the Pinecone index with the given name. + + _convert_response_to_string(item: Dict) -> str: + Converts a response item to a formatted string. + + _combine_documents(responses: List[Dict]) -> str: + Combines multiple response documents into a single string. + + similarity_search(query: str, top_k: int = 5, **kwargs: Any) -> List[Dict]: + Performs a similarity search using the given query and returns the top-k matching documents. + + get_relevant_doc_string(query: str, top_k: int = 5) -> str: + Retrieves the relevant document string for the given query. + + """ + def __init__( self, index_name: str, @@ -28,6 +60,18 @@ def __init__( self._embedding_model = embedding_model def get_pinecone_index(self, index_name: str): + """ + Retrieves the Pinecone index with the given name. + + Args: + index_name (str): The name of the Pinecone index. + + Returns: + pinecone.Index: The Pinecone index object. + + Raises: + ValueError: If the index with the given name is not found in the Pinecone project. + """ indexes = pinecone.list_indexes() if index_name in indexes: @@ -39,23 +83,62 @@ def get_pinecone_index(self, index_name: str): return index def _convert_response_to_string(self, item: Dict) -> str: + """ + Converts a response item to a formatted string. + + Args: + item (Dict): The response item containing metadata. + + Returns: + str: The formatted string representation of the response item. + """ doc = f""" page_url: {item['metadata']['page_url']} document: {item['metadata']['text']} """ return doc - def _combine_documents(self, responses: List[Dict]) -> List[str]: + def _combine_documents(self, responses: List[Dict]) -> str: + """ + Combines multiple response documents into a single string. + + Args: + responses (List[Dict]): The list of response items. + + Returns: + str: The combined document string. + """ docs = [self._convert_response_to_string(response) for response in responses] doc_string = DOCUMENT_SEPERATOR.join(docs) return doc_string - def similarity_search(self, query: str, top_k: int = 5, **kwargs: Any): + def similarity_search(self, query: str, top_k: int = 5, **kwargs: Any) -> List[Dict]: + """ + Performs a similarity search using the given query and returns the top-k matching documents. + + Args: + query (str): The query string. + top_k (int): The number of top matching documents to retrieve. Defaults to 5. + **kwargs (Any): Additional keyword arguments to be passed to the similarity search. + + Returns: + List[Dict]: The list of top-k matching documents with metadata. + """ embeddings = self._embedding_model.embed_query(query) responses = self._index.query([embeddings], top_k=top_k, include_metadata=True) return responses - def get_relevant_doc_string(self, query: str, top_k: int = 5): + def get_relevant_doc_string(self, query: str, top_k: int = 5) -> str: + """ + Retrieves the relevant document string for the given query. + + Args: + query (str): The query string. + top_k (int): The number of top matching documents to retrieve. Defaults to 5. + + Returns: + str: The combined document string of the top-k matching documents. + """ responses = self.similarity_search(query=query, top_k=top_k) doc_string = self._combine_documents(responses=responses["matches"]) return doc_string