-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Create notebook based demos (#67)
1. Creates a Jupyter notebook using Dewy as a LangChain retriever. Demonstrates both with and without citations. 2. Updates chunking to use a smaller chunk size. This makes it easier to use models with smaller context limits. 3. Fix `Dockerfile` to include all SQL migrations. 4. Adds a CI rule that ensures we don't check-in the results of notebooks.
- Loading branch information
1 parent
f877c3c
commit 24279cd
Showing
5 changed files
with
311 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
.env |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,300 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Setup\n", | ||
"\n", | ||
"Create a `.env` file containing:\n", | ||
"```\n", | ||
"OPENAI_API_KEY=\"<your key here>\"\n", | ||
"```\n", | ||
"\n", | ||
"Install langchain and dewy-client as shown below:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%pip install dewy-client langchain langchain-openai" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Example LangChain without RAG\n", | ||
"This example shows a simple LangChain application which attempts to answer questions without retrieval." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_openai import ChatOpenAI\n", | ||
"# MODEL=\"gpt-4-0125-preview\"\n", | ||
"MODEL=\"gpt-3.5-turbo\"\n", | ||
"llm = ChatOpenAI(temperature=0.9, model_name=MODEL)\n", | ||
"\n", | ||
"llm.invoke(\"What is RAG useful for?\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Example LangChain with RAG (using Dewy)\n", | ||
"This example shows what the previous chain looks like using Dewy to retrieve relevant chunks." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Create the Dewy Client\n", | ||
"The following cell creates the Dewy client. It assumes you wish to connect to a Dewy service running on your local machine on port 8000. Change the URL as appropriate to your situation." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from dewy_client import Client\n", | ||
"client = Client(base_url=\"http://localhost:8000\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# The following retrieves a collection ID from Dewy.\n", | ||
"# In general use you could hard-code the collection ID.\n", | ||
"# This may switch to using the names directly.\n", | ||
"from dewy_client.api.default import list_collections\n", | ||
"collection = list_collections.sync(name=\"main\", client=client)[0]\n", | ||
"print(f\"Collection: {collection.to_dict()}\")\n", | ||
"collection_id = collection.id" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Retrieving documents in a chain" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Langchain retriever using Dewy.\n", | ||
"#\n", | ||
"# This will be added to Dewy or LangChain.\n", | ||
"from langchain_core.callbacks.manager import AsyncCallbackManagerForRetrieverRun\n", | ||
"from langchain_core.retrievers import BaseRetriever\n", | ||
"from langchain_core.callbacks import CallbackManagerForRetrieverRun\n", | ||
"from langchain_core.documents import Document\n", | ||
"from typing import Any, Coroutine, List\n", | ||
"\n", | ||
"from dewy_client.api.default import retrieve_chunks\n", | ||
"from dewy_client.models import RetrieveRequest, TextResult\n", | ||
"\n", | ||
"class DewyRetriever(BaseRetriever):\n", | ||
"\n", | ||
" collection_id: int\n", | ||
"\n", | ||
" def _make_request(self, query: str) -> RetrieveRequest:\n", | ||
" return RetrieveRequest(\n", | ||
" collection_id=self.collection_id,\n", | ||
" query=query,\n", | ||
" include_image_chunks=False,\n", | ||
" )\n", | ||
"\n", | ||
" def _make_document(self, chunk: TextResult) -> Document:\n", | ||
" return Document(page_content=chunk.text, metadata = { \"chunk_id\": chunk.chunk_id })\n", | ||
"\n", | ||
" def _get_relevant_documents(\n", | ||
" self, query: str, *, run_manager: CallbackManagerForRetrieverRun\n", | ||
" ) -> List[Document]:\n", | ||
" retrieved = retrieve_chunks.sync(client=client, body=self._make_request(query))\n", | ||
" return [self._make_document(chunk) for chunk in retrieved.text_results]\n", | ||
"\n", | ||
" async def _aget_relevant_documents(\n", | ||
" self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun\n", | ||
" ) -> Coroutine[Any, Any, List[Document]]:\n", | ||
" retrieved = await retrieve_chunks.asyncio(client=client, body=self._make_request(query))\n", | ||
" return [self._make_document(chunk) for chunk in retrieved.text_results]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_core.runnables import RunnableParallel, RunnablePassthrough\n", | ||
"from langchain_core.output_parsers import StrOutputParser\n", | ||
"from langchain_core.prompts import ChatPromptTemplate\n", | ||
"\n", | ||
"retriever = DewyRetriever(collection_id=collection_id)\n", | ||
"prompt = ChatPromptTemplate.from_messages(\n", | ||
" [\n", | ||
" (\n", | ||
" \"system\",\n", | ||
" \"\"\"\n", | ||
" You're a helpful AI assistant. Given a user question and some retrieved content, answer the user question.\n", | ||
" If none of the articles answer the question, just say you don't know.\n", | ||
"\n", | ||
" Here is the retrieved content:\n", | ||
" {context}\n", | ||
" \"\"\",\n", | ||
" ),\n", | ||
" (\"human\", \"{question}\"),\n", | ||
" ]\n", | ||
")\n", | ||
"\n", | ||
"def format_chunks(chunks):\n", | ||
" return \"\\n\\n\".join([d.page_content for d in chunks])\n", | ||
"\n", | ||
"chain = (\n", | ||
" { \"context\": retriever | format_chunks, \"question\": RunnablePassthrough() }\n", | ||
" | prompt\n", | ||
" | llm\n", | ||
" | StrOutputParser()\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"chain.invoke(\"What is RAG useful for?\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Langchain with Citations\n", | ||
"Based on https://python.langchain.com/docs/use_cases/question_answering/citations#cite-documents." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_core.pydantic_v1 import BaseModel, Field\n", | ||
"from operator import itemgetter\n", | ||
"from langchain_core.runnables import (\n", | ||
" RunnableLambda,\n", | ||
")\n", | ||
"\n", | ||
"class cited_answer(BaseModel):\n", | ||
" \"\"\"Answer the user question based only on the given sources, and cite the sources used.\"\"\"\n", | ||
"\n", | ||
" answer: str = Field(\n", | ||
" ...,\n", | ||
" description=\"The answer to the user question, which is based only on the given sources.\",\n", | ||
" )\n", | ||
" citations: List[int] = Field(\n", | ||
" ...,\n", | ||
" description=\"The integer IDs of the SPECIFIC sources which justify the answer.\",\n", | ||
" )\n", | ||
"\n", | ||
"def format_docs_with_id(docs: List[Document]) -> str:\n", | ||
" formatted = [\n", | ||
" f\"Source ID: {doc.metadata['chunk_id']}\\nArticle Snippet: {doc.page_content}\"\n", | ||
" for doc in docs\n", | ||
" ]\n", | ||
" return \"\\n\\n\" + \"\\n\\n\".join(formatted)\n", | ||
"\n", | ||
"format = itemgetter(\"docs\") | RunnableLambda(format_docs_with_id)\n", | ||
"\n", | ||
"# Setup a \"cited_answer\" tool.\n", | ||
"from langchain.output_parsers.openai_tools import JsonOutputKeyToolsParser\n", | ||
"output_parser = JsonOutputKeyToolsParser(key_name=\"cited_answer\", return_single=True)\n", | ||
"\n", | ||
"llm_with_tool = llm.bind_tools(\n", | ||
" [cited_answer],\n", | ||
" tool_choice=\"cited_answer\",\n", | ||
")\n", | ||
"answer = prompt | llm_with_tool | output_parser\n", | ||
"\n", | ||
"citation_chain = (\n", | ||
" RunnableParallel(docs = retriever, question=RunnablePassthrough())\n", | ||
" .assign(context=format)\n", | ||
" .assign(cited_answer=answer)\n", | ||
" # Can't include `docs` because they're not JSON serializable.\n", | ||
" .pick([\"cited_answer\"])\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"citation_chain.invoke(\"What is RAG useful for?\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Bonus: Adding documents to the collection" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from dewy_client.api.default import add_document\n", | ||
"from dewy_client.models import AddDocumentRequest\n", | ||
"add_document.sync(client=client, body=AddDocumentRequest(\n", | ||
" url = \"https://arxiv.org/pdf/2305.14283.pdf\",\n", | ||
" collection_id=collection_id,\n", | ||
"))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "knowledge-7QbvxqGg-py3.11", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters