From 7a0cb59bcadab42b6e0ed6ccfa4805696baf51bf Mon Sep 17 00:00:00 2001 From: Juliano Fernandes Date: Tue, 28 May 2024 16:43:38 -0300 Subject: [PATCH] feat: implement cache protocol for redis client * Add custom Redis client class to override the `get` and `set` methods in order to comply with the Cache protocol from auth module. --- fastapi_extras/databases/redis.py | 16 +++++++++++++--- tests/databases/test_redis.py | 25 ++++++++++++++++++------- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/fastapi_extras/databases/redis.py b/fastapi_extras/databases/redis.py index 0e81781..a26e693 100644 --- a/fastapi_extras/databases/redis.py +++ b/fastapi_extras/databases/redis.py @@ -1,15 +1,25 @@ -from typing import AsyncGenerator, Union +from typing import Any, AsyncGenerator, Optional, Union import redis.asyncio as redis from pydantic import RedisDsn +class Redis(redis.Redis): + async def get(self, key: str, *args: Any, **kwargs: Any) -> Optional[str]: + return await super().get(key, *args, **kwargs) + + async def set( + self, key: str, value: str, ttl: Optional[int] = None, *args: Any, **kwargs: Any + ) -> None: + await super().set(key, value, ttl, *args, **kwargs) + + class RedisManager: def __init__(self, url: Union[RedisDsn, str]): self.pool = redis.ConnectionPool.from_url(str(url)) - async def __call__(self) -> AsyncGenerator[redis.Redis, None]: - cli = redis.Redis(connection_pool=self.pool) + async def __call__(self) -> AsyncGenerator[Redis, None]: + cli = Redis(connection_pool=self.pool) try: yield cli diff --git a/tests/databases/test_redis.py b/tests/databases/test_redis.py index 4cc340d..2c6d292 100644 --- a/tests/databases/test_redis.py +++ b/tests/databases/test_redis.py @@ -1,29 +1,35 @@ -from typing import Union +from typing import Any, Optional import pytest from fastapi import Depends, FastAPI, HTTPException, status from fastapi.testclient import TestClient from pydantic import BaseModel from pytest import MonkeyPatch -from redis.asyncio import Redis from typing_extensions import Annotated -from fastapi_extras.databases.redis import RedisManager +from fastapi_extras.databases.redis import Redis, RedisManager class FakeRedis: db = {} closed = [] - async def get(self, key: str) -> Union[bytes, None]: + async def get(self, key: str) -> Optional[str]: return FakeRedis.db.get(key) - async def set(self, key: str, val: str): - FakeRedis.db[key] = val.encode() + async def set( + self, key: str, val: str, ttl: Optional[int] = None, *args: Any, **kwargs: Any + ) -> None: + FakeRedis.db[key] = val async def aclose(self): FakeRedis.closed.append(id(self)) + @classmethod + def flush(cls): + cls.db.clear() + cls.closed.clear() + class Item(BaseModel): key: str @@ -36,7 +42,7 @@ class Item(BaseModel): @app.get("/items/{key}", response_model=Item) async def read(key: str, redis: Annotated[Redis, Depends(redis_manager)]): - item = await redis.get(key) + item = await redis.get(key) or "" if not item: HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Item not found") @@ -60,6 +66,11 @@ def redis_mock(monkeypatch: MonkeyPatch): monkeypatch.setattr("redis.asyncio.Redis.aclose", FakeRedis.aclose) +@pytest.fixture(autouse=True) +def redis_flush(): + FakeRedis.flush() + + def test_redis_manager(): data = [ {"key": "foo", "val": "bar"},