From a383f9b42e17cae6c6ea6f2fef0bc668dd891cd4 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Mon, 3 Apr 2023 23:59:21 +0000 Subject: [PATCH] use new provider interface in magics --- .gitignore | 3 + packages/jupyter-ai/jupyter_ai/__init__.py | 10 + packages/jupyter-ai/jupyter_ai/magics.py | 181 ++++++------------ packages/jupyter-ai/jupyter_ai/providers.py | 170 ++++++++++++++++ .../jupyter_ai/tests/test_handlers.py | 21 +- .../jupyter_ai/tests/test_providers.py | 11 ++ packages/jupyter-ai/pyproject.toml | 10 + 7 files changed, 277 insertions(+), 129 deletions(-) create mode 100644 packages/jupyter-ai/jupyter_ai/providers.py create mode 100644 packages/jupyter-ai/jupyter_ai/tests/test_providers.py diff --git a/.gitignore b/.gitignore index 94a97a79e..6fc57c1d2 100644 --- a/.gitignore +++ b/.gitignore @@ -123,3 +123,6 @@ playground/ # jupyter releaser checkout .jupyter_releaser_checkout + +# reserve path for a dev script +dev.sh diff --git a/packages/jupyter-ai/jupyter_ai/__init__.py b/packages/jupyter-ai/jupyter_ai/__init__.py index ff19b2aa7..33075f9f8 100644 --- a/packages/jupyter-ai/jupyter_ai/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/__init__.py @@ -5,8 +5,18 @@ # imports to expose entry points. DO NOT REMOVE. from .engine import GPT3ModelEngine from .tasks import tasks +from .providers import ( + AI21Provider, + AnthropicProvider, + CohereProvider, + HfHubProvider, + OpenAIProvider, + ChatOpenAIProvider, + SmEndpointProvider +) # imports to expose types to other AI modules. DO NOT REMOVE. +from .providers import BaseProvider from .tasks import DefaultTaskDefinition def _jupyter_labextension_paths(): diff --git a/packages/jupyter-ai/jupyter_ai/magics.py b/packages/jupyter-ai/jupyter_ai/magics.py index 04faa1745..a39afad31 100644 --- a/packages/jupyter-ai/jupyter_ai/magics.py +++ b/packages/jupyter-ai/jupyter_ai/magics.py @@ -1,81 +1,20 @@ import os import json -from typing import Dict, Optional, Any +from typing import Optional +from importlib_metadata import entry_points +from jupyter_ai.providers import BaseProvider from IPython import get_ipython from IPython.core.magic import Magics, magics_class, line_cell_magic from IPython.core.magic_arguments import magic_arguments, argument, parse_argstring from IPython.display import display, HTML, Markdown, Math, JSON -from pydantic import ValidationError -from langchain.llms import type_to_cls_dict -from langchain.llms.openai import OpenAIChat - -LANGCHAIN_LLM_DICT = type_to_cls_dict.copy() -LANGCHAIN_LLM_DICT["openai-chat"] = OpenAIChat - -PROVIDER_SCHEMAS: Dict[str, Any] = { - # AI21 model provider currently only provides its prompt via the `prompt` - # key, which limits the models that can actually be used. - # Reference: https://docs.ai21.com/reference/j2-complete-api - "ai21": { - "models": ["j1-large", "j1-grande", "j1-jumbo", "j1-grande-instruct", "j2-large", "j2-grande", "j2-jumbo", "j2-grande-instruct", "j2-jumbo-instruct"], - "model_id_key": "model", - "auth_envvar": "AI21_API_KEY" - }, - # Anthropic model provider supports any model available via - # `anthropic.Client#completion()`. - # Reference: https://console.anthropic.com/docs/api/reference - "anthropic": { - "models": ["claude-v1", "claude-v1.0", "claude-v1.2", "claude-instant-v1", "claude-instant-v1.0"], - "model_id_key": "model", - "auth_envvar": "ANTHROPIC_API_KEY" - }, - # Cohere model provider supports any model available via - # `cohere.Client#generate()`.` - # Reference: https://docs.cohere.ai/reference/generate - "cohere": { - "models": ["medium", "xlarge"], - "model_id_key": "model", - "auth_envvar": "COHERE_API_KEY" - }, - "huggingface_hub": { - "models": ["*"], - "model_id_key": "repo_id", - "auth_envvar": "HUGGINGFACEHUB_API_TOKEN" - }, - "huggingface_endpoint": { - "models": ["*"], - "model_id_key": "endpoint_url", - "auth_envvar": "HUGGINGFACEHUB_API_TOKEN" - }, - # OpenAI model provider supports any model available via - # `openai.Completion`. - # Reference: https://platform.openai.com/docs/models/model-endpoint-compatibility - "openai": { - "models": ['text-davinci-003', 'text-davinci-002', 'text-curie-001', 'text-babbage-001', 'text-ada-001', 'davinci', 'curie', 'babbage', 'ada'], - "model_id_key": "model_name", - "auth_envvar": "OPENAI_API_KEY" - }, - # OpenAI chat model provider supports any model available via - # `openai.ChatCompletion`. - # Reference: https://platform.openai.com/docs/models/model-endpoint-compatibility - "openai-chat": { - "models": ['gpt-4', 'gpt-4-0314', 'gpt-4-32k', 'gpt-4-32k-0314', 'gpt-3.5-turbo', 'gpt-3.5-turbo-0301'], - "model_id_key": "model_name", - "auth_envvar": "OPENAI_API_KEY" - }, - "sagemaker_endpoint": { - "models": ["*"], - "model_id_key": "endpoint_name", - }, -} MODEL_ID_ALIASES = { - "gpt2": "huggingface_hub:gpt2", - "gpt3": "openai:text-davinci-003", - "chatgpt": "openai-chat:gpt-3.5-turbo", - "gpt4": "openai-chat:gpt-4", + "gpt2": "huggingface_hub::gpt2", + "gpt3": "openai::text-davinci-003", + "chatgpt": "openai-chat::gpt-3.5-turbo", + "gpt4": "openai-chat::gpt-4", } DISPLAYS_BY_FORMAT = { @@ -86,40 +25,6 @@ "json": JSON, } -def decompose_model_id(model_id: str): - """Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate.""" - if model_id in MODEL_ID_ALIASES: - model_id = MODEL_ID_ALIASES[model_id] - - if ":" not in model_id: - # case: model ID was not provided with a prefix indicating the provider - # ID. try to infer the provider ID before returning (None, None). - - # naively search through the dictionary and return the first provider - # that provides a model of the same ID. - for provider_id, provider_schema in PROVIDER_SCHEMAS.items(): - if model_id in provider_schema["models"]: - return (provider_id, model_id) - - return (None, None) - - possible_provider_id, local_model_id = model_id.split(":", 1) - - if possible_provider_id not in ("http", "https"): - # case (happy path): provider ID was specified - return (possible_provider_id, local_model_id) - - # else case: model ID is a URL to some endpoint. right now, the only - # provider that accepts a raw URL is huggingface_endpoint. - return ("huggingface_endpoint", local_model_id) - -def get_provider(provider_id: Optional[str]): - """Returns the model provider ID and class for a model ID. Returns None if indeterminate.""" - if provider_id is None or provider_id not in LANGCHAIN_LLM_DICT: - return None - - return LANGCHAIN_LLM_DICT[provider_id] - class FormatDict(dict): """Subclass of dict to be passed to str#format(). Suppresses KeyError and leaves replacement field unchanged if replacement field is not associated @@ -135,8 +40,19 @@ class AiMagics(Magics): def __init__(self, shell): super(AiMagics, self).__init__(shell) self.transcript_openai = [] + + # load model providers from entry point + self.providers = {} + eps = entry_points() + model_provider_eps = eps.select(group="jupyter_ai.model_providers") + for model_provider_ep in model_provider_eps: + try: + Provider = model_provider_ep.load() + except: + continue + self.providers[Provider.id] = Provider - def append_exchange_openai(self, prompt: str, output: str): + def _append_exchange_openai(self, prompt: str, output: str): """Appends a conversational exchange between user and an OpenAI Chat model to a transcript that will be included in future exchanges.""" self.transcript_openai.append({ @@ -148,6 +64,33 @@ def append_exchange_openai(self, prompt: str, output: str): "content": output }) + def _decompose_model_id(self, model_id: str): + """Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate.""" + if model_id in MODEL_ID_ALIASES: + model_id = MODEL_ID_ALIASES[model_id] + + if "::" not in model_id: + # case: model ID was not provided with a prefix indicating the provider + # ID. try to infer the provider ID before returning (None, None). + + # naively search through the dictionary and return the first provider + # that provides a model of the same ID. + for provider_id, Provider in self.providers.items(): + if model_id in Provider.models: + return (provider_id, model_id) + + return (None, None) + + provider_id, local_model_id = model_id.split("::", 1) + return (provider_id, local_model_id) + + def _get_provider(self, provider_id: Optional[str]) -> BaseProvider: + """Returns the model provider ID and class for a model ID. Returns None if indeterminate.""" + if provider_id is None or provider_id not in self.providers: + return None + + return self.providers[provider_id] + @magic_arguments() @argument('model_id', help="""Model to run, specified as a model ID that may be @@ -178,9 +121,10 @@ def ai(self, line, cell=None): prompt = cell # determine provider and local model IDs - provider_id, local_model_id = decompose_model_id(args.model_id) - if provider_id is None: - return display("Cannot determine model provider.") + provider_id, local_model_id = self._decompose_model_id(args.model_id) + Provider = self._get_provider(provider_id) + if Provider is None: + return display(f"Cannot determine model provider from model ID {args.model_id}.") # if `--reset` is specified, reset transcript and return early if (provider_id == "openai-chat" and args.reset): @@ -188,26 +132,25 @@ def ai(self, line, cell=None): return # validate presence of authn credentials - auth_envvar = PROVIDER_SCHEMAS[provider_id].get("auth_envvar") - if auth_envvar and auth_envvar not in os.environ: - raise EnvironmentError( - f"Authentication environment variable {auth_envvar} not provided.\n" - f"An authentication token is required to use models from the {provider_id} provider.\n" - f"Please specify it via `%env {auth_envvar}=token`. " - ) from None + auth_strategy = self.providers[provider_id].auth_strategy + if auth_strategy: + # TODO: handle auth strategies besides EnvAuthStrategy + if auth_strategy.type == "env" and auth_strategy.name not in os.environ: + raise EnvironmentError( + f"Authentication environment variable {auth_strategy.name} not provided.\n" + f"An authentication token is required to use models from the {Provider.name} provider.\n" + f"Please specify it via `%env {auth_strategy.name}=token`. " + ) from None # interpolate user namespace into prompt ip = get_ipython() prompt = prompt.format_map(FormatDict(ip.user_ns)) # configure and instantiate provider - ProviderClass = get_provider(provider_id) - model_id_key = PROVIDER_SCHEMAS[provider_id]['model_id_key'] - provider_params = {} - provider_params[model_id_key] = local_model_id + provider_params = { "model_id": local_model_id } if provider_id == "openai-chat": provider_params["prefix_messages"] = self.transcript_openai - provider = ProviderClass(**provider_params) + provider = Provider(**provider_params) # generate output from model via provider result = provider.generate([prompt]) @@ -215,7 +158,7 @@ def ai(self, line, cell=None): # if openai-chat, append exchange to transcript if provider_id == "openai-chat": - self.append_exchange_openai(prompt, output) + self._append_exchange_openai(prompt, output) # build output display DisplayClass = DISPLAYS_BY_FORMAT[args.format] diff --git a/packages/jupyter-ai/jupyter_ai/providers.py b/packages/jupyter-ai/jupyter_ai/providers.py new file mode 100644 index 000000000..a1f94829a --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/providers.py @@ -0,0 +1,170 @@ +from typing import ClassVar, List, Union, Literal, Optional + +from langchain.schema import BaseLanguageModel as BaseLangchainProvider +from langchain.llms import ( + AI21, + Anthropic, + Cohere, + HuggingFaceHub, + OpenAI, + OpenAIChat, + SagemakerEndpoint +) +from pydantic import BaseModel + +class EnvAuthStrategy(BaseModel): + """Require one auth token via an environment variable.""" + type: Literal["env"] = "env" + name: str + + +class MultiEnvAuthStrategy(BaseModel): + """Require multiple auth tokens via multiple environment variables.""" + type: Literal["file"] = "file" + names: List[str] + + +class AwsAuthStrategy(BaseModel): + """Require AWS authentication via Boto3""" + type: Literal["aws"] = "aws" + + +AuthStrategy = Optional[ + Union[ + EnvAuthStrategy, + MultiEnvAuthStrategy, + AwsAuthStrategy, + ] +] + +class BaseProvider(BaseLangchainProvider): + # + # class attrs + # + id: ClassVar[str] = ... + """ID for this provider class.""" + + name: ClassVar[str] = ... + """User-facing name of this provider.""" + + models: ClassVar[List[str]] = ... + """List of supported models by their IDs. For registry providers, this will + be just ["*"].""" + + pypi_package_deps: ClassVar[List[str]] = [] + """List of PyPi package dependencies.""" + + auth_strategy: ClassVar[AuthStrategy] = None + """Authentication/authorization strategy. Declares what credentials are + required to use this model provider. Generally should not be `None`.""" + + # + # instance attrs + # + model_id: str + + # define readonly aliases to self.model_id for LangChain model providers. + @property + def model(self): + return self.model_id + + @property + def model_name(self): + return self.model_id + + @property + def repo_id(self): + return self.model_id + + @property + def endpoint_url(self): + return self.model_id + + @property + def endpoint_name(self): + return self.model_id + +class AI21Provider(BaseProvider, AI21): + id = "ai21" + name = "AI21" + models = [ + "j1-large", + "j1-grande", + "j1-jumbo", + "j1-grande-instruct", + "j2-large", + "j2-grande", + "j2-jumbo", + "j2-grande-instruct", + "j2-jumbo-instruct", + ] + pypi_package_deps = ["ai21"] + auth_strategy = EnvAuthStrategy(name="AI21_API_KEY") + +class AnthropicProvider(BaseProvider, Anthropic): + id = "anthropic" + name = "Anthropic" + models = [ + "claude-v1", + "claude-v1.0", + "claude-v1.2", + "claude-instant-v1", + "claude-instant-v1.0", + ] + pypi_package_deps = ["anthropic"] + auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY") + +class CohereProvider(BaseProvider, Cohere): + id = "cohere" + name = "Cohere" + models = ["medium", "xlarge"] + pypi_package_deps = ["cohere"] + auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY") + +class HfHubProvider(BaseProvider, HuggingFaceHub): + id = "huggingface_hub" + name = "HuggingFace Hub" + models = ["*"] + # ipywidgets needed to suppress tqdm warning + # https://stackoverflow.com/questions/67998191 + # tqdm is a dependency of huggingface_hub + pypi_package_deps = ["huggingface_hub", "ipywidgets"] + auth_strategy = EnvAuthStrategy(name="HUGGINGFACEHUB_API_TOKEN") + +class OpenAIProvider(BaseProvider, OpenAI): + id = "openai" + name = "OpenAI" + models = [ + "text-davinci-003", + "text-davinci-002", + "text-curie-001", + "text-babbage-001", + "text-ada-001", + "davinci", + "curie", + "babbage", + "ada", + ] + pypi_package_deps = ["openai"] + auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") + +class ChatOpenAIProvider(BaseProvider, OpenAIChat): + id = "openai-chat" + name = "OpenAI" + models = [ + "gpt-4", + "gpt-4-0314", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-3.5-turbo", + "gpt-3.5-turbo-0301", + ] + pypi_package_deps = ["openai"] + auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY") + +class SmEndpointProvider(BaseProvider, SagemakerEndpoint): + id = "sagemaker-endpoint" + name = "Sagemaker Endpoint" + models = ["*"] + pypi_package_deps = ["boto3"] + auth_strategy = AwsAuthStrategy() diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py index eac035dbd..620ec8990 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_handlers.py @@ -1,13 +1,14 @@ -import json +# TODO +# import json -async def test_get_example(jp_fetch): - # When - response = await jp_fetch("jupyter-ai", "get_example") +# async def test_get_example(jp_fetch): +# # When +# response = await jp_fetch("jupyter-ai", "get_example") - # Then - assert response.code == 200 - payload = json.loads(response.body) - assert payload == { - "data": "This is /jupyter-ai/get_example endpoint!" - } \ No newline at end of file +# # Then +# assert response.code == 200 +# payload = json.loads(response.body) +# assert payload == { +# "data": "This is /jupyter-ai/get_example endpoint!" +# } \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_providers.py b/packages/jupyter-ai/jupyter_ai/tests/test_providers.py new file mode 100644 index 000000000..2012365da --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tests/test_providers.py @@ -0,0 +1,11 @@ +from pydantic import ValidationError + +from jupyter_ai.providers import AI21Provider + +def test_model_id_required(): + try: + AI21Provider(ai21_api_key="asdf") + assert False + except ValidationError as e: + assert "1 validation error" in str(e) + assert "model_id" in str(e) diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index 8121df036..358f65f54 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -39,8 +39,18 @@ gpt3 = "jupyter_ai:GPT3ModelEngine" [project.entry-points."jupyter_ai.default_tasks"] core_default_tasks = "jupyter_ai:tasks" +[project.entry-points."jupyter_ai.model_providers"] +ai21 = "jupyter_ai:AI21Provider" +anthropic = "jupyter_ai:AnthropicProvider" +cohere = "jupyter_ai:CohereProvider" +huggingface_hub = "jupyter_ai:HfHubProvider" +openai = "jupyter_ai:OpenAIProvider" +openai-chat = "jupyter_ai:ChatOpenAIProvider" +sagemaker-endpoint = "jupyter_ai:SmEndpointProvider" + [project.optional-dependencies] test = [ + "jupyter-server[test]>=1.6,<3", "coverage", "pytest", "pytest-asyncio",