Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle config errors #470

Closed
wants to merge 11 commits into from
197 changes: 144 additions & 53 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@

from deepmerge import always_merger as Merger
from jsonschema import Draft202012Validator as Validator
from jupyter_ai.models import DescribeConfigResponse, GlobalConfig, UpdateConfigRequest
from jupyter_ai.models import (
ConfigErrorModel,
ConfigErrorType,
DescribeConfigResponse,
GlobalConfig,
UpdateConfigRequest,
)
from jupyter_ai_magics.utils import (
AnyProvider,
EmProvidersDict,
Expand All @@ -16,6 +22,7 @@
get_lm_provider,
)
from jupyter_core.paths import jupyter_data_dir
from pydantic import ValidationError
from traitlets import Integer, Unicode
from traitlets.config import Configurable

Expand Down Expand Up @@ -70,6 +77,16 @@ def _validate_provider_authn(config: GlobalConfig, provider: AnyProvider):
)


def _format_validation_errors(error: ValidationError):
"""Format Pydantic validation errors for user-friendly output."""
messages = []
for e in error.errors():
field_path = " -> ".join(map(str, e["loc"]))
error_message = f"Error in '{field_path}': {e['msg']}. Please review and correct this field."
messages.append(error_message)
return "Configuration Error: " + " | ".join(messages)


class ConfigManager(Configurable):
"""Provides model and embedding provider id along
with the credentials to authenticate providers.
Expand Down Expand Up @@ -111,6 +128,7 @@ def __init__(
super().__init__(*args, **kwargs)
self.log = log

self._config_errors = []
self._lm_providers = lm_providers
"""List of LM providers."""
self._em_providers = em_providers
Expand Down Expand Up @@ -146,53 +164,31 @@ def _init_validator(self) -> Validator:
self.validator = Validator(schema)

def _init_config(self):
# try:
if os.path.exists(self.config_path):
with open(self.config_path, encoding="utf-8") as f:
config = GlobalConfig(**json.loads(f.read()))
lm_id = config.model_provider_id
em_id = config.embeddings_provider_id

# if the currently selected language or embedding model are
# forbidden, set them to `None` and log a warning.
if lm_id is not None and not self._validate_model(
lm_id, raise_exc=False
):
self.log.warning(
f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None."
)
config.model_provider_id = None
if em_id is not None and not self._validate_model(
em_id, raise_exc=False
):
self.log.warning(
f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None."
)
config.embeddings_provider_id = None

# if the currently selected language or embedding model ids are
# not associated with models, set them to `None` and log a warning.
if (
lm_id is not None
and not get_lm_provider(lm_id, self._lm_providers)[1]
):
self.log.warning(
f"No language model is associated with '{lm_id}'. Setting to None."
)
config.model_provider_id = None
if (
em_id is not None
and not get_em_provider(em_id, self._em_providers)[1]
):
self.log.warning(
f"No embedding model is associated with '{em_id}'. Setting to None."
)
config.embeddings_provider_id = None
self._process_existing_config()
else:
self._create_default_config()
# except ValidationError as e:
# self._handle_validation_error(e)
# self._config = GlobalConfig(
# send_with_shift_enter=False, fields={}, api_keys={}
# )

def _process_existing_config(self):
with open(self.config_path, encoding="utf-8") as f:
raw_config = json.loads(f.read())

# re-write to the file to validate the config and apply any
# updates to the config file immediately
self._write_config(config)
return
validated_raw_config = self._validate_lm_em_id(raw_config)

try:
config = GlobalConfig(**validated_raw_config)
self._write_config(config)
except ValidationError as e:
corrected_config = self._handle_validation_error(e, validated_raw_config)
self._write_config(corrected_config)

def _create_default_config(self):
properties = self.validator.schema.get("properties", {})
field_list = GlobalConfig.__fields__.keys()
field_dict = {
Expand All @@ -201,6 +197,89 @@ def _init_config(self):
default_config = GlobalConfig(**field_dict)
self._write_config(default_config)

def _validate_lm_em_id(self, raw_config):
lm_id = raw_config.get("model_provider_id")
em_id = raw_config.get("embeddings_provider_id")

# if the currently selected language or embedding model are
# forbidden, set them to `None` and log a warning.
if lm_id is not None and not self._validate_model(lm_id, raise_exc=False):
warning_message = f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None."
self.log.warning(warning_message)
raw_config["model_provider_id"] = None
self._config_errors.append(
ConfigErrorModel(
error_type=ConfigErrorType.WARNING, message=warning_message
)
)

if em_id is not None and not self._validate_model(em_id, raise_exc=False):
warning_message = f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None."
self.log.warning(warning_message)
raw_config["embeddings_provider_id"] = None
self._config_errors.append(
ConfigErrorModel(
error_type=ConfigErrorType.WARNING, message=warning_message
)
)

# if the currently selected language or embedding model ids are
# not associated with models, set them to `None` and log a warning.
if lm_id is not None and not get_lm_provider(lm_id, self._lm_providers)[1]:
warning_message = (
f"No language model is associated with '{lm_id}'. Setting to None."
)
self.log.warning(warning_message)
raw_config["model_provider_id"] = None
self._config_errors.append(
ConfigErrorModel(
error_type=ConfigErrorType.WARNING, message=warning_message
)
)

if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]:
warning_message = (
f"No embedding model is associated with '{em_id}'. Setting to None."
)
self.log.warning(warning_message)
raw_config["embeddings_provider_id"] = None
self._config_errors.append(
ConfigErrorModel(
error_type=ConfigErrorType.WARNING, message=warning_message
)
)

return raw_config

def _handle_validation_error(self, e: ValidationError, raw_config):
# Extract default values from schema
properties = self.validator.schema.get("properties", {})
field_list = GlobalConfig.__fields__.keys()
default_values = {
field: properties.get(field).get("default") for field in field_list
}

# Apply default values to erroneous fields
for error in e.errors():
field = error["loc"][0]
if field in default_values:
raw_config[field] = default_values[field]
warning_message = f"Error in '{field}': {error['msg']}. Resetting to default value ('{default_values[field]}')."
self.log.warning(warning_message)
self._config_errors.append(
ConfigErrorModel(
error_type=ConfigErrorType.WARNING, message=warning_message
)
)

# Create a config with default values for erroneous fields
config = GlobalConfig(**raw_config)
self.log.warning("\n\n\n Config \n\n\n")

self.log.warning(config)
self._validate_config(config)
return config

def _read_config(self) -> GlobalConfig:
"""Returns the user's current configuration as a GlobalConfig object.
This should never be sent to the client as it includes API keys. Prefer
Expand All @@ -210,12 +289,15 @@ def _read_config(self) -> GlobalConfig:
if last_write <= self._last_read:
return self._config

