Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

misc: Add Redis async cache #1010

Merged
merged 1 commit into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backend/endpoints/tests/test_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
from fastapi.testclient import TestClient
from handler.redis_handler import cache
from handler.redis_handler import sync_cache
from main import app
from models.user import Role

Expand All @@ -12,7 +12,7 @@
@pytest.fixture(autouse=True)
def clear_cache():
yield
cache.flushall()
sync_cache.flushall()


def test_login_logout(admin_user):
Expand Down
24 changes: 12 additions & 12 deletions backend/handler/metadata/base_hander.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import unicodedata
from typing import Final

from handler.redis_handler import cache
from handler.redis_handler import async_cache, sync_cache
from logger.logger import log
from tasks.update_switch_titledb import (
SWITCH_PRODUCT_ID_KEY,
Expand All @@ -18,9 +18,9 @@ def conditionally_set_cache(
index_key: str, filename: str, parent_dir: str = os.path.dirname(__file__)
) -> None:
fixtures_path = os.path.join(parent_dir, "fixtures")
if not cache.exists(index_key):
if not sync_cache.exists(index_key):
index_data = json.loads(open(os.path.join(fixtures_path, filename)).read())
with cache.pipeline() as pipe:
with sync_cache.pipeline() as pipe:
for data_batch in batched(index_data.items(), 2000):
data_map = {k: json.dumps(v) for k, v in dict(data_batch).items()}
pipe.hset(index_key, mapping=data_map)
Expand Down Expand Up @@ -99,15 +99,15 @@ def _normalize_exact_match(name: str) -> str:

async def _ps2_opl_format(self, match: re.Match[str], search_term: str) -> str:
serial_code = match.group(1)
index_entry = cache.hget(PS2_OPL_KEY, serial_code)
index_entry = await async_cache.hget(PS2_OPL_KEY, serial_code)
if index_entry:
index_entry = json.loads(index_entry)
search_term = index_entry["Name"] # type: ignore

return search_term

async def _sony_serial_format(self, index_key: str, serial_code: str) -> str | None:
index_entry = cache.hget(index_key, serial_code)
index_entry = await async_cache.hget(index_key, serial_code)
if index_entry:
index_entry = json.loads(index_entry)
return index_entry["title"]
Expand Down Expand Up @@ -140,15 +140,15 @@ async def _switch_titledb_format(
) -> tuple[str, dict | None]:
title_id = match.group(1)

if not cache.exists(SWITCH_TITLEDB_INDEX_KEY):
if not (await async_cache.exists(SWITCH_TITLEDB_INDEX_KEY)):
log.warning("Fetching the Switch titleID index file...")
await update_switch_titledb_task.run(force=True)

if not cache.exists(SWITCH_TITLEDB_INDEX_KEY):
if not (await async_cache.exists(SWITCH_TITLEDB_INDEX_KEY)):
log.error("Could not fetch the Switch titleID index file")
return search_term, None

index_entry = cache.hget(SWITCH_TITLEDB_INDEX_KEY, title_id)
index_entry = await async_cache.hget(SWITCH_TITLEDB_INDEX_KEY, title_id)
if index_entry:
index_entry = json.loads(index_entry)
return index_entry["name"], index_entry
Expand All @@ -165,15 +165,15 @@ async def _switch_productid_format(
product_id[-3] = "0"
product_id = "".join(product_id)

if not cache.exists(SWITCH_PRODUCT_ID_KEY):
if not (await async_cache.exists(SWITCH_PRODUCT_ID_KEY)):
log.warning("Fetching the Switch productID index file...")
await update_switch_titledb_task.run(force=True)

if not cache.exists(SWITCH_PRODUCT_ID_KEY):
if not (await async_cache.exists(SWITCH_PRODUCT_ID_KEY)):
log.error("Could not fetch the Switch productID index file")
return search_term, None

index_entry = cache.hget(SWITCH_PRODUCT_ID_KEY, product_id)
index_entry = await async_cache.hget(SWITCH_PRODUCT_ID_KEY, product_id)
if index_entry:
index_entry = json.loads(index_entry)
return index_entry["name"], index_entry
Expand All @@ -183,7 +183,7 @@ async def _switch_productid_format(
async def _mame_format(self, search_term: str) -> str:
from handler.filesystem import fs_rom_handler

index_entry = cache.hget(MAME_XML_KEY, search_term)
index_entry = await async_cache.hget(MAME_XML_KEY, search_term)
if index_entry:
index_entry = json.loads(index_entry)
search_term = fs_rom_handler.get_file_name_with_no_tags(
Expand Down
10 changes: 5 additions & 5 deletions backend/handler/metadata/igdb_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import requests
from config import IGDB_CLIENT_ID, IGDB_CLIENT_SECRET
from fastapi import HTTPException, status
from handler.redis_handler import cache
from handler.redis_handler import sync_cache
from logger.logger import log
from requests.exceptions import HTTPError, Timeout
from typing_extensions import TypedDict
Expand Down Expand Up @@ -592,8 +592,8 @@ def _update_twitch_token(self) -> str:
return ""

# Set token in redis to expire in <expires_in> seconds
cache.set("romm:twitch_token", token, ex=expires_in - 10) # type: ignore[attr-defined]
cache.set("romm:twitch_token_expires_at", time.time() + expires_in - 10) # type: ignore[attr-defined]
sync_cache.set("romm:twitch_token", token, ex=expires_in - 10) # type: ignore[attr-defined]
sync_cache.set("romm:twitch_token_expires_at", time.time() + expires_in - 10) # type: ignore[attr-defined]

log.info("Twitch token fetched!")

Expand All @@ -608,8 +608,8 @@ def get_oauth_token(self) -> str:
return ""

# Fetch the token cache
token = cache.get("romm:twitch_token") # type: ignore[attr-defined]
token_expires_at = cache.get("romm:twitch_token_expires_at") # type: ignore[attr-defined]
token = sync_cache.get("romm:twitch_token") # type: ignore[attr-defined]
token_expires_at = sync_cache.get("romm:twitch_token_expires_at") # type: ignore[attr-defined]

if not token or time.time() > float(token_expires_at or 0):
log.warning("Twitch token invalid: fetching a new one...")
Expand Down
33 changes: 28 additions & 5 deletions backend/handler/redis_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from config import REDIS_DB, REDIS_HOST, REDIS_PASSWORD, REDIS_PORT, REDIS_USERNAME
from logger.logger import log
from redis import Redis, StrictRedis
from redis import Redis
from redis.asyncio import Redis as AsyncRedis
from rq import Queue


Expand Down Expand Up @@ -31,12 +32,12 @@ class QueuePrio(Enum):
low_prio_queue = Queue(name=QueuePrio.LOW.value, connection=redis_client)


def __get_cache() -> StrictRedis:
def __get_sync_cache() -> Redis:
if "pytest" in sys.modules:
# Only import fakeredis when running tests, as it is a test dependency.
from fakeredis import FakeStrictRedis
from fakeredis import FakeRedis

return FakeStrictRedis(version=7)
return FakeRedis(version=7)

log.info(f"Connecting to redis in {sys.argv[0]}...")
# A separate client that auto-decodes responses is needed
Expand All @@ -52,4 +53,26 @@ def __get_cache() -> StrictRedis:
return client


cache = __get_cache()
def __get_async_cache() -> AsyncRedis:
if "pytest" in sys.modules:
# Only import fakeredis when running tests, as it is a test dependency.
from fakeredis import FakeAsyncRedis

return FakeAsyncRedis(version=7)

log.info(f"Connecting to redis in {sys.argv[0]}...")
# A separate client that auto-decodes responses is needed
client = AsyncRedis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
username=REDIS_USERNAME,
db=REDIS_DB,
decode_responses=True,
)
log.info(f"Redis connection established in {sys.argv[0]}!")
return client


sync_cache = __get_sync_cache()
async_cache = __get_async_cache()
4 changes: 2 additions & 2 deletions backend/models/firmware.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import TYPE_CHECKING

from handler.metadata.base_hander import conditionally_set_cache
from handler.redis_handler import cache
from handler.redis_handler import sync_cache
from models.base import BaseModel
from sqlalchemy import BigInteger, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
Expand Down Expand Up @@ -55,7 +55,7 @@ def platform_name(self) -> str:

@cached_property
def is_verified(self) -> bool:
cache_entry = cache.hget(
cache_entry = sync_cache.hget(
KNOWN_BIOS_KEY, f"{self.platform_slug}:{self.file_name}"
)
if cache_entry:
Expand Down
10 changes: 5 additions & 5 deletions backend/tasks/update_switch_titledb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
ENABLE_SCHEDULED_UPDATE_SWITCH_TITLEDB,
SCHEDULED_UPDATE_SWITCH_TITLEDB_CRON,
)
from handler.redis_handler import cache
from handler.redis_handler import async_cache
from logger.logger import log
from tasks.tasks import RemoteFilePullTask
from utils.iterators import batched
Expand All @@ -32,19 +32,19 @@ async def run(self, force: bool = False) -> None:
index_json = json.loads(content)
relevant_data = {k: v for k, v in index_json.items() if k and v}

with cache.pipeline() as pipe:
async with async_cache.pipeline() as pipe:
for data_batch in batched(relevant_data.items(), 2000):
titledb_map = {k: json.dumps(v) for k, v in dict(data_batch).items()}
pipe.hset(SWITCH_TITLEDB_INDEX_KEY, mapping=titledb_map)
await pipe.hset(SWITCH_TITLEDB_INDEX_KEY, mapping=titledb_map)
for data_batch in batched(relevant_data.items(), 2000):
product_map = {
v["id"]: json.dumps(v)
for v in dict(data_batch).values()
if v.get("id")
}
if product_map:
pipe.hset(SWITCH_PRODUCT_ID_KEY, mapping=product_map)
pipe.execute()
await pipe.hset(SWITCH_PRODUCT_ID_KEY, mapping=product_map)
await pipe.execute()

log.info("Scheduled switch titledb update completed!")

Expand Down
Loading