Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

APIs for fetching and setting model and embedding providers #101

Merged
merged 13 commits into from
May 1, 2023
Merged
7 changes: 7 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,15 @@
HfHubProvider,
OpenAIProvider,
ChatOpenAIProvider,
ChatOpenAINewProvider,
SmEndpointProvider
)
# expose embedding model providers on the package root
from .embedding_providers import (
OpenAIEmbeddingsProvider,
CohereEmbeddingsProvider,
HfHubEmbeddingsProvider
)
from .providers import BaseProvider

def load_ipython_extension(ipython):
Expand Down
6 changes: 6 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
MODEL_ID_ALIASES = {
"gpt2": "huggingface_hub:gpt2",
"gpt3": "openai:text-davinci-003",
"chatgpt": "openai-chat:gpt-3.5-turbo",
"gpt4": "openai-chat:gpt-4",
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import ClassVar, List, Type
from jupyter_ai_magics.providers import AuthStrategy, EnvAuthStrategy
from pydantic import BaseModel, Extra
from langchain.embeddings import OpenAIEmbeddings, CohereEmbeddings, HuggingFaceHubEmbeddings
from langchain.embeddings.base import Embeddings


class BaseEmbeddingsProvider(BaseModel):
"""Base class for embedding providers"""

class Config:
extra = Extra.allow

id: ClassVar[str] = ...
"""ID for this provider class."""

name: ClassVar[str] = ...
"""User-facing name of this provider."""

models: ClassVar[List[str]] = ...
"""List of supported models by their IDs. For registry providers, this will
be just ["*"]."""

model_id_key: ClassVar[str] = ...
"""Kwarg expected by the upstream LangChain provider."""

pypi_package_deps: ClassVar[List[str]] = []
"""List of PyPi package dependencies."""

auth_strategy: ClassVar[AuthStrategy] = None
"""Authentication/authorization strategy. Declares what credentials are
required to use this model provider. Generally should not be `None`."""

model_id: str

provider_klass: ClassVar[Type[Embeddings]]


class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider):
id = "openai"
name = "OpenAI"
models = [
"text-embedding-ada-002"
]
model_id_key = "model"
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
provider_klass = OpenAIEmbeddings


class CohereEmbeddingsProvider(BaseEmbeddingsProvider):
id = "cohere"
name = "Cohere"
models = [
'large',
'multilingual-22-12',
'small'
]
model_id_key = "model"
pypi_package_deps = ["cohere"]
auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY")
provider_klass = CohereEmbeddings


class HfHubEmbeddingsProvider(BaseEmbeddingsProvider):
id = "huggingface_hub"
name = "HuggingFace Hub"
models = ["*"]
model_id_key = "repo_id"
# ipywidgets needed to suppress tqdm warning
# https://stackoverflow.com/questions/67998191
# tqdm is a dependency of huggingface_hub
pypi_package_deps = ["huggingface_hub", "ipywidgets"]
auth_strategy = EnvAuthStrategy(name="HUGGINGFACEHUB_API_TOKEN")
provider_klass = HuggingFaceHubEmbeddings
33 changes: 4 additions & 29 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import warnings
from typing import Optional

from importlib_metadata import entry_points
from IPython import get_ipython
from IPython.core.magic import Magics, magics_class, line_cell_magic
from IPython.core.magic_arguments import magic_arguments, argument, parse_argstring
from IPython.display import HTML, Markdown, Math, JSON
from jupyter_ai_magics.utils import decompose_model_id, load_providers

from .providers import BaseProvider

Expand All @@ -34,6 +34,7 @@ def _repr_mimebundle_(self, include=None, exclude=None):
}
)


class TextWithMetadata(object):

def __init__(self, text, metadata):
Expand Down Expand Up @@ -93,16 +94,7 @@ def __init__(self, shell):
"no longer supported. Instead, please use: "
"`from langchain.chat_models import ChatOpenAI`")

