diff --git a/docs/usage/stores.rst b/docs/usage/stores.rst index 52118f9332..fb580f75f7 100644 --- a/docs/usage/stores.rst +++ b/docs/usage/stores.rst @@ -259,3 +259,11 @@ with minimal boilerplate: Without any extra configuration, every call to ``app.stores.get`` with a unique name will return a namespace for this name only, while re-using the underlying Redis instance. + + +Store lifetime +++++++++++++++ + +Stores may not be automatically closed when the application is shut down. +This is the case in particular for the RedisStore if you are not using the class method :meth:`RedisStore.with_client <.redis.RedisStore.with_client>` and passing in your own Redis instance. +In this case you're responsible to close the Redis instance yourself. diff --git a/litestar/app.py b/litestar/app.py index 1ece9f016b..4bdb2648c3 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -388,7 +388,12 @@ def __init__( self._openapi_schema: OpenAPI | None = None self._debug: bool = True + self.stores: StoreRegistry = ( + config.stores if isinstance(config.stores, StoreRegistry) else StoreRegistry(config.stores) + ) self._lifespan_managers = config.lifespan + for store in self.stores._stores.values(): + self._lifespan_managers.append(store) self._server_lifespan_managers = [p.server_lifespan for p in config.plugins or [] if isinstance(p, CLIPlugin)] self.experimental_features = frozenset(config.experimental_features or []) self.get_logger: GetLogger = get_logger_placeholder @@ -471,10 +476,6 @@ def __init__( self.asgi_handler = self._create_asgi_handler() - self.stores: StoreRegistry = ( - config.stores if isinstance(config.stores, StoreRegistry) else StoreRegistry(config.stores) - ) - @property @deprecated(version="2.6.0", kind="property", info="Use create_static_files router instead") def static_files_config(self) -> list[StaticFilesConfig]: diff --git a/litestar/stores/base.py b/litestar/stores/base.py index 585a2cbf5b..34aa514fca 100644 --- a/litestar/stores/base.py +++ b/litestar/stores/base.py @@ -9,6 +9,8 @@ from msgspec.msgpack import encode as msgpack_encode if TYPE_CHECKING: + from types import TracebackType + from typing_extensions import Self @@ -76,6 +78,17 @@ async def expires_in(self, key: str) -> int | None: """ raise NotImplementedError + async def __aenter__(self) -> None: # noqa: B027 + pass + + async def __aexit__( # noqa: B027 + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + pass + class NamespacedStore(Store): """A subclass of :class:`Store`, offering hierarchical namespacing. diff --git a/litestar/stores/redis.py b/litestar/stores/redis.py index f0baa2538a..6697962fab 100644 --- a/litestar/stores/redis.py +++ b/litestar/stores/redis.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import timedelta -from typing import cast +from typing import TYPE_CHECKING, cast from redis.asyncio import Redis from redis.asyncio.connection import ConnectionPool @@ -14,13 +14,18 @@ __all__ = ("RedisStore",) +if TYPE_CHECKING: + from types import TracebackType + class RedisStore(NamespacedStore): """Redis based, thread and process safe asynchronous key/value store.""" __slots__ = ("_redis",) - def __init__(self, redis: Redis, namespace: str | None | EmptyType = Empty) -> None: + def __init__( + self, redis: Redis, namespace: str | None | EmptyType = Empty, handle_client_shutdown: bool = False + ) -> None: """Initialize :class:`RedisStore` Args: @@ -28,9 +33,11 @@ def __init__(self, redis: Redis, namespace: str | None | EmptyType = Empty) -> N namespace: A key prefix to simulate a namespace in redis. If not given, defaults to ``LITESTAR``. Namespacing can be explicitly disabled by passing ``None``. This will make :meth:`.delete_all` unavailable. + handle_client_shutdown: If ``True``, handle the shutdown of the `redis` instance automatically during the store's lifespan. Should be set to `True` unless the shutdown is handled externally """ self._redis = redis self.namespace: str | None = value_or_default(namespace, "LITESTAR") + self.handle_client_shutdown = handle_client_shutdown # script to get and renew a key in one atomic step self._get_and_renew_script = self._redis.register_script( @@ -64,6 +71,18 @@ def __init__(self, redis: Redis, namespace: str | None | EmptyType = Empty) -> N """ ) + async def _shutdown(self) -> None: + if self.handle_client_shutdown: + await self._redis.aclose(close_connection_pool=True) # type: ignore[attr-defined] + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self._shutdown() + @classmethod def with_client( cls, @@ -93,14 +112,22 @@ def with_client( username=username, password=password, ) - return cls(redis=Redis(connection_pool=pool), namespace=namespace) + return cls( + redis=Redis(connection_pool=pool), + namespace=namespace, + handle_client_shutdown=True, + ) def with_namespace(self, namespace: str) -> RedisStore: """Return a new :class:`RedisStore` with a nested virtual key namespace. The current instances namespace will serve as a prefix for the namespace, so it can be considered the parent namespace. """ - return type(self)(redis=self._redis, namespace=f"{self.namespace}_{namespace}" if self.namespace else namespace) + return type(self)( + redis=self._redis, + namespace=f"{self.namespace}_{namespace}" if self.namespace else namespace, + handle_client_shutdown=self.handle_client_shutdown, + ) def _make_key(self, key: str) -> str: prefix = f"{self.namespace}:" if self.namespace else "" diff --git a/tests/unit/test_stores.py b/tests/unit/test_stores.py index da3761ff6c..c8178996bc 100644 --- a/tests/unit/test_stores.py +++ b/tests/unit/test_stores.py @@ -366,3 +366,16 @@ async def test_file_store_handle_rename_fail(file_store: FileStore, mocker: Mock await file_store.set("foo", "bar") mock_unlink.assert_called_once() assert Path(mock_unlink.call_args_list[0].args[0]).with_suffix("") == file_store.path.joinpath("foo") + + +async def test_redis_store_with_client_shutdown() -> None: + redis_store = RedisStore.with_client(url="redis://localhost:6397") + assert await redis_store._redis.ping() + # remove the private shutdown and the assert below fails + # the check on connection is a mimic of https://github.com/redis/redis-py/blob/d529c2ad8d2cf4dcfb41bfd93ea68cfefd81aa66/tests/test_asyncio/test_connection_pool.py#L35-L39 + await redis_store._shutdown() + assert not any( + x.is_connected + for x in redis_store._redis.connection_pool._available_connections + + list(redis_store._redis.connection_pool._in_use_connections) + )