import numpy as np
import matplotlib.pyplot as plt
import sympy as sym
import math
import os
import subprocess
import fileinput
import re
import matlab.engine
import sys
# from ClassifyMin import *
from ClassifyMin_New import *
from HelperFunctions import *
# from CellScript import *
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm
from vtk.util import numpy_support
from pyevtk.hl import gridToVTK
import time
import matplotlib.ticker as ticker

import matplotlib as mpl
from matplotlib.ticker import MultipleLocator,FormatStrFormatter,MaxNLocator
import pandas as pd
import seaborn as sns

# from matplotlib import rc
# rc('text', usetex=True) # Use LaTeX font
#
# import seaborn as sns
# sns.set(color_codes=True)


def format_func(value, tick_number):
    # # find number of multiples of pi/2
    # N = int(np.round(2 * value / np.pi))
    # if N == 0:
    #     return "0"
    # elif N == 1:
    #     return r"$\pi/2$"
    # elif N == 2:
    #     return r"$\pi$"
    # elif N % 2 > 0:
    #     return r"${0}\pi/2$".format(N)
    # else:
    #     return r"${0}\pi$".format(N // 2)
    # find number of multiples of pi/2
    N = int(np.round(4 * value / np.pi))
    if N == 0:
        return "0"
    elif N == 1:
        return r"$\pi/4$"
    elif N == 2:
        return r"$\pi/2$"
    elif N % 2 > 0:
        return r"${0}\pi/2$".format(N)
    else:
        return r"${0}\pi$".format(N // 2)



def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]


