from PIL import ImageFile 
import numpy as np 
import matplotlib.pyplot as plt 
import matplotlib.patches as patches 
import cv2
import random
import os
from PIL import Image
import csv
AVERAGE_SIZE = (32, 32)  # Thanks to the stats, we know that size of bbox will be (127, 145) -> Average size of labels 

# Dictionary for mapping class names to integers
    "danger": 0,
    "interdiction": 1,
    "obligation": 2,
    "stop": 3,
    "ceder": 4,
    "frouge": 5,
    "forange": 6,
    "fvert": 7,
    "ff": 8,
    "empty": 9

# Dictionary for mapping integers to class names
    0: "danger",
    1: "interdiction",
    2: "obligation",
    3: "stop",
    4: "ceder",
    5: "frouge",
    6: "forange",
    7: "fvert",
    8: "ff",
    9: "empty"

# Data labels key
CLASSES = ["danger", "interdiction", "obligation", "stop", "ceder", "frouge", "forange", "fvert", "ff", "empty"]

# Number of classes

def load_dataset(image_dir, label_dir):
    # Initialize empty lists to store images (X) and labels (Y)
    for label_file in os.listdir(label_dir):
        label_path = os.path.join(label_dir, label_file)
        file_name = int(label_file.split('.')[0])  # Extract the file name to find corresponding image file
            image_path = os.path.join(image_dir, str(file_name).zfill(4) + ".jpg")            
            image ="RGB")  # Open the image
            # Read bounding boxes from the label file
            with open(label_path, "r") as file:
                reader = csv.reader(file)
                bboxes = list(reader)

            # Check if there are any bounding boxes in the label file
            if bboxes != [[]]: 
                # Iterate over each bounding box
                for box in bboxes:
                    # Convert class label from string to integer using CLASSE_TO_INT dictionary
                    box[4] = CLASSE_TO_INT[box[4]]
                    # Convert all elements in the bounding box to integers
                    box[:] = map(int, box)
                    # Extract Region of Interest (ROI) from the image based on the bounding box
                    roi = image.crop((box[0], box[1], box[2], box[3]))  
                    # Resize the ROI to a predefined average size
                    roi_resized = roi.resize(AVERAGE_SIZE)
                    # Append the resized ROI to X and its corresponding class label to Y
                # If no bounding boxes are present, generate empty bounding boxes
                for _ in range(3):
                    box = list(generate_empty_bbox(image_width=image.size[1], image_height=image.size[0]))
                    # Extract ROI from image based on empty bounding box
                    roi = image.crop((box[0], box[1], box[2], box[3]))  
                    # Resize the ROI to a predefined average size
                    roi_resized = roi.resize(AVERAGE_SIZE)
                    # Append the resized ROI to X and the class label for empty to Y

        except FileNotFoundError:
            print(f"Image file not found for {file_name}")
        except Exception as e:
            print(f"Error when processing index {file_name}: {e}")

    # Convert the lists X and Y to numpy arrays and return them
    return np.array(X), np.array(Y)

# Function to calculate Intersection over Union (IoU) 
def iou(box1, box2):
    Calcule l'Intersection over Union (IoU) entre deux boîtes englobantes.

    box1 (tuple): Une boîte englobante sous la forme (x1, y1, x2, y2) où (x1, y1) est le coin supérieur gauche et (x2, y2) est le coin inférieur droit.
    box2 (tuple): Une deuxième boîte englobante sous la même forme (x1, y1, x2, y2).

    float: La valeur IoU entre les deux boîtes englobantes.
    # Coordonnées des coins des boîtes ([Axe][corner_idx]_[boxe_idx])
    x0_1, y0_1, x1_1, y1_1 = box1
    x0_2, y0_2, x1_2, y1_2 = box2

    # Calcul des coordonnées de l'intersection
    x1_inter = max(x0_1, x0_2)
    y1_inter = max(y0_1, y0_2)
    x2_inter = min(x1_1, x1_2)
    y2_inter = min(y1_1, y1_2)

    # Calcul de l'aire de l'intersection
    inter_area = max(0, x2_inter - x1_inter) * max(0, y2_inter - y1_inter)

    # Calcul de l'aire des deux boîtes
    box1_area = (x1_1 - x0_1) * (y1_1 - y0_1)
    box2_area = (x1_2 - x0_2) * (y1_2 - y0_2)

    # Calcul de l'aire de l'union
    union_area = box1_area + box2_area - inter_area

    # Calcul de l'IoU
    iou = inter_area / union_area if union_area > 0 else 0

    return iou

