diff --git a/examples/rag/rag_self_correction.py b/examples/rag/rag_self_correction.py new file mode 100644 index 0000000000..a3455dd128 --- /dev/null +++ b/examples/rag/rag_self_correction.py @@ -0,0 +1,152 @@ +from typing import List, Any, Optional, Dict + +import logging +from pprint import pprint + +from canals.component.types import Variadic +from haystack import Pipeline, Document, component, default_to_dict, default_from_dict, DeserializationError +from haystack.document_stores import InMemoryDocumentStore +from haystack.components.retrievers import InMemoryBM25Retriever +from haystack.components.generators import GPTGenerator +from haystack.components.builders.prompt_builder import PromptBuilder +from haystack.components.others import Multiplexer +from haystack.components.routers.conditional_router import ConditionalRouter + + +logging.getLogger().setLevel(logging.DEBUG) + + +@component +class PaginatedRetriever: + """ + This component is used to paginate the results of a retriever. + It is useful when the retriever returns a large number of documents, and we want to pass them to the LLM + in batches. + + It is useful in cases where the LLM's context length is limited, and we want to avoid passing too many + documents to it at once. + """ + + def __init__(self, retriever: Any, page_size: int = 1, top_k: int = 100): + self.retriever = retriever + self.page_size = page_size + self.top_k = top_k + self.retrieved_documents = None + + def to_dict(self): + return default_to_dict(self, retriever=self.retriever.to_dict(), page_size=self.page_size) + + @classmethod + def from_dict(cls, data): + if not "retriever" in data["init_parameters"]: + raise DeserializationError("Missing required field 'retriever' in SlidingWindowRetriever") + + retriever_data = data["init_parameters"]["retriever"] + if "type" not in retriever_data: + raise DeserializationError("Missing 'type' in retriever's serialization data") + if retriever_data["type"] not in component.registry: + raise DeserializationError(f"Component type '{retriever_data['type']}' not found") + retriever_class = component.registry[retriever_data["type"]] + + data["init_parameters"]["retriever"] = retriever_class.from_dict(retriever_data) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query: Variadic[str], + top_k: Optional[int] = None, + filters: Optional[Dict[str, Any]] = None, + scale_score: Optional[bool] = None, + ): + if not top_k: + top_k = self.top_k + + if self.retrieved_documents is None: + self.retrieved_documents = self.retriever.run( + query=query[0], filters=filters, top_k=top_k, scale_score=scale_score + )["documents"] + + if not self.retrieved_documents: + raise ValueError("No more documents available :(") + + next_page = self.retrieved_documents[: self.page_size] + self.retrieved_documents = self.retrieved_documents[self.page_size :] + return {"documents": next_page} + + +def self_correcting_pipeline(): + # Create the RAG pipeline + rag_pipeline = Pipeline(max_loops_allowed=10) + rag_pipeline.add_component(instance=Multiplexer(str), name="query_multiplexer") + rag_pipeline.add_component( + instance=PaginatedRetriever(InMemoryBM25Retriever(document_store=InMemoryDocumentStore())), name="retriever" + ) + rag_pipeline.add_component( + instance=PromptBuilder( + template=""" + Given these documents, answer the question. + If the documents don't provide enough information to answer the question, answer with the string "UNKNOWN". + + Documents: + {% for doc in documents %} + {{ doc.content }} + {% endfor %} + + Question: {{question}} + Answer: + """ + ), + name="prompt_builder", + ) + rag_pipeline.add_component(instance=GPTGenerator(), name="llm") + rag_pipeline.add_component( + instance=ConditionalRouter( + routes=[ + { + "condition": "{{ 'UNKNOWN' in replies|join(' ') }}", + "output": "{{ query }}", + "output_name": "unanswered_query", + "output_type": str, + }, + { + "condition": "{{ 'UNKNOWN' not in replies|join(' ') }}", + "output": "{{ replies }}", + "output_name": "replies", + "output_type": List[str], + }, + ] + ), + name="answer_checker", + ) + + rag_pipeline.connect("query_multiplexer", "retriever") + rag_pipeline.connect("query_multiplexer", "prompt_builder.question") + rag_pipeline.connect("query_multiplexer", "answer_checker.query") + rag_pipeline.connect("retriever", "prompt_builder.documents") + rag_pipeline.connect("prompt_builder", "llm") + rag_pipeline.connect("llm.replies", "answer_checker.replies") + rag_pipeline.connect("answer_checker.unanswered_query", "query_multiplexer") + + # Draw the pipeline + rag_pipeline.draw("self_correcting_pipeline.png") + + # Populate the document store + documents = [ + Document(content="My name is Jean and I live in Paris."), + Document(content="My name is Mark and I live in Berlin."), + Document(content="My name is Giorgio and I live in Rome."), + Document(content="My name is Juan and I live in Madrid."), + ] + rag_pipeline.get_component("retriever").retriever.document_store.write_documents(documents) + + # Query and assert + question = "Who lives in Germany?" + + result = rag_pipeline.run({"query_multiplexer": {"value": question}}) + + pprint(result) + + +if __name__ == "__main__": + self_correcting_pipeline() diff --git a/releasenotes/notes/self-correcting-rag-2e77ac94b89dfe5b.yaml b/releasenotes/notes/self-correcting-rag-2e77ac94b89dfe5b.yaml new file mode 100644 index 0000000000..f0d3c1f92b --- /dev/null +++ b/releasenotes/notes/self-correcting-rag-2e77ac94b89dfe5b.yaml @@ -0,0 +1,2 @@ +enhancements: + - Add RAG self correction loop example