From 896ddc8bda96523f6e0b9a2e60de71a4b7e99529 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Thu, 20 Apr 2023 20:34:21 -0700 Subject: [PATCH 01/16] Refactored provider load, decompose logic, aded model provider list api --- .../jupyter_ai_magics/aliases.py | 6 +++ .../jupyter_ai_magics/magics.py | 42 +++-------------- .../jupyter_ai_magics/providers.py | 2 + .../jupyter_ai_magics/utils.py | 47 +++++++++++++++++++ packages/jupyter-ai/jupyter_ai/extension.py | 8 +++- packages/jupyter-ai/jupyter_ai/handlers.py | 41 +++++++++++++++- packages/jupyter-ai/jupyter_ai/models.py | 15 ++++++ 7 files changed, 124 insertions(+), 37 deletions(-) create mode 100644 packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py create mode 100644 packages/jupyter-ai-magics/jupyter_ai_magics/utils.py diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py b/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py new file mode 100644 index 000000000..74be10485 --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py @@ -0,0 +1,6 @@ +MODEL_ID_ALIASES = { + "gpt2": "huggingface_hub:gpt2", + "gpt3": "openai:text-davinci-003", + "chatgpt": "openai-chat:gpt-3.5-turbo", + "gpt4": "openai-chat:gpt-4", +} \ No newline at end of file diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index f79a70973..7fc61889d 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -2,15 +2,16 @@ import json import os import re -import traceback import warnings from typing import Optional -from importlib_metadata import entry_points from IPython import get_ipython from IPython.core.magic import Magics, magics_class, line_cell_magic from IPython.core.magic_arguments import magic_arguments, argument, parse_argstring -from IPython.display import HTML, Image, JSON, Markdown, Math + +from IPython.display import HTML, JSON, Markdown, Math + +from jupyter_ai_magics.utils import decompose_model_id, load_providers from .providers import BaseProvider @@ -36,8 +37,7 @@ def _repr_mimebundle_(self, include=None, exclude=None): } ) -class TextWithMetadata: - +class TextWithMetadata(object): def __init__(self, text, metadata): self.text = text self.metadata = metadata @@ -109,18 +109,7 @@ def __init__(self, shell): "no longer supported. Instead, please use: " "`from langchain.chat_models import ChatOpenAI`") - # load model providers from entry point - self.providers = {} - eps = entry_points() - model_provider_eps = eps.select(group="jupyter_ai.model_providers") - for model_provider_ep in model_provider_eps: - try: - Provider = model_provider_ep.load() - except: - print(f"Unable to load entry point {model_provider_ep.name}"); - traceback.print_exc() - continue - self.providers[Provider.id] = Provider + self.providers = load_providers() def _ai_help_command_markdown(self): table = ("| Command | Description |\n" @@ -272,24 +261,7 @@ def _append_exchange_openai(self, prompt: str, output: str): }) def _decompose_model_id(self, model_id: str): - """Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate.""" - if model_id in MODEL_ID_ALIASES: - model_id = MODEL_ID_ALIASES[model_id] - - if ":" not in model_id: - # case: model ID was not provided with a prefix indicating the provider - # ID. try to infer the provider ID before returning (None, None). - - # naively search through the dictionary and return the first provider - # that provides a model of the same ID. - for provider_id, Provider in self.providers.items(): - if model_id in Provider.models: - return (provider_id, model_id) - - return (None, None) - - provider_id, local_model_id = model_id.split(":", 1) - return (provider_id, local_model_id) + return decompose_model_id(model_id, self.providers) def _get_provider(self, provider_id: Optional[str]) -> BaseProvider: """Returns the model provider ID and class for a model ID. Returns None if indeterminate.""" diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 852c6ce7f..3c765001f 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -298,3 +298,5 @@ class SmEndpointProvider(BaseProvider, SagemakerEndpoint): model_id_key = "endpoint_name" pypi_package_deps = ["boto3"] auth_strategy = AwsAuthStrategy() + + diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py new file mode 100644 index 000000000..d60f81455 --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py @@ -0,0 +1,47 @@ +import logging +from typing import Dict, Optional, Tuple, Union +from importlib_metadata import entry_points +from jupyter_ai_magics.aliases import MODEL_ID_ALIASES + +from jupyter_ai_magics.providers import BaseProvider + + +Logger = Union[logging.Logger, logging.LoggerAdapter] + + +def load_providers(log: Optional[Logger] = None) -> Dict[str, BaseProvider]: + providers = {} + eps = entry_points() + model_provider_eps = eps.select(group="jupyter_ai.model_providers") + for model_provider_ep in model_provider_eps: + try: + provider = model_provider_ep.load() + except: + if log: + log.error(f"Unable to load model provider class from entry point `{model_provider_ep.name}`.") + continue + providers[provider.id] = provider + if log: + log.info(f"Registered model provider `{provider.id}`.") + + return providers + +def decompose_model_id(model_id: str, providers: Dict[str, BaseProvider]) -> Tuple[str, str]: + """Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate.""" + if model_id in MODEL_ID_ALIASES: + model_id = MODEL_ID_ALIASES[model_id] + + if ":" not in model_id: + # case: model ID was not provided with a prefix indicating the provider + # ID. try to infer the provider ID before returning (None, None). + + # naively search through the dictionary and return the first provider + # that provides a model of the same ID. + for provider_id, provider in providers.items(): + if model_id in provider.models: + return (provider_id, model_id) + + return (None, None) + + provider_id, local_model_id = model_id.split(":", 1) + return (provider_id, local_model_id) diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index e4b22a1f5..876d31f11 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -1,6 +1,7 @@ import asyncio import os import queue +from jupyter_ai_magics.utils import load_providers from langchain.memory import ConversationBufferWindowMemory from jupyter_ai.actors.default import DefaultActor from jupyter_ai.actors.ask import AskActor @@ -11,7 +12,7 @@ from jupyter_ai.actors.base import ACTOR_TYPE from jupyter_ai.reply_processor import ReplyProcessor from jupyter_server.extension.application import ExtensionApp -from .handlers import ChatHandler, ChatHistoryHandler, PromptAPIHandler, TaskAPIHandler +from .handlers import ChatHandler, ChatHistoryHandler, ModelProviderHandler, PromptAPIHandler, TaskAPIHandler from importlib_metadata import entry_points import inspect from .engine import BaseModelEngine @@ -29,6 +30,7 @@ class AiExtension(ExtensionApp): (r"api/ai/tasks/([\w\-:]*)", TaskAPIHandler), (r"api/ai/chats/?", ChatHandler), (r"api/ai/chats/history?", ChatHistoryHandler), + (r"api/ai/providers?", ModelProviderHandler), ] @property @@ -91,6 +93,10 @@ def initialize_settings(self): self.settings["ai_default_tasks"] = default_tasks self.log.info("Registered all default tasks.") + providers = load_providers(log=self.log) + self.settings["chat_providers"] = providers + self.log.info("Registered providers.") + if ChatOpenAINewProvider.auth_strategy.name not in os.environ: raise EnvironmentError(f"`{ChatOpenAINewProvider.auth_strategy.name}` value not set in environment. For chat to work, this value should be provided.") diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 0bcfd8a62..e9c782738 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -16,7 +16,20 @@ from jupyter_server.utils import ensure_async from .task_manager import TaskManager -from .models import ChatHistory, PromptRequest, ChatRequest, ChatMessage, Message, AgentChatMessage, HumanChatMessage, ConnectionMessage, ChatClient, ChatUser +from .models import ( + ChatHistory, + ListProviderEntry, + ListProvidersResponse, + PromptRequest, + ChatRequest, + ChatMessage, + Message, + AgentChatMessage, + HumanChatMessage, + ConnectionMessage, + ChatClient +) + class APIHandler(BaseAPIHandler): @@ -254,3 +267,29 @@ def on_close(self): self.log.info(f"Client disconnected. ID: {self.client_id}") self.log.debug("Chat clients: %s", self.chat_handlers.keys()) + + +class ModelProviderHandler(BaseAPIHandler): + @property + def chat_providers(self): + return self.settings["chat_providers"] + + @web.authenticated + def get(self): + providers = [] + for provider in self.chat_providers.values(): + providers.append( + ListProviderEntry( + id=provider.id, + name=provider.name, + models=provider.models, + auth_strategy=provider.auth_strategy + ) + ) + response = ListProvidersResponse(providers=providers) + self.finish(response.json()) + + +class EmbeddingModelProviderHandler(BaseAPIHandler): + # Placeholder for embedding model provider handler + pass \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 7da9aa47f..e574e16e9 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -1,3 +1,4 @@ +from jupyter_ai_magics.providers import AuthStrategy from pydantic import BaseModel from typing import Dict, List, Union, Literal, Optional @@ -80,3 +81,17 @@ class DescribeTaskResponse(BaseModel): class ChatHistory(BaseModel): """History of chat messages""" messages: List[ChatMessage] + + +class ListProviderEntry(BaseModel): + """Model provider with supported models + and provider's authentication strategy + """ + id: str + name: str + models: List[str] + auth_strategy: AuthStrategy + + +class ListProvidersResponse(BaseModel): + providers: List[ListProviderEntry] \ No newline at end of file From 44aac3a3c7b62df027ee19eca95f701283f61287 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Fri, 21 Apr 2023 13:32:12 -0700 Subject: [PATCH 02/16] Renamed model --- packages/jupyter-ai-magics/jupyter_ai_magics/utils.py | 9 +++++---- packages/jupyter-ai/jupyter_ai/handlers.py | 6 +++--- packages/jupyter-ai/jupyter_ai/models.py | 4 ++-- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py index d60f81455..97221d985 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py @@ -10,6 +10,9 @@ def load_providers(log: Optional[Logger] = None) -> Dict[str, BaseProvider]: + if not log: + log = logging.getLogger() + log.addHandler(logging.NullHandler()) providers = {} eps = entry_points() model_provider_eps = eps.select(group="jupyter_ai.model_providers") @@ -17,12 +20,10 @@ def load_providers(log: Optional[Logger] = None) -> Dict[str, BaseProvider]: try: provider = model_provider_ep.load() except: - if log: - log.error(f"Unable to load model provider class from entry point `{model_provider_ep.name}`.") + log.error(f"Unable to load model provider class from entry point `{model_provider_ep.name}`.") continue providers[provider.id] = provider - if log: - log.info(f"Registered model provider `{provider.id}`.") + log.info(f"Registered model provider `{provider.id}`.") return providers diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index e9c782738..7d711cd4f 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -16,9 +16,10 @@ from jupyter_server.utils import ensure_async from .task_manager import TaskManager + from .models import ( ChatHistory, - ListProviderEntry, + ListProvidersEntry, ListProvidersResponse, PromptRequest, ChatRequest, @@ -31,7 +32,6 @@ ) - class APIHandler(BaseAPIHandler): @property def engines(self): @@ -279,7 +279,7 @@ def get(self): providers = [] for provider in self.chat_providers.values(): providers.append( - ListProviderEntry( + ListProvidersEntry( id=provider.id, name=provider.name, models=provider.models, diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index e574e16e9..decb90c62 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -83,7 +83,7 @@ class ChatHistory(BaseModel): messages: List[ChatMessage] -class ListProviderEntry(BaseModel): +class ListProvidersEntry(BaseModel): """Model provider with supported models and provider's authentication strategy """ @@ -94,4 +94,4 @@ class ListProviderEntry(BaseModel): class ListProvidersResponse(BaseModel): - providers: List[ListProviderEntry] \ No newline at end of file + providers: List[ListProvidersEntry] \ No newline at end of file From 273f67b61fd1f46e07a7e6f9c589d32f0fcc2a32 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Mon, 24 Apr 2023 12:56:24 -0700 Subject: [PATCH 03/16] Sorted the provider names --- packages/jupyter-ai/jupyter_ai/handlers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 7d711cd4f..72e83c7f6 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -286,7 +286,8 @@ def get(self): auth_strategy=provider.auth_strategy ) ) - response = ListProvidersResponse(providers=providers) + + response = ListProvidersResponse(providers=sorted(providers, key=lambda p: p.name)) self.finish(response.json()) From ebd127e75c9e2e5fe9fd6a5d508ccad3acb975be Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Mon, 24 Apr 2023 15:25:06 -0700 Subject: [PATCH 04/16] WIP: Embedding providers --- .../jupyter_ai_magics/__init__.py | 6 ++ .../jupyter_ai_magics/embedding_providers.py | 83 +++++++++++++++++++ .../jupyter_ai_magics/utils.py | 22 +++++ packages/jupyter-ai-magics/pyproject.toml | 5 ++ packages/jupyter-ai/jupyter_ai/extension.py | 9 +- packages/jupyter-ai/jupyter_ai/handlers.py | 24 +++++- 6 files changed, 144 insertions(+), 5 deletions(-) create mode 100644 packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py index d7f5de232..1a947c8fc 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py @@ -12,6 +12,12 @@ ChatOpenAIProvider, SmEndpointProvider ) +# expose embedding model providers on the package root +from .embedding_providers import ( + OpenAIEmbeddingsProvider, + CohereEmbeddingsProvider, + HfHubEmbeddingsProvider +) from .providers import BaseProvider def load_ipython_extension(ipython): diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py new file mode 100644 index 000000000..a4c213216 --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py @@ -0,0 +1,83 @@ +from typing import ClassVar, List, Type +from jupyter_ai_magics.providers import AuthStrategy, EnvAuthStrategy +from pydantic import BaseModel, Extra +from langchain.embeddings import OpenAIEmbeddings, CohereEmbeddings, HuggingFaceHubEmbeddings +from langchain.embeddings.base import Embeddings + + +class BaseEmbeddingsProvider(Embeddings): + """Base class for embedding providers""" + + class Config: + extra = Extra.allow + + id: ClassVar[str] = ... + """ID for this provider class.""" + + name: ClassVar[str] = ... + """User-facing name of this provider.""" + + models: ClassVar[List[str]] = ... + """List of supported models by their IDs. For registry providers, this will + be just ["*"].""" + + model_id_key: ClassVar[str] = ... + """Kwarg expected by the upstream LangChain provider.""" + + pypi_package_deps: ClassVar[List[str]] = [] + """List of PyPi package dependencies.""" + + auth_strategy: ClassVar[AuthStrategy] = None + """Authentication/authorization strategy. Declares what credentials are + required to use this model provider. Generally should not be `None`.""" + + model_id: str + + + def __init__(self, *args, **kwargs): + try: + assert kwargs["model_id"] + except: + raise AssertionError("model_id was not specified. Please specify it as a keyword argument.") + + model_kwargs = {} + model_kwargs[self.__class__.model_id_key] = kwargs["model_id"] + + super().__init__(*args, **kwargs, **model_kwargs) + + + +class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider, OpenAIEmbeddings): + id = "openai" + name = "OpenAI" + models = [ + "text-embedding-ada-002" + ] + model_id_key = "model" + pypi_package_deps = ["openai"] + auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") + + +class CohereEmbeddingsProvider(BaseEmbeddingsProvider, CohereEmbeddings): + id = "cohere" + name = "Cohere" + models = [ + 'large', + 'multilingual-22-12', + 'small' + ] + model_id_key = "model" + pypi_package_deps = ["cohere"] + auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY") + + +class HfHubEmbeddingsProvider(BaseEmbeddingsProvider, HuggingFaceHubEmbeddings): + id = "huggingface_hub" + name = "HuggingFace Hub" + models = ["*"] + model_id_key = "repo_id" + # ipywidgets needed to suppress tqdm warning + # https://stackoverflow.com/questions/67998191 + # tqdm is a dependency of huggingface_hub + pypi_package_deps = ["huggingface_hub", "ipywidgets"] + auth_strategy = EnvAuthStrategy(name="HUGGINGFACEHUB_API_TOKEN") diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py index 97221d985..05d296c5e 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py @@ -2,6 +2,7 @@ from typing import Dict, Optional, Tuple, Union from importlib_metadata import entry_points from jupyter_ai_magics.aliases import MODEL_ID_ALIASES +from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider from jupyter_ai_magics.providers import BaseProvider @@ -27,6 +28,27 @@ def load_providers(log: Optional[Logger] = None) -> Dict[str, BaseProvider]: return providers + +def load_embedding_providers(log: Optional[Logger] = None) -> Dict[str, BaseEmbeddingsProvider]: + if not log: + log = logging.getLogger() + log.addHandler(logging.NullHandler()) + providers = {} + eps = entry_points() + model_provider_eps = eps.select(group="jupyter_ai.embeddings_model_providers") + for model_provider_ep in model_provider_eps: + try: + provider = model_provider_ep.load() + except: + log.error(f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`.") + continue + providers[provider.id] = provider + log.info(f"Registered embeddings model provider `{provider.id}`.") + + return providers + + + def decompose_model_id(model_id: str, providers: Dict[str, BaseProvider]) -> Tuple[str, str]: """Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate.""" if model_id in MODEL_ID_ALIASES: diff --git a/packages/jupyter-ai-magics/pyproject.toml b/packages/jupyter-ai-magics/pyproject.toml index 66a40e3a3..1ee2de714 100644 --- a/packages/jupyter-ai-magics/pyproject.toml +++ b/packages/jupyter-ai-magics/pyproject.toml @@ -56,6 +56,11 @@ openai = "jupyter_ai_magics:OpenAIProvider" openai-chat = "jupyter_ai_magics:ChatOpenAIProvider" sagemaker-endpoint = "jupyter_ai_magics:SmEndpointProvider" +[project.entry-points."jupyter_ai.embeddings_model_providers"] +cohere = "jupyter_ai_magics:CohereEmbeddingsProvider" +huggingface_hub = "jupyter_ai_magics:HfHubEmbeddingsProvider" +openai = "jupyter_ai_magics:OpenAIEmbeddingsProvider" + [tool.hatch.version] source = "nodejs" diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 876d31f11..1197c11de 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -1,7 +1,7 @@ import asyncio import os import queue -from jupyter_ai_magics.utils import load_providers +from jupyter_ai_magics.utils import load_embedding_providers, load_providers from langchain.memory import ConversationBufferWindowMemory from jupyter_ai.actors.default import DefaultActor from jupyter_ai.actors.ask import AskActor @@ -12,7 +12,7 @@ from jupyter_ai.actors.base import ACTOR_TYPE from jupyter_ai.reply_processor import ReplyProcessor from jupyter_server.extension.application import ExtensionApp -from .handlers import ChatHandler, ChatHistoryHandler, ModelProviderHandler, PromptAPIHandler, TaskAPIHandler +from .handlers import ChatHandler, ChatHistoryHandler, EmbeddingsModelProviderHandler, ModelProviderHandler, PromptAPIHandler, TaskAPIHandler from importlib_metadata import entry_points import inspect from .engine import BaseModelEngine @@ -31,6 +31,7 @@ class AiExtension(ExtensionApp): (r"api/ai/chats/?", ChatHandler), (r"api/ai/chats/history?", ChatHistoryHandler), (r"api/ai/providers?", ModelProviderHandler), + (r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler), ] @property @@ -97,6 +98,10 @@ def initialize_settings(self): self.settings["chat_providers"] = providers self.log.info("Registered providers.") + embeddings_providers = load_embedding_providers(log=self.log) + self.settings["embeddings_providers"] = embeddings_providers + self.log.info("Registered embeddings providers.") + if ChatOpenAINewProvider.auth_strategy.name not in os.environ: raise EnvironmentError(f"`{ChatOpenAINewProvider.auth_strategy.name}` value not set in environment. For chat to work, this value should be provided.") diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 72e83c7f6..db7e7223f 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -291,6 +291,24 @@ def get(self): self.finish(response.json()) -class EmbeddingModelProviderHandler(BaseAPIHandler): - # Placeholder for embedding model provider handler - pass \ No newline at end of file +class EmbeddingsModelProviderHandler(BaseAPIHandler): + + @property + def embeddings_providers(self): + return self.settings['embeddings_providers'] + + @web.authenticated + def get(self): + providers = [] + for provider in self.embeddings_providers.values(): + providers.append( + ListProvidersEntry( + id=provider.id, + name=provider.name, + models=provider.models, + auth_strategy=provider.auth_strategy + ) + ) + + response = ListProvidersResponse(providers=sorted(providers, key=lambda p: p.name)) + self.finish(response.json()) \ No newline at end of file From 4286257e4dcef64fa6a6d1d6136e04c99851ed49 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Mon, 24 Apr 2023 21:59:01 -0700 Subject: [PATCH 05/16] Added embeddings provider api --- .../jupyter_ai_magics/embedding_providers.py | 24 +++++++------------ 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py index a4c213216..81eaadbf1 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py @@ -5,7 +5,7 @@ from langchain.embeddings.base import Embeddings -class BaseEmbeddingsProvider(Embeddings): +class BaseEmbeddingsProvider(BaseModel): """Base class for embedding providers""" class Config: @@ -33,21 +33,10 @@ class Config: model_id: str - - def __init__(self, *args, **kwargs): - try: - assert kwargs["model_id"] - except: - raise AssertionError("model_id was not specified. Please specify it as a keyword argument.") - - model_kwargs = {} - model_kwargs[self.__class__.model_id_key] = kwargs["model_id"] - - super().__init__(*args, **kwargs, **model_kwargs) - + provider_klass: ClassVar[Type[Embeddings]] -class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider, OpenAIEmbeddings): +class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider): id = "openai" name = "OpenAI" models = [ @@ -56,9 +45,10 @@ class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider, OpenAIEmbeddings): model_id_key = "model" pypi_package_deps = ["openai"] auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") + provider_klass: OpenAIEmbeddings -class CohereEmbeddingsProvider(BaseEmbeddingsProvider, CohereEmbeddings): +class CohereEmbeddingsProvider(BaseEmbeddingsProvider): id = "cohere" name = "Cohere" models = [ @@ -69,9 +59,10 @@ class CohereEmbeddingsProvider(BaseEmbeddingsProvider, CohereEmbeddings): model_id_key = "model" pypi_package_deps = ["cohere"] auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY") + provider_klass: CohereEmbeddings -class HfHubEmbeddingsProvider(BaseEmbeddingsProvider, HuggingFaceHubEmbeddings): +class HfHubEmbeddingsProvider(BaseEmbeddingsProvider): id = "huggingface_hub" name = "HuggingFace Hub" models = ["*"] @@ -81,3 +72,4 @@ class HfHubEmbeddingsProvider(BaseEmbeddingsProvider, HuggingFaceHubEmbeddings): # tqdm is a dependency of huggingface_hub pypi_package_deps = ["huggingface_hub", "ipywidgets"] auth_strategy = EnvAuthStrategy(name="HUGGINGFACEHUB_API_TOKEN") + provider_klass: HuggingFaceHubEmbeddings From 8a40a6902e566d2b2d0def8bcde544dc8cb80a0d Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Mon, 24 Apr 2023 22:16:34 -0700 Subject: [PATCH 06/16] Added missing import --- packages/jupyter-ai/jupyter_ai/handlers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index db7e7223f..3f5a59d47 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -18,7 +18,8 @@ from .task_manager import TaskManager from .models import ( - ChatHistory, + ChatHistory, + ChatUser, ListProvidersEntry, ListProvidersResponse, PromptRequest, From 2a120681517a7e36c8e758387fcbbd8b3b3a20b0 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Mon, 24 Apr 2023 23:50:46 -0700 Subject: [PATCH 07/16] Moved providers to ray actor, added config actor --- packages/jupyter-ai/jupyter_ai/actors/base.py | 2 ++ .../jupyter-ai/jupyter_ai/actors/default.py | 7 ++++++- .../jupyter-ai/jupyter_ai/actors/providers.py | 20 +++++++++++++++++++ packages/jupyter-ai/jupyter_ai/extension.py | 19 ++++++++++-------- packages/jupyter-ai/jupyter_ai/handlers.py | 17 +++++++++++++--- packages/jupyter-ai/jupyter_ai/models.py | 11 ++++++++-- 6 files changed, 62 insertions(+), 14 deletions(-) create mode 100644 packages/jupyter-ai/jupyter_ai/actors/providers.py diff --git a/packages/jupyter-ai/jupyter_ai/actors/base.py b/packages/jupyter-ai/jupyter_ai/actors/base.py index 587f84560..fe5c31e95 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/base.py +++ b/packages/jupyter-ai/jupyter_ai/actors/base.py @@ -19,6 +19,8 @@ class ACTOR_TYPE(str, Enum): LEARN = 'learn' MEMORY = 'memory' GENERATE = 'generate' + PROVIDERS = 'providers' + CONFIG = 'config' COMMANDS = { '/ask': ACTOR_TYPE.ASK, diff --git a/packages/jupyter-ai/jupyter_ai/actors/default.py b/packages/jupyter-ai/jupyter_ai/actors/default.py index 6ab33ae08..08f32ed86 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/default.py +++ b/packages/jupyter-ai/jupyter_ai/actors/default.py @@ -1,3 +1,4 @@ +from jupyter_ai_magics.utils import decompose_model_id import ray from ray.util.queue import Queue @@ -11,7 +12,7 @@ from jupyter_ai.actors.base import BaseActor, Logger, ACTOR_TYPE from jupyter_ai.actors.memory import RemoteMemory -from jupyter_ai.models import HumanChatMessage +from jupyter_ai.models import HumanChatMessage, ProviderConfig from jupyter_ai_magics.providers import ChatOpenAINewProvider SYSTEM_PROMPT = "The following is a friendly conversation between a human and an AI, whose name is Jupyter AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know." @@ -37,6 +38,10 @@ def __init__(self, reply_queue: Queue, log: Logger): ) self.chat_provider = chain + def update_chat_provider(config: ProviderConfig): + # Placeholder for updating chat provider + pass + def _process_message(self, message: HumanChatMessage): response = self.chat_provider.predict(input=message.body) self.reply(response, message) diff --git a/packages/jupyter-ai/jupyter_ai/actors/providers.py b/packages/jupyter-ai/jupyter_ai/actors/providers.py new file mode 100644 index 000000000..869ec047d --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/actors/providers.py @@ -0,0 +1,20 @@ +from jupyter_ai_magics.utils import load_embedding_providers, load_providers +import ray +from jupyter_ai.actors.base import BaseActor, Logger +from ray.util.queue import Queue + +@ray.remote +class ProvidersActor(): + + def __init__(self, log: Logger): + self.log = log + self.model_providers = load_providers(log=log) + self.embeddings_providers = load_embedding_providers(log=log) + + def get_model_providers(self): + return self.model_providers + + def get_embeddings_providers(self): + return self.embeddings_providers + + \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 1197c11de..85102e715 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -1,6 +1,8 @@ import asyncio import os import queue +from jupyter_ai.actors.config import ConfigActor +from jupyter_ai.actors.providers import ProvidersActor from jupyter_ai_magics.utils import load_embedding_providers, load_providers from langchain.memory import ConversationBufferWindowMemory from jupyter_ai.actors.default import DefaultActor @@ -94,14 +96,6 @@ def initialize_settings(self): self.settings["ai_default_tasks"] = default_tasks self.log.info("Registered all default tasks.") - providers = load_providers(log=self.log) - self.settings["chat_providers"] = providers - self.log.info("Registered providers.") - - embeddings_providers = load_embedding_providers(log=self.log) - self.settings["embeddings_providers"] = embeddings_providers - self.log.info("Registered embeddings providers.") - if ChatOpenAINewProvider.auth_strategy.name not in os.environ: raise EnvironmentError(f"`{ChatOpenAINewProvider.auth_strategy.name}` value not set in environment. For chat to work, this value should be provided.") @@ -128,6 +122,13 @@ def initialize_settings(self): reply_queue=reply_queue, log=self.log ) + providers_actor = ProvidersActor.options(name=ACTOR_TYPE.PROVIDERS.value).remote( + log=self.log + ) + config_actor = ConfigActor.options(name=ACTOR_TYPE.CONFIG.value).remote( + log=self.log, + root_dir=self.serverapp.root_dir, + ) default_actor = DefaultActor.options(name=ACTOR_TYPE.DEFAULT.value).remote( reply_queue=reply_queue, log=self.log @@ -152,6 +153,8 @@ def initialize_settings(self): ) self.settings['router'] = router + self.settings['providers_actor'] = providers_actor + self.settings['config_actor'] = config_actor self.settings["default_actor"] = default_actor self.settings["learn_actor"] = learn_actor self.settings["ask_actor"] = ask_actor diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 3f5a59d47..4e21e4e69 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -273,7 +273,9 @@ def on_close(self): class ModelProviderHandler(BaseAPIHandler): @property def chat_providers(self): - return self.settings["chat_providers"] + actor = ray.get_actor("providers") + o = actor.get_model_providers.remote() + return ray.get(o) @web.authenticated def get(self): @@ -296,7 +298,9 @@ class EmbeddingsModelProviderHandler(BaseAPIHandler): @property def embeddings_providers(self): - return self.settings['embeddings_providers'] + actor = ray.get_actor("providers") + o = actor.get_embeddings_providers.remote() + return ray.get(o) @web.authenticated def get(self): @@ -312,4 +316,11 @@ def get(self): ) response = ListProvidersResponse(providers=sorted(providers, key=lambda p: p.name)) - self.finish(response.json()) \ No newline at end of file + self.finish(response.json()) + + +class ProviderConfigHandler(BaseAPIHandler): + + @web.authenticated + def get(self): + ... \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index decb90c62..db46689e7 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -1,4 +1,5 @@ -from jupyter_ai_magics.providers import AuthStrategy +from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider +from jupyter_ai_magics.providers import AuthStrategy, BaseProvider from pydantic import BaseModel from typing import Dict, List, Union, Literal, Optional @@ -94,4 +95,10 @@ class ListProvidersEntry(BaseModel): class ListProvidersResponse(BaseModel): - providers: List[ListProvidersEntry] \ No newline at end of file + providers: List[ListProvidersEntry] + + +class ProviderConfig(BaseModel): + model_provider: str + embeddings_provider: str + api_keys: Dict[str, str] \ No newline at end of file From c25a80ee2c3c494ea20eade7b93a8a6f90f65805 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Tue, 25 Apr 2023 20:24:45 -0700 Subject: [PATCH 08/16] Ability to load llm and embeddings from config --- .../jupyter_ai_magics/__init__.py | 1 + .../jupyter_ai_magics/embedding_providers.py | 6 +-- packages/jupyter-ai-magics/pyproject.toml | 1 + packages/jupyter-ai/jupyter_ai/actors/ask.py | 29 ++++++++++--- packages/jupyter-ai/jupyter_ai/actors/base.py | 2 + .../jupyter_ai/actors/chat_provider.py | 40 ++++++++++++++++++ .../jupyter-ai/jupyter_ai/actors/default.py | 37 ++++++++++------ .../jupyter_ai/actors/embeddings_provider.py | 42 +++++++++++++++++++ .../jupyter-ai/jupyter_ai/actors/generate.py | 25 +++++++---- .../jupyter-ai/jupyter_ai/actors/learn.py | 35 ++++++++++++---- .../jupyter-ai/jupyter_ai/actors/providers.py | 13 ++++++ packages/jupyter-ai/jupyter_ai/extension.py | 16 +++++-- packages/jupyter-ai/jupyter_ai/handlers.py | 4 -- 13 files changed, 206 insertions(+), 45 deletions(-) create mode 100644 packages/jupyter-ai/jupyter_ai/actors/chat_provider.py create mode 100644 packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py index 1a947c8fc..ff35147db 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py @@ -10,6 +10,7 @@ HfHubProvider, OpenAIProvider, ChatOpenAIProvider, + ChatOpenAINewProvider, SmEndpointProvider ) # expose embedding model providers on the package root diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py index 81eaadbf1..cfc8481bc 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/embedding_providers.py @@ -45,7 +45,7 @@ class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider): model_id_key = "model" pypi_package_deps = ["openai"] auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") - provider_klass: OpenAIEmbeddings + provider_klass = OpenAIEmbeddings class CohereEmbeddingsProvider(BaseEmbeddingsProvider): @@ -59,7 +59,7 @@ class CohereEmbeddingsProvider(BaseEmbeddingsProvider): model_id_key = "model" pypi_package_deps = ["cohere"] auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY") - provider_klass: CohereEmbeddings + provider_klass = CohereEmbeddings class HfHubEmbeddingsProvider(BaseEmbeddingsProvider): @@ -72,4 +72,4 @@ class HfHubEmbeddingsProvider(BaseEmbeddingsProvider): # tqdm is a dependency of huggingface_hub pypi_package_deps = ["huggingface_hub", "ipywidgets"] auth_strategy = EnvAuthStrategy(name="HUGGINGFACEHUB_API_TOKEN") - provider_klass: HuggingFaceHubEmbeddings + provider_klass = HuggingFaceHubEmbeddings diff --git a/packages/jupyter-ai-magics/pyproject.toml b/packages/jupyter-ai-magics/pyproject.toml index 1ee2de714..a95720ecc 100644 --- a/packages/jupyter-ai-magics/pyproject.toml +++ b/packages/jupyter-ai-magics/pyproject.toml @@ -54,6 +54,7 @@ cohere = "jupyter_ai_magics:CohereProvider" huggingface_hub = "jupyter_ai_magics:HfHubProvider" openai = "jupyter_ai_magics:OpenAIProvider" openai-chat = "jupyter_ai_magics:ChatOpenAIProvider" +openai-chat-new = "jupyter_ai_magics:ChatOpenAINewProvider" sagemaker-endpoint = "jupyter_ai_magics:SmEndpointProvider" [project.entry-points."jupyter_ai.embeddings_model_providers"] diff --git a/packages/jupyter-ai/jupyter_ai/actors/ask.py b/packages/jupyter-ai/jupyter_ai/actors/ask.py index 5676154f0..9e17ba458 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/ask.py +++ b/packages/jupyter-ai/jupyter_ai/actors/ask.py @@ -1,4 +1,5 @@ import argparse +from jupyter_ai_magics.providers import BaseProvider import ray from ray.util.queue import Queue @@ -21,20 +22,35 @@ class AskActor(BaseActor): def __init__(self, reply_queue: Queue, log: Logger): super().__init__(reply_queue=reply_queue, log=log) + + self.provider = None + self.chat_provider = None + self.parser.prog = '/ask' + self.parser.add_argument('query', nargs=argparse.REMAINDER) + + def create_chat_provider(self, provider: BaseProvider): index_actor = ray.get_actor(ACTOR_TYPE.LEARN.value) handle = index_actor.get_index.remote() vectorstore = ray.get(handle) if not vectorstore: - return - + return None + self.provider = provider self.chat_history = [] self.chat_provider = ConversationalRetrievalChain.from_llm( - OpenAI(temperature=0, verbose=True), + provider, vectorstore.as_retriever() ) + - self.parser.prog = '/ask' - self.parser.add_argument('query', nargs=argparse.REMAINDER) + def _get_chat_provider(self): + actor = ray.get_actor(ACTOR_TYPE.CHAT_PROVIDER) + o = actor.get_provider.remote() + provider = ray.get(o) + if not provider: + return None + if provider.__class__.__name__ != self.provider.__class__.__name__: + self.create_chat_provider(provider) + return self.chat_provider def _process_message(self, message: HumanChatMessage): @@ -49,6 +65,9 @@ def _process_message(self, message: HumanChatMessage): index_actor = ray.get_actor(ACTOR_TYPE.LEARN.value) handle = index_actor.get_index.remote() vectorstore = ray.get(handle) + + self._get_chat_provider() + # Have to reference the latest index self.chat_provider.retriever = vectorstore.as_retriever() diff --git a/packages/jupyter-ai/jupyter_ai/actors/base.py b/packages/jupyter-ai/jupyter_ai/actors/base.py index fe5c31e95..a0e4a070e 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/base.py +++ b/packages/jupyter-ai/jupyter_ai/actors/base.py @@ -21,6 +21,8 @@ class ACTOR_TYPE(str, Enum): GENERATE = 'generate' PROVIDERS = 'providers' CONFIG = 'config' + CHAT_PROVIDER = 'chat_provider' + EMBEDDINGS_PROVIDER = 'embeddings_provider' COMMANDS = { '/ask': ACTOR_TYPE.ASK, diff --git a/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py b/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py new file mode 100644 index 000000000..7c4e85d39 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py @@ -0,0 +1,40 @@ + +import os +from typing import Optional +from jupyter_ai.actors.base import Logger, ACTOR_TYPE +from jupyter_ai.models import ProviderConfig +from jupyter_ai_magics.utils import decompose_model_id +import ray + +@ray.remote +class ChatProviderActor(): + + def __init__(self, log: Logger): + self.log = log + self.provider = None + + def update(self, config: ProviderConfig): + providers_actor = ray.get_actor(ACTOR_TYPE.PROVIDERS.value) + o = providers_actor.get_model_providers.remote() + providers = ray.get(o) + provider_id, local_model_id = decompose_model_id(model_id=config.model_provider, providers=providers) + + p = providers_actor.get_model_provider.remote(provider_id) + provider = ray.get(p) + + if not provider: + return + auth_strategy = provider.auth_strategy + api_keys = config.api_keys + if auth_strategy: + if auth_strategy.type == "env" and auth_strategy.name.lower() not in api_keys: + # raise error? + return + + provider_params = { "model_id": local_model_id} + api_key_name = auth_strategy.name.lower() + provider_params[api_key_name] = api_keys[api_key_name] + self.provider = provider(**provider_params) + + def get_provider(self): + return self.provider \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/actors/default.py b/packages/jupyter-ai/jupyter_ai/actors/default.py index 08f32ed86..38a55be34 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/default.py +++ b/packages/jupyter-ai/jupyter_ai/actors/default.py @@ -12,8 +12,8 @@ from jupyter_ai.actors.base import BaseActor, Logger, ACTOR_TYPE from jupyter_ai.actors.memory import RemoteMemory -from jupyter_ai.models import HumanChatMessage, ProviderConfig -from jupyter_ai_magics.providers import ChatOpenAINewProvider +from jupyter_ai.models import HumanChatMessage +from jupyter_ai_magics.providers import BaseProvider SYSTEM_PROMPT = "The following is a friendly conversation between a human and an AI, whose name is Jupyter AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know." @@ -21,27 +21,40 @@ class DefaultActor(BaseActor): def __init__(self, reply_queue: Queue, log: Logger): super().__init__(reply_queue=reply_queue, log=log) - provider = ChatOpenAINewProvider(model_id="gpt-3.5-turbo") - - # Create a conversation memory + self.provider = None + self.chat_provider = None + + def create_chat_provider(self, provider: BaseProvider): memory = RemoteMemory(actor_name=ACTOR_TYPE.MEMORY) prompt_template = ChatPromptTemplate.from_messages([ SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT), MessagesPlaceholder(variable_name="history"), HumanMessagePromptTemplate.from_template("{input}") ]) - chain = ConversationChain( - llm=provider, + self.provider = provider + self.chat_provider = ConversationChain( + llm=provider, prompt=prompt_template, verbose=True, memory=memory ) - self.chat_provider = chain - def update_chat_provider(config: ProviderConfig): - # Placeholder for updating chat provider - pass + def _get_chat_provider(self): + actor = ray.get_actor(ACTOR_TYPE.CHAT_PROVIDER) + o = actor.get_provider.remote() + provider = ray.get(o) + + if not provider: + return None + + if provider.__class__.__name__ != self.provider.__class__.__name__: + self.create_chat_provider(provider) + return self.chat_provider def _process_message(self, message: HumanChatMessage): - response = self.chat_provider.predict(input=message.body) + chat_provider = self._get_chat_provider() + if not chat_provider: + response = "It seems like there is no chat provider set up currently to allow chat to work. Please follow the settings panel to configure a chat provider, before using the chat." + else: + response = self.chat_provider.predict(input=message.body) self.reply(response, message) diff --git a/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py b/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py new file mode 100644 index 000000000..744d790f8 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py @@ -0,0 +1,42 @@ + + + +from typing import Optional +from jupyter_ai.actors.base import Logger, ACTOR_TYPE +from jupyter_ai.models import ProviderConfig +from jupyter_ai_magics.utils import decompose_model_id +import ray + +@ray.remote +class EmbeddingsProviderActor(): + + def __init__(self, log: Logger): + self.log = log + self.provider = None + + def update(self, config: ProviderConfig): + providers_actor = ray.get_actor(ACTOR_TYPE.PROVIDERS.value) + o = providers_actor.get_embeddings_providers.remote() + providers = ray.get(o) + provider_id, local_model_id = decompose_model_id(model_id=config.embeddings_provider, providers=providers) + + p = providers_actor.get_embeddings_provider.remote(provider_id) + provider = ray.get(p) + if not provider: + return + + auth_strategy = provider.auth_strategy + api_keys = config.api_keys + if auth_strategy: + if auth_strategy.type == "env" and auth_strategy.name.lower() not in api_keys: + # raise error? + return + + provider_params = {} + provider_params[provider.model_id_key] = local_model_id + api_key_name = auth_strategy.name.lower() + provider_params[api_key_name] = api_keys[api_key_name] + self.provider = provider.provider_klass(**provider_params) + + def get_provider(self): + return self.provider \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/actors/generate.py b/packages/jupyter-ai/jupyter_ai/actors/generate.py index bbcdef88b..8c6223c42 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/generate.py +++ b/packages/jupyter-ai/jupyter_ai/actors/generate.py @@ -15,8 +15,8 @@ import nbformat from jupyter_ai.models import AgentChatMessage, HumanChatMessage -from jupyter_ai.actors.base import BaseActor, Logger -from jupyter_ai_magics.providers import ChatOpenAINewProvider +from jupyter_ai.actors.base import ACTOR_TYPE, BaseActor, Logger +from jupyter_ai_magics.providers import BaseProvider, ChatOpenAINewProvider schema = """{ "$schema": "http://json-schema.org/draft-07/schema#", @@ -67,8 +67,6 @@ def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain: def generate_outline(description, llm=None, verbose=False): """Generate an outline of sections given a description of a notebook.""" - if llm is None: - llm = ChatOpenAINewProvider(model_id='gpt-3.5-turbo') chain = NotebookOutlineChain.from_llm(llm=llm, verbose=verbose) outline = chain.predict(description=description, schema=schema) return json.loads(outline) @@ -125,8 +123,6 @@ def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain: def generate_code(outline, llm=None, verbose=False): """Generate source code for a section given a description of the notebook and section.""" - if llm is None: - llm = ChatOpenAINewProvider(model_id='gpt-3.5-turbo') chain = NotebookSectionCodeChain.from_llm(llm=llm, verbose=verbose) code_so_far = [] for section in outline['sections']: @@ -177,8 +173,6 @@ def from_llm(cls, llm: BaseLLM, verbose: bool=False) -> LLMChain: def generate_title_and_summary(outline, llm=None, verbose=False): """Generate a title and summary of a notebook outline using an LLM.""" - if llm is None: - llm = ChatOpenAINewProvider(model_id='gpt-3.5-turbo') summary_chain = NotebookSummaryChain.from_llm(llm=llm, verbose=verbose) title_chain = NotebookTitleChain.from_llm(llm=llm, verbose=verbose) summary = summary_chain.predict(content=outline) @@ -210,9 +204,22 @@ class GenerateActor(BaseActor): def __init__(self, reply_queue: Queue, root_dir: str, log: Logger): super().__init__(log=log, reply_queue=reply_queue) self.root_dir = os.path.abspath(os.path.expanduser(root_dir)) - self.llm = ChatOpenAINewProvider(model_id='gpt-3.5-turbo') + self.llm = None + + def _get_chat_provider(self): + actor = ray.get_actor(ACTOR_TYPE.CHAT_PROVIDER) + o = actor.get_provider.remote() + provider = ray.get(o) + if not provider: + return None + if provider.__class__.__name__ != self.provider.__class__.__name__: + self.llm = provider + return self.llm def _process_message(self, message: HumanChatMessage): + llm = self._get_chat_provider() + if not llm: + return response = "👍 Great, I will get started on your notebook. It may take a few minutes, but I will reply here when the notebook is ready. In the meantime, you can continue to ask me other questions." self.reply(response, message) diff --git a/packages/jupyter-ai/jupyter_ai/actors/learn.py b/packages/jupyter-ai/jupyter_ai/actors/learn.py index c32c1b054..e6610808a 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/learn.py +++ b/packages/jupyter-ai/jupyter_ai/actors/learn.py @@ -2,6 +2,7 @@ import traceback from collections import Counter import argparse +from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider import ray from ray.util.queue import Queue @@ -16,7 +17,7 @@ ) from jupyter_ai.models import HumanChatMessage -from jupyter_ai.actors.base import BaseActor, Logger +from jupyter_ai.actors.base import ACTOR_TYPE, BaseActor, Logger from jupyter_ai_magics.providers import ChatOpenAINewProvider from jupyter_ai.document_loaders.directory import RayRecursiveDirectoryLoader from jupyter_ai.document_loaders.splitter import ExtensionSplitter, NotebookSplitter @@ -37,16 +38,17 @@ def __init__(self, reply_queue: Queue, log: Logger, root_dir: str): self.parser.add_argument('path', nargs=argparse.REMAINDER) self.index_name = 'default' self.index = None - - if ChatOpenAINewProvider.auth_strategy.name not in os.environ: - return - + self.embeddings_provider = None + if not os.path.exists(self.index_save_dir): os.makedirs(self.index_save_dir) - - self.load_or_create() + + self.load_or_create() def _process_message(self, message: HumanChatMessage): + if not self.index: + self.load_or_create() + args = self.parse_args(message) if args is None: return @@ -101,8 +103,21 @@ def delete(self): os.remove(path) self.create() + def _get_embeddings_provider(self): + actor = ray.get_actor(ACTOR_TYPE.EMBEDDINGS_PROVIDER) + o = actor.get_provider.remote() + provider = ray.get(o) + + if not provider: + return None + if provider.__class__.__name__ != self.embeddings_provider.__class__.__name__: + self.embeddings_provider = provider + return self.embeddings_provider + def create(self): - embeddings = OpenAIEmbeddings() + embeddings = self._get_embeddings_provider() + if not embeddings: + return self.index = FAISS.from_texts(["Jupyter AI knows about your filesystem, to ask questions first use the /learn command."], embeddings) self.save() @@ -111,7 +126,9 @@ def save(self): self.index.save_local(self.index_save_dir, index_name=self.index_name) def load_or_create(self): - embeddings = OpenAIEmbeddings() + embeddings = self._get_embeddings_provider() + if not embeddings: + return if self.index is None: try: self.index = FAISS.load_local(self.index_save_dir, embeddings, index_name=self.index_name) diff --git a/packages/jupyter-ai/jupyter_ai/actors/providers.py b/packages/jupyter-ai/jupyter_ai/actors/providers.py index 869ec047d..3e2c5e4d3 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/providers.py +++ b/packages/jupyter-ai/jupyter_ai/actors/providers.py @@ -1,3 +1,4 @@ +from typing import Optional from jupyter_ai_magics.utils import load_embedding_providers, load_providers import ray from jupyter_ai.actors.base import BaseActor, Logger @@ -14,6 +15,18 @@ def __init__(self, log: Logger): def get_model_providers(self): return self.model_providers + def get_model_provider(self, provider_id: Optional[str]): + if provider_id is None or provider_id not in self.model_providers: + return None + + return self.model_providers[provider_id] + + def get_embeddings_provider(self, provider_id: Optional[str]): + if provider_id is None or provider_id not in self.embeddings_providers: + return None + + return self.embeddings_providers[provider_id] + def get_embeddings_providers(self): return self.embeddings_providers diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 85102e715..d22fdd6dc 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -1,7 +1,9 @@ import asyncio import os import queue +from jupyter_ai.actors.chat_provider import ChatProviderActor from jupyter_ai.actors.config import ConfigActor +from jupyter_ai.actors.embeddings_provider import EmbeddingsProviderActor from jupyter_ai.actors.providers import ProvidersActor from jupyter_ai_magics.utils import load_embedding_providers, load_providers from langchain.memory import ConversationBufferWindowMemory @@ -96,11 +98,11 @@ def initialize_settings(self): self.settings["ai_default_tasks"] = default_tasks self.log.info("Registered all default tasks.") - if ChatOpenAINewProvider.auth_strategy.name not in os.environ: - raise EnvironmentError(f"`{ChatOpenAINewProvider.auth_strategy.name}` value not set in environment. For chat to work, this value should be provided.") + #if ChatOpenAINewProvider.auth_strategy.name not in os.environ: + # raise EnvironmentError(f"`{ChatOpenAINewProvider.auth_strategy.name}` value not set in environment. For chat to work, this value should be provided.") ## load OpenAI provider - self.settings["openai_chat"] = ChatOpenAIProvider(model_id="gpt-3.5-turbo") + #self.settings["openai_chat"] = ChatOpenAIProvider(model_id="gpt-3.5-turbo") self.log.info(f"Registered {self.name} server extension") @@ -129,6 +131,12 @@ def initialize_settings(self): log=self.log, root_dir=self.serverapp.root_dir, ) + chat_provider_actor = ChatProviderActor.options(name=ACTOR_TYPE.CHAT_PROVIDER.value).remote( + log=self.log + ) + embeddings_provider_actor = EmbeddingsProviderActor.options(name=ACTOR_TYPE.EMBEDDINGS_PROVIDER.value).remote( + log=self.log + ) default_actor = DefaultActor.options(name=ACTOR_TYPE.DEFAULT.value).remote( reply_queue=reply_queue, log=self.log @@ -155,6 +163,8 @@ def initialize_settings(self): self.settings['router'] = router self.settings['providers_actor'] = providers_actor self.settings['config_actor'] = config_actor + self.settings['chat_provider_actor'] = chat_provider_actor + self.settings['embeddings_provider_actor'] = embeddings_provider_actor self.settings["default_actor"] = default_actor self.settings["learn_actor"] = learn_actor self.settings["ask_actor"] = ask_actor diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 4e21e4e69..3a8637ddb 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -50,10 +50,6 @@ def task_manager(self): self.settings["task_manager"] = TaskManager(engines=self.engines, default_tasks=self.default_tasks) return self.settings["task_manager"] - @property - def openai_chat(self): - return self.settings["openai_chat"] - class PromptAPIHandler(APIHandler): @tornado.web.authenticated async def post(self): From acb616f1f85eb30797867bd08e8a55b8015eb7f9 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Wed, 26 Apr 2023 12:48:59 -0700 Subject: [PATCH 09/16] Moved llm creation to specific actors --- packages/jupyter-ai/jupyter_ai/actors/ask.py | 31 ++++--------- packages/jupyter-ai/jupyter_ai/actors/base.py | 43 ++++++++++++++++++- .../jupyter_ai/actors/chat_provider.py | 12 +++--- .../jupyter-ai/jupyter_ai/actors/default.py | 31 ++++--------- .../jupyter_ai/actors/embeddings_provider.py | 9 +++- .../jupyter-ai/jupyter_ai/actors/generate.py | 26 ++++------- .../jupyter-ai/jupyter_ai/actors/learn.py | 16 +------ 7 files changed, 83 insertions(+), 85 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/actors/ask.py b/packages/jupyter-ai/jupyter_ai/actors/ask.py index 9e17ba458..5829fbc81 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/ask.py +++ b/packages/jupyter-ai/jupyter_ai/actors/ask.py @@ -1,4 +1,5 @@ import argparse +from typing import Dict from jupyter_ai_magics.providers import BaseProvider import ray @@ -22,36 +23,22 @@ class AskActor(BaseActor): def __init__(self, reply_queue: Queue, log: Logger): super().__init__(reply_queue=reply_queue, log=log) - - self.provider = None - self.chat_provider = None + self.parser.prog = '/ask' self.parser.add_argument('query', nargs=argparse.REMAINDER) - def create_chat_provider(self, provider: BaseProvider): + def create_llm_chain(self, provider: BaseProvider, provider_params: Dict[str, str]): index_actor = ray.get_actor(ACTOR_TYPE.LEARN.value) handle = index_actor.get_index.remote() vectorstore = ray.get(handle) if not vectorstore: return None - self.provider = provider + self.llm = provider(**provider_params) self.chat_history = [] - self.chat_provider = ConversationalRetrievalChain.from_llm( - provider, + self.llm_chain = ConversationalRetrievalChain.from_llm( + self.llm, vectorstore.as_retriever() ) - - - def _get_chat_provider(self): - actor = ray.get_actor(ACTOR_TYPE.CHAT_PROVIDER) - o = actor.get_provider.remote() - provider = ray.get(o) - if not provider: - return None - if provider.__class__.__name__ != self.provider.__class__.__name__: - self.create_chat_provider(provider) - return self.chat_provider - def _process_message(self, message: HumanChatMessage): args = self.parse_args(message) @@ -66,12 +53,12 @@ def _process_message(self, message: HumanChatMessage): handle = index_actor.get_index.remote() vectorstore = ray.get(handle) - self._get_chat_provider() + self.get_llm_chain() # Have to reference the latest index - self.chat_provider.retriever = vectorstore.as_retriever() + self.llm_chain.retriever = vectorstore.as_retriever() - result = self.chat_provider({"question": query, "chat_history": self.chat_history}) + result = self.llm_chain({"question": query, "chat_history": self.chat_history}) response = result['answer'] self.chat_history.append((query, response)) self.reply(response, message) diff --git a/packages/jupyter-ai/jupyter_ai/actors/base.py b/packages/jupyter-ai/jupyter_ai/actors/base.py index a0e4a070e..f02b0ab54 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/base.py +++ b/packages/jupyter-ai/jupyter_ai/actors/base.py @@ -3,14 +3,16 @@ from uuid import uuid4 import time import logging -from typing import Union +from typing import Dict, Union import traceback +from jupyter_ai_magics.providers import BaseProvider +import ray + from ray.util.queue import Queue from jupyter_ai.models import HumanChatMessage, AgentChatMessage - Logger = Union[logging.Logger, logging.LoggerAdapter] class ACTOR_TYPE(str, Enum): @@ -41,6 +43,11 @@ def __init__( self.log = log self.reply_queue = reply_queue self.parser = argparse.ArgumentParser() + self.llm = None + self.llm_params = None + self.llm_chain = None + self.embeddings = None + self.embeddings_params = None def process_message(self, message: HumanChatMessage): """Processes the message passed by the `Router`""" @@ -63,6 +70,38 @@ def reply(self, response, message: HumanChatMessage): reply_to=message.id ) self.reply_queue.put(m) + + def get_llm_chain(self): + actor = ray.get_actor(ACTOR_TYPE.CHAT_PROVIDER) + handle = actor.get_provider.remote() + llm = ray.get(handle) + + handle = actor.get_provider_params.remote() + llm_params = ray.get(handle) + + if not llm: + return None + + if llm.__class__.__name__ != self.llm.__class__.__name__: + self.create_llm_chain(llm, llm_params) + return self.llm_chain + + def get_embeddings(self): + actor = ray.get_actor(ACTOR_TYPE.EMBEDDINGS_PROVIDER) + handle = actor.get_provider.remote() + provider = ray.get(handle) + + handle = actor.get_provider_params.remote() + embedding_params = ray.get(handle) + + if not provider: + return None + if provider.__class__.__name__ != self.embeddings.__class__.__name__: + self.embeddings = provider(**embedding_params) + return self.embeddings + + def create_llm_chain(self, provider: BaseProvider, provider_params: Dict[str, str]): + raise NotImplementedError("Should be implemented by subclasses") def parse_args(self, message): args = message.body.split(' ') diff --git a/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py b/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py index 7c4e85d39..afa295ce9 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py +++ b/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py @@ -1,6 +1,3 @@ - -import os -from typing import Optional from jupyter_ai.actors.base import Logger, ACTOR_TYPE from jupyter_ai.models import ProviderConfig from jupyter_ai_magics.utils import decompose_model_id @@ -12,6 +9,7 @@ class ChatProviderActor(): def __init__(self, log: Logger): self.log = log self.provider = None + self.provider_params = None def update(self, config: ProviderConfig): providers_actor = ray.get_actor(ACTOR_TYPE.PROVIDERS.value) @@ -34,7 +32,11 @@ def update(self, config: ProviderConfig): provider_params = { "model_id": local_model_id} api_key_name = auth_strategy.name.lower() provider_params[api_key_name] = api_keys[api_key_name] - self.provider = provider(**provider_params) + self.provider = provider + self.provider_params = provider_params def get_provider(self): - return self.provider \ No newline at end of file + return self.provider + + def get_provider_params(self): + return self.provider_params \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/actors/default.py b/packages/jupyter-ai/jupyter_ai/actors/default.py index 38a55be34..1574d13ec 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/default.py +++ b/packages/jupyter-ai/jupyter_ai/actors/default.py @@ -1,3 +1,4 @@ +from typing import Dict from jupyter_ai_magics.utils import decompose_model_id import ray from ray.util.queue import Queue @@ -21,40 +22,24 @@ class DefaultActor(BaseActor): def __init__(self, reply_queue: Queue, log: Logger): super().__init__(reply_queue=reply_queue, log=log) - self.provider = None - self.chat_provider = None - def create_chat_provider(self, provider: BaseProvider): + def create_llm_chain(self, provider: BaseProvider, provider_params: Dict[str, str]): + llm = provider(**provider_params) memory = RemoteMemory(actor_name=ACTOR_TYPE.MEMORY) prompt_template = ChatPromptTemplate.from_messages([ SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT), MessagesPlaceholder(variable_name="history"), HumanMessagePromptTemplate.from_template("{input}") ]) - self.provider = provider - self.chat_provider = ConversationChain( - llm=provider, + self.llm = llm + self.llm_chain = ConversationChain( + llm=llm, prompt=prompt_template, verbose=True, memory=memory ) - def _get_chat_provider(self): - actor = ray.get_actor(ACTOR_TYPE.CHAT_PROVIDER) - o = actor.get_provider.remote() - provider = ray.get(o) - - if not provider: - return None - - if provider.__class__.__name__ != self.provider.__class__.__name__: - self.create_chat_provider(provider) - return self.chat_provider - def _process_message(self, message: HumanChatMessage): - chat_provider = self._get_chat_provider() - if not chat_provider: - response = "It seems like there is no chat provider set up currently to allow chat to work. Please follow the settings panel to configure a chat provider, before using the chat." - else: - response = self.chat_provider.predict(input=message.body) + self.get_llm_chain() + response = self.llm_chain.predict(input=message.body) self.reply(response, message) diff --git a/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py b/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py index 744d790f8..b48173afe 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py +++ b/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py @@ -13,6 +13,7 @@ class EmbeddingsProviderActor(): def __init__(self, log: Logger): self.log = log self.provider = None + self.provider_params = None def update(self, config: ProviderConfig): providers_actor = ray.get_actor(ACTOR_TYPE.PROVIDERS.value) @@ -36,7 +37,11 @@ def update(self, config: ProviderConfig): provider_params[provider.model_id_key] = local_model_id api_key_name = auth_strategy.name.lower() provider_params[api_key_name] = api_keys[api_key_name] - self.provider = provider.provider_klass(**provider_params) + self.provider = provider.provider_klass + self.provider_params = provider_params def get_provider(self): - return self.provider \ No newline at end of file + return self.provider + + def get_provider_params(self): + return self.provider_params \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/actors/generate.py b/packages/jupyter-ai/jupyter_ai/actors/generate.py index 8c6223c42..6153b4480 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/generate.py +++ b/packages/jupyter-ai/jupyter_ai/actors/generate.py @@ -1,21 +1,19 @@ import json import os -import time -from uuid import uuid4 +from typing import Dict import ray from ray.util.queue import Queue from langchain.llms import BaseLLM -from langchain.chat_models import ChatOpenAI from langchain.prompts import PromptTemplate from langchain.llms import BaseLLM from langchain.chains import LLMChain import nbformat -from jupyter_ai.models import AgentChatMessage, HumanChatMessage -from jupyter_ai.actors.base import ACTOR_TYPE, BaseActor, Logger +from jupyter_ai.models import HumanChatMessage +from jupyter_ai.actors.base import BaseActor, Logger from jupyter_ai_magics.providers import BaseProvider, ChatOpenAINewProvider schema = """{ @@ -206,20 +204,14 @@ def __init__(self, reply_queue: Queue, root_dir: str, log: Logger): self.root_dir = os.path.abspath(os.path.expanduser(root_dir)) self.llm = None - def _get_chat_provider(self): - actor = ray.get_actor(ACTOR_TYPE.CHAT_PROVIDER) - o = actor.get_provider.remote() - provider = ray.get(o) - if not provider: - return None - if provider.__class__.__name__ != self.provider.__class__.__name__: - self.llm = provider - return self.llm + def create_llm_chain(self, provider: BaseProvider, provider_params: Dict[str, str]): + llm = provider(**provider_params) + self.llm = llm + return llm def _process_message(self, message: HumanChatMessage): - llm = self._get_chat_provider() - if not llm: - return + self.get_llm_chain() + response = "👍 Great, I will get started on your notebook. It may take a few minutes, but I will reply here when the notebook is ready. In the meantime, you can continue to ask me other questions." self.reply(response, message) diff --git a/packages/jupyter-ai/jupyter_ai/actors/learn.py b/packages/jupyter-ai/jupyter_ai/actors/learn.py index e6610808a..f890b3780 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/learn.py +++ b/packages/jupyter-ai/jupyter_ai/actors/learn.py @@ -38,7 +38,6 @@ def __init__(self, reply_queue: Queue, log: Logger, root_dir: str): self.parser.add_argument('path', nargs=argparse.REMAINDER) self.index_name = 'default' self.index = None - self.embeddings_provider = None if not os.path.exists(self.index_save_dir): os.makedirs(self.index_save_dir) @@ -102,20 +101,9 @@ def delete(self): if os.path.isfile(path): os.remove(path) self.create() - - def _get_embeddings_provider(self): - actor = ray.get_actor(ACTOR_TYPE.EMBEDDINGS_PROVIDER) - o = actor.get_provider.remote() - provider = ray.get(o) - - if not provider: - return None - if provider.__class__.__name__ != self.embeddings_provider.__class__.__name__: - self.embeddings_provider = provider - return self.embeddings_provider def create(self): - embeddings = self._get_embeddings_provider() + embeddings = self.get_embeddings() if not embeddings: return self.index = FAISS.from_texts(["Jupyter AI knows about your filesystem, to ask questions first use the /learn command."], embeddings) @@ -126,7 +114,7 @@ def save(self): self.index.save_local(self.index_save_dir, index_name=self.index_name) def load_or_create(self): - embeddings = self._get_embeddings_provider() + embeddings = self.get_embeddings() if not embeddings: return if self.index is None: From 4a7767789e04c873cc9feed0b0acdeed40da370c Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Wed, 26 Apr 2023 20:49:11 -0700 Subject: [PATCH 10/16] Added apis for fetching, updating config. Fixed config update, error handling --- .../jupyter_ai/actors/chat_provider.py | 6 +-- .../jupyter_ai/actors/embeddings_provider.py | 5 +-- packages/jupyter-ai/jupyter_ai/extension.py | 11 +++++- packages/jupyter-ai/jupyter_ai/handlers.py | 38 ++++++++++++++++++- 4 files changed, 51 insertions(+), 9 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py b/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py index afa295ce9..5f3f558fb 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py +++ b/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py @@ -21,13 +21,13 @@ def update(self, config: ProviderConfig): provider = ray.get(p) if not provider: - return + raise ValueError(f"No provider and model found with '{config.model_provider}'") + auth_strategy = provider.auth_strategy api_keys = config.api_keys if auth_strategy: if auth_strategy.type == "env" and auth_strategy.name.lower() not in api_keys: - # raise error? - return + raise ValueError(f"Missing value for '{auth_strategy.name}' in the config.") provider_params = { "model_id": local_model_id} api_key_name = auth_strategy.name.lower() diff --git a/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py b/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py index b48173afe..1d8d9075d 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py +++ b/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py @@ -24,14 +24,13 @@ def update(self, config: ProviderConfig): p = providers_actor.get_embeddings_provider.remote(provider_id) provider = ray.get(p) if not provider: - return + raise ValueError(f"No provider and model found with '{config.embeddings_provider}'") auth_strategy = provider.auth_strategy api_keys = config.api_keys if auth_strategy: if auth_strategy.type == "env" and auth_strategy.name.lower() not in api_keys: - # raise error? - return + raise ValueError(f"Missing value for '{auth_strategy.name}' in the config.") provider_params = {} provider_params[provider.model_id_key] = local_model_id diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index d22fdd6dc..27d28b3ef 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -16,7 +16,15 @@ from jupyter_ai.actors.base import ACTOR_TYPE from jupyter_ai.reply_processor import ReplyProcessor from jupyter_server.extension.application import ExtensionApp -from .handlers import ChatHandler, ChatHistoryHandler, EmbeddingsModelProviderHandler, ModelProviderHandler, PromptAPIHandler, TaskAPIHandler +from .handlers import ( + ChatHandler, + ChatHistoryHandler, + EmbeddingsModelProviderHandler, + ModelProviderHandler, + PromptAPIHandler, + TaskAPIHandler, + ProviderConfigHandler +) from importlib_metadata import entry_points import inspect from .engine import BaseModelEngine @@ -29,6 +37,7 @@ class AiExtension(ExtensionApp): name = "jupyter_ai" handlers = [ + ("api/ai/config", ProviderConfigHandler), ("api/ai/prompt", PromptAPIHandler), (r"api/ai/tasks/?", TaskAPIHandler), (r"api/ai/tasks/([\w\-:]*)", TaskAPIHandler), diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 3a8637ddb..fd1aac216 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -1,6 +1,7 @@ from dataclasses import asdict import json from typing import Dict, List +from jupyter_ai.actors.base import ACTOR_TYPE import ray import tornado import uuid @@ -29,7 +30,8 @@ AgentChatMessage, HumanChatMessage, ConnectionMessage, - ChatClient + ChatClient, + ProviderConfig ) @@ -316,7 +318,39 @@ def get(self): class ProviderConfigHandler(BaseAPIHandler): + """API handler for fetching and setting the + model and emebddings provider config. + """ @web.authenticated def get(self): - ... \ No newline at end of file + actor = ray.get_actor(ACTOR_TYPE.CONFIG) + handle = actor.get_config.remote() + config = ray.get(handle) + if not config: + raise HTTPError(500, "No config found.") + + self.finish(config.json()) + + @web.authenticated + def post(self): + try: + config = ProviderConfig(**self.get_json_body()) + actor = ray.get_actor(ACTOR_TYPE.CONFIG) + handle = actor.update.remote(config) + ray.get(handle) + + self.set_status(204) + self.finish() + + except ValidationError as e: + self.log.exception(e) + raise HTTPError(500, str(e)) from e + except ValueError as e: + self.log.exception(e) + raise HTTPError(500, str(e.cause) if hasattr(e, 'cause') else str(e)) + except Exception as e: + self.log.exception(e) + raise HTTPError( + 500, "Unexpected error occurred while updating the config." + ) from e \ No newline at end of file From f2d4dfc8e598428e9a5f5904a1cdcf4eec5f0593 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Fri, 28 Apr 2023 21:46:21 -0700 Subject: [PATCH 11/16] Updated as per PR feedback --- packages/jupyter-ai/.gitignore | 3 -- packages/jupyter-ai/jupyter_ai/actors/ask.py | 14 +++-- packages/jupyter-ai/jupyter_ai/actors/base.py | 19 +++---- .../jupyter_ai/actors/chat_provider.py | 36 ++++++------- .../jupyter-ai/jupyter_ai/actors/config.py | 53 +++++++++++++++++++ .../jupyter-ai/jupyter_ai/actors/default.py | 5 +- .../jupyter_ai/actors/embeddings_provider.py | 43 +++++++-------- .../jupyter-ai/jupyter_ai/actors/generate.py | 4 +- .../jupyter-ai/jupyter_ai/actors/learn.py | 11 ++-- .../jupyter-ai/jupyter_ai/actors/providers.py | 35 +++++++----- packages/jupyter-ai/jupyter_ai/extension.py | 20 ++----- packages/jupyter-ai/jupyter_ai/handlers.py | 23 +++----- packages/jupyter-ai/jupyter_ai/models.py | 6 +-- 13 files changed, 145 insertions(+), 127 deletions(-) create mode 100644 packages/jupyter-ai/jupyter_ai/actors/config.py diff --git a/packages/jupyter-ai/.gitignore b/packages/jupyter-ai/.gitignore index 7fa065974..56891ff87 100644 --- a/packages/jupyter-ai/.gitignore +++ b/packages/jupyter-ai/.gitignore @@ -119,9 +119,6 @@ dmypy.json # OSX files .DS_Store -# local config storing authn credentials -config.py - # vscode .vscode diff --git a/packages/jupyter-ai/jupyter_ai/actors/ask.py b/packages/jupyter-ai/jupyter_ai/actors/ask.py index 5829fbc81..f19b2f71e 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/ask.py +++ b/packages/jupyter-ai/jupyter_ai/actors/ask.py @@ -1,11 +1,10 @@ import argparse -from typing import Dict +from typing import Dict, Type from jupyter_ai_magics.providers import BaseProvider import ray from ray.util.queue import Queue -from langchain import OpenAI from langchain.chains import ConversationalRetrievalChain from jupyter_ai.models import HumanChatMessage @@ -27,12 +26,12 @@ def __init__(self, reply_queue: Queue, log: Logger): self.parser.prog = '/ask' self.parser.add_argument('query', nargs=argparse.REMAINDER) - def create_llm_chain(self, provider: BaseProvider, provider_params: Dict[str, str]): + def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]): index_actor = ray.get_actor(ACTOR_TYPE.LEARN.value) - handle = index_actor.get_index.remote() - vectorstore = ray.get(handle) + vectorstore = ray.get(index_actor.get_index.remote()) if not vectorstore: return None + self.llm = provider(**provider_params) self.chat_history = [] self.llm_chain = ConversationalRetrievalChain.from_llm( @@ -50,9 +49,8 @@ def _process_message(self, message: HumanChatMessage): return index_actor = ray.get_actor(ACTOR_TYPE.LEARN.value) - handle = index_actor.get_index.remote() - vectorstore = ray.get(handle) - + vectorstore = ray.get(index_actor.get_index.remote()) + self.get_llm_chain() # Have to reference the latest index diff --git a/packages/jupyter-ai/jupyter_ai/actors/base.py b/packages/jupyter-ai/jupyter_ai/actors/base.py index f02b0ab54..c1485aeba 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/base.py +++ b/packages/jupyter-ai/jupyter_ai/actors/base.py @@ -3,7 +3,7 @@ from uuid import uuid4 import time import logging -from typing import Dict, Union +from typing import Dict, Type, Union import traceback from jupyter_ai_magics.providers import BaseProvider @@ -73,11 +73,8 @@ def reply(self, response, message: HumanChatMessage): def get_llm_chain(self): actor = ray.get_actor(ACTOR_TYPE.CHAT_PROVIDER) - handle = actor.get_provider.remote() - llm = ray.get(handle) - - handle = actor.get_provider_params.remote() - llm_params = ray.get(handle) + llm = ray.get(actor.get_provider.remote()) + llm_params = ray.get(actor.get_provider_params.remote()) if not llm: return None @@ -88,19 +85,17 @@ def get_llm_chain(self): def get_embeddings(self): actor = ray.get_actor(ACTOR_TYPE.EMBEDDINGS_PROVIDER) - handle = actor.get_provider.remote() - provider = ray.get(handle) - - handle = actor.get_provider_params.remote() - embedding_params = ray.get(handle) + provider = ray.get(actor.get_provider.remote()) + embedding_params = ray.get(actor.get_provider_params.remote()) if not provider: return None + if provider.__class__.__name__ != self.embeddings.__class__.__name__: self.embeddings = provider(**embedding_params) return self.embeddings - def create_llm_chain(self, provider: BaseProvider, provider_params: Dict[str, str]): + def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]): raise NotImplementedError("Should be implemented by subclasses") def parse_args(self, message): diff --git a/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py b/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py index 5f3f558fb..12e37181c 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py +++ b/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py @@ -1,6 +1,5 @@ from jupyter_ai.actors.base import Logger, ACTOR_TYPE -from jupyter_ai.models import ProviderConfig -from jupyter_ai_magics.utils import decompose_model_id +from jupyter_ai.models import GlobalConfig import ray @ray.remote @@ -11,29 +10,28 @@ def __init__(self, log: Logger): self.provider = None self.provider_params = None - def update(self, config: ProviderConfig): - providers_actor = ray.get_actor(ACTOR_TYPE.PROVIDERS.value) - o = providers_actor.get_model_providers.remote() - providers = ray.get(o) - provider_id, local_model_id = decompose_model_id(model_id=config.model_provider, providers=providers) - - p = providers_actor.get_model_provider.remote(provider_id) - provider = ray.get(p) + def update(self, config: GlobalConfig): + model_id = config.model_provider_id + actor = ray.get_actor(ACTOR_TYPE.PROVIDERS.value) + local_model_id, provider = ray.get( + actor.get_model_provider_data.remote(model_id) + ) if not provider: - raise ValueError(f"No provider and model found with '{config.model_provider}'") + raise ValueError(f"No provider and model found with '{model_id}'") + + provider_params = { "model_id": local_model_id} auth_strategy = provider.auth_strategy - api_keys = config.api_keys - if auth_strategy: - if auth_strategy.type == "env" and auth_strategy.name.lower() not in api_keys: + if auth_strategy and auth_strategy.type == "env": + api_keys = config.api_keys + name = auth_strategy.name.lower() + if name not in api_keys: raise ValueError(f"Missing value for '{auth_strategy.name}' in the config.") + provider_params[name] = api_keys[name] - provider_params = { "model_id": local_model_id} - api_key_name = auth_strategy.name.lower() - provider_params[api_key_name] = api_keys[api_key_name] - self.provider = provider - self.provider_params = provider_params + self.provider = provider + self.provider_params = provider_params def get_provider(self): return self.provider diff --git a/packages/jupyter-ai/jupyter_ai/actors/config.py b/packages/jupyter-ai/jupyter_ai/actors/config.py new file mode 100644 index 000000000..9a534e872 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/actors/config.py @@ -0,0 +1,53 @@ +import json +import os +from jupyter_ai.actors.base import ACTOR_TYPE, Logger +from jupyter_ai.models import GlobalConfig +import ray +from jupyter_core.paths import jupyter_data_dir + + +@ray.remote +class ConfigActor(): + """Provides model and embedding provider id along + with the credentials to authenticate providers. + """ + + def __init__(self, log: Logger): + self.log = log + self.save_dir = os.path.join(jupyter_data_dir(), 'jupyter_ai') + self.save_path = os.path.join(self.save_dir, 'config.json') + self.config = None + self._load() + + def update(self, config: GlobalConfig, save_to_disk: bool = True): + self._update_chat_provider(config) + self._update_embeddings_provider(config) + if save_to_disk: + self._save() + self.config = config + + def _update_chat_provider(self, config: GlobalConfig): + actor = ray.get_actor(ACTOR_TYPE.CHAT_PROVIDER) + handle = actor.update.remote(config) + ray.get(handle) + + def _update_embeddings_provider(self, config: GlobalConfig): + actor = ray.get_actor(ACTOR_TYPE.EMBEDDINGS_PROVIDER) + handle = actor.update.remote(config) + ray.get(handle) + + def _save(self, config: GlobalConfig): + if not os.path.exists: + os.makedirs(self.save_dir) + + with open(self.save_path, 'w') as f: + f.write(json.dumps(config)) + + def _load(self): + if os.path.exists(self.save_path): + with open(self.save_path, 'r', encoding='utf-8') as f: + config = GlobalConfig(**json.loads(f.read())) + self.update(config, False) + + def get_config(self): + return self.config \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/actors/default.py b/packages/jupyter-ai/jupyter_ai/actors/default.py index 1574d13ec..83edc36f1 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/default.py +++ b/packages/jupyter-ai/jupyter_ai/actors/default.py @@ -1,5 +1,4 @@ -from typing import Dict -from jupyter_ai_magics.utils import decompose_model_id +from typing import Dict, Type import ray from ray.util.queue import Queue @@ -23,7 +22,7 @@ class DefaultActor(BaseActor): def __init__(self, reply_queue: Queue, log: Logger): super().__init__(reply_queue=reply_queue, log=log) - def create_llm_chain(self, provider: BaseProvider, provider_params: Dict[str, str]): + def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]): llm = provider(**provider_params) memory = RemoteMemory(actor_name=ACTOR_TYPE.MEMORY) prompt_template = ChatPromptTemplate.from_messages([ diff --git a/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py b/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py index 1d8d9075d..328f0e3a5 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py +++ b/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py @@ -1,10 +1,5 @@ - - - -from typing import Optional from jupyter_ai.actors.base import Logger, ACTOR_TYPE -from jupyter_ai.models import ProviderConfig -from jupyter_ai_magics.utils import decompose_model_id +from jupyter_ai.models import GlobalConfig import ray @ray.remote @@ -15,29 +10,29 @@ def __init__(self, log: Logger): self.provider = None self.provider_params = None - def update(self, config: ProviderConfig): - providers_actor = ray.get_actor(ACTOR_TYPE.PROVIDERS.value) - o = providers_actor.get_embeddings_providers.remote() - providers = ray.get(o) - provider_id, local_model_id = decompose_model_id(model_id=config.embeddings_provider, providers=providers) - - p = providers_actor.get_embeddings_provider.remote(provider_id) - provider = ray.get(p) + def update(self, config: GlobalConfig): + model_id = config.embeddings_provider_id + actor = ray.get_actor(ACTOR_TYPE.PROVIDERS.value) + local_model_id, provider = ray.get( + actor.get_embeddings_provider_data.remote(model_id) + ) + if not provider: - raise ValueError(f"No provider and model found with '{config.embeddings_provider}'") + raise ValueError(f"No provider and model found with '{model_id}'") + + provider_params = {} + provider_params[provider.model_id_key] = local_model_id auth_strategy = provider.auth_strategy - api_keys = config.api_keys - if auth_strategy: - if auth_strategy.type == "env" and auth_strategy.name.lower() not in api_keys: + if auth_strategy and auth_strategy.type == "env": + api_keys = config.api_keys + name = auth_strategy.name.lower() + if name not in api_keys: raise ValueError(f"Missing value for '{auth_strategy.name}' in the config.") + provider_params[name] = api_keys[name] - provider_params = {} - provider_params[provider.model_id_key] = local_model_id - api_key_name = auth_strategy.name.lower() - provider_params[api_key_name] = api_keys[api_key_name] - self.provider = provider.provider_klass - self.provider_params = provider_params + self.provider = provider.provider_klass + self.provider_params = provider_params def get_provider(self): return self.provider diff --git a/packages/jupyter-ai/jupyter_ai/actors/generate.py b/packages/jupyter-ai/jupyter_ai/actors/generate.py index 6153b4480..a240078b5 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/generate.py +++ b/packages/jupyter-ai/jupyter_ai/actors/generate.py @@ -1,6 +1,6 @@ import json import os -from typing import Dict +from typing import Dict, Type import ray from ray.util.queue import Queue @@ -204,7 +204,7 @@ def __init__(self, reply_queue: Queue, root_dir: str, log: Logger): self.root_dir = os.path.abspath(os.path.expanduser(root_dir)) self.llm = None - def create_llm_chain(self, provider: BaseProvider, provider_params: Dict[str, str]): + def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]): llm = provider(**provider_params) self.llm = llm return llm diff --git a/packages/jupyter-ai/jupyter_ai/actors/learn.py b/packages/jupyter-ai/jupyter_ai/actors/learn.py index f890b3780..86f238332 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/learn.py +++ b/packages/jupyter-ai/jupyter_ai/actors/learn.py @@ -1,8 +1,5 @@ import os -import traceback -from collections import Counter import argparse -from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider import ray from ray.util.queue import Queue @@ -10,15 +7,13 @@ from jupyter_core.paths import jupyter_data_dir from langchain import FAISS -from langchain.embeddings.openai import OpenAIEmbeddings from langchain.text_splitter import ( RecursiveCharacterTextSplitter, PythonCodeTextSplitter, MarkdownTextSplitter, LatexTextSplitter ) from jupyter_ai.models import HumanChatMessage -from jupyter_ai.actors.base import ACTOR_TYPE, BaseActor, Logger -from jupyter_ai_magics.providers import ChatOpenAINewProvider +from jupyter_ai.actors.base import BaseActor, Logger from jupyter_ai.document_loaders.directory import RayRecursiveDirectoryLoader from jupyter_ai.document_loaders.splitter import ExtensionSplitter, NotebookSplitter @@ -48,6 +43,10 @@ def _process_message(self, message: HumanChatMessage): if not self.index: self.load_or_create() + # If index is not still there, embeddings are not present + if not self.index: + self.reply("Sorry, please select an embedding provider before using the `/learn` command.") + args = self.parse_args(message) if args is None: return diff --git a/packages/jupyter-ai/jupyter_ai/actors/providers.py b/packages/jupyter-ai/jupyter_ai/actors/providers.py index 3e2c5e4d3..fd249ede9 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/providers.py +++ b/packages/jupyter-ai/jupyter_ai/actors/providers.py @@ -1,11 +1,17 @@ -from typing import Optional -from jupyter_ai_magics.utils import load_embedding_providers, load_providers +from typing import Optional, Tuple, Type +from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider +from jupyter_ai_magics.providers import BaseProvider +from jupyter_ai_magics.utils import decompose_model_id, load_embedding_providers, load_providers import ray from jupyter_ai.actors.base import BaseActor, Logger from ray.util.queue import Queue @ray.remote class ProvidersActor(): + """Actor that loads model and embedding providers from, + entry points. Also provides utility functions to get the + providers and provider class matching a provider id. + """ def __init__(self, log: Logger): self.log = log @@ -13,21 +19,24 @@ def __init__(self, log: Logger): self.embeddings_providers = load_embedding_providers(log=log) def get_model_providers(self): + """Returns dictionary of registered LLM providers""" return self.model_providers - def get_model_provider(self, provider_id: Optional[str]): - if provider_id is None or provider_id not in self.model_providers: - return None - - return self.model_providers[provider_id] - - def get_embeddings_provider(self, provider_id: Optional[str]): - if provider_id is None or provider_id not in self.embeddings_providers: - return None - - return self.embeddings_providers[provider_id] + def get_model_provider_data(self, model_id: str) -> Tuple[str, Type[BaseProvider]]: + """Returns the model provider class that matches the provider id""" + provider_id, local_model_id = decompose_model_id(model_id, self.model_providers) + provider = self.model_providers.get(provider_id, None) + return local_model_id, provider def get_embeddings_providers(self): + """Returns dictionary of registered embedding providers""" return self.embeddings_providers + def get_embeddings_provider_data(self, model_id: str) -> Tuple[str, Type[BaseEmbeddingsProvider]]: + """Returns the embedding provider class that matches the provider id""" + provider_id, local_model_id = decompose_model_id(model_id, self.embeddings_providers) + provider = self.embeddings_providers.get(provider_id, None) + return local_model_id, provider + + \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 27d28b3ef..356ee7ae7 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -1,11 +1,8 @@ import asyncio -import os -import queue from jupyter_ai.actors.chat_provider import ChatProviderActor from jupyter_ai.actors.config import ConfigActor from jupyter_ai.actors.embeddings_provider import EmbeddingsProviderActor from jupyter_ai.actors.providers import ProvidersActor -from jupyter_ai_magics.utils import load_embedding_providers, load_providers from langchain.memory import ConversationBufferWindowMemory from jupyter_ai.actors.default import DefaultActor from jupyter_ai.actors.ask import AskActor @@ -23,12 +20,11 @@ ModelProviderHandler, PromptAPIHandler, TaskAPIHandler, - ProviderConfigHandler + GlobalConfigHandler ) from importlib_metadata import entry_points import inspect from .engine import BaseModelEngine -from jupyter_ai_magics.providers import ChatOpenAINewProvider, ChatOpenAIProvider import ray from ray.util.queue import Queue @@ -37,7 +33,7 @@ class AiExtension(ExtensionApp): name = "jupyter_ai" handlers = [ - ("api/ai/config", ProviderConfigHandler), + ("api/ai/config", GlobalConfigHandler), ("api/ai/prompt", PromptAPIHandler), (r"api/ai/tasks/?", TaskAPIHandler), (r"api/ai/tasks/([\w\-:]*)", TaskAPIHandler), @@ -107,17 +103,8 @@ def initialize_settings(self): self.settings["ai_default_tasks"] = default_tasks self.log.info("Registered all default tasks.") - #if ChatOpenAINewProvider.auth_strategy.name not in os.environ: - # raise EnvironmentError(f"`{ChatOpenAINewProvider.auth_strategy.name}` value not set in environment. For chat to work, this value should be provided.") - - ## load OpenAI provider - #self.settings["openai_chat"] = ChatOpenAIProvider(model_id="gpt-3.5-turbo") - self.log.info(f"Registered {self.name} server extension") - # Add a message queue to the settings to be used by the chat handler - self.settings["chat_message_queue"] = queue.Queue() - # Store chat clients in a dictionary self.settings["chat_clients"] = {} self.settings["chat_handlers"] = {} @@ -137,8 +124,7 @@ def initialize_settings(self): log=self.log ) config_actor = ConfigActor.options(name=ACTOR_TYPE.CONFIG.value).remote( - log=self.log, - root_dir=self.serverapp.root_dir, + log=self.log ) chat_provider_actor = ChatProviderActor.options(name=ACTOR_TYPE.CHAT_PROVIDER.value).remote( log=self.log diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index fd1aac216..843203f28 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -31,7 +31,7 @@ HumanChatMessage, ConnectionMessage, ChatClient, - ProviderConfig + GlobalConfig ) @@ -117,15 +117,6 @@ class ChatHandler( """ A websocket handler for chat. """ - - _chat_provider = None - _chat_message_queue = None - - @property - def chat_message_queue(self): - if self._chat_message_queue is None: - self._chat_message_queue = self.settings["chat_message_queue"] - return self._chat_message_queue @property def chat_handlers(self) -> Dict[str, 'ChatHandler']: @@ -317,16 +308,15 @@ def get(self): self.finish(response.json()) -class ProviderConfigHandler(BaseAPIHandler): +class GlobalConfigHandler(BaseAPIHandler): """API handler for fetching and setting the - model and emebddings provider config. + model and emebddings config. """ @web.authenticated def get(self): actor = ray.get_actor(ACTOR_TYPE.CONFIG) - handle = actor.get_config.remote() - config = ray.get(handle) + config = ray.get(actor.get_config.remote()) if not config: raise HTTPError(500, "No config found.") @@ -335,10 +325,9 @@ def get(self): @web.authenticated def post(self): try: - config = ProviderConfig(**self.get_json_body()) + config = GlobalConfig(**self.get_json_body()) actor = ray.get_actor(ACTOR_TYPE.CONFIG) - handle = actor.update.remote(config) - ray.get(handle) + ray.get(actor.update.remote(config)) self.set_status(204) self.finish() diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index db46689e7..98145e4d5 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -98,7 +98,7 @@ class ListProvidersResponse(BaseModel): providers: List[ListProvidersEntry] -class ProviderConfig(BaseModel): - model_provider: str - embeddings_provider: str +class GlobalConfig(BaseModel): + model_provider_id: str + embeddings_provider_id: str api_keys: Dict[str, str] \ No newline at end of file From 97156a64af9556ff0ba612147695d018dabfb348 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Sat, 29 Apr 2023 22:37:20 -0700 Subject: [PATCH 12/16] Fixes issue with cohere embeddings, api keys not working --- packages/jupyter-ai/jupyter_ai/actors/ask.py | 35 +++++++++++-------- .../jupyter-ai/jupyter_ai/actors/learn.py | 8 +++++ 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/actors/ask.py b/packages/jupyter-ai/jupyter_ai/actors/ask.py index f19b2f71e..1c4f9c9b7 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/ask.py +++ b/packages/jupyter-ai/jupyter_ai/actors/ask.py @@ -1,11 +1,12 @@ import argparse -from typing import Dict, Type +from typing import Dict, List, Type from jupyter_ai_magics.providers import BaseProvider import ray from ray.util.queue import Queue from langchain.chains import ConversationalRetrievalChain +from langchain.schema import BaseRetriever, Document from jupyter_ai.models import HumanChatMessage from jupyter_ai.actors.base import ACTOR_TYPE, BaseActor, Logger @@ -27,16 +28,12 @@ def __init__(self, reply_queue: Queue, log: Logger): self.parser.add_argument('query', nargs=argparse.REMAINDER) def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]): - index_actor = ray.get_actor(ACTOR_TYPE.LEARN.value) - vectorstore = ray.get(index_actor.get_index.remote()) - if not vectorstore: - return None - + retriever = Retriever() self.llm = provider(**provider_params) self.chat_history = [] self.llm_chain = ConversationalRetrievalChain.from_llm( self.llm, - vectorstore.as_retriever() + retriever ) def _process_message(self, message: HumanChatMessage): @@ -48,15 +45,25 @@ def _process_message(self, message: HumanChatMessage): self.reply(f"{self.parser.format_usage()}", message) return - index_actor = ray.get_actor(ACTOR_TYPE.LEARN.value) - vectorstore = ray.get(index_actor.get_index.remote()) - self.get_llm_chain() - - # Have to reference the latest index - self.llm_chain.retriever = vectorstore.as_retriever() - + result = self.llm_chain({"question": query, "chat_history": self.chat_history}) response = result['answer'] self.chat_history.append((query, response)) self.reply(response, message) + + +class Retriever(BaseRetriever): + """Wrapper retriever class to get relevant docs + from the vector store, this is important because + of inconsistent de-serialization of index when it's + accessed directly from the ask actor. + """ + + def get_relevant_documents(self, question: str): + index_actor = ray.get_actor(ACTOR_TYPE.LEARN.value) + docs = ray.get(index_actor.get_relevant_documents.remote(question)) + return docs + + async def aget_relevant_documents(self, query: str) -> List[Document]: + return await super().aget_relevant_documents(query) \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/actors/learn.py b/packages/jupyter-ai/jupyter_ai/actors/learn.py index 86f238332..cc9492092 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/learn.py +++ b/packages/jupyter-ai/jupyter_ai/actors/learn.py @@ -1,5 +1,6 @@ import os import argparse +from typing import List import ray from ray.util.queue import Queue @@ -11,6 +12,7 @@ RecursiveCharacterTextSplitter, PythonCodeTextSplitter, MarkdownTextSplitter, LatexTextSplitter ) +from langchain.schema import Document from jupyter_ai.models import HumanChatMessage from jupyter_ai.actors.base import BaseActor, Logger @@ -121,3 +123,9 @@ def load_or_create(self): self.index = FAISS.load_local(self.index_save_dir, embeddings, index_name=self.index_name) except Exception as e: self.create() + + def get_relevant_documents(self, question: str) -> List[Document]: + if self.index: + docs = self.index.similarity_search(question) + return docs + return [] From 03b840e8e8ddfef57c753a7dcd7210a487c7fbfd Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Sat, 29 Apr 2023 23:41:21 -0700 Subject: [PATCH 13/16] Added an error check when embedding change causes read error --- packages/jupyter-ai/jupyter_ai/actors/ask.py | 17 +++++++++++++---- packages/jupyter-ai/jupyter_ai/actors/base.py | 1 + packages/jupyter-ai/jupyter_ai/extension.py | 5 +++++ 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/actors/ask.py b/packages/jupyter-ai/jupyter_ai/actors/ask.py index 1c4f9c9b7..e78837ca4 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/ask.py +++ b/packages/jupyter-ai/jupyter_ai/actors/ask.py @@ -47,10 +47,19 @@ def _process_message(self, message: HumanChatMessage): self.get_llm_chain() - result = self.llm_chain({"question": query, "chat_history": self.chat_history}) - response = result['answer'] - self.chat_history.append((query, response)) - self.reply(response, message) + try: + result = self.llm_chain({"question": query, "chat_history": self.chat_history}) + response = result['answer'] + self.chat_history.append((query, response)) + self.reply(response, message) + except AssertionError as e: + self.log.error(e) + response = """Sorry, an error occurred while reading the from the learned documents. + If you have changed the embedding provider, try deleting the existing index by running + `/learn -d` command and then re-submitting the `learn ` to learn the documents, + and then asking the question again. + """ + self.reply(response, message) class Retriever(BaseRetriever): diff --git a/packages/jupyter-ai/jupyter_ai/actors/base.py b/packages/jupyter-ai/jupyter_ai/actors/base.py index c1485aeba..d885084a0 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/base.py +++ b/packages/jupyter-ai/jupyter_ai/actors/base.py @@ -93,6 +93,7 @@ def get_embeddings(self): if provider.__class__.__name__ != self.embeddings.__class__.__name__: self.embeddings = provider(**embedding_params) + return self.embeddings def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]): diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 356ee7ae7..6a316a050 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -28,6 +28,7 @@ import ray from ray.util.queue import Queue +from jupyter_ai_magics.utils import load_providers class AiExtension(ExtensionApp): @@ -103,6 +104,10 @@ def initialize_settings(self): self.settings["ai_default_tasks"] = default_tasks self.log.info("Registered all default tasks.") + providers = load_providers(log=self.log) + self.settings["chat_providers"] = providers + self.log.info("Registered providers.") + self.log.info(f"Registered {self.name} server extension") # Store chat clients in a dictionary From f0f6958663489a541f9f3e20156462ce17a27dff Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Wed, 3 May 2023 11:40:40 -0700 Subject: [PATCH 14/16] Delete and re-index docs when embedding model changes (#137) * Added an error check when embedding change causes read error * Refactored provider load, decompose logic, aded model provider list api * Re-indexes dirs when embeddings change, learn list command * Fixed typo, simplified adding metadata * Moved index dir, metadata path to constants --- .../jupyter_ai_magics/magics.py | 4 +- .../jupyter_ai_magics/providers.py | 1 + .../jupyter_ai_magics/utils.py | 4 +- packages/jupyter-ai/jupyter_ai/actors/base.py | 10 +- .../jupyter-ai/jupyter_ai/actors/config.py | 10 +- .../jupyter_ai/actors/embeddings_provider.py | 13 ++- .../jupyter-ai/jupyter_ai/actors/learn.py | 109 +++++++++++++++--- packages/jupyter-ai/jupyter_ai/extension.py | 6 + packages/jupyter-ai/jupyter_ai/handlers.py | 3 +- packages/jupyter-ai/jupyter_ai/models.py | 15 ++- 10 files changed, 137 insertions(+), 38 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index 7fc61889d..927766ba9 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -8,11 +8,8 @@ from IPython import get_ipython from IPython.core.magic import Magics, magics_class, line_cell_magic from IPython.core.magic_arguments import magic_arguments, argument, parse_argstring - from IPython.display import HTML, JSON, Markdown, Math - from jupyter_ai_magics.utils import decompose_model_id, load_providers - from .providers import BaseProvider @@ -37,6 +34,7 @@ def _repr_mimebundle_(self, include=None, exclude=None): } ) + class TextWithMetadata(object): def __init__(self, text, metadata): self.text = text diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 3c765001f..22507acbc 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -15,6 +15,7 @@ SagemakerEndpoint ) from langchain.utils import get_from_dict_or_env +from langchain.llms.utils import enforce_stop_tokens from pydantic import BaseModel, Extra, root_validator from langchain.chat_models import ChatOpenAI diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py index 05d296c5e..aab722240 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py @@ -2,6 +2,7 @@ from typing import Dict, Optional, Tuple, Union from importlib_metadata import entry_points from jupyter_ai_magics.aliases import MODEL_ID_ALIASES + from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider from jupyter_ai_magics.providers import BaseProvider @@ -14,6 +15,7 @@ def load_providers(log: Optional[Logger] = None) -> Dict[str, BaseProvider]: if not log: log = logging.getLogger() log.addHandler(logging.NullHandler()) + providers = {} eps = entry_points() model_provider_eps = eps.select(group="jupyter_ai.model_providers") @@ -47,8 +49,6 @@ def load_embedding_providers(log: Optional[Logger] = None) -> Dict[str, BaseEmbe return providers - - def decompose_model_id(model_id: str, providers: Dict[str, BaseProvider]) -> Tuple[str, str]: """Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate.""" if model_id in MODEL_ID_ALIASES: diff --git a/packages/jupyter-ai/jupyter_ai/actors/base.py b/packages/jupyter-ai/jupyter_ai/actors/base.py index d885084a0..365861553 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/base.py +++ b/packages/jupyter-ai/jupyter_ai/actors/base.py @@ -3,7 +3,7 @@ from uuid import uuid4 import time import logging -from typing import Dict, Type, Union +from typing import Dict, Optional, Type, Union import traceback from jupyter_ai_magics.providers import BaseProvider @@ -48,6 +48,7 @@ def __init__( self.llm_chain = None self.embeddings = None self.embeddings_params = None + self.embedding_model_id = None def process_message(self, message: HumanChatMessage): """Processes the message passed by the `Router`""" @@ -62,12 +63,12 @@ def _process_message(self, message: HumanChatMessage): """Processes the message passed by the `Router`""" raise NotImplementedError("Should be implemented by subclasses.") - def reply(self, response, message: HumanChatMessage): + def reply(self, response, message: Optional[HumanChatMessage] = None): m = AgentChatMessage( id=uuid4().hex, time=time.time(), body=response, - reply_to=message.id + reply_to=message.id if message else "" ) self.reply_queue.put(m) @@ -87,11 +88,12 @@ def get_embeddings(self): actor = ray.get_actor(ACTOR_TYPE.EMBEDDINGS_PROVIDER) provider = ray.get(actor.get_provider.remote()) embedding_params = ray.get(actor.get_provider_params.remote()) + embedding_model_id = ray.get(actor.get_model_id.remote()) if not provider: return None - if provider.__class__.__name__ != self.embeddings.__class__.__name__: + if embedding_model_id != self.embedding_model_id: self.embeddings = provider(**embedding_params) return self.embeddings diff --git a/packages/jupyter-ai/jupyter_ai/actors/config.py b/packages/jupyter-ai/jupyter_ai/actors/config.py index 9a534e872..65160c5dc 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/config.py +++ b/packages/jupyter-ai/jupyter_ai/actors/config.py @@ -23,25 +23,23 @@ def update(self, config: GlobalConfig, save_to_disk: bool = True): self._update_chat_provider(config) self._update_embeddings_provider(config) if save_to_disk: - self._save() + self._save(config) self.config = config def _update_chat_provider(self, config: GlobalConfig): actor = ray.get_actor(ACTOR_TYPE.CHAT_PROVIDER) - handle = actor.update.remote(config) - ray.get(handle) + ray.get(actor.update.remote(config)) def _update_embeddings_provider(self, config: GlobalConfig): actor = ray.get_actor(ACTOR_TYPE.EMBEDDINGS_PROVIDER) - handle = actor.update.remote(config) - ray.get(handle) + ray.get(actor.update.remote(config)) def _save(self, config: GlobalConfig): if not os.path.exists: os.makedirs(self.save_dir) with open(self.save_path, 'w') as f: - f.write(json.dumps(config)) + f.write(config.json()) def _load(self): if os.path.exists(self.save_path): diff --git a/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py b/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py index 328f0e3a5..af4881448 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py +++ b/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py @@ -9,6 +9,7 @@ def __init__(self, log: Logger): self.log = log self.provider = None self.provider_params = None + self.model_id = None def update(self, config: GlobalConfig): model_id = config.embeddings_provider_id @@ -33,9 +34,19 @@ def update(self, config: GlobalConfig): self.provider = provider.provider_klass self.provider_params = provider_params + previous_model_id = self.model_id + self.model_id = model_id + + if previous_model_id and previous_model_id != model_id: + # delete the index + actor = ray.get_actor(ACTOR_TYPE.LEARN) + actor.delete_and_relearn.remote() def get_provider(self): return self.provider def get_provider_params(self): - return self.provider_params \ No newline at end of file + return self.provider_params + + def get_model_id(self): + return self.model_id \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/actors/learn.py b/packages/jupyter-ai/jupyter_ai/actors/learn.py index cc9492092..d65002238 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/learn.py +++ b/packages/jupyter-ai/jupyter_ai/actors/learn.py @@ -1,5 +1,7 @@ +import json import os import argparse +import time from typing import List import ray @@ -14,30 +16,34 @@ ) from langchain.schema import Document -from jupyter_ai.models import HumanChatMessage +from jupyter_ai.models import HumanChatMessage, IndexedDir, IndexMetadata from jupyter_ai.actors.base import BaseActor, Logger from jupyter_ai.document_loaders.directory import RayRecursiveDirectoryLoader from jupyter_ai.document_loaders.splitter import ExtensionSplitter, NotebookSplitter +INDEX_SAVE_DIR = os.path.join(jupyter_data_dir(), 'jupyter_ai', 'indices') +METADATA_SAVE_PATH = os.path.join(INDEX_SAVE_DIR, 'metadata.json') + @ray.remote class LearnActor(BaseActor): def __init__(self, reply_queue: Queue, log: Logger, root_dir: str): super().__init__(reply_queue=reply_queue, log=log) self.root_dir = root_dir - self.index_save_dir = os.path.join(jupyter_data_dir(), 'jupyter_ai', 'indices') self.chunk_size = 2000 self.chunk_overlap = 100 self.parser.prog = '/learn' self.parser.add_argument('-v', '--verbose', action='store_true') self.parser.add_argument('-d', '--delete', action='store_true') + self.parser.add_argument('-l', '--list', action='store_true') self.parser.add_argument('path', nargs=argparse.REMAINDER) self.index_name = 'default' self.index = None - - if not os.path.exists(self.index_save_dir): - os.makedirs(self.index_save_dir) + self.metadata = IndexMetadata(dirs=[]) + + if not os.path.exists(INDEX_SAVE_DIR): + os.makedirs(INDEX_SAVE_DIR) self.load_or_create() @@ -57,6 +63,10 @@ def _process_message(self, message: HumanChatMessage): self.delete() self.reply(f"👍 I have deleted everything I previously learned.", message) return + + if args.list: + self.reply(self._build_list_response()) + return # Make sure the path exists. if not len(args.path) == 1: @@ -72,6 +82,24 @@ def _process_message(self, message: HumanChatMessage): if args.verbose: self.reply(f"Loading and splitting files for {load_path}", message) + self.learn_dir(load_path) + self.save() + + response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them. + You can ask questions about these docs by prefixing your message with **/ask**.""" + self.reply(response, message) + + def _build_list_response(self): + if not self.metadata.dirs: + return "There are no docs that have been learned yet." + + dirs = [dir.path for dir in self.metadata.dirs] + dir_list = "\n- " + "\n- ".join(dirs) + "\n\n" + message = f"""I can answer questions from docs in these directories: + {dir_list}""" + return message + + def learn_dir(self, path: str): splitters={ '.py': PythonCodeTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap), '.md': MarkdownTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap), @@ -83,26 +111,56 @@ def _process_message(self, message: HumanChatMessage): default_splitter=RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) ) - loader = RayRecursiveDirectoryLoader(load_path) - texts = loader.load_and_split(text_splitter=splitter) + loader = RayRecursiveDirectoryLoader(path) + texts = loader.load_and_split(text_splitter=splitter) self.index.add_documents(texts) - self.save() - - response = f"""🎉 I have indexed documents at **{load_path}** and I am ready to answer questions about them. - You can ask questions about these docs by prefixing your message with **/ask**.""" - self.reply(response, message) - - def get_index(self): - return self.index + self._add_dir_to_metadata(path) + + def _add_dir_to_metadata(self, path: str): + dirs = self.metadata.dirs + index = next((i for i, dir in enumerate(dirs) if dir.path == path), None) + if not index: + dirs.append(IndexedDir(path=path)) + self.metadata.dirs = dirs + + def delete_and_relearn(self): + if not self.metadata.dirs: + self.delete() + return + message = """🔔 Hi there, It seems like you have updated the embeddings model. For the **/ask** + command to work with the new model, I have to re-learn the documents you had previously + submitted for learning. Please wait to use the **/ask** command until I am done with this task.""" + self.reply(message) + + metadata = self.metadata + self.delete() + self.relearn(metadata) def delete(self): self.index = None - paths = [os.path.join(self.index_save_dir, self.index_name+ext) for ext in ['.pkl', '.faiss']] + self.metadata = IndexMetadata(dirs=[]) + paths = [os.path.join(INDEX_SAVE_DIR, self.index_name+ext) for ext in ['.pkl', '.faiss']] for path in paths: if os.path.isfile(path): os.remove(path) self.create() + def relearn(self, metadata: IndexMetadata): + # Index all dirs in the metadata + if not metadata.dirs: + return + + for dir in metadata.dirs: + self.learn_dir(dir.path) + + self.save() + + dir_list = "\n- " + "\n- ".join([dir.path for dir in self.metadata.dirs]) + "\n\n" + message = f"""🎉 I am done learning docs in these directories: + {dir_list} I am ready to answer questions about them. + You can ask questions about these docs by prefixing your message with **/ask**.""" + self.reply(message) + def create(self): embeddings = self.get_embeddings() if not embeddings: @@ -112,7 +170,13 @@ def create(self): def save(self): if self.index is not None: - self.index.save_local(self.index_save_dir, index_name=self.index_name) + self.index.save_local(INDEX_SAVE_DIR, index_name=self.index_name) + + self.save_metadata() + + def save_metadata(self): + with open(METADATA_SAVE_PATH, 'w') as f: + f.write(self.metadata.json()) def load_or_create(self): embeddings = self.get_embeddings() @@ -120,10 +184,19 @@ def load_or_create(self): return if self.index is None: try: - self.index = FAISS.load_local(self.index_save_dir, embeddings, index_name=self.index_name) + self.index = FAISS.load_local(INDEX_SAVE_DIR, embeddings, index_name=self.index_name) + self.load_metadata() except Exception as e: self.create() + def load_metadata(self): + if not os.path.exists(METADATA_SAVE_PATH): + return + + with open(METADATA_SAVE_PATH, 'r', encoding='utf-8') as f: + j = json.loads(f.read()) + self.metadata = IndexMetadata(**j) + def get_relevant_documents(self, question: str) -> List[Document]: if self.index: docs = self.index.similarity_search(question) diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 6a316a050..c7968a491 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -1,8 +1,12 @@ import asyncio + from jupyter_ai.actors.chat_provider import ChatProviderActor from jupyter_ai.actors.config import ConfigActor from jupyter_ai.actors.embeddings_provider import EmbeddingsProviderActor from jupyter_ai.actors.providers import ProvidersActor + +from jupyter_ai_magics.utils import load_providers + from langchain.memory import ConversationBufferWindowMemory from jupyter_ai.actors.default import DefaultActor from jupyter_ai.actors.ask import AskActor @@ -13,6 +17,7 @@ from jupyter_ai.actors.base import ACTOR_TYPE from jupyter_ai.reply_processor import ReplyProcessor from jupyter_server.extension.application import ExtensionApp + from .handlers import ( ChatHandler, ChatHistoryHandler, @@ -22,6 +27,7 @@ TaskAPIHandler, GlobalConfigHandler ) + from importlib_metadata import entry_points import inspect from .engine import BaseModelEngine diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 843203f28..298aa2643 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -342,4 +342,5 @@ def post(self): self.log.exception(e) raise HTTPError( 500, "Unexpected error occurred while updating the config." - ) from e \ No newline at end of file + ) from e + diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 98145e4d5..f9a282207 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -1,5 +1,5 @@ -from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider -from jupyter_ai_magics.providers import AuthStrategy, BaseProvider +from jupyter_ai_magics.providers import AuthStrategy + from pydantic import BaseModel from typing import Dict, List, Union, Literal, Optional @@ -101,4 +101,13 @@ class ListProvidersResponse(BaseModel): class GlobalConfig(BaseModel): model_provider_id: str embeddings_provider_id: str - api_keys: Dict[str, str] \ No newline at end of file + api_keys: Dict[str, str] + + +class IndexedDir(BaseModel): + path: str + + +class IndexMetadata(BaseModel): + dirs: List[IndexedDir] + From 520b45ef4a1cb6e3b73c24036703387417709216 Mon Sep 17 00:00:00 2001 From: david qiu Date: Thu, 4 May 2023 08:33:59 -0700 Subject: [PATCH 15/16] Chat settings UI (#141) * remove unused div * automatically create config if not present * allow all-caps envvars in config * implement basic chat settings UI * hide API key text inputs * limit popup size, show success banner * show welcome message if no LM is selected * fix buggy UI with no selected LM/EM * exclude legacy OpenAI chat provider used in magics * Added a button with welcome message --------- Co-authored-by: Jain --- .../jupyter_ai/actors/chat_provider.py | 4 +- .../jupyter-ai/jupyter_ai/actors/config.py | 10 + .../jupyter_ai/actors/embeddings_provider.py | 4 +- packages/jupyter-ai/jupyter_ai/handlers.py | 4 + packages/jupyter-ai/jupyter_ai/models.py | 12 +- .../src/components/chat-settings.tsx | 266 ++++++++++++++++++ packages/jupyter-ai/src/components/chat.tsx | 125 ++++++-- packages/jupyter-ai/src/components/select.tsx | 48 ++++ packages/jupyter-ai/src/handler.ts | 63 ++++- 9 files changed, 499 insertions(+), 37 deletions(-) create mode 100644 packages/jupyter-ai/src/components/chat-settings.tsx create mode 100644 packages/jupyter-ai/src/components/select.tsx diff --git a/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py b/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py index 12e37181c..5885e8e9d 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py +++ b/packages/jupyter-ai/jupyter_ai/actors/chat_provider.py @@ -25,10 +25,10 @@ def update(self, config: GlobalConfig): auth_strategy = provider.auth_strategy if auth_strategy and auth_strategy.type == "env": api_keys = config.api_keys - name = auth_strategy.name.lower() + name = auth_strategy.name if name not in api_keys: raise ValueError(f"Missing value for '{auth_strategy.name}' in the config.") - provider_params[name] = api_keys[name] + provider_params[name.lower()] = api_keys[name] self.provider = provider self.provider_params = provider_params diff --git a/packages/jupyter-ai/jupyter_ai/actors/config.py b/packages/jupyter-ai/jupyter_ai/actors/config.py index 65160c5dc..dc9a719ce 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/config.py +++ b/packages/jupyter-ai/jupyter_ai/actors/config.py @@ -27,10 +27,16 @@ def update(self, config: GlobalConfig, save_to_disk: bool = True): self.config = config def _update_chat_provider(self, config: GlobalConfig): + if not config.model_provider_id: + return + actor = ray.get_actor(ACTOR_TYPE.CHAT_PROVIDER) ray.get(actor.update.remote(config)) def _update_embeddings_provider(self, config: GlobalConfig): + if not config.embeddings_provider_id: + return + actor = ray.get_actor(ACTOR_TYPE.EMBEDDINGS_PROVIDER) ray.get(actor.update.remote(config)) @@ -46,6 +52,10 @@ def _load(self): with open(self.save_path, 'r', encoding='utf-8') as f: config = GlobalConfig(**json.loads(f.read())) self.update(config, False) + return + + # otherwise, create a new empty config file + self.update(GlobalConfig(), True) def get_config(self): return self.config \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py b/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py index af4881448..068ce0388 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py +++ b/packages/jupyter-ai/jupyter_ai/actors/embeddings_provider.py @@ -27,10 +27,10 @@ def update(self, config: GlobalConfig): auth_strategy = provider.auth_strategy if auth_strategy and auth_strategy.type == "env": api_keys = config.api_keys - name = auth_strategy.name.lower() + name = auth_strategy.name if name not in api_keys: raise ValueError(f"Missing value for '{auth_strategy.name}' in the config.") - provider_params[name] = api_keys[name] + provider_params[name.lower()] = api_keys[name] self.provider = provider.provider_klass self.provider_params = provider_params diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 298aa2643..24db73b2c 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -270,6 +270,10 @@ def chat_providers(self): def get(self): providers = [] for provider in self.chat_providers.values(): + # skip old legacy OpenAI chat provider used only in magics + if provider.id == "openai-chat": + continue + providers.append( ListProvidersEntry( id=provider.id, diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index f9a282207..a9f28768a 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -97,17 +97,13 @@ class ListProvidersEntry(BaseModel): class ListProvidersResponse(BaseModel): providers: List[ListProvidersEntry] - -class GlobalConfig(BaseModel): - model_provider_id: str - embeddings_provider_id: str - api_keys: Dict[str, str] - - class IndexedDir(BaseModel): path: str - class IndexMetadata(BaseModel): dirs: List[IndexedDir] +class GlobalConfig(BaseModel): + model_provider_id: Optional[str] = None + embeddings_provider_id: Optional[str] = None + api_keys: Dict[str, str] = {} diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx new file mode 100644 index 000000000..d36a3bf63 --- /dev/null +++ b/packages/jupyter-ai/src/components/chat-settings.tsx @@ -0,0 +1,266 @@ +import React, { useEffect, useState } from 'react'; +import { Box } from '@mui/system'; +import { + Alert, + Button, + MenuItem, + TextField, + CircularProgress +} from '@mui/material'; + +import { Select } from './select'; +import { AiService } from '../handler'; + +enum ChatSettingsState { + // chat settings is making initial fetches + Loading, + // chat settings is ready (happy path) + Ready, + // chat settings failed to make initial fetches + FetchError, + // chat settings failed to submit the save request + SubmitError, + // chat settings successfully submitted the save request + Success +} + +export function ChatSettings() { + const [state, setState] = useState( + ChatSettingsState.Loading + ); + // error message from initial fetch + const [fetchEmsg, setFetchEmsg] = useState(); + + // state fetched on initial render + const [config, setConfig] = useState(); + const [lmProviders, setLmProviders] = + useState(); + const [emProviders, setEmProviders] = + useState(); + + // user inputs + const [inputConfig, setInputConfig] = useState({ + model_provider_id: null, + embeddings_provider_id: null, + api_keys: {} + }); + + // whether the form is currently saving + const [saving, setSaving] = useState(false); + // error message from submission + const [saveEmsg, setSaveEmsg] = useState(); + + /** + * Effect: call APIs on initial render + */ + useEffect(() => { + async function getConfig() { + try { + const [config, lmProviders, emProviders] = await Promise.all([ + AiService.getConfig(), + AiService.listLmProviders(), + AiService.listEmProviders() + ]); + setConfig(config); + setInputConfig(config); + setLmProviders(lmProviders); + setEmProviders(emProviders); + setState(ChatSettingsState.Ready); + } catch (e) { + console.error(e); + if (e instanceof Error) { + setFetchEmsg(e.message); + } + setState(ChatSettingsState.FetchError); + } + } + getConfig(); + }, []); + + /** + * Effect: re-initialize API keys object whenever the selected LM/EM changes. + */ + useEffect(() => { + const selectedLmpId = inputConfig.model_provider_id?.split(':')[0]; + const selectedEmpId = inputConfig.embeddings_provider_id?.split(':')[0]; + const lmp = lmProviders?.providers.find( + provider => provider.id === selectedLmpId + ); + const emp = emProviders?.providers.find( + provider => provider.id === selectedEmpId + ); + const newApiKeys: Record = {}; + + if (lmp?.auth_strategy && lmp.auth_strategy.type === 'env') { + newApiKeys[lmp.auth_strategy.name] = + config?.api_keys[lmp.auth_strategy.name] || ''; + } + if (emp?.auth_strategy && emp.auth_strategy.type === 'env') { + newApiKeys[emp.auth_strategy.name] = + config?.api_keys[emp.auth_strategy.name] || ''; + } + + setInputConfig(inputConfig => ({ + ...inputConfig, + api_keys: { ...config?.api_keys, ...newApiKeys } + })); + }, [inputConfig.model_provider_id, inputConfig.embeddings_provider_id]); + + const handleSave = async () => { + const inputConfigCopy: AiService.Config = { + ...inputConfig, + api_keys: { ...inputConfig.api_keys } + }; + + // delete any empty api keys + for (const apiKey in inputConfigCopy.api_keys) { + if (inputConfigCopy.api_keys[apiKey] === '') { + delete inputConfigCopy.api_keys[apiKey]; + } + } + + setSaving(true); + try { + await AiService.updateConfig(inputConfigCopy); + } catch (e) { + console.error(e); + if (e instanceof Error) { + setSaveEmsg(e.message); + } + setState(ChatSettingsState.SubmitError); + } + setState(ChatSettingsState.Success); + setSaving(false); + }; + + if (state === ChatSettingsState.Loading) { + return ( + + + + ); + } + + if ( + state === ChatSettingsState.FetchError || + !lmProviders || + !emProviders || + !config + ) { + return ( + + + {fetchEmsg + ? `An error occurred. Error details:\n\n${fetchEmsg}` + : 'An unknown error occurred. Check the console for more details.'} + + + ); + } + + return ( + .MuiAlert-root': { marginBottom: 2 } + }} + > + {state === ChatSettingsState.SubmitError && ( + + {saveEmsg + ? `An error occurred. Error details:\n\n${saveEmsg}` + : 'An unknown error occurred. Check the console for more details.'} + + )} + {state === ChatSettingsState.Success && ( + Settings saved successfully. + )} + + + {Object.entries(inputConfig.api_keys).map( + ([apiKey, apiKeyValue], idx) => ( + + setInputConfig(inputConfig => ({ + ...inputConfig, + api_keys: { + ...inputConfig.api_keys, + [apiKey]: e.target.value + } + })) + } + /> + ) + )} + + + + + ); +} diff --git a/packages/jupyter-ai/src/components/chat.tsx b/packages/jupyter-ai/src/components/chat.tsx index 0e68dad13..eaae48ebc 100644 --- a/packages/jupyter-ai/src/components/chat.tsx +++ b/packages/jupyter-ai/src/components/chat.tsx @@ -1,10 +1,14 @@ import React, { useState, useEffect } from 'react'; import { Box } from '@mui/system'; +import { Button, IconButton, Stack } from '@mui/material'; +import SettingsIcon from '@mui/icons-material/Settings'; +import ArrowBackIcon from '@mui/icons-material/ArrowBack'; import type { Awareness } from 'y-protocols/awareness'; import { JlThemeProvider } from './jl-theme-provider'; import { ChatMessages } from './chat-messages'; import { ChatInput } from './chat-input'; +import { ChatSettings } from './chat-settings'; import { AiService } from '../handler'; import { SelectionContextProvider, @@ -17,10 +21,12 @@ import { ScrollContainer } from './scroll-container'; type ChatBodyProps = { chatHandler: ChatHandler; + setChatView: (view: ChatView) => void }; -function ChatBody({ chatHandler }: ChatBodyProps): JSX.Element { +function ChatBody({ chatHandler, setChatView: chatViewHandler }: ChatBodyProps): JSX.Element { const [messages, setMessages] = useState([]); + const [showWelcomeMessage, setShowWelcomeMessage] = useState(false); const [includeSelection, setIncludeSelection] = useState(true); const [replaceSelection, setReplaceSelection] = useState(false); const [input, setInput] = useState(''); @@ -32,12 +38,17 @@ function ChatBody({ chatHandler }: ChatBodyProps): JSX.Element { useEffect(() => { async function fetchHistory() { try { - const history = await chatHandler.getHistory(); + const [history, config] = await Promise.all([ + chatHandler.getHistory(), + AiService.getConfig() + ]); setMessages(history.messages); + if (!config.model_provider_id) { + setShowWelcomeMessage(true); + } } catch (e) { - + console.error(e); } - } fetchHistory(); @@ -71,7 +82,9 @@ function ChatBody({ chatHandler }: ChatBodyProps): JSX.Element { const prompt = input + - (includeSelection && selection?.text ? '\n\n```\n' + selection.text + '```': ''); + (includeSelection && selection?.text + ? '\n\n```\n' + selection.text + '```' + : ''); // send message to backend const messageId = await chatHandler.sendMessage({ prompt }); @@ -90,23 +103,45 @@ function ChatBody({ chatHandler }: ChatBodyProps): JSX.Element { } }; + const openSettingsView = () => { + setShowWelcomeMessage(false) + chatViewHandler(ChatView.Settings) + } + + if (showWelcomeMessage) { + return ( + + +

