-
Notifications
You must be signed in to change notification settings - Fork 5
/
mcts.py
260 lines (197 loc) · 8.1 KB
/
mcts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
from math import sqrt
import numpy as np
from random import randint
MCTS_COUNT = 100
PRUNING_COUNT = 5
def predict_policy(pb, pw, state):
a, b, c = (15, 15, 2)
if state.check_turn():
x = np.array([state.black, state.white])
x = x.reshape(c, a, b).transpose(1, 2, 0).reshape(1, a, b, c)
y = pb.predict(x, batch_size=1)
else:
x = np.array([state.white, state.black])
x = x.reshape(c, a, b).transpose(1, 2, 0).reshape(1, a, b, c)
y = pw.predict(x, batch_size=1)
policies = y[0][list(state.referee()[0])]
if sum(policies) != 0:
policies /= sum(policies)
return policies
def predict_value(vb, vw, state):
win_action = []
defend_action = []
attack_action = []
defend2_action = []
if state.check_turn():
me = state.black
enemy = state.white
else:
me = state.white
enemy = state.black
for i in range(15):
for j in range(15):
if state.black[i][j] == 0 and state.white[i][j] == 0 and state.check_5(i, j):
win_action.append(15 * i + j)
if not state.check_turn():
if state.black[i][j] == 0 and state.white[i][j] == 0 and state.check_6(i, j):
win_action.append(15 * i + j)
if not win_action and enemy[i][j] == 1:
for k in state.check_defend(i, j):
if k not in defend_action and state.check_legal(k // 15, k % 15)[0]:
defend_action.append(k)
if not win_action and not defend_action and me[i][j] == 1:
for k in state.check_attack(i, j):
if k not in attack_action and state.check_legal(k // 15, k % 15)[0]:
attack_action.append(k)
if not win_action and not defend_action and state.black[i][j] == 0 and state.white[i][j] == 0:
if state.check_finish(i, j):
attack_action.append(15 * i + j)
if not win_action and not defend_action and not attack_action and enemy[i][j] == 1:
for k in state.check_defend2(i, j):
if k not in defend2_action and state.check_legal(k // 15, k % 15)[0]:
defend2_action.append(k)
if win_action:
return 1.0
elif len(defend_action) == 0 and attack_action:
return 1.0
elif len(defend_action) + len(defend2_action) >= 2:
return 0.0
a, b, c = (15, 15, 2)
if state.check_turn():
x = np.array([state.black, state.white])
x = x.reshape(c, a, b).transpose(1, 2, 0).reshape(1, a, b, c)
y = vb.predict(x, batch_size=1)
else:
x = np.array([state.white, state.black])
x = x.reshape(c, a, b).transpose(1, 2, 0).reshape(1, a, b, c)
y = vw.predict(x, batch_size = 1)
value = y[0]
return value
def nodes_to_scores(nodes):
scores = []
for i in nodes:
scores.append(i.n)
return scores
def p_pruning(nodes):
def get_p(node):
return node.p
pruning = []
for i in nodes:
pruning.append(i)
pruning.sort(key=get_p, reverse=True)
return pruning[:PRUNING_COUNT]
def pv_mcts_action(p_model, v_model, temperature = 0):
def pv_mcts_action(state):
scores = pv_mcts_scores(p_model, v_model, state, temperature)
return np.random.choice(state.referee()[0], p = scores)
return pv_mcts_action
def pv_mcts_scores(pb, pw, vb, vw, state):
class Node:
def __init__(self, p_s, p_action, p):
self.state = p_s
self.previous_action = p_action
self.p = p
self.w = 0
self.n = 0
self.child_nodes = None
def evaluate(self):
result = self.state.referee()
# 게임 종료 시
if result[2] != 0:
if result[2] == 1:
value = -1 # 패배
else:
value = 0 # 무승부
self.w += value
self.n += 1
return value
if not self.child_nodes:
if self.previous_action != None:
self.state = self.state.next(self.previous_action)
self.previous_action = None
policies = predict_policy(pb, pw, self.state)
value = predict_value(vb, vw, self.state)
self.w += value
self.n += 1
self.child_nodes = []
for action, policy in zip(result[0], policies):
self.child_nodes.append(Node(self.state, action, policy))
return value
else:
value = -self.next_child_node().evaluate()
self.w += value
self.n += 1
return value
def next_child_node(self):
C_PUCT = 1.0
t = sum(nodes_to_scores(self.child_nodes))
pucb_values = []
pruned_child_nodes = p_pruning(self.child_nodes)
#print(self.child_nodes[84].previous_action, self.child_nodes[84].p)
for child_node in pruned_child_nodes:
if child_node.n != 0:
a = (-child_node.w / child_node.n)
else:
a = 0
pucb_values.append(a + C_PUCT * child_node.p * sqrt(t) / (1 + child_node.n))
#for _ in range(5):
#print(pruned_child_nodes[_].previous_action, pruned_child_nodes[_].p)
return pruned_child_nodes[np.argmax(pucb_values)]
root_node = Node(state, None, 0)
for _ in range(MCTS_COUNT):
root_node.evaluate()
scores = nodes_to_scores(root_node.child_nodes)
action = np.argmax(scores)
scores = np.zeros(len(scores))
scores[action] = 1
return scores
def mcts_action(pb, pw, vb, vw, state):
if state.black == [[0]*15]*15 and state.white == [[0]*15]*15:
return 112
win_action = []
defend_action = []
attack_action = []
defend2_action = []
if state.check_turn():
me = state.black
enemy = state.white
else:
me = state.white
enemy = state.black
for i in range(15):
for j in range(15):
if state.black[i][j] == 0 and state.white[i][j] == 0 and state.check_5(i, j):
win_action.append(15 * i + j)
if not state.check_turn():
if state.black[i][j] == 0 and state.white[i][j] == 0 and state.check_6(i, j):
win_action.append(15 * i + j)
if not win_action and enemy[i][j] == 1:
for k in state.check_defend(i, j):
if k not in defend_action and state.check_legal(k // 15, k % 15)[0]:
defend_action.append(k)
if not win_action and not defend_action and me[i][j] == 1:
for k in state.check_attack(i, j):
if k not in attack_action and state.check_legal(k // 15, k % 15)[0]:
attack_action.append(k)
if not win_action and not defend_action and state.black[i][j] == 0 and state.white[i][j] == 0:
if state.check_finish(i, j):
attack_action.append(15 * i + j)
if not win_action and not defend_action and not attack_action and enemy[i][j] == 1:
for k in state.check_defend2(i, j):
if k not in defend2_action and state.check_legal(k // 15, k % 15)[0]:
defend2_action.append(k)
if win_action:
return win_action[randint(0, len(win_action) - 1)]
elif defend_action:
return defend_action[randint(0, len(defend_action) - 1)]
elif attack_action:
return attack_action[randint(0, len(attack_action) - 1)]
scores = pv_mcts_scores(pb, pw, vb, vw, state)
# 랜덤성 부여
#action = np.random.choice(state.referee()[0], p = scores)
# 최선의 수 착수
action = state.referee()[0][np.argmax(scores)]
if defend2_action and action not in defend2_action:
if state.count_4(action//15, action%15) == 0:
return defend2_action[randint(0, len(defend2_action) - 1)]
return action