import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
from astropy import units as u, constants as ac

from Calculator import *
from Helpers import *
from Potentials import *

#####################################################################
# PLOTTING                                                          #
#####################################################################

def generate_label(v, dv):
    unit = 'Hz'
    if v <= 0.0:
        v        = np.nan
        dv       = np.nan
        unit = 'Hz'
    elif v > 0.0 and orderOfMagnitude(v) > 2:
        v    = v / 1e3 # in kHz
        dv   = dv / 1e3 # in kHz
        unit = 'kHz'
    tf_label = '\u03BD = %.1f \u00B1 %.2f %s'% tuple([v,dv,unit])
    return tf_label

def plotHarmonicFit(Positions, TrappingPotential, TrapDepthsInKelvin, axis, popt, pcov):
    v = popt[0]
    dv = pcov[0][0]**0.5
    happrox = harmonic_potential(Positions[axis, :].value, *popt)
    fig = plt.figure(figsize=(12, 6))
    ax = fig.add_subplot(121)
    ax.set_title('Fit to Potential')
    plt.plot(Positions[axis, :].value, happrox, '-r', label = '\u03BD = %.1f \u00B1 %.2f Hz'% tuple([v,dv]))
    plt.plot(Positions[axis, :], TrappingPotential[axis], 'ob', label = 'Gaussian Potential')
    plt.xlabel('Distance (um)', fontsize= 12, fontweight='bold')
    plt.ylabel('Trap Potential (uK)', fontsize= 12, fontweight='bold')
    plt.ylim([-TrapDepthsInKelvin[0].value, max(TrappingPotential[axis].value)])
    plt.grid(visible=1)
    plt.legend(prop={'size': 12, 'weight': 'bold'})

    bx = fig.add_subplot(122)
    bx.set_title('Fit Residuals')
    plt.plot(Positions[axis, :].value, TrappingPotential[axis].value - happrox, 'ob')
    plt.xlabel('Distance (um)', fontsize= 12, fontweight='bold')
    plt.ylabel('$U_{trap} - U_{Harmonic}$', fontsize= 12, fontweight='bold')
    plt.xlim([-10, 10])
    plt.ylim([-1e-2, 1e-2])
    plt.grid(visible=1)
    plt.tight_layout()
    plt.show()

def plotGaussianFit(Positions, TrappingPotential, popt, pcov):
    extracted_waist = popt[1]
    dextracted_waist = pcov[1][1]**0.5
    gapprox = gaussian_potential(Positions, *popt)
    fig = plt.figure(figsize=(12, 6))
    ax = fig.add_subplot(121)
    ax.set_title('Fit to Potential')
    plt.plot(Positions, gapprox, '-r', label = 'waist = %.1f \u00B1 %.2f um'% tuple([extracted_waist,dextracted_waist]))
    plt.plot(Positions, TrappingPotential, 'ob', label = 'Gaussian Potential')
    plt.xlabel('Distance (um)', fontsize= 12, fontweight='bold')
    plt.ylabel('Trap Potential (uK)', fontsize= 12, fontweight='bold')
    plt.ylim([min(TrappingPotential), max(TrappingPotential)])
    plt.grid(visible=1)
    plt.legend(prop={'size': 12, 'weight': 'bold'})

    bx = fig.add_subplot(122)
    bx.set_title('Fit Residuals')
    plt.plot(Positions, TrappingPotential - gapprox, 'ob')
    plt.xlabel('Distance (um)', fontsize= 12, fontweight='bold')
    plt.ylabel('$U_{trap} - U_{Gaussian}$', fontsize= 12, fontweight='bold')
    plt.xlim([-10, 10])
    plt.ylim([-1, 1])
    plt.grid(visible=1)
    plt.tight_layout()
    plt.show()

