Skip to content

Commit

Permalink
fix: close RedisStore client in case the with_client classmethod is u…
Browse files Browse the repository at this point in the history
…sed (#3111)



---------

Co-authored-by: Janek Nouvertné <provinzkraut@posteo.de>
  • Loading branch information
euri10 and provinzkraut authored Feb 27, 2024
1 parent 1ff7f1e commit fe72143
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 8 deletions.
8 changes: 8 additions & 0 deletions docs/usage/stores.rst
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,11 @@ with minimal boilerplate:

Without any extra configuration, every call to ``app.stores.get`` with a unique name will return a namespace for this
name only, while re-using the underlying Redis instance.


Store lifetime
++++++++++++++

Stores may not be automatically closed when the application is shut down.
This is the case in particular for the RedisStore if you are not using the class method :meth:`RedisStore.with_client <.redis.RedisStore.with_client>` and passing in your own Redis instance.
In this case you're responsible to close the Redis instance yourself.
9 changes: 5 additions & 4 deletions litestar/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,12 @@ def __init__(

self._openapi_schema: OpenAPI | None = None
self._debug: bool = True
self.stores: StoreRegistry = (
config.stores if isinstance(config.stores, StoreRegistry) else StoreRegistry(config.stores)
)
self._lifespan_managers = config.lifespan
for store in self.stores._stores.values():
self._lifespan_managers.append(store)
self._server_lifespan_managers = [p.server_lifespan for p in config.plugins or [] if isinstance(p, CLIPlugin)]
self.experimental_features = frozenset(config.experimental_features or [])
self.get_logger: GetLogger = get_logger_placeholder
Expand Down Expand Up @@ -471,10 +476,6 @@ def __init__(

self.asgi_handler = self._create_asgi_handler()

self.stores: StoreRegistry = (
config.stores if isinstance(config.stores, StoreRegistry) else StoreRegistry(config.stores)
)

@property
@deprecated(version="2.6.0", kind="property", info="Use create_static_files router instead")
def static_files_config(self) -> list[StaticFilesConfig]:
Expand Down
13 changes: 13 additions & 0 deletions litestar/stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from msgspec.msgpack import encode as msgpack_encode

if TYPE_CHECKING:
from types import TracebackType

from typing_extensions import Self


Expand Down Expand Up @@ -76,6 +78,17 @@ async def expires_in(self, key: str) -> int | None:
"""
raise NotImplementedError

async def __aenter__(self) -> None: # noqa: B027
pass

async def __aexit__( # noqa: B027
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
pass


class NamespacedStore(Store):
"""A subclass of :class:`Store`, offering hierarchical namespacing.
Expand Down
35 changes: 31 additions & 4 deletions litestar/stores/redis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from datetime import timedelta
from typing import cast
from typing import TYPE_CHECKING, cast

from redis.asyncio import Redis
from redis.asyncio.connection import ConnectionPool
Expand All @@ -14,23 +14,30 @@

__all__ = ("RedisStore",)

if TYPE_CHECKING:
from types import TracebackType


class RedisStore(NamespacedStore):
"""Redis based, thread and process safe asynchronous key/value store."""

__slots__ = ("_redis",)

def __init__(self, redis: Redis, namespace: str | None | EmptyType = Empty) -> None:
def __init__(
self, redis: Redis, namespace: str | None | EmptyType = Empty, handle_client_shutdown: bool = False
) -> None:
"""Initialize :class:`RedisStore`
Args:
redis: An :class:`redis.asyncio.Redis` instance
namespace: A key prefix to simulate a namespace in redis. If not given,
defaults to ``LITESTAR``. Namespacing can be explicitly disabled by passing
``None``. This will make :meth:`.delete_all` unavailable.
handle_client_shutdown: If ``True``, handle the shutdown of the `redis` instance automatically during the store's lifespan. Should be set to `True` unless the shutdown is handled externally
"""
self._redis = redis
self.namespace: str | None = value_or_default(namespace, "LITESTAR")
self.handle_client_shutdown = handle_client_shutdown

# script to get and renew a key in one atomic step
self._get_and_renew_script = self._redis.register_script(
Expand Down Expand Up @@ -64,6 +71,18 @@ def __init__(self, redis: Redis, namespace: str | None | EmptyType = Empty) -> N
"""
)

async def _shutdown(self) -> None:
if self.handle_client_shutdown:
await self._redis.aclose(close_connection_pool=True) # type: ignore[attr-defined]

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self._shutdown()

@classmethod
def with_client(
cls,
Expand Down Expand Up @@ -93,14 +112,22 @@ def with_client(
username=username,
password=password,
)
return cls(redis=Redis(connection_pool=pool), namespace=namespace)
return cls(
redis=Redis(connection_pool=pool),
namespace=namespace,
handle_client_shutdown=True,
)

def with_namespace(self, namespace: str) -> RedisStore:
"""Return a new :class:`RedisStore` with a nested virtual key namespace.
The current instances namespace will serve as a prefix for the namespace, so it
can be considered the parent namespace.
"""
return type(self)(redis=self._redis, namespace=f"{self.namespace}_{namespace}" if self.namespace else namespace)
return type(self)(
redis=self._redis,
namespace=f"{self.namespace}_{namespace}" if self.namespace else namespace,
handle_client_shutdown=self.handle_client_shutdown,
)

def _make_key(self, key: str) -> str:
prefix = f"{self.namespace}:" if self.namespace else ""
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/test_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,16 @@ async def test_file_store_handle_rename_fail(file_store: FileStore, mocker: Mock
await file_store.set("foo", "bar")
mock_unlink.assert_called_once()
assert Path(mock_unlink.call_args_list[0].args[0]).with_suffix("") == file_store.path.joinpath("foo")


async def test_redis_store_with_client_shutdown() -> None:
redis_store = RedisStore.with_client(url="redis://localhost:6397")
assert await redis_store._redis.ping()
# remove the private shutdown and the assert below fails
# the check on connection is a mimic of https://github.com/redis/redis-py/blob/d529c2ad8d2cf4dcfb41bfd93ea68cfefd81aa66/tests/test_asyncio/test_connection_pool.py#L35-L39
await redis_store._shutdown()
assert not any(
x.is_connected
for x in redis_store._redis.connection_pool._available_connections
+ list(redis_store._redis.connection_pool._in_use_connections)
)

0 comments on commit fe72143

Please sign in to comment.