-
Notifications
You must be signed in to change notification settings - Fork 2
/
mod_mcts.py
115 lines (100 loc) · 4.17 KB
/
mod_mcts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import time
import math
import random
def randomPolicy(state):
while not state.isTerminal():
try:
action = random.choice(state.getPossibleActions())
except IndexError:
raise Exception("Non-terminal state has no possible actions: " + str(state))
state = state.takeAction(action)
return state.getReward()
class treeNode():
def __init__(self, state, parent):
self.state = state
self.isTerminal = state.isTerminal()
self.isFullyExpanded = self.isTerminal
self.parent = parent
self.numVisits = 0
self.totalReward = 0
self.children = {}
def __str__(self):
s=[]
s.append("totalReward: %s"%(self.totalReward))
s.append("numVisits: %d"%(self.numVisits))
s.append("isTerminal: %s"%(self.isTerminal))
s.append("possibleActions: %s"%(self.children.keys()))
return "%s: {%s}"%(self.__class__.__name__, ', '.join(s))
class mcts():
def __init__(self, timeLimit=None, iterationLimit=None, explorationConstant=1 / math.sqrt(2),
rolloutPolicy=randomPolicy):
if timeLimit != None:
if iterationLimit != None:
raise ValueError("Cannot have both a time limit and an iteration limit")
# time taken for each MCTS search in milliseconds
self.timeLimit = timeLimit
self.limitType = 'time'
else:
if iterationLimit == None:
raise ValueError("Must have either a time limit or an iteration limit")
# number of iterations of the search
if iterationLimit < 1:
raise ValueError("Iteration limit must be greater than one")
self.searchLimit = iterationLimit
self.limitType = 'iterations'
self.explorationConstant = explorationConstant
self.rollout = rolloutPolicy
self.tree = {}
def search(self, initialState, needDetails=False):
self.root = treeNode(initialState, None)
if self.limitType == 'time':
timeLimit = time.time() + self.timeLimit / 1000
while time.time() < timeLimit:
self.executeRound()
else:
for i in range(self.searchLimit):
self.executeRound()
bestChild = self.getBestChild(self.root, 0)
action=(action for action, node in self.root.children.items() if node is bestChild).__next__()
if needDetails:
return {"action": action, "expectedReward": bestChild.totalReward / bestChild.numVisits}
else:
return action
def executeRound(self):
node = self.selectNode(self.root)
reward = self.rollout(node.state)
self.backpropogate(node, reward)
def selectNode(self, node):
while not node.isTerminal:
if node.isFullyExpanded:
node = self.getBestChild(node, self.explorationConstant)
else:
return self.expand(node)
return node
def expand(self, node):
actions = node.state.getPossibleActions()
for action in actions:
if action not in node.children.keys():
newNode = treeNode(node.state.takeAction(action), node)
node.children[action] = newNode
if len(actions) == len(node.children):
node.isFullyExpanded = True
return newNode
raise Exception("Should never reach here")
def backpropogate(self, node, reward):
while node is not None:
node.numVisits += 1
node.totalReward += reward
node = node.parent
def getBestChild(self, node, explorationValue):
bestValue = float("-inf")
bestNodes = []
for child in node.children.values():
nodeValue = node.state.getCurrentPlayer() * child.totalReward / child.numVisits + explorationValue * math.sqrt(
2 * math.log(node.numVisits) / child.numVisits)
if nodeValue > bestValue:
bestValue = nodeValue
bestNodes = [child]
elif nodeValue == bestValue:
bestNodes.append(child)
return random.choice(bestNodes)