diff --git a/packages/jupyter-ai-magics/pyproject.toml b/packages/jupyter-ai-magics/pyproject.toml index 1cd7a513f..38d885e7a 100644 --- a/packages/jupyter-ai-magics/pyproject.toml +++ b/packages/jupyter-ai-magics/pyproject.toml @@ -22,9 +22,9 @@ dynamic = ["version", "description", "authors", "urls", "keywords"] dependencies = [ "ipython", - "pydantic", + "pydantic~=1.0", "importlib_metadata>=5.2.0", - "langchain==0.0.223", + "langchain==0.0.277", "typing_extensions>=4.5.0", "click~=8.0", "jsonpath-ng>=1.5.3,<2", diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py index 963eae6ce..7d88beac2 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py @@ -4,6 +4,7 @@ from typing import Any, Awaitable, Coroutine, List, Optional, Tuple from dask.distributed import Client as DaskClient +from jupyter_ai.config_manager import ConfigManager from jupyter_ai.document_loaders.directory import get_embeddings, split from jupyter_ai.document_loaders.splitter import ExtensionSplitter, NotebookSplitter from jupyter_ai.models import ( @@ -29,7 +30,7 @@ METADATA_SAVE_PATH = os.path.join(INDEX_SAVE_DIR, "metadata.json") -class LearnChatHandler(BaseChatHandler, BaseRetriever): +class LearnChatHandler(BaseChatHandler): def __init__( self, root_dir: str, dask_client_future: Awaitable[DaskClient], *args, **kwargs ): @@ -266,9 +267,6 @@ def load_metadata(self): j = json.loads(f.read()) self.metadata = IndexMetadata(**j) - def get_relevant_documents(self, query: str) -> List[Document]: - raise NotImplementedError() - async def aget_relevant_documents( self, query: str ) -> Coroutine[Any, Any, List[Document]]: @@ -291,3 +289,16 @@ def get_embedding_model(self): return None return em_provider_cls(**em_provider_args) + + +class Retriever(BaseRetriever): + learn_chat_handler: LearnChatHandler = None + + def _get_relevant_documents(self, query: str) -> List[Document]: + raise NotImplementedError() + + async def _aget_relevant_documents( + self, query: str + ) -> Coroutine[Any, Any, List[Document]]: + docs = await self.learn_chat_handler.aget_relevant_documents(query) + return docs diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index f9d0fee56..a6564679b 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -1,6 +1,7 @@ import time from dask.distributed import Client as DaskClient +from jupyter_ai.chat_handlers.learn import Retriever from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_server.extension.application import ExtensionApp @@ -93,9 +94,8 @@ def initialize_settings(self): dask_client_future=dask_client_future, ) help_chat_handler = HelpChatHandler(**chat_handler_kwargs) - ask_chat_handler = AskChatHandler( - **chat_handler_kwargs, retriever=learn_chat_handler - ) + retriever = Retriever(learn_chat_handler=learn_chat_handler) + ask_chat_handler = AskChatHandler(**chat_handler_kwargs, retriever=retriever) self.settings["jai_chat_handlers"] = { "default": default_chat_handler, "/ask": ask_chat_handler, diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index 7c7792dd3..71ae54ff4 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -24,11 +24,11 @@ classifiers = [ dependencies = [ "jupyter_server>=1.6,<3", "jupyterlab~=4.0", - "pydantic", + "pydantic~=1.0", "openai~=0.26", "aiosqlite>=0.18", "importlib_metadata>=5.2.0", - "langchain==0.0.223", + "langchain==0.0.277", "tiktoken", # required for OpenAIEmbeddings "jupyter_ai_magics", "dask[distributed]",