From af612cf1b9534c21bf27ddc3e8ec4d22893fa11e Mon Sep 17 00:00:00 2001 From: Ben Chambers <35960+bjchambers@users.noreply.github.com> Date: Thu, 1 Feb 2024 13:38:42 -0800 Subject: [PATCH] feat: Create notebook based demos 1. Creates a Jupyter notebook using Dewy as a LangChain retriever. Demonstrates both with and without citations. 2. Updates chunking to use a smaller chunk size. This makes it easier to use models with smaller context limits. 3. Fix `Dockerfile` to include all SQL migrations. 4. Adds a CI rule that ensures we don't check-in the results of notebooks. --- .github/workflows/ci.yml | 8 +- Dockerfile | 4 +- demos/python-langchain-notebook/.gitignore | 1 + .../python-langchain.ipynb | 300 ++++++++++++++++++ dewy/common/collection_embeddings.py | 2 +- 5 files changed, 311 insertions(+), 4 deletions(-) create mode 100644 demos/python-langchain-notebook/.gitignore create mode 100644 demos/python-langchain-notebook/python-langchain.ipynb diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1df6f86..0acdb95 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -61,4 +61,10 @@ jobs: - name: Ruff Format (Check) uses: chartboost/ruff-action@v1 with: - args: format --check \ No newline at end of file + args: format --check + + verify_clean_notebooks: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: ResearchSoftwareActions/EnsureCleanNotebooksAction@1.1 \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index d91444e..287b5b4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,7 +9,7 @@ COPY ./pyproject.toml ./poetry.lock* /tmp/ RUN poetry export -f requirements.txt --output requirements.txt --without-hashes ###### -# 2. Compile the frontend +# 2. Compile the frontend FROM node:20.9.0-alpine as frontend-stage WORKDIR /app COPY ./frontend/package.json ./package.json @@ -31,6 +31,6 @@ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt COPY ./dewy /code/dewy COPY --from=frontend-stage /app/dist /code/dewy/frontend/dist -COPY ./migrations/0001_schema.sql /code/migrations/0001_schema.sql +COPY ./migrations/*.sql /code/migrations/ CMD ["uvicorn", "dewy.main:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file diff --git a/demos/python-langchain-notebook/.gitignore b/demos/python-langchain-notebook/.gitignore new file mode 100644 index 0000000..2eea525 --- /dev/null +++ b/demos/python-langchain-notebook/.gitignore @@ -0,0 +1 @@ +.env \ No newline at end of file diff --git a/demos/python-langchain-notebook/python-langchain.ipynb b/demos/python-langchain-notebook/python-langchain.ipynb new file mode 100644 index 0000000..0939321 --- /dev/null +++ b/demos/python-langchain-notebook/python-langchain.ipynb @@ -0,0 +1,300 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Setup\n", + "\n", + "Create a `.env` file containing:\n", + "```\n", + "OPENAI_API_KEY=\"\"\n", + "```\n", + "\n", + "Install langchain and dewy-client as shown below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install dewy-client langchain langchain-openai" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Example LangChain without RAG\n", + "This example shows a simple LangChain application which attempts to answer questions without retrieval." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_openai import ChatOpenAI\n", + "# MODEL=\"gpt-4-0125-preview\"\n", + "MODEL=\"gpt-3.5-turbo\"\n", + "llm = ChatOpenAI(temperature=0.9, model_name=MODEL)\n", + "\n", + "llm.invoke(\"What is RAG useful for?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Example LangChain with RAG (using Dewy)\n", + "This example shows what the previous chain looks like using Dewy to retrieve relevant chunks." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create the Dewy Client\n", + "The following cell creates the Dewy client. It assumes you wish to connect to a Dewy service running on your local machine on port 8000. Change the URL as appropriate to your situation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dewy_client import Client\n", + "client = Client(base_url=\"http://localhost:8000\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The following retrieves a collection ID from Dewy.\n", + "# In general use you could hard-code the collection ID.\n", + "# This may switch to using the names directly.\n", + "from dewy_client.api.default import list_collections\n", + "collection = list_collections.sync(name=\"main\", client=client)[0]\n", + "print(f\"Collection: {collection.to_dict()}\")\n", + "collection_id = collection.id" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Retrieving documents in a chain" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Langchain retriever using Dewy.\n", + "#\n", + "# This will be added to Dewy or LangChain.\n", + "from langchain_core.callbacks.manager import AsyncCallbackManagerForRetrieverRun\n", + "from langchain_core.retrievers import BaseRetriever\n", + "from langchain_core.callbacks import CallbackManagerForRetrieverRun\n", + "from langchain_core.documents import Document\n", + "from typing import Any, Coroutine, List\n", + "\n", + "from dewy_client.api.default import retrieve_chunks\n", + "from dewy_client.models import RetrieveRequest, TextResult\n", + "\n", + "class DewyRetriever(BaseRetriever):\n", + "\n", + " collection_id: int\n", + "\n", + " def _make_request(self, query: str) -> RetrieveRequest:\n", + " return RetrieveRequest(\n", + " collection_id=self.collection_id,\n", + " query=query,\n", + " include_image_chunks=False,\n", + " )\n", + "\n", + " def _make_document(self, chunk: TextResult) -> Document:\n", + " return Document(page_content=chunk.text, metadata = { \"chunk_id\": chunk.chunk_id })\n", + "\n", + " def _get_relevant_documents(\n", + " self, query: str, *, run_manager: CallbackManagerForRetrieverRun\n", + " ) -> List[Document]:\n", + " retrieved = retrieve_chunks.sync(client=client, body=self._make_request(query))\n", + " return [self._make_document(chunk) for chunk in retrieved.text_results]\n", + "\n", + " async def _aget_relevant_documents(\n", + " self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun\n", + " ) -> Coroutine[Any, Any, List[Document]]:\n", + " retrieved = await retrieve_chunks.asyncio(client=client, body=self._make_request(query))\n", + " return [self._make_document(chunk) for chunk in retrieved.text_results]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.runnables import RunnableParallel, RunnablePassthrough\n", + "from langchain_core.output_parsers import StrOutputParser\n", + "from langchain_core.prompts import ChatPromptTemplate\n", + "\n", + "retriever = DewyRetriever(collection_id=collection_id)\n", + "prompt = ChatPromptTemplate.from_messages(\n", + " [\n", + " (\n", + " \"system\",\n", + " \"\"\"\n", + " You're a helpful AI assistant. Given a user question and some retrieved content, answer the user question.\n", + " If none of the articles answer the question, just say you don't know.\n", + "\n", + " Here is the retrieved content:\n", + " {context}\n", + " \"\"\",\n", + " ),\n", + " (\"human\", \"{question}\"),\n", + " ]\n", + ")\n", + "\n", + "def format_chunks(chunks):\n", + " return \"\\n\\n\".join([d.page_content for d in chunks])\n", + "\n", + "chain = (\n", + " { \"context\": retriever | format_chunks, \"question\": RunnablePassthrough() }\n", + " | prompt\n", + " | llm\n", + " | StrOutputParser()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "chain.invoke(\"What is RAG useful for?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Langchain with Citations\n", + "Based on https://python.langchain.com/docs/use_cases/question_answering/citations#cite-documents." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.pydantic_v1 import BaseModel, Field\n", + "from operator import itemgetter\n", + "from langchain_core.runnables import (\n", + " RunnableLambda,\n", + ")\n", + "\n", + "class cited_answer(BaseModel):\n", + " \"\"\"Answer the user question based only on the given sources, and cite the sources used.\"\"\"\n", + "\n", + " answer: str = Field(\n", + " ...,\n", + " description=\"The answer to the user question, which is based only on the given sources.\",\n", + " )\n", + " citations: List[int] = Field(\n", + " ...,\n", + " description=\"The integer IDs of the SPECIFIC sources which justify the answer.\",\n", + " )\n", + "\n", + "def format_docs_with_id(docs: List[Document]) -> str:\n", + " formatted = [\n", + " f\"Source ID: {doc.metadata['chunk_id']}\\nArticle Snippet: {doc.page_content}\"\n", + " for doc in docs\n", + " ]\n", + " return \"\\n\\n\" + \"\\n\\n\".join(formatted)\n", + "\n", + "format = itemgetter(\"docs\") | RunnableLambda(format_docs_with_id)\n", + "\n", + "# Setup a \"cited_answer\" tool.\n", + "from langchain.output_parsers.openai_tools import JsonOutputKeyToolsParser\n", + "output_parser = JsonOutputKeyToolsParser(key_name=\"cited_answer\", return_single=True)\n", + "\n", + "llm_with_tool = llm.bind_tools(\n", + " [cited_answer],\n", + " tool_choice=\"cited_answer\",\n", + ")\n", + "answer = prompt | llm_with_tool | output_parser\n", + "\n", + "citation_chain = (\n", + " RunnableParallel(docs = retriever, question=RunnablePassthrough())\n", + " .assign(context=format)\n", + " .assign(cited_answer=answer)\n", + " # Can't include `docs` because they're not JSON serializable.\n", + " .pick([\"cited_answer\"])\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "citation_chain.invoke(\"What is RAG useful for?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Bonus: Adding documents to the collection" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dewy_client.api.default import add_document\n", + "from dewy_client.models import AddDocumentRequest\n", + "add_document.sync(client=client, body=AddDocumentRequest(\n", + " url = \"https://arxiv.org/pdf/2305.14283.pdf\",\n", + " collection_id=collection_id,\n", + "))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "knowledge-7QbvxqGg-py3.11", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/dewy/common/collection_embeddings.py b/dewy/common/collection_embeddings.py index 71f0491..6a64925 100644 --- a/dewy/common/collection_embeddings.py +++ b/dewy/common/collection_embeddings.py @@ -35,7 +35,7 @@ def __init__( self.extract_images = False # TODO: Look at a sentence window splitter? - self._splitter = SentenceSplitter() + self._splitter = SentenceSplitter(chunk_size=256) self._embedding = _resolve_embedding_model(self.text_embedding_model) field = f"embedding::vector({text_embedding_dimensions})"