diff --git a/demos/python-langchain-notebook/python-langchain.ipynb b/demos/python-langchain-notebook/python-langchain.ipynb index 6b6b1dd..608c8c5 100644 --- a/demos/python-langchain-notebook/python-langchain.ipynb +++ b/demos/python-langchain-notebook/python-langchain.ipynb @@ -20,7 +20,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install dewy-client langchain langchain-openai" + "%pip install dewy-langchain langchain langchain-openai" ] }, { @@ -53,39 +53,6 @@ "This example shows what the previous chain looks like using Dewy to retrieve relevant chunks." ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Create the Dewy Client\n", - "The following cell creates the Dewy client. It assumes you wish to connect to a Dewy service running on your local machine on port 8000. Change the URL as appropriate to your situation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from dewy_client import Client\n", - "client = Client(base_url=\"http://localhost:8000\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# The following retrieves a collection ID from Dewy.\n", - "# In general use you could hard-code the collection ID.\n", - "# This may switch to using the names directly.\n", - "from dewy_client.api.kb import list_collections\n", - "collection = list_collections.sync(name=\"main\", client=client)[0]\n", - "print(f\"Collection: {collection.to_dict()}\")\n", - "collection_id = collection.id" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -93,51 +60,6 @@ "## Retrieving documents in a chain" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Langchain retriever using Dewy.\n", - "#\n", - "# This will be added to Dewy or LangChain.\n", - "from langchain_core.callbacks.manager import AsyncCallbackManagerForRetrieverRun\n", - "from langchain_core.retrievers import BaseRetriever\n", - "from langchain_core.callbacks import CallbackManagerForRetrieverRun\n", - "from langchain_core.documents import Document\n", - "from typing import Any, Coroutine, List\n", - "\n", - "from dewy_client.api.kb import retrieve_chunks\n", - "from dewy_client.models import RetrieveRequest, TextResult\n", - "\n", - "class DewyRetriever(BaseRetriever):\n", - "\n", - " collection_id: int\n", - "\n", - " def _make_request(self, query: str) -> RetrieveRequest:\n", - " return RetrieveRequest(\n", - " collection_id=self.collection_id,\n", - " query=query,\n", - " include_image_chunks=False,\n", - " )\n", - "\n", - " def _make_document(self, chunk: TextResult) -> Document:\n", - " return Document(page_content=chunk.text, metadata = { \"chunk_id\": chunk.chunk_id })\n", - "\n", - " def _get_relevant_documents(\n", - " self, query: str, *, run_manager: CallbackManagerForRetrieverRun\n", - " ) -> List[Document]:\n", - " retrieved = retrieve_chunks.sync(client=client, body=self._make_request(query))\n", - " return [self._make_document(chunk) for chunk in retrieved.text_results]\n", - "\n", - " async def _aget_relevant_documents(\n", - " self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun\n", - " ) -> Coroutine[Any, Any, List[Document]]:\n", - " retrieved = await retrieve_chunks.asyncio(client=client, body=self._make_request(query))\n", - " return [self._make_document(chunk) for chunk in retrieved.text_results]" - ] - }, { "cell_type": "code", "execution_count": null, @@ -148,7 +70,9 @@ "from langchain_core.output_parsers import StrOutputParser\n", "from langchain_core.prompts import ChatPromptTemplate\n", "\n", - "retriever = DewyRetriever(collection_id=collection_id)\n", + "from dewy_langchain import DewyRetriever\n", + "\n", + "retriever = DewyRetriever.for_collection(\"main\", base_url=\"http://localhost:8000\")\n", "prompt = ChatPromptTemplate.from_messages(\n", " [\n", " (\n", @@ -267,8 +191,10 @@ "metadata": {}, "outputs": [], "source": [ + "from dewy_client import Client\n", "from dewy_client.api.kb import add_document\n", "from dewy_client.models import AddDocumentRequest\n", + "client = Client(base_url=\"http://localhost:8000\")\n", "add_document.sync(client=client, body=AddDocumentRequest(\n", " url = \"https://arxiv.org/pdf/2305.14283.pdf\",\n", " collection_id=collection_id,\n",