Skip to content

Commit

Permalink
Runtime model configurability (jupyterlab#146)
Browse files Browse the repository at this point in the history
* Refactored provider load, decompose logic, aded model provider list api

* Renamed model

* Sorted the provider names

* WIP: Embedding providers

* Added embeddings provider api

* Added missing import

* Moved providers to ray actor, added config actor

* Ability to load llm and embeddings from config

* Moved llm creation to specific actors

* Added apis for fetching, updating config. Fixed config update, error handling

* Updated as per PR feedback

* Fixes issue with cohere embeddings, api keys not working

* Added an error check when embedding change causes read error

* Delete and re-index docs when embedding model changes (jupyterlab#137)

* Added an error check when embedding change causes read error

* Refactored provider load, decompose logic, aded model provider list api

* Re-indexes dirs when embeddings change, learn list command

* Fixed typo, simplified adding metadata

* Moved index dir, metadata path to constants

* Chat settings UI (jupyterlab#141)

* remove unused div

* automatically create config if not present

* allow all-caps envvars in config

* implement basic chat settings UI

* hide API key text inputs

* limit popup size, show success banner

* show welcome message if no LM is selected

* fix buggy UI with no selected LM/EM

* exclude legacy OpenAI chat provider used in magics

* Added a button with welcome message

---------

Co-authored-by: Jain <pijain@3c22fb64c9fa.amazon.com>

* Various chat chain enhancements and fixes (jupyterlab#144)

* fix /clear command

* use model IDs to compare LLMs instead

* specify stop sequence in chat chain

* add empty AI message, improve system prompt

* add RTD configuration

---------

Co-authored-by: Piyush Jain <piyushjain@duck.com>
Co-authored-by: Jain <pijain@3c22fb64c9fa.amazon.com>
  • Loading branch information
3 people authored May 5, 2023
1 parent f21445e commit 0fb8306
Show file tree
Hide file tree
Showing 27 changed files with 1,324 additions and 184 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@
*.gif filter=lfs diff=lfs merge=lfs -text
*.ipynb filter=lfs diff=lfs merge=lfs -text
/packages/jupyter-ai-magics/LICENSE filter=lfs diff=lfs merge=lfs -text
*.yaml filter=lfs diff=lfs merge=lfs -text
* !text !filter !merge !diff
17 changes: 17 additions & 0 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# .readthedocs.yaml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details

version: 2

build:
os: ubuntu-22.04
tools:
python: "3.11"

sphinx:
configuration: docs/source/conf.py

python:
install:
- requirements: docs/requirements.txt
7 changes: 7 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,15 @@
HfHubProvider,
OpenAIProvider,
ChatOpenAIProvider,
ChatOpenAINewProvider,
SmEndpointProvider
)
# expose embedding model providers on the package root
from .embedding_providers import (
OpenAIEmbeddingsProvider,
CohereEmbeddingsProvider,
HfHubEmbeddingsProvider
)
from .providers import BaseProvider

def load_ipython_extension(ipython):
Expand Down
6 changes: 6 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
MODEL_ID_ALIASES = {
"gpt2": "huggingface_hub:gpt2",
"gpt3": "openai:text-davinci-003",
"chatgpt": "openai-chat:gpt-3.5-turbo",
"gpt4": "openai-chat:gpt-4",
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import ClassVar, List, Type
from jupyter_ai_magics.providers import AuthStrategy, EnvAuthStrategy
from pydantic import BaseModel, Extra
from langchain.embeddings import OpenAIEmbeddings, CohereEmbeddings, HuggingFaceHubEmbeddings
from langchain.embeddings.base import Embeddings


class BaseEmbeddingsProvider(BaseModel):
"""Base class for embedding providers"""

class Config:
extra = Extra.allow

id: ClassVar[str] = ...
"""ID for this provider class."""

name: ClassVar[str] = ...
"""User-facing name of this provider."""

models: ClassVar[List[str]] = ...
"""List of supported models by their IDs. For registry providers, this will
be just ["*"]."""

model_id_key: ClassVar[str] = ...
"""Kwarg expected by the upstream LangChain provider."""

pypi_package_deps: ClassVar[List[str]] = []
"""List of PyPi package dependencies."""

auth_strategy: ClassVar[AuthStrategy] = None
"""Authentication/authorization strategy. Declares what credentials are
required to use this model provider. Generally should not be `None`."""

model_id: str

provider_klass: ClassVar[Type[Embeddings]]


class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider):
id = "openai"
name = "OpenAI"
models = [
"text-embedding-ada-002"
]
model_id_key = "model"
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
provider_klass = OpenAIEmbeddings


class CohereEmbeddingsProvider(BaseEmbeddingsProvider):
id = "cohere"
name = "Cohere"
models = [
'large',
'multilingual-22-12',
'small'
]
model_id_key = "model"
pypi_package_deps = ["cohere"]
auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY")
provider_klass = CohereEmbeddings


class HfHubEmbeddingsProvider(BaseEmbeddingsProvider):
id = "huggingface_hub"
name = "HuggingFace Hub"
models = ["*"]
model_id_key = "repo_id"
# ipywidgets needed to suppress tqdm warning
# https://stackoverflow.com/questions/67998191
# tqdm is a dependency of huggingface_hub
pypi_package_deps = ["huggingface_hub", "ipywidgets"]
auth_strategy = EnvAuthStrategy(name="HUGGINGFACEHUB_API_TOKEN")
provider_klass = HuggingFaceHubEmbeddings
40 changes: 5 additions & 35 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@
import json
import os
import re
import traceback
import warnings
from typing import Optional

from importlib_metadata import entry_points
from IPython import get_ipython
from IPython.core.magic import Magics, magics_class, line_cell_magic
from IPython.core.magic_arguments import magic_arguments, argument, parse_argstring
from IPython.display import HTML, Image, JSON, Markdown, Math

from IPython.display import HTML, JSON, Markdown, Math
from jupyter_ai_magics.utils import decompose_model_id, load_providers
from .providers import BaseProvider


Expand All @@ -36,8 +34,8 @@ def _repr_mimebundle_(self, include=None, exclude=None):
}
)

class TextWithMetadata:

class TextWithMetadata(object):
def __init__(self, text, metadata):
self.text = text
self.metadata = metadata
Expand Down Expand Up @@ -109,18 +107,7 @@ def __init__(self, shell):
"no longer supported. Instead, please use: "
"`from langchain.chat_models import ChatOpenAI`")

