Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model allowlist and blocklists #446

Merged
merged 4 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 90 additions & 20 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import shutil
import time
from typing import Optional, Union
from typing import List, Optional, Union

from deepmerge import always_merger as Merger
from jsonschema import Draft202012Validator as Validator
Expand All @@ -12,10 +12,8 @@
AnyProvider,
EmProvidersDict,
LmProvidersDict,
ProviderRestrictions,
get_em_provider,
get_lm_provider,
is_provider_allowed,
)
from jupyter_core.paths import jupyter_data_dir
from traitlets import Integer, Unicode
Expand Down Expand Up @@ -57,6 +55,10 @@ class KeyEmptyError(Exception):
pass


class BlockedModelError(Exception):
pass


def _validate_provider_authn(config: GlobalConfig, provider: AnyProvider):
# TODO: handle non-env auth strategies
if not provider.auth_strategy or provider.auth_strategy.type != "env":
Expand Down Expand Up @@ -99,27 +101,34 @@ def __init__(
log: Logger,
lm_providers: LmProvidersDict,
em_providers: EmProvidersDict,
restrictions: ProviderRestrictions,
allowed_providers: Optional[List[str]],
blocked_providers: Optional[List[str]],
allowed_models: Optional[List[str]],
blocked_models: Optional[List[str]],
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.log = log
"""List of LM providers."""

self._lm_providers = lm_providers
"""List of EM providers."""
"""List of LM providers."""
self._em_providers = em_providers
"""Provider restrictions."""
self._restrictions = restrictions
"""List of EM providers."""

self._allowed_providers = allowed_providers
self._blocked_providers = blocked_providers
self._allowed_models = allowed_models
self._blocked_models = blocked_models

self._last_read: Optional[int] = None
"""When the server last read the config file. If the file was not
modified after this time, then we can return the cached
`self._config`."""
self._last_read: Optional[int] = None

self._config: Optional[GlobalConfig] = None
"""In-memory cache of the `GlobalConfig` object parsed from the config
file."""
self._config: Optional[GlobalConfig] = None

self._init_config_schema()
self._init_validator()
Expand All @@ -140,6 +149,26 @@ def _init_config(self):
if os.path.exists(self.config_path):
with open(self.config_path, encoding="utf-8") as f:
config = GlobalConfig(**json.loads(f.read()))
lm_id = config.model_provider_id
em_id = config.embeddings_provider_id

# if the currently selected language or embedding model are
# forbidden, set them to `None` and log a warning.
if lm_id is not None and not self._validate_model(
lm_id, raise_exc=False
):
self.log.warning(
f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None."
)
config.model_provider_id = None
if em_id is not None and not self._validate_model(
em_id, raise_exc=False
):
self.log.warning(
f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None."
)
config.embeddings_provider_id = None

# re-write to the file to validate the config and apply any
# updates to the config file immediately
self._write_config(config)
Expand Down Expand Up @@ -181,33 +210,74 @@ def _validate_config(self, config: GlobalConfig):
_, lm_provider = get_lm_provider(
config.model_provider_id, self._lm_providers
)
# do not check config for blocked providers
if not is_provider_allowed(config.model_provider_id, self._restrictions):
assert not lm_provider
return

# verify model is declared by some provider
if not lm_provider:
raise ValueError(
f"No language model is associated with '{config.model_provider_id}'."
)

# verify model is not blocked
self._validate_model(config.model_provider_id)

# verify model is authenticated
_validate_provider_authn(config, lm_provider)

# validate embedding model config
if config.embeddings_provider_id:
_, em_provider = get_em_provider(
config.embeddings_provider_id, self._em_providers
)
# do not check config for blocked providers
if not is_provider_allowed(
config.embeddings_provider_id, self._restrictions
):
assert not em_provider
return

# verify model is declared by some provider
if not em_provider:
raise ValueError(
f"No embedding model is associated with '{config.embeddings_provider_id}'."
)

# verify model is not blocked
self._validate_model(config.embeddings_provider_id)

# verify model is authenticated
_validate_provider_authn(config, em_provider)

def _validate_model(self, model_id: str, raise_exc=True):
"""
Validates a model against the set of allow/blocklists specified by the
traitlets configuration, returning `True` if the model is allowed, and
raising a `BlockedModelError` otherwise. If `raise_exc=False`, this
function returns `False` if the model is not allowed.
"""

assert model_id is not None
components = model_id.split(":", 1)
assert len(components) == 2
provider_id, _ = components

try:
if self._allowed_providers and provider_id not in self._allowed_providers:
raise BlockedModelError(
"Model provider not included in the provider allowlist."
)

if self._blocked_providers and provider_id in self._blocked_providers:
raise BlockedModelError(
"Model provider included in the provider blocklist."
)

if self._allowed_models and model_id not in self._allowed_models:
raise BlockedModelError("Model not included in the model allowlist.")

if self._blocked_models and model_id in self._blocked_models:
raise BlockedModelError("Model included in the model blocklist.")
except BlockedModelError as e:
if raise_exc:
raise e
else:
return False

return True

def _write_config(self, new_config: GlobalConfig):
"""Updates configuration and persists it to disk. This accepts a
complete `GlobalConfig` object, and should not be called publicly."""
Expand Down
44 changes: 42 additions & 2 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,50 @@ class AiExtension(ExtensionApp):
config=True,
)

