-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmcts.py
More file actions
216 lines (185 loc) · 7.04 KB
/
mcts.py
File metadata and controls
216 lines (185 loc) · 7.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""
Monte carlo Tree search
Negative eval: Black
Positive eval: White
"""
from dataclasses import dataclass, field
import math
from typing import Optional, Self, cast
import torch
# import torch.nn.functional as F
import numpy as np
import numpy.typing as npt
from scipy.special import softmax
from arguments import Arguments
from game import GameState
from piece import Turn, convert_status_to_score
def ucb_score(parent, child, expl_param=math.sqrt(2)):
"""
The score for an action that would transition between the parent and child.
"""
prior_score = (
child.prior * math.sqrt(parent.visit_count + 1) / (child.visit_count + 1)
)
if child.visit_count > 0:
# The value of the child is from the perspective of the opposing player
value_score = -child.value()
else:
value_score = 0
return value_score + expl_param * prior_score
@dataclass
class Node:
"""
Node class
Contains:
visit_count: How many times this node has been visited
to_play: Which player is to play
prior: Prior probability
value_sum: (I guess combined value of its children)
children: Action-Board dictionary
"""
prior: float
state: GameState
visit_count: int = 0
value_sum: float = 0
children: dict[int, Self] = field(default_factory=dict)
def expanded(self):
"""
Return number of children
"""
return len(self.children) > 0
def value(self):
"""
Average value of its children
"""
if self.visit_count == 0:
return 0
return self.value_sum / self.visit_count
def select_action(self, temperature: float) -> int:
"""
Select action according to the visit count distribution and the temperature.
"""
visit_counts = np.array([child.visit_count for child in self.children.values()])
actions = list(self.children.keys())
if temperature < 1e-5:
action = actions[np.argmax(visit_counts)]
elif temperature == float("inf"):
action = np.random.choice(actions)
else:
# See paper appendix Data Generation
# As the temperature increase, we are more bound to choose better choices
visit_count_distribution = visit_counts ** (1 / temperature)
assert np.sum(visit_count_distribution) > 0, visit_count_distribution
visit_count_distribution = visit_count_distribution / np.sum(
visit_count_distribution
)
action = np.random.choice(actions, p=visit_count_distribution)
return action
def select_child(self):
"""
Select the child with the highest UCB score.
"""
best_score = -np.inf
best_action = -1
best_child = None
for action, child in self.children.items():
score = ucb_score(self, child)
if score > best_score:
best_score = score
best_action = action
best_child = child
return best_action, best_child
def expand(self, state: GameState, action_probs: npt.ArrayLike):
"""
We expand a node and keep track of the prior policy probability given by neural network
"""
self.state = state
assert np.sum(action_probs == 0) < len(action_probs), f"{action_probs}"
for a, prob in enumerate(action_probs):
if prob != 0:
new_state = self.state.move(a)
self.children[a] = cast(Self, Node(prob, new_state))
def __hash__(self) -> int:
return self.state.__hash__()
def __repr__(self):
"""
Debugger pretty print node info
"""
return f"{self.state} Prior: {self.prior:.2f} Count: {self.visit_count,} Value: {\
self.value()} children: {len(self.children)}"
@dataclass
class MCTS:
"""
Monte Carlo Tree Search agent
"""
model: torch.nn.Module
args: Arguments
root: Optional[Node] = None
def select_action(self, state: GameState):
st = state.canonical_representation()
action_probs, value = self.model.predict(st)
valid_moves = np.array(state.get_valid_moves())
action_probs = softmax(action_probs.astype(np.float32))
action_probs = action_probs * valid_moves # mask invalid moves
action_probs /= max(np.sum(action_probs), 1e-10) # mask invalid moves
if np.sum(action_probs == 0) == len(action_probs):
action_probs = valid_moves.astype(np.float32) / np.sum(valid_moves.astype(np.float32))
return action_probs, value
def move_head(self, action: int):
if self.root is not None:
self.root = self.root.children[action]
def run(self, state: GameState, action: Optional[int] = None):
"""
Run mcts given current state and previous action
"""
if (
self.root is not None
and action is not None
and action in self.root.children.keys()
and self.root.children[action].state == state
):
self.root = self.root.children[action]
else:
self.root = Node(0, state)
self.root.state = state
assert self.root is not None
if not self.root.expanded():
action_probs, value = self.select_action(state)
self.root.expand(state, action_probs)
assert self.root.expanded()
for _ in range(self.args.num_simulations):
node = self.root
search_path: list[Optional[Node]] = [node]
action = None
# SELECT
while node is not None and node.expanded():
action, node = node.select_child()
search_path.append(node)
assert len(search_path) > 1, f"Search path not expanded: {node} {self.root}"
parent = cast(Node, search_path[-2])
state = parent.state
# Now we're at a leaf node and we would like to expand
# Players always play from their own perspective
state = state.move(cast(int, action))
# Get the board from the perspective of the other player
value = state.is_winning()
value = (
convert_status_to_score(value, state.turn)
if value is not None
else None
)
if value is None:
# If the game has not ended:
# EXPAND
action_probs, value = self.select_action(state)
node = cast(Node, node)
node.expand(state, action_probs)
self.backpropagate(cast(list[Node], search_path), value, Turn(state.turn))
def backpropagate(self, search_path: list[Node], value: float, to_play: Turn):
"""
At the end of a simulation, we propagate the evaluation all the way up the tree
to the root.
"""
for node in reversed(search_path):
node.value_sum += value if node.state.turn == to_play else -value
node.visit_count += 1