Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Create notebook based demos #67

Merged
merged 1 commit into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,10 @@ jobs:
- name: Ruff Format (Check)
uses: chartboost/ruff-action@v1
with:
args: format --check
args: format --check

verify_clean_notebooks:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: ResearchSoftwareActions/EnsureCleanNotebooksAction@1.1
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ COPY ./pyproject.toml ./poetry.lock* /tmp/
RUN poetry export -f requirements.txt --output requirements.txt --without-hashes

######
# 2. Compile the frontend
# 2. Compile the frontend
FROM node:20.9.0-alpine as frontend-stage
WORKDIR /app
COPY ./frontend/package.json ./package.json
Expand All @@ -31,6 +31,6 @@ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
COPY ./dewy /code/dewy
COPY --from=frontend-stage /app/dist /code/dewy/frontend/dist

COPY ./migrations/0001_schema.sql /code/migrations/0001_schema.sql
COPY ./migrations/*.sql /code/migrations/

CMD ["uvicorn", "dewy.main:app", "--host", "0.0.0.0", "--port", "8000"]
1 change: 1 addition & 0 deletions demos/python-langchain-notebook/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.env
300 changes: 300 additions & 0 deletions demos/python-langchain-notebook/python-langchain.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Setup\n",
"\n",
"Create a `.env` file containing:\n",
"```\n",
"OPENAI_API_KEY=\"<your key here>\"\n",
"```\n",
"\n",
"Install langchain and dewy-client as shown below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install dewy-client langchain langchain-openai"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example LangChain without RAG\n",
"This example shows a simple LangChain application which attempts to answer questions without retrieval."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_openai import ChatOpenAI\n",
"# MODEL=\"gpt-4-0125-preview\"\n",
"MODEL=\"gpt-3.5-turbo\"\n",
"llm = ChatOpenAI(temperature=0.9, model_name=MODEL)\n",
"\n",
"llm.invoke(\"What is RAG useful for?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example LangChain with RAG (using Dewy)\n",
"This example shows what the previous chain looks like using Dewy to retrieve relevant chunks."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create the Dewy Client\n",
"The following cell creates the Dewy client. It assumes you wish to connect to a Dewy service running on your local machine on port 8000. Change the URL as appropriate to your situation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dewy_client import Client\n",
"client = Client(base_url=\"http://localhost:8000\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# The following retrieves a collection ID from Dewy.\n",
"# In general use you could hard-code the collection ID.\n",
"# This may switch to using the names directly.\n",
"from dewy_client.api.default import list_collections\n",
"collection = list_collections.sync(name=\"main\", client=client)[0]\n",
"print(f\"Collection: {collection.to_dict()}\")\n",
"collection_id = collection.id"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Retrieving documents in a chain"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Langchain retriever using Dewy.\n",
"#\n",
"# This will be added to Dewy or LangChain.\n",
"from langchain_core.callbacks.manager import AsyncCallbackManagerForRetrieverRun\n",
"from langchain_core.retrievers import BaseRetriever\n",
"from langchain_core.callbacks import CallbackManagerForRetrieverRun\n",
"from langchain_core.documents import Document\n",
"from typing import Any, Coroutine, List\n",
"\n",
"from dewy_client.api.default import retrieve_chunks\n",
"from dewy_client.models import RetrieveRequest, TextResult\n",
"\n",
"class DewyRetriever(BaseRetriever):\n",
"\n",
" collection_id: int\n",
"\n",
" def _make_request(self, query: str) -> RetrieveRequest:\n",
" return RetrieveRequest(\n",
" collection_id=self.collection_id,\n",
" query=query,\n",
" include_image_chunks=False,\n",
" )\n",
"\n",
" def _make_document(self, chunk: TextResult) -> Document:\n",
" return Document(page_content=chunk.text, metadata = { \"chunk_id\": chunk.chunk_id })\n",
"\n",
" def _get_relevant_documents(\n",
" self, query: str, *, run_manager: CallbackManagerForRetrieverRun\n",
" ) -> List[Document]:\n",
" retrieved = retrieve_chunks.sync(client=client, body=self._make_request(query))\n",
" return [self._make_document(chunk) for chunk in retrieved.text_results]\n",
"\n",
" async def _aget_relevant_documents(\n",
" self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun\n",
" ) -> Coroutine[Any, Any, List[Document]]:\n",
" retrieved = await retrieve_chunks.asyncio(client=client, body=self._make_request(query))\n",
" return [self._make_document(chunk) for chunk in retrieved.text_results]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.runnables import RunnableParallel, RunnablePassthrough\n",
"from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"retriever = DewyRetriever(collection_id=collection_id)\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"\"\"\n",
" You're a helpful AI assistant. Given a user question and some retrieved content, answer the user question.\n",
" If none of the articles answer the question, just say you don't know.\n",
"\n",
" Here is the retrieved content:\n",
" {context}\n",
" \"\"\",\n",
" ),\n",
" (\"human\", \"{question}\"),\n",
" ]\n",
")\n",
"\n",
"def format_chunks(chunks):\n",
" return \"\\n\\n\".join([d.page_content for d in chunks])\n",
"\n",
"chain = (\n",
" { \"context\": retriever | format_chunks, \"question\": RunnablePassthrough() }\n",
" | prompt\n",
" | llm\n",
" | StrOutputParser()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chain.invoke(\"What is RAG useful for?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Langchain with Citations\n",
"Based on https://python.langchain.com/docs/use_cases/question_answering/citations#cite-documents."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"from operator import itemgetter\n",
"from langchain_core.runnables import (\n",
" RunnableLambda,\n",
")\n",
"\n",
"class cited_answer(BaseModel):\n",
" \"\"\"Answer the user question based only on the given sources, and cite the sources used.\"\"\"\n",
"\n",
" answer: str = Field(\n",
" ...,\n",
" description=\"The answer to the user question, which is based only on the given sources.\",\n",
" )\n",
" citations: List[int] = Field(\n",
" ...,\n",
" description=\"The integer IDs of the SPECIFIC sources which justify the answer.\",\n",
" )\n",
"\n",
"def format_docs_with_id(docs: List[Document]) -> str:\n",
" formatted = [\n",
" f\"Source ID: {doc.metadata['chunk_id']}\\nArticle Snippet: {doc.page_content}\"\n",
" for doc in docs\n",
" ]\n",
" return \"\\n\\n\" + \"\\n\\n\".join(formatted)\n",
"\n",
"format = itemgetter(\"docs\") | RunnableLambda(format_docs_with_id)\n",
"\n",
"# Setup a \"cited_answer\" tool.\n",
"from langchain.output_parsers.openai_tools import JsonOutputKeyToolsParser\n",
"output_parser = JsonOutputKeyToolsParser(key_name=\"cited_answer\", return_single=True)\n",
"\n",
"llm_with_tool = llm.bind_tools(\n",
" [cited_answer],\n",
" tool_choice=\"cited_answer\",\n",
")\n",
"answer = prompt | llm_with_tool | output_parser\n",
"\n",
"citation_chain = (\n",
" RunnableParallel(docs = retriever, question=RunnablePassthrough())\n",
" .assign(context=format)\n",
" .assign(cited_answer=answer)\n",
" # Can't include `docs` because they're not JSON serializable.\n",
" .pick([\"cited_answer\"])\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"citation_chain.invoke(\"What is RAG useful for?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Bonus: Adding documents to the collection"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dewy_client.api.default import add_document\n",
"from dewy_client.models import AddDocumentRequest\n",
"add_document.sync(client=client, body=AddDocumentRequest(\n",
" url = \"https://arxiv.org/pdf/2305.14283.pdf\",\n",
" collection_id=collection_id,\n",
"))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "knowledge-7QbvxqGg-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion dewy/common/collection_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
self.extract_images = False

# TODO: Look at a sentence window splitter?
self._splitter = SentenceSplitter()
self._splitter = SentenceSplitter(chunk_size=256)
self._embedding = _resolve_embedding_model(self.text_embedding_model)

field = f"embedding::vector({text_embedding_dimensions})"
Expand Down
Loading