def plotPotential(Positions, ComputedPotentials, options, Params = [], listToIterateOver = [], save = False):

    axis = options['axis']

    plt.figure(figsize=(9, 7))
    for i in range(np.size(ComputedPotentials, 0)):
        
        if i % 2 == 0:
            j = int(i / 2)
        else:
            j = int((i - 1) / 2)
        
        IdealTrapDepthInKelvin = Params[j][0][0]
        EffectiveTrapDepthInKelvin = Params[j][0][1]
        
        idealv = Params[j][2][0][0]
        idealdv = Params[j][2][0][1]
        
        if options['extract_trap_frequencies']:
            v = Params[j][2][1][0]
            dv = Params[j][2][1][1]
        else:
            v = np.nan
            dv = np.nan
        
        if listToIterateOver:
            if np.size(ComputedPotentials, 0) == len(listToIterateOver):
                plt.plot(Positions[axis], ComputedPotentials[i][axis], label = 'Trap Depth = ' + str(round(EffectiveTrapDepthInKelvin.value, 2)) + ' ' + str(EffectiveTrapDepthInKelvin.unit) + '; ' + generate_label(v, dv)) 
            else:
                if i % 2 == 0:
                    plt.plot(Positions[axis], ComputedPotentials[i][axis], '--', label = 'Trap Depth = ' + str(round(IdealTrapDepthInKelvin.value, 2)) + ' ' + str(IdealTrapDepthInKelvin.unit) + '; ' + generate_label(idealv, idealdv)) 
                elif i % 2 != 0:
                    plt.plot(Positions[axis], ComputedPotentials[i][axis], label = 'Effective Trap Depth = ' + str(round(EffectiveTrapDepthInKelvin.value, 2)) + ' ' + str(EffectiveTrapDepthInKelvin.unit) + '; ' + generate_label(v, dv))
        else:
            if i % 2 == 0:
                plt.plot(Positions[axis], ComputedPotentials[i][axis], '--', label = 'Trap Depth = ' + str(round(IdealTrapDepthInKelvin.value, 2)) + ' ' + str(IdealTrapDepthInKelvin.unit) + '; ' + generate_label(idealv, idealdv))
            elif i % 2 != 0:
                plt.plot(Positions[axis], ComputedPotentials[i][axis], label = 'Effective Trap Depth = ' + str(round(EffectiveTrapDepthInKelvin.value, 2)) + ' ' + str(EffectiveTrapDepthInKelvin.unit) + '; ' + generate_label(v, dv))
    if axis == 0:
        dir = 'X - Horizontal'
    elif axis == 1:
        dir = 'Y - Propagation'
    else:
        dir = 'Z - Vertical'
    
    plt.ylim(top = 0)
    plt.xlabel(dir + ' Direction (um)', fontsize= 12, fontweight='bold')
    plt.ylabel('Trap Potential (uK)', fontsize= 12, fontweight='bold')
    plt.tight_layout()
    plt.grid(visible=1)
    plt.legend(loc=3, prop={'size': 12, 'weight': 'bold'})
    if save:
        plt.savefig('pot_' + dir + '.png')
    plt.show()

def plotIntensityProfileAndPotentials(positions, waists, I, U):
    
    x_Positions = positions[0]
    z_Positions = positions[1]
    
    w_x = waists[0]
    dw_x = waists[1]
    w_z = waists[2]
    dw_x = waists[3]

    ar = w_x/w_z
    dar = ar * np.sqrt((dw_x/w_x)**2 + (dw_x/w_z)**2)
    
    fig = plt.figure(figsize=(12, 6))
    ax = fig.add_subplot(121)
    ax.set_title('Intensity Profile ($MW/cm^2$)\n Aspect Ratio = %.2f \u00B1 %.2f um'% tuple([ar,dar]))
    im = plt.imshow(np.transpose(I.value), cmap="coolwarm", extent=[np.min(x_Positions.value), np.max(x_Positions.value), np.min(z_Positions.value), np.max(z_Positions.value)])
    plt.xlabel('X - Horizontal (um)', fontsize= 12, fontweight='bold')
    plt.ylabel('Z - Vertical (um)', fontsize= 12, fontweight='bold')
    ax.set_aspect('equal')
    fig.colorbar(im, fraction=0.046, pad=0.04, orientation='vertical')
    
    bx = fig.add_subplot(122)
    bx.set_title('Trap Potential')
    plt.plot(x_Positions, U[:, np.where(z_Positions==0)[0][0]], label = 'X - Horizontal')
    plt.plot(z_Positions, U[np.where(x_Positions==0)[0][0], :], label = 'Z - Vertical')
    plt.ylim(top = 0)
    plt.xlabel('Extent (um)', fontsize= 12, fontweight='bold')
    plt.ylabel('Depth (uK)', fontsize= 12, fontweight='bold')
    plt.tight_layout()
    plt.grid(visible=1)
    plt.legend(prop={'size': 12, 'weight': 'bold'})
    plt.show()

