Commit 0e4b6fc5 authored by Joseph Walton-Rivers's avatar Joseph Walton-Rivers 🐦

ensure that mcts/pmcts forwards the state during expand

parent 6795dcc6
Pipeline #2151 passed with stages
in 2 minutes and 51 seconds
......@@ -144,16 +144,20 @@ public class MCTS implements Agent {
protected MCTSNode select(MCTSNode root, GameState state, IterationObject iterationObject) {
MCTSNode current = root;
int treeDepth = calculateTreeDepthLimit(state);
while (!state.isGameOver() && current.getDepth() < treeDepth) {
boolean expandedNode = false;
while (!state.isGameOver() && current.getDepth() < treeDepth && !expandedNode) {
MCTSNode next;
if (current.fullyExpanded(state)) {
next = current.getUCTNode(state);
} else {
next = expand(current, state);
return next;
expandedNode = true;
}
if (next == null) {
//XXX if all follow on states explored so far are null, we are now a leaf node
//ok to early return here - we will have applied current last time round the loop!
return current;
}
current = next;
......
......@@ -76,8 +76,9 @@ public class MCTSPredictor extends MCTS {
protected MCTSNode select(MCTSNode root, GameState state, IterationObject iterationObject) {
MCTSNode current = root;
int treeDepth = calculateTreeDepthLimit(state);
boolean expanded = false;
while (!state.isGameOver() && current.getDepth() < treeDepth) {
while (!state.isGameOver() && current.getDepth() < treeDepth && !expanded) {
MCTSNode next;
if (current.fullyExpanded(state)) {
next = current.getUCTNode(state);
......@@ -92,11 +93,14 @@ public class MCTSPredictor extends MCTS {
if (numChildren != current.getChildSize()) {
// It is new
return next;
expanded = true;
//return next;
}
}
// Forward the state
if (next == null) {
//ok to early return, current advanced in last game tick
return current;
}
current = next;
......
......@@ -129,13 +129,16 @@ public class MCTSExpConst implements Agent {
protected MCTSNode select(MCTSNode root, GameState state, IterationObject iterationObject) {
MCTSNode current = root;
int treeDepth = calculateTreeDepthLimit(state);
while (!state.isGameOver() && current.getDepth() < treeDepth) {
boolean expanded = false;
while (!state.isGameOver() && current.getDepth() < treeDepth && !expanded) {
MCTSNode next;
if (current.fullyExpanded(state)) {
next = current.getUCTNode(state);
} else {
next = expand(current, state);
return next;
expanded = true;
//return next;
}
if (next == null) {
//XXX if all follow on states explored so far are null, we are now a leaf node
......
......@@ -73,8 +73,9 @@ public class MCTSPredictorExpConst extends MCTSExpConst {
protected MCTSNode select(MCTSNode root, GameState state, IterationObject iterationObject) {
MCTSNode current = root;
int treeDepth = calculateTreeDepthLimit(state);
boolean expanded = false;
while (!state.isGameOver() && current.getDepth() < treeDepth) {
while (!state.isGameOver() && current.getDepth() < treeDepth && !expanded) {
MCTSNode next;
if (current.fullyExpanded(state)) {
next = current.getUCTNode(state);
......@@ -89,7 +90,8 @@ public class MCTSPredictorExpConst extends MCTSExpConst {
if (numChildren != current.getChildSize()) {
// It is new
return next;
expanded = true;
//return next;
}
}
// Forward the state
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment