Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
package battle.controllers.Piers;
import asteroids.Action;
import battle.NeuroShip;
import battle.SimpleBattle;
import java.util.Random;
/**
* Created by Piers on 12/06/2015.
*/
public class BetterMCTSNode {
private static final double EPSILON = 1e-6;
private static Action[] allActions;
private static Random random = new Random();
private static int numberOfActionsPerState = 15;
private Action ourMoveToThisState;
// private Action enemyMoveToThisState;
// private Action[][] possibleActions;
private BetterMCTSNode parent;
private BetterMCTSNode[] children;
private int numberOfChildrenExpanded;
private boolean ourNode;
private int currentDepth;
private int playerID;
private double totalValue = 0;
// private double enemyTotalValue = 0;
private int numberOfVisits = 1;
private double explorationConstant;
public BetterMCTSNode(double explorationConstant, int playerID) {
this.explorationConstant = explorationConstant;
currentDepth = 0;
this.playerID = playerID;
ourNode = true;
children = new BetterMCTSNode[allActions.length];
// possibleActions = p1Yesp2Not;
}
private BetterMCTSNode(BetterMCTSNode parent, Action ourMoveToThisState) {
this.explorationConstant = parent.explorationConstant;
this.ourMoveToThisState = ourMoveToThisState;
this.parent = parent;
this.children = new BetterMCTSNode[allActions.length];
this.currentDepth = parent.currentDepth + 1;
this.playerID = parent.playerID;
this.ourNode = parent.ourNode;
}
public static void setAllActions() {
allActions = new Action[6];
// notShootActions = new Action[6];
int i = 0;
int j = 0;
for (double thrust = 1; thrust <= 1; thrust += 1) {
for (double turn = -1; turn <= 1; turn += 1) {
allActions[i++] = new Action(thrust, turn, true);
allActions[i++] = new Action(thrust, turn, false);
}
}
}
public BetterMCTSNode select(SimpleBattle state, int maxDepth) {
BetterMCTSNode current = this;
while (current.currentDepth <= maxDepth) {
if (current.fullyExpanded()) {
current = current.selectBestChild();
state.update(current.ourMoveToThisState, allActions[random.nextInt(allActions.length)]);
} else {
return current.expand(state);
}
}
return current;
}
public BetterMCTSNode expand(SimpleBattle state) {
// Calculate the possible action spaces
// can we shoot
children = new BetterMCTSNode[allActions.length];
int childToExpand = random.nextInt(allActions.length);
while (children[childToExpand] != null) {
childToExpand = random.nextInt(allActions.length);
}
children[childToExpand] = new BetterMCTSNode(this, allActions[childToExpand]);
numberOfChildrenExpanded++;
return children[childToExpand];
}
public BetterMCTSNode selectBestChild() {
double bestScore = -Double.MAX_VALUE;
int bestIndex = -1;
for (int i = 0; i < children.length; i++) {
if (children[i] != null) {
double score = children[i].calculateChild();
if (score > bestScore) {
bestScore = score;
bestIndex = i;
}
}
}
return children[bestIndex];
}
public double calculateChild() {
return (totalValue / numberOfVisits) +
(explorationConstant * (Math.sqrt(Math.log(parent.numberOfVisits) / numberOfVisits)));
}
public void updateValues(double value) {
BetterMCTSNode current = this;
double alteredValue = value / 1000;
while (current.parent != null) {
current.numberOfVisits++;
current.totalValue += alteredValue;
current = current.parent;
}
current.totalValue += alteredValue;
current.numberOfVisits++;
}
public double rollout(SimpleBattle state, int maxDepth) {
int currentRolloutDepth = this.currentDepth;
while (maxDepth > currentRolloutDepth && !state.isGameOver()) {
Action first = allActions[random.nextInt(allActions.length)];
Action second = allActions[random.nextInt(allActions.length)];
state.update(first, second);
}
return state.getPoints(playerID) / (100 - state.getMissilesLeft(playerID));// - state.getPoints((playerID == 1) ? 0 : 1);
}
public Action getBestAction() {
double bestScore = -Double.MAX_VALUE;
int bestIndex = -1;
if (children == null) return allActions[0];
for (int i = 0; i < children.length; i++) {
if (children[i] != null) {
double childScore = children[i].totalValue + (random.nextFloat() * EPSILON);
if (childScore > bestScore) {
bestScore = childScore;
bestIndex = i;
}
}
}
if (bestIndex == -1) return allActions[0];
return children[bestIndex].ourMoveToThisState;
}
private boolean fullyExpanded() {
return numberOfChildrenExpanded == allActions.length;
}
public void printAllChildren() {
StringBuilder builder = new StringBuilder();
for (int i = 0; i < children.length; i++) {
builder.append("Value: ");
builder.append(children[i].totalValue / children[i].numberOfVisits);
builder.append(" Action: ");
builder.append(children[i].ourMoveToThisState);
builder.append("UCB: ");
builder.append(children[i].calculateChild());
builder.append("\n");
}
System.out.println(builder.toString());
}
}