diff --git a/django_healthy/health_checks/db.py b/django_healthy/health_checks/db.py index 7b9b8bc..401192a 100644 --- a/django_healthy/health_checks/db.py +++ b/django_healthy/health_checks/db.py @@ -8,9 +8,8 @@ from django.utils.crypto import get_random_string from django.utils.translation import gettext_lazy as _ -from django_healthy.models import Test - from .base import HealthCheck, HealthCheckResult +from django_healthy.models import Test if TYPE_CHECKING: from .types import MessageDict diff --git a/django_healthy/health_checks/handler.py b/django_healthy/health_checks/handler.py index 209714a..e927f95 100644 --- a/django_healthy/health_checks/handler.py +++ b/django_healthy/health_checks/handler.py @@ -1,14 +1,20 @@ from __future__ import annotations -from typing import Any, Iterator, cast +from typing import TYPE_CHECKING, Any -from django.conf import settings from django.core.exceptions import ImproperlyConfigured +from django.utils.connection import BaseConnectionHandler from django.utils.module_loading import import_string -from django_healthy._compat import Mapping, MutableMapping, NotRequired, TypeAlias, TypedDict +from django_healthy._compat import ( + Mapping, + NotRequired, + TypeAlias, + TypedDict, +) -from .base import HealthCheck +if TYPE_CHECKING: + from .base import HealthCheck class InvalidHealthCheckError(ImproperlyConfigured): @@ -23,36 +29,12 @@ class BackendConfig(TypedDict): HealthCheckConfig: TypeAlias = Mapping[str, BackendConfig] -class HealthCheckHandler(Mapping[str, HealthCheck]): - __slots__: tuple[str, ...] = ("_backends", "_health_checks") +class HealthCheckHandler(BaseConnectionHandler): + settings_name = "HEALTH_CHECKS" + exception_class = InvalidHealthCheckError - def __init__(self, backends: HealthCheckConfig | None = None): - self._backends = ( - backends if backends is not None else cast(HealthCheckConfig, getattr(settings, "HEALTH_CHECKS", {})) - ) - self._health_checks: MutableMapping[str, HealthCheck] = {} - - def __getitem__(self, alias: str) -> HealthCheck: - try: - return self._health_checks[alias] - except KeyError: - try: - params = self._backends[alias] - except KeyError as exc: - msg = f"Could not find config for '{alias}' in settings.HEALTH_CHECKS." - raise InvalidHealthCheckError(msg) from exc - else: - health_check = self.create_health_check(params) - self._health_checks[alias] = health_check - return health_check - - def __iter__(self) -> Iterator[str]: - return iter(self._backends) - - def __len__(self) -> int: - return len(self._backends) - - def create_health_check(self, params: BackendConfig) -> HealthCheck: + def create_connection(self, alias: str) -> HealthCheck: + params = self.settings[alias] backend = params["BACKEND"] options = params.get("OPTIONS", {}) diff --git a/django_healthy/health_checks/service.py b/django_healthy/health_checks/service.py index 84945ca..9cb83eb 100644 --- a/django_healthy/health_checks/service.py +++ b/django_healthy/health_checks/service.py @@ -6,10 +6,9 @@ from timeit import default_timer as timer from typing import Any -from django_healthy._compat import Mapping # noqa: TCH001 - from .base import HealthCheck, HealthStatus from .handler import HealthCheckHandler, health_checks +from django_healthy._compat import Mapping # noqa: TCH001 @dataclass @@ -37,8 +36,7 @@ def __init__(self, handler: HealthCheckHandler | None = None): async def check_health(self) -> HealthReport: start_time: float = timer() task_map: dict[str, asyncio.Task[HealthReportEntry]] = { - name: asyncio.create_task(self.run_health_check(health_check)) - for name, health_check in self._handler.items() + name: asyncio.create_task(self.run_health_check(self._handler[name])) for name in self._handler } await asyncio.gather(*task_map.values()) end_time: float = timer() diff --git a/ruff_defaults.toml b/ruff_defaults.toml index cdf79d7..72c3fdd 100644 --- a/ruff_defaults.toml +++ b/ruff_defaults.toml @@ -544,7 +544,7 @@ select = [ ban-relative-imports = "all" [lint.isort] -known-first-party = ["django_healthy"] +known-local-folder = ["django_healthy"] [lint.flake8-pytest-style] fixture-parentheses = false diff --git a/tests/health_checks/test_handler.py b/tests/health_checks/test_handler.py index 2ea7d14..523cff0 100644 --- a/tests/health_checks/test_handler.py +++ b/tests/health_checks/test_handler.py @@ -2,31 +2,34 @@ from django.conf import settings from django_healthy.health_checks.cache import CacheHealthCheck -from django_healthy.health_checks.handler import HealthCheckHandler, InvalidHealthCheckError +from django_healthy.health_checks.handler import ( + HealthCheckHandler, + InvalidHealthCheckError, +) class TestHealthCheckHandler: def test_with_custom_settings(self): handler = HealthCheckHandler( - backends={ + { "test": { "BACKEND": "django_healthy.health_checks.cache.CacheHealthCheck", } } ) - items = set(handler.keys()) + items = set(handler) assert items == {"test"} def test_with_default_settings(self): handler = HealthCheckHandler() - items = set(handler.keys()) + items = set(handler) assert items == set(settings.HEALTH_CHECKS) def test_get_existing_item(self): handler = HealthCheckHandler( - backends={ + { "test": { "BACKEND": "django_healthy.health_checks.cache.CacheHealthCheck", } @@ -39,7 +42,7 @@ def test_get_existing_item(self): def test_get_missing_item(self): handler = HealthCheckHandler( - backends={ + { "test": { "BACKEND": "django_healthy.health_checks.cache.CacheHealthCheck", } diff --git a/tests/health_checks/test_service.py b/tests/health_checks/test_service.py index 562107a..37c1447 100644 --- a/tests/health_checks/test_service.py +++ b/tests/health_checks/test_service.py @@ -21,7 +21,7 @@ async def test_with_default_handler(self): async def test_with_custom_handler(self): service = HealthCheckService( HealthCheckHandler( - backends={ + { "test": { "BACKEND": "django_healthy.health_checks.cache.CacheHealthCheck", } @@ -37,7 +37,7 @@ async def test_with_custom_handler(self): async def test_with_unhealthy_service(self): service = HealthCheckService( HealthCheckHandler( - backends={ + { "test": { "BACKEND": "django_healthy.health_checks.db.DatabasePingHealthCheck", "OPTIONS": { @@ -56,7 +56,7 @@ async def test_with_unhealthy_service(self): async def test_with_multiple_service_status_gets_worst_case(self): service = HealthCheckService( HealthCheckHandler( - backends={ + { "healthy": { "BACKEND": "django_healthy.health_checks.db.DatabasePingHealthCheck", "OPTIONS": {