import numpy as np
import matplotlib.pyplot as plt
import sympy as sym
import math
import os
import subprocess
import fileinput
import re
import matlab.engine
import sys
from ClassifyMin import *
from HelperFunctions import *
# from CellScript import *
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm
from vtk.util import numpy_support
from pyevtk.hl import gridToVTK
import time
import matplotlib.ticker as ticker

import matplotlib as mpl
from matplotlib.ticker import MultipleLocator,FormatStrFormatter,MaxNLocator
import pandas as pd

# from matplotlib import rc
# rc('text', usetex=True) # Use LaTeX font
#
# import seaborn as sns
# sns.set(color_codes=True)


def format_func(value, tick_number):
    # find number of multiples of pi/2
    # N = int(np.round(2 * value / np.pi))
    # if N == 0:
    #     return "0"
    # elif N == 1:
    #     return r"$\pi/2$"
    # elif N == -1:
    #     return r"$-\pi/2$"
    # elif N == 2:
    #     return r"$\pi$"
    # elif N % 2 > 0:
    #     return r"${0}\pi/2$".format(N)
    # else:
    #     return r"${0}\pi$".format(N // 2)
    ##find number of multiples of pi/2
    N = int(np.round(4 * value / np.pi))
    if N == 0:
        return "0"
    elif N == 1:
        return r"$\pi/4$"
    elif N == -1:
        return r"$-\pi/4$"
    elif N == 2:
        return r"$\pi/2$"
    elif N == -2:
        return r"$-\pi/2$"
    elif N % 2 > 0:
        return r"${0}\pi/2$".format(N)
    else:
        return r"${0}\pi$".format(N // 2)



def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]


