Skip to content

Commit

Permalink
docs: update langchain notebook (#91)
Browse files Browse the repository at this point in the history
  • Loading branch information
bjchambers authored Feb 16, 2024
1 parent 81100d2 commit e4226fc
Showing 1 changed file with 6 additions and 80 deletions.
86 changes: 6 additions & 80 deletions demos/python-langchain-notebook/python-langchain.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install dewy-client langchain langchain-openai"
"%pip install dewy-langchain langchain langchain-openai"
]
},
{
Expand Down Expand Up @@ -53,91 +53,13 @@
"This example shows what the previous chain looks like using Dewy to retrieve relevant chunks."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create the Dewy Client\n",
"The following cell creates the Dewy client. It assumes you wish to connect to a Dewy service running on your local machine on port 8000. Change the URL as appropriate to your situation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dewy_client import Client\n",
"client = Client(base_url=\"http://localhost:8000\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# The following retrieves a collection ID from Dewy.\n",
"# In general use you could hard-code the collection ID.\n",
"# This may switch to using the names directly.\n",
"from dewy_client.api.kb import list_collections\n",
"collection = list_collections.sync(name=\"main\", client=client)[0]\n",
"print(f\"Collection: {collection.to_dict()}\")\n",
"collection_id = collection.id"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Retrieving documents in a chain"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Langchain retriever using Dewy.\n",
"#\n",
"# This will be added to Dewy or LangChain.\n",
"from langchain_core.callbacks.manager import AsyncCallbackManagerForRetrieverRun\n",
"from langchain_core.retrievers import BaseRetriever\n",
"from langchain_core.callbacks import CallbackManagerForRetrieverRun\n",
"from langchain_core.documents import Document\n",
"from typing import Any, Coroutine, List\n",
"\n",
"from dewy_client.api.kb import retrieve_chunks\n",
"from dewy_client.models import RetrieveRequest, TextResult\n",
"\n",
"class DewyRetriever(BaseRetriever):\n",
"\n",
" collection_id: int\n",
"\n",
" def _make_request(self, query: str) -> RetrieveRequest:\n",
" return RetrieveRequest(\n",
" collection_id=self.collection_id,\n",
" query=query,\n",
" include_image_chunks=False,\n",
" )\n",
"\n",
" def _make_document(self, chunk: TextResult) -> Document:\n",
" return Document(page_content=chunk.text, metadata = { \"chunk_id\": chunk.chunk_id })\n",
"\n",
" def _get_relevant_documents(\n",
" self, query: str, *, run_manager: CallbackManagerForRetrieverRun\n",
" ) -> List[Document]:\n",
" retrieved = retrieve_chunks.sync(client=client, body=self._make_request(query))\n",
" return [self._make_document(chunk) for chunk in retrieved.text_results]\n",
"\n",
" async def _aget_relevant_documents(\n",
" self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun\n",
" ) -> Coroutine[Any, Any, List[Document]]:\n",
" retrieved = await retrieve_chunks.asyncio(client=client, body=self._make_request(query))\n",
" return [self._make_document(chunk) for chunk in retrieved.text_results]"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -148,7 +70,9 @@
"from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"retriever = DewyRetriever(collection_id=collection_id)\n",
"from dewy_langchain import DewyRetriever\n",
"\n",
"retriever = DewyRetriever.for_collection(\"main\", base_url=\"http://localhost:8000\")\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
Expand Down Expand Up @@ -267,8 +191,10 @@
"metadata": {},
"outputs": [],
"source": [
"from dewy_client import Client\n",
"from dewy_client.api.kb import add_document\n",
"from dewy_client.models import AddDocumentRequest\n",
"client = Client(base_url=\"http://localhost:8000\")\n",
"add_document.sync(client=client, body=AddDocumentRequest(\n",
" url = \"https://arxiv.org/pdf/2305.14283.pdf\",\n",
" collection_id=collection_id,\n",
Expand Down

0 comments on commit e4226fc

Please sign in to comment.