import numpy as np
import matplotlib.pyplot as plt
import sympy as sym
import math
import os
import subprocess
import fileinput
import re
import matlab.engine
import sys
from ClassifyMin import *
from HelperFunctions import *
# from CellScript import *
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm
from vtk.util import numpy_support
from pyevtk.hl import gridToVTK
import time
import matplotlib.ticker as ticker

import matplotlib as mpl
from matplotlib.ticker import MultipleLocator,FormatStrFormatter,MaxNLocator
import pandas as pd

import seaborn as sns
import matplotlib.colors as mcolors

from chart_studio import plotly
import plotly.graph_objs as go
import plotly.express as px
import plotly.colors
# from matplotlib import rc
# rc('text', usetex=True) # Use LaTeX font
#
# import seaborn as sns
# sns.set(color_codes=True)


def show(fig):
    import io
    import plotly.io as pio
    from PIL import Image
    buf = io.BytesIO()
    pio.write_image(fig, buf)
    img = Image.open(buf)
    img.show()




def add_arrow(line, position=None, direction='right', size=15, color=None):
    """
    add an arrow to a line.

    line:       Line2D object
    position:   x-position of the arrow. If None, mean of xdata is taken
    direction:  'left' or 'right'
    size:       size of the arrow in fontsize points
    color:      if None, line color is taken.
    """
    if color is None:
        color = line.get_color()

    xdata = line.get_xdata()
    ydata = line.get_ydata()

    if position is None:
        position = xdata.mean()
    # find closest index
    start_ind = np.argmin(np.absolute(xdata - position))
    if direction == 'right':
        end_ind = start_ind + 1
    else:
        end_ind = start_ind - 1

    line.axes.annotate('',
        xytext=(xdata[start_ind], ydata[start_ind]),
        xy=(xdata[end_ind], ydata[end_ind]),
        arrowprops=dict(arrowstyle="->", color=color),
        size=size
    )


# set the colormap and centre the colorbar
class MidpointNormalize(mcolors.Normalize):
	"""
	Normalise the colorbar so that diverging bars work there way either side from a prescribed midpoint value)

	e.g. im=ax1.imshow(array, norm=MidpointNormalize(midpoint=0.,vmin=-100, vmax=100))
	"""
	def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
		self.midpoint = midpoint
		mcolors.Normalize.__init__(self, vmin, vmax, clip)

	def __call__(self, value, clip=None):
		# I'm ignoring masked values and all kinds of edge cases to make a
		# simple example...
		x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
		return np.ma.masked_array(np.interp(value, x, y), np.isnan(value))



def set_size(width, fraction=1):
    """Set figure dimensions to avoid scaling in LaTeX.

    Parameters
    ----------
    width: float
            Document textwidth or columnwidth in pts
    fraction: float, optional
            Fraction of the width which you wish the figure to occupy

    Returns
    -------
    fig_dim: tuple
            Dimensions of figure in inches
    """
    # Width of figure (in pts)
    fig_width_pt = width * fraction

    # Convert from pt to inches
    inches_per_pt = 1 / 72.27

    # Golden ratio to set aesthetic figure height
    # https://disq.us/p/2940ij3
    golden_ratio = (5**.5 - 1) / 2

    # Figure width in inches
    fig_width_in = fig_width_pt * inches_per_pt
    # Figure height in inches
    fig_height_in = fig_width_in * golden_ratio

    fig_dim = (fig_width_in, fig_height_in)

    return fig_dim



def format_func(value, tick_number):
    # find number of multiples of pi/2
    # N = int(np.round(2 * value / np.pi))
    # if N == 0:
    #     return "0"
    # elif N == 1:
    #     return r"$\pi/2$"
    # elif N == -1:
    #     return r"$-\pi/2$"
    # elif N == 2:
    #     return r"$\pi$"
    # elif N % 2 > 0:
    #     return r"${0}\pi/2$".format(N)
    # else:
    #     return r"${0}\pi$".format(N // 2)
    ##find number of multiples of pi/2
    N = int(np.round(4 * value / np.pi))
    if N == 0:
        return "0"
    elif N == 1:
        return r"$\pi/4$"
    elif N == -1:
        return r"$-\pi/4$"
    elif N == 2:
        return r"$\pi/2$"
    elif N == -2:
        return r"$-\pi/2$"
    elif N % 2 > 0:
        return r"${0}\pi/2$".format(N)
    else:
        return r"${0}\pi$".format(N // 2)



def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]


