import serial
import time
import numpy as np
import re
from matplotlib import pyplot as plt
from matplotlib import animation
from scipy import *
from scipy.special import eval_hermitenorm

fig, ax = plt.subplots(figsize=(5, 3))
ax.set(xlim=(0,120),ylim=(0,120))

pi = 3.1415926535897932385

x_range = 32
y_range = 32

alpha = 3.6835
beta = 1.8591

grad = np.sqrt((1-beta/alpha)/(1+(beta/alpha)))*np.sqrt(pi/2.0)   #.508
scale = np.sqrt(pi/2.0)/np.log(20.0)   #.45
scale_var = scale*20.0*np.log(20.0)  #.013

samples = 32

cls_write = samples*4

def fac(n):
    result = 1.0
    if n == 0:
        return 1
    else:       
        for x in range(1,n+1,1):
            result = result*x
        return result

def phi_map(phi_range):
    phi = []
    for x in range(phi_range):
        phi.append(0)
    return phi

def amp_map(amp_range):
    amp = []
    for x in range(amp_range):
        amp.append(0)
    return amp

def functional_mapping(n,rho,theta,phi_val):
    one = np.sqrt(n+1)/(2**(n+2)*fac(n+1))*eval_hermitenorm(n+1,rho*scale*(n+1))*np.exp(-(rho*scale*(n+1))**2)*np.cos(theta)*phi_val[n]
    return one

def functional_map(tat, pos_x, pos_y, scale):
    value = 0.0
    rho = scale_var*np.lib.scimath.sqrt(pos_x**2+pos_y**2)
    theta = np.arctan2(pos_x,pos_y)
    phi1_val = np.sin(tat[0])
    amp1_val = tat[1]
    phi2_val = np.sin(tat[2])
    amp2_val = tat[3]
   
    phi_val = np.subtract(phi2_val,phi1_val)
    amp_val = np.subtract(amp2_val,amp1_val)
   
    for n in range(samples):
        value += amp_val[n]*functional_mapping(n,rho,theta,phi_val)
    return value

def get_value(pos_x, pos_y, tat, scale): 
    return functional_map(tat, pos_x, pos_y, scale)

def take_data(phi1_values, amp1_values, phi2_values, amp2_values, samples):
    ser.write(bytes("1", 'utf-8'))
    let = str(ser.read(cls_write))
    roger = [int(s) if '.' in s else int(s) for s in re.findall(r'-?\d+\.?\d*', let) if s]

    for j in range(samples):
        mode = roger[j+0]
        amp = roger[j+1]
        phi = roger[j+2]
        
        amp1_values[mode] = 10**(-grad*float(amp/(scale_var**2)))
        phi1_values[mode] = pi*float(phi/(scale_var**2))
        
    ser.write(bytes("1", 'utf-8'))
    let = str(ser.read(cls_write))
    roger = [int(s) if '.' in s else int(s) for s in re.findall(r'-?\d+\.?\d*', let) if s]
    
    for j in range(samples):
        mode = int(roger[j+0])
        amp = int(roger[j+1])
        phi = int(roger[j+2])
    
        amp2_values[mode] = 10**(-grad*float(amp/(scale_var**2)))
        phi2_values[mode] = pi*float(phi/(scale_var**2))
            
    print(amp1_values)
    print(phi1_values)
    print(amp2_values)
    print(phi2_values)

    return [phi1_values, amp1_values, phi2_values, amp2_values]

def animate(i):
   ser.open()
   newdata = take_data(phi1_values, amp1_values, phi2_values, amp2_values, samples)
   ser.close()
   Z = get_value(X,Y,newdata,scale)
   im.set_data(Z)
   return [im]

ser = serial.Serial(
   port = '/dev/ttyACM1',
   baudrate = 115200,
   parity = serial.PARITY_NONE,
   stopbits = serial.STOPBITS_ONE,
   bytesize = serial.EIGHTBITS,
   timeout = 1
)

phi_values = phi_map(samples)
amp_values = amp_map(samples)

phi1_values = phi_map(samples)
amp1_values = amp_map(samples)
phi2_values = phi_map(samples)
amp2_values = amp_map(samples)

ser.write(bytes("1", 'utf-8'))
let = str(ser.read(cls_write))
roger = [float(s) if '.' in s else int(s) for s in re.findall(r'-?\d+\.?\d*', let) if s]

for j in range(samples):
    mode = int(roger[j])
    amp = int(roger[j])
    phi = int(roger[j])
    
    amp1_values[mode] = -grad*(amp/(scale_var**2))
    phi1_values[mode] = pi*float(phi/(scale_var**2))

time.sleep(6)
ser.write(bytes("1", 'utf-8'))
let = str(ser.read(cls_write))
roger = [float(s) if '.' in s else int(s) for s in re.findall(r'-?\d+\.?\d*', let) if s]

for j in range(samples):
    mode = int(roger[j+0])
    amp = int(roger[j+1])
    phi = int(roger[j+2])

    amp2_values[mode] = 10**(-grad*float(amp/(scale_var**2)))
    phi2_values[mode] = pi*float(phi/(scale_var**2))
    
max_index = samples*32*4

newdata = take_data(phi1_values, amp1_values, phi2_values, amp2_values, samples)
data = [phi_values, amp_values, samples]

x = np.arange(-x_range,x_range,1)
y = np.arange(-y_range,y_range,1)
X, Y = np.meshgrid(x,y)

Z = get_value(X,Y,data,scale)

print(amp1_values)
print(phi1_values)
print(amp2_values)
print(phi2_values)

im = plt.imshow(Z, cmap='bone', origin='lower', vmin=-0.0005, vmax=0.0005)

ani = animation.FuncAnimation(fig, animate, frames=max_index)
plt.colorbar()
plt.show()
ser.close()