Skip to content

Commit

Permalink
feat: Create notebook based demos (#67)
Browse files Browse the repository at this point in the history
1. Creates a Jupyter notebook using Dewy as a LangChain retriever.
   Demonstrates both with and without citations.
2. Updates chunking to use a smaller chunk size. This makes it easier
   to use models with smaller context limits.
3. Fix `Dockerfile` to include all SQL migrations.
4. Adds a CI rule that ensures we don't check-in the results of
   notebooks.
  • Loading branch information
bjchambers authored Feb 2, 2024
1 parent f877c3c commit 24279cd
Show file tree
Hide file tree
Showing 5 changed files with 311 additions and 4 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,10 @@ jobs:
- name: Ruff Format (Check)
uses: chartboost/ruff-action@v1
with:
args: format --check
args: format --check

verify_clean_notebooks:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: ResearchSoftwareActions/EnsureCleanNotebooksAction@1.1
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ COPY ./pyproject.toml ./poetry.lock* /tmp/
RUN poetry export -f requirements.txt --output requirements.txt --without-hashes

######
# 2. Compile the frontend
# 2. Compile the frontend
FROM node:20.9.0-alpine as frontend-stage
WORKDIR /app
COPY ./frontend/package.json ./package.json
Expand All @@ -31,6 +31,6 @@ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
COPY ./dewy /code/dewy
COPY --from=frontend-stage /app/dist /code/dewy/frontend/dist

COPY ./migrations/0001_schema.sql /code/migrations/0001_schema.sql
COPY ./migrations/*.sql /code/migrations/

CMD ["uvicorn", "dewy.main:app", "--host", "0.0.0.0", "--port", "8000"]
1 change: 1 addition & 0 deletions demos/python-langchain-notebook/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.env
300 changes: 300 additions & 0 deletions demos/python-langchain-notebook/python-langchain.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Setup\n",
"\n",
"Create a `.env` file containing:\n",
"```\n",
"OPENAI_API_KEY=\"<your key here>\"\n",
"```\n",
"\n",
"Install langchain and dewy-client as shown below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install dewy-client langchain langchain-openai"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example LangChain without RAG\n",
"This example shows a simple LangChain application which attempts to answer questions without retrieval."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_openai import ChatOpenAI\n",
"# MODEL=\"gpt-4-0125-preview\"\n",
"MODEL=\"gpt-3.5-turbo\"\n",
"llm = ChatOpenAI(temperature=0.9, model_name=MODEL)\n",
"\n",
"llm.invoke(\"What is RAG useful for?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example LangChain with RAG (using Dewy)\n",
"This example shows what the previous chain looks like using Dewy to retrieve relevant chunks."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create the Dewy Client\n",
"The following cell creates the Dewy client. It assumes you wish to connect to a Dewy service running on your local machine on port 8000. Change the URL as appropriate to your situation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dewy_client import Client\n",
"client = Client(base_url=\"http://localhost:8000\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# The following retrieves a collection ID from Dewy.\n",
"# In general use you could hard-code the collection ID.\n",
"# This may switch to using the names directly.\n",
"from dewy_client.api.default import list_collections\n",
"collection = list_collections.sync(name=\"main\", client=client)[0]\n",
"print(f\"Collection: {collection.to_dict()}\")\n",
"collection_id = collection.id"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Retrieving documents in a chain"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Langchain retriever using Dewy.\n",
"#\n",
"# This will be added to Dewy or LangChain.\n",
"from langchain_core.callbacks.manager import AsyncCallbackManagerForRetrieverRun\n",
"from langchain_core.retrievers import BaseRetriever\n",
"from langchain_core.callbacks import CallbackManagerForRetrieverRun\n",
"from langchain_core.documents import Document\n",
"from typing import Any, Coroutine, List\n",
"\n",
"from dewy_client.api.default import retrieve_chunks\n",
"from dewy_client.models import RetrieveRequest, TextResult\n",
"\n",
"class DewyRetriever(BaseRetriever):\n",
"\n",
" collection_id: int\n",
"\n",
" def _make_request(self, query: str) -> RetrieveRequest:\n",
" return RetrieveRequest(\n",
" collection_id=self.collection_id,\n",
" query=query,\n",
" include_image_chunks=False,\n",
" )\n",
"\n",
" def _make_document(self, chunk: TextResult) -> Document:\n",
" return Document(page_content=chunk.text, metadata = { \"chunk_id\": chunk.chunk_id })\n",
"\n",
" def _get_relevant_documents(\n",
" self, query: str, *, run_manager: CallbackManagerForRetrieverRun\n",
" ) -> List[Document]:\n",
" retrieved = retrieve_chunks.sync(client=client, body=self._make_request(query))\n",
" return [self._make_document(chunk) for chunk in retrieved.text_results]\n",
"\n",
" async def _aget_relevant_documents(\n",
" self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun\n",
" ) -> Coroutine[Any, Any, List[Document]]:\n",
" retrieved = await retrieve_chunks.asyncio(client=client, body=self._make_request(query))\n",
" return [self._make_document(chunk) for chunk in retrieved.text_results]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.runnables import RunnableParallel, RunnablePassthrough\n",
"from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"retriever = DewyRetriever(collection_id=collection_id)\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"\"\"\n",
" You're a helpful AI assistant. Given a user question and some retrieved content, answer the user question.\n",
" If none of the articles answer the question, just say you don't know.\n",
"\n",
" Here is the retrieved content:\n",
" {context}\n",
" \"\"\",\n",
" ),\n",
" (\"human\", \"{question}\"),\n",
" ]\n",
")\n",
"\n",
"def format_chunks(chunks):\n",
" return \"\\n\\n\".join([d.page_content for d in chunks])\n",
"\n",
"chain = (\n",
" { \"context\": retriever | format_chunks, \"question\": RunnablePassthrough() }\n",
" | prompt\n",
" | llm\n",
" | StrOutputParser()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chain.invoke(\"What is RAG useful for?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Langchain with Citations\n",
"Based on https://python.langchain.com/docs/use_cases/question_answering/citations#cite-documents."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"from operator import itemgetter\n",
"from langchain_core.runnables import (\n",
" RunnableLambda,\n",
")\n",
"\n",
"class cited_answer(BaseModel):\n",
" \"\"\"Answer the user question based only on the given sources, and cite the sources used.\"\"\"\n",
"\n",
" answer: str = Field(\n",
" ...,\n",
" description=\"The answer to the user question, which is based only on the given sources.\",\n",
" )\n",
" citations: List[int] = Field(\n",
" ...,\n",
" description=\"The integer IDs of the SPECIFIC sources which justify the answer.\",\n",
" )\n",
"\n",
"def format_docs_with_id(docs: List[Document]) -> str:\n",
" formatted = [\n",
" f\"Source ID: {doc.metadata['chunk_id']}\\nArticle Snippet: {doc.page_content}\"\n",
" for doc in docs\n",
" ]\n",
" return \"\\n\\n\" + \"\\n\\n\".join(formatted)\n",
"\n",
"format = itemgetter(\"docs\") | RunnableLambda(format_docs_with_id)\n",
"\n",
"# Setup a \"cited_answer\" tool.\n",
"from langchain.output_parsers.openai_tools import JsonOutputKeyToolsParser\n",
"output_parser = JsonOutputKeyToolsParser(key_name=\"cited_answer\", return_single=True)\n",
"\n",
"llm_with_tool = llm.bind_tools(\n",
" [cited_answer],\n",
" tool_choice=\"cited_answer\",\n",
")\n",
"answer = prompt | llm_with_tool | output_parser\n",
"\n",
"citation_chain = (\n",
" RunnableParallel(docs = retriever, question=RunnablePassthrough())\n",
" .assign(context=format)\n",
" .assign(cited_answer=answer)\n",
" # Can't include `docs` because they're not JSON serializable.\n",
" .pick([\"cited_answer\"])\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"citation_chain.invoke(\"What is RAG useful for?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Bonus: Adding documents to the collection"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dewy_client.api.default import add_document\n",
"from dewy_client.models import AddDocumentRequest\n",
"add_document.sync(client=client, body=AddDocumentRequest(\n",
" url = \"https://arxiv.org/pdf/2305.14283.pdf\",\n",
" collection_id=collection_id,\n",
"))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "knowledge-7QbvxqGg-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion dewy/common/collection_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
self.extract_images = False

# TODO: Look at a sentence window splitter?
self._splitter = SentenceSplitter()
self._splitter = SentenceSplitter(chunk_size=256)
self._embedding = _resolve_embedding_model(self.text_embedding_model)

field = f"embedding::vector({text_embedding_dimensions})"
Expand Down

0 comments on commit 24279cd

Please sign in to comment.