Skip to content

Commit

Permalink
Merge pull request #148 from circleguard/new-loadables
Browse files Browse the repository at this point in the history
new loadables
  • Loading branch information
tybug authored May 25, 2020
2 parents b7d4271 + 2529072 commit 987f769
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 5 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# v4.2.0

* add several new ``Loadable`` convenience classes. ``ReplayCache`` for accessing random elements from a circlecore database, ``ReplayDir`` for folders of ``.osr`` files, and ``ReplayID`` for when you only know the replay id.

# v4.1.2

* correctly account for skips in replays
Expand Down
7 changes: 5 additions & 2 deletions circleguard/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging

from circleguard.circleguard import Circleguard, set_options
from circleguard.loadable import Check, Replay, ReplayMap, ReplayPath, Map, User, MapUser, ReplayContainer, LoadableContainer, Loadable
from circleguard.loadable import (Check, Replay, ReplayMap, ReplayPath, Map, User,
MapUser, ReplayDir, ReplayContainer, LoadableContainer, Loadable, ReplayCache,
CachedReplay, ReplayID)
from circleguard.enums import Key, RatelimitWeight, Detect, ResultType
from circleguard.mod import Mod
from circleguard.utils import TRACE, ColoredFormatter
Expand All @@ -26,7 +28,8 @@
"Circleguard", "set_options",
# loadables
"Check", "ReplayContainer", "LoadableContainer", "Map", "User", "MapUser",
"Replay", "ReplayMap", "ReplayPath", "Loadable",
"ReplayCache", "Replay", "ReplayMap", "ReplayPath", "CachedReplay", "Loadable",
"ReplayID", "ReplayDir",
# enums
"Key", "RatelimitWeight", "Detect", "ResultType",
# mod
Expand Down
144 changes: 144 additions & 0 deletions circleguard/loadable.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import abc
import logging
from pathlib import Path
import os

import circleparse
import numpy as np
import sqlite3
import random
import wtc

from circleguard.enums import RatelimitWeight
from circleguard.mod import Mod
Expand Down Expand Up @@ -431,6 +436,105 @@ def __eq__(self, loadable):
and self.span == loadable.span)


class ReplayCache(ReplayContainer):
"""
Contains replays represented by a circlecore database. Primarily useful
to randomly sample these replays, rather than directly access them.
Parameters
----------
path: string
The path to the database to load replays from.
num_maps: int
How many (randomly chosen) maps to load replays from.
limit: int
How many replays to load for each map.
Notes
-----
:meth:`~.load_info` is an expensive operation for large databases
(likely because of inefficient sql queries). Consider using as few instances
of this object as possible.
"""
def __init__(self, path, num_maps, num_replays):
super().__init__(False)
self.path = path
self.num_maps = num_maps
self.limit = num_replays * num_maps
self.replays = []
conn = sqlite3.connect(path)
self.cursor = conn.cursor()

def load_info(self, loader):
map_ids = self.cursor.execute(
"""
SELECT DISTINCT map_id
FROM replays
"""
).fetchall()
# flatten map_ids, because it's actually a list of lists
map_ids = [item[0] for item in map_ids]
chosen_maps = random.choices(map_ids, k=self.num_maps)

subclauses = [f"map_id = {chosen_map}" for chosen_map in chosen_maps]
where_clause = " OR ".join(subclauses)

# TODO LIMIT clause isn't quite right here, some maps will have less
# than ``num_replays`` stored
infos = self.cursor.execute(
f"""
SELECT user_id, map_id, replay_data, replay_id, mods
FROM replays
WHERE {where_clause}
LIMIT {self.limit}
"""
)

for info in infos:
r = CachedReplay(info[0], info[1], info[4], info[2], info[3])
self.replays.append(r)

def all_replays(self):
return self.replays

def __eq__(self, other):
return self.path == other.path


class ReplayDir(ReplayContainer):
"""
A folder with replay files inside it.
Notes
-----
Any files not ending in ``.osr`` are ignored.
Warnings
--------
Nested directories are not support (yet). Any folders encountered will be
ignored.
"""
def __init__(self, dir_path, cache=None):
super().__init__(cache)
self.dir_path = Path(dir_path)
if not self.dir_path.is_dir():
raise ValueError(f"Expected path pointing to {self.dir_path} to be "
"a directory")
self.replays = []