allowed_models = List(
Unicode(),
default_value=None,
help="""
Language models to allow, as a list of global model IDs in the format
`<provider>:<local-model-id>`. If `None`, all are allowed. Defaults to
`None`.

Note: Currently, if `allowed_providers` is also set, then this field is
ignored. This is subject to change in a future non-major release. Using
both traits is considered to be undefined behavior at this time.
""",
allow_none=True,
config=True,
)

blocked_models = List(
Unicode(),
default_value=None,
help="""
Language models to block, as a list of global model IDs in the format
`<provider>:<local-model-id>`. If `None`, none are blocked. Defaults to
`None`.
""",
allow_none=True,
config=True,
)

def initialize_settings(self):
start = time.time()

# Read from allowlist and blocklist
restrictions = {
"allowed_providers": self.allowed_providers,
"blocked_providers": self.blocked_providers,
}

self.settings["allowed_models"] = self.allowed_models
self.settings["blocked_models"] = self.blocked_models
self.log.info(f"Configured provider allowlist: {self.allowed_providers}")
self.log.info(f"Configured provider blocklist: {self.blocked_providers}")
self.log.info(f"Configured model allowlist: {self.allowed_models}")
self.log.info(f"Configured model blocklist: {self.blocked_models}")

# Fetch LM & EM providers
self.settings["lm_providers"] = get_lm_providers(
log=self.log, restrictions=restrictions
)
Expand All @@ -73,7 +110,10 @@ def initialize_settings(self):
log=self.log,
lm_providers=self.settings["lm_providers"],
em_providers=self.settings["em_providers"],
restrictions=restrictions,
allowed_providers=self.allowed_providers,
blocked_providers=self.blocked_providers,
allowed_models=self.allowed_models,
blocked_models=self.blocked_models,
)

self.log.info("Registered providers.")
Expand Down
70 changes: 57 additions & 13 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uuid
from asyncio import AbstractEventLoop
from dataclasses import asdict
from typing import TYPE_CHECKING, Dict, List
from typing import TYPE_CHECKING, Dict, List, Optional

import tornado
from jupyter_ai.chat_handlers import BaseChatHandler
Expand Down Expand Up @@ -240,14 +240,58 @@ def on_close(self):
self.log.debug("Chat clients: %s", self.root_chat_handlers.keys())


class ModelProviderHandler(BaseAPIHandler):
class ProviderHandler(BaseAPIHandler):
"""
Helper base class used for HTTP handlers hosting endpoints relating to
providers. Wrapper around BaseAPIHandler.
"""

@property
def lm_providers(self) -> Dict[str, "BaseProvider"]:
return self.settings["lm_providers"]

@property
def em_providers(self) -> Dict[str, "BaseEmbeddingsProvider"]:
return self.settings["em_providers"]

@property
def allowed_models(self) -> Optional[List[str]]:
return self.settings["allowed_models"]

@property
def blocked_models(self) -> Optional[List[str]]:
return self.settings["blocked_models"]

def _filter_blocked_models(self, providers: List[ListProvidersEntry]):
"""
Satisfy the model-level allow/blocklist by filtering models accordingly.
The provider-level allow/blocklist is already handled in
`AiExtension.initialize_settings()`.
"""
if self.blocked_models is None and self.allowed_models is None:
return providers

def filter_predicate(local_model_id: str):
model_id = provider.id + ":" + local_model_id
if self.blocked_models:
return model_id not in self.blocked_models
else:
return model_id in self.allowed_models

# filter out every model w/ model ID according to allow/blocklist
for provider in providers:
provider.models = list(filter(filter_predicate, provider.models))

# filter out every provider with no models which satisfy the allow/blocklist, then return
return filter((lambda p: len(p.models) > 0), providers)


class ModelProviderHandler(ProviderHandler):
@web.authenticated
def get(self):
providers = []

# Step 1: gather providers
for provider in self.lm_providers.values():
# skip old legacy OpenAI chat provider used only in magics
if provider.id == "openai-chat":
Expand All @@ -270,17 +314,16 @@ def get(self):
)
)

response = ListProvidersResponse(
providers=sorted(providers, key=lambda p: p.name)
)
self.finish(response.json())
# Step 2: sort & filter providers
providers = self._filter_blocked_models(providers)
providers = sorted(providers, key=lambda p: p.name)

# Finally, yield response.
response = ListProvidersResponse(providers=providers)
self.finish(response.json())

class EmbeddingsModelProviderHandler(BaseAPIHandler):
@property
def em_providers(self) -> Dict[str, "BaseEmbeddingsProvider"]:
return self.settings["em_providers"]

class EmbeddingsModelProviderHandler(ProviderHandler):
@web.authenticated
def get(self):
providers = []
Expand All @@ -296,9 +339,10 @@ def get(self):
)
)

response = ListProvidersResponse(
providers=sorted(providers, key=lambda p: p.name)
)
providers = self._filter_blocked_models(providers)
providers = sorted(providers, key=lambda p: p.name)

response = ListProvidersResponse(providers=providers)
self.finish(response.json())


Expand Down
Loading