Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new provider interface in magics #23

Merged
merged 1 commit into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,6 @@ playground/

# jupyter releaser checkout
.jupyter_releaser_checkout

# reserve path for a dev script
dev.sh
10 changes: 10 additions & 0 deletions packages/jupyter-ai/jupyter_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,18 @@
# imports to expose entry points. DO NOT REMOVE.
from .engine import GPT3ModelEngine
from .tasks import tasks
from .providers import (
AI21Provider,
AnthropicProvider,
CohereProvider,
HfHubProvider,
OpenAIProvider,
ChatOpenAIProvider,
SmEndpointProvider
)

# imports to expose types to other AI modules. DO NOT REMOVE.
from .providers import BaseProvider
from .tasks import DefaultTaskDefinition

def _jupyter_labextension_paths():
Expand Down
181 changes: 62 additions & 119 deletions packages/jupyter-ai/jupyter_ai/magics.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,20 @@
import os
import json
from typing import Dict, Optional, Any
from typing import Optional
from importlib_metadata import entry_points

from jupyter_ai.providers import BaseProvider
from IPython import get_ipython
from IPython.core.magic import Magics, magics_class, line_cell_magic
from IPython.core.magic_arguments import magic_arguments, argument, parse_argstring
from IPython.display import display, HTML, Markdown, Math, JSON

from pydantic import ValidationError
from langchain.llms import type_to_cls_dict
from langchain.llms.openai import OpenAIChat

LANGCHAIN_LLM_DICT = type_to_cls_dict.copy()
LANGCHAIN_LLM_DICT["openai-chat"] = OpenAIChat

PROVIDER_SCHEMAS: Dict[str, Any] = {
# AI21 model provider currently only provides its prompt via the `prompt`
# key, which limits the models that can actually be used.
# Reference: https://docs.ai21.com/reference/j2-complete-api
"ai21": {
"models": ["j1-large", "j1-grande", "j1-jumbo", "j1-grande-instruct", "j2-large", "j2-grande", "j2-jumbo", "j2-grande-instruct", "j2-jumbo-instruct"],
"model_id_key": "model",
"auth_envvar": "AI21_API_KEY"
},
# Anthropic model provider supports any model available via
# `anthropic.Client#completion()`.
# Reference: https://console.anthropic.com/docs/api/reference
"anthropic": {
"models": ["claude-v1", "claude-v1.0", "claude-v1.2", "claude-instant-v1", "claude-instant-v1.0"],
"model_id_key": "model",
"auth_envvar": "ANTHROPIC_API_KEY"
},
# Cohere model provider supports any model available via
# `cohere.Client#generate()`.`
# Reference: https://docs.cohere.ai/reference/generate
"cohere": {
"models": ["medium", "xlarge"],
"model_id_key": "model",
"auth_envvar": "COHERE_API_KEY"
},
"huggingface_hub": {
"models": ["*"],
"model_id_key": "repo_id",
"auth_envvar": "HUGGINGFACEHUB_API_TOKEN"
},
"huggingface_endpoint": {
"models": ["*"],
"model_id_key": "endpoint_url",
"auth_envvar": "HUGGINGFACEHUB_API_TOKEN"
},
# OpenAI model provider supports any model available via
# `openai.Completion`.
# Reference: https://platform.openai.com/docs/models/model-endpoint-compatibility
"openai": {
"models": ['text-davinci-003', 'text-davinci-002', 'text-curie-001', 'text-babbage-001', 'text-ada-001', 'davinci', 'curie', 'babbage', 'ada'],
"model_id_key": "model_name",
"auth_envvar": "OPENAI_API_KEY"
},
# OpenAI chat model provider supports any model available via
# `openai.ChatCompletion`.
# Reference: https://platform.openai.com/docs/models/model-endpoint-compatibility
"openai-chat": {
"models": ['gpt-4', 'gpt-4-0314', 'gpt-4-32k', 'gpt-4-32k-0314', 'gpt-3.5-turbo', 'gpt-3.5-turbo-0301'],
"model_id_key": "model_name",
"auth_envvar": "OPENAI_API_KEY"
},
"sagemaker_endpoint": {
"models": ["*"],
"model_id_key": "endpoint_name",
},
}

MODEL_ID_ALIASES = {
"gpt2": "huggingface_hub:gpt2",
"gpt3": "openai:text-davinci-003",
"chatgpt": "openai-chat:gpt-3.5-turbo",
"gpt4": "openai-chat:gpt-4",
"gpt2": "huggingface_hub::gpt2",
"gpt3": "openai::text-davinci-003",
"chatgpt": "openai-chat::gpt-3.5-turbo",
"gpt4": "openai-chat::gpt-4",
}

DISPLAYS_BY_FORMAT = {
Expand All @@ -86,40 +25,6 @@
"json": JSON,
}

def decompose_model_id(model_id: str):
"""Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate."""
if model_id in MODEL_ID_ALIASES:
model_id = MODEL_ID_ALIASES[model_id]

if ":" not in model_id:
# case: model ID was not provided with a prefix indicating the provider
# ID. try to infer the provider ID before returning (None, None).

