from threading import Thread
from typing import Optional, Callable, ByteString, Union

import zmq
from zmq import Frame


class Manager:
    def __init__(self, recv_addr='tcp://127.0.0.1:5557', send_addr='tcp://127.0.0.1:5558'):
        self.__ctx = zmq.Context()
        self.__recv_addr = recv_addr
        self.__send_addr = send_addr
        self.__send_socket: Optional[zmq.Socket] = None
        self.__recv_socket: Optional[zmq.Socket] = None
        self.__recv_thread: Optional[Thread] = None
        self.__recv_callback: Optional[Callable[[Union[Frame, ByteString]], None]] = None
        self.__running = False

    def start(self, recv_callback: Callable[[Union[Frame, ByteString]], None]) -> None:
        self.__recv_callback = recv_callback
        self._create_recv()
        self._create_send()
        self._start_recv_loop()

    def _create_recv(self):
        self.__recv_socket = self.__ctx.socket(zmq.SUB)
        self.__recv_socket.connect(self.__recv_addr)
        self.__recv_socket.setsockopt_string(zmq.SUBSCRIBE, "")

    def _start_recv_loop(self):
        self.__running = False
        if self.__recv_thread:
            self.__recv_thread.join()
        self.__running = True
        self.__recv_thread = Thread(target=self._recv_loop)
        self.__recv_thread.start()

    def _recv_loop(self):
        while self.__running:
            msg = self.__recv_socket.recv()
            if self.__recv_callback:
                self.__recv_callback(msg)

    def _create_send(self):
        self.__send_socket = self.__ctx.socket(zmq.PUSH)
        self.__send_socket.connect(self.__send_addr)

    def send_message(self, msg: Union[Frame, ByteString]) -> None:
        msg_tracker = self.__send_socket.send(msg)