def find_nearestIdx(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx


InputFile  = "/inputs/computeMuGamma.parset"
OutputFile = "/outputs/outputMuGamma.txt"
# --------- Run  from src folder:
path_parent = os.path.dirname(os.getcwd())
os.chdir(path_parent)
path = os.getcwd()
print(path)
InputFilePath = os.getcwd()+InputFile
OutputFilePath = os.getcwd()+OutputFile
print("InputFilepath: ", InputFilePath)
print("OutputFilepath: ", OutputFilePath)
print("Path: ", path)

print('---- Input parameters: -----')

q1=1;
q2=2;
q12=1/2;
q3=((4*q1*q2)**0.5-q12)/2;
# H=[2*q1,q12+2*q3;q12+2*q3,2*q2];

H = np.array([[2*q1, q12+2*q3], [q12+2*q3,2*q2] ])
A = np.array([[q1,1/2*q12], [1/2*q12,q2] ])
abar = np.array([q12+2*q3, 2*q2])
abar = (1.0/math.sqrt((q12+2*q3)**2+(2*q2)**2))*abar

print('abar:',abar)

b = np.linalg.lstsq(A, abar)[0]
print('b',b)


# print('abar:',np.shape(abar))
# print('np.transpose(abar):',np.shape(np.transpose(abar)))
sstar = (1/(q1+q2))*abar.dot(A.dot(b))
# sstar = (1/(q1+q2))*abar.dot(tmp)
print('sstar', sstar)
abarperp= np.array([abar[1],-abar[0]])
print('abarperp:',abarperp)

print('----------------------------')

# ----------------------------------------------------------------

N=1000;
T = np.linspace(-sstar*(q12+2*q3)/(2*q2), sstar*(2*q2)/(q12+2*q3), num=N)
print('T:', T)

kappas = []
alphas = []
# G.append(float(s[0]))

G_container = []
abar_container = []


for t in T :

    abar_current = sstar*abar+t*abarperp;
    # print('abar_current', abar_current)
    abar_current[abar_current < 1e-10] = 0
    # print('abar_current', abar_current)
    # G = np.array([[2*q1, q12+2*q3], [q12+2*q3,2*q2] ])
    G = [abar_current[0], abar_current[1] , (2*abar_current[0]*abar_current[1])**0.5 ]
    G_container.append(G)
    abar_container.append(abar_current)
    e = [(abar_current[0]/(abar_current[0]+abar_current[1]))**0.5, (abar_current[1]/(abar_current[0]+abar_current[1]))**0.5]
    kappa = abar_current[0]+abar_current[1]
    alpha = math.atan2(e[1], e[0])
    print('angle current:', alpha)
    kappas.append(kappa)
    alphas.append(alpha)




# idx_1 = np.where(alphas == np.pi/4)
idx_1 = np.where(np.round(alphas,3) == round(np.pi/3,3))
idx_2 = np.where(np.round(alphas,3) == 0.0)
idx_3 = np.where(np.round(alphas,3) == round(np.pi/4,3))

# idx_3 = np.where(alphas == 0)

print('Index idx_1:', idx_1)
print('Index idx_2:', idx_2)
print('Index idx_3:', idx_3)
print('Index idx_1[0][0]:', idx_1[0][0])
print('Index idx_2[0][0]:', idx_2[0][0])
print('Index idx_3[0][0]:', idx_3[0][0])

alphas = np.array(alphas)
kappas = np.array(kappas)

print('alphas[idx_1[0][0]]', alphas[idx_1[0][0]])
print('kappas[idx_1[0][0]]', kappas[idx_1[0][0]])
print('alphas[idx_2[0][0]]', alphas[idx_2[0][0]])
print('kappas[idx_2[0][0]]', kappas[idx_2[0][0]])
print('alphas[idx_3[0][0]]', alphas[idx_3[0][0]])
print('kappas[idx_3[0][0]]', kappas[idx_3[0][0]])



# print('kappas:',kappas)
# print('alphas:',alphas)
print('min alpha:', min(alphas))
print('min kappa:', min(kappas))

# mpl.rcParams['text.usetex'] = True
# mpl.rcParams["font.family"] = "serif"
# mpl.rcParams["font.size"] = "9"
#
# # plt.style.use("seaborn-darkgrid")
# plt.style.use("seaborn-whitegrid")
plt.style.use("seaborn")
# plt.style.use("seaborn-paper")
# plt.style.use('ggplot')
# plt.rcParams["font.family"] = "Avenir"
# plt.rcParams["font.size"] = 16

# plt.style.use("seaborn-darkgrid")
mpl.rcParams['text.usetex'] = True
mpl.rcParams["font.family"] = "serif"
mpl.rcParams["font.size"] = "10"
# mpl.rcParams['xtick.labelsize'] = 16mpl.rcParams['xtick.major.size'] = 2.5
# mpl.rcParams['xtick.bottom'] = True
# mpl.rcParams['ticks'] = True
mpl.rcParams['xtick.bottom'] = True
mpl.rcParams['xtick.major.size'] = 3
mpl.rcParams['xtick.minor.size'] = 1.5
mpl.rcParams['xtick.major.width'] = 0.75
mpl.rcParams['ytick.left'] = True
mpl.rcParams['ytick.major.size'] = 3
mpl.rcParams['ytick.minor.size'] = 1.5
mpl.rcParams['ytick.major.width'] = 0.75

mpl.rcParams.update({'font.size': 10})
mpl.rcParams['axes.labelpad'] = 2.0


width = 6.28 *0.5
height = width / 1.618
fig = plt.figure()
# ax = plt.axes((0.15,0.21 ,0.75,0.75))
ax = plt.axes((0.15,0.21 ,0.8,0.75))
ax.tick_params(axis='x',which='major', direction='out',pad=5)
ax.tick_params(axis='y',which='major', length=3, width=1, direction='out',pad=3)
# ax.xaxis.set_major_locator(MultipleLocator(0.1))
# ax.xaxis.set_minor_locator(MultipleLocator(0.05))
# ax.xaxis.set_major_locator(plt.MultipleLocator(np.pi / 8))
# ax.xaxis.set_minor_locator(plt.MultipleLocator(np.pi / 16))
ax.xaxis.set_major_locator(plt.MultipleLocator(np.pi / 2))
ax.xaxis.set_minor_locator(plt.MultipleLocator(np.pi / 4))
ax.xaxis.set_major_formatter(plt.FuncFormatter(format_func))
ax.grid(True,which='major',axis='both',alpha=0.3)




ax.plot(alphas, kappas, 'royalblue', zorder=3, )
ax.plot(-1.0*alphas, kappas, 'red', zorder=3, )


ax.scatter(-1*alphas[idx_1[0][0]],kappas[idx_1[0][0]], marker='x', color='black',zorder=5, s=40)
ax.scatter(alphas[idx_2[0][0]],kappas[idx_2[0][0]], marker='x', color='black' ,zorder=5, s=40)
ax.scatter(alphas[idx_3[0][0]],kappas[idx_3[0][0]], marker='x', color='black', zorder=5, s=40)

label = [r'$\mathcal S^+_{Q,B}$', '$T\mathcal S^+_{Q,B}$']


# ax.annotate(label, xy=np.array([[np.pi/4,0.55],[np.pi/4,0.55]]))



ax.annotate(r'$\mathcal S^+_{Q,B}$', np.array([np.pi/4,0.55]), color= 'royalblue', size=12)
ax.annotate(r'$T\mathcal S^+_{Q,B}$', np.array([-np.pi/4-0.55,0.55]), color= 'red', size=12)

#

ax.set_xlabel(r"angle $\alpha$" ,fontsize=10, labelpad=2)
ax.set_ylabel(r"curvature  $\kappa$", fontsize=10, labelpad=2)



ax.set_xticks([-np.pi/2, -np.pi/4  ,0,  np.pi/4,  np.pi/2 ])
# labels = ['$0$',r'$\pi/8$', r'$\pi/4$' ,r'$3\pi/8$' , r'$\pi/2$']
# ax.set_yticklabels(labels)





# ax.legend(loc='upper right')



fig.set_size_inches(width, height)
fig.savefig('Plot-1-ParameterFamily2.pdf')

plt.show()