"""
Strategy representation and implementation (minimax, MCTS, etc...)
"""

import multiprocessing
import time
from typing import Optional
from copy import deepcopy
from math import inf
from numpy.random import randint, choice

from api import Player, Action
from hexgrid.grid import HexGrid

from game.helpers.mcts import MCTSNode, UCB
from game.helpers.minimax import Minimax, NegaScout

from . import utils
from .rules import Rules, Dodo, Gopher


class Strategy:
    """
    Base class for a strategy.
    """

    def __init__(self, grid: HexGrid, rules: Rules, ply: Player):
        self.grid = grid
        self.rules = rules
        self.player = ply

    def get_action(self, grid: HexGrid, ply: Player) -> Action:
        """
        Get the action to play for a player.

        Args:
            ply: The player

        Returns:
            The action to play
        """
        raise NotImplementedError


class RandomStrategy(Strategy):
    """
    Random strategy.

    This is the simplest strategy, it chooses a random action from the
    possible and alloweds moves according to a rule.
    """

    def get_action(self, grid: HexGrid, ply: Player) -> Action:
        moves = self.rules.get_legal_moves(grid, ply)
        if moves == []:
            return None

        return moves[randint(0, len(moves))]


class StrategyAlphaBeta(Strategy):
    """
    Implementation of the minimax with Alpha-Beta pruning
    """

    def __init__(
        self,
        grid: HexGrid,
        rules: Rules,
        ply: Player,
        depth: int = 3,
    ):
        super().__init__(grid, rules, ply)
        self.depth = depth
        self.memo = {}
        self.current_state = None


    def minmax(
        self,
        ply: Player,
        alpha: float,
        beta: float,
        depth: int = 30,
    ) -> tuple[float, Optional[Action]]:
        """
        The actual Minimax search algorithm, with Alpha-Beta pruning.

        Args:
            grid: The grid
            ply: The player
            alpha: The alpha value
            beta: The beta value
            depth: The depth of the search

        Returns:
            The best score and the best action based on the minimax algorithm
        """

        if depth == 0 or self.rules.game_over(self.current_state):
            score = self.rules.score(self.current_state)
            self.memo[self.current_state] = score
            return score, None

        best_score = -inf if ply == 1 else inf
        best_action = None

        for action in self.rules.get_legal_moves(self.current_state, ply):
            self.current_state.move(action, ply)

            current_score = None
            if self.current_state in self.memo:
                current_score = self.memo[self.current_state]
            else:
                current_score, _ = self.minmax(utils.other(ply), alpha, beta, depth - 1)

            self.current_state.undo(action)

            if ply == 1:  # Maximizing player
                if current_score > best_score:
                    best_score = current_score
                    best_action = action
                alpha = max(alpha, best_score)
            else:  # Minimizing player
                if current_score < best_score:
                    best_score = current_score
                    best_action = action
                beta = min(beta, best_score)

            if alpha >= beta:
                break

        self.memo[self.current_state] = best_score
        return best_score, best_action

    def precompute_first_move(self, grid: HexGrid, ply: Player) -> Action:
        """
        Calculates the most efficient first move for an empty board of the Gopher game, using a very high depth
        This method is only meant to be used once, the results will be manually typed down in the get_first_move method

        Returns:
            The most efficient Gopher Action for the first move of the game
        """
        self.current_state = grid
        return self.minmax(ply, -inf, inf, depth=50)[1]  # Use a very high depth for precomputation

    def get_first_move(self) -> Action:
        """
        This function returns the first move for the Gopher game, calculated by the precompute_firs_move
        It is meant to save the time wasted by the alphabeta when calculating the very first move of the game, which
        takes a very long time because there are a lot of possibilities when the board is empty.
        Therefore, the first move is calculated beforehand
        Returns:
            The Gopher action of the most efficient first move, depending on the size of the grid

        """
        return -self.grid.size[0]+1, -self.grid.size[0]+1

    def get_action(self, grid: HexGrid, ply: Player) -> Action:
        self.current_state = grid

        action = None
        if self.rules.is_first_turn(grid) and isinstance(self.rules, Gopher):
            action = self.get_first_move()

        if action is not None:
            return action
        return self.minmax(ply, -inf, inf, self.depth)[1]


