Skip to content

Commit

Permalink
Option to not swap players in data.py
Browse files Browse the repository at this point in the history
  • Loading branch information
vladfi1 committed Oct 23, 2024
1 parent b1b0def commit 053032f
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion slippi_ai/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class DatasetConfig:
# comma-separated lists of characters, or "all"
allowed_characters: str = 'all'
allowed_opponents: str = 'all'
swap: bool = True # yield swapped versions of each replay
seed: int = 0


Expand All @@ -109,6 +110,24 @@ def replays_from_meta(config: DatasetConfig) -> List[ReplayInfo]:
replay_meta = ReplayMeta.from_metadata(row)
replay_path = os.path.join(config.data_dir, replay_meta.slp_md5)

if not config.swap:
is_banned = False
for name in [replay_meta.p0.name, replay_meta.p1.name]:
if nametags.is_banned_name(name):
banned_counts[name] += 1
is_banned = True

if is_banned:
continue

if (replay_meta.p0.character not in allowed_characters
or replay_meta.p1.character not in allowed_opponents):
continue

replays.append(ReplayInfo(replay_path, False, replay_meta))

continue

for swap in [False, True]:
players = [replay_meta.p0, replay_meta.p1]
if swap:
Expand Down Expand Up @@ -155,7 +174,7 @@ def train_test_split(
replays.append(ReplayInfo(replay_path, False))
replays.append(ReplayInfo(replay_path, True))

# reproducible train/test split
# TODO: stable partition
rng = random.Random(config.seed)
rng.shuffle(replays)
num_test = int(config.test_ratio * len(replays))
Expand Down

0 comments on commit 053032f

Please sign in to comment.