import threading
import multiprocess as mp
import rpyc
import uuid
from rpyc.utils.server import ThreadedServer

@rpyc.service
class CustomService(rpyc.Service):
    def __init__(self, function):
        self.function = function
        self.lock = threading.Lock()
        
    @rpyc.exposed
    def get_data(self):
        data = []
        try:
            self.lock.acquire()
            data = self.function()
            self.lock.release()
        except Exception as ex:
            return 'ERROR'
        return data            
    
    @rpyc.exposed
    def stop_server(self):
        self.server.close()
            
    def update_server(self, server):
        self.server = server

class LivePlotter():  
    def __init__(self, logger, num_samples = 200, data_rate = 100, plot_rate = 10) -> None:
        self.count = 0
        self._num_samples = num_samples
        self._data_rate = data_rate
        self._plot_rate = plot_rate
        self._logger = logger
        
    def plot(self, get_data_callback):
        fetch_data_lock = threading.Lock()
        def plot_data(count, plot_rate, num_samples, data_rate):
            import matplotlib.pyplot as plt
            import numpy as np
            import matplotlib.animation as animation
            import time
            import matplotlib
            matplotlib.use('TkAgg')

            cancellation_token = threading.Event()
            conn = rpyc.connect("localhost", port=(18820 + count), config={'allow_public_attrs': True, 'import_custom_exceptions': True})
            
            global vals
            vals = np.ndarray((0,0))
            
            global time_vals
            time_vals = np.ndarray(0)
            
            global initial_time
            initial_time = time.time()
            
            def closed(evt):
                cancellation_token.set()
                
            def fetch_data():
                global vals
                global time_vals
                global initial_time

                while not cancellation_token.is_set():
                    data = conn.root.get_data()
                    timestamp = time.time() - initial_time
                    
                    if(data == 'ERROR'):
                        cancellation_token.set()
                        break   
                    
                    fetch_data_lock.acquire()
                    if(len(vals) == 0):
                        vals = data
                        time_vals = timestamp
                    else:
                        vals = np.vstack((vals, data))    
                        time_vals = np.append(time_vals, timestamp)
                    fetch_data_lock.release()

                    if(len(vals) > num_samples):
                        vals = vals[-num_samples:]
                        time_vals = time_vals[-num_samples:]
                    
                    time.sleep(1/data_rate)
                
            fetch_data_f = threading.Thread(target=fetch_data)
                            
            fig, ax = plt.subplots()
            try:
                fig.canvas.manager.set_window_title('Plot ' + str(uuid.uuid4())[:4])
            except:
                self._logger.warn("Unable to set window title, check if matplotlib version is correct as per requirements.txt")
            fig.canvas.mpl_connect('close_event', closed)

            init_data = conn.root.get_data()
            
            if(init_data == 'ERROR'):
                cancellation_token.set()
                self._logger.error("Plot Initialisation Failed. Please check callback and try again")
                return
            
            init_data = np.array(init_data)                       
                                
            if(not (init_data.ndim == 1)):
                cancellation_token.set()
                self._logger.error("Plotter only supports 1D arrays.")
                return     

            lines = []
            for i in range(len(init_data)):
                line, = ax.plot(0,init_data[i], label = f'{i}')
                lines.append(line)

            plt.legend()

            fetch_data_f.start()
            
            time.sleep(0.5)          
            
            def update(frame):
                if(cancellation_token.is_set()):
                    plt.close(fig)
                
                fetch_data_lock.acquire()
                for i in range(len(lines)):
                    lines[i].set_data(time_vals, vals[:, i])
                fetch_data_lock.release() 

                span = np.abs(np.max(vals) - np.min(vals))
                y_min = np.min(vals) - (0.2*span)
                y_max = np.max(vals) + (0.2*span)
                
                # print(y_min, y_max, np.min(vals), np.max(vals))
                
                ax.set_ylim([y_min, y_max])
                ax.set_xlim([np.min(time_vals), np.max(time_vals)])
                
                return tuple(lines)
                        
            plot_anim = animation.FuncAnimation(fig=fig, func=update, interval=plot_rate)
            plt.show()
            
            fetch_data_f.join()
            
            try:    
                conn.root.stop_server()
            except EOFError as e:
                # print("Server was closed")
                pass
            conn.close()
        
        def server_thread():  
            gen_data_service = CustomService(get_data_callback)    
            data_server = ThreadedServer(gen_data_service, port=(18820 + self.count), protocol_config={'allow_public_attrs': True})
            gen_data_service.update_server(data_server)
            self.count = self.count + 1
            data_server.start()
            
        data_server_t = threading.Thread(target=server_thread, daemon=True)
        plot_data_p = mp.Process(target=plot_data, args=(self.count,self._plot_rate, self._num_samples, self._data_rate), daemon=True)
        
        data_server_t.start()
        plot_data_p.start()            
        