class NegascoutStrategy(Strategy):
    """
    Implementation of the NegaScout algorithm
    """

    def __init__(self, grid: HexGrid, rules: Rules, ply: Player, depth=3):
        super().__init__(grid, rules, ply)
        self.negascout = NegaScout(rules, depth, ply)

    def get_action(self, grid: HexGrid, ply: Player) -> Action:
        return self.negascout.get_move(grid, ply)


class MCTSStrategy(Strategy):
    """
    Implementation of the Monte-Carlo Tree Search algorithm.

    Inspired from:
        - https://gist.github.com/qpwo/c538c6f73727e254fdc7fab81024f6e1
        - https://ai-boson.github.io/mcts/

    References:
        - https://www.lri.fr/~sebag/Slides/InvitedTutorial_CP12.pdf
        - [1] Chaslot, Guillaume & Winands, Mark & Herik, H. & Uiterwijk, Jos & Bouzy, Bruno. (2008).
          Progressive Strategies for Monte-Carlo Tree Search.
          New Mathematics and Natural Computation. 04. 343-357. 10.1142/S1793005708001094.
        - [2] H. Baier and M. H. M. Winands, "Monte-Carlo Tree Search and minimax hybrids,"
          2013 IEEE Conference on Computational Inteligence in Games (CIG),
          Niagara Falls, ON, Canada, 2013, pp. 1-8, doi: 10.1109/CIG.2013.6633630.
    """

    # From the statistics got by making two random strategies
    # play agains't each others, in average, 60 moves are made
    # before a game ends. We don't have the time to wait until
    # more than 60 moves in the playout phase
    MAX_DEPTH = 30

    def __init__(
        self,
        grid: HexGrid,
        rules: Rules,
        ply: Player,
        simulations: int = 128,
        exploration_weight: float = 1.4,
        threshold: int = 50,
        ucb=UCB.ucb,
        minimax_level: int = 3,
    ):
        super().__init__(grid, rules, ply)
        atype = utils.get_action_type(rules.move_repr())
        untried = rules.get_legal_moves(grid, ply).copy()
        self.size = grid.size[0]
        self.root = MCTSNode(self.size, untried, atype=atype)
        self.simulations = simulations
        self.ce = exploration_weight
        self.threshold = threshold
        self.ucb = ucb
        self.minimax_level = minimax_level

        self.players: list[StrategyAlphaBeta] = [
            StrategyAlphaBeta(grid, rules, 1, minimax_level),
            StrategyAlphaBeta(grid, rules, 1, minimax_level),
        ]

        self.current_state = deepcopy(grid)

    def __playout(self, node: MCTSNode, rules: Rules, ply: Player) -> float:
        """
        Playout phase of the MCTS algorithm.

        Args:
            grid: The grid
            rules: The rules of the game
            player: The player

        Returns:
            The result of the game
        """
        state = deepcopy(self.current_state)
        action = node.action
        current_player = ply

        if action is not None:
            state.move(action, ply)
            current_player = utils.other(ply)

        # Citation from [1]:
        # > If the strategy is too stochastic (e.g., if it selects moves nearly randomly),
        # > then the moves played are often weak, and the level of the Monte-Carlo program
        # > is decreasing. In contrast, if the strategy is too deterministic (e.g., if the
        # > selected move for a given position is almost always the same, i.e., too much
        # > exploitation takes place) then the exploration of the search space becomes too
        # > selective, and the level of the Monte-Carlo program is decreasing too.
        #
        # Thus, the two players in the rollout phase are AlphaBeta players with a fixed depth
        # Note: the original idea was found by me, but after further research I found [2] that
        # describes a clever way of doing this (way more clever than what I wanted to do first).
        # The points that they talk about are the use of AlphaBeta algorithms not only in the
        # playout phase, but also in the selection, expansion and backpropagation phases

        depth = 0

        while True:
            if depth >= self.MAX_DEPTH:
                break

            legals = rules.get_legal_moves(state, current_player)

            if not legals:
                break

            action = None
            best = 0

            # Before using the Minimax strategy to choose the move,
            # we check that the branching factor isn't too high to
            # prevent the computing time from being too long
            if len(legals) <= 5:
                strategy: StrategyAlphaBeta = self.players[current_player - 1]

                for move in legals:
                    state.move(move, current_player)
                    strategy.current_state = state
                    score, _ = strategy.minmax(
                        utils.other(current_player), -inf, inf, self.minimax_level
                    )
                    state.undo(move)
                    if score > best:
                        best = score
                        action = move

            if action is None:
                i = choice(range(len(legals)), 1)[0]
                action = legals[i]

            state.move(action, current_player)
            current_player = utils.other(current_player)

            depth += 1

        # To accelerate the convergence, we can give a bonus
        # if the game is by multiplying it by a coefficient
        # based on the depth that decreases as the depth increase
        score = rules.player_score(state, self.player)
        if score == 1:
            return 1 + (0.8 * (2 + depth)) / (1 + depth)
        # If the returned score is zero (which means that noboby won),
        # we can still try to evaluate the likelihood of the situation;
        if score == 0:
            pmobility = len(rules.get_legal_moves(state, self.player))
            emobility = len(rules.get_legal_moves(state, utils.other(self.player)))

            if pmobility + emobility == 0:
                return 0.5

            # We will prefer situations where the player has less mobility than
            # its ennemy
            return 1 / (pmobility + 1) - pmobility / (pmobility + emobility)

        return score

    def __select(self, node: MCTSNode) -> MCTSNode:
        """
        Select a child node to explore.

        If the node has been visited enough, we use the UTC to choose the best child.
        If not, we choose the next child according to our playout strategy.

        Args:
            node: The node to start selection from

        Returns:
            The selected node
        """
        current_node: MCTSNode = node
        while current_node.is_expanded():
            if current_node.visits >= self.threshold:
                current_node = current_node.best_child(
                    exploration_weight=self.ce, ucb=self.ucb
                )
            else:
                current_node = current_node.get_max_prior()
        return current_node

    def __expand(self, node: MCTSNode, rules: Rules, ply: Player):
        """
        Expansion strategy for the MCTS algorithm.

        We're following the strategy from the paper given in reference above.

        We're expanding one node per simulation, chosen by following the selection
        strategy.

        If the node visit count exceed the threshold, we expand all the children of
        this node.

        Args:
            rules: The rules of the game.
            player: The player.
        """

        if node.visits >= self.threshold and node.untried_actions:
            # We expand all the children of the node
            # since we passed the threshold

            moves, probabilities = rules.get_moves_distribution(
                self.current_state, ply, node.untried_actions
            )

            for i, move in enumerate(moves):
                action = move
                new_state: HexGrid = deepcopy(self.current_state)
                new_state.move(action, ply)
                untried = rules.get_legal_moves(new_state, ply).copy()
                new_child: MCTSNode = MCTSNode(
                    self.size, untried, parent=node, action=action
                )
                probability: float = probabilities[i]
                new_child.prior = probability
                node.children[action] = new_child
            node.untried_actions = None

        elif node.untried_actions:
            # If the visits count is below the threshold, we're going
            # to only expand two children of the node
            moves, probabilities = rules.get_moves_distribution(
                self.current_state, ply, node.untried_actions
            )
            if not moves:
                # There are no legal moves left
                node.untried_actions = None
                return

            # Replace False unsure that we got two different childs
            # thanks stackoverflow omg
            for i in choice(
                range(len(moves)), min(6, len(moves)), p=probabilities, replace=False
            ):
                action = moves[i]
                new_state: HexGrid = deepcopy(self.current_state)
                new_state.move(action, ply)
                untried: list[Action] = rules.get_legal_moves(new_state, ply).copy()
                new_child: MCTSNode = MCTSNode(
                    self.size, untried, parent=node, action=action
                )
                probability: float = probabilities[i]
                new_child.prior = probability
                node.children[action] = new_child
                node.untried_actions.remove(action)

    def __backpropagate(self, node: MCTSNode, result: float) -> None:
        """
        Backpropagation phase of the MCTS algorithm.

        Args:
            node: The node
            result: The result of the game
        """
        while isinstance(node, MCTSNode):
            node.visits += 1
            node.value += result
            node = node.parent

    def train(self):
        """
        Performs a training of the MCTS tree.
        """
        node = self.root

        # s = time.time()
        node = self.__select(node)
        # print("Select:", time.time() - s)

        # s = time.time()
        self.__expand(node, self.rules, self.player)
        # print("Expand:", time.time() - s)

        # s = time.time()
        reward = self.__playout(node, self.rules, self.player)
        # print("Playout:", time.time() - s)

        # s = time.time()
        self.__backpropagate(node, reward)
        # print("Backpropagate:", time.time() - s)

    def __pickup(self, node, legals: list[Action]):
        while isinstance(node, MCTSNode):
            for child in node.children.values():
                best: MCTSNode = child.best_child(self.ce, ucb=self.ucb)
                maxp: MCTSNode = child.get_max_prior()
                if best and best.action in legals:
                    return best
                if maxp and maxp.action in legals:
                    return maxp
                if child.action in legals:
                    return child
            node = node.parent

        return None

    def get_action(self, grid: HexGrid, ply: Player = None) -> Action:
        """
        Get the action to play for a player using MCTS.

        Args:
            grid: The game grid
            rules: The rules of the game
            ply: The player

        Returns:
            The action to play
        """

        self.current_state = deepcopy(grid)
        # We train the MCTS algorithm
        # before choosing the best possible move

        # The training is ran in parallel in hope of
        # improving performances
        cpus = multiprocessing.cpu_count()
        with multiprocessing.Pool(cpus) as pool:
            pool.map(MCTSStrategy.train, [self] * self.simulations)

        # When choosing the best child, we need to also ensure
        # that the move is a legal one
        legals: list[Action] = self.rules.get_legal_moves(grid, ply)

        best_child: MCTSNode = self.root.best_child(self.ce, ucb=self.ucb)

        if best_child.action in legals:
            self.root = best_child
            return best_child.action

        # We're in the case where the best child we just computed is not available anymore
        # We can then delete this child and recompute the best child

        best_child: Optional[MCTSNode] = self.__pickup(best_child, legals)
        if best_child is not None:
            self.root = best_child
            return best_child.action

        # Now it's time to choose a kinda random move based
        # on the distributions of the moves
        moves, probabilities = self.rules.get_moves_distribution(grid, ply, legals)

        if not moves:  # Weird, the game should have finished
            return None

        i = choice(range(len(moves)), 1, p=probabilities)[0]
        return moves[i]