def find_nearestIdx(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx


InputFile  = "/inputs/computeMuGamma.parset"
OutputFile = "/outputs/outputMuGamma.txt"
# --------- Run  from src folder:
path_parent = os.path.dirname(os.getcwd())
os.chdir(path_parent)
path = os.getcwd()
print(path)
InputFilePath = os.getcwd()+InputFile
OutputFilePath = os.getcwd()+OutputFile
print("InputFilepath: ", InputFilePath)
print("OutputFilepath: ", OutputFilePath)
print("Path: ", path)

print('---- Input parameters: -----')
alpha = 10
mu1 = 1.0
rho1 = 1.0
beta = 2.0  #5.0
theta = 1.0/8.0
#

alpha = -0.5
beta = 40.0
theta= 1/8.0



# # INTERESTING! from pi/2:
alpha = -0.5
beta = 40.0
theta= 1/8.0
#
# # # INTERESTING! from pi/2:
# alpha = -0.2
# beta = 25.0
# theta= 1/2

# INTERESTING!:
# alpha = -0.5
# beta = 5.0
# theta= 1/30



# INTERESTING!:
# alpha = -0.25
# beta = 10.0
# theta= 3/4


# # INTERESTING!:
alpha = -0.25
beta = 10.0
theta= 1/8

#
# INTERESTING!:
# alpha = -0.25
# beta = 5.0
# theta= 1/8
#


# # INTERESTING!:
alpha = -0.5
beta = 10.0
theta= 1/8



alpha_1 = -1.0
alpha_2 = -0.75
alpha_3 = -0.70

angles_1 = []
angles_2 = []
angles_3 = []

beta = 2.0
theta= 0.25




print('mu1: ', mu1)
print('rho1: ', rho1)
print('alpha_1: ', alpha_1)
print('alpha_2: ', alpha_2)
print('alpha_3: ', alpha_3)
print('beta: ', beta)
print('theta: ', theta)
# print('gamma:', gamma)
print('----------------------------')

# ----------------------------------------------------------------


gamma_min = 0.01
gamma_max = 1.5
# Gamma_Values = np.linspace(gamma_min, gamma_max, num=200)    # TODO variable Input Parameters...alpha,beta...
Gamma_Values = np.linspace(gamma_min, gamma_max, num=50)    # TODO variable Input Parameters...alpha,beta...
print('(Input) Gamma_Values:', Gamma_Values)
# mu_gamma = []

# Gamma_Values = '0'



# Get values for mu_Gamma
GetMuGammaVec = np.vectorize(GetMuGamma)
muGammas = GetMuGammaVec(beta,theta,Gamma_Values,mu1,rho1, InputFilePath ,OutputFilePath )
print('muGammas:', muGammas)

q12 = 0.0
q1 = (1.0/6.0)*harmonicMean(mu1, beta, theta)
q2 = (1.0/6.0)*arithmeticMean(mu1, beta, theta)
print('q1: ', q1)
print('q2: ', q2)
b1 = prestrain_b1(rho1, beta, alpha,theta)
b2 = prestrain_b2(rho1, beta, alpha,theta)
q3_star = math.sqrt(q1*q2)
print('q3_star:', q3_star)

# TODO these have to be compatible with input parameters!!!
# compute certain ParameterValues that this makes sense
# b1 = q3_star
# b2 = q1
print('b1: ', b1)
print('b2: ', b2)

# return classifyMin(q1, q2, q3, q12,  b1, b2,  print_Cases, print_Output)



# classifyMin_anaVec = np.vectorize(classifyMin_ana)
# G, angles, Types, curvature = classifyMin_anaVec(alpha, beta, theta, muGammas,  mu1, rho1)
classifyMin_anaVec = np.vectorize(classifyMin_ana)
G, angles_1, Types, curvature_1 = classifyMin_anaVec(alpha_1, beta, theta, muGammas,  mu1, rho1)
G, angles_2, Types, curvature_2 = classifyMin_anaVec(alpha_2, beta, theta, muGammas,  mu1, rho1)
G, angles_3, Types, curvature_3 = classifyMin_anaVec(alpha_3, beta, theta, muGammas,  mu1, rho1)

# _,angles,_,_ = classifyMin_anaVec(alpha, beta, theta, muGammas,  mu1, rho1)

print('angles_1:', angles_1)
print('angles_2:', angles_2)
print('angles_3:', angles_3)

print('curvature_1:', curvature_1)
print('curvature_2:', curvature_2)
print('curvature_3:', curvature_3)


idx = find_nearestIdx(muGammas, q3_star)
print('GammaValue Idx closest to q_3^*', idx)
gammaClose = Gamma_Values[idx]
print('GammaValue(Idx) with mu_gamma closest to q_3^*', gammaClose)



determinantVec = np.vectorize(determinant)

detValues = determinantVec(q1,q2,muGammas,q12)
print('detValues:', detValues)


detZeroidx = find_nearestIdx(detValues, 0)
print('idx where det nearest to zero', idx)
gammaClose = Gamma_Values[detZeroidx]
print('gammaClose:', gammaClose)


# --- Convert to numpy array
Gamma_Values = np.array(Gamma_Values)


angles_1 = np.array(angles_1)
angles_2 = np.array(angles_2)
angles_3 = np.array(angles_3)

curvature_1 = np.array(curvature_1)
curvature_2 = np.array(curvature_2)
curvature_3 = np.array(curvature_3)

# ---------------- Create Plot -------------------
# plt.figure()

# Styling
plt.style.use("seaborn-darkgrid")
# plt.style.use("seaborn-whitegrid")
plt.style.use("seaborn")
# plt.style.use("seaborn-paper")
# plt.style.use('ggplot')
# plt.rcParams["font.family"] = "Avenir"
# plt.rcParams["font.size"] = 16

# plt.style.use("seaborn-darkgrid")
mpl.rcParams['text.usetex'] = True
mpl.rcParams["font.family"] = "serif"
mpl.rcParams["font.size"] = "10"
# mpl.rcParams['xtick.labelsize'] = 16mpl.rcParams['xtick.major.size'] = 2.5
# mpl.rcParams['xtick.bottom'] = True
# mpl.rcParams['ticks'] = True
mpl.rcParams['xtick.bottom'] = True
mpl.rcParams['xtick.major.size'] = 3
mpl.rcParams['xtick.minor.size'] = 1.5
mpl.rcParams['xtick.major.width'] = 0.75
mpl.rcParams['ytick.left'] = True
mpl.rcParams['ytick.major.size'] = 3
mpl.rcParams['ytick.minor.size'] = 1.5
mpl.rcParams['ytick.major.width'] = 0.75

mpl.rcParams.update({'font.size': 10})
mpl.rcParams['axes.labelpad'] = 0.0
# mpl.rcParams['legend.frameon'] = 'False'
# mpl.rcParams['xtick.bottom'] = True
# mpl.rcParams['ytick.left'] = True
# mpl.rcParams['axes.autolimit_mode'] = 'round_numbers'
# mpl.rc('xtick', direction='out', color='gray')
# mpl.rc('ytick', direction='out', color='gray')


# sns.set_style("ticks")
# plt.set_style("ticks")

width = 6.28
height = width / 1.618
# height = width / 2.5
fig = plt.figure(figsize=(width,height))

# fig,ax = plt.subplots(nrows=2,ncols=3,figsize=(width,height)) # more than one plot
# fig,ax = plt.subplots(nrows=1,ncols=3,figsize=(width,height),sharey=True) # Share Y-axis
# fig.tight_layout()
#
#
# fig = plt.figure()

gs = fig.add_gridspec(nrows=2,ncols=3, hspace=0.15, wspace=0.1)
# ax = gs.subplots(sharey=True)

# Create Three Axes Objects


ax4 = fig.add_subplot(gs[1, 0])
ax5 = fig.add_subplot(gs[1, 1],sharey=ax4)
ax6 = fig.add_subplot(gs[1, 2],sharey=ax4)
plt.setp(ax5.get_yticklabels(), visible=False)
plt.setp(ax6.get_yticklabels(), visible=False)

ax1 = fig.add_subplot(gs[0, 0],sharex=ax4)
ax2 = fig.add_subplot(gs[0, 1],sharey=ax1)
ax3 = fig.add_subplot(gs[0, 2],sharey=ax1)
plt.setp(ax1.get_xticklabels(), visible=False)
plt.setp(ax2.get_xticklabels(), visible=False)
plt.setp(ax3.get_xticklabels(), visible=False)
plt.setp(ax2.get_yticklabels(), visible=False)
plt.setp(ax3.get_yticklabels(), visible=False)

# ax = plt.axes((0.15,0.21 ,0.75,0.75))
# ax = plt.axes((0.15,0.21 ,0.8,0.75))
# ax.tick_params(axis='x',which='major', direction='out',pad=5)
# ax.tick_params(axis='y',which='major', length=3, width=1, direction='out',pad=3)
# ax.xaxis.set_major_locator(MultipleLocator(0.1))
# ax.xaxis.set_minor_locator(MultipleLocator(0.05))




# ax[0,0].yaxis.set_major_locator(plt.MultipleLocator(np.pi / 8))
# ax[0,0].yaxis.set_minor_locator(plt.MultipleLocator(np.pi / 16))
# ax[0,0].yaxis.set_major_formatter(plt.FuncFormatter(format_func))
# ax[0,1].yaxis.set_major_locator(plt.MultipleLocator(np.pi / 8))
# ax[0,1].yaxis.set_minor_locator(plt.MultipleLocator(np.pi / 16))
# ax[0,1].yaxis.set_major_formatter(plt.FuncFormatter(format_func))
# ax[0,2].yaxis.set_major_locator(plt.MultipleLocator(np.pi / 8))
# ax[0,2].yaxis.set_minor_locator(plt.MultipleLocator(np.pi / 16))
# ax[0,2].yaxis.set_major_formatter(plt.FuncFormatter(format_func))
#
# ax[0,0].grid(True,which='major',axis='both',alpha=0.3)
# ax[0,1].grid(True,which='major',axis='both',alpha=0.3)
# ax[0,2].grid(True,which='major',axis='both',alpha=0.3)

ax1.yaxis.set_major_locator(plt.MultipleLocator(np.pi / 8))
ax1.yaxis.set_minor_locator(plt.MultipleLocator(np.pi / 16))
ax1.yaxis.set_major_formatter(plt.FuncFormatter(format_func))
ax2.yaxis.set_major_locator(plt.MultipleLocator(np.pi / 8))
ax2.yaxis.set_minor_locator(plt.MultipleLocator(np.pi / 16))
ax2.yaxis.set_major_formatter(plt.FuncFormatter(format_func))
ax3.yaxis.set_major_locator(plt.MultipleLocator(np.pi / 8))
ax3.yaxis.set_minor_locator(plt.MultipleLocator(np.pi / 16))
ax3.yaxis.set_major_formatter(plt.FuncFormatter(format_func))

ax1.grid(True,which='major',axis='both',alpha=0.3)
ax2.grid(True,which='major',axis='both',alpha=0.3)
ax3.grid(True,which='major',axis='both',alpha=0.3)


ax1.plot(Gamma_Values, angles_1, 'royalblue', zorder=3, )
ax2.plot(Gamma_Values, angles_2, 'royalblue', zorder=3, )
ax3.plot(Gamma_Values, angles_3, 'royalblue', zorder=3, )

# ax1.set_xlabel(r"$\gamma$")
ax1.set_ylabel(r"angle  $\alpha$")
ax1.xaxis.set_minor_locator(MultipleLocator(0.25))
ax1.xaxis.set_major_locator(MultipleLocator(0.5))

# ax2.set_xlabel(r"$\gamma$")
ax2.xaxis.set_minor_locator(MultipleLocator(0.25))
ax2.xaxis.set_major_locator(MultipleLocator(0.5))

# ax3.set_xlabel(r"$\gamma$")
# ax[2].set_ylabel(r"angle  $\alpha$")
ax3.xaxis.set_minor_locator(MultipleLocator(0.25))
ax3.xaxis.set_major_locator(MultipleLocator(0.5))
# Labels to use in the legend for each line
line_labels = [r"$\theta_\mu  = 1.0$", r"$\theta_\mu  = 2.0$",  r"$\theta_\mu  = 5.0$", r"$\theta_\mu  = 10.0$"]
labels = ['$0$',r'$\pi/8$', r'$\pi/4$' ,r'$3\pi/8$' , r'$\pi/2$']
ax1.set_yticks([0, np.pi/8, np.pi/4, 3*np.pi/8 , np.pi/2, ])
ax2.set_yticks([0, np.pi/8, np.pi/4, 3*np.pi/8 , np.pi/2 ])
ax3.set_yticks([0, np.pi/8, np.pi/4, 3*np.pi/8 , np.pi/2 ])

ax1.set_yticklabels(labels)
ax2.set_yticklabels(labels)
ax3.set_yticklabels(labels)

ax1.set_ylim([0-0.1, np.pi/2+0.1])
ax2.set_ylim([0-0.1, np.pi/2+0.1])
ax3.set_ylim([0-0.1, np.pi/2+0.1])

# for i in range(3):
#     ax1[i].set_ylim([0-0.1, np.pi/2+0.1])

# Plot Gamma Value that is closest to q3_star
l1 = ax1.axvline(x = gammaClose, color = 'midnightblue', linestyle = 'dashed', linewidth=1, label='$\gamma^*$')
l2 = ax2.axvline(x = gammaClose, color = 'midnightblue', linestyle = 'dashed', linewidth=1, label='$\gamma^*$')
l3 = ax3.axvline(x = gammaClose, color = 'midnightblue', linestyle = 'dashed', linewidth=1, label='$\gamma^*$')


# -------------------------------------------------------------------------------------

# ax[1,0].grid(True,which='major',axis='both',alpha=0.3)
# ax[1,1].grid(True,which='major',axis='both',alpha=0.3)
# ax[1,2].grid(True,which='major',axis='both',alpha=0.3)

# ax[1,0].set_xlabel(r"$\gamma$")
# ax[1,0].set_ylabel(r"curvature $\kappa$")
# ax[1,0].xaxis.set_minor_locator(MultipleLocator(0.5))
# ax[1,0].xaxis.set_major_locator(MultipleLocator(1))
# ax[1,0].yaxis.set_minor_locator(MultipleLocator(0.5))
# ax[1,0].yaxis.set_major_locator(MultipleLocator(1))
# ax[1,1].set_xlabel(r"$\gamma$")
# # ax[1].set_ylabel(r"angle  $\alpha$")
# ax[1,1].xaxis.set_minor_locator(MultipleLocator(0.5))
# ax[1,1].xaxis.set_major_locator(MultipleLocator(1))
# ax[1,2].set_xlabel(r"$\gamma$")
# # ax[2].set_ylabel(r"angle  $\alpha$")
# ax[1,2].xaxis.set_minor_locator(MultipleLocator(0.5))
# ax[1,2].xaxis.set_major_locator(MultipleLocator(1))
# l4 = ax[1,0].axvline(x = gammaClose, color = 'midnightblue', linestyle = 'dashed', linewidth=1, label='$\gamma^*$', zorder=4)
# l5 = ax[1,1].axvline(x = gammaClose, color = 'midnightblue', linestyle = 'dashed', linewidth=1, label='$\gamma^*$', zorder=4)
# l6 = ax[1,2].axvline(x = gammaClose, color = 'midnightblue', linestyle = 'dashed', linewidth=1, label='$\gamma^*$' ,zorder=4)



ax4.grid(True,which='major',axis='both',alpha=0.3)
ax5.grid(True,which='major',axis='both',alpha=0.3)
ax6.grid(True,which='major',axis='both',alpha=0.3)
ax4.plot(Gamma_Values, curvature_1, 'forestgreen', zorder=3, )
ax5.plot(Gamma_Values, curvature_2, 'forestgreen', zorder=3, )
ax6.plot(Gamma_Values, curvature_3, 'forestgreen', zorder=3, )
# ax2.plot(Gamma_Values, curvature_1, 'forestgreen', zorder=3, )
# ax2.plot(Gamma_Values, curvature_2, 'forestgreen', zorder=3, )
# ax2.plot(Gamma_Values, curvature_3, 'forestgreen', zorder=3, )
ax4.set_xlabel(r"$\gamma$", fontsize=10 ,labelpad=0)
ax4.set_ylabel(r"curvature $\kappa$")
# ax4.set_ylabel(r"curvature $\kappa$", labelpad=10)
ax4.xaxis.set_minor_locator(MultipleLocator(0.25))
ax4.xaxis.set_major_locator(MultipleLocator(0.5))
# ax4.yaxis.set_minor_locator(MultipleLocator(0.1))
ax4.yaxis.set_major_locator(MultipleLocator(0.05))
ax5.set_xlabel(r"$\gamma$", fontsize=10 ,labelpad=0)
# ax[1].set_ylabel(r"angle  $\alpha$")
ax5.xaxis.set_minor_locator(MultipleLocator(0.25))
ax5.xaxis.set_major_locator(MultipleLocator(0.5))
ax6.set_xlabel(r"$\gamma$", fontsize=10 ,labelpad=0)
# ax[2].set_ylabel(r"angle  $\alpha$")
ax6.xaxis.set_minor_locator(MultipleLocator(0.25))
ax6.xaxis.set_major_locator(MultipleLocator(0.5))
l4 = ax4.axvline(x = gammaClose, color = 'midnightblue', linestyle = 'dashed', linewidth=1, label='$\gamma^*$', zorder=4)
l5 = ax5.axvline(x = gammaClose, color = 'midnightblue', linestyle = 'dashed', linewidth=1, label='$\gamma^*$', zorder=4)
l6 = ax6.axvline(x = gammaClose, color = 'midnightblue', linestyle = 'dashed', linewidth=1, label='$\gamma^*$' ,zorder=4)





#
#



## LEGEND
line_labels = [r"$\gamma^*$"]
# fig.legend([l1], [r"$\gamma^*$"],
#             # bbox_to_anchor=[0.5, 0.92],
#             bbox_to_anchor=[0.5, 0.94],
#             loc='center', ncol=3)
legend = fig.legend([l1], [r"$\gamma^*$"],
            # bbox_to_anchor=[0.5, 0.92],
            bbox_to_anchor=[0.52, 0.58],
            loc='center', ncol=3,
            frameon=True)
frame = legend.get_frame()
# frame.set_color('white')
frame.set_edgecolor('gray')

# plt.subplots_adjust(wspace=0.4, hspace=0.0)
# plt.tight_layout()

# Adjust the scaling factor to fit your legend text completely outside the plot
# (smaller value results in more space being made for the legend)
# plt.subplots_adjust(right=0.9)
# plt.subplots_adjust(bottom=0.2)


fig.align_ylabels()

fig.set_size_inches(width, height)
fig.savefig('Plot-AngleCurv-Gamma.pdf')

plt.show()




# plt.figure()
# plt.title(r'angle$-\mu_\gamma(\gamma)$-Plot')
# plt.plot(muGammas, angles)
# plt.scatter(muGammas, angles)
# # plt.axis([0, 6, 0, 20])
# # plt.axhline(y = 1.90476, color = 'b', linestyle = ':', label='$q_1$')
# # plt.axhline(y = 2.08333, color = 'r', linestyle = 'dashed', label='$q_2$')
# plt.axvline(x = 1.90476, color = 'b', linestyle = ':', label='$q_1$')
# plt.axvline(x = 2.08333, color = 'r', linestyle = 'dashed', label='$q_2$')
# plt.xlabel("$\mu_\gamma$")
# plt.ylabel("angle")
# plt.legend()
# plt.show()
#