import numpy as np
import matplotlib.pyplot as plt
import sympy as sym
import math
import os
import subprocess
import fileinput
import re
import matlab.engine
import sys
from ClassifyMin import *
from HelperFunctions import *
# from CellScript import *
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm
from vtk.util import numpy_support
from pyevtk.hl import gridToVTK
import time
import matplotlib.ticker as ticker

import matplotlib as mpl
from matplotlib.ticker import MultipleLocator,FormatStrFormatter,MaxNLocator
import pandas as pd

import seaborn as sns
import matplotlib.colors as mcolors


from mpl_toolkits.mplot3d.proj3d import proj_transform
from mpl_toolkits.mplot3d import proj3d
# from mpl_toolkits.mplot3d.axes3d import Axes3D
from matplotlib.text import Annotation
from matplotlib.patches import FancyArrowPatch


import mayavi.mlab as mlab
from mayavi.api import OffScreenEngine
mlab.options.offscreen = True

from chart_studio import plotly
import plotly.graph_objs as go
# from matplotlib import rc
# rc('text', usetex=True) # Use LaTeX font
#
# import seaborn as sns
# sns.set(color_codes=True)

# set the colormap and centre the colorbar
class MidpointNormalize(mcolors.Normalize):
	"""
	Normalise the colorbar so that diverging bars work there way either side from a prescribed midpoint value)

	e.g. im=ax1.imshow(array, norm=MidpointNormalize(midpoint=0.,vmin=-100, vmax=100))
	"""
	def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
		self.midpoint = midpoint
		mcolors.Normalize.__init__(self, vmin, vmax, clip)

	def __call__(self, value, clip=None):
		# I'm ignoring masked values and all kinds of edge cases to make a
		# simple example...
		x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
		return np.ma.masked_array(np.interp(value, x, y), np.isnan(value))



def format_func(value, tick_number):
    # find number of multiples of pi/2
    # N = int(np.round(2 * value / np.pi))
    # if N == 0:
    #     return "0"
    # elif N == 1:
    #     return r"$\pi/2$"
    # elif N == -1:
    #     return r"$-\pi/2$"
    # elif N == 2:
    #     return r"$\pi$"
    # elif N % 2 > 0:
    #     return r"${0}\pi/2$".format(N)
    # else:
    #     return r"${0}\pi$".format(N // 2)
    ##find number of multiples of pi/2
    N = int(np.round(4 * value / np.pi))
    if N == 0:
        return "0"
    elif N == 1:
        return r"$\pi/4$"
    elif N == -1:
        return r"$-\pi/4$"
    elif N == 2:
        return r"$\pi/2$"
    elif N == -2:
        return r"$-\pi/2$"
    elif N % 2 > 0:
        return r"${0}\pi/2$".format(N)
    else:
        return r"${0}\pi$".format(N // 2)



def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]


