-
Notifications
You must be signed in to change notification settings - Fork 0
/
arena.py
57 lines (48 loc) · 1.59 KB
/
arena.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from typing import Dict
import joblib
from joblib.parallel import delayed
from game.base import Game, Player
from util import ProgressParallel
class Arena:
"""
Pit two players against each other at a game
"""
players: Dict[int, Player]
game: Game
n_jobs: int
def __init__(
self, player_1: Player, player_2: Player, game: Game, n_jobs: int = 1
) -> None:
self.players = {1: player_1, -1: player_2}
self.game = game
self.n_jobs = joblib.parallel.effective_n_jobs(n_jobs)
def play_game(self):
"""
Execute a single episode of a game
Returns
-------
result: int
+1 if player_1 won, -1 if player_2 won, 0 if draw
"""
game = self.game
while not game.over:
p = self.players[game.state.current_player]
action = p.play(game)
game = game.move(action)
return game.reward_player(1)
def play_games(self, n_games: int):
"""
Execute n_games, swapping starting player after for half the games
"""
n_games1 = n_games // 2
n_games2 = n_games - n_games1
res1 = ProgressParallel(self.n_jobs, total=n_games1, leave=False)(
delayed(self.play_game)() for _ in range(n_games1)
)
# Switch starting player
self.game.cfg.start_player *= -1
res2 = ProgressParallel(self.n_jobs, total=n_games2, leave=False)(
delayed(self.play_game)() for _ in range(n_games2)
)
res = res1 + res2
return res.count(1), res.count(-1), res.count(0)