# load model providers from entry point
self.providers = {}
eps = entry_points()
model_provider_eps = eps.select(group="jupyter_ai.model_providers")
for model_provider_ep in model_provider_eps:
try:
Provider = model_provider_ep.load()
except:
continue
self.providers[Provider.id] = Provider
self.providers = load_providers()

def _ai_help_command_markdown(self):
table = ("| Command | Description |\n"
Expand Down Expand Up @@ -254,24 +246,7 @@ def _append_exchange_openai(self, prompt: str, output: str):
})

def _decompose_model_id(self, model_id: str):
"""Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate."""
if model_id in MODEL_ID_ALIASES:
model_id = MODEL_ID_ALIASES[model_id]

if ":" not in model_id:
# case: model ID was not provided with a prefix indicating the provider
# ID. try to infer the provider ID before returning (None, None).

# naively search through the dictionary and return the first provider
# that provides a model of the same ID.
for provider_id, Provider in self.providers.items():
if model_id in Provider.models:
return (provider_id, model_id)

return (None, None)

provider_id, local_model_id = model_id.split(":", 1)
return (provider_id, local_model_id)
return decompose_model_id(model_id, self.providers)

def _get_provider(self, provider_id: Optional[str]) -> BaseProvider:
"""Returns the model provider ID and class for a model ID. Returns None if indeterminate."""
Expand Down
2 changes: 2 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,5 @@ class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
model_id_key = "endpoint_name"
pypi_package_deps = ["boto3"]
auth_strategy = AwsAuthStrategy()


70 changes: 70 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import logging
from typing import Dict, Optional, Tuple, Union
from importlib_metadata import entry_points
from jupyter_ai_magics.aliases import MODEL_ID_ALIASES
from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider

from jupyter_ai_magics.providers import BaseProvider


Logger = Union[logging.Logger, logging.LoggerAdapter]


def load_providers(log: Optional[Logger] = None) -> Dict[str, BaseProvider]:
3coins marked this conversation as resolved.
Show resolved Hide resolved
if not log:
log = logging.getLogger()
log.addHandler(logging.NullHandler())
providers = {}
eps = entry_points()
model_provider_eps = eps.select(group="jupyter_ai.model_providers")
for model_provider_ep in model_provider_eps:
try:
provider = model_provider_ep.load()
except:
log.error(f"Unable to load model provider class from entry point `{model_provider_ep.name}`.")
continue
providers[provider.id] = provider
log.info(f"Registered model provider `{provider.id}`.")

return providers