# naively search through the dictionary and return the first provider
# that provides a model of the same ID.
for provider_id, provider_schema in PROVIDER_SCHEMAS.items():
if model_id in provider_schema["models"]:
return (provider_id, model_id)

return (None, None)

possible_provider_id, local_model_id = model_id.split(":", 1)

if possible_provider_id not in ("http", "https"):
# case (happy path): provider ID was specified
return (possible_provider_id, local_model_id)

# else case: model ID is a URL to some endpoint. right now, the only
# provider that accepts a raw URL is huggingface_endpoint.
return ("huggingface_endpoint", local_model_id)

def get_provider(provider_id: Optional[str]):
"""Returns the model provider ID and class for a model ID. Returns None if indeterminate."""
if provider_id is None or provider_id not in LANGCHAIN_LLM_DICT:
return None

return LANGCHAIN_LLM_DICT[provider_id]

class FormatDict(dict):
"""Subclass of dict to be passed to str#format(). Suppresses KeyError and
leaves replacement field unchanged if replacement field is not associated
Expand All @@ -135,8 +40,19 @@ class AiMagics(Magics):
def __init__(self, shell):
super(AiMagics, self).__init__(shell)
self.transcript_openai = []

# load model providers from entry point
self.providers = {}
eps = entry_points()
model_provider_eps = eps.select(group="jupyter_ai.model_providers")
for model_provider_ep in model_provider_eps:
try:
Provider = model_provider_ep.load()
except:
continue
self.providers[Provider.id] = Provider

def append_exchange_openai(self, prompt: str, output: str):
def _append_exchange_openai(self, prompt: str, output: str):
"""Appends a conversational exchange between user and an OpenAI Chat
model to a transcript that will be included in future exchanges."""
self.transcript_openai.append({
Expand All @@ -148,6 +64,33 @@ def append_exchange_openai(self, prompt: str, output: str):
"content": output
})

def _decompose_model_id(self, model_id: str):
"""Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate."""
if model_id in MODEL_ID_ALIASES:
model_id = MODEL_ID_ALIASES[model_id]

if "::" not in model_id:
# case: model ID was not provided with a prefix indicating the provider
# ID. try to infer the provider ID before returning (None, None).

# naively search through the dictionary and return the first provider
# that provides a model of the same ID.
for provider_id, Provider in self.providers.items():
if model_id in Provider.models:
return (provider_id, model_id)

return (None, None)

provider_id, local_model_id = model_id.split("::", 1)
return (provider_id, local_model_id)

def _get_provider(self, provider_id: Optional[str]) -> BaseProvider:
"""Returns the model provider ID and class for a model ID. Returns None if indeterminate."""
if provider_id is None or provider_id not in self.providers:
return None

return self.providers[provider_id]

@magic_arguments()
@argument('model_id',
help="""Model to run, specified as a model ID that may be
Expand Down Expand Up @@ -178,44 +121,44 @@ def ai(self, line, cell=None):
prompt = cell

# determine provider and local model IDs
provider_id, local_model_id = decompose_model_id(args.model_id)
if provider_id is None:
return display("Cannot determine model provider.")
provider_id, local_model_id = self._decompose_model_id(args.model_id)
Provider = self._get_provider(provider_id)
if Provider is None:
return display(f"Cannot determine model provider from model ID {args.model_id}.")

# if `--reset` is specified, reset transcript and return early
if (provider_id == "openai-chat" and args.reset):
self.transcript_openai = []
return

# validate presence of authn credentials
auth_envvar = PROVIDER_SCHEMAS[provider_id].get("auth_envvar")
if auth_envvar and auth_envvar not in os.environ:
raise EnvironmentError(
f"Authentication environment variable {auth_envvar} not provided.\n"
f"An authentication token is required to use models from the {provider_id} provider.\n"
f"Please specify it via `%env {auth_envvar}=token`. "
) from None
auth_strategy = self.providers[provider_id].auth_strategy
if auth_strategy:
# TODO: handle auth strategies besides EnvAuthStrategy
if auth_strategy.type == "env" and auth_strategy.name not in os.environ:
raise EnvironmentError(
f"Authentication environment variable {auth_strategy.name} not provided.\n"
f"An authentication token is required to use models from the {Provider.name} provider.\n"
f"Please specify it via `%env {auth_strategy.name}=token`. "
) from None

# interpolate user namespace into prompt
ip = get_ipython()
prompt = prompt.format_map(FormatDict(ip.user_ns))

# configure and instantiate provider
ProviderClass = get_provider(provider_id)
model_id_key = PROVIDER_SCHEMAS[provider_id]['model_id_key']
provider_params = {}
provider_params[model_id_key] = local_model_id
provider_params = { "model_id": local_model_id }
if provider_id == "openai-chat":
provider_params["prefix_messages"] = self.transcript_openai
provider = ProviderClass(**provider_params)
provider = Provider(**provider_params)

# generate output from model via provider
result = provider.generate([prompt])
output = result.generations[0][0].text

# if openai-chat, append exchange to transcript
if provider_id == "openai-chat":
self.append_exchange_openai(prompt, output)
self._append_exchange_openai(prompt, output)

# build output display
DisplayClass = DISPLAYS_BY_FORMAT[args.format]
Expand Down
Loading