with open(self.config_path, encoding="utf-8") as f:
self._last_read = time.time_ns()
raw_config = json.loads(f.read())
config = GlobalConfig(**raw_config)
self._validate_config(config)
return config
with open(self.config_path, encoding="utf-8") as f:
self._last_read = time.time_ns()
raw_config = json.loads(f.read())
try:
config = GlobalConfig(**raw_config)
except ValidationError as e:
config = self._handle_validation_error(e, raw_config)
self._validate_config(config)
return config

def _validate_config(self, config: GlobalConfig):
"""Method used to validate the configuration. This is called after every
Expand Down Expand Up @@ -333,6 +415,12 @@ def delete_api_key(self, key_name: str):
config_dict["api_keys"].pop(key_name, None)
self._write_config(GlobalConfig(**config_dict))

def get_config_errors(self):
if self._config_errors:
return self._config_errors
else:
return None

def update_config(self, config_update: UpdateConfigRequest):
last_write = os.stat(self.config_path).st_mtime_ns
if config_update.last_read and config_update.last_read < last_write:
Expand All @@ -354,9 +442,12 @@ def update_config(self, config_update: UpdateConfigRequest):
def get_config(self):
config = self._read_config()
config_dict = config.dict(exclude_unset=True)
api_key_names = list(config_dict.pop("api_keys").keys())
api_key_names = list(config_dict.pop("api_keys", {}).keys())
return DescribeConfigResponse(
**config_dict, api_keys=api_key_names, last_read=self._last_read
**config_dict,
api_keys=api_key_names,
last_read=self._last_read,
config_errors=self.get_config_errors(),
)

@property
Expand Down
53 changes: 41 additions & 12 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
LearnChatHandler,
)
from .chat_handlers.help import HelpMessage
from .config_manager import ConfigManager
from .config_manager import ConfigErrorType, ConfigManager
from .handlers import (
ApiKeysHandler,
ChatHistoryHandler,
Expand All @@ -28,14 +28,7 @@

class AiExtension(ExtensionApp):
name = "jupyter_ai"
handlers = [
(r"api/ai/api_keys/(?P<api_key_name>\w+)", ApiKeysHandler),
(r"api/ai/config/?", GlobalConfigHandler),
(r"api/ai/chats/?", RootChatHandler),
(r"api/ai/chats/history?", ChatHistoryHandler),
(r"api/ai/providers?", ModelProviderHandler),
(r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler),
]
handlers = [(r"api/ai/config/?", GlobalConfigHandler)]

allowed_providers = List(
Unicode(),
Expand Down Expand Up @@ -130,10 +123,33 @@ def initialize_settings(self):
blocked_models=self.blocked_models,
)

self.log.info("Registered providers.")
config_errors = self.settings["jai_config_manager"].get_config_errors()
if config_errors is None or all(
error.error_type != ConfigErrorType.CRITICAL for error in config_errors
):
# Full functionality initialization
self._initialize_full_functionality()
else:
# Log the error and proceed with limited functionality
self.log.error(f"Configuration errors detected: {config_errors}")
self._initialize_limited_functionality(config_errors)

self.log.info(f"Registered {self.name} server extension")

latency_ms = round((time.time() - start) * 1000)
self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.")

def _initialize_full_functionality(self):
self.handlers.extend(
[
(r"api/ai/api_keys/(?P<api_key_name>\w+)", ApiKeysHandler),
(r"api/ai/chats/?", RootChatHandler),
(r"api/ai/chats/history?", ChatHistoryHandler),
(r"api/ai/providers?", ModelProviderHandler),
(r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler),
]
)

# Store chat clients in a dictionary
self.settings["chat_clients"] = {}
self.settings["jai_root_chat_handlers"] = {}
Expand Down Expand Up @@ -190,8 +206,21 @@ def initialize_settings(self):
"/help": help_chat_handler,
}

latency_ms = round((time.time() - start) * 1000)
self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.")
self.log.info("Registered providers.")

def _initialize_limited_functionality(self, config_errors):
"""
Initialize the extension with limited functionality due to configuration errors.
"""
self.log.warning(
"Initializing Jupyter AI extension with limited functionality due to configuration errors."
)

# Capture configuration error details
config_errors = self.settings["jai_config_manager"].get_config_errors()
self.settings["config_errors"] = config_errors

self.settings["jai_chat_handlers"] = []

async def _get_dask_client(self):
return DaskClient(processes=False, asynchronous=True)
16 changes: 16 additions & 0 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union

from jupyter_ai_magics.providers import AuthStrategy, Field
Expand Down Expand Up @@ -99,6 +100,20 @@ class IndexMetadata(BaseModel):
dirs: List[IndexedDir]


class ConfigErrorType(Enum):
CRITICAL = "Critical"
WARNING = "Warning"


class ConfigErrorModel(BaseModel):
error_type: ConfigErrorType
message: str
details: str = None

def __str__(self):
return f"{self.error_type.value} ConfigError: {self.message} - {self.details or ''}"


class DescribeConfigResponse(BaseModel):
model_provider_id: Optional[str]
embeddings_provider_id: Optional[str]
Expand All @@ -110,6 +125,7 @@ class DescribeConfigResponse(BaseModel):
# timestamp indicating when the configuration file was last read. should be
# passed to the subsequent UpdateConfig request.
last_read: int
config_errors: Optional[List[ConfigErrorModel]] = None


def forbid_none(cls, v):
Expand Down
Loading