Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support client side caching with ConnectionPool #3099

Merged
merged 7 commits into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion redis/cache.py → redis/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,11 @@ class _LocalCache:
"""

def __init__(
self, max_size: int, ttl: int, eviction_policy: EvictionPolicy, **kwargs
self,
max_size: int = 100,
ttl: int = 0,
eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY,
**kwargs,
):
self.max_size = max_size
self.ttl = ttl
Expand Down
106 changes: 18 additions & 88 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
cast,
)

from redis._cache import (
DEFAULT_BLACKLIST,
DEFAULT_EVICTION_POLICY,
DEFAULT_WHITELIST,
_LocalCache,
)
from redis._parsers.helpers import (
_RedisCallbacks,
_RedisCallbacksRESP2,
Expand All @@ -39,12 +45,6 @@
)
from redis.asyncio.lock import Lock
from redis.asyncio.retry import Retry
from redis.cache import (
DEFAULT_BLACKLIST,
DEFAULT_EVICTION_POLICY,
DEFAULT_WHITELIST,
_LocalCache,
)
from redis.client import (
EMPTY_RESPONSE,
NEVER_DECODE,
Expand All @@ -67,7 +67,7 @@
TimeoutError,
WatchError,
)
from redis.typing import ChannelT, EncodableT, KeysT, KeyT, ResponseT
from redis.typing import ChannelT, EncodableT, KeyT
from redis.utils import (
HIREDIS_AVAILABLE,
_set_info_logger,
Expand Down Expand Up @@ -294,6 +294,13 @@ def __init__(
"lib_version": lib_version,
"redis_connect_func": redis_connect_func,
"protocol": protocol,
"cache_enable": cache_enable,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have you, @vladvildanov and @uglide sync on these just to finalize before the next release. Getting them into the specs finally too - WDYT?

"client_cache": client_cache,
"cache_max_size": cache_max_size,
"cache_ttl": cache_ttl,
"cache_eviction_policy": cache_eviction_policy,
"cache_blacklist": cache_blacklist,
"cache_whitelist": cache_whitelist,
}
# based on input, setup appropriate connection args
if unix_socket_path is not None:
Expand Down Expand Up @@ -350,16 +357,6 @@ def __init__(
# on a set of redis commands
self._single_conn_lock = asyncio.Lock()

self.client_cache = client_cache
if cache_enable:
self.client_cache = _LocalCache(
cache_max_size, cache_ttl, cache_eviction_policy
)
if self.client_cache is not None:
self.cache_blacklist = cache_blacklist
self.cache_whitelist = cache_whitelist
self.client_cache_initialized = False

def __repr__(self):
return (
f"<{self.__class__.__module__}.{self.__class__.__name__}"
Expand All @@ -374,10 +371,6 @@ async def initialize(self: _RedisT) -> _RedisT:
async with self._single_conn_lock:
if self.connection is None:
self.connection = await self.connection_pool.get_connection("_")
if self.client_cache is not None:
self.connection._parser.set_invalidation_push_handler(
self._cache_invalidation_process
)
return self

def set_response_callback(self, command: str, callback: ResponseCallbackT):
Expand Down Expand Up @@ -596,8 +589,6 @@ async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
close_connection_pool is None and self.auto_close_connection_pool
):
await self.connection_pool.disconnect()
if self.client_cache:
self.client_cache.flush()

@deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
async def close(self, close_connection_pool: Optional[bool] = None) -> None:
Expand Down Expand Up @@ -626,89 +617,28 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
):
raise error

def _cache_invalidation_process(
self, data: List[Union[str, Optional[List[str]]]]
) -> None:
"""
Invalidate (delete) all redis commands associated with a specific key.
`data` is a list of strings, where the first string is the invalidation message
and the second string is the list of keys to invalidate.
(if the list of keys is None, then all keys are invalidated)
"""
if data[1] is not None:
for key in data[1]:
self.client_cache.invalidate(str_if_bytes(key))
else:
self.client_cache.flush()

async def _get_from_local_cache(self, command: str):
"""
If the command is in the local cache, return the response
"""
if (
self.client_cache is None
or command[0] in self.cache_blacklist
or command[0] not in self.cache_whitelist
):
return None
while not self.connection._is_socket_empty():
await self.connection.read_response(push_request=True)
return self.client_cache.get(command)

def _add_to_local_cache(
self, command: Tuple[str], response: ResponseT, keys: List[KeysT]
):
"""
Add the command and response to the local cache if the command
is allowed to be cached
"""
if (
self.client_cache is not None
and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist)
and (self.cache_whitelist == [] or command[0] in self.cache_whitelist)
):
self.client_cache.set(command, response, keys)

def delete_from_local_cache(self, command: str):
"""
Delete the command from the local cache
"""
try:
self.client_cache.delete(command)
except AttributeError:
pass

# COMMAND EXECUTION AND PROTOCOL PARSING
async def execute_command(self, *args, **options):
"""Execute a command and return a parsed response"""
await self.initialize()
command_name = args[0]
keys = options.pop("keys", None) # keys are used only for client side caching
response_from_cache = await self._get_from_local_cache(args)
pool = self.connection_pool
conn = self.connection or await pool.get_connection(command_name, **options)
response_from_cache = await conn._get_from_local_cache(args)
if response_from_cache is not None:
return response_from_cache
else:
pool = self.connection_pool
conn = self.connection or await pool.get_connection(command_name, **options)

if self.single_connection_client:
await self._single_conn_lock.acquire()
try:
if self.client_cache is not None and not self.client_cache_initialized:
await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, "CLIENT", *("CLIENT", "TRACKING", "ON")
),
lambda error: self._disconnect_raise(conn, error),
)
self.client_cache_initialized = True
response = await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
self._add_to_local_cache(args, response, keys)
conn._add_to_local_cache(args, response, keys)
return response
finally:
if self.single_connection_client:
Expand Down
88 changes: 86 additions & 2 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,15 @@
ResponseError,
TimeoutError,
)
from redis.typing import EncodableT
from redis.typing import EncodableT, KeysT, ResponseT
from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes

from .._cache import (
DEFAULT_BLACKLIST,
DEFAULT_EVICTION_POLICY,
DEFAULT_WHITELIST,
_LocalCache,
)
from .._parsers import (
BaseParser,
Encoder,
Expand Down Expand Up @@ -114,6 +120,9 @@ class AbstractConnection:
"encoder",
"ssl_context",
"protocol",
"client_cache",
"cache_blacklist",
"cache_whitelist",
"_reader",
"_writer",
"_parser",
Expand Down Expand Up @@ -148,6 +157,13 @@ def __init__(
encoder_class: Type[Encoder] = Encoder,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
cache_enable: bool = False,
client_cache: Optional[_LocalCache] = None,
cache_max_size: int = 100,
cache_ttl: int = 0,
cache_eviction_policy: str = DEFAULT_EVICTION_POLICY,
cache_blacklist: List[str] = DEFAULT_BLACKLIST,
cache_whitelist: List[str] = DEFAULT_WHITELIST,
):
if (username or password) and credential_provider is not None:
raise DataError(
Expand Down Expand Up @@ -205,6 +221,14 @@ def __init__(
if p < 2 or p > 3:
raise ConnectionError("protocol must be either 2 or 3")
self.protocol = protocol
if cache_enable:
Copy link
Contributor

@chayim chayim Jan 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do an enforce check on RESP version, like others did - just for CSC? It might provide a better error for users.

_cache = _LocalCache(cache_max_size, cache_ttl, cache_eviction_policy)
else:
_cache = None
self.client_cache = client_cache if client_cache is not None else _cache
if self.client_cache is not None:
self.cache_blacklist = cache_blacklist
self.cache_whitelist = cache_whitelist

def __del__(self, _warnings: Any = warnings):
# For some reason, the individual streams don't get properly garbage
Expand Down Expand Up @@ -395,6 +419,11 @@ async def on_connect(self) -> None:
# if a database is specified, switch to it. Also pipeline this
if self.db:
await self.send_command("SELECT", self.db)
# if client caching is enabled, start tracking
if self.client_cache:
await self.send_command("CLIENT", "TRACKING", "ON")
await self.read_response()
self._parser.set_invalidation_push_handler(self._cache_invalidation_process)

# read responses from pipeline
for _ in (sent for sent in (self.lib_name, self.lib_version) if sent):
Expand Down Expand Up @@ -429,6 +458,9 @@ async def disconnect(self, nowait: bool = False) -> None:
raise TimeoutError(
f"Timed out closing connection after {self.socket_connect_timeout}"
) from None
finally:
if self.client_cache:
dvora-h marked this conversation as resolved.
Show resolved Hide resolved
self.client_cache.flush()

async def _send_ping(self):
"""Send PING, expect PONG in return"""
Expand Down Expand Up @@ -646,10 +678,62 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]
output.append(SYM_EMPTY.join(pieces))
return output

def _is_socket_empty(self):
def _socket_is_empty(self):
"""Check if the socket is empty"""
return not self._reader.at_eof()

def _cache_invalidation_process(
self, data: List[Union[str, Optional[List[str]]]]
) -> None:
"""
Invalidate (delete) all redis commands associated with a specific key.
`data` is a list of strings, where the first string is the invalidation message
and the second string is the list of keys to invalidate.
(if the list of keys is None, then all keys are invalidated)
"""
if data[1] is not None:
dvora-h marked this conversation as resolved.
Show resolved Hide resolved
self.client_cache.flush()
else:
for key in data[1]:
self.client_cache.invalidate(str_if_bytes(key))

async def _get_from_local_cache(self, command: str):
"""
If the command is in the local cache, return the response
"""
if (
self.client_cache is None
or command[0] in self.cache_blacklist
or command[0] not in self.cache_whitelist
):
return None
while not self._socket_is_empty():
await self.read_response(push_request=True)
return self.client_cache.get(command)

def _add_to_local_cache(
self, command: Tuple[str], response: ResponseT, keys: List[KeysT]
):
"""
Add the command and response to the local cache if the command
is allowed to be cached
"""
if (
self.client_cache is not None
and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist)
and (self.cache_whitelist == [] or command[0] in self.cache_whitelist)
):
self.client_cache.set(command, response, keys)

def delete_from_local_cache(self, command: str):
"""
Delete the command from the local cache
"""
try:
self.client_cache.delete(command)
except AttributeError:
pass


class Connection(AbstractConnection):
"Manages TCP communication to and from a Redis server"
Expand Down
Loading
Loading