Skip to content

Commit

Permalink
Azure OpenAI and OpenAI proxy support (jupyterlab#322)
Browse files Browse the repository at this point in the history
* implement Azure OpenAI support and model ID labels

* pre-commit
  • Loading branch information
dlqqq authored and Marchlak committed Oct 28, 2024
1 parent a318e49 commit 37691ab
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 9 deletions.
2 changes: 1 addition & 1 deletion docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ see the following interface:
Each of the additional fields under "Language model" is required. These fields
should contain the following data:

- **Local model ID**: The name of your endpoint. This can be retrieved from the
- **Endpoint name**: The name of your endpoint. This can be retrieved from the
AWS Console at the URL
`https://<region>.console.aws.amazon.com/sagemaker/home?region=<region>#/endpoints`.

Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .providers import (
AI21Provider,
AnthropicProvider,
AzureChatOpenAIProvider,
BaseProvider,
BedrockProvider,
ChatOpenAINewProvider,
Expand Down
50 changes: 46 additions & 4 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from jsonpath_ng import parse
from langchain import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.llms import (
AI21,
Anthropic,
Expand Down Expand Up @@ -107,6 +107,9 @@ class Config:
model_id_key: ClassVar[str] = ...
"""Kwarg expected by the upstream LangChain provider."""

model_id_label: ClassVar[str] = ""
"""Human-readable label of the model ID."""

pypi_package_deps: ClassVar[List[str]] = []
"""List of PyPi package dependencies."""

Expand Down Expand Up @@ -464,6 +467,40 @@ class ChatOpenAINewProvider(BaseProvider, ChatOpenAI):
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")

fields = [
TextField(
key="openai_api_base", label="Base API URL (optional)", format="text"
),
TextField(
key="openai_organization", label="Organization (optional)", format="text"
),
TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
]


class AzureChatOpenAIProvider(BaseProvider, AzureChatOpenAI):
id = "azure-chat-openai"
name = "Azure OpenAI"
models = ["*"]
model_id_key = "deployment_name"
model_id_label = "Deployment name"
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
registry = True

fields = [
TextField(
key="openai_api_base", label="Base API URL (required)", format="text"
),
TextField(
key="openai_api_version", label="API version (required)", format="text"
),
TextField(
key="openai_organization", label="Organization (optional)", format="text"
),
TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
]


class JsonContentHandler(LLMContentHandler):
content_type = "application/json"
Expand Down Expand Up @@ -501,6 +538,7 @@ class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
name = "SageMaker endpoint"
models = ["*"]
model_id_key = "endpoint_name"
model_id_label = "Endpoint name"
# This all needs to be on one line of markdown, for use in a table
help = (
"Specify an endpoint name as the model ID. "
Expand All @@ -513,9 +551,13 @@ class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
auth_strategy = AwsAuthStrategy()
registry = True
fields = [
TextField(key="region_name", label="Region name", format="text"),
MultilineTextField(key="request_schema", label="Request schema", format="json"),
TextField(key="response_path", label="Response path", format="jsonpath"),
TextField(key="region_name", label="Region name (required)", format="text"),
MultilineTextField(
key="request_schema", label="Request schema (required)", format="json"
),
TextField(
key="response_path", label="Response path (required)", format="jsonpath"
),
]

def __init__(self, *args, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ huggingface_hub = "jupyter_ai_magics:HfHubProvider"
openai = "jupyter_ai_magics:OpenAIProvider"
openai-chat = "jupyter_ai_magics:ChatOpenAIProvider"
openai-chat-new = "jupyter_ai_magics:ChatOpenAINewProvider"
azure-chat-openai = "jupyter_ai_magics:AzureChatOpenAIProvider"
sagemaker-endpoint = "jupyter_ai_magics:SmEndpointProvider"
amazon-bedrock = "jupyter_ai_magics:BedrockProvider"

Expand Down
15 changes: 12 additions & 3 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uuid
from asyncio import AbstractEventLoop
from dataclasses import asdict
from typing import Dict, List
from typing import TYPE_CHECKING, Dict, List

import tornado
from jupyter_ai.chat_handlers import BaseChatHandler
Expand All @@ -29,6 +29,10 @@
Message,
)

if TYPE_CHECKING:
from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider
from jupyter_ai_magics.providers import BaseProvider


class ChatHistoryHandler(BaseAPIHandler):
"""Handler to return message history"""
Expand Down Expand Up @@ -237,7 +241,7 @@ def on_close(self):

class ModelProviderHandler(BaseAPIHandler):
@property
def lm_providers(self):
def lm_providers(self) -> Dict[str, "BaseProvider"]:
return self.settings["lm_providers"]

@web.authenticated
Expand All @@ -248,6 +252,10 @@ def get(self):
if provider.id == "openai-chat":
continue

optionals = {}
if provider.model_id_label:
optionals["model_id_label"] = provider.model_id_label

providers.append(
ListProvidersEntry(
id=provider.id,
Expand All @@ -256,6 +264,7 @@ def get(self):
auth_strategy=provider.auth_strategy,
registry=provider.registry,
fields=provider.fields,
**optionals,
)
)

Expand All @@ -267,7 +276,7 @@ def get(self):

class EmbeddingsModelProviderHandler(BaseAPIHandler):
@property
def em_providers(self):
def em_providers(self) -> Dict[str, "BaseEmbeddingsProvider"]:
return self.settings["em_providers"]

@web.authenticated
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class ListProvidersEntry(BaseModel):

id: str
name: str
model_id_label: Optional[str]
models: List[str]
auth_strategy: AuthStrategy
registry: bool
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/src/components/chat-settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ export function ChatSettings(): JSX.Element {
</Select>
{showLmLocalId && (
<TextField
label="Local model ID"
label={lmProvider?.model_id_label || 'Local model ID'}
value={lmLocalId}
onChange={e => setLmLocalId(e.target.value)}
fullWidth
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai/src/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ export namespace AiService {
export type ListProvidersEntry = {
id: string;
name: string;
model_id_label?: string;
models: string[];
auth_strategy: AuthStrategy;
registry: boolean;
Expand Down

0 comments on commit 37691ab

Please sign in to comment.