Skip to content
Snippets Groups Projects
Commit 48c338e6 authored by Gabriel Santamaria's avatar Gabriel Santamaria
Browse files

Adding statistics for first alphabeta depth, tuned the algorithm

parent 3778777a
No related branches found
No related tags found
No related merge requests found
......@@ -9,10 +9,11 @@ import time
from random import choice
import numpy as np
from tqdm import tqdm
import gc
from hexgrid.grid import HexGrid
from game.rules import Rules, Dodo
from game.strategy import MCTSStrategy
from game.strategy import MCTSStrategy, StrategyAlphaBeta
from game.helpers.mcts import UCB
from game.utils import other
......@@ -171,5 +172,82 @@ def write_csv(wins, times):
f.write(f"{i},{j},{k}, {wins[i, j, k]},{times[i, j, k]}\n")
data_wins, data_times = iterate()
write_csv(data_wins, data_times)
# data_wins, data_times = iterate()
# write_csv(data_wins, data_times)
def run_one():
grid: HexGrid = HexGrid.split_state(4, 2, 1)
rules: Rules = Dodo()
rules.clear_cache()
player_two = MCTSStrategy(grid, rules, 2)
player_one = StrategyAlphaBeta(grid, rules, 1, depth=10)
ply = 1
ttimes = [0, 0]
niter = [0, 0]
wins = [0, 0]
# grid.debug_plot(labels=[("MCTS", 2)], save=True)
while not rules.game_over(grid):
if ply == 1:
s = time.time()
try:
action = player_one.get_action(grid, ply)
ttimes[1] += time.time() - s
niter[1] += 1
grid.move(action, ply)
except RuntimeError:
print("Runtime error, retrying...")
continue
else:
s = time.time()
try:
action = player_two.get_action(grid, ply)
ttimes[0] += time.time() - s
niter[0] += 1
grid.move(action, ply)
except RuntimeError:
print("Runtime error, retrying...")
continue
ply = other(ply)
# grid.debug_plot(labels=[("MCTS", 2)], save=True)
ply_one_won = rules.has_won(grid, 1)
if ply_one_won:
wins[1] += 1
else:
wins[0] += 1
gc.collect()
return wins, ttimes, niter
def get_stats(n=100):
wins = [0, 0]
ttimes = [0, 0]
niters = [0, 0]
for _ in tqdm(range(n)):
w, t, n = run_one()
wins[0] += w[0]
wins[1] += w[1]
ttimes[0] += t[0]
ttimes[1] += t[1]
niters[0] += n[0]
niters[1] += n[1]
return wins, ttimes, niters
wins, ttimes, niters = get_stats(50)
print(f"MCTS wins: {wins[0]}")
print(f"AlphaBeta wins: {wins[1]}")
print(f"MCTS time: {ttimes[0] / niters[0]:.10f}")
print(f"AlphaBeta time: {ttimes[1] / niters[1]:.10f}")
......@@ -186,15 +186,14 @@ Nous avons comparé les performances de notre algorithme avec celles de l'algori
## Pour Dodo
Les statistiques ont été réalisées sur des grilles de taille $4 \times 4$ et sur 200 parties (100 en tant que joueur 1, 100 en tant que joueur 2) pour chaque algorithme. L'algorithme MCTS est toujours le même, avec le même nombre de simulations (à savoir $512$).
Les statistiques ont été réalisées sur des grilles de taille $4 \times 4$ et sur 200 parties (100 en tant que joueur 1, 100 en tant que joueur 2) pour chaque algorithme. L'algorithme MCTS est toujours le même, avec le même nombre de simulations (à savoir $256$).
| Minimax Depth | % Win (Minimax) | % Win (MCTS) | Temps / coup (Minimax) | Temps / coup (MCTS) |
| ------------- | --------------- | ------------ | ---------------------- | ------------------- |
| 10 | 1.00% | 99.00% | 0.03s | 1.76s |
| 10 | 0.00% | 100.00% | 0.01s | 0.66s |
| 15 | | | | |
| 20 | | | | |
| 25 | | | | |
| 35 | 0.00% | 100% | 2.31s | 1.86s |
# Références
......
......@@ -349,6 +349,10 @@ class Dodo(Rules):
return 0
@staticmethod
def clear_cache():
Dodo.__legals = {}
class Gopher(Rules):
"""
......
......@@ -190,7 +190,7 @@ class MCTSStrategy(Strategy):
ply: Player,
simulations: int = 256,
exploration_weight: float = 0.7,
threshold: int = 10,
threshold: int = 30,
ucb=UCB.ucb,
minimax_level: int = 10,
):
......@@ -205,7 +205,7 @@ class MCTSStrategy(Strategy):
self.minimax_level = minimax_level
self.legals = []
self.niter = 10
self.niter = 2
self.players: list[StrategyAlphaBeta] = [
StrategyAlphaBeta(grid, rules, 1, minimax_level),
......
......@@ -52,67 +52,3 @@ def get_action_type(action: Action) -> bool:
return True
case (_, _):
return False
def __pos_flattner(pos, size):
offset = size - 1
q, r = pos
return (q + offset) + size * 2 * (r + offset)
def __pos_hexitifier(value, size):
offset = size - 1
r = value // (size * 2)
q = value % (size * 2)
return (q - offset, r - offset)
def flatify(action: Action, size: int) -> int:
"""
For the sake of computing efficiency (in an attempt to optimize MCTS)
we need to have a way to "flatten" the actions into a single integer,
that will be used to index the action inside numpy arrays to allow
vectorized computation.
Flatify does the transform Action -> int
Args:
action: The action to flatify
size: The size of the grid
Returns:
The flatified action
"""
match action:
case ((x1, y1), (x2, y2)): # it's a Dodo action
max_coord = size * 2
a = __pos_flattner((x1, y1), size)
b = __pos_flattner((x2, y2), size)
return a + (b * (max_coord * max_coord))
case (x1, y1):
return __pos_flattner((x1, y1), size)
def unflatify(action: int, size: int, atype: bool) -> Action:
"""
Inverse operation of the flatify function.
Args:
action: The flatified action
size: The size of the grid
atype: The type of the action to decode, True for Dodo, False for Gopher
Returns:
The unflatified action
"""
match atype:
case True: # A Dodo action
max_coord = size * 2
b = action // (max_coord * max_coord)
a = action % (max_coord * max_coord)
x = __pos_hexitifier(a, size)
y = __pos_hexitifier(b, size)
return (x, y)
case False: # A Gopher action
return __pos_hexitifier(action, size)
......@@ -5,16 +5,12 @@ Entry point for the hexgrid module.
from typing import Generator, Callable
import struct
import numpy as np
from matplotlib import use
import matplotlib.pyplot as plt
from matplotlib.patches import RegularPolygon
import matplotlib.lines as mlines
from xxhash import xxh128_intdigest as xxh
from xxhash import xxh32_intdigest as xxh
from api import State, Cell, Player, Action
from game.utils import other
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment