-
Notifications
You must be signed in to change notification settings - Fork 0
/
MCTS.py
196 lines (164 loc) · 6.48 KB
/
MCTS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import math
from typing import Dict, Hashable, List
import numpy as np
from c4zero import C4Zero
from game.base import Game
class Node:
pb_c_base: int = 19652
pb_c_init: float = 1.25
prior: float
value_sum: float
children: Dict[int, "Node"]
game: Game
_visit_count: int
def __init__(self, prior: float, game: Game):
self.prior = prior # prob of selecting node
self.value_sum = 0 # total value from all visits
self.children = {} # legal child positions
self.game = game
self._visit_count = 0
@property
def n(self):
return self._visit_count
@property
def expanded(self):
return len(self.children) > 0
@property
def value(self):
if self.n == 0:
return 0
return self.value_sum / self.n
def select_action(self, temperature):
"""
Select an action from this node
Actions are chosen according to the visit count distribution and temperature.
"""
visit_counts = np.array([child.n for child in self.children.values()])
actions = [action for action in self.children.keys()]
if temperature == 0:
action = actions[np.argmax(visit_counts)]
elif temperature == float("inf"):
action = np.random.choice(actions)
else:
# See paper appendix Data Generation
visit_count_distribution = visit_counts ** (1 / temperature)
visit_count_distribution /= sum(visit_count_distribution)
action = np.random.choice(actions, p=visit_count_distribution)
return action
def select_child(self):
"""Select the child with the highest UCB score"""
if not self.expanded:
raise UserWarning("Node not expanded, cannot select a child")
children_scores = [(self.ucb_score(c), c) for c in self.children.values()]
max_score = max(score for score, c in children_scores)
best_child = next(c for score, c in children_scores if score == max_score)
return best_child
def expand(self, action_probs: np.ndarray):
"""
Expand a node and keep track of the prior policy probability
"""
actions = self.game.get_valid_actions()
valid_actions = [action is not None for action in actions]
action_probs = action_probs * valid_actions # Mask invalid moves
action_probs /= np.sum(action_probs) # Normalise new probs
for a, prob in enumerate(action_probs.flatten()):
if prob == 0:
continue
self.children[a] = Node(prob, self.game.move(actions[a]))
def ucb_score(self, child: "Node"):
"""Calculate the upper confidence bound score between nodes"""
# Exploration bonus based on prior score
c_puct = (
math.log((self.n + self.pb_c_base + 1) / self.pb_c_base) + self.pb_c_init
)
c_puct *= math.sqrt(self.n) / (child.n + 1)
prior_score = c_puct * child.prior
# The value of the child is from the perspective of the opposing player
value_score = -child.value
return value_score + prior_score
def add_exploration_noise(self, alpha=0.3, e_frac=0.25):
"""
Add Dirichlet noise to a Node's priors to increase exploratory behaviour
Parameters
----------
alpha: float (default = 0.3)
The shape of the gamma distribution. Must be positive.
e_frac: float (default = 0.25)
The exploration fraction - how much noise to apply to the priors
"""
if not self.expanded:
raise UserWarning("Cannot add noise to Node before expansion")
actions = self.children.keys()
noise = np.random.gamma(alpha, 1, len(actions))
for a, n in zip(actions, noise):
self.children[a].prior = self.children[a].prior * (1 - e_frac) + n * e_frac
return self
def __str__(self):
"""Pretty print node info"""
p = "{0:.2f}".format(self.prior)
return (
f"{self.game.state}\nPrior: {p} Count: {self.n} "
+ f"Value: {self.value}\nExpanded: {self.expanded}"
)
def __repr__(self) -> str:
return f"Node({self.prior}, {self.game!r})"
class MCTS:
"""Class to perform a Monte Carlo Tree Search"""
game: Game
model: C4Zero
_Ps: Dict[Hashable, np.ndarray]
_vs: Dict[Hashable, float]
def __init__(self, game: Game, model: C4Zero):
self.game = game
self.model = model
# Cache for prediction values
self._Ps = {}
self._vs = {}
def run(
self,
n_simulations: int,
root_dirichlet_alpha: float = 0.3,
root_explore_frac: float = 0.25,
):
"""Run Monte Carlo Tree Search"""
root = Node(0, self.game)
# Expand root
# Get policy, value from model
action_probs, _ = self.cached_predict(root)
root.expand(action_probs)
root.add_exploration_noise(root_dirichlet_alpha, root_explore_frac)
# Simulate gameplay from this position
for _ in range(n_simulations):
node = root
search_path: List[Node] = [node]
# Select
while node.expanded:
node = node.select_child()
search_path.append(node)
value = node.game.reward_player()
if value is None: # Game not over
# Expand & Evaluate
# TODO: Randomly reflect/rotate board along game symmetry line here
# See Methods: Expand and Evaluate
action_probs, value = self.cached_predict(node)
node.expand(action_probs)
# Backup
MCTS.backpropagate(search_path, value, node.game.state.current_player)
return root
@staticmethod
def backpropagate(search_path: List[Node], value: float, current_player: int):
"""
At the end of a simulation, we propagate the evaluation all the way up the tree
to the root.
"""
for node in reversed(search_path):
node.value_sum += (
value if node.game.state.current_player == current_player else -value
)
node._visit_count += 1
def cached_predict(self, node: Node):
"""Predict policy at a given node and cache the result"""
s = node.game.state.hash()
if s not in self._Ps:
self._Ps[s], self._vs[s] = self.model.predict(node.game.encode())
return self._Ps[s], self._vs[s]