# Function to calculate Non Maximum Suppression (NMS) 
def non_maximum_suppression(bboxes, threshold=0.5):
        Apply non-maximum suppression to filter overlapping bounding boxes
        :param bboxes: List of proposed bounding boxes with their scores
        :param threshold: IoU threshold for suppression
        :return: List of final bounding boxes
        if len(bboxes) == 0:
            return []

        # Extract the coordinates and scores
        x1 = torch.tensor([bbox[0] for bbox in bboxes])
        y1 = torch.tensor([bbox[1] for bbox in bboxes])
        x2 = torch.tensor([bbox[2] for bbox in bboxes])
        y2 = torch.tensor([bbox[3] for bbox in bboxes])
        scores = torch.tensor([bbox[4] for bbox in bboxes])

        # Compute the area of the bounding boxes and sort by score
        areas = (x2 - x1 + 1) * (y2 - y1 + 1)
        _, order = scores.sort(0, descending=True)

        keep = []
        while order.numel() > 0:
            i = order[0]

            if order.numel() == 1:

            xx1 = x1[order[1:]].clamp(min=x1[i])
            yy1 = y1[order[1:]].clamp(min=y1[i])
            xx2 = x2[order[1:]].clamp(max=x2[i])
            yy2 = y2[order[1:]].clamp(max=y2[i])

            w = (xx2 - xx1 + 1).clamp(min=0)
            h = (yy2 - yy1 + 1).clamp(min=0)

            inter = w * h
            ovr = inter / (areas[i] + areas[order[1:]] - inter)

            # Keep only elements with an overlap less than the threshold
            inds = (ovr <= threshold).nonzero(as_tuple=False).squeeze()
            order = order[inds + 1]

        final_bboxes = [bboxes[idx] for idx in keep]

        return final_bboxes

# Function to plot images with bounding boxes and class labels 
def plot_bbox_image(image, boxes):
    # Getting the color map from matplotlib 
    colour_map = plt.get_cmap("tab20b") 
    # Getting different colors from the color map for 20 different classes 
    colors = [colour_map(i) for i in np.linspace(0, 1, NB_CLASSES)] 

    # Getting the height and width of the image 
    h, w, _ = image.shape 

    # Create figure and axes 
    fig, ax = plt.subplots(1) 

    # Add image to plot 

    # Plotting the bounding boxes and labels over the image 
    for box in boxes:
        # Get the class from the box 
            class_pred = box[4]
            class_pred=1  # No classe (maybe because of selective search) set at 1 randomly
        if class_pred != CLASSE_TO_INT["empty"]:
            x = box[0] 
            y = box[1]
            width = box[2] - x
            height = box[3] - y

            # Create a Rectangle patch with the bounding box 
            rect = patches.Rectangle( 
                (x, y), width, height, 
            # Add the patch to the Axes 
            # Add class name to the patch 
                bbox={"color": colors[int(class_pred)], "pad": 0}, 

    # Display the plot

def selective_search(image, visualize=False, visulize_count=100):
    # Convert image to BGR format for OpenCV
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    # Initialiser la recherche sélective
    ss = cv2.ximgproc.segmentation.createSelectiveSearchSegmentation()

    # Utiliser la recherche sélective en mode rapide (ou en mode qualité)
    ss.switchToSelectiveSearchFast()  # Pour la recherche rapide
    # ss.switchToSelectiveSearchQuality()  # Pour une recherche plus précise

    # Obtenir les régions candidates
    roi = ss.process()

    if visualize:
        # Dessiner les régions candidates sur l'image
        for (x, y, w, h) in roi[:visulize_count]:  # Limiter à 100 régions pour la visualisation
            cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 0), 2)

        # Afficher l'image avec les régions candidates
        plt.figure(figsize=(10, 10))

    return roi

#  Generate an empty box for images without label
def generate_empty_bbox(image_width, image_height):
    # Generating random coords for the bbox
    x_min = random.randint(0, image_width - AVERAGE_SIZE[0])
    y_min = random.randint(0, image_height - AVERAGE_SIZE[1])
    # Compute complete coords of the bbox
    x_max = x_min + AVERAGE_SIZE[0]
    y_max = y_min + AVERAGE_SIZE[1]
    return (x_min, y_min, x_max, y_max)