Skip to content

Commit

Permalink
Fix specifying empty list in provider and model allow/denylists (#1185)
Browse files Browse the repository at this point in the history
* Fix empty list in provider and model allow/deny-listing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update unit tests

---------

Co-authored-by: maico <maico.timmerman@adyen.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: David L. Qiu <david@qiu.dev>
  • Loading branch information
4 people authored Jan 6, 2025
1 parent f44e62b commit 3f11712
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
"restrictions",
[
{"allowed_providers": None, "blocked_providers": None},
{"allowed_providers": [], "blocked_providers": []},
{"allowed_providers": [], "blocked_providers": [KNOWN_LM_B]},
{"allowed_providers": None, "blocked_providers": []},
{"allowed_providers": None, "blocked_providers": [KNOWN_LM_B]},
{"allowed_providers": [KNOWN_LM_A], "blocked_providers": []},
{"allowed_providers": [KNOWN_LM_A], "blocked_providers": None},
],
)
def test_get_lm_providers_not_restricted(restrictions):
Expand All @@ -25,8 +26,11 @@ def test_get_lm_providers_not_restricted(restrictions):
@pytest.mark.parametrize(
"restrictions",
[
{"allowed_providers": [], "blocked_providers": None},
{"allowed_providers": [], "blocked_providers": [KNOWN_LM_A]},
{"allowed_providers": None, "blocked_providers": [KNOWN_LM_A]},
{"allowed_providers": [KNOWN_LM_B], "blocked_providers": []},
{"allowed_providers": [KNOWN_LM_B], "blocked_providers": None},
],
)
def test_get_lm_providers_restricted(restrictions):
Expand Down
4 changes: 2 additions & 2 deletions packages/jupyter-ai-magics/jupyter_ai_magics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def get_em_provider(
def is_provider_allowed(provider_id: str, restrictions: ProviderRestrictions) -> bool:
allowed = restrictions["allowed_providers"]
blocked = restrictions["blocked_providers"]
if blocked and provider_id in blocked:
if blocked is not None and provider_id in blocked:
return False
if allowed and provider_id not in allowed:
if allowed is not None and provider_id not in allowed:
return False
return True

Expand Down
7 changes: 5 additions & 2 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,13 @@ def _validate_model(self, model_id: str, raise_exc=True):
"Model provider included in the provider blocklist."
)

if self._allowed_models and model_id not in self._allowed_models:
if (
self._allowed_models is not None
and model_id not in self._allowed_models
):
raise BlockedModelError("Model not included in the model allowlist.")

if self._blocked_models and model_id in self._blocked_models:
if self._blocked_models is not None and model_id in self._blocked_models:
raise BlockedModelError("Model included in the model blocklist.")
except BlockedModelError as e:
if raise_exc:
Expand Down

0 comments on commit 3f11712

Please sign in to comment.