Skip to content

Commit

Permalink
[#31] feat: Add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
FacerAin committed Dec 10, 2023
1 parent eab2cfb commit cd3fac0
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 3 deletions.
45 changes: 45 additions & 0 deletions app/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,22 @@


class ExecutorAgent:
"""
ExecutorAgent class represents an agent that executes queries using a set of tools.
Attributes:
retriever (PineconeRetriever): The retriever used to retrieve relevant document strings.
llm (ChatOpenAI): The language model used for generating responses.
tools (list[Tool]): The list of tools available for executing queries.
agent_prompt (AgentPromptTemplate): The template for agent prompts.
output_parser (CustomAgentOutputParser): The output parser for parsing agent responses.
agent (LLMSingleActionAgent): The language model agent for executing queries.
executor (AgentExecutor): The executor for running queries using the agent and tools.
Methods:
run(query): Executes a query using the executor and returns the response.
"""

def __init__(self):
self.retriever = PineconeRetriever(index_name="khugpt")

Expand Down Expand Up @@ -58,16 +74,45 @@ def __init__(self):
)

def run(self, query):
"""
Executes a query using the executor and returns the response.
Args:
query (str): The query to be executed.
Returns:
str: The response generated by the agent.
"""
response = self.executor.run(query)
return response


class RetrieverAgent:
"""
A class representing a Retriever Agent.
Attributes:
index_name (str): The name of the index used for retrieval.
Methods:
__init__(self, index_name: str = "khugpt") -> None: Initializes the RetrieverAgent object.
run(self, query: str): Runs the retrieval process and returns the answer.
"""

def __init__(self, index_name: str = "khugpt") -> None:
self.llm = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0, openai_api_key=settings.OPENAI_API_KEY)
self.retreiver = PineconeRetriever(index_name=index_name)

def run(self, query: str):
"""
Runs the retrieval process and returns the answer.
Args:
query (str): The query string.
Returns:
str: The answer retrieved from the system.
"""
context = self.retreiver.get_relevant_doc_string(query)
system_prompt = retriever_prompt_template.format(
question=query, context=context, current_date=datetime.datetime.now()
Expand Down
89 changes: 86 additions & 3 deletions app/agent/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,45 @@


class Retriever(ABC):
"""Abstract base class for retrievers."""

@abstractmethod
def similarity_search(self, query: str, top_k: int = 5, **kwargs: Any):
"""Return docs most similar to query."""
raise NotImplementedError


class PineconeRetriever(Retriever):
"""
A retriever class that uses Pinecone for similarity search.
Args:
index_name (str): The name of the Pinecone index.
embedding_model (Union[Embeddings, Callable]): The embedding model used for generating embeddings.
Defaults to OpenAIEmbeddings with the provided OpenAI API key.
Attributes:
_index (pinecone.Index): The Pinecone index object.
_embedding_model (Union[Embeddings, Callable]): The embedding model used for generating embeddings.
Methods:
get_pinecone_index(index_name: str) -> pinecone.Index:
Retrieves the Pinecone index with the given name.
_convert_response_to_string(item: Dict) -> str:
Converts a response item to a formatted string.
_combine_documents(responses: List[Dict]) -> str:
Combines multiple response documents into a single string.
similarity_search(query: str, top_k: int = 5, **kwargs: Any) -> List[Dict]:
Performs a similarity search using the given query and returns the top-k matching documents.
get_relevant_doc_string(query: str, top_k: int = 5) -> str:
Retrieves the relevant document string for the given query.
"""

def __init__(
self,
index_name: str,
Expand All @@ -28,6 +60,18 @@ def __init__(
self._embedding_model = embedding_model

def get_pinecone_index(self, index_name: str):
"""
Retrieves the Pinecone index with the given name.
Args:
index_name (str): The name of the Pinecone index.
Returns:
pinecone.Index: The Pinecone index object.
Raises:
ValueError: If the index with the given name is not found in the Pinecone project.
"""
indexes = pinecone.list_indexes()

if index_name in indexes:
Expand All @@ -39,23 +83,62 @@ def get_pinecone_index(self, index_name: str):
return index

def _convert_response_to_string(self, item: Dict) -> str:
"""
Converts a response item to a formatted string.
Args:
item (Dict): The response item containing metadata.
Returns:
str: The formatted string representation of the response item.
"""
doc = f"""
page_url: {item['metadata']['page_url']}
document: {item['metadata']['text']}
"""
return doc

def _combine_documents(self, responses: List[Dict]) -> List[str]:
def _combine_documents(self, responses: List[Dict]) -> str:
"""
Combines multiple response documents into a single string.
Args:
responses (List[Dict]): The list of response items.
Returns:
str: The combined document string.
"""
docs = [self._convert_response_to_string(response) for response in responses]
doc_string = DOCUMENT_SEPERATOR.join(docs)
return doc_string

def similarity_search(self, query: str, top_k: int = 5, **kwargs: Any):
def similarity_search(self, query: str, top_k: int = 5, **kwargs: Any) -> List[Dict]:
"""
Performs a similarity search using the given query and returns the top-k matching documents.
Args:
query (str): The query string.
top_k (int): The number of top matching documents to retrieve. Defaults to 5.
**kwargs (Any): Additional keyword arguments to be passed to the similarity search.
Returns:
List[Dict]: The list of top-k matching documents with metadata.
"""
embeddings = self._embedding_model.embed_query(query)
responses = self._index.query([embeddings], top_k=top_k, include_metadata=True)
return responses

def get_relevant_doc_string(self, query: str, top_k: int = 5):
def get_relevant_doc_string(self, query: str, top_k: int = 5) -> str:
"""
Retrieves the relevant document string for the given query.
Args:
query (str): The query string.
top_k (int): The number of top matching documents to retrieve. Defaults to 5.
Returns:
str: The combined document string of the top-k matching documents.
"""
responses = self.similarity_search(query=query, top_k=top_k)
doc_string = self._combine_documents(responses=responses["matches"])
return doc_string

0 comments on commit cd3fac0

Please sign in to comment.