class BrainStrategy(Strategy):
    """ Implementation of the human brain strategy"""

    def get_action(self, grid: HexGrid, ply: Player) -> Action:
        """
        Asks the user what move they want to play

        Args:
            grid: HexGrid
            ply: Player

        Returns:
            Action desired by the user
        """

        def get_int_input(prompt):
            while True:
                value = input(prompt)
                try:
                    return int(value)
                except ValueError:
                    print("Invalid input. Please enter an integer.")

        print("Legal moves : ")
        print(self.rules.get_legal_moves(grid, ply))
        print("Please insert:")

        if isinstance(self.rules, Dodo):
            origin_cell_x = get_int_input("Horizontal coordinates of origin cell of desired action: ")
            origin_cell_y = get_int_input("Vertical coordinates of origin cell of desired action: ")
            move_cell_x = get_int_input("Horizontal coordinates of landing cell of desired action: ")
            move_cell_y = get_int_input("Vertical coordinates of landing cell of desired action: ")

            while not (((origin_cell_x, origin_cell_y), (move_cell_x, move_cell_y)) in self.rules.get_legal_moves(grid,
                                                                                                                  ply)):
                print("This move is not legal. Please try another.")
                origin_cell_x = get_int_input("Horizontal coordinates of origin cell of desired action: ")
                origin_cell_y = get_int_input("Vertical coordinates of origin cell of desired action: ")
                move_cell_x = get_int_input("Horizontal coordinates of landing cell of desired action: ")
                move_cell_y = get_int_input("Vertical coordinates of landing cell of desired action: ")

            return (origin_cell_x, origin_cell_y), (move_cell_x, move_cell_y)

        if isinstance(self.rules, Gopher):
            move_cell_x = get_int_input("Horizontal coordinates of landing cell of desired action: ")
            move_cell_y = get_int_input("Vertical coordinates of landing cell of desired action: ")

            while not ((move_cell_x, move_cell_y) in self.rules.get_legal_moves(grid, ply)):
                print("This move is not legal. Please try another.")
                move_cell_x = get_int_input("Horizontal coordinates of landing cell of desired action: ")
                move_cell_y = get_int_input("Vertical coordinates of landing cell of desired action: ")

            return move_cell_x, move_cell_y

        return None