def plotAlphas():

    modulation_depth = np.arange(0, 1.1, 0.1)
    Alphas, fin_mod_dep, alpha_x, alpha_y, dalpha_x, dalpha_y = convert_modulation_depth_to_alpha(modulation_depth)   

    plt.figure()
    plt.errorbar(fin_mod_dep, alpha_x, yerr = dalpha_x, fmt= 'ob', label = 'From Horz TF', markersize=5, capsize=5)
    plt.errorbar(fin_mod_dep, alpha_y, yerr = dalpha_y, fmt= 'or', label = 'From Vert TF', markersize=5, capsize=5)
    plt.plot(modulation_depth, Alphas, '--g')
    plt.xlabel('Modulation depth', fontsize= 12, fontweight='bold')
    plt.ylabel('$\\alpha$', fontsize= 12, fontweight='bold')
    plt.tight_layout()
    plt.grid(visible=1)
    plt.legend(prop={'size': 12, 'weight': 'bold'})
    plt.show()

def plotTemperatures(w_x, w_z, plot_against_mod_depth = True):

    modulation_depth = np.arange(0, 1.1, 0.1)
    w_xs = w_x * convert_modulation_depth_to_alpha(modulation_depth)[0]
    new_aspect_ratio  = w_xs / w_z
    Temperatures, fin_mod_dep, T_x, T_y, dT_x, dT_y = convert_modulation_depth_to_temperature(modulation_depth)  
    measured_aspect_ratio = (w_x * convert_modulation_depth_to_alpha(fin_mod_dep)[0]) / w_z

    plt.figure()
    if plot_against_mod_depth:
        plt.errorbar(fin_mod_dep, T_x, yerr = dT_x, fmt= 'ob', label = 'Horz direction', markersize=5, capsize=5)
        plt.errorbar(fin_mod_dep, T_y, yerr = dT_y, fmt= 'or', label = 'Vert direction', markersize=5, capsize=5)
        plt.plot(modulation_depth, Temperatures, '--g')
        xlabel = 'Modulation depth'
    else:
        plt.errorbar(measured_aspect_ratio, T_x, yerr = dT_x, fmt= 'ob', label = 'Horz direction', markersize=5, capsize=5)
        plt.errorbar(measured_aspect_ratio, T_y, yerr = dT_y, fmt= 'or', label = 'Vert direction', markersize=5, capsize=5)
        plt.plot(new_aspect_ratio, Temperatures, '--g')
        xlabel = 'Aspect Ratio'

    plt.xlabel(xlabel, fontsize= 12, fontweight='bold')
    plt.ylabel('Temperature (uK)', fontsize= 12, fontweight='bold')
    plt.tight_layout()
    plt.grid(visible=1)
    plt.legend(prop={'size': 12, 'weight': 'bold'})
    plt.show()

