Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix specifying empty list in provider and model allow/denylists #1185

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
"restrictions",
[
{"allowed_providers": None, "blocked_providers": None},
{"allowed_providers": [], "blocked_providers": []},
{"allowed_providers": [], "blocked_providers": [KNOWN_LM_B]},
{"allowed_providers": None, "blocked_providers": []},
{"allowed_providers": None, "blocked_providers": [KNOWN_LM_B]},
{"allowed_providers": [KNOWN_LM_A], "blocked_providers": []},
{"allowed_providers": [KNOWN_LM_A], "blocked_providers": None},
],
)
def test_get_lm_providers_not_restricted(restrictions):
Expand All @@ -25,8 +26,11 @@ def test_get_lm_providers_not_restricted(restrictions):
@pytest.mark.parametrize(
"restrictions",
[
{"allowed_providers": [], "blocked_providers": None},
{"allowed_providers": [], "blocked_providers": [KNOWN_LM_A]},
{"allowed_providers": None, "blocked_providers": [KNOWN_LM_A]},
{"allowed_providers": [KNOWN_LM_B], "blocked_providers": []},
{"allowed_providers": [KNOWN_LM_B], "blocked_providers": None},
],
)
def test_get_lm_providers_restricted(restrictions):
Expand Down
4 changes: 2 additions & 2 deletions packages/jupyter-ai-magics/jupyter_ai_magics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def get_em_provider(
def is_provider_allowed(provider_id: str, restrictions: ProviderRestrictions) -> bool:
allowed = restrictions["allowed_providers"]
blocked = restrictions["blocked_providers"]
if blocked and provider_id in blocked:
if blocked is not None and provider_id in blocked:
return False
if allowed and provider_id not in allowed:
if allowed is not None and provider_id not in allowed:
return False
return True

Expand Down
7 changes: 5 additions & 2 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,13 @@ def _validate_model(self, model_id: str, raise_exc=True):
"Model provider included in the provider blocklist."
)

if self._allowed_models and model_id not in self._allowed_models:
if (
self._allowed_models is not None
and model_id not in self._allowed_models
):
raise BlockedModelError("Model not included in the model allowlist.")

if self._blocked_models and model_id in self._blocked_models:
if self._blocked_models is not None and model_id in self._blocked_models:
raise BlockedModelError("Model included in the model blocklist.")
except BlockedModelError as e:
if raise_exc:
Expand Down
Loading