# load model providers from entry point
self.providers = {}
eps = entry_points()
model_provider_eps = eps.select(group="jupyter_ai.model_providers")
for model_provider_ep in model_provider_eps:
try:
Provider = model_provider_ep.load()
except:
print(f"Unable to load entry point {model_provider_ep.name}");
traceback.print_exc()
continue
self.providers[Provider.id] = Provider
self.providers = load_providers()

def _ai_help_command_markdown(self):
table = ("| Command | Description |\n"
Expand Down Expand Up @@ -272,24 +259,7 @@ def _append_exchange_openai(self, prompt: str, output: str):
})

def _decompose_model_id(self, model_id: str):
"""Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate."""
if model_id in MODEL_ID_ALIASES:
model_id = MODEL_ID_ALIASES[model_id]

if ":" not in model_id:
# case: model ID was not provided with a prefix indicating the provider
# ID. try to infer the provider ID before returning (None, None).

# naively search through the dictionary and return the first provider
# that provides a model of the same ID.
for provider_id, Provider in self.providers.items():
if model_id in Provider.models:
return (provider_id, model_id)

return (None, None)

provider_id, local_model_id = model_id.split(":", 1)
return (provider_id, local_model_id)
return decompose_model_id(model_id, self.providers)

def _get_provider(self, provider_id: Optional[str]) -> BaseProvider:
"""Returns the model provider ID and class for a model ID. Returns None if indeterminate."""
Expand Down
3 changes: 3 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SagemakerEndpoint
)
from langchain.utils import get_from_dict_or_env
from langchain.llms.utils import enforce_stop_tokens

from pydantic import BaseModel, Extra, root_validator
from langchain.chat_models import ChatOpenAI
Expand Down Expand Up @@ -298,3 +299,5 @@ class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
model_id_key = "endpoint_name"
pypi_package_deps = ["boto3"]
auth_strategy = AwsAuthStrategy()


70 changes: 70 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import logging
from typing import Dict, Optional, Tuple, Union
from importlib_metadata import entry_points
from jupyter_ai_magics.aliases import MODEL_ID_ALIASES

from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider

from jupyter_ai_magics.providers import BaseProvider


Logger = Union[logging.Logger, logging.LoggerAdapter]


def load_providers(log: Optional[Logger] = None) -> Dict[str, BaseProvider]:
if not log:
log = logging.getLogger()
log.addHandler(logging.NullHandler())

providers = {}
eps = entry_points()
model_provider_eps = eps.select(group="jupyter_ai.model_providers")
for model_provider_ep in model_provider_eps:
try:
provider = model_provider_ep.load()
except:
log.error(f"Unable to load model provider class from entry point `{model_provider_ep.name}`.")
continue
providers[provider.id] = provider
log.info(f"Registered model provider `{provider.id}`.")

return providers