def plotTrapFrequencies(v_x, v_y, v_z, modulation_depth, new_aspect_ratio, plot_against_mod_depth = True):
    fig, ax3 = plt.subplots(figsize=(8, 6))
    
    if plot_against_mod_depth:
        ln1 = ax3.plot(modulation_depth, v_x, '-or', label = 'v_x')
        ln2 = ax3.plot(modulation_depth, v_z, '-^b', label = 'v_z')
        ax4 = ax3.twinx()
        ln3 = ax4.plot(modulation_depth, v_y, '-*g', label = 'v_y')
        xlabel = 'Modulation depth'
    else:
        ln1 = ax3.plot(new_aspect_ratio, v_x, '-or', label = 'v_x')
        ln2 = ax3.plot(new_aspect_ratio, v_z, '-^b', label = 'v_z')
        ax4 = ax3.twinx()
        ln3 = ax4.plot(new_aspect_ratio, v_y, '-*g', label = 'v_y')
        xlabel = 'Aspect Ratio'

    ax3.set_xlabel(xlabel, fontsize= 12, fontweight='bold')
    ax3.set_ylabel('Trap Frequency (Hz)', fontsize= 12, fontweight='bold')
    ax3.tick_params(axis="y", labelcolor='b')
    ax4.set_ylabel('Trap Frequency (Hz)', fontsize= 12, fontweight='bold')
    ax4.tick_params(axis="y", labelcolor='g')
    plt.tight_layout()
    plt.grid(visible=1)
    lns = ln1+ln2+ln3
    labs = [l.get_label() for l in lns]
    ax3.legend(lns, labs, prop={'size': 12, 'weight': 'bold'})
    plt.show()

def plotMeasuredTrapFrequencies(fx, dfx, fy, dfy, fz, dfz,  modulation_depth_radial, modulation_depth_axial, w_x, w_z, plot_against_mod_depth = True):
    
    alpha_x =  [(fx[0]/x)**(2/3) for x in fx]
    dalpha_x = [alpha_x[i] * np.sqrt((dfx[0]/fx[0])**2 + (dfx[i]/fx[i])**2) for i in range(len(fx))]
    alpha_y = [(fy[0]/y)**2 for y in fy]
    dalpha_y = [alpha_y[i] * np.sqrt((dfy[0]/fy[0])**2 + (dfy[i]/fy[i])**2) for i in range(len(fy))]

    avg_alpha = [(g + h) / 2 for g, h in zip(alpha_x, alpha_y)]
    new_aspect_ratio = (w_x * avg_alpha) / w_z
    
    
    if plot_against_mod_depth:
        fig, ax1 = plt.subplots(figsize=(8, 6))
        ax2 = ax1.twinx()
        ax1.errorbar(modulation_depth_radial, fx, yerr = dfx, fmt= 'or', label = 'v_x', markersize=5, capsize=5)
        ax2.errorbar(modulation_depth_axial, fy, yerr = dfy, fmt= '*g', label = 'v_y', markersize=5, capsize=5)
        ax1.errorbar(modulation_depth_radial, fz, yerr = dfz, fmt= '^b', label = 'v_z', markersize=5, capsize=5)
        ax1.set_xlabel('Modulation depth', fontsize= 12, fontweight='bold')
        ax1.set_ylabel('Trap Frequency (kHz)', fontsize= 12, fontweight='bold')
        ax1.tick_params(axis="y", labelcolor='b')
        ax2.set_ylabel('Trap Frequency (Hz)', fontsize= 12, fontweight='bold')
        ax2.tick_params(axis="y", labelcolor='g')
        h1, l1 = ax1.get_legend_handles_labels()
        h2, l2 = ax2.get_legend_handles_labels()
        ax1.legend(h1+h2, l1+l2, loc=0, prop={'size': 12, 'weight': 'bold'})
    else:
        plt.figure()
        plt.errorbar(new_aspect_ratio, fx, yerr = dfx, fmt= 'or', label = 'v_x', markersize=5, capsize=5)
        plt.errorbar(new_aspect_ratio, fz, yerr = dfz, fmt= '^b', label = 'v_z', markersize=5, capsize=5)
        plt.xlabel('Aspect Ratio', fontsize= 12, fontweight='bold')
        plt.ylabel('Trap Frequency (kHz)', fontsize= 12, fontweight='bold')
        plt.legend(prop={'size': 12, 'weight': 'bold'})

    plt.tight_layout()
    plt.grid(visible=1)
    plt.show() 