+ Welcome to Jupyter AI! To get started, please select a language + model to chat with from the settings panel. You will also likely + need to provide API credentials, so be sure to have those handy. +

+ +
+
+ ); + } + return ( - + <> - {/* https://css-tricks.com/books/greatest-css-tricks/pin-scrolling-to-bottom/ */} - Press Shift + Enter to submit message} + helperText={ + + Press Shift + Enter to submit message + + } /> - + ); } @@ -138,14 +177,56 @@ export type ChatProps = { selectionWatcher: SelectionWatcher; chatHandler: ChatHandler; globalAwareness: Awareness | null; + chatView?: ChatView }; +enum ChatView { + Chat, + Settings +} + export function Chat(props: ChatProps) { + const [view, setView] = useState(props.chatView || ChatView.Chat); + return ( - + + {/* top bar */} + + {view !== ChatView.Chat ? ( + setView(ChatView.Chat)}> + + + ) : ( + + )} + {view === ChatView.Chat ? ( + setView(ChatView.Settings)}> + + + ) : ( + + )} + + {/* body */} + {view === ChatView.Chat && ( + + )} + {view === ChatView.Settings && } + diff --git a/packages/jupyter-ai/src/components/select.tsx b/packages/jupyter-ai/src/components/select.tsx new file mode 100644 index 000000000..2e709812d --- /dev/null +++ b/packages/jupyter-ai/src/components/select.tsx @@ -0,0 +1,48 @@ +import React from 'react'; +import { FormControl, InputLabel, Select as MuiSelect } from '@mui/material'; +import type { + SelectChangeEvent, + SelectProps as MuiSelectProps +} from '@mui/material'; + +export type SelectProps = Omit, 'value' | 'onChange'> & { + value: string | null; + onChange: ( + event: SelectChangeEvent, + child: React.ReactNode + ) => void; +}; + +/** + * A helpful wrapper around MUI's native `Select` component that provides the + * following services: + * + * - automatically wraps base `Select` component in `FormControl` context and + * prepends an input label derived from `props.label`. + * + * - limits max height of menu + * + * - handles `null` values by coercing them to the string `'null'`. The + * corresponding `MenuItem` should have the value `'null'`. + */ +export function Select(props: SelectProps) { + return ( + + {props.label} + { + if (e.target.value === 'null') { + e.target.value = null as any; + } + props.onChange?.(e, child); + }} + MenuProps={{ sx: { maxHeight: '50%', minHeight: 400 } }} + > + {props.children} + + + ); +} diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index 6ccad4545..55db58c49 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -99,11 +99,15 @@ export namespace AiService { }; export type ClearMessage = { - type: 'clear' - } + type: 'clear'; + }; export type ChatMessage = AgentChatMessage | HumanChatMessage; - export type Message = AgentChatMessage | HumanChatMessage | ConnectionMessage | ClearMessage; + export type Message = + | AgentChatMessage + | HumanChatMessage + | ConnectionMessage + | ClearMessage; export type ChatHistory = { messages: ChatMessage[]; @@ -160,4 +164,57 @@ export namespace AiService { ): Promise { return requestAPI(`tasks/${id}`); } + + export type Config = { + model_provider_id: string | null; + embeddings_provider_id: string | null; + api_keys: Record; + }; + + export type GetConfigResponse = Config; + + export type UpdateConfigRequest = Config; + + export async function getConfig(): Promise { + return requestAPI('config'); + } + + export type EnvAuthStrategy = { + type: 'env'; + name: string; + }; + + export type AwsAuthStrategy = { + type: 'aws'; + }; + + export type AuthStrategy = EnvAuthStrategy | AwsAuthStrategy | null; + + export type ListProvidersEntry = { + id: string; + name: string; + models: string[]; + auth_strategy: AuthStrategy; + }; + + export type ListProvidersResponse = { + providers: ListProvidersEntry[]; + }; + + export async function listLmProviders(): Promise { + return requestAPI('providers'); + } + + export async function listEmProviders(): Promise { + return requestAPI('providers/embeddings'); + } + + export async function updateConfig( + config: UpdateConfigRequest + ): Promise { + return requestAPI('config', { + method: 'POST', + body: JSON.stringify(config) + }); + } } From 94630c78ffc708f12854616a7315c1905cb13892 Mon Sep 17 00:00:00 2001 From: david qiu Date: Fri, 5 May 2023 07:43:15 -0700 Subject: [PATCH 16/16] Various chat chain enhancements and fixes (#144) * fix /clear command * use model IDs to compare LLMs instead * specify stop sequence in chat chain * add empty AI message, improve system prompt * add RTD configuration --- .readthedocs.yaml | 17 ++++++ packages/jupyter-ai/jupyter_ai/actors/base.py | 19 ++++-- .../jupyter-ai/jupyter_ai/actors/default.py | 59 ++++++++++++++----- .../jupyter-ai/jupyter_ai/actors/router.py | 5 +- packages/jupyter-ai/jupyter_ai/extension.py | 30 ++++++---- 5 files changed, 95 insertions(+), 35 deletions(-) create mode 100644 .readthedocs.yaml diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 000000000..8654753a1 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,17 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.11" + +sphinx: + configuration: docs/source/conf.py + +python: + install: + - requirements: docs/requirements.txt diff --git a/packages/jupyter-ai/jupyter_ai/actors/base.py b/packages/jupyter-ai/jupyter_ai/actors/base.py index 365861553..84a62bf8b 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/base.py +++ b/packages/jupyter-ai/jupyter_ai/actors/base.py @@ -16,7 +16,12 @@ Logger = Union[logging.Logger, logging.LoggerAdapter] class ACTOR_TYPE(str, Enum): + # the top level actor that routes incoming messages to the appropriate actor + ROUTER = "router" + + # the default actor that responds to messages using a language model DEFAULT = "default" + ASK = "ask" LEARN = 'learn' MEMORY = 'memory' @@ -74,14 +79,18 @@ def reply(self, response, message: Optional[HumanChatMessage] = None): def get_llm_chain(self): actor = ray.get_actor(ACTOR_TYPE.CHAT_PROVIDER) - llm = ray.get(actor.get_provider.remote()) - llm_params = ray.get(actor.get_provider_params.remote()) + lm_provider = ray.get(actor.get_provider.remote()) + lm_provider_params = ray.get(actor.get_provider_params.remote()) + + curr_lm_id = f'{self.llm.id}:{lm_provider_params["model_id"]}' if self.llm else None + next_lm_id = f'{lm_provider.id}:{lm_provider_params["model_id"]}' if lm_provider else None - if not llm: + if not lm_provider: return None - if llm.__class__.__name__ != self.llm.__class__.__name__: - self.create_llm_chain(llm, llm_params) + if curr_lm_id != next_lm_id: + self.log.info(f"Switching chat language model from {curr_lm_id} to {next_lm_id}.") + self.create_llm_chain(lm_provider, lm_provider_params) return self.llm_chain def get_embeddings(self): diff --git a/packages/jupyter-ai/jupyter_ai/actors/default.py b/packages/jupyter-ai/jupyter_ai/actors/default.py index 83edc36f1..7914daad3 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/default.py +++ b/packages/jupyter-ai/jupyter_ai/actors/default.py @@ -1,44 +1,75 @@ -from typing import Dict, Type +from typing import Dict, Type, List import ray -from ray.util.queue import Queue from langchain import ConversationChain from langchain.prompts import ( ChatPromptTemplate, MessagesPlaceholder, - SystemMessagePromptTemplate, - HumanMessagePromptTemplate + HumanMessagePromptTemplate, + SystemMessagePromptTemplate +) +from langchain.schema import ( + AIMessage, ) -from jupyter_ai.actors.base import BaseActor, Logger, ACTOR_TYPE +from jupyter_ai.actors.base import BaseActor, ACTOR_TYPE from jupyter_ai.actors.memory import RemoteMemory -from jupyter_ai.models import HumanChatMessage +from jupyter_ai.models import HumanChatMessage, ClearMessage, ChatMessage from jupyter_ai_magics.providers import BaseProvider -SYSTEM_PROMPT = "The following is a friendly conversation between a human and an AI, whose name is Jupyter AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know." +SYSTEM_PROMPT = """ +You are Jupyter AI, a conversational assistant living in JupyterLab to help users. +You are not a language model, but rather an application built on a foundation model from {provider_name} called {local_model_id}. +You are talkative and provides lots of specific details from its context. +You may use Markdown to format your response. +Code blocks must be formatted in Markdown. +Math should be rendered with inline TeX markup, surrounded by $. +If you do not know the answer to a question, answer truthfully by responding that you do not know. +The following is a friendly conversation between you and a human. +""".strip() @ray.remote class DefaultActor(BaseActor): - def __init__(self, reply_queue: Queue, log: Logger): - super().__init__(reply_queue=reply_queue, log=log) + def __init__(self, chat_history: List[ChatMessage], *args, **kwargs): + super().__init__(*args, **kwargs) + self.memory = None + self.chat_history = chat_history def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]): llm = provider(**provider_params) - memory = RemoteMemory(actor_name=ACTOR_TYPE.MEMORY) + self.memory = RemoteMemory(actor_name=ACTOR_TYPE.MEMORY) prompt_template = ChatPromptTemplate.from_messages([ - SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT), + SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT).format(provider_name=llm.name, local_model_id=llm.model_id), MessagesPlaceholder(variable_name="history"), - HumanMessagePromptTemplate.from_template("{input}") + HumanMessagePromptTemplate.from_template("{input}"), + AIMessage(content="") ]) self.llm = llm self.llm_chain = ConversationChain( llm=llm, prompt=prompt_template, verbose=True, - memory=memory + memory=self.memory ) + + def clear_memory(self): + if not self.memory: + return + + # clear chain memory + self.memory.clear() + + # clear transcript for existing chat clients + reply_message = ClearMessage() + self.reply_queue.put(reply_message) + + # clear transcript for new chat clients + self.chat_history.clear() def _process_message(self, message: HumanChatMessage): self.get_llm_chain() - response = self.llm_chain.predict(input=message.body) + response = self.llm_chain.predict( + input=message.body, + stop=["\nHuman:"] + ) self.reply(response, message) diff --git a/packages/jupyter-ai/jupyter_ai/actors/router.py b/packages/jupyter-ai/jupyter_ai/actors/router.py index 7b417a0cd..fbc3234da 100644 --- a/packages/jupyter-ai/jupyter_ai/actors/router.py +++ b/packages/jupyter-ai/jupyter_ai/actors/router.py @@ -2,7 +2,6 @@ from ray.util.queue import Queue from jupyter_ai.actors.base import ACTOR_TYPE, COMMANDS, Logger, BaseActor -from jupyter_ai.models import ClearMessage @ray.remote class Router(BaseActor): @@ -25,7 +24,7 @@ def _process_message(self, message): actor = ray.get_actor(COMMANDS[command].value) actor.process_message.remote(message) if command == '/clear': - reply_message = ClearMessage() - self.reply_queue.put(reply_message) + actor = ray.get_actor(ACTOR_TYPE.DEFAULT) + actor.clear_memory.remote() else: default.process_message.remote(message) diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index c7968a491..837bb68f8 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -121,31 +121,35 @@ def initialize_settings(self): self.settings["chat_handlers"] = {} # store chat messages in memory for now + # this is only used to render the UI, and is not the conversational + # memory object used by the LM chain. self.settings["chat_history"] = [] reply_queue = Queue() self.settings["reply_queue"] = reply_queue - router = Router.options(name="router").remote( + router = Router.options(name=ACTOR_TYPE.ROUTER).remote( reply_queue=reply_queue, - log=self.log + log=self.log, ) + default_actor = DefaultActor.options(name=ACTOR_TYPE.DEFAULT.value).remote( + reply_queue=reply_queue, + log=self.log, + chat_history=self.settings["chat_history"] + ) + providers_actor = ProvidersActor.options(name=ACTOR_TYPE.PROVIDERS.value).remote( - log=self.log + log=self.log, ) config_actor = ConfigActor.options(name=ACTOR_TYPE.CONFIG.value).remote( - log=self.log + log=self.log, ) chat_provider_actor = ChatProviderActor.options(name=ACTOR_TYPE.CHAT_PROVIDER.value).remote( - log=self.log + log=self.log, ) embeddings_provider_actor = EmbeddingsProviderActor.options(name=ACTOR_TYPE.EMBEDDINGS_PROVIDER.value).remote( - log=self.log - ) - default_actor = DefaultActor.options(name=ACTOR_TYPE.DEFAULT.value).remote( - reply_queue=reply_queue, - log=self.log + log=self.log, ) learn_actor = LearnActor.options(name=ACTOR_TYPE.LEARN.value).remote( reply_queue=reply_queue, @@ -154,16 +158,16 @@ def initialize_settings(self): ) ask_actor = AskActor.options(name=ACTOR_TYPE.ASK.value).remote( reply_queue=reply_queue, - log=self.log + log=self.log, ) memory_actor = MemoryActor.options(name=ACTOR_TYPE.MEMORY.value).remote( log=self.log, - memory=ConversationBufferWindowMemory(return_messages=True, k=2) + memory=ConversationBufferWindowMemory(return_messages=True, k=2), ) generate_actor = GenerateActor.options(name=ACTOR_TYPE.GENERATE.value).remote( reply_queue=reply_queue, log=self.log, - root_dir=self.settings['server_root_dir'] + root_dir=self.settings['server_root_dir'], ) self.settings['router'] = router