def load_embedding_providers(log: Optional[Logger] = None) -> Dict[str, BaseEmbeddingsProvider]:
if not log:
log = logging.getLogger()
log.addHandler(logging.NullHandler())
providers = {}
eps = entry_points()
model_provider_eps = eps.select(group="jupyter_ai.embeddings_model_providers")
for model_provider_ep in model_provider_eps:
try:
provider = model_provider_ep.load()
except:
log.error(f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`.")
continue
providers[provider.id] = provider
log.info(f"Registered embeddings model provider `{provider.id}`.")

return providers



def decompose_model_id(model_id: str, providers: Dict[str, BaseProvider]) -> Tuple[str, str]:
"""Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate."""
if model_id in MODEL_ID_ALIASES:
model_id = MODEL_ID_ALIASES[model_id]

if ":" not in model_id:
# case: model ID was not provided with a prefix indicating the provider
# ID. try to infer the provider ID before returning (None, None).

# naively search through the dictionary and return the first provider
# that provides a model of the same ID.
for provider_id, provider in providers.items():
if model_id in provider.models:
return (provider_id, model_id)

return (None, None)

provider_id, local_model_id = model_id.split(":", 1)
return (provider_id, local_model_id)
6 changes: 6 additions & 0 deletions packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,14 @@ cohere = "jupyter_ai_magics:CohereProvider"
huggingface_hub = "jupyter_ai_magics:HfHubProvider"
openai = "jupyter_ai_magics:OpenAIProvider"
openai-chat = "jupyter_ai_magics:ChatOpenAIProvider"
openai-chat-new = "jupyter_ai_magics:ChatOpenAINewProvider"
sagemaker-endpoint = "jupyter_ai_magics:SmEndpointProvider"

[project.entry-points."jupyter_ai.embeddings_model_providers"]
cohere = "jupyter_ai_magics:CohereEmbeddingsProvider"
huggingface_hub = "jupyter_ai_magics:HfHubEmbeddingsProvider"
openai = "jupyter_ai_magics:OpenAIEmbeddingsProvider"

[tool.hatch.version]
source = "nodejs"

Expand Down
3 changes: 0 additions & 3 deletions packages/jupyter-ai/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,6 @@ dmypy.json
# OSX files
.DS_Store

# local config storing authn credentials
config.py

# vscode
.vscode

Expand Down
62 changes: 41 additions & 21 deletions packages/jupyter-ai/jupyter_ai/actors/ask.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import argparse
from typing import Dict, List, Type
from jupyter_ai_magics.providers import BaseProvider

import ray
from ray.util.queue import Queue

from langchain import OpenAI
from langchain.chains import ConversationalRetrievalChain
from langchain.schema import BaseRetriever, Document

from jupyter_ai.models import HumanChatMessage
from jupyter_ai.actors.base import ACTOR_TYPE, BaseActor, Logger
Expand All @@ -21,21 +23,18 @@ class AskActor(BaseActor):

def __init__(self, reply_queue: Queue, log: Logger):
super().__init__(reply_queue=reply_queue, log=log)
index_actor = ray.get_actor(ACTOR_TYPE.LEARN.value)
handle = index_actor.get_index.remote()
vectorstore = ray.get(handle)
if not vectorstore:
return

self.chat_history = []
self.chat_provider = ConversationalRetrievalChain.from_llm(
OpenAI(temperature=0, verbose=True),
vectorstore.as_retriever()
)

self.parser.prog = '/ask'
self.parser.add_argument('query', nargs=argparse.REMAINDER)

def create_llm_chain(self, provider: Type[BaseProvider], provider_params: Dict[str, str]):
retriever = Retriever()
self.llm = provider(**provider_params)
self.chat_history = []
self.llm_chain = ConversationalRetrievalChain.from_llm(
self.llm,
retriever
)

def _process_message(self, message: HumanChatMessage):
args = self.parse_args(message)
Expand All @@ -46,13 +45,34 @@ def _process_message(self, message: HumanChatMessage):
self.reply(f"{self.parser.format_usage()}", message)
return

self.get_llm_chain()

try:
result = self.llm_chain({"question": query, "chat_history": self.chat_history})
response = result['answer']
self.chat_history.append((query, response))
self.reply(response, message)
except AssertionError as e:
self.log.error(e)
response = """Sorry, an error occurred while reading the from the learned documents.
If you have changed the embedding provider, try deleting the existing index by running
`/learn -d` command and then re-submitting the `learn <directory>` to learn the documents,
and then asking the question again.
"""
self.reply(response, message)


class Retriever(BaseRetriever):
"""Wrapper retriever class to get relevant docs
from the vector store, this is important because
of inconsistent de-serialization of index when it's
accessed directly from the ask actor.
"""

def get_relevant_documents(self, question: str):
index_actor = ray.get_actor(ACTOR_TYPE.LEARN.value)
handle = index_actor.get_index.remote()
vectorstore = ray.get(handle)
# Have to reference the latest index
self.chat_provider.retriever = vectorstore.as_retriever()

result = self.chat_provider({"question": query, "chat_history": self.chat_history})
response = result['answer']
self.chat_history.append((query, response))
self.reply(response, message)
docs = ray.get(index_actor.get_relevant_documents.remote(question))
return docs

async def aget_relevant_documents(self, query: str) -> List[Document]:
return await super().aget_relevant_documents(query)
Loading