Skip to content
Snippets Groups Projects
Commit cf2a20fd authored by Chloé Taurel's avatar Chloé Taurel
Browse files

brain strategy

parent f0c7dec7
No related branches found
No related tags found
1 merge request!10Merge taurelch/minimax
......@@ -14,6 +14,7 @@ from game.strategy import (
StrategyAlphaBeta,
NegascoutStrategy,
RandomStrategy,
BrainStrategy,
)
from game.helpers.mcts import UCB
from game.utils import other
......@@ -139,11 +140,37 @@ def run(n=100, debug=False):
return wins, ttimes, niters, n
w, t, ni, n = run(1, True)
#w, t, ni, n = run(1, True)
print(
f"MCTS | Wins percentage: {w[mcts_index] / n * 100:.2f}% | Time: {t[mcts_index] / ni[mcts_index]:.2f}"
)
print(
f"AlphaBeta | Wins percentage: {w[alpha_index] / n * 100:.2f}% | Time: {t[alpha_index] / ni[alpha_index]:.2f}"
)
#print(
#f"MCTS | Wins percentage: {w[mcts_index] / n * 100:.2f}% | Time: {t[mcts_index] / ni[mcts_index]:.2f}"
#)
#print(
#f"AlphaBeta | Wins percentage: {w[alpha_index] / n * 100:.2f}% | Time: {t[alpha_index] / ni[alpha_index]:.2f}"
#)
grid = HexGrid.split_state(4, 2, 1)
rules = Dodo(grid)
random = RandomStrategy(grid, rules, 1)
brain = BrainStrategy(grid, rules, 2)
total_time = 0
n = 0
t = (grid, rules)
ply = 1
while not rules.game_over(grid):
if ply == 1:
s = time.time()
grid.move(random.get_action(grid, ply), ply)
total_time += time.time() - s
n += 1
grid.debug_plot()
else:
grid.move(brain.get_action(grid, ply), ply)
grid.debug_plot()
ply = 3 - ply
......@@ -16,7 +16,7 @@ from game.helpers.mcts import MCTSNode, UCB
from game.helpers.minimax import Minimax, NegaScout
from . import utils
from .rules import Rules
from .rules import Rules, Dodo, Gopher
class Strategy:
......@@ -487,3 +487,61 @@ class MCTSStrategy(Strategy):
i = choice(range(len(moves)), 1, p=probabilities)[0]
return moves[i]
class BrainStrategy(Strategy):
""" Implementation of the human brain strategy"""
def get_action(self, grid: HexGrid, ply: Player) -> Action:
"""
Asks the user what move they want to play
Args:
grid: HexGrid
ply: Player
Returns:
Action desired by the user
"""
def get_int_input(prompt):
while True:
value = input(prompt)
if value.isdigit():
return int(value)
else:
print("Invalid input. Please enter an integer.")
print("Legal moves : ")
print(self.rules.get_legal_moves(grid, ply))
print("Please insert:")
if isinstance(self.rules, Dodo):
origin_cell_x = get_int_input("Horizontal coordinates of origin cell of desired action: ")
origin_cell_y = get_int_input("Vertical coordinates of origin cell of desired action: ")
move_cell_x = get_int_input("Horizontal coordinates of landing cell of desired action: ")
move_cell_y = get_int_input("Vertical coordinates of landing cell of desired action: ")
while not (((origin_cell_x, origin_cell_y), (move_cell_x, move_cell_y)) in self.rules.get_legal_moves(grid,
ply)):
print("This move is not legal. Please try another.")
origin_cell_x = get_int_input("Horizontal coordinates of origin cell of desired action: ")
origin_cell_y = get_int_input("Vertical coordinates of origin cell of desired action: ")
move_cell_x = get_int_input("Horizontal coordinates of landing cell of desired action: ")
move_cell_y = get_int_input("Vertical coordinates of landing cell of desired action: ")
return (origin_cell_x, origin_cell_y), (move_cell_x, move_cell_y)
if isinstance(self.rules, Gopher):
move_cell_x = get_int_input("Horizontal coordinates of landing cell of desired action: ")
move_cell_y = get_int_input("Vertical coordinates of landing cell of desired action: ")
while not ((move_cell_x, move_cell_y) in self.rules.get_legal_moves(grid, ply)):
print("This move is not legal. Please try another.")
move_cell_x = get_int_input("Horizontal coordinates of landing cell of desired action: ")
move_cell_y = get_int_input("Vertical coordinates of landing cell of desired action: ")
return move_cell_x, move_cell_y
return None
\ No newline at end of file
......@@ -565,7 +565,7 @@ class HexGrid:
return color
@staticmethod
def __add_hexagon(ax, x_center, y_center, player=0, **kwargs):
def __add_hexagon(ax, x_center, y_center, q, r, player=0, **kwargs):
color = HexGrid.__get_color(player)
hexagon = RegularPolygon(
......@@ -578,6 +578,8 @@ class HexGrid:
)
ax.add_patch(hexagon)
# Add text annotation for the hex coordinates (q, r)
ax.text(x_center, y_center, f'({q},{r})', ha='center', va='center', fontsize=8, color='black')
def debug_plot(self, bold=None, save=False, labels=None):
"""
......@@ -602,23 +604,24 @@ class HexGrid:
x = q * np.sin(np.radians(45)) - r * np.cos(np.radians(45))
y = (q * np.cos(np.radians(45)) + r * np.sin(np.radians(45))) * 9 / 16
HexGrid.__add_hexagon(ax, x, y, player=player, edgecolor="black")
HexGrid.__add_hexagon(ax, x, y, q, r, player=player, edgecolor="black")
ax.set_aspect("equal")
ax.axis("off")
plt.autoscale(enable=True)
legend = []
for label, player in labels:
color = HexGrid.__get_color(player)
line = mlines.Line2D(
[], [], color=color, marker="o", markersize=10, label=label
if labels is not None:
legend = []
for label, player in labels:
color = HexGrid.__get_color(player)
line = mlines.Line2D(
[], [], color=color, marker="o", markersize=10, label=label
)
legend.append(line)
ax.legend(
handles=legend, loc="upper right", bbox_to_anchor=(1.3, 1.12)
)
legend.append(line)
ax.legend(
handles=legend, loc="upper right", bbox_to_anchor=(1.3, 1.12)
)
if save:
plt.savefig(f"plots/hexgrid_{HexGrid.Plots}.png")
......@@ -627,3 +630,4 @@ class HexGrid:
return
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment