-
Notifications
You must be signed in to change notification settings - Fork 5
/
predict.py
122 lines (92 loc) · 3.85 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np
from random import randint
# 예측 네트워크 시각화에 사용
def get_policy(pb, pw, state):
a, b, c = (15, 15, 2)
if state.check_turn():
x = np.array([state.black, state.white])
else:
x = np.array([state.white, state.black])
x = x.reshape(c, a, b).transpose(1, 2, 0).reshape(1, a, b, c)
# x = x.transpose(1, 2, 0).reshape(1, a, b, c)
if state.check_turn():
y = pb.predict(x, batch_size=1)
else:
y = pw.predict(x, batch_size=1)
policies = y[0][list(state.referee()[0])]
if sum(policies) != 0:
policies /= sum(policies)
m = np.argmax(policies)
return policies, policies[m]
# 예상 승률 출력에 사용
def get_value(vb, vw, state):
a, b, c = (15, 15, 2)
if state.check_turn():
x = np.array([state.black, state.white])
x = x.reshape(c, a, b).transpose(1, 2, 0).reshape(1, a, b, c)
y = vb.predict(x, batch_size=1)
else:
x = np.array([state.white, state.black])
x = x.reshape(c, a, b).transpose(1, 2, 0).reshape(1, a, b, c)
y = vw.predict(x, batch_size=1)
value = y[0]
return value
# 예측 네트워크만으로 착수 (0.02초)
def predict_p(pb, pw, state):
if state.black == [[0]*15]*15 and state.white == [[0]*15]*15:
return 112
win_action = []
defend_action = []
attack_action = []
defend2_action = []
if state.check_turn():
me = state.black
enemy = state.white
else:
me = state.white
enemy = state.black
for i in range(15):
for j in range(15):
if state.black[i][j] == 0 and state.white[i][j] == 0 and state.check_5(i, j):
win_action.append(15 * i + j)
if not state.check_turn():
if state.black[i][j] == 0 and state.white[i][j] == 0 and state.check_6(i, j):
win_action.append(15 * i + j)
if not win_action and enemy[i][j] == 1:
for k in state.check_defend(i, j):
if k not in defend_action and state.check_legal(k // 15, k % 15)[0]:
defend_action.append(k)
if not win_action and not defend_action and me[i][j] == 1:
for k in state.check_attack(i, j):
if k not in attack_action and state.check_legal(k // 15, k % 15)[0]:
attack_action.append(k)
if not win_action and not defend_action and state.black[i][j] == 0 and state.white[i][j] == 0:
if state.check_finish(i, j):
attack_action.append(15 * i + j)
if not win_action and not defend_action and not attack_action and enemy[i][j] == 1:
for k in state.check_defend2(i, j):
if k not in defend2_action and state.check_legal(k // 15, k % 15)[0]:
defend2_action.append(k)
if win_action:
return win_action[randint(0, len(win_action) - 1)]
elif defend_action:
return defend_action[randint(0, len(defend_action) - 1)]
elif attack_action:
return attack_action[randint(0, len(attack_action) - 1)]
a, b, c = (15, 15, 2)
if state.check_turn():
x = np.array([state.black, state.white])
x = x.reshape(c, a, b).transpose(1, 2, 0).reshape(1, a, b, c)
y = pb.predict(x, batch_size=1)
else:
x = np.array([state.white, state.black])
x = x.reshape(c, a, b).transpose(1, 2, 0).reshape(1, a, b, c)
y = pw.predict(x, batch_size=1)
policies = y[0][list(state.referee()[0])]
if sum(policies) != 0:
policies /= sum(policies)
move = np.random.choice(state.referee()[0], p=policies)
if defend2_action and move not in defend2_action:
if state.count_4(move//15, move%15) == 0:
return defend2_action[randint(0, len(defend2_action) - 1)]
return move