Commit f406081a authored by Joseph Walton-Rivers's avatar Joseph Walton-Rivers 🐦

ensure that getUCTNode is handled correctly on unexpanded nodes

parent 0e4b6fc5
Pipeline #2152 passed with stages
in 2 minutes and 48 seconds
......@@ -25,6 +25,7 @@ public class MCTS implements Agent {
public static final int DEFAULT_ROLLOUT_DEPTH = 18;
public static final int DEFAULT_TREE_DEPTH_MUL = 1;
public static final int NO_LIMIT = 100;
protected static final boolean OLD_UCT_BEHAVIOUR = false;
protected final int roundLength;
protected final int rolloutDepth;
......
......@@ -9,7 +9,9 @@ import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.stream.Collectors;
......@@ -36,7 +38,9 @@ public class MCTSNode {
private double score;
private int visits;
private int parentWasVisitedAndIWasLegal;
private int parentWasVisitedAndIWasLegalOld;
protected Map<Action, Integer> legalChildVisits;
protected final StatsSummary rolloutScores;
protected final StatsSummary rolloutMoves;
......@@ -73,6 +77,8 @@ public class MCTSNode {
this.random = new Random();
this.depth = (parent == null) ? 0 : parent.depth + 1;
this.legalChildVisits = new HashMap<>();
this.rolloutScores = new BasicStats();
this.rolloutMoves = new BasicStats();
......@@ -89,7 +95,8 @@ public class MCTSNode {
return 0;
}
return ((score / MAX_SCORE) / visits) + (expConst * Math.sqrt(Math.log(parentWasVisitedAndIWasLegal) / visits));
int legalVisits = MCTS.OLD_UCT_BEHAVIOUR ? parentWasVisitedAndIWasLegalOld : parent.legalChildVisits.get(moveToState);
return ((score / MAX_SCORE) / visits) + (expConst * Math.sqrt(Math.log(legalVisits) / visits));
}
public List<MCTSNode> getChildren() {
......@@ -123,7 +130,10 @@ public class MCTSNode {
if (!moveToMake.isLegal(child.agentId, state)) {
continue;
}
child.parentWasVisitedAndIWasLegal++;
child.parentWasVisitedAndIWasLegalOld++;
updateVisitCount(moveToMake);
double childScore = child.getUCTValue() + (random.nextDouble() * EPSILON);
if (childScore > bestScore) {
......@@ -131,9 +141,22 @@ public class MCTSNode {
bestChild = child;
}
}
//now, update all children we haven't expanded yet, but we could have done
int nextPlayer = (getAgent() + 1) % state.getPlayerCount();
for (Action unexpandedAction : allUnexpandedActions) {
if (unexpandedAction.isLegal(nextPlayer, state)) {
updateVisitCount(unexpandedAction);
}
}
return bestChild;
}
protected void updateVisitCount(Action action) {
int current = legalChildVisits.getOrDefault(action, 0);
legalChildVisits.put(action, current + 1);
}
public int getAgent() {
return agentId;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment