import datetime
import threading

import pygame
import cv2
import numpy as np
import torch
from threading import Event, Thread
from typing import List
from PyQt5.QtCore import QThread, pyqtSlot, pyqtSignal, QUrl, QDir, pyqtProperty
from matplotlib import pyplot as plt
from detector import Detector
from detector.utils import get_bbox_by_point
from tracker import Tracker
from video_streamer.videostreamer import VideoStreamer
from time import sleep
import time
from PyQt5.QtCore import QObject, pyqtSignal
import ctypes
from ctypes import c_int64
from server import run_server, RTSPServer, get_local_ip

showTrack = False

class Core(QThread):
    newFrame = pyqtSignal(object, int, bool, ctypes.c_int64)
    coordsUpdated = pyqtSignal(int, object, bool)

    def __init__(self, video_sources, tracker=None, detector=None, parent=None):
        super(QThread, self).__init__(parent)

        self.__detector = detector
        self.__tracker = tracker
        self.__tracker_rio = None
        self.__tracker__secc = False

        self.__rtspserver_0 = RTSPServer(get_local_ip(), 41231,"/stream0")
        threading.Thread(target=run_server,args=[self.__rtspserver_0], daemon=True).start()

        self.__rtspserver_1 = RTSPServer(get_local_ip(), 41232,"/stream1")
        threading.Thread(target=run_server,args=[self.__rtspserver_1], daemon=True).start()

        self.__video_sources = video_sources
        self.__processing_source = video_sources[0]

        self.__detection_roi = list()
        self.__is_detecting = False
        self.__detection_thread = None
        self.__thickness = 2
        self.__detection_bboxes = np.empty([])
        self.__detection_frame = None

        self.__is_tracking = False
        self.__tracking_thread = None

        self.__processing_id = 0

        self.__frame = None  # Frame property for Pygame

        # Start the continuous streaming thread
        self.__is_streaming = True
        self.__streaming_thread = Thread(target=self.__stream)
        self.__streaming_thread.start()

    @pyqtProperty(np.ndarray)
    def frame(self):
        return self.__frame

    def set_thickness(self, thickness: int):
        self.__thickness = thickness

    def set_source(self, source_id: int):
        self.__processing_source = self.__video_sources[source_id]
        self.__processing_id = source_id

    def set_video_sources(self, video_sources):
        if len(video_sources) >= 2:
            self.__video_sources = video_sources
            self.set_source(0)

    def __stream(self):
        """Continuous streaming of the video source."""
        while self.__is_streaming:
            try:

                frame_0 = self.__video_sources[0].get_frame()
                frame_1 = self.__video_sources[1].get_frame()

                if frame_1 is not None:
                    if self.__is_tracking and self.__tracker is not None and self.__processing_id == 1:

                        if self.__tracker_rio is not None:
                            x, y, w, h = map(int, self.__tracker_roi)
                            box_color = (0, 255, 0) if self.__tracker__secc else (255, 0, 0)
                            cv2.rectangle(frame_1, (x, y), (x + w, y + h), box_color, 2)
                    self.__rtspserver_1.update_frame(frame_1)

                if frame_0 is not None:
                    if self.__is_tracking and self.__tracker is not None and self.__processing_id == 0:
                        if self.__tracker_rio is not None:
                            x, y, w, h = map(int, self.__tracker_roi)
                            box_color = (0, 255, 0) if self.__tracker__secc else (255, 0, 0)
                            cv2.rectangle(frame_0, (x, y), (x + w, y + h), box_color, 2)
                    self.__rtspserver_0.update_frame(frame_0)

                sleep(0.03)
                # self.__tracker_roi = None



            except Exception as e:
                print(e)
                sleep(0.1)

    def __detection(self):
        while self.__is_detecting:
            try:
                torch.cuda.empty_cache()
                source = self.__processing_source
                roi = self.__detection_roi
                frame = source.get_frame()
                cropped_frame = frame[roi[1]:roi[3], roi[0]:roi[2]]
                results = self.__detector.predict(cropped_frame)
                global_bboxes = list()
                for result in results:
                    cls = result[0]
                    bbox = result[1:]
                    bbox[:2] += roi[:2]
                    global_bboxes.append(bbox)

                self.newFrame.emit(global_bboxes, self.__processing_id, True, c_int64(int(time.time() * 1e3)))
                self.__detection_bboxes = np.array(global_bboxes)
                self.__detection_frame = frame.copy()
                sleep(0.03)
            except Exception as e:
                print(e)
                sleep(0.1)

    def __tracking(self):
        source = self.__processing_source
        if showTrack:
            pygame.init()

            # Get actual screen resolution
            info = pygame.display.Info()
            screen_width, screen_height = info.current_w, info.current_h
            screen = pygame.display.set_mode((screen_width, screen_height), pygame.FULLSCREEN)
            pygame.display.set_caption('Tracking Frame')

            clock = pygame.time.Clock()  # Add a clock to control frame rate

        while self.__is_tracking:
            if showTrack:
                for event in pygame.event.get():  # Prevent freezing by handling events
                    if event.type == pygame.QUIT:
                        pygame.quit()
                        return

            ctime = c_int64(int(time.time() * 1000))  # Convert to c_int64

            frame = source.get_frame()
            print(f"intial frame size :{frame.shape}")
            bbox, success = self.__tracker.update(frame)
            self.__tracker_roi = bbox
            self.__tracker__secc = success

            if bbox is not None:
                center = bbox[:2] + bbox[2:] // 2
                self.coordsUpdated.emit(self.__processing_id, center, success)
                self.newFrame.emit([bbox], self.__processing_id, False, ctime)

                x, y, w, h = map(int, bbox)
                box_color = (0, 255, 0) if success else (255, 0, 0)
                cv2.rectangle(frame, (x, y), (x + w, y + h), box_color, 2)

                if showTrack:

                    # Convert OpenCV frame (BGR) to RGB
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

                    font = cv2.FONT_HERSHEY_SIMPLEX
                    font_scale = 2.25
                    font_color = (255, 255, 0)
                    thickness = 6
                    position = (50, 450)  # Bottom-left corner of the text in the image
                    now = datetime.datetime.now()
                    time_string = now.strftime("%H:%M:%S.%f")[:-3]

                    # 4. Use cv2.putText() to write time on the image
                    cv2.putText(frame, time_string, position, font, font_scale, font_color, thickness, cv2.LINE_AA)
                    cv2.putText(frame, f"{ctime}", (50, 380), font, font_scale, (255,255,255), thickness, cv2.LINE_AA)
                    # print(ctime)

                    frame = cv2.flip(frame, 1)  # Flip horizontally

                    # Resize frame while maintaining aspect ratio
                    frame_height, frame_width, _ = frame.shape
                    aspect_ratio = frame_width / frame_height

                    if aspect_ratio > (screen_width / screen_height):  # Wider than screen
                        new_width = screen_width
                        new_height = int(screen_width / aspect_ratio)
                    else:  # Taller than screen
                        new_height = screen_height
                        new_width = int(screen_height * aspect_ratio)

                    resized_frame = cv2.resize(frame, (new_width, new_height))

                    # Convert to Pygame surface without unnecessary rotation
                    frame_surface = pygame.surfarray.make_surface(resized_frame)

                    # Optional: If rotation is needed, use pygame.transform.rotate()
                    frame_surface = pygame.transform.rotate(frame_surface, -90)  # Example rotation

                    # Center the frame
                    x_offset = (screen_width - new_width) // 2
                    y_offset = (screen_height - new_height) // 2

                    screen.fill((0, 0, 0))  # Clear screen
                    screen.blit(frame_surface, (x_offset, y_offset))
                    pygame.display.flip()

                    clock.tick(30)  # Limit FPS to prevent excessive CPU usage

    def start_detect(self, x: int, y: int, w: int, h: int):
        self.__detection_roi = [x, y, x + w, y + h]

        if not self.__is_detecting:
            if self.__detection_thread is not None:
                self.__detection_thread.join()
            self.__is_detecting = True
            self.__detection_thread = Thread(target=self.__detection)
            self.__detection_thread.start()

    def stop_detection(self):
        self.__is_detecting = False
        if self.__detection_thread is not None and self.__detection_thread.is_alive():
            self.__detection_thread.join()

        self.__detection_thread = None

    def start_track(self, x: int, y: int, w: int = 0, h: int = 0):
        print(f"start tracking: {x}, {y}, {w}, {h}")
        try:
            self.__is_detecting = False
            self.__is_tracking = False
            bbox = None
            if w == 0:
                if len(self.__detection_bboxes):
                    bbox = get_bbox_by_point(self.__detection_bboxes, np.array([x, y]))
                frame = self.__detection_frame
            else:
                bbox = np.array([x, y, w, h])
                frame = self.__processing_source.get_frame()

            self.__tracker.stop()

            if bbox is not None:
                self.__tracker.init(frame, bbox)
            else:
                return
        except Exception as e:
            print(e)
            return

        if self.__tracking_thread is not None:
            self.__tracking_thread.join()
            self.stop_track()
        self.__is_tracking = True
        self.__tracking_thread = Thread(target=self.__tracking)
        self.__tracking_thread.start()
        sleep(0.03)

    def stop_track(self):
        if showTrack:
            pygame.quit()
        print("stop tracking")
        self.stop_detection()
        self.__tracker.stop()
        self.__is_tracking = False
        self.__tracking_thread = None

    def __draw_bbox(self, img: np.ndarray, bbox, color):
        thickness = self.__thickness
        cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2] + bbox[0], bbox[3] + bbox[1]),
                      color, thickness)