def plotRatioOfTrapFrequencies(fx, fy, fz, dfx, dfy, dfz, v_x, v_y, v_z, modulation_depth, w_x, w_z, plot_against_mod_depth = True):
    
    w_xs = w_x * convert_modulation_depth_to_alpha(modulation_depth)[0]
    new_aspect_ratio  = w_xs / w_z

    plt.figure()
    
    if plot_against_mod_depth:
        plt.errorbar(modulation_depth, fx/v_x, yerr = dfx/v_x, fmt= 'or', label = 'b/w horz TF', markersize=5, capsize=5)
        plt.errorbar(modulation_depth, fy/v_y, yerr = dfy/v_y, fmt= '*g', label = 'b/w axial TF', markersize=5, capsize=5)
        plt.errorbar(modulation_depth, fz/v_z, yerr = dfz/v_z, fmt= '^b', label = 'b/w vert TF', markersize=5, capsize=5)
        xlabel = 'Modulation depth'
    else:
        plt.errorbar(new_aspect_ratio, fx/v_x, yerr = dfx/v_x, fmt= 'or', label = 'b/w horz TF', markersize=5, capsize=5)
        plt.errorbar(new_aspect_ratio, fy/v_y, yerr = dfy/v_y, fmt= '*g', label = 'b/w axial TF', markersize=5, capsize=5)
        plt.errorbar(new_aspect_ratio, fz/v_z, yerr = dfz/v_z, fmt= '^b', label = 'b/w vert TF', markersize=5, capsize=5)
        xlabel = 'Aspect Ratio'

    plt.xlabel(xlabel, fontsize= 12, fontweight='bold')
    plt.ylabel('Ratio', fontsize= 12, fontweight='bold')
    plt.tight_layout()
    plt.grid(visible=1)
    plt.legend(prop={'size': 12, 'weight': 'bold'})
    plt.show() 

def plotScatteringLengths(B_range = [0, 2.59]):
    BField = np.arange(B_range[0], B_range[1], 1e-3) * u.G
    a_s_array = np.zeros(len(BField)) * ac.a0
    for idx in range(len(BField)):
        a_s_array[idx], a_bkg = scatteringLength(BField[idx])
    rmelmIdx = [i for i, x in enumerate(np.isinf(a_s_array.value)) if x] 
    for x in rmelmIdx:
        a_s_array[x-1] = np.inf * ac.a0
    
    plt.figure(figsize=(9, 7))
    plt.plot(BField, a_s_array/ac.a0, '-b')
    plt.axhline(y = a_bkg/ac.a0, color = 'r', linestyle = '--')
    plt.text(min(BField.value) + 0.5, (a_bkg/ac.a0).value + 1, '$a_{bkg}$ = %.2f a0' %((a_bkg/ac.a0).value), fontsize=14, fontweight='bold')
    plt.xlim([min(BField.value), max(BField.value)])
    plt.ylim([65, 125])
    plt.xlabel('B field (G)', fontsize= 12, fontweight='bold')
    plt.ylabel('Scattering length (a0)', fontsize= 12, fontweight='bold')
    plt.tight_layout()
    plt.grid(visible=1)
    plt.show()

def plotCollisionRatesAndPSD(Gamma_elastic, PSD, modulation_depth, new_aspect_ratio, plot_against_mod_depth = True):
    fig, ax1 = plt.subplots(figsize=(8, 6))
    ax2 = ax1.twinx()

    if plot_against_mod_depth:
        ax1.plot(modulation_depth, Gamma_elastic, '-ob')
        ax2.plot(modulation_depth, PSD, '-*r')
        ax2.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
        xlabel = 'Modulation depth'
    else:
        ax1.plot(new_aspect_ratio, Gamma_elastic, '-ob')
        ax2.plot(new_aspect_ratio, PSD, '-*r')
        ax2.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
        xlabel = 'Aspect Ratio'

    ax1.set_xlabel(xlabel, fontsize= 12, fontweight='bold')
    ax1.set_ylabel('Elastic Collision Rate', fontsize= 12, fontweight='bold')
    ax1.tick_params(axis="y", labelcolor='b')
    ax2.set_ylabel('Phase Space Density', fontsize= 12, fontweight='bold')
    ax2.tick_params(axis="y", labelcolor='r')
    plt.tight_layout()
    plt.grid(visible=1)
    plt.show()

#####################################################################