diff --git a/src/battle/controllers/Piers/MCTSNode.java b/src/battle/controllers/Piers/MCTSNode.java index 96658374a60dfd7798195aa5bc1999a15d9ed944..e832bf96a6e20f8393e0725e6595e5f33ec620b9 100644 --- a/src/battle/controllers/Piers/MCTSNode.java +++ b/src/battle/controllers/Piers/MCTSNode.java @@ -1,7 +1,7 @@ package battle.controllers.Piers; import asteroids.Action; -import asteroids.GameState; +import battle.SimpleBattle; import java.util.Random; @@ -19,17 +19,21 @@ public class MCTSNode { private MCTSNode parent; private MCTSNode[] children; private int numberOfChildrenExpanded; + private boolean ourNode; private int currentDepth; + private int playerID; private double totalValue; private int numberOfVisits = 1; private double explorationConstant; - public MCTSNode(double explorationConstant) { + public MCTSNode(double explorationConstant, int playerID) { this.explorationConstant = explorationConstant; currentDepth = 0; + this.playerID = playerID; + ourNode = true; } private MCTSNode(MCTSNode parent, Action moveToThisState) { @@ -38,6 +42,8 @@ public class MCTSNode { this.parent = parent; this.children = new MCTSNode[MCTSNode.allActions.length]; this.currentDepth = parent.currentDepth + 1; + this.playerID = parent.playerID; + this.ourNode = parent.ourNode; } public static void setAllActions() { @@ -53,7 +59,7 @@ public class MCTSNode { } } - public MCTSNode select(GameState state, int maxDepth) { + public MCTSNode select(SimpleBattle state, int maxDepth) { MCTSNode current = this; while (current.currentDepth <= maxDepth) { if (current.fullyExpanded()) { @@ -93,17 +99,22 @@ public class MCTSNode { (explorationConstant * (Math.sqrt(Math.log(parent.numberOfVisits) / numberOfVisits))); } - public void updateValues(double value) { + public void updateValues(double value, double enemyScore) { MCTSNode current = this; while (current.parent != null) { current.numberOfVisits++; - // todo work out scoring system. + current.totalValue += (current.ourNode) ? value : enemyScore; current = current.parent; } } - public double rollout(GameState state) { - return 0.0d; + public double[] rollout(SimpleBattle state) { + while (!state.isGameOver()) { + Action first = allActions[random.nextInt(allActions.length)]; + Action second = allActions[random.nextInt(allActions.length)]; + state.update(first, second); + } + return new double[]{state.getPoints(playerID), state.getPoints((playerID == 1) ? 0 : 1)}; } public Action getBestAction() { diff --git a/src/battle/controllers/Piers/PiersBattleTest.java b/src/battle/controllers/Piers/PiersBattleTest.java new file mode 100644 index 0000000000000000000000000000000000000000..8c5602ffe19672dc1ff62c4fae9ead891a21dbe1 --- /dev/null +++ b/src/battle/controllers/Piers/PiersBattleTest.java @@ -0,0 +1,11 @@ +package battle.controllers.Piers; + +/** + * Created by pwillic on 12/06/2015. + */ +public class PiersBattleTest { + + public static void main(String[] args) { + + } +} diff --git a/src/battle/controllers/Piers/PiersMCTS.java b/src/battle/controllers/Piers/PiersMCTS.java index 6a7fc8883f950e528006666ac83c3a11d38388b7..e4054fb7bea2ece85394723c591e3f47dc4eae00 100644 --- a/src/battle/controllers/Piers/PiersMCTS.java +++ b/src/battle/controllers/Piers/PiersMCTS.java @@ -11,7 +11,7 @@ public class PiersMCTS implements BattleController { @Override public Action getAction(SimpleBattle gameStateCopy, int playerId) { - MCTSNode root = new MCTSNode(2.0); + MCTSNode root = new MCTSNode(2.0, playerId); return null; }