-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
example: self-correcting loop for RAG (#6420)
* add example * docstrings * reno * use condrouter * move functions * tests * reno * add component * reno * add tests * mypy * pylint * logger * module name * multiplexer * draw * query_multiplexer * reno * typo
- Loading branch information
Showing
2 changed files
with
154 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
from typing import List, Any, Optional, Dict | ||
|
||
import logging | ||
from pprint import pprint | ||
|
||
from canals.component.types import Variadic | ||
from haystack import Pipeline, Document, component, default_to_dict, default_from_dict, DeserializationError | ||
from haystack.document_stores import InMemoryDocumentStore | ||
from haystack.components.retrievers import InMemoryBM25Retriever | ||
from haystack.components.generators import GPTGenerator | ||
from haystack.components.builders.prompt_builder import PromptBuilder | ||
from haystack.components.others import Multiplexer | ||
from haystack.components.routers.conditional_router import ConditionalRouter | ||
|
||
|
||
logging.getLogger().setLevel(logging.DEBUG) | ||
|
||
|
||
@component | ||
class PaginatedRetriever: | ||
""" | ||
This component is used to paginate the results of a retriever. | ||
It is useful when the retriever returns a large number of documents, and we want to pass them to the LLM | ||
in batches. | ||
It is useful in cases where the LLM's context length is limited, and we want to avoid passing too many | ||
documents to it at once. | ||
""" | ||
|
||
def __init__(self, retriever: Any, page_size: int = 1, top_k: int = 100): | ||
self.retriever = retriever | ||
self.page_size = page_size | ||
self.top_k = top_k | ||
self.retrieved_documents = None | ||
|
||
def to_dict(self): | ||
return default_to_dict(self, retriever=self.retriever.to_dict(), page_size=self.page_size) | ||
|
||
@classmethod | ||
def from_dict(cls, data): | ||
if not "retriever" in data["init_parameters"]: | ||
raise DeserializationError("Missing required field 'retriever' in SlidingWindowRetriever") | ||
|
||
retriever_data = data["init_parameters"]["retriever"] | ||
if "type" not in retriever_data: | ||
raise DeserializationError("Missing 'type' in retriever's serialization data") | ||
if retriever_data["type"] not in component.registry: | ||
raise DeserializationError(f"Component type '{retriever_data['type']}' not found") | ||
retriever_class = component.registry[retriever_data["type"]] | ||
|
||
data["init_parameters"]["retriever"] = retriever_class.from_dict(retriever_data) | ||
return default_from_dict(cls, data) | ||
|
||
@component.output_types(documents=List[Document]) | ||
def run( | ||
self, | ||
query: Variadic[str], | ||
top_k: Optional[int] = None, | ||
filters: Optional[Dict[str, Any]] = None, | ||
scale_score: Optional[bool] = None, | ||
): | ||
if not top_k: | ||
top_k = self.top_k | ||
|
||
if self.retrieved_documents is None: | ||
self.retrieved_documents = self.retriever.run( | ||
query=query[0], filters=filters, top_k=top_k, scale_score=scale_score | ||
)["documents"] | ||
|
||
if not self.retrieved_documents: | ||
raise ValueError("No more documents available :(") | ||
|
||
next_page = self.retrieved_documents[: self.page_size] | ||
self.retrieved_documents = self.retrieved_documents[self.page_size :] | ||
return {"documents": next_page} | ||
|
||
|
||
def self_correcting_pipeline(): | ||
# Create the RAG pipeline | ||
rag_pipeline = Pipeline(max_loops_allowed=10) | ||
rag_pipeline.add_component(instance=Multiplexer(str), name="query_multiplexer") | ||
rag_pipeline.add_component( | ||
instance=PaginatedRetriever(InMemoryBM25Retriever(document_store=InMemoryDocumentStore())), name="retriever" | ||
) | ||
rag_pipeline.add_component( | ||
instance=PromptBuilder( | ||
template=""" | ||
Given these documents, answer the question. | ||
If the documents don't provide enough information to answer the question, answer with the string "UNKNOWN". | ||
Documents: | ||
{% for doc in documents %} | ||
{{ doc.content }} | ||
{% endfor %} | ||
Question: {{question}} | ||
Answer: | ||
""" | ||
), | ||
name="prompt_builder", | ||
) | ||
rag_pipeline.add_component(instance=GPTGenerator(), name="llm") | ||
rag_pipeline.add_component( | ||
instance=ConditionalRouter( | ||
routes=[ | ||
{ | ||
"condition": "{{ 'UNKNOWN' in replies|join(' ') }}", | ||
"output": "{{ query }}", | ||
"output_name": "unanswered_query", | ||
"output_type": str, | ||
}, | ||
{ | ||
"condition": "{{ 'UNKNOWN' not in replies|join(' ') }}", | ||
"output": "{{ replies }}", | ||
"output_name": "replies", | ||
"output_type": List[str], | ||
}, | ||
] | ||
), | ||
name="answer_checker", | ||
) | ||
|
||
rag_pipeline.connect("query_multiplexer", "retriever") | ||
rag_pipeline.connect("query_multiplexer", "prompt_builder.question") | ||
rag_pipeline.connect("query_multiplexer", "answer_checker.query") | ||
rag_pipeline.connect("retriever", "prompt_builder.documents") | ||
rag_pipeline.connect("prompt_builder", "llm") | ||
rag_pipeline.connect("llm.replies", "answer_checker.replies") | ||
rag_pipeline.connect("answer_checker.unanswered_query", "query_multiplexer") | ||
|
||
# Draw the pipeline | ||
rag_pipeline.draw("self_correcting_pipeline.png") | ||
|
||
# Populate the document store | ||
documents = [ | ||
Document(content="My name is Jean and I live in Paris."), | ||
Document(content="My name is Mark and I live in Berlin."), | ||
Document(content="My name is Giorgio and I live in Rome."), | ||
Document(content="My name is Juan and I live in Madrid."), | ||
] | ||
rag_pipeline.get_component("retriever").retriever.document_store.write_documents(documents) | ||
|
||
# Query and assert | ||
question = "Who lives in Germany?" | ||
|
||
result = rag_pipeline.run({"query_multiplexer": {"value": question}}) | ||
|
||
pprint(result) | ||
|
||
|
||
if __name__ == "__main__": | ||
self_correcting_pipeline() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
enhancements: | ||
- Add RAG self correction loop example |