Skip to content

Commit

Permalink
fix Config source priority and add tests (#200)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Nov 11, 2023
1 parent 34f4c5b commit 16264f3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
18 changes: 10 additions & 8 deletions ragna/core/_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from pathlib import Path
from typing import Union
from typing import Type, Union

import tomlkit
from pydantic import Field, ImportString, field_validator
Expand All @@ -17,12 +17,14 @@
from ._utils import RagnaException


class ConfigBase:
class ConfigBase(BaseSettings):
@classmethod
def customise_sources(
def settings_customise_sources(
cls,
settings_cls: Type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
# This order is needed to prioritize values from environment variables over
Expand All @@ -32,15 +34,15 @@ def customise_sources(
# explicitly passed to the constructor. For example, if the environment variable
# 'RAGNA_RAG_DATABASE_URL' is set, any values passed to
# `RagnaConfig(rag=RagConfig(database_url=...))` is ignored.
# TODO: Find a way to achieve the following priorities:
# FIXME: Find a way to achieve the following priorities:
# 1. Explicitly passed to Python object
# 2. Environment variable
# 3. Configuration file
# 4. Default
return env_settings, init_settings


class CoreConfig(BaseSettings):
class CoreConfig(ConfigBase):
model_config = SettingsConfigDict(env_prefix="ragna_core_")

queue_url: str = "memory"
Expand All @@ -54,7 +56,7 @@ class CoreConfig(BaseSettings):
]


class ApiConfig(BaseSettings):
class ApiConfig(ConfigBase):
model_config = SettingsConfigDict(env_prefix="ragna_api_")

url: str = "http://127.0.0.1:31476"
Expand All @@ -67,15 +69,15 @@ class ApiConfig(BaseSettings):
] = "ragna.core.RagnaDemoAuthentication" # type: ignore[assignment]


class UiConfig(BaseSettings):
class UiConfig(ConfigBase):
model_config = SettingsConfigDict(env_prefix="ragna_ui_")

url: str = "http://127.0.0.1:31477"
# FIXME: this needs to be dynamic for the url
origins: list[str] = ["http://127.0.0.1:31477"]


class Config(BaseSettings):
class Config(ConfigBase):
"""Ragna configuration"""

model_config = SettingsConfigDict(env_prefix="ragna_")
Expand Down
33 changes: 33 additions & 0 deletions tests/core/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os

import pytest

from ragna import Config


@pytest.mark.xfail()
def test_explicit_gt_env_var(mocker, tmp_path):
explicit = tmp_path / "explicit"

env_var = tmp_path / "env_var"
mocker.patch.dict(os.environ, values={"RAGNA_LOCAL_CACHE_ROOT": str(env_var)})

config = Config(local_cache_root=explicit)

assert config.local_cache_root == explicit


def test_env_var_gt_config_file(mocker, tmp_path):
config_file = tmp_path / "config_file"
config = Config(local_cache_root=config_file)
assert config.local_cache_root == config_file

config_path = tmp_path / "ragna.toml"
config.to_file(config_path)

env_var = tmp_path / "env_var"
mocker.patch.dict(os.environ, values={"RAGNA_LOCAL_CACHE_ROOT": str(env_var)})

config = Config.from_file(config_path)

assert config.local_cache_root == env_var

0 comments on commit 16264f3

Please sign in to comment.