diff --git a/backend/endpoints/tests/test_identity.py b/backend/endpoints/tests/test_identity.py index 7c8d9ba79..854ae8079 100644 --- a/backend/endpoints/tests/test_identity.py +++ b/backend/endpoints/tests/test_identity.py @@ -2,7 +2,7 @@ import pytest from fastapi.testclient import TestClient -from handler.redis_handler import cache +from handler.redis_handler import sync_cache from main import app from models.user import Role @@ -12,7 +12,7 @@ @pytest.fixture(autouse=True) def clear_cache(): yield - cache.flushall() + sync_cache.flushall() def test_login_logout(admin_user): diff --git a/backend/handler/metadata/base_hander.py b/backend/handler/metadata/base_hander.py index 0de05cc81..f163b8fdf 100644 --- a/backend/handler/metadata/base_hander.py +++ b/backend/handler/metadata/base_hander.py @@ -4,7 +4,7 @@ import unicodedata from typing import Final -from handler.redis_handler import cache +from handler.redis_handler import async_cache, sync_cache from logger.logger import log from tasks.update_switch_titledb import ( SWITCH_PRODUCT_ID_KEY, @@ -18,9 +18,9 @@ def conditionally_set_cache( index_key: str, filename: str, parent_dir: str = os.path.dirname(__file__) ) -> None: fixtures_path = os.path.join(parent_dir, "fixtures") - if not cache.exists(index_key): + if not sync_cache.exists(index_key): index_data = json.loads(open(os.path.join(fixtures_path, filename)).read()) - with cache.pipeline() as pipe: + with sync_cache.pipeline() as pipe: for data_batch in batched(index_data.items(), 2000): data_map = {k: json.dumps(v) for k, v in dict(data_batch).items()} pipe.hset(index_key, mapping=data_map) @@ -99,7 +99,7 @@ def _normalize_exact_match(name: str) -> str: async def _ps2_opl_format(self, match: re.Match[str], search_term: str) -> str: serial_code = match.group(1) - index_entry = cache.hget(PS2_OPL_KEY, serial_code) + index_entry = await async_cache.hget(PS2_OPL_KEY, serial_code) if index_entry: index_entry = json.loads(index_entry) search_term = index_entry["Name"] # type: ignore @@ -107,7 +107,7 @@ async def _ps2_opl_format(self, match: re.Match[str], search_term: str) -> str: return search_term async def _sony_serial_format(self, index_key: str, serial_code: str) -> str | None: - index_entry = cache.hget(index_key, serial_code) + index_entry = await async_cache.hget(index_key, serial_code) if index_entry: index_entry = json.loads(index_entry) return index_entry["title"] @@ -140,15 +140,15 @@ async def _switch_titledb_format( ) -> tuple[str, dict | None]: title_id = match.group(1) - if not cache.exists(SWITCH_TITLEDB_INDEX_KEY): + if not (await async_cache.exists(SWITCH_TITLEDB_INDEX_KEY)): log.warning("Fetching the Switch titleID index file...") await update_switch_titledb_task.run(force=True) - if not cache.exists(SWITCH_TITLEDB_INDEX_KEY): + if not (await async_cache.exists(SWITCH_TITLEDB_INDEX_KEY)): log.error("Could not fetch the Switch titleID index file") return search_term, None - index_entry = cache.hget(SWITCH_TITLEDB_INDEX_KEY, title_id) + index_entry = await async_cache.hget(SWITCH_TITLEDB_INDEX_KEY, title_id) if index_entry: index_entry = json.loads(index_entry) return index_entry["name"], index_entry @@ -165,15 +165,15 @@ async def _switch_productid_format( product_id[-3] = "0" product_id = "".join(product_id) - if not cache.exists(SWITCH_PRODUCT_ID_KEY): + if not (await async_cache.exists(SWITCH_PRODUCT_ID_KEY)): log.warning("Fetching the Switch productID index file...") await update_switch_titledb_task.run(force=True) - if not cache.exists(SWITCH_PRODUCT_ID_KEY): + if not (await async_cache.exists(SWITCH_PRODUCT_ID_KEY)): log.error("Could not fetch the Switch productID index file") return search_term, None - index_entry = cache.hget(SWITCH_PRODUCT_ID_KEY, product_id) + index_entry = await async_cache.hget(SWITCH_PRODUCT_ID_KEY, product_id) if index_entry: index_entry = json.loads(index_entry) return index_entry["name"], index_entry @@ -183,7 +183,7 @@ async def _switch_productid_format( async def _mame_format(self, search_term: str) -> str: from handler.filesystem import fs_rom_handler - index_entry = cache.hget(MAME_XML_KEY, search_term) + index_entry = await async_cache.hget(MAME_XML_KEY, search_term) if index_entry: index_entry = json.loads(index_entry) search_term = fs_rom_handler.get_file_name_with_no_tags( diff --git a/backend/handler/metadata/igdb_handler.py b/backend/handler/metadata/igdb_handler.py index 59b36c35d..b4ec6b694 100644 --- a/backend/handler/metadata/igdb_handler.py +++ b/backend/handler/metadata/igdb_handler.py @@ -8,7 +8,7 @@ import requests from config import IGDB_CLIENT_ID, IGDB_CLIENT_SECRET from fastapi import HTTPException, status -from handler.redis_handler import cache +from handler.redis_handler import sync_cache from logger.logger import log from requests.exceptions import HTTPError, Timeout from typing_extensions import TypedDict @@ -592,8 +592,8 @@ def _update_twitch_token(self) -> str: return "" # Set token in redis to expire in seconds - cache.set("romm:twitch_token", token, ex=expires_in - 10) # type: ignore[attr-defined] - cache.set("romm:twitch_token_expires_at", time.time() + expires_in - 10) # type: ignore[attr-defined] + sync_cache.set("romm:twitch_token", token, ex=expires_in - 10) # type: ignore[attr-defined] + sync_cache.set("romm:twitch_token_expires_at", time.time() + expires_in - 10) # type: ignore[attr-defined] log.info("Twitch token fetched!") @@ -608,8 +608,8 @@ def get_oauth_token(self) -> str: return "" # Fetch the token cache - token = cache.get("romm:twitch_token") # type: ignore[attr-defined] - token_expires_at = cache.get("romm:twitch_token_expires_at") # type: ignore[attr-defined] + token = sync_cache.get("romm:twitch_token") # type: ignore[attr-defined] + token_expires_at = sync_cache.get("romm:twitch_token_expires_at") # type: ignore[attr-defined] if not token or time.time() > float(token_expires_at or 0): log.warning("Twitch token invalid: fetching a new one...") diff --git a/backend/handler/redis_handler.py b/backend/handler/redis_handler.py index 4e3513115..ec15dddf6 100644 --- a/backend/handler/redis_handler.py +++ b/backend/handler/redis_handler.py @@ -3,7 +3,8 @@ from config import REDIS_DB, REDIS_HOST, REDIS_PASSWORD, REDIS_PORT, REDIS_USERNAME from logger.logger import log -from redis import Redis, StrictRedis +from redis import Redis +from redis.asyncio import Redis as AsyncRedis from rq import Queue @@ -31,12 +32,12 @@ class QueuePrio(Enum): low_prio_queue = Queue(name=QueuePrio.LOW.value, connection=redis_client) -def __get_cache() -> StrictRedis: +def __get_sync_cache() -> Redis: if "pytest" in sys.modules: # Only import fakeredis when running tests, as it is a test dependency. - from fakeredis import FakeStrictRedis + from fakeredis import FakeRedis - return FakeStrictRedis(version=7) + return FakeRedis(version=7) log.info(f"Connecting to redis in {sys.argv[0]}...") # A separate client that auto-decodes responses is needed @@ -52,4 +53,26 @@ def __get_cache() -> StrictRedis: return client -cache = __get_cache() +def __get_async_cache() -> AsyncRedis: + if "pytest" in sys.modules: + # Only import fakeredis when running tests, as it is a test dependency. + from fakeredis import FakeAsyncRedis + + return FakeAsyncRedis(version=7) + + log.info(f"Connecting to redis in {sys.argv[0]}...") + # A separate client that auto-decodes responses is needed + client = AsyncRedis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + username=REDIS_USERNAME, + db=REDIS_DB, + decode_responses=True, + ) + log.info(f"Redis connection established in {sys.argv[0]}!") + return client + + +sync_cache = __get_sync_cache() +async_cache = __get_async_cache() diff --git a/backend/models/firmware.py b/backend/models/firmware.py index 197756f26..cb79190f2 100644 --- a/backend/models/firmware.py +++ b/backend/models/firmware.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING from handler.metadata.base_hander import conditionally_set_cache -from handler.redis_handler import cache +from handler.redis_handler import sync_cache from models.base import BaseModel from sqlalchemy import BigInteger, ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -55,7 +55,7 @@ def platform_name(self) -> str: @cached_property def is_verified(self) -> bool: - cache_entry = cache.hget( + cache_entry = sync_cache.hget( KNOWN_BIOS_KEY, f"{self.platform_slug}:{self.file_name}" ) if cache_entry: diff --git a/backend/tasks/update_switch_titledb.py b/backend/tasks/update_switch_titledb.py index 062113094..c78be002a 100644 --- a/backend/tasks/update_switch_titledb.py +++ b/backend/tasks/update_switch_titledb.py @@ -5,7 +5,7 @@ ENABLE_SCHEDULED_UPDATE_SWITCH_TITLEDB, SCHEDULED_UPDATE_SWITCH_TITLEDB_CRON, ) -from handler.redis_handler import cache +from handler.redis_handler import async_cache from logger.logger import log from tasks.tasks import RemoteFilePullTask from utils.iterators import batched @@ -32,10 +32,10 @@ async def run(self, force: bool = False) -> None: index_json = json.loads(content) relevant_data = {k: v for k, v in index_json.items() if k and v} - with cache.pipeline() as pipe: + async with async_cache.pipeline() as pipe: for data_batch in batched(relevant_data.items(), 2000): titledb_map = {k: json.dumps(v) for k, v in dict(data_batch).items()} - pipe.hset(SWITCH_TITLEDB_INDEX_KEY, mapping=titledb_map) + await pipe.hset(SWITCH_TITLEDB_INDEX_KEY, mapping=titledb_map) for data_batch in batched(relevant_data.items(), 2000): product_map = { v["id"]: json.dumps(v) @@ -43,8 +43,8 @@ async def run(self, force: bool = False) -> None: if v.get("id") } if product_map: - pipe.hset(SWITCH_PRODUCT_ID_KEY, mapping=product_map) - pipe.execute() + await pipe.hset(SWITCH_PRODUCT_ID_KEY, mapping=product_map) + await pipe.execute() log.info("Scheduled switch titledb update completed!")