def find_nearestIdx(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx



def energy(a1,a2,q1,q2,q12,q3,b1,b2):


    a = np.array([a1,a2])
    b = np.array([b1,b2])
    H = np.array([[2*q1, q12+2*q3], [q12+2*q3,2*q2] ])
    A = np.array([[q1,1/2*q12], [1/2*q12,q2] ])


    tmp = H.dot(a)

    # print('H',H)
    # print('A',A)
    # print('b',b)
    # print('a',a)
    # print('tmp',tmp)

    tmp = (1/2)*a.dot(tmp)
    # print('tmp',tmp)

    tmp2 = A.dot(b)
    # print('tmp2',tmp2)
    tmp2 = 2*a.dot(tmp2)

    # print('tmp2',tmp2)
    energy = tmp - tmp2
    # print('energy',energy)


    # energy_axial1.append(energy_1)

    return energy



# def energy(a1,a2,q1,q2,q12,q3,b1,b2):
#
#
#     b = np.array([b1,b2])
#     H = np.array([[2*q1, q12+2*q3], [q12+2*q3,2*q2] ])
#     A = np.array([[q1,1/2*q12], [1/2*q12,q2] ])
#
#
#     tmp = H.dot(a)
#
#     print('H',H)
#     print('A',A)
#     print('b',b)
#     print('a',a)
#     print('tmp',tmp)
#
#     tmp = (1/2)*a.dot(tmp)
#     print('tmp',tmp)
#
#     tmp2 = A.dot(b)
#     print('tmp2',tmp2)
#     tmp2 = 2*a.dot(tmp2)
#
#     print('tmp2',tmp2)
#     energy = tmp - tmp2
#     print('energy',energy)
#
#
#     # energy_axial1.append(energy_1)
#
#     return energy
#





################################################################################################################
################################################################################################################
################################################################################################################

InputFile  = "/inputs/computeMuGamma.parset"
OutputFile = "/outputs/outputMuGamma.txt"
# --------- Run  from src folder:
path_parent = os.path.dirname(os.getcwd())
os.chdir(path_parent)
path = os.getcwd()
print(path)
InputFilePath = os.getcwd()+InputFile
OutputFilePath = os.getcwd()+OutputFile
print("InputFilepath: ", InputFilePath)
print("OutputFilepath: ", OutputFilePath)
print("Path: ", path)

print('---- Input parameters: -----')

# q1=1;
# q2=2;
# q12=1/2;
# q3=((4*q1*q2)**0.5-q12)/2;
# # H=[2*q1,q12+2*q3;q12+2*q3,2*q2];
#
# H = np.array([[2*q1, q12+2*q3], [q12+2*q3,2*q2] ])
# A = np.array([[q1,1/2*q12], [1/2*q12,q2] ])
# abar = np.array([q12+2*q3, 2*q2])
# abar = (1.0/math.sqrt((q12+2*q3)**2+(2*q2)**2))*abar
#
# print('abar:',abar)
#
# b = np.linalg.lstsq(A, abar)[0]
# print('b',b)
#
#
# # print('abar:',np.shape(abar))
# # print('np.transpose(abar):',np.shape(np.transpose(abar)))
# sstar = (1/(q1+q2))*abar.dot(A.dot(b))
# # sstar = (1/(q1+q2))*abar.dot(tmp)
# print('sstar', sstar)
# abarperp= np.array([abar[1],-abar[0]])
# print('abarperp:',abarperp)


# -------------------------- Input Parameters --------------------

mu1 = 1.0
rho1 = 1.0
alpha = 5.0
theta = 1.0/2
# theta= 0.1
beta = 5.0




# mu1 = 1.0
# rho1 = 1.0
# alpha = -0.75
# theta = 1.0/2
# # theta= 0.1
# beta = 5.0


# mu1 = 1.0
# rho1 = 1.0
# alpha = 2.0
# theta = 1.0/2
# # theta= 0.1
# beta = 5.0


#Figure3:
# mu1 = 1.0
# rho1 = 1.0
# alpha = 2.0
# theta = 1.0/8
# # theta= 0.1
# beta = 2.0


# alpha= -5


#set gamma either to 1. '0' 2. 'infinity' or 3. a numerical positive value
gamma = '0'
# gamma = 'infinity'


lambda1 = 0.0


print('---- Input parameters: -----')
print('mu1: ', mu1)
print('rho1: ', rho1)
# print('alpha: ', alpha)
print('beta: ', beta)
# print('theta: ', theta)
print('gamma:', gamma)

print('lambda1: ', lambda1)
print('----------------------------')
# ----------------------------------------------------------------
print('----------------------------')

# ----------------------------------------------------------------






q1 = (1.0/6.0)*harmonicMean(mu1, beta, theta)
q2 = (1.0/6.0)*arithmeticMean(mu1, beta, theta)
q12 = 0.0
q3 = GetMuGamma(beta, theta,gamma,mu1,rho1,InputFilePath ,OutputFilePath )
b1 = prestrain_b1(rho1,beta, alpha, theta )
b2 = prestrain_b2(rho1,beta, alpha, theta )


## ---- 1-ParameterFamilyCase:
# q1=1;
# q2=2;
# q12=1/2;
# q3=((4*q1*q2)**0.5-q12)/2;
# # H=[2*q1,q12+2*q3;q12+2*q3,2*q2];
# H = np.array([[2*q1, q12+2*q3], [q12+2*q3,2*q2] ])
# A = np.array([[q1,1/2*q12], [1/2*q12,q2] ])
# abar = np.array([q12+2*q3, 2*q2])
# abar = (1.0/math.sqrt((q12+2*q3)**2+(2*q2)**2))*abar
# print('abar:',abar)
# b = np.linalg.lstsq(A, abar)[0]
# print('b',b)
# b1=b[0]
# b2=b[1]
## ---------------





########################################

# 1-ParameterFamilyCase:

q1=1;
q2=2;
q12=1/2;
q3=((4*q1*q2)**0.5-q12)/2;
# H=[2*q1,q12+2*q3;q12+2*q3,2*q2];

H = np.array([[2*q1, q12+2*q3], [q12+2*q3,2*q2] ])
A = np.array([[q1,1/2*q12], [1/2*q12,q2] ])
abar = np.array([q12+2*q3, 2*q2])
abar = (1.0/math.sqrt((q12+2*q3)**2+(2*q2)**2))*abar

print('abar:',abar)

b = np.linalg.lstsq(A, abar)[0]
print('b',b)

b1=b[0]
b2=b[1]


##----------------------
print('q1 = ', q1)
print('q2 = ', q2)
print('q3 = ', q3)
print('q12 = ', q12)
print('b1 = ', b1)
print('b2 = ', b2)
##--------------


# print('abar:',np.shape(abar))
# print('np.transpose(abar):',np.shape(np.transpose(abar)))
sstar = (1/(q1+q2))*abar.dot(A.dot(b))
# sstar = (1/(q1+q2))*abar.dot(tmp)
print('sstar', sstar)
abarperp= np.array([abar[1],-abar[0]])
print('abarperp:',abarperp)

print('sstar*abar:', sstar*abar)
# print('np.dot(sstar*abar):', np.dot(sstar*abar))
print('----------------------------')


N=1000;
N=10;
scale_domain = 5
translate_startpoint = -1.8
T_line = np.linspace(-sstar*(q12+2*q3)/(2*q2)*scale_domain + translate_startpoint, sstar*(2*q2)/(q12+2*q3)*scale_domain , num=N)
line_values = []
for t in T_line :
    print('sstar*abar+t*abarperp', sstar*abar+t*abarperp)
    line_values.append(sstar*abar+t*abarperp)




T = np.linspace(-sstar*(q12+2*q3)/(2*q2), sstar*(2*q2)/(q12+2*q3), num=N)

# T = np.linspace(-2,2, num=N)
# print('T:', T)

print('T.min():', T.min())
print('T.max():', T.max())

kappas = []
alphas = []
# G.append(float(s[0]))

G_container = []
abar_container = []

test = sstar*abar


abar_tmp = abar

for t in T :
    abar_current = sstar*abar+t*abarperp;
    abar_current[abar_current < 1e-10] = 0   # Projection onto x-y-axis!!
    print('abar_current', abar_current)
    G = [abar_current[0], abar_current[1] , (2*abar_current[0]*abar_current[1])**0.5 ]
    G_container.append(G)
    abar_container.append(abar_current)
    e = [(abar_current[0]/(abar_current[0]+abar_current[1]))**0.5, (abar_current[1]/(abar_current[0]+abar_current[1]))**0.5]
    kappa = abar_current[0]+abar_current[1]
    alpha = math.atan2(e[1], e[0])
	# print('angle current:', alpha)
    kappas.append(kappa)
    alphas.append(alpha)


alphas = np.array(alphas)
kappas = np.array(kappas)

# print('G_container', G_container)
G = np.array(G_container)
abar = np.array(abar_container)

print('G', G)
print('abar', abar)
print('abar.shape',abar.shape)










######################################



num_Points = 400
num_Points = 200
# num_Points = 20


# Creating dataset
x = np.linspace(-5,5,num_Points)
y = np.linspace(-5,5,num_Points)

x = np.linspace(-20,20,num_Points)
y = np.linspace(-20,20,num_Points)


x = np.linspace(-2,2,num_Points)
y = np.linspace(-2,2,num_Points)

# x = np.linspace(-10,10,num_Points)
# y = np.linspace(-10,10,num_Points)

# x = np.linspace(-60,60,num_Points)
# y = np.linspace(-60,60,num_Points)
#
#
# x = np.linspace(-40,40,num_Points)
# y = np.linspace(-40,40,num_Points)


a1, a2 = np.meshgrid(x,y)

# geyser = sns.load_dataset("geyser")
# print('type of geyser:', type(geyser))
# print('geyser:',geyser)

ContourRange=20

x_in = np.linspace(-ContourRange,ContourRange,num_Points)
y_in = np.linspace(-ContourRange,ContourRange,num_Points)
a1_in, a2_in = np.meshgrid(x_in,y_in)

print('a1:', a1)
print('a2:',a2 )

print('a1.shape', a1.shape)

#-- FILTER OUT VALUES for G+ :

# tmp1 = a1[np.where(a1*a2 >= 0)]
# tmp2 = a2[np.where(a1*a2 >= 0)]
#
# np.take(a, np.where(a>100)[0], axis=0)
# tmp1 = np.take(a1, np.where(a1*a2 >= 0)[0], axis=0)
# tmp2 = np.take(a1, np.where(a1*a2 >= 0)[0], axis=0)
# tmp2 = a2[np.where(a1*a2 >= 0)]

tmp1 = a1[a1*a2 >= 0]
tmp2 = a2[a1*a2 >= 0]
tmp1 = tmp1.reshape(-1,5)
tmp2 = tmp2.reshape(-1,5)


# tmp1_pos = a1[np.where(a1*a2 >= 0) ]
# tmp2_pos = a2[np.where(a1*a2 >= 0) ]
# tmp1_pos = tmp1_pos[np.where(tmp1_pos  >= 0)]
# tmp2_pos = tmp2_pos[np.where(tmp2_pos  >= 0)]
#
# tmp1_neg = a1[a1*a2 >= 0 ]
# tmp2_neg = a2[a1*a2 >= 0 ]
# tmp1_neg = tmp1_neg[tmp1_neg  < 0]
# tmp2_neg = tmp2_neg[tmp2_neg  < 0]
# a1 = tmp1
# a2 = tmp2
#
# a1 = a1.reshape(-1,5)
# a2 = a2.reshape(-1,5)
#
# tmp1_pos = tmp1_pos.reshape(-1,5)
# tmp2_pos = tmp2_pos.reshape(-1,5)
# tmp1_neg = tmp1_neg.reshape(-1,5)
# tmp2_neg = tmp2_neg.reshape(-1,5)


print('a1:', a1)
print('a2:',a2 )
print('a1.shape', a1.shape)





energyVec = np.vectorize(energy)

# Z = energyVec(np.array([a1,a2]),q1,q2,q12,q3,b1,b2)
Z = energyVec(a1,a2,q1,q2,q12,q3,b1,b2)

Z_in = energyVec(a1_in,a2_in,q1,q2,q12,q3,b1,b2)




print('Z:', Z)

print('any', np.any(Z<0))

#


# negZ_a1 = a1[np.where(Z<0)]
# negZ_a2 = a2[np.where(Z<0)]
# negativeValues = Z[np.where(Z<0)]
# print('negativeValues:',negativeValues)
#
# print('negZ_a1',negZ_a1)
# print('negZ_a2',negZ_a2)
#
#
# negZ_a1 = negZ_a1.reshape(-1,5)
# negZ_a2 = negZ_a2.reshape(-1,5)
# negativeValues = negativeValues.reshape(-1,5)
#
# Z_pos =  energyVec(tmp1_pos,tmp2_pos,q1,q2,q12,q3,b1,b2)
# Z_neg = energyVec(tmp1_neg,tmp2_neg,q1,q2,q12,q3,b1,b2)





# print('Test energy:' , energy(np.array([1,1]),q1,q2,q12,q3,b1,b2))




# print('Z_pos.shape', Z_pos.shape)







## -- PLOT :
# mpl.rcParams['text.usetex'] = True
# mpl.rcParams["font.family"] = "serif"
# mpl.rcParams["font.size"] = "9"
#
# label_size = 8
# mpl.rcParams['xtick.labelsize'] = label_size
# mpl.rcParams['ytick.labelsize'] = label_size
#
# plt.style.use('seaborn')
# # plt.style.use('seaborn-whitegrid')
# # sns.set()
# # plt.style.use('seaborn-whitegrid')
#
# label_size = 9
# mpl.rcParams['xtick.labelsize'] = label_size
# mpl.rcParams['ytick.labelsize'] = label_size

# # plt.style.use("seaborn-darkgrid")
# plt.style.use("seaborn-whitegrid")
plt.style.use("seaborn")
# plt.style.use("seaborn-paper")
# plt.style.use('ggplot')
# plt.rcParams["font.family"] = "Avenir"
# plt.rcParams["font.size"] = 16

# plt.style.use("seaborn-darkgrid")
mpl.rcParams['text.usetex'] = True
mpl.rcParams["font.family"] = "serif"
mpl.rcParams["font.size"] = "10"
# mpl.rcParams['xtick.labelsize'] = 16mpl.rcParams['xtick.major.size'] = 2.5
# mpl.rcParams['xtick.bottom'] = True
# mpl.rcParams['ticks'] = True
mpl.rcParams['xtick.bottom'] = True
mpl.rcParams['xtick.major.size'] = 3
mpl.rcParams['xtick.minor.size'] = 1.5
mpl.rcParams['xtick.major.width'] = 0.75
mpl.rcParams['ytick.left'] = True
mpl.rcParams['ytick.major.size'] = 3
mpl.rcParams['ytick.minor.size'] = 1.5
mpl.rcParams['ytick.major.width'] = 0.75

mpl.rcParams.update({'font.size': 10})
mpl.rcParams['axes.labelpad'] = 3.0

width = 6.28 *0.5
# width = 6.28 *0.33
# width = 6.28
height = width #/ 1.618

# width = 452.9579/2
# size= set_size(width, fraction=0.5)
# print('set_size(width, fraction=0.5)', set_size(width, fraction=1))
# print('size[0]',size[0])

f_size = 10
# fig= plt.figure()
fig, ax = plt.subplots()
# fig.set_size_inches(width, height)
# fig.set_size_inches(set_size(width, fraction=0.5))
# ax = plt.axes(projection ='3d', adjustable='box')
# ax = plt.axes((0.17,0.21 ,0.75,0.75))

# ax = plt.axes((0.17,0.23 ,0.7,0.7))
# ax = plt.axes((0.17,0.23 ,1.0,1.0))
# ax=plt.axes()
# ax = plt.axes((0.15,0.18,0.8,0.8))
# ax.tick_params(axis='x',which='major', direction='out',pad=5)
# ax.tick_params(axis='y',which='major', length=3, width=1, direction='out',pad=3)
# ax.xaxis.set_major_locator(MultipleLocator(0.1))
# ax.xaxis.set_minor_locator(MultipleLocator(0.05))
# ax.xaxis.set_major_locator(plt.MultipleLocator(np.pi / 8))
# ax.xaxis.set_minor_locator(plt.MultipleLocator(np.pi / 16))
# ax.xaxis.set_major_locator(plt.MultipleLocator(np.pi / 2))
# ax.xaxis.set_minor_locator(plt.MultipleLocator(np.pi / 4))
# ax.xaxis.set_major_formatter(plt.FuncFormatter(format_func))
ax.grid(True,which='major',axis='both',alpha=0.3)

# Substract constant:
# c = (b1**2)*q1+b1*b2*q12+(b2**2)*q2
# Z = Z-c
#
# print('Value of c:', c)



print('Z.min()', Z.min())
print('Z.max()', Z.max())
norm=mcolors.Normalize(Z.min(),Z.max())
# facecolors=cm.brg(norm)




cmap=mpl.cm.RdBu_r
cmap=mpl.cm.viridis_r
# cmap=mpl.cm.bwr
cmap=mpl.cm.coolwarm
# cmap=mpl.cm.gnuplot
cmap=mpl.cm.magma_r
# cmap=mpl.cm.inferno_r
# cmap=mpl.cm.plasma
# cmap=mpl.cm.plasma_r
# cmap=mpl.cm.cividis_r
# cmap = mpl.colors.LinearSegmentedColormap.from_list("", ["blue","violet","red"])
# cmap = mpl.colors.LinearSegmentedColormap.from_list("", ["blue","orange"])

# cmap = mpl.colors.LinearSegmentedColormap.from_list("", [(0,"red"), (.1,"violet"), (.5, "blue"), (1.0, "green")])
# make a colormap that has land and ocean clearly delineated and of the
# same length (256 + 256)
#
# colors_undersea = plt.cm.terrain(np.linspace(0, 0.17, 256))
# colors_land = plt.cm.terrain(np.linspace(0.25, 1, 256))
# all_colors = np.vstack((colors_undersea, colors_land))
# # cmap = mcolors.LinearSegmentedColormap.from_list(
# #     'terrain_map', all_colors)


# cmap = px.colors.sequential.agsunset
# cmap = plotly.colors.PLOTLY_SCALES["Viridis"]




# cmap = cm.brg(Z)
divnorm=mcolors.TwoSlopeNorm(vmin=Z.min(), vcenter=0., vmax=Z.max())
divnorm=mcolors.TwoSlopeNorm(vmin=Z.min(), vcenter=(Z.max()+Z.min())/2, vmax=Z.max())
# divnorm=mcolors.TwoSlopeNorm(vmin=-500, vcenter=0, vmax=Z.max())

# divnorm=mcolors.TwoSlopeNorm(vmin=-10, vcenter=0. ,vmax=10)
# divnorm=mcolors.TwoSlopeNorm(vmin=-10, vcenter=0., vmax=Z.max())

# divnorm=mcolors.LogNorm(vmin=Z.min(),  vmax=Z.max())  #Test LogNorm

# cmap = cm.brg(divnorm(Z))

# ax.imshow(Z, extent=[-20, 20, -20, 20],origin='lower', norm = norm,
#                   cmap='coolwarm', alpha=0.6)

# ax.imshow(Z, extent=[-20, 20, -20, 20],origin='lower',
#                   cmap='coolwarm', alpha=0.6)
# ax.imshow(Z, extent=[-20, 20, -20, 20],origin='lower',
#                   cmap=cmap, alpha=0.6)

# divnorm=mcolors.TwoSlopeNorm(vmin=Z.min(), vcenter=0., vmax=Z.max())
# plt.imshow(Z, extent=[x.min(), x.max(), y.min(), y.max()],origin='lower',
#                   cmap=cmap, alpha=0.6)

# I = plt.imshow(Z, extent=[x.min(), x.max(), y.min(), y.max()],origin='lower', norm = divnorm,
#                   cmap=cmap, alpha=0.6)
# plt.imshow(Z, extent=[x.min(), x.max(), y.min(), y.max()],origin='lower', norm = divnorm,
#                   cmap=cmap, alpha=0.9)
# plt.imshow(Z, extent=[x.min(), x.max(), y.min(), y.max()],origin='lower', norm = divnorm,
#                   cmap=cmap, alpha=0.6)


# I = plt.imshow(Z, extent=[x.min(), x.max(), y.min(), y.max()],origin='lower',
                  # cmap=cmap, alpha=0.6)


# I = plt.imshow(Z, extent=[x.min(), x.max(), y.min(), y.max()],origin='lower', norm = mcolors.CenteredNorm(),
#                   cmap=cmap, alpha=0.6)



# COLORBAR :
# cbar = plt.colorbar()
# cbar.ax.tick_params(labelsize=f_size)
# fig.colorbar(I)



##----- ADD RECTANGLE TO COVER QUADRANT :
# epsilon = 0.4
epsilon = 0.001
# ax.axvspan(0, x.max(), y.min(), 0, alpha=1, color='yellow', zorder=5)#yellow
# ax.fill_between([0, x.max()], y.min(), 0, alpha=0.3, color='yellow', zorder=5)#yellow
# ax.fill_between([x.min(), 0], 0, y.max(), alpha=0.3, color='yellow', zorder=5)#yellow

fillcolor = 'royalblue'
# ax.fill_between([0+epsilon, x.max()], y.min(), 0-epsilon, alpha=0.7, color=fillcolor, zorder=4)#yellow
# ax.fill_between([x.min(), 0-epsilon], 0+epsilon, y.max(), alpha=0.7, color=fillcolor, zorder=4)#yellow
# ax.fill_between([0+epsilon, x.max()], y.min(), 0-epsilon, alpha=1.0, color=fillcolor, zorder=4)#yellow
# ax.fill_between([x.min(), 0-epsilon], 0+epsilon, y.max(), alpha=1.0, color=fillcolor, zorder=4)#yellow


## FILL
ax.fill_between([x.min(), 0-epsilon], y.min(), 0-epsilon, alpha=.25, color=fillcolor, zorder=4)#yellow
ax.fill_between([0+epsilon, x.max()], 0+epsilon, y.max(), alpha=.25, color=fillcolor, zorder=4)#yellow


ax.text(1,1, r"$\mathcal{G}^+_{\mathbf{R}^2}$", color='royalblue', size=15)

# ax.text(0.25,0.25, r"$\mathcal{S}$", color='darkorange', size=15)

# PLOT 1-PARAMETER FAMILY

print('abar:', abar)
print('abar[0,:]:', abar[0,:])
print('abar[1,:]:', abar[1,:])

line = ax.plot(abar[:,0],abar[:,1], linewidth=1.5, color='darkorange', linestyle='-', zorder=4)

# plt.arrow(x=1, y=0, dx=0.5, dy=0, )
# ax.arrow(1, 0, 0.5, 0, head_width=0.05, head_length=0.1, fc='k', ec='k', zorder=5)
# plt.arrow(1, 0, 0.8, 0,  shape='full', lw=0, length_includes_head=True, head_width=.15, zorder=5, color='purple')


## PLOT ARROW:
# plt.arrow(1, 0, 0.5, 0,  shape='full', lw=0, length_includes_head=True, head_width=.20, zorder=5, color='darkorange')
# plt.arrow(0, 1.8, 0, -0.8,  shape='full', lw=0, length_includes_head=True, head_width=.20, zorder=5, color='darkorange')


# plt.arrow(0, 1.8, 0, -0.8, lw=0,  head_width=.12, zorder=5, color='darkorange')
# plt.arrow(0, 2, 0, -0.5,  shape='full', lw=0, length_includes_head=True, head_width=.12, zorder=5, color='darkorange')
# add_arrow(line, color='darkorange')




# ax.plot_surface(a1,a2, Z, cmap=cm.coolwarm,
#                        linewidth=0, antialiased=False)



# ax.plot(theta_rho, energy_axial1, 'royalblue', zorder=3, label=r"axialMin1")
# ax.plot(theta_rho, energy_axial2, 'forestgreen', zorder=3, label=r"axialMin2")
# ax.plot(-1.0*alphas, kappas, 'red', zorder=3, )


# print('test:',test)
# ax.scatter(test[0],test[1])
# ax.arrow(test[0],test[1],abarperp[0],abarperp[1])
line_values= np.array(line_values)

ax.plot(line_values[:,0],line_values[:,1],'k--', linewidth=1,color='orange',alpha=0.5)

# lg = ax.legend(bbox_to_anchor=(0.0, 0.75), loc='upper left')



### PLot x and y- Axes
ax.plot(ax.get_xlim(),[0,0],'k--', linewidth=0.5)
ax.plot([0,0],ax.get_ylim(), 'k--', linewidth=0.5)



ax.set_xlabel(r"$a_1$", fontsize=f_size ,labelpad=0)
ax.set_ylabel(r"$a_2$", fontsize=f_size ,labelpad=0, rotation=0)
# ax.set_ylabel(r"energy")
ax.tick_params(axis='both', which='major', labelsize=f_size)
ax.tick_params(axis='both', which='minor', labelsize=f_size)



ax.set_ylim(-2,2)

# ax.set_xticks([-np.pi/2, -np.pi/4  ,0,  np.pi/4,  np.pi/2 ])
# labels = ['$0$',r'$\pi/8$', r'$\pi/4$' ,r'$3\pi/8$' , r'$\pi/2$']
# ax.set_yticklabels(labels)

print('x.max()',x.max())
print('y.max()',y.max())
ax.set_xlim(x.min(),x.max())
ax.set_xlim(y.min(),y.max())

plt.subplots_adjust(bottom=0.15)
plt.subplots_adjust(left=0.2)



# ax.legend(loc='upper right')



fig.set_size_inches(width, height)
# fig.set_size_inches(set_size(width, fraction=0.33))
fig.savefig('Energy_ContourG+_Flat.pdf')

plt.show()


#
#
#
# # Curve parametrised by \theta_rho = alpha in parameter space
# N=100;
# theta_rho = np.linspace(1, 3, num=N)
# print('theta_rho:', theta_rho)
#
#
# theta_values = []
#
#
# for t in theta_rho:
#
#         s = (1.0/10.0)*t+0.1
#         theta_values.append(s)
#
#
#
#
#
# theta_rho = np.array(theta_rho)
# theta_values = np.array(theta_values)
#
# betas_ = 2.0
#

# alphas, betas, thetas = np.meshgrid(theta_rho, betas_, theta_values, indexing='ij')
#
#
# harmonicMeanVec = np.vectorize(harmonicMean)
# arithmeticMeanVec = np.vectorize(arithmeticMean)
# prestrain_b1Vec = np.vectorize(prestrain_b1)
# prestrain_b2Vec = np.vectorize(prestrain_b2)
#
# GetMuGammaVec = np.vectorize(GetMuGamma)
# muGammas = GetMuGammaVec(betas,thetas,gamma,mu1,rho1,InputFilePath ,OutputFilePath )
#
# q1_vec = harmonicMeanVec(mu1, betas, thetas)
# q2_vec = arithmeticMeanVec(mu1, betas, thetas)
#
# b1_vec = prestrain_b1Vec(rho1, betas, alphas, thetas)
# b2_vec = prestrain_b2Vec(rho1, betas, alphas, thetas)

# special case: q12 == 0!! .. braucht eigentlich nur b1 & b2 ...

# print('type b1_values:', type(b1_values))



# print('size(q1)',q1.shape)
#
#
# energy_axial1 = []
# energy_axial2 = []
#
# # for b1 in b1_values:
# for i in range(len(theta_rho)):
#     print('index i:', i)
#
#     print('theta_rho[i]',theta_rho[i])
#     print('theta_values[i]',theta_values[i])
#
#     q1 = (1.0/6.0)*harmonicMean(mu1, beta, theta_values[i])
#     q2 = (1.0/6.0)*arithmeticMean(mu1, beta, theta_values[i])
#     q12 = 0.0
#     q3 = GetMuGamma(beta, theta_values[i],gamma,mu1,rho1,InputFilePath ,OutputFilePath )
#     b1 = prestrain_b1(rho1,beta, theta_rho[i],theta_values[i] )
#     b2 = prestrain_b2(rho1,beta, theta_rho[i],theta_values[i] )
#
#
#     # q2_vec = arithmeticMean(mu1, betas, thetas)
#     #
#     # b1_vec = prestrain_b1Vec(rho1, betas, alphas, thetas)
#     # b2_vec = prestrain_b2Vec(rho1, betas, alphas, thetas)
#     print('q1[i]',q1)
#     print('q2[i]',q2)
#     print('q3[i]',q3)
#     print('b1[i]',b1)
#     print('b2[i]',b2)
#     # print('q1[i]',q1[0][i])
#     # print('q2[i]',q2[i])
#     # print('b1[i]',b1[i])
#     # print('b2[i]',b2[i])
#     #compute axial energy #1 ...
#
#     a_axial1 = np.array([b1,0])
#     a_axial2 = np.array([0,b2])
#     b = np.array([b1,b2])
#
#     H = np.array([[2*q1, q12+2*q3], [q12+2*q3,2*q2] ])
#     A = np.array([[q1,1/2*q12], [1/2*q12,q2] ])
#
#
#     tmp = H.dot(a_axial1)
#
#     print('H',H)
#     print('A',A)
#     print('b',b)
#     print('a_axial1',a_axial1)
#     print('tmp',tmp)
#
#     tmp = (1/2)*a_axial1.dot(tmp)
#     print('tmp',tmp)
#
#     tmp2 = A.dot(b)
#     print('tmp2',tmp2)
#     tmp2 = 2*a_axial1.dot(tmp2)
#
#     print('tmp2',tmp2)
#     energy_1 = tmp - tmp2
#     print('energy_1',energy_1)
#
#
#     energy_axial1.append(energy_1)
#
#
#     tmp = H.dot(a_axial2)
#
#     print('H',H)
#     print('A',A)
#     print('b',b)
#     print('a_axial2',a_axial2)
#     print('tmp',tmp)
#
#     tmp = (1/2)*a_axial2.dot(tmp)
#     print('tmp',tmp)
#
#     tmp2 = A.dot(b)
#     print('tmp2',tmp2)
#     tmp2 = 2*a_axial2.dot(tmp2)
#
#     print('tmp2',tmp2)
#     energy_2 = tmp - tmp2
#     print('energy_2',energy_2)
#
#
#     energy_axial2.append(energy_2)
#
#
#
#
#
# print('theta_values', theta_values)
#
#
#


#
#
#
#
# kappas = []
# alphas = []
# # G.append(float(s[0]))
#
#
#
#
# for t in T :
#
#     abar_current = sstar*abar+t*abarperp;
#     # print('abar_current', abar_current)
#     abar_current[abar_current < 1e-10] = 0
#     # print('abar_current', abar_current)
#
#     # G = np.array([[2*q1, q12+2*q3], [q12+2*q3,2*q2] ])
#     G = [abar_current[0], abar_current[1] , (2*abar_current[0]*abar_current[1])**0.5 ]
#
#     e = [(abar_current[0]/(abar_current[0]+abar_current[1]))**0.5, (abar_current[1]/(abar_current[0]+abar_current[1]))**0.5]
#     kappa = abar_current[0]+abar_current[1]
#     alpha = math.atan2(e[1], e[0])
#
#     print('angle current:', alpha)
#
#     kappas.append(kappa)
#     alphas.append(alpha)
#
#
#
# alphas = np.array(alphas)
# kappas = np.array(kappas)
#
#
# print('kappas:',kappas)
# print('alphas:',alphas)
# print('min alpha:', min(alphas))
# print('min kappa:', min(kappas))
#
# mpl.rcParams['text.usetex'] = True
# mpl.rcParams["font.family"] = "serif"
# mpl.rcParams["font.size"] = "9"
# width = 6.28 *0.5
# height = width / 1.618
# fig = plt.figure()
# # ax = plt.axes((0.15,0.21 ,0.75,0.75))
# ax = plt.axes((0.15,0.21 ,0.8,0.75))
# ax.tick_params(axis='x',which='major', direction='out',pad=5)
# ax.tick_params(axis='y',which='major', length=3, width=1, direction='out',pad=3)
# # ax.xaxis.set_major_locator(MultipleLocator(0.1))
# # ax.xaxis.set_minor_locator(MultipleLocator(0.05))
# # ax.xaxis.set_major_locator(plt.MultipleLocator(np.pi / 8))
# # ax.xaxis.set_minor_locator(plt.MultipleLocator(np.pi / 16))
# ax.xaxis.set_major_locator(plt.MultipleLocator(np.pi / 2))
# ax.xaxis.set_minor_locator(plt.MultipleLocator(np.pi / 4))
# ax.xaxis.set_major_formatter(plt.FuncFormatter(format_func))
# ax.grid(True,which='major',axis='both',alpha=0.3)
#
#
#
#
# ax.plot(alphas, kappas, 'royalblue', zorder=3, )
# ax.plot(-1.0*alphas, kappas, 'red', zorder=3, )