Skip to content

Commit

Permalink
Various magic enhancements and fixes (#32)
Browse files Browse the repository at this point in the history
* fix huggingface_hub provider

* add [all] optional dep group

* improve provider table in docs
  • Loading branch information
dlqqq authored Apr 6, 2023
1 parent 0031a9a commit cf5abf8
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 27 deletions.
17 changes: 9 additions & 8 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ in environment variables, and you will have to install relevant Python packages
You can find the environment variables you need to set, and the Python packages you need, in
[`packages/jupyter-ai/jupyter_ai/providers.py`](https://github.com/jupyterlab/jupyter-ai/blob/main/packages/jupyter-ai/jupyter_ai/providers.py).

| Provider | Environment variable | Python package(s) |
| ------------| -------------------------- | -------------- |
| AI21 | `AI21_API_KEY` | `ai21` |
| Anthropic | `ANTHROPIC_API_KEY` | `anthropic` |
| Cohere | `COHERE_API_KEY` | `cohere` |
| HuggingFace | `HUGGINGFACEHUB_API_TOKEN` | `huggingface_hub`, `ipywidgets` |
| OpenAI | `OPENAI_API_KEY` | `openai` |
| SageMaker | N/A | `boto3` |
| Provider | Provider ID | Environment variable | Python package(s) |
|---------------------|----------------------|----------------------------|---------------------------------|
| AI21 | `ai21` | `AI21_API_KEY` | `ai21` |
| Anthropic | `anthropic` | `ANTHROPIC_API_KEY` | `anthropic` |
| Cohere | `cohere` | `COHERE_API_KEY` | `cohere` |
| HuggingFace Hub | `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | `huggingface_hub`, `ipywidgets` |
| OpenAI | `openai` | `OPENAI_API_KEY` | `openai` |
| OpenAI (chat) | `openai-chat` | `OPENAI_API_KEY` | `openai` |
| SageMaker Endpoints | `sagemaker-endpoint` | N/A | `boto3` |

To use SageMaker's models, you will need to authenticate via
[boto3](https://github.com/boto/boto3).
Expand Down
45 changes: 26 additions & 19 deletions packages/jupyter-ai/jupyter_ai/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
OpenAIChat,
SagemakerEndpoint
)
from pydantic import BaseModel
from pydantic import BaseModel, Extra

class EnvAuthStrategy(BaseModel):
"""Require one auth token via an environment variable."""
Expand Down Expand Up @@ -38,6 +38,12 @@ class AwsAuthStrategy(BaseModel):
]

class BaseProvider(BaseLangchainProvider):
#
# pydantic config
#
class Config:
extra = Extra.allow

#
# class attrs
#
Expand All @@ -51,6 +57,9 @@ class BaseProvider(BaseLangchainProvider):
"""List of supported models by their IDs. For registry providers, this will
be just ["*"]."""

model_id_key: ClassVar[str] = ...
"""Kwarg expected by the upstream LangChain provider."""

pypi_package_deps: ClassVar[List[str]] = []
"""List of PyPi package dependencies."""

Expand All @@ -63,26 +72,17 @@ class BaseProvider(BaseLangchainProvider):
#
model_id: str

# define readonly aliases to self.model_id for LangChain model providers.
@property
def model(self):
return self.model_id
def __init__(self, *args, **kwargs):
try:
assert kwargs["model_id"]
except:
raise AssertionError("model_id was not specified. Please specify it as a keyword argument.")

@property
def model_name(self):
return self.model_id

@property
def repo_id(self):
return self.model_id
model_kwargs = {}
model_kwargs[self.__class__.model_id_key] = kwargs["model_id"]

super().__init__(*args, **kwargs, **model_kwargs)

@property
def endpoint_url(self):
return self.model_id

@property
def endpoint_name(self):
return self.model_id

class AI21Provider(BaseProvider, AI21):
id = "ai21"
Expand All @@ -98,6 +98,7 @@ class AI21Provider(BaseProvider, AI21):
"j2-grande-instruct",
"j2-jumbo-instruct",
]
model_id_key = "model"
pypi_package_deps = ["ai21"]
auth_strategy = EnvAuthStrategy(name="AI21_API_KEY")

Expand All @@ -111,20 +112,23 @@ class AnthropicProvider(BaseProvider, Anthropic):
"claude-instant-v1",
"claude-instant-v1.0",
]
model_id_key = "model"
pypi_package_deps = ["anthropic"]
auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY")

class CohereProvider(BaseProvider, Cohere):
id = "cohere"
name = "Cohere"
models = ["medium", "xlarge"]
model_id_key = "model"
pypi_package_deps = ["cohere"]
auth_strategy = EnvAuthStrategy(name="COHERE_API_KEY")

class HfHubProvider(BaseProvider, HuggingFaceHub):
id = "huggingface_hub"
name = "HuggingFace Hub"
models = ["*"]
model_id_key = "repo_id"
# ipywidgets needed to suppress tqdm warning
# https://stackoverflow.com/questions/67998191
# tqdm is a dependency of huggingface_hub
Expand All @@ -145,6 +149,7 @@ class OpenAIProvider(BaseProvider, OpenAI):
"babbage",
"ada",
]
model_id_key = "model_name"
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")

Expand All @@ -159,12 +164,14 @@ class ChatOpenAIProvider(BaseProvider, OpenAIChat):
"gpt-3.5-turbo",
"gpt-3.5-turbo-0301",
]
model_id_key = "model_name"
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")

class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
id = "sagemaker-endpoint"
name = "Sagemaker Endpoint"
models = ["*"]
model_id_key = "endpoint_name"
pypi_package_deps = ["boto3"]
auth_strategy = AwsAuthStrategy()
10 changes: 10 additions & 0 deletions packages/jupyter-ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ test = [
"pytest-tornasync"
]

all = [
"ai21",
"anthropic",
"cohere",
"huggingface_hub",
"ipywidgets",
"openai",
"boto3"
]

[tool.hatch.version]
source = "nodejs"

Expand Down

0 comments on commit cf5abf8

Please sign in to comment.