diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py index e1c517ebe..860bb02bc 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py @@ -12,9 +12,10 @@ "restrictions", [ {"allowed_providers": None, "blocked_providers": None}, - {"allowed_providers": [], "blocked_providers": []}, - {"allowed_providers": [], "blocked_providers": [KNOWN_LM_B]}, + {"allowed_providers": None, "blocked_providers": []}, + {"allowed_providers": None, "blocked_providers": [KNOWN_LM_B]}, {"allowed_providers": [KNOWN_LM_A], "blocked_providers": []}, + {"allowed_providers": [KNOWN_LM_A], "blocked_providers": None}, ], ) def test_get_lm_providers_not_restricted(restrictions): @@ -25,8 +26,11 @@ def test_get_lm_providers_not_restricted(restrictions): @pytest.mark.parametrize( "restrictions", [ + {"allowed_providers": [], "blocked_providers": None}, {"allowed_providers": [], "blocked_providers": [KNOWN_LM_A]}, + {"allowed_providers": None, "blocked_providers": [KNOWN_LM_A]}, {"allowed_providers": [KNOWN_LM_B], "blocked_providers": []}, + {"allowed_providers": [KNOWN_LM_B], "blocked_providers": None}, ], ) def test_get_lm_providers_restricted(restrictions): diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py index 5fb8f4fee..92aab3980 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py @@ -121,9 +121,9 @@ def get_em_provider( def is_provider_allowed(provider_id: str, restrictions: ProviderRestrictions) -> bool: allowed = restrictions["allowed_providers"] blocked = restrictions["blocked_providers"] - if blocked and provider_id in blocked: + if blocked is not None and provider_id in blocked: return False - if allowed and provider_id not in allowed: + if allowed is not None and provider_id not in allowed: return False return True diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 4732e152c..79b710e3a 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -325,10 +325,13 @@ def _validate_model(self, model_id: str, raise_exc=True): "Model provider included in the provider blocklist." ) - if self._allowed_models and model_id not in self._allowed_models: + if ( + self._allowed_models is not None + and model_id not in self._allowed_models + ): raise BlockedModelError("Model not included in the model allowlist.") - if self._blocked_models and model_id in self._blocked_models: + if self._blocked_models is not None and model_id in self._blocked_models: raise BlockedModelError("Model included in the model blocklist.") except BlockedModelError as e: if raise_exc: