import sys
from typing import Callable
import threading
import time
from nmotion_transport import USBInterface, getValidUSBInterfaces, getValidNLinkInterfaces
from utils.logger import Logger
import os
#TODO: Need to check with MacOS <- valid for the cpp layer as well

class DeviceManager:
    __interfaces: 'dict[str, USBInterface]' = {}
    __stop_loop: bool = False
    __dev_manager_thread_f = None
    __connected_devices_on_last_update: 'dict[int, int]' = {}
    __on_new_dev_cb = lambda *args: None
    __on_remove_dev_cb = lambda *args: None
    
    def __init__(self, logger: Logger, on_new_device_cb: Callable[[int], None] = None, on_remove_device_cb: Callable[[int], None] = None) -> None:
        self.__logger = logger
        if(on_new_device_cb):
            self.__on_new_dev_cb = on_new_device_cb
        if(on_remove_device_cb):
            self.__on_remove_dev_cb = on_remove_device_cb        
        self.__dev_manager_thread_f = threading.Thread(target=self.__devManagerThread)
        self.__dev_manager_thread_f.start()
    
    def __initialiseInterface(self, interface_name, is_device = False) -> USBInterface:
        interface = USBInterface(interface_name)       
        if(not is_device):
            self.__logger.notify(f'Connected to NLink device at {interface_name}.')
        else:
            self.__logger.notify(f'Connected to NMotion component at {interface_name}')
            interface.setAsDevice()    # TODO: some error here -> leads to a SEG FAULT.
        return interface
    
    def __removeInterface(self, interface_name):
        is_device = self.__interfaces[interface_name].isDevice()
        if(not is_device):
            self.__logger.warn(f'NLink device removed at {interface_name}')
        else:
            self.__logger.warn(f'NMotion component removed at {interface_name}')
        try:
            # remove all devices connected to the interface
            for (dev_id, dev_type) in self.__interfaces[interface_name].getConnectedDevices().items():
                stripped_iface_name = interface_name.split("/")[-1]
                dev_identifier = f'{dev_id}::{dev_type}::{stripped_iface_name}'
                self.__on_remove_dev_cb(dev_identifier)           
                self.__connected_devices_on_last_update.pop(dev_identifier)
            
            # iface = self.__interfaces.pop(interface_name)
            # iface.close()       
            self.__interfaces[interface_name].close()
            del self.__interfaces[interface_name]
        except Exception as err:
            print(f'Exception on remove interface: {err}')
              
        # TODO: del causes terminate called after throwing an instance of Serial::SerialException <- make serial library
        # more fault tolerant
        # try: 
        #     del self.__interfaces[interface_name]
        # except Exception as err:
        #     print(f'Exception on del: {err}')       
       
    def __devManagerThread(self):
        while(not self.__stop_loop):
            self.__updateInterfaces()
            self.__updateDevices()
            time.sleep(0.5)
    
    def __updateInterfaces(self):
        current_nlink_ports = getValidNLinkInterfaces()
        current_component_ports = list(set(getValidUSBInterfaces()) - set(current_nlink_ports))
        
        # get diff
        ports_on_last_update = list(self.__interfaces.keys())
        newly_added_ports = list((set(current_nlink_ports) | set(current_component_ports)) - set(ports_on_last_update))
        removed_ports = list(set(ports_on_last_update) - set(current_nlink_ports) - set(current_component_ports))

        # add to dict
        for port in newly_added_ports:
            if(port in current_nlink_ports):
                self.__interfaces[port] = self.__initialiseInterface(f'{port}')
            else:
                self.__interfaces[port] = self.__initialiseInterface(f'{port}', True)
        
        # remove from dict
        for port in removed_ports:
            self.__removeInterface(port)
            
    def __updateDevices(self):
        current_connected_devices = {}
        for iface in self.__interfaces.values():
            devices = iface.getConnectedDevices()
            iface_name = iface.getName()
            stripped_iface_name = iface_name.split("/")[-1]
            # if(not iface.isDevice()):
            for (key, value) in devices.items():
                current_connected_devices.update({f'{key}::{value}::{stripped_iface_name}': iface})
            # else:
                # device = iface.getConnectedDevices()
                # if(device):
                    # key, value = next(iter(device.items()))
                    # current_connected_devices.update({f'{key}::{value}::{stripped_iface_name}': iface})
        
            # current_connected_devices = current_connected_devices + list(iface.getConnectedDevices().keys())
                        
        # Diff between this and __connected_devices_on_last_update list and call necessary functions
        newly_added_devices = list(set(current_connected_devices.keys()) - set(self.__connected_devices_on_last_update.keys()))
        removed_devices = list(set(self.__connected_devices_on_last_update.keys()) - set(current_connected_devices.keys()))
        
        for dev in newly_added_devices:
            self.__on_new_dev_cb(dev, current_connected_devices[dev])

        for dev in removed_devices:
            self.__on_remove_dev_cb(dev)

        # for dev in newly_added_devices:
        #     for iface in self.__interfaces.values():
        #         connected_devs = iface.getConnectedDevices()
        #         if(dev in connected_devs.keys()):
        #             self.__on_new_dev_cb(dev, connected_devs[dev], iface)
        
        # for dev in removed_devices:
        #     self.__on_remove_dev_cb(dev, self.__connected_devices_on_last_update[dev])
        
        self.__connected_devices_on_last_update = current_connected_devices    
                        
    def stopThreads(self):
        self.__stop_loop = True        
        
    def __del__(self):
        self.__stop_loop = True
        if(self.__dev_manager_thread_f):
            if(self.__dev_manager_thread_f.is_alive()):
                self.__dev_manager_thread_f.join()