def load_embedding_providers(log: Optional[Logger] = None) -> Dict[str, BaseEmbeddingsProvider]:
if not log:
log = logging.getLogger()
log.addHandler(logging.NullHandler())
providers = {}
eps = entry_points()
model_provider_eps = eps.select(group="jupyter_ai.embeddings_model_providers")
for model_provider_ep in model_provider_eps:
try:
provider = model_provider_ep.load()
except:
log.error(f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`.")
continue
providers[provider.id] = provider
log.info(f"Registered embeddings model provider `{provider.id}`.")

return providers

def decompose_model_id(model_id: str, providers: Dict[str, BaseProvider]) -> Tuple[str, str]:
"""Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate."""
if model_id in MODEL_ID_ALIASES:
model_id = MODEL_ID_ALIASES[model_id]

if ":" not in model_id:
# case: model ID was not provided with a prefix indicating the provider
# ID. try to infer the provider ID before returning (None, None).

# naively search through the dictionary and return the first provider
# that provides a model of the same ID.
for provider_id, provider in providers.items():
if model_id in provider.models:
return (provider_id, model_id)

return (None, None)

provider_id, local_model_id = model_id.split(":", 1)
return (provider_id, local_model_id)
6 changes: 6 additions & 0 deletions packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,14 @@ cohere = "jupyter_ai_magics:CohereProvider"
huggingface_hub = "jupyter_ai_magics:HfHubProvider"
openai = "jupyter_ai_magics:OpenAIProvider"
openai-chat = "jupyter_ai_magics:ChatOpenAIProvider"
openai-chat-new = "jupyter_ai_magics:ChatOpenAINewProvider"
sagemaker-endpoint = "jupyter_ai_magics:SmEndpointProvider"

[project.entry-points."jupyter_ai.embeddings_model_providers"]
cohere = "jupyter_ai_magics:CohereEmbeddingsProvider"
huggingface_hub = "jupyter_ai_magics:HfHubEmbeddingsProvider"
openai = "jupyter_ai_magics:OpenAIEmbeddingsProvider"

[tool.hatch.version]
source = "nodejs"

Expand Down
3 changes: 0 additions & 3 deletions packages/jupyter-ai/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,6 @@ dmypy.json
# OSX files
.DS_Store

# local config storing authn credentials
config.py

# vscode
.vscode

Expand Down
62 changes: 41 additions & 21 deletions packages/jupyter-ai/jupyter_ai/actors/ask.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import argparse
from typing import Dict, List, Type
from jupyter_ai_magics.providers import BaseProvider

import ray
from ray.util.queue import Queue

from langchain import OpenAI
from langchain.chains import ConversationalRetrievalChain
from langchain.schema import BaseRetriever, Document

from jupyter_ai.models import HumanChatMessage
from jupyter_ai.actors.base import ACTOR_TYPE, BaseActor, Logger
Expand All @@ -21,21 +23,18 @@ class AskActor(BaseActor):

def __init__(self, reply_queue: Queue, log: Logger):
super().__init__(reply_queue=reply_queue, log=log)
index_actor = ray.get_actor(ACTOR_TYPE.LEARN.value)
handle = index_actor.get_index.remote()
vectorstore = ray.get(handle)
if not vectorstore:
return

self.chat_history = []
self.chat_provider = ConversationalRetrievalChain.from_llm(
OpenAI(temperature=0, verbose=True),
vectorstore.as_retriever()
)

self.parser.prog = '/ask'
self.parser.add_argument('query', nargs=argparse.REMAINDER)

def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]):
retriever = Retriever()
self.llm = provider(**provider_params)
self.chat_history = []
self.llm_chain = ConversationalRetrievalChain.from_llm(
self.llm,
retriever
)

def _process_message(self, message: HumanChatMessage):
args = self.parse_args(message)
Expand All @@ -46,13 +45,34 @@ def _process_message(self, message: HumanChatMessage):
self.reply(f"{self.parser.format_usage()}", message)
return

self.get_llm_chain()

try:
result = self.llm_chain({"question": query, "chat_history": self.chat_history})
response = result['answer']
self.chat_history.append((query, response))
self.reply(response, message)
except AssertionError as e:
self.log.error(e)
response = """Sorry, an error occurred while reading the from the learned documents.
If you have changed the embedding provider, try deleting the existing index by running
`/learn -d` command and then re-submitting the `learn <directory>` to learn the documents,
and then asking the question again.
"""
self.reply(response, message)


class Retriever(BaseRetriever):
"""Wrapper retriever class to get relevant docs
from the vector store, this is important because
of inconsistent de-serialization of index when it's
accessed directly from the ask actor.
"""

def get_relevant_documents(self, question: str):
index_actor = ray.get_actor(ACTOR_TYPE.LEARN.value)
handle = index_actor.get_index.remote()
vectorstore = ray.get(handle)
# Have to reference the latest index
self.chat_provider.retriever = vectorstore.as_retriever()

result = self.chat_provider({"question": query, "chat_history": self.chat_history})
response = result['answer']
self.chat_history.append((query, response))
self.reply(response, message)
docs = ray.get(index_actor.get_relevant_documents.remote(question))
return docs

async def aget_relevant_documents(self, query: str) -> List[Document]:
return await super().aget_relevant_documents(query)
Loading

0 comments on commit 0fb8306

Please sign in to comment.