def load_info(self, loader):
for path in os.listdir(self.dir_path):
if not path.endswith(".osr"):
continue
replay = ReplayPath(self.dir_path / path)
self.replays.append(replay)

def all_replays(self):
return self.replays

def __eq__(self, other):
return self.dir_path == other.dir_path


class Replay(Loadable):
"""
A replay played by a player.
Expand Down Expand Up @@ -779,3 +883,43 @@ def __str__(self):
return f"Loaded ReplayPath by {self.username} on {self.map_id} at {self.path}"
else:
return f"Unloaded ReplayPath at {self.path}"


class ReplayID(Replay):
def __init__(self, replay_id, cache=None):
super().__init__(RatelimitWeight.HEAVY, cache)
self.replay_id = replay_id

def load(self, loader, cache):
# TODO file github issue about loading info from replay id,
# right now we can literally only load the replay data which
# is pretty useless if we don't have a map id or the mods used
cache = cache if self.cache is None else self.cache
replay_data = loader.replay_data_from_id(self.replay_id, cache)
self._process_replay_data(replay_data)
self.loaded = True

def __eq__(self, other):
return self.replay_id == other.replay_id


class CachedReplay(Replay):
def __init__(self, user_id, map_id, mods, replay_data, replay_id):
super().__init__(RatelimitWeight.NONE, False)
self.user_id = user_id
self.map_id = map_id
self.mods = Mod(mods)
self.replay_data = replay_data
self.replay_id = replay_id

def load(self, loader, cache):
if self.loaded:
return
decompressed = wtc.decompress(self.replay_data)
replay_data = circleparse.parse_replay(decompressed, pure_lzma=True).play_data
self._process_replay_data(replay_data)
self.loaded = True

def __eq__(self, other):
# could check more but replay_id is already a primary key, guaranteed unique
return self.replay_id == other.replay_id
17 changes: 17 additions & 0 deletions circleguard/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,23 @@ def replay_data(self, replay_info, cache=None):
self.cacher.cache(lzma_bytes, replay_info)
return replay_data

# TODO make this check cache for the replay
def replay_data_from_id(self, replay_id, cache):
"""
Retrieves replay data from the api, given a replay id.
Parameters
----------
replay_id: int
The id of the replay to retrieve data for.
"""
response = self.api.get_replay({"s": replay_id})
Loader.check_response(response)
lzma = base64.b64decode(response["content"])
replay_data = circleparse.parse_replay(lzma, pure_lzma=True).play_data
# TODO cache the replay here if the api ever gives us the info we need
return replay_data

@lru_cache()
@request
def map_id(self, map_hash):
Expand Down
10 changes: 8 additions & 2 deletions circleguard/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@ class StealResult(ComparisonResult):
The other replay involved.
earlier_replay: :class:`~circleguard.loadable.Replay`
The earlier of the two replays (when the score was made). This is a
reference to either replay1 or replay2.
reference to either replay1 or replay2, or ``None`` if one of the
replays did not provide a timestamp.
later_replay: :class:`~circleguard.loadable.Replay`
The later of the two replays (when the score was made). This is a
reference to either replay1 or replay2.
reference to either replay1 or replay2, or ``None`` if one of the
replays did not provide a timestamp.
similarity: int
The similarity of the two replays (the lower, the more similar).
Similarity is, roughly speaking, a measure of the average pixel
Expand All @@ -82,6 +84,10 @@ class StealResult(ComparisonResult):
def __init__(self, replay1: Replay, replay2: Replay):
super().__init__(replay1, replay2, ResultType.STEAL)

# can't compare ``None`` timestamps
if not self.replay1.timestamp or not self.replay2.timestamp:
return

if self.replay1.timestamp < self.replay2.timestamp:
self.earlier_replay: Replay = self.replay1
self.later_replay: Replay = self.replay2
Expand Down
2 changes: 1 addition & 1 deletion circleguard/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "4.1.2"
__version__ = "4.2.0"

0 comments on commit 987f769

Please sign in to comment.