diff --git a/src/battle/controllers/Piers/MCTSNode.java b/src/battle/controllers/Piers/MCTSNode.java index e832bf96a6e20f8393e0725e6595e5f33ec620b9..a1a5e3ee10b24f9eb034976180a65a1d8fff9500 100644 --- a/src/battle/controllers/Piers/MCTSNode.java +++ b/src/battle/controllers/Piers/MCTSNode.java @@ -11,10 +11,12 @@ import java.util.Random; public class MCTSNode { private static Action[] allActions; + private static Action[][] allActionPairs; private static Random random = new Random(); private static int numberOfActionsPerState = 15; - private Action moveToThisState; + private Action ourMoveToThisState; + private Action enemyMoveToThisState; private MCTSNode parent; private MCTSNode[] children; @@ -25,6 +27,7 @@ public class MCTSNode { private int playerID; private double totalValue; + private double enemyTotalValue; private int numberOfVisits = 1; private double explorationConstant; @@ -36,9 +39,10 @@ public class MCTSNode { ourNode = true; } - private MCTSNode(MCTSNode parent, Action moveToThisState) { + private MCTSNode(MCTSNode parent, Action ourMoveToThisState, Action enemyMoveToThisState) { this.explorationConstant = parent.explorationConstant; - this.moveToThisState = moveToThisState; + this.ourMoveToThisState = ourMoveToThisState; + this.enemyMoveToThisState = enemyMoveToThisState; this.parent = parent; this.children = new MCTSNode[MCTSNode.allActions.length]; this.currentDepth = parent.currentDepth + 1; @@ -57,6 +61,15 @@ public class MCTSNode { i++; } } + i = 0; + allActionPairs = new Action[144][2]; + for (Action action : allActions) { + for (Action otherAction : allActions) { + allActionPairs[i++] = new Action[]{ + action, otherAction + }; + } + } } public MCTSNode select(SimpleBattle state, int maxDepth) { @@ -64,6 +77,7 @@ public class MCTSNode { while (current.currentDepth <= maxDepth) { if (current.fullyExpanded()) { current = current.selectBestChild(); + state.update(current.ourMoveToThisState, current.enemyMoveToThisState); } else { return current.expand(); } @@ -72,11 +86,11 @@ public class MCTSNode { } public MCTSNode expand() { - int childToExpand = random.nextInt(allActions.length); + int childToExpand = random.nextInt(allActionPairs.length); while (children[childToExpand] != null) { - childToExpand = random.nextInt(allActions.length); + childToExpand = random.nextInt(allActionPairs.length); } - children[childToExpand] = new MCTSNode(this, allActions[childToExpand]); + children[childToExpand] = new MCTSNode(this, allActionPairs[childToExpand][0], allActionPairs[childToExpand][1]); return children[childToExpand]; } @@ -103,7 +117,8 @@ public class MCTSNode { MCTSNode current = this; while (current.parent != null) { current.numberOfVisits++; - current.totalValue += (current.ourNode) ? value : enemyScore; + current.totalValue += value; + current.enemyTotalValue += enemyScore; current = current.parent; } } @@ -128,7 +143,7 @@ public class MCTSNode { bestIndex = i; } } - return children[bestIndex].moveToThisState; + return children[bestIndex].ourMoveToThisState; } private boolean fullyExpanded() {