def find_nearestIdx(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx



def energy(a1,a2,q1,q2,q12,q3,b1,b2):


    a = np.array([a1,a2])
    b = np.array([b1,b2])
    H = np.array([[2*q1, q12+2*q3], [q12+2*q3,2*q2] ])
    A = np.array([[q1,1/2*q12], [1/2*q12,q2] ])


    tmp = H.dot(a)

    # print('H',H)
    # print('A',A)
    # print('b',b)
    # print('a',a)
    # print('tmp',tmp)

    tmp = (1/2)*a.dot(tmp)
    # print('tmp',tmp)

    tmp2 = A.dot(b)
    # print('tmp2',tmp2)
    tmp2 = 2*a.dot(tmp2)

    # print('tmp2',tmp2)
    energy = tmp - tmp2
    # print('energy',energy)


    # energy_axial1.append(energy_1)

    return energy



def evaluate(x,y):

	# (abar[0,:]*abar[1,:])**0.5

	return np.sqrt(x*y)


# def energy(a1,a2,q1,q2,q12,q3,b1,b2):
#
#
#     b = np.array([b1,b2])
#     H = np.array([[2*q1, q12+2*q3], [q12+2*q3,2*q2] ])
#     A = np.array([[q1,1/2*q12], [1/2*q12,q2] ])
#
#
#     tmp = H.dot(a)
#
#     print('H',H)
#     print('A',A)
#     print('b',b)
#     print('a',a)
#     print('tmp',tmp)
#
#     tmp = (1/2)*a.dot(tmp)
#     print('tmp',tmp)
#
#     tmp2 = A.dot(b)
#     print('tmp2',tmp2)
#     tmp2 = 2*a.dot(tmp2)
#
#     print('tmp2',tmp2)
#     energy = tmp - tmp2
#     print('energy',energy)
#
#
#     # energy_axial1.append(energy_1)
#
#     return energy
#

def add_arrow(line, position=None, direction='right', size=15, color=None):
    """
    add an arrow to a line.

    line:       Line2D object
    position:   x-position of the arrow. If None, mean of xdata is taken
    direction:  'left' or 'right'
    size:       size of the arrow in fontsize points
    color:      if None, line color is taken.
    """
    if color is None:
        color = line.get_color()

    xdata = line.get_xdata()
    ydata = line.get_ydata()

    if position is None:
        position = xdata.mean()
    # find closest index
    start_ind = np.argmin(np.absolute(xdata - position))
    if direction == 'right':
        end_ind = start_ind + 1
    else:
        end_ind = start_ind - 1

    line.axes.annotate('',
        xytext=(xdata[start_ind], ydata[start_ind]),
        xy=(xdata[end_ind], ydata[end_ind]),
        arrowprops=dict(arrowstyle="->", color=color),
        size=size
    )



class Annotation3D(Annotation):
    def __init__(self, text, xyz, *args, **kwargs):
        super().__init__(text, xy=(0, 0), *args, **kwargs)
        self._xyz = xyz

    def draw(self, renderer):
        x2, y2, z2 = proj_transform(*self._xyz, self.axes.M)
        self.xy = (x2, y2)
        super().draw(renderer)

def _annotate3D(ax, text, xyz, *args, **kwargs):
    '''Add anotation `text` to an `Axes3d` instance.'''

    annotation = Annotation3D(text, xyz, *args, **kwargs)
    ax.add_artist(annotation)

setattr(Axes3D, 'annotate3D', _annotate3D)



class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        FancyArrowPatch.__init__(self, (0,0), (0,0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)
        self.set_positions((xs[0],ys[0]),(xs[1],ys[1]))
        FancyArrowPatch.draw(self, renderer)


def _arrow3D(ax, x, y, z, dx, dy, dz, *args, **kwargs):
    '''Add an 3d arrow to an `Axes3D` instance.'''

    arrow = Arrow3D(x, y, z, dx, dy, dz, *args, **kwargs)
    ax.add_artist(arrow)

setattr(Axes3D, 'arrow3D', _arrow3D)

################################################################################################################
################################################################################################################
################################################################################################################

InputFile  = "/inputs/computeMuGamma.parset"
OutputFile = "/outputs/outputMuGamma.txt"
# --------- Run  from src folder:
path_parent = os.path.dirname(os.getcwd())
os.chdir(path_parent)
path = os.getcwd()
print(path)
InputFilePath = os.getcwd()+InputFile
OutputFilePath = os.getcwd()+OutputFile
print("InputFilepath: ", InputFilePath)
print("OutputFilepath: ", OutputFilePath)
print("Path: ", path)

print('---- Input parameters: -----')

# q1=1;
# q2=2;
# q12=1/2;
# q3=((4*q1*q2)**0.5-q12)/2;
# # H=[2*q1,q12+2*q3;q12+2*q3,2*q2];
#
# H = np.array([[2*q1, q12+2*q3], [q12+2*q3,2*q2] ])
# A = np.array([[q1,1/2*q12], [1/2*q12,q2] ])
# abar = np.array([q12+2*q3, 2*q2])
# abar = (1.0/math.sqrt((q12+2*q3)**2+(2*q2)**2))*abar
#
# print('abar:',abar)
#
# b = np.linalg.lstsq(A, abar)[0]
# print('b',b)
#
#
# # print('abar:',np.shape(abar))
# # print('np.transpose(abar):',np.shape(np.transpose(abar)))
# sstar = (1/(q1+q2))*abar.dot(A.dot(b))
# # sstar = (1/(q1+q2))*abar.dot(tmp)
# print('sstar', sstar)
# abarperp= np.array([abar[1],-abar[0]])
# print('abarperp:',abarperp)


# -------------------------- Input Parameters --------------------

mu1 = 1.0
rho1 = 1.0
alpha = 5.0
theta = 1.0/2
# theta= 0.1
beta = 5.0



# mu1 = 1.0
# rho1 = 1.0
# alpha = 2.0
# theta = 1.0/2
# # theta= 0.1
# beta = 5.0


#Figure3:
# mu1 = 1.0
# rho1 = 1.0
# alpha = 2.0
# theta = 1.0/8
# # theta= 0.1
# beta = 2.0


# alpha= -5


#set gamma either to 1. '0' 2. 'infinity' or 3. a numerical positive value
gamma = '0'
gamma = 'infinity'


lambda1 = 0.0


print('---- Input parameters: -----')
print('mu1: ', mu1)
print('rho1: ', rho1)
# print('alpha: ', alpha)
print('beta: ', beta)
# print('theta: ', theta)
print('gamma:', gamma)

print('lambda1: ', lambda1)
print('----------------------------')
# ----------------------------------------------------------------
print('----------------------------')

# ----------------------------------------------------------------






q1 = (1.0/6.0)*harmonicMean(mu1, beta, theta)
q2 = (1.0/6.0)*arithmeticMean(mu1, beta, theta)
q12 = 0.0
q3 = GetMuGamma(beta, theta,gamma,mu1,rho1,InputFilePath ,OutputFilePath )
b1 = prestrain_b1(rho1,beta, alpha, theta )
b2 = prestrain_b2(rho1,beta, alpha, theta )


# 1-ParameterFamilyCase:

q1=1;
q2=2;
q12=1/2;
q3=((4*q1*q2)**0.5-q12)/2;
# H=[2*q1,q12+2*q3;q12+2*q3,2*q2];

H = np.array([[2*q1, q12+2*q3], [q12+2*q3,2*q2] ])
A = np.array([[q1,1/2*q12], [1/2*q12,q2] ])
abar = np.array([q12+2*q3, 2*q2])
abar = (1.0/math.sqrt((q12+2*q3)**2+(2*q2)**2))*abar

print('abar:',abar)

b = np.linalg.lstsq(A, abar)[0]
print('b',b)

b1=b[0]
b2=b[1]


# print('abar:',np.shape(abar))
# print('np.transpose(abar):',np.shape(np.transpose(abar)))
sstar = (1/(q1+q2))*abar.dot(A.dot(b))
# sstar = (1/(q1+q2))*abar.dot(tmp)
print('sstar', sstar)
abarperp= np.array([abar[1],-abar[0]])
print('abarperp:',abarperp)

print('----------------------------')

# ----------------------------------------------------------------

N=1000;
# scale_domain = 5
# translate_startpoint = -5

scale_domain = 5
translate_startpoint = -1.8


# T = np.linspace(-sstar*(q12+2*q3)/(2*q2), sstar*(2*q2)/(q12+2*q3), num=N)
T = np.linspace(-sstar*(q12+2*q3)/(2*q2)*scale_domain + translate_startpoint, sstar*(2*q2)/(q12+2*q3)*scale_domain , num=N)
# T = np.linspace(-2,2, num=N)
# print('T:', T)

print('T.min():', T.min())
print('T.max():', T.max())

kappas = []
alphas = []
# G.append(float(s[0]))

G_container = []
abar_container = []


abar_tmp = abar

for t in T :
	abar_current = sstar*abar+t*abarperp;
    # print('abar_current', abar_current)
	abar_current[abar_current < 1e-10] = 0
	# print('abar_current', abar_current)
	# G = np.array([[2*q1, q12+2*q3], [q12+2*q3,2*q2] ])
	G = [abar_current[0], abar_current[1] , (2*abar_current[0]*abar_current[1])**0.5 ]
	# print('type of G', type(G))
	# print('G', G)
	G_container.append(G)
	abar_container.append(abar_current)
	e = [(abar_current[0]/(abar_current[0]+abar_current[1]))**0.5, (abar_current[1]/(abar_current[0]+abar_current[1]))**0.5]
	kappa = abar_current[0]+abar_current[1]
	alpha = math.atan2(e[1], e[0])
	# print('angle current:', alpha)
	kappas.append(kappa)
	alphas.append(alpha)


alphas = np.array(alphas)
kappas = np.array(kappas)

# print('G_container', G_container)
G = np.array(G_container)
abar = np.array(abar_container)

print('G', G)
print('abar', abar)
print('abar.shape',abar.shape)





print('q1 = ', q1)
print('q2 = ', q2)
print('q3 = ', q3)
print('q12 = ', q12)
print('b1 = ', b1)
print('b2 = ', b2)

num_Points = 20
num_Points = 50
num_Points = 200

# Creating dataset
x = np.linspace(-5,5,num_Points)
y = np.linspace(-5,5,num_Points)

x = np.linspace(-20,20,num_Points)
y = np.linspace(-20,20,num_Points)

# x = np.linspace(-10,10,num_Points)
# y = np.linspace(-10,10,num_Points)

# x = np.linspace(-60,60,num_Points)
# y = np.linspace(-60,60,num_Points)
#
#
# x = np.linspace(-40,40,num_Points)
# y = np.linspace(-40,40,num_Points)


range = 2

x_1 = np.linspace(0,range,num_Points)
y_1 = np.linspace(0,range,num_Points)
x_2 = np.linspace(-range,0,num_Points)
y_2 = np.linspace(-range,0,num_Points)


X_1,Y_1 = np.meshgrid(x_1,y_1)
X_2,Y_2 = np.meshgrid(x_2,y_2)


a1, a2 = np.meshgrid(x,y)

# geyser = sns.load_dataset("geyser")
# print('type of geyser:', type(geyser))
# print('geyser:',geyser)

ContourRange=20

x_in = np.linspace(-ContourRange,ContourRange,num_Points)
y_in = np.linspace(-ContourRange,ContourRange,num_Points)
a1_in, a2_in = np.meshgrid(x_in,y_in)

# print('a1:', a1)
# print('a2:',a2 )
#
# print('a1.shape', a1.shape)

#-- FILTER OUT VALUES for G+ :

tmp1 = a1[np.where(a1*a2 >= 0)]
tmp2 = a2[np.where(a1*a2 >= 0)]

# tmp1 = a1[a1*a2 >= 0]
# tmp2 = a2[a1*a2 >= 0]

# tmp1 = a1[np.where(a1>=0 and a2 >= 0)]
# tmp2 = a2[np.where(a1>=0 and a2 >= 0)]


# tmp1 = tmp1[np.where(a1 >= 0)]
# tmp2 = tmp2[np.where(a1 >= 0)]

# tmp1_pos = a1[np.where(a1*a2 >= 0)]
# tmp2_neg = a2[np.where(a1*a2 >= 0)]



print('tmp1.shape',tmp1.shape)
print('tmp1.shape[0]',tmp1.shape[0])
print('tmp2.shape',tmp2.shape)
print('tmp2.shape[0]',tmp2.shape[0])



tmp1 = tmp1.reshape(-1,int(tmp1.shape[0]/2))
tmp2 = tmp2.reshape(-1,int(tmp2.shape[0]/2))

print('tmp1.shape',tmp1.shape)
print('tmp1.shape[0]',tmp1.shape[0])
print('tmp2.shape',tmp2.shape)
print('tmp2.shape[0]',tmp2.shape[0])

# np.take(a, np.where(a>100)[0], axis=0)
# tmp1 = np.take(a1, np.where(a1*a2 >= 0)[0], axis=0)
# tmp2 = np.take(a1, np.where(a1*a2 >= 0)[0], axis=0)
# tmp2 = a2[np.where(a1*a2 >= 0)]

# tmp1 = a1[a1*a2 >= 0]
# tmp2 = a2[a1*a2 >= 0]


# tmp1_pos = a1[np.where(a1*a2 >= 0) ]
# tmp2_pos = a2[np.where(a1*a2 >= 0) ]
# tmp1_pos = tmp1_pos[np.where(tmp1_pos  >= 0)]
# tmp2_pos = tmp2_pos[np.where(tmp2_pos  >= 0)]

# tmp1_neg = a1[a1*a2 >= 0 ]
# tmp2_neg = a2[a1*a2 >= 0 ]
# tmp1_neg = tmp1_neg[tmp1_neg  < 0]
# tmp2_neg = tmp2_neg[tmp2_neg  < 0]
# a1 = tmp1
# a2 = tmp2
#
# a1 = a1.reshape(-1,5)
# a2 = a2.reshape(-1,5)

# tmp1_pos = tmp1_pos.reshape(-1,5)
# tmp2_pos = tmp2_pos.reshape(-1,5)
# tmp1_neg = tmp1_neg.reshape(-1,5)
# tmp2_neg = tmp2_neg.reshape(-1,5)


# print('a1:', a1)
# print('a2:',a2 )
# print('a1.shape', a1.shape)





energyVec = np.vectorize(energy)

# Z = energyVec(np.array([a1,a2]),q1,q2,q12,q3,b1,b2)
# Z = energyVec(a1,a2,q1,q2,q12,q3,b1,b2)
#
# Z_in = energyVec(a1_in,a2_in,q1,q2,q12,q3,b1,b2)




# Z = (tmp2**2)/tmp1
Z = np.sqrt(tmp1*tmp2)


Z1 = np.sqrt(X_1*Y_1)
Z2 = np.sqrt(X_2*Y_2)

# Z_bar = np.sqrt(abar[0,:]*abar[1,:])
Z_bar = (abar[0,:]*abar[1,:])**0.5*abar

abar = abar.T



v1 = abar[0,:]
v2 = abar[1,:]

# print('a1:', a1)
# print('a2:',a2 )
# print('a1.shape', a1.shape)


evaluateVec = np.vectorize(evaluate)
Z_bar = evaluateVec(abar[0,:],abar[1,:])
# Z = np.sqrt(np.multiply(tmp1,tmp2))
# Z = np.sqrt(a1*a2)


print('v1.shape', v1.shape)
print('v1', v1)




print('Z:', Z)
print('Z_bar:', Z_bar)
# print('any', np.any(Z<0))

#


# negZ_a1 = a1[np.where(Z<0)]
# negZ_a2 = a2[np.where(Z<0)]
# negativeValues = Z[np.where(Z<0)]
# print('negativeValues:',negativeValues)
#
# print('negZ_a1',negZ_a1)
# print('negZ_a2',negZ_a2)
#
#
# negZ_a1 = negZ_a1.reshape(-1,5)
# negZ_a2 = negZ_a2.reshape(-1,5)
# negativeValues = negativeValues.reshape(-1,5)
#
# Z_pos =  energyVec(tmp1_pos,tmp2_pos,q1,q2,q12,q3,b1,b2)
# Z_neg = energyVec(tmp1_neg,tmp2_neg,q1,q2,q12,q3,b1,b2)





# print('Test energy:' , energy(np.array([1,1]),q1,q2,q12,q3,b1,b2))




# print('Z_pos.shape', Z_pos.shape)







## -- PLOT :
mpl.rcParams['text.usetex'] = True
mpl.rcParams["font.family"] = "serif"
mpl.rcParams["font.size"] = "9"

label_size = 8
mpl.rcParams['xtick.labelsize'] = label_size
mpl.rcParams['ytick.labelsize'] = label_size

# plt.style.use('seaborn')
plt.style.use('seaborn-whitegrid')
# sns.set()
# plt.style.use('seaborn-whitegrid')

label_size = 9
mpl.rcParams['xtick.labelsize'] = label_size
mpl.rcParams['ytick.labelsize'] = label_size

width = 6.28 *0.5
# width = 6.28
height = width / 1.618
height = width
fig = plt.figure()

ax = plt.axes(projection ='3d', adjustable='box')
# ax = plt.axes((0.17,0.21 ,0.75,0.75))
# ax = plt.axes((0.15,0.18,0.8,0.8))
# ax.tick_params(axis='x',which='major', direction='out',pad=5)
# ax.tick_params(axis='y',which='major', length=3, width=1, direction='out',pad=3)
# ax.xaxis.set_major_locator(MultipleLocator(0.1))
# ax.xaxis.set_minor_locator(MultipleLocator(0.05))
# ax.xaxis.set_major_locator(plt.MultipleLocator(np.pi / 8))
# ax.xaxis.set_minor_locator(plt.MultipleLocator(np.pi / 16))
# ax.xaxis.set_major_locator(plt.MultipleLocator(np.pi / 2))
# ax.xaxis.set_minor_locator(plt.MultipleLocator(np.pi / 4))
# ax.xaxis.set_major_formatter(plt.FuncFormatter(format_func))
# ax.grid(True,which='major',axis='both',alpha=0.3)
ax.grid(True,which='major',axis='xy',alpha=0.3)
# ax.grid(False,which='major',alpha=0.3)
# Hide grid lines
# ax.grid(False)



# colorfunction=(B*kappa)
# print('colofunction',colorfunction)

#translate Data
# Z = Z - (Z.max()-Z.min())/2
# Z = Z - 50
# Z = Z - 500
#
# Z = Z.T


# Substract constant:
# c = (b1**2)*q1+b1*b2*q12+(b2**2)*q2
# Z = Z-c

# print('Value of c:', c)



# print('Z.min()', Z.min())
# print('Z.max()', Z.max())
# norm=mcolors.Normalize(Z.min(),Z.max())
# facecolors=cm.brg(norm)


# print('norm:', norm)
# print('type of norm', type(norm))
# print('norm(0):', norm(0))
# print('norm(Z):', norm(Z))

# ax.plot(theta_rho, theta_values, 'royalblue', zorder=3, )

# ax.scatter(a1,a2, s=0.5)

# ax.scatter(tmp1_pos,tmp2_pos, s=0.5)
# ax.scatter(tmp1_neg,tmp2_neg, s=0.5)

# CS = ax.contour(a1, a2, Z,10, cmap=plt.cm.gnuplot, levels=100 )
# CS = ax.contour(a1, a2, Z,10, cmap=plt.cm.gnuplot,  levels=20 )


# sns.kdeplot(np.array([a1, a2, Z]))
# sns.kdeplot(tmp1_pos,tmp2_pos,Z_pos)

# levels = [-5.0, -4, -3, 0.0, 1.5, 2.5, 3.5]
# CS = ax.contour(a1, a2, Z,10, cmap=plt.cm.gnuplot, corner_mask=True,levels=levels)
# CS = ax.contour(a1, a2, Z, cmap=plt.cm.gnuplot(norm(Z)), corner_mask=True)
# CS = ax.contour(a1, a2, Z, cm.brg(norm(Z)), levels=20)
# CS = ax.contour(a1, a2, Z, cmap=plt.cm.gnuplot, levels=20)
# CS = ax.contour(a1, a2, Z, colors='k', levels=14,  linewidths=(0.5,))
# CS = ax.contour(a1, a2, Z, colors='k', levels=18,  linewidths=(0.5,))

# ax.contour(negZ_a1, negZ_a2, negativeValues, colors='k',  linewidths=(0.5,))
# CS =  ax.contour(a1_in, a2_in, Z_in, colors='k',  linewidths=(0.5,))



# df = pd.DataFrame(data=Z_in, columns=a1_in, index=a2_in)
# df2 = pd.DataFrame(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
#                    columns=['a', 'b', 'c'])




# sns.kdeplot(data=df2, x="waiting", y="duration")
# sns.kdeplot(data=df2)

# CS = ax.contour(a1, a2, Z, colors='k',   linewidths=(0.5,))

# CS = ax.contour(a1, a2, Z,10, cmap=plt.cm.gnuplot, extend='both', levels=50)
# CS = ax.contourf(a1, a2, Z,10, colors='k', extend='both', levels=50)
# CS = ax.contourf(a1, a2, Z,10, colors='k')
#
# # CS = ax.contour(tmp1_pos,tmp2_pos, Z_pos,10, cmap=plt.cm.gnuplot, levels=10 )
# # CS = ax.contour(tmp1_pos,tmp2_pos, Z_pos,10, cmap=plt.cm.gnuplot, corner_mask=True)
#
# CS = ax.contour(a1, a2, Z,10,  colors = 'k')
# ax.clabel(CS, inline=True, fontsize=4)


# cmap = cm.brg(norm(Z))
#
# C_map = cm.inferno(norm(Z))

# ax.imshow(Z, cmap=C_map, extent=[-20, 20, -20, 20], origin='lower', alpha=0.5)

# ax.imshow(norm(Z), extent=[-20, 20, -20, 20], origin='lower',
#                   cmap='bwr', alpha=0.8)

# ax.imshow(norm(Z), extent=[-20, 20, -20, 20],origin='lower', vmin=Z.min(), vmax=Z.max(),
#                   cmap='bwr', alpha=0.6)

# ax.imshow(norm(Z), extent=[-20, 20, -20, 20],origin='lower', norm = norm,
#                   cmap='coolwarm', alpha=0.6)


cmap=mpl.cm.RdBu_r
# cmap=mpl.cm.viridis_r
cmap=mpl.cm.bwr
# # cmap=mpl.cm.coolwarm
# # cmap=mpl.cm.gnuplot
# cmap=mpl.cm.viridis
# cmap=mpl.cm.inferno
cmap=mpl.cm.Blues
# cmap=mpl.cm.magma
# cmap=mpl.cm.cividis
# cmap=mpl.cm.gnuplot
# cmap=mpl.cm.gnuplot
cmap = mpl.colors.ListedColormap(["royalblue"], name='from_list', N=None)
# m = cm.ScalarMappable(norm=norm, cmap=cmap)
# m = cm.ScalarMappable(cmap=cmap)

# cmap = cm.brg(Z)
# divnorm=mcolors.TwoSlopeNorm(vmin=Z.min(), vcenter=0., vmax=Z.max())

# ax.imshow(Z, extent=[-20, 20, -20, 20],origin='lower', norm = norm,
#                   cmap='coolwarm', alpha=0.6)

# ax.imshow(Z, extent=[-20, 20, -20, 20],origin='lower',
#                   cmap='coolwarm', alpha=0.6)
# ax.imshow(Z, extent=[-20, 20, -20, 20],origin='lower',
#                   cmap=cmap, alpha=0.6)

# divnorm=mcolors.TwoSlopeNorm(vmin=Z.min(), vcenter=0., vmax=Z.max())
# plt.imshow(Z, extent=[x.min(), x.max(), y.min(), y.max()],origin='lower',
#                   cmap=cmap, alpha=0.6)

# plt.imshow(Z, extent=[x.min(), x.max(), y.min(), y.max()],origin='lower', norm = divnorm,
#                   cmap=cmap, alpha=0.6)




# COLORBAR :
# cbar = plt.colorbar()
# cbar.ax.tick_params(labelsize=8)




# ##----- ADD RECTANGLE TO COVER QUADRANT :
# epsilon = 0.4
# epsilon = 0.1
# # ax.axvspan(0, x.max(), y.min(), 0, alpha=1, color='yellow', zorder=5)#yellow
# # ax.fill_between([0, x.max()], y.min(), 0, alpha=0.3, color='yellow', zorder=5)#yellow
# # ax.fill_between([x.min(), 0], 0, y.max(), alpha=0.3, color='yellow', zorder=5)#yellow
# ax.fill_between([0+epsilon, x.max()], y.min(), 0-epsilon, alpha=0.7, color='gray', zorder=5)#yellow
# ax.fill_between([x.min(), 0-epsilon], 0+epsilon, y.max(), alpha=0.7, color='gray', zorder=5)#yellow




print('abar', abar)
print('abar.shape',abar.shape)

line = ax.plot(abar[0,:],abar[1,:],Z_bar, linewidth=2, color='coral', zorder=1)
# CS = ax.contour(X_1,Y_1,Z1, colors='k', levels=18,  linewidths=(0.5,))

start = np.array([abar[0,499],abar[1,499],Z_bar[499]])
end = np.array([abar[0,500],abar[1,500],Z_bar[500]])





# idx = np.where(np.round(Z_bar,3) == np.round( 0.03581463,3) )
idx = np.where(np.round(Z_bar,3) == np.round( 0.02823972,3) )
print('idx[0][0]', idx[0][0])

# abar_1 = abar[0,0:idx[0][0]]
# abar_2 = abar[1,0:idx[0][0]]
line = ax.plot(abar[0,idx[0][0]:-1],abar[1,idx[0][0]:-1],Z_bar[idx[0][0]:-1], linewidth=2, color='coral', zorder=5)

# ax.plot_surface(a1,a2, Z, cmap=cm.coolwarm,
#                        linewidth=0, antialiased=False)

# ax.plot_surface(tmp1,tmp2, Z, cmap=cm.coolwarm,
#                        linewidth=0, antialiased=False)

# ax.scatter(X_1,Y_1,Z1, s=0.2)
# ax.scatter(X_2,Y_2,Z2, s=0.2)


# X = np.concatenate((X_1, X_2), axis=0)
# Y = np.concatenate((Y_1, Y_2), axis=0)
# Z = np.concatenate((Z1, Z2), axis=0)
# ax.plot_surface(X,Y,Z)


ax.plot_surface(X_1,Y_1,Z1 ,cmap=cmap,
                       linewidth=0, antialiased=True,alpha=.35, zorder=5)
ax.plot_surface(X_2,Y_2,Z2 ,cmap=cmap,
                       linewidth=0, antialiased=True,alpha=.35, zorder=5)
# ax.plot_surface(X_1,Y_1,Z1 , facecolor = 'lightblue',     #cmap=cmap,
#                        linewidth=0, antialiased=True, alpha=1, zorder=5)
# ax.plot_surface(X_2,Y_2,Z2 , facecolor = 'lightblue', edgecolor='none',    #cmap=cmap,
#                        linewidth=0, antialiased=True, alpha=1, zorder=5)

# ax.plot_surface(X_1,Y_1,Z1 , #color='C0',
# 							rstride=1, cstride=1,linewidth=0, antialiased=True, alpha=1, zorder=5)




# ax.plot_surface(X_2,Y_2,Z2 , #color='C0',
# 							rstride=1, cstride=1,linewidth=0,  alpha=0.8, zorder=5, shade=True)
# ax.plot_surface(X_2,Y_2,Z2)

# X_2 = X_2.reshape(-1,1).flatten()
# Y_2 = Y_2.reshape(-1,1).flatten()
# Z2 = Z2.reshape(-1,1).flatten()
#
# ax.plot_trisurf(X_2,Y_2,Z2, color='blue' )


# X_1 = X_1.reshape(-1,1).flatten()
# Y_1 = Y_1.reshape(-1,1).flatten()
# Z1 = Z1.reshape(-1,1).flatten()
# ax.plot_trisurf(X_1,Y_1,Z1 , color='blue')


### MAYAVI TEST
# mlab.figure(bgcolor=(1.0, 1.0, 1.0), size=(1000,1000))
# mlab.view(azimuth=90, elevation=125)
# mlab.view(azimuth=100, elevation=115)
# axes = mlab.axes(color=(0, 0, 0), nb_labels=5)
# mlab.orientation_axes()
# mlab.mesh(X_1, Y_1,Z1, color=(0,0,1) ,  transparent=True )
# mlab.plot3d(abar[0,:],abar[1,:],Z_bar, line_width=1)
# mlab.mesh(X_2, Y_2,Z2)
# mlab.savefig("./example.png")
### --------------------------------------------



# ax.plot_surface(X_1,Y_1,Z1 , cmap=cmap,
#                        linewidth=0, antialiased=False,alpha=1, zorder=5)
# ax.plot_surface(X_2,Y_2,Z2 , cmap=cmap,
#                        linewidth=0, antialiased=True,alpha=1, zorder=5)

# ax.plot_surface(X_2,Y_2,Z2 , color = 'lightblue',     #cmap=cmap,
#                        linewidth=0, antialiased=True, alpha=1, zorder=5)


# ax.plot(G[0,:],G[1,:],G[2,:])
# ax.plot(abar[0,:],abar[1,:],Z_bar, linewidth=2, color='yellow', linestyle='--')
# ax.scatter(abar[0,:],abar[1,:],Z_bar,  color='purple', zorder=5)
# ax.plot(abar[0,:],abar[1,:],Z_bar, linewidth=2, color='royalblue', linestyle='--')
# ax.plot(abar[0,:],abar[1,:],Z_bar, linewidth=3, color='dodgerblue', linestyle='--')
# ax.plot(abar[0,:],abar[1,:],Z_bar, linewidth=3, color='cornflowerblue', linestyle='--')
# ax.plot(abar[0,:],abar[1,:],Z_bar, linewidth=3, color='darkorange', linestyle='--')
# ax.plot(abar[0,:],abar[1,:],Z_bar, linewidth=3, color='yellow', linestyle='--')
# line = ax.plot(abar[0,:],abar[1,:],Z_bar, linewidth=1, color='coral', linestyle='--', zorder=3)


# plot starting point:
# ax.scatter(abar[0,0],abar[1,0],Z_bar[0], marker='^',  s=30, color='black', zorder=5)
#
#
# ax.scatter(abar[0,500],abar[1,500],Z_bar[500], marker='^',  s=30, color='purple', zorder=5)

# ax.scatter(start[0],start[1],start[2], marker='^',  s=30, color='purple', zorder=5)
# ax.scatter(end[0],end[1],end[2], marker='^',  s=30, color='purple', zorder=5)


# define origin
o = np.array([0,0,0])

start = np.array([1,0,0])
end = np.array([2.5,0,0])


print('start:', start)
print('end:', end)


dir = end-start
print('dir:', dir)
# ax.arrow()

# ax.arrow3D(start[0],start[1],start[2],
#            dir[0],dir[1],dir[2],
#            mutation_scale=10,
#            arrowstyle="->",
#            linestyle='dashed',fc='coral',
#            lw = 1,
#            ec ='coral',
#            zorder=3)

# ax.arrow3D(midpoint_mapped[0],midpoint_mapped[1],midpoint_mapped[2],
#            normal[0],normal[1],normal[2],
#            mutation_scale=15,
#             lw = 1.5,
#            arrowstyle="-|>",
#            linestyle='dashed',fc='blue',
#            ec ='blue',
#            zorder = 5)


# ax.arrow3D(start[0],start[1],start[2],
#            dir[0],dir[1],dir[2],
#            mutation_scale=20,
#            arrowstyle="->",
#            fc='coral',
#            lw = 1,
#            ec ='coral',
#            zorder=3)
# ax.arrow3D(start[0],start[1],start[2],
#            dir[0],dir[1],dir[2],
#            mutation_scale=20,
#            arrowstyle="->",
#            fc='coral',
#            lw = 1,
#            ec ='coral',
#            zorder=3)

arrow_prop_dict = dict(mutation_scale=20, arrowstyle='-|>', color='k', shrinkA=0, shrinkB=0)
# style = ArrowStyle('Fancy', head_length=1, head_width=1.5, tail_width=0.5)


# ADD ARROW:
# a = Arrow3D([start[0], end[0]], [start[1], end[1]], [start[2], end[2]], mutation_scale=15, arrowstyle='-|>', color='darkorange')
# ax.add_artist(a)

## ADD Annotation
ax.text(0, 0, 1.5, r"$\mathcal G^+$", color='royalblue', size=12)
# ax.text(0.5, 0.5, "Test")
# ax.text(9, 0, 0, "red", color='red')

### ---- Check proof:
# value at t = 0:
abar_zero= sstar*abar_tmp
# value_0 = [abar_zero[0], abar_zero[1] , (2*abar_zero[0]*abar_zero[1])**0.5 ]
value_0 = evaluate(abar_zero[0], abar_zero[1])
print('value_0', value_0)
# ax.scatter(value_0[0],value_0[1],value_0[2], marker='x',  s=20, color='dodgerblue', zorder=5)
# ax.scatter(abar_zero[0], abar_zero[1],value_0, marker='o',  s=30, color='dodgerblue', zorder=5)
## -----------------------------










# ax.scatter(tmp1,tmp2, Z, cmap=cm.coolwarm,
#                        linewidth=0, antialiased=False)
# ax.plot_surface(tmp1,Z, tmp2, cmap=cm.coolwarm,
#                        linewidth=0, antialiased=False)
# ax.plot_trisurf(tmp1,tmp2, Z, cmap=cm.coolwarm,
#                        linewidth=0, antialiased=False)


# ax.plot(theta_rho, energy_axial1, 'royalblue', zorder=3, label=r"axialMin1")
# ax.plot(theta_rho, energy_axial2, 'forestgreen', zorder=3, label=r"axialMin2")
# ax.plot(-1.0*alphas, kappas, 'red', zorder=3, )




# lg = ax.legend(bbox_to_anchor=(0.0, 0.75), loc='upper left')




### PLOT X AND Y AXIS:
# ax.plot(ax.get_xlim(),[0,0],'k--', linewidth=0.5)
# ax.plot([0,0],ax.get_ylim(), 'k--', linewidth=0.5)
ax.plot(ax.get_xlim(),[0,0],'k--', linewidth=1 ,zorder=5)
ax.plot([0,0],ax.get_ylim(), 'k--', linewidth=1)


ax.set_xlabel(r"$a_1$", fontsize=10 ,labelpad=0)
ax.set_ylabel(r"$a_2$", fontsize=10 ,labelpad=0)


# ax.get_xaxis().set_visible(False)
# ax = plt.gca(projection="3d")
# ax._axis3don = False
ZL = ax.get_zgridlines()

# ax.set_ylabel(r"energy")


# ax.set_xticks([-np.pi/2, -np.pi/4  ,0,  np.pi/4,  np.pi/2 ])
# labels = ['$0$',r'$\pi/8$', r'$\pi/4$' ,r'$3\pi/8$' , r'$\pi/2$']
# ax.set_yticklabels(labels)





# ax.legend(loc='upper right')



fig.set_size_inches(width, height)
fig.savefig('1-ParameterFamily_G+.pdf')

plt.show()