Skip to content

Commit

Permalink
Support client side caching with ConnectionPool (#3099)
Browse files Browse the repository at this point in the history
* sync

* async

* fixs connection mocks

* fix async connection mock

* fix test_asyncio/test_connection.py::test_single_connection

* add test for cache blacklist and flushdb at the end of each test

* fix review comments
  • Loading branch information
dvora-h authored Jan 7, 2024
1 parent 6d77c6d commit 8cbf7f5
Show file tree
Hide file tree
Showing 10 changed files with 318 additions and 246 deletions.
6 changes: 5 additions & 1 deletion redis/cache.py → redis/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,11 @@ class _LocalCache:
"""

def __init__(
self, max_size: int, ttl: int, eviction_policy: EvictionPolicy, **kwargs
self,
max_size: int = 100,
ttl: int = 0,
eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY,
**kwargs,
):
self.max_size = max_size
self.ttl = ttl
Expand Down
106 changes: 18 additions & 88 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
cast,
)

from redis._cache import (
DEFAULT_BLACKLIST,
DEFAULT_EVICTION_POLICY,
DEFAULT_WHITELIST,
_LocalCache,
)
from redis._parsers.helpers import (
_RedisCallbacks,
_RedisCallbacksRESP2,
Expand All @@ -39,12 +45,6 @@
)
from redis.asyncio.lock import Lock
from redis.asyncio.retry import Retry
from redis.cache import (
DEFAULT_BLACKLIST,
DEFAULT_EVICTION_POLICY,
DEFAULT_WHITELIST,
_LocalCache,
)
from redis.client import (
EMPTY_RESPONSE,
NEVER_DECODE,
Expand All @@ -67,7 +67,7 @@
TimeoutError,
WatchError,
)
from redis.typing import ChannelT, EncodableT, KeysT, KeyT, ResponseT
from redis.typing import ChannelT, EncodableT, KeyT
from redis.utils import (
HIREDIS_AVAILABLE,
_set_info_logger,
Expand Down Expand Up @@ -294,6 +294,13 @@ def __init__(
"lib_version": lib_version,
"redis_connect_func": redis_connect_func,
"protocol": protocol,
"cache_enable": cache_enable,
"client_cache": client_cache,
"cache_max_size": cache_max_size,
"cache_ttl": cache_ttl,
"cache_eviction_policy": cache_eviction_policy,
"cache_blacklist": cache_blacklist,
"cache_whitelist": cache_whitelist,
}
# based on input, setup appropriate connection args
if unix_socket_path is not None:
Expand Down Expand Up @@ -350,16 +357,6 @@ def __init__(
# on a set of redis commands
self._single_conn_lock = asyncio.Lock()

self.client_cache = client_cache
if cache_enable:
self.client_cache = _LocalCache(
cache_max_size, cache_ttl, cache_eviction_policy
)
if self.client_cache is not None:
self.cache_blacklist = cache_blacklist
self.cache_whitelist = cache_whitelist
self.client_cache_initialized = False

def __repr__(self):
return (
f"<{self.__class__.__module__}.{self.__class__.__name__}"
Expand All @@ -374,10 +371,6 @@ async def initialize(self: _RedisT) -> _RedisT:
async with self._single_conn_lock:
if self.connection is None:
self.connection = await self.connection_pool.get_connection("_")
if self.client_cache is not None:
self.connection._parser.set_invalidation_push_handler(
self._cache_invalidation_process
)
return self

def set_response_callback(self, command: str, callback: ResponseCallbackT):
Expand Down Expand Up @@ -596,8 +589,6 @@ async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
close_connection_pool is None and self.auto_close_connection_pool
):
await self.connection_pool.disconnect()
if self.client_cache:
self.client_cache.flush()

@deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
async def close(self, close_connection_pool: Optional[bool] = None) -> None:
Expand Down Expand Up @@ -626,89 +617,28 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
):
raise error

def _cache_invalidation_process(
self, data: List[Union[str, Optional[List[str]]]]
) -> None:
"""
Invalidate (delete) all redis commands associated with a specific key.
`data` is a list of strings, where the first string is the invalidation message
and the second string is the list of keys to invalidate.
(if the list of keys is None, then all keys are invalidated)
"""
if data[1] is not None:
for key in data[1]:
self.client_cache.invalidate(str_if_bytes(key))
else:
self.client_cache.flush()

async def _get_from_local_cache(self, command: str):
"""
If the command is in the local cache, return the response
"""
if (
self.client_cache is None
or command[0] in self.cache_blacklist
or command[0] not in self.cache_whitelist
):
return None
while not self.connection._is_socket_empty():
await self.connection.read_response(push_request=True)
return self.client_cache.get(command)

def _add_to_local_cache(
self, command: Tuple[str], response: ResponseT, keys: List[KeysT]
):
"""
Add the command and response to the local cache if the command
is allowed to be cached
"""
if (
self.client_cache is not None
and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist)
and (self.cache_whitelist == [] or command[0] in self.cache_whitelist)
):
self.client_cache.set(command, response, keys)

def delete_from_local_cache(self, command: str):
"""
Delete the command from the local cache
"""
try:
self.client_cache.delete(command)
except AttributeError:
pass

# COMMAND EXECUTION AND PROTOCOL PARSING
async def execute_command(self, *args, **options):
"""Execute a command and return a parsed response"""
await self.initialize()
command_name = args[0]
keys = options.pop("keys", None) # keys are used only for client side caching
response_from_cache = await self._get_from_local_cache(args)
pool = self.connection_pool
conn = self.connection or await pool.get_connection(command_name, **options)
response_from_cache = await conn._get_from_local_cache(args)
if response_from_cache is not None:
return response_from_cache
else:
pool = self.connection_pool
conn = self.connection or await pool.get_connection(command_name, **options)

if self.single_connection_client:
await self._single_conn_lock.acquire()
try:
if self.client_cache is not None and not self.client_cache_initialized:
await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, "CLIENT", *("CLIENT", "TRACKING", "ON")
),
lambda error: self._disconnect_raise(conn, error),
)
self.client_cache_initialized = True
response = await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
self._add_to_local_cache(args, response, keys)
conn._add_to_local_cache(args, response, keys)
return response
finally:
if self.single_connection_client:
Expand Down
88 changes: 86 additions & 2 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,15 @@
ResponseError,
TimeoutError,
)
from redis.typing import EncodableT
from redis.typing import EncodableT, KeysT, ResponseT
from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes

from .._cache import (
DEFAULT_BLACKLIST,
DEFAULT_EVICTION_POLICY,
DEFAULT_WHITELIST,
_LocalCache,
)
from .._parsers import (
BaseParser,
Encoder,
Expand Down Expand Up @@ -114,6 +120,9 @@ class AbstractConnection:
"encoder",
"ssl_context",
"protocol",
"client_cache",
"cache_blacklist",
"cache_whitelist",
"_reader",
"_writer",
"_parser",
Expand Down Expand Up @@ -148,6 +157,13 @@ def __init__(
encoder_class: Type[Encoder] = Encoder,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
cache_enable: bool = False,
client_cache: Optional[_LocalCache] = None,
cache_max_size: int = 100,
cache_ttl: int = 0,
cache_eviction_policy: str = DEFAULT_EVICTION_POLICY,
cache_blacklist: List[str] = DEFAULT_BLACKLIST,
cache_whitelist: List[str] = DEFAULT_WHITELIST,
):
if (username or password) and credential_provider is not None:
raise DataError(
Expand Down Expand Up @@ -205,6 +221,14 @@ def __init__(
if p < 2 or p > 3:
raise ConnectionError("protocol must be either 2 or 3")
self.protocol = protocol
if cache_enable:
_cache = _LocalCache(cache_max_size, cache_ttl, cache_eviction_policy)
else:
_cache = None
self.client_cache = client_cache if client_cache is not None else _cache
if self.client_cache is not None:
self.cache_blacklist = cache_blacklist
self.cache_whitelist = cache_whitelist

def __del__(self, _warnings: Any = warnings):
# For some reason, the individual streams don't get properly garbage
Expand Down Expand Up @@ -395,6 +419,11 @@ async def on_connect(self) -> None:
# if a database is specified, switch to it. Also pipeline this
if self.db:
await self.send_command("SELECT", self.db)
# if client caching is enabled, start tracking
if self.client_cache:
await self.send_command("CLIENT", "TRACKING", "ON")
await self.read_response()
self._parser.set_invalidation_push_handler(self._cache_invalidation_process)

# read responses from pipeline
for _ in (sent for sent in (self.lib_name, self.lib_version) if sent):
Expand Down Expand Up @@ -429,6 +458,9 @@ async def disconnect(self, nowait: bool = False) -> None:
raise TimeoutError(
f"Timed out closing connection after {self.socket_connect_timeout}"
) from None
finally:
if self.client_cache:
self.client_cache.flush()

async def _send_ping(self):
"""Send PING, expect PONG in return"""
Expand Down Expand Up @@ -646,10 +678,62 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]
output.append(SYM_EMPTY.join(pieces))
return output

def _is_socket_empty(self):
def _socket_is_empty(self):
"""Check if the socket is empty"""
return not self._reader.at_eof()

def _cache_invalidation_process(
self, data: List[Union[str, Optional[List[str]]]]
) -> None:
"""
Invalidate (delete) all redis commands associated with a specific key.
`data` is a list of strings, where the first string is the invalidation message
and the second string is the list of keys to invalidate.
(if the list of keys is None, then all keys are invalidated)
"""
if data[1] is not None:
self.client_cache.flush()
else:
for key in data[1]:
self.client_cache.invalidate(str_if_bytes(key))

async def _get_from_local_cache(self, command: str):
"""
If the command is in the local cache, return the response
"""
if (
self.client_cache is None
or command[0] in self.cache_blacklist
or command[0] not in self.cache_whitelist
):
return None
while not self._socket_is_empty():
await self.read_response(push_request=True)
return self.client_cache.get(command)

def _add_to_local_cache(
self, command: Tuple[str], response: ResponseT, keys: List[KeysT]
):
"""
Add the command and response to the local cache if the command
is allowed to be cached
"""
if (
self.client_cache is not None
and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist)
and (self.cache_whitelist == [] or command[0] in self.cache_whitelist)
):
self.client_cache.set(command, response, keys)

def delete_from_local_cache(self, command: str):
"""
Delete the command from the local cache
"""
try:
self.client_cache.delete(command)
except AttributeError:
pass


class Connection(AbstractConnection):
"Manages TCP communication to and from a Redis server"
Expand Down
Loading

0 comments on commit 8cbf7f5

Please sign in to comment.