diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d5716db..51c7ab69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# v4.2.0 + +* add several new ``Loadable`` convenience classes. ``ReplayCache`` for accessing random elements from a circlecore database, ``ReplayDir`` for folders of ``.osr`` files, and ``ReplayID`` for when you only know the replay id. + # v4.1.2 * correctly account for skips in replays diff --git a/circleguard/__init__.py b/circleguard/__init__.py index b519fa76..ce884e67 100644 --- a/circleguard/__init__.py +++ b/circleguard/__init__.py @@ -1,7 +1,9 @@ import logging from circleguard.circleguard import Circleguard, set_options -from circleguard.loadable import Check, Replay, ReplayMap, ReplayPath, Map, User, MapUser, ReplayContainer, LoadableContainer, Loadable +from circleguard.loadable import (Check, Replay, ReplayMap, ReplayPath, Map, User, + MapUser, ReplayDir, ReplayContainer, LoadableContainer, Loadable, ReplayCache, + CachedReplay, ReplayID) from circleguard.enums import Key, RatelimitWeight, Detect, ResultType from circleguard.mod import Mod from circleguard.utils import TRACE, ColoredFormatter @@ -26,7 +28,8 @@ "Circleguard", "set_options", # loadables "Check", "ReplayContainer", "LoadableContainer", "Map", "User", "MapUser", -"Replay", "ReplayMap", "ReplayPath", "Loadable", +"ReplayCache", "Replay", "ReplayMap", "ReplayPath", "CachedReplay", "Loadable", +"ReplayID", "ReplayDir", # enums "Key", "RatelimitWeight", "Detect", "ResultType", # mod diff --git a/circleguard/loadable.py b/circleguard/loadable.py index dd5f93d4..ed130932 100644 --- a/circleguard/loadable.py +++ b/circleguard/loadable.py @@ -1,8 +1,13 @@ import abc import logging +from pathlib import Path +import os import circleparse import numpy as np +import sqlite3 +import random +import wtc from circleguard.enums import RatelimitWeight from circleguard.mod import Mod @@ -431,6 +436,105 @@ def __eq__(self, loadable): and self.span == loadable.span) +class ReplayCache(ReplayContainer): + """ + Contains replays represented by a circlecore database. Primarily useful + to randomly sample these replays, rather than directly access them. + + Parameters + ---------- + path: string + The path to the database to load replays from. + num_maps: int + How many (randomly chosen) maps to load replays from. + limit: int + How many replays to load for each map. + Notes + ----- + :meth:`~.load_info` is an expensive operation for large databases + (likely because of inefficient sql queries). Consider using as few instances + of this object as possible. + """ + def __init__(self, path, num_maps, num_replays): + super().__init__(False) + self.path = path + self.num_maps = num_maps + self.limit = num_replays * num_maps + self.replays = [] + conn = sqlite3.connect(path) + self.cursor = conn.cursor() + + def load_info(self, loader): + map_ids = self.cursor.execute( + """ + SELECT DISTINCT map_id + FROM replays + """ + ).fetchall() + # flatten map_ids, because it's actually a list of lists + map_ids = [item[0] for item in map_ids] + chosen_maps = random.choices(map_ids, k=self.num_maps) + + subclauses = [f"map_id = {chosen_map}" for chosen_map in chosen_maps] + where_clause = " OR ".join(subclauses) + + # TODO LIMIT clause isn't quite right here, some maps will have less + # than ``num_replays`` stored + infos = self.cursor.execute( + f""" + SELECT user_id, map_id, replay_data, replay_id, mods + FROM replays + WHERE {where_clause} + LIMIT {self.limit} + """ + ) + + for info in infos: + r = CachedReplay(info[0], info[1], info[4], info[2], info[3]) + self.replays.append(r) + + def all_replays(self): + return self.replays + + def __eq__(self, other): + return self.path == other.path + + +class ReplayDir(ReplayContainer): + """ + A folder with replay files inside it. + + Notes + ----- + Any files not ending in ``.osr`` are ignored. + + Warnings + -------- + Nested directories are not support (yet). Any folders encountered will be + ignored. + """ + def __init__(self, dir_path, cache=None): + super().__init__(cache) + self.dir_path = Path(dir_path) + if not self.dir_path.is_dir(): + raise ValueError(f"Expected path pointing to {self.dir_path} to be " + "a directory") + self.replays = [] + + def load_info(self, loader): + for path in os.listdir(self.dir_path): + if not path.endswith(".osr"): + continue + replay = ReplayPath(self.dir_path / path) + self.replays.append(replay) + + def all_replays(self): + return self.replays + + def __eq__(self, other): + return self.dir_path == other.dir_path + + class Replay(Loadable): """ A replay played by a player. @@ -779,3 +883,43 @@ def __str__(self): return f"Loaded ReplayPath by {self.username} on {self.map_id} at {self.path}" else: return f"Unloaded ReplayPath at {self.path}" + + +class ReplayID(Replay): + def __init__(self, replay_id, cache=None): + super().__init__(RatelimitWeight.HEAVY, cache) + self.replay_id = replay_id + + def load(self, loader, cache): + # TODO file github issue about loading info from replay id, + # right now we can literally only load the replay data which + # is pretty useless if we don't have a map id or the mods used + cache = cache if self.cache is None else self.cache + replay_data = loader.replay_data_from_id(self.replay_id, cache) + self._process_replay_data(replay_data) + self.loaded = True + + def __eq__(self, other): + return self.replay_id == other.replay_id + + +class CachedReplay(Replay): + def __init__(self, user_id, map_id, mods, replay_data, replay_id): + super().__init__(RatelimitWeight.NONE, False) + self.user_id = user_id + self.map_id = map_id + self.mods = Mod(mods) + self.replay_data = replay_data + self.replay_id = replay_id + + def load(self, loader, cache): + if self.loaded: + return + decompressed = wtc.decompress(self.replay_data) + replay_data = circleparse.parse_replay(decompressed, pure_lzma=True).play_data + self._process_replay_data(replay_data) + self.loaded = True + + def __eq__(self, other): + # could check more but replay_id is already a primary key, guaranteed unique + return self.replay_id == other.replay_id diff --git a/circleguard/loader.py b/circleguard/loader.py index 533ed361..2a1dbe62 100644 --- a/circleguard/loader.py +++ b/circleguard/loader.py @@ -358,6 +358,23 @@ def replay_data(self, replay_info, cache=None): self.cacher.cache(lzma_bytes, replay_info) return replay_data + # TODO make this check cache for the replay + def replay_data_from_id(self, replay_id, cache): + """ + Retrieves replay data from the api, given a replay id. + + Parameters + ---------- + replay_id: int + The id of the replay to retrieve data for. + """ + response = self.api.get_replay({"s": replay_id}) + Loader.check_response(response) + lzma = base64.b64decode(response["content"]) + replay_data = circleparse.parse_replay(lzma, pure_lzma=True).play_data + # TODO cache the replay here if the api ever gives us the info we need + return replay_data + @lru_cache() @request def map_id(self, map_hash): diff --git a/circleguard/result.py b/circleguard/result.py index ede139b9..7dbe1e3e 100644 --- a/circleguard/result.py +++ b/circleguard/result.py @@ -66,10 +66,12 @@ class StealResult(ComparisonResult): The other replay involved. earlier_replay: :class:`~circleguard.loadable.Replay` The earlier of the two replays (when the score was made). This is a - reference to either replay1 or replay2. + reference to either replay1 or replay2, or ``None`` if one of the + replays did not provide a timestamp. later_replay: :class:`~circleguard.loadable.Replay` The later of the two replays (when the score was made). This is a - reference to either replay1 or replay2. + reference to either replay1 or replay2, or ``None`` if one of the + replays did not provide a timestamp. similarity: int The similarity of the two replays (the lower, the more similar). Similarity is, roughly speaking, a measure of the average pixel @@ -82,6 +84,10 @@ class StealResult(ComparisonResult): def __init__(self, replay1: Replay, replay2: Replay): super().__init__(replay1, replay2, ResultType.STEAL) + # can't compare ``None`` timestamps + if not self.replay1.timestamp or not self.replay2.timestamp: + return + if self.replay1.timestamp < self.replay2.timestamp: self.earlier_replay: Replay = self.replay1 self.later_replay: Replay = self.replay2 diff --git a/circleguard/version.py b/circleguard/version.py index 13ffcf42..0fd7811c 100644 --- a/circleguard/version.py +++ b/circleguard/version.py @@ -1 +1 @@ -__version__ = "4.1.2" +__version__ = "4.2.0"