Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: update langchain notebook #91

Merged
merged 1 commit into from
Feb 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 6 additions & 80 deletions demos/python-langchain-notebook/python-langchain.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install dewy-client langchain langchain-openai"
"%pip install dewy-langchain langchain langchain-openai"
]
},
{
Expand Down Expand Up @@ -53,91 +53,13 @@
"This example shows what the previous chain looks like using Dewy to retrieve relevant chunks."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create the Dewy Client\n",
"The following cell creates the Dewy client. It assumes you wish to connect to a Dewy service running on your local machine on port 8000. Change the URL as appropriate to your situation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dewy_client import Client\n",
"client = Client(base_url=\"http://localhost:8000\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# The following retrieves a collection ID from Dewy.\n",
"# In general use you could hard-code the collection ID.\n",
"# This may switch to using the names directly.\n",
"from dewy_client.api.kb import list_collections\n",
"collection = list_collections.sync(name=\"main\", client=client)[0]\n",
"print(f\"Collection: {collection.to_dict()}\")\n",
"collection_id = collection.id"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Retrieving documents in a chain"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Langchain retriever using Dewy.\n",
"#\n",
"# This will be added to Dewy or LangChain.\n",
"from langchain_core.callbacks.manager import AsyncCallbackManagerForRetrieverRun\n",
"from langchain_core.retrievers import BaseRetriever\n",
"from langchain_core.callbacks import CallbackManagerForRetrieverRun\n",
"from langchain_core.documents import Document\n",
"from typing import Any, Coroutine, List\n",
"\n",
"from dewy_client.api.kb import retrieve_chunks\n",
"from dewy_client.models import RetrieveRequest, TextResult\n",
"\n",
"class DewyRetriever(BaseRetriever):\n",
"\n",
" collection_id: int\n",
"\n",
" def _make_request(self, query: str) -> RetrieveRequest:\n",
" return RetrieveRequest(\n",
" collection_id=self.collection_id,\n",
" query=query,\n",
" include_image_chunks=False,\n",
" )\n",
"\n",
" def _make_document(self, chunk: TextResult) -> Document:\n",
" return Document(page_content=chunk.text, metadata = { \"chunk_id\": chunk.chunk_id })\n",
"\n",
" def _get_relevant_documents(\n",
" self, query: str, *, run_manager: CallbackManagerForRetrieverRun\n",
" ) -> List[Document]:\n",
" retrieved = retrieve_chunks.sync(client=client, body=self._make_request(query))\n",
" return [self._make_document(chunk) for chunk in retrieved.text_results]\n",
"\n",
" async def _aget_relevant_documents(\n",
" self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun\n",
" ) -> Coroutine[Any, Any, List[Document]]:\n",
" retrieved = await retrieve_chunks.asyncio(client=client, body=self._make_request(query))\n",
" return [self._make_document(chunk) for chunk in retrieved.text_results]"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -148,7 +70,9 @@
"from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"retriever = DewyRetriever(collection_id=collection_id)\n",
"from dewy_langchain import DewyRetriever\n",
"\n",
"retriever = DewyRetriever.for_collection(\"main\", base_url=\"http://localhost:8000\")\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
Expand Down Expand Up @@ -267,8 +191,10 @@
"metadata": {},
"outputs": [],
"source": [
"from dewy_client import Client\n",
"from dewy_client.api.kb import add_document\n",
"from dewy_client.models import AddDocumentRequest\n",
"client = Client(base_url=\"http://localhost:8000\")\n",
"add_document.sync(client=client, body=AddDocumentRequest(\n",
" url = \"https://arxiv.org/pdf/2305.14283.pdf\",\n",
" collection_id=collection_id,\n",
Expand Down
Loading