import numpy as np
import matplotlib.pyplot as plt
import sympy as sym
import math
import os
import subprocess
import fileinput
import re
import matlab.engine
import matplotlib.ticker as tickers
import matplotlib as mpl
from matplotlib.ticker import MultipleLocator,FormatStrFormatter,MaxNLocator
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd
import matplotlib.colors as mcolors
from matplotlib import cm

from mpl_toolkits.mplot3d.proj3d import proj_transform
# from mpl_toolkits.mplot3d.axes3d import Axes3D
from matplotlib.text import Annotation
from matplotlib.patches import FancyArrowPatch

# Extra packages :
# from HelperFunctions import *
# from ClassifyMin import *
# from subprocess import Popen, PIPE
#import sys

###################### Documentation #########################

#..... add description here

###########################################################


def rot(v,alpha):

#rotate about axis v with degree deg in radians:


    tmp = np.array([ [v[0]**2*(1-np.cos(alpha))+np.cos(alpha), v[0]*v[1]*(1-np.cos(alpha))-v[2]*np.sin(alpha), v[0]*v[2]*(1-np.cos(alpha))+ v[1]*np.sin(alpha) ],
         [v[0]*v[1]*(1-np.cos(alpha))+v[2]*np.sin(alpha), v[1]**2*(1-np.cos(alpha))+np.cos(alpha), v[1]*v[2]*(1-np.cos(alpha))+v[0]*np.sin(alpha) ],
         [v[2]*v[0]*(1-np.cos(alpha))-v[1]*np.sin(alpha), v[2]*v[1]*(1-np.cos(alpha))+v[0]*np.sin(alpha) , v[2]**2*(1-np.cos(alpha))+np.cos(alpha) ] ])

    return tmp



def rotate_data(X, R):
#rotate about axis v with degree deg in radians:
# X : DataSet
# R : RotationMatrix
    print('ROTATE DATA FUNCTION ---------------')

    rot_matrix = R
    # print('rot_matrix:', rot_matrix)
    # print('rot_matrix.shape:', rot_matrix.shape)
    # print('X', X)
    # print('shape of X[0]', X.shape[0])
    B = np.dot(rot_matrix, X.reshape(rot_matrix.shape[1],-1))
    # print('shape of B', B.shape)
    # print('B',B)
    # print('B[0,:]', B[0,:])
    # print('B[0,:].shape', B[0,:].shape)
    Out = np.array([B[0,:].reshape(X.shape[1],X.shape[2]), B[1,:].reshape(X.shape[1],X.shape[2]), B[2,:].reshape(X.shape[1],X.shape[2])])
    print('shape of Out', Out.shape)

    return Out

# def rotate_data(X, v,alpha): #(Old Version)
# #rotate about axis v with degree deg in radians:
# # X : DataSet
#     print('ROTATE DATA FUNCTION ---------------')
#     # v = np.array([1,0,0])
#     # rotM = rot(v,np.pi/2)
#     # print('rotM:', rotM)
#     rot_matrix = rot(v,alpha)
#     # print('rot_matrix:', rot_matrix)
#     # print('rot_matrix.shape:', rot_matrix.shape)
#
#     # print('X', X)
#     # print('shape of X[0]', X.shape[0])
#     B = np.dot(rot_matrix, X.reshape(rot_matrix.shape[1],-1))
#
#     # print('shape of B', B.shape)
#     # print('B',B)
#     # print('B[0,:]', B[0,:])
#     # print('B[0,:].shape', B[0,:].shape)
#     Out = np.array([B[0,:].reshape(X.shape[1],X.shape[2]), B[1,:].reshape(X.shape[1],X.shape[2]), B[2,:].reshape(X.shape[1],X.shape[2])])
#     print('shape of Out', Out.shape)
#
#     return Out


# def translate_data(X, v):  ...
# #rotate about axis v with degree deg in radians:
# # X : DataSet
#     print('ROTATE DATA FUNCTION ---------------')
#     # v = np.array([1,0,0])
#     # rotM = rot(v,np.pi/2)
#     # print('rotM:', rotM)
#
#     print('X', X)
#     print('shape of X[0]', X.shape[0])
#
#     Out = X + v
#     return Out

#
# def u(x,kappa,e):
#
#     tmp = (x.dot(e))*kappa
#     # print('tmp for u',tmp)
#     if kappa == 0 :
#         tmp = np.array([0*x[0],  x[0]*e[0] + x[1]*e[1], x[1]*e[0] - x[0]*e[1] ])
#     else :
#         tmp = np.array([-(1/kappa)*np.cos(tmp)+(1/kappa),  (1/kappa)*np.sin(tmp), -x[0]*e[1]+x[1]*e[0] ])
#     return tmp
#
#
#
#
# def grad_u(x,kappa,e):
#
#     tmp = (x.dot(e))*kappa
#     # print('tmp',tmp)
#
#     grad_u = np.array([ [np.sin(tmp)*e[0], np.sin(tmp)*e[1]], [np.cos(tmp)*e[0], np.cos(tmp)*e[1]], [-e[1], e[0]] ])
#     # print('produkt', grad_u.dot(e) )
#     mapped_e = grad_u.dot(e)
#     # print('mapped_e:', mapped_e)
#     # print('siize of mapped_e', mapped_e.shape)
#     # mapped_e = mapped_e.transpose()
#     # print('mapped_e:', mapped_e)
#     # print('siize of mapped_e', mapped_e.shape)
#     return mapped_e
#
# def compute_normal(x,kappa,e):
#     tmp = (x.dot(e))*kappa
#     partial1_u = np.array([ np.sin(tmp)*e[0] ,np.cos(tmp)*e[0], -e[1] ])
#     partial2_u = np.array([ np.sin(tmp)*e[1], np.cos(tmp)*e[1], e[0]      ])
#     normal = np.cross(partial1_u,partial2_u)
#     # print('normal=',normal)
#     return normal



def u(x,kappa,e):

    tmp = (x.dot(e))*((-1)*kappa)
    # print('tmp for u',tmp)
    if kappa == 0 :
        tmp = np.array([x[0]*e[0] + x[1]*e[1], x[1]*e[0] - x[0]*e[1], 0*x[0]  ])
    else :
        tmp = np.array([ -(1/kappa)*np.sin(tmp), -x[0]*e[1]+x[1]*e[0], (1/kappa)*np.cos(tmp)-(1/kappa)   ])
    return tmp



# def grad_u(x,kappa,e):
#
#     tmp = (x.dot(e))*kappa
#     # print('tmp',tmp)
#
#     grad_u = np.array([ [np.sin(tmp)*e[0], np.sin(tmp)*e[1]], [np.cos(tmp)*e[0], np.cos(tmp)*e[1]], [-e[1], e[0]] ])
#     # print('produkt', grad_u.dot(e) )
#     mapped_e = grad_u.dot(e)
#     # print('mapped_e:', mapped_e)
#     # print('siize of mapped_e', mapped_e.shape)
#     # mapped_e = mapped_e.transpose()
#     # print('mapped_e:', mapped_e)
#     # print('siize of mapped_e', mapped_e.shape)
#     return mapped_e
#


def grad_u(x,kappa,e):

    tmp = (x.dot(e))*(-1)*kappa
    # print('tmp',tmp)

    grad_u = np.array([  [np.cos(tmp)*e[0], np.cos(tmp)*e[1]], [-e[1], e[0]], [np.sin(tmp)*e[0], np.sin(tmp)*e[1]] ])
    # print('produkt', grad_u.dot(e) )
    mapped_e = grad_u.dot(e)
    # print('mapped_e:', mapped_e)
    # print('siize of mapped_e', mapped_e.shape)
    # mapped_e = mapped_e.transpose()
    # print('mapped_e:', mapped_e)
    # print('siize of mapped_e', mapped_e.shape)
    return mapped_e


def compute_normal(x,kappa,e):
    tmp = (x.dot(e))*(-1)*kappa
    partial1_u = np.array([ np.cos(tmp)*e[0], -e[1],np.sin(tmp)*e[0]  ])
    partial2_u = np.array([ np.cos(tmp)*e[1], e[0], np.sin(tmp)*e[1]  ])
    normal = np.cross(partial1_u,partial2_u)
    # print('normal=',normal)
    return normal



class Annotation3D(Annotation):
    def __init__(self, text, xyz, *args, **kwargs):
        super().__init__(text, xy=(0, 0), *args, **kwargs)
        self._xyz = xyz

    def draw(self, renderer):
        x2, y2, z2 = proj_transform(*self._xyz, self.axes.M)
        self.xy = (x2, y2)
        super().draw(renderer)

def _annotate3D(ax, text, xyz, *args, **kwargs):
    '''Add anotation `text` to an `Axes3d` instance.'''

    annotation = Annotation3D(text, xyz, *args, **kwargs)
    ax.add_artist(annotation)

setattr(Axes3D, 'annotate3D', _annotate3D)

class Arrow3D(FancyArrowPatch):

    def __init__(self, x, y, z, dx, dy, dz, *args, **kwargs):
        super().__init__((0, 0), (0, 0), *args, **kwargs)
        self._xyz = (x, y, z)
        self._dxdydz = (dx, dy, dz)

    def draw(self, renderer):
        x1, y1, z1 = self._xyz
        dx, dy, dz = self._dxdydz
        x2, y2, z2 = (x1 + dx, y1 + dy, z1 + dz)

        xs, ys, zs = proj_transform((x1, x2), (y1, y2), (z1, z2), self.axes.M)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        super().draw(renderer)


def _arrow3D(ax, x, y, z, dx, dy, dz, *args, **kwargs):
    '''Add an 3d arrow to an `Axes3D` instance.'''

    arrow = Arrow3D(x, y, z, dx, dy, dz, *args, **kwargs)
    ax.add_artist(arrow)

setattr(Axes3D, 'arrow3D', _arrow3D)
################################################################################################################
################################################################################################################
################################################################################################################



q1=1;
q2=2;
q12=1/2;
q3=((4*q1*q2)**0.5-q12)/2;
# H=[2*q1,q12+2*q3;q12+2*q3,2*q2];

H = np.array([[2*q1, q12+2*q3], [q12+2*q3,2*q2] ])
A = np.array([[q1,1/2*q12], [1/2*q12,q2] ])
abar = np.array([q12+2*q3, 2*q2])
abar = (1.0/math.sqrt((q12+2*q3)**2+(2*q2)**2))*abar

print('abar:',abar)

b = np.linalg.lstsq(A, abar)[0]
print('b',b)


# print('abar:',np.shape(abar))
# print('np.transpose(abar):',np.shape(np.transpose(abar)))
sstar = (1/(q1+q2))*abar.dot(A.dot(b))
# sstar = (1/(q1+q2))*abar.dot(tmp)
print('sstar', sstar)
abarperp= np.array([abar[1],-abar[0]])
print('abarperp:',abarperp)

print('----------------------------')

# ----------------------------------------------------------------

N=1000;
T = np.linspace(-sstar*(q12+2*q3)/(2*q2), sstar*(2*q2)/(q12+2*q3), num=N)
print('T:', T)

kappas = []
alphas = []
# G.append(float(s[0]))

G_container = []
abar_container = []
e_container = []


for t in T :

    abar_current = sstar*abar+t*abarperp;
    # print('abar_current', abar_current)
    abar_current[abar_current < 1e-10] = 0
    # print('abar_current', abar_current)
    # G = np.array([[2*q1, q12+2*q3], [q12+2*q3,2*q2] ])
    G = [abar_current[0], abar_current[1] , (2*abar_current[0]*abar_current[1])**0.5 ]
    G_container.append(G)
    abar_container.append(abar_current)
    e = [(abar_current[0]/(abar_current[0]+abar_current[1]))**0.5, (abar_current[1]/(abar_current[0]+abar_current[1]))**0.5]
    e_container.append(e)
    kappa = abar_current[0]+abar_current[1]
    alpha = math.atan2(e[1], e[0])
    # print('angle current:', alpha)
    kappas.append(kappa)
    alphas.append(alpha)


G_container = np.array(G_container)
abar_container = np.array(abar_container)

e_container = np.array(e_container)

print('G_container', G_container)
print('G_container.shape', G_container.shape)

# idx_1 = np.where(alphas == np.pi/4)
idx_1 = np.where(np.round(alphas,2) == round(np.pi/3,2))
idx_2 = np.where(np.round(alphas,2) == 0.0)
idx_3 = np.where(np.round(alphas,2) == round(np.pi/4,2))

# idx_3 = np.where(alphas == 0)

print('Index idx_1:', idx_1)
print('Index idx_2:', idx_2)
print('Index idx_3:', idx_3)
print('Index idx_1[0][0]:', idx_1[0][0])
print('Index idx_2[0][0]:', idx_2[0][0])
print('Index idx_3[0][0]:', idx_3[0][0])

alphas = np.array(alphas)
kappas = np.array(kappas)


# print('kappas:',kappas)
# print('alphas:',alphas)
print('min alpha:', min(alphas))
print('min kappa:', min(kappas))


print('G_container[idx_1[0][0]]', G_container[idx_1[0][0]])
print('G_container[idx_2[0][0]]', G_container[idx_2[0][0]])
print('G_container[idx_3[0][0]]', G_container[idx_3[0][0]])


print('e_container[idx_1[0][0]]', e_container[idx_1[0][0]])
print('e_container[idx_2[0][0]]', e_container[idx_2[0][0]])
print('e_container[idx_3[0][0]]', e_container[idx_3[0][0]])



###################################################

reflect = False
# reflect = True



idx = 3

if idx == 1 :
    e = e_container[idx_1[0][0]]
    kappa = kappas[idx_1[0][0]]
    angle = alphas[idx_1[0][0]]
    reflect = True

if idx == 2 :
    e = e_container[idx_2[0][0]]
    kappa = kappas[idx_2[0][0]]
    angle = alphas[idx_2[0][0]]


if idx == 3 :
    e = e_container[idx_3[0][0]]
    kappa = kappas[idx_3[0][0]]
    angle = alphas[idx_3[0][0]]

### SCALE ?
# kappa = kappa*2


print('kappa:',kappa)
print('angle:',angle)
###################################################
#### TEST apply reflection
#
# G_tmp = G_container[idx_1[0][0]]
#
# print('G_tmp', G_tmp)
#
# # Basis:
# G_1 = np.array([[1.0,0.0], [0.0,0.0]])
# G_2 = np.array([[0.0,0.0], [0.0,1.0]])
# G_3 = (1/np.sqrt(2))*np.array([[0.0,1.0], [1.0,0.0]])
# print('G_1', G_1)
# print('G_2', G_2)
# print('G_3', G_3)
#
# G = G_tmp[0] * G_1 + G_tmp[1]*G_2 + G_tmp[2]*G_3
# print('G:', G )
#
# T = np.array([[1.0 , -1.0] , [-1.0,1.0]])
#
# TG = np.multiply(T,G)
# print('TG', TG)
#
#
#
# v = np.array([np.sqrt(TG[0][0]),np.sqrt(TG[1][1]) ])
# print('v', v)
# print('norm(v):', np.linalg.norm(v))
# norm_v = np.linalg.norm(v)
#
# kappa = norm_v**2
#
# e = (1/norm_v)*v
# print('e:', e)
# print('kappa:', kappa)






if reflect == True:
    reflected_e = np.array([e[0], -1*e[1]])
    e = reflected_e                                         # Correct?! Reflect e on x-Axis ??!
    print('reflected_e:', reflected_e)


############################################################################################################################################
####################################################################### KAPPA NEGATIVE ####################################################
############################################################################################################################################
# kappa = -2
num_Points = 200
# num_Points = 100


# e = np.array([1,0])
# e = np.array([0,1])
# e = np.array([1/np.sqrt(2),1/np.sqrt(2)])
# e = np.array([1/2,np.sqrt(3)/2])
# e = np.array([np.sqrt(3)/2,1/2])
# e = np.array([-1,0])
# e = np.array([0,-1])

###--- Creating dataset
x = np.linspace(-2,2,num_Points)
x = np.linspace(-2.5,2.5,num_Points)
x = np.linspace(-3,3,num_Points)
# x = np.linspace(-4,4,num_Points)

# x = np.linspace(-1.5,1.5,num_Points)
# x = np.linspace(-1,1,num_Points)
y = np.linspace(-1/2,1/2,num_Points)
y = np.linspace(-1/4,1/4,num_Points)

print('type of x', type(x))
print('max of x:', max(x))
print('max of y:', max(y))
# print('x:', x)

x1, x2 = np.meshgrid(x,y)
zero = 0*x1

if kappa == 0 :
    u1 = 0*x1
    u2 = x1*e[0] + x2*e[1]
    u3 = x2*e[0] - x1*e[1]

else :
    u1 = -(1/kappa)*np.cos(kappa*(x1*e[0]+x2*e[1])) + (1/kappa)
    u2 = (1/kappa)*np.sin(kappa*(x1*e[0]+x2*e[1]))
    u3 = x2*e[0] -x1*e[1]

# print('np.size(u1)',np.size(u1))
# print('u1.shape',u1.shape)
# colorfunction=(u1**2+u2**2)
# print('colofunction',colorfunction)

# print('u1.size:',np.size(u1))
# tmp = np.ones(np.size(u1))*kappa
# print('np.size(tmp)',np.size(tmp))
B = np.full_like(u1, 1)
# colorfunction=(u3)                                              # TODO Color by angle
# colorfunction=(np.ones(np.size(u1))*kappa)
colorfunction=(B*kappa)
# print('colofunction',colorfunction)
norm=mcolors.Normalize(colorfunction.min(),colorfunction.max())


# -----------------------------------------------------
# Display the mesh
fig = plt.figure()


width = 6.28 *0.5
width = 6.28 * 0.333
height = width / 1.618
height = width / 2.5
height = width



ax = plt.axes(projection ='3d', adjustable='box')


###---TEST MAP e-vectprs!
# e1 = np.array([1,0])
# e2 = np.array([0,1])
# e3 = np.array([1/np.sqrt(2),1/np.sqrt(2)])
# e1 = np.array([0,1])
# e2 = np.array([-1,0])
# e3 = np.array([-1/np.sqrt(2),1/np.sqrt(2)])
# e1_mapped = u(e1,kappa,e1)
# e2_mapped = u(e2,kappa,e2)
# e3_mapped = u(e3,kappa,e3)
# print('e1 mapped:',e1_mapped)
# print('e2 mapped:',e2_mapped)
# print('e3 mapped:',e3_mapped)
### -----------------------------------

#--e1 :
# Rotation_angle = -np.pi/2
# Rotation_vector = np.array([0,1,0])

#--e2:
Rotation_angle = np.pi/2
Rotation_vector = np.array([1,0,0])

###--e = np.array([1/np.sqrt(2),1/np.sqrt(2)])
# Rotation_angle = -np.pi/2
# Rotation_vector = np.array([1,0,0])
# #2te rotation :
# Rotation_angle = np.pi/4
# Rotation_vector = np.array([0,0,1])



Rotation_angle = -np.pi/2
Rotation_angle = 0
# Rotation_angle = np.pi/2
Rotation_vector = np.array([0,1,0])
Rotation_vector = np.array([1,0,0])

# rot(np.array([0,1,0]),np.pi/2)

# ZERO ROTATION
Rotation = rot(np.array([0,1,0]),0)

# if idx == 1:
# Rotation = rot(np.array([1,0,0]),np.pi)
# Rotation = rot(np.array([1,0,0]),np.pi).dot(rot(np.array([0,0,1]),angle))

Rotation = rot(np.array([1,0,0]),np.pi)

# TEST :

#DETERMINE ANGLE:
angle = math.atan2(e[1], e[0])
print('angle:', angle)

## GENERAL TRANSFORMATION / ROTATION:
# Rotation = rot(np.array([0,0,1]),angle).dot(Rotation)
Rotation = rot(np.array([1,0,0]),np.pi).dot(rot(np.array([0,0,1]),angle))
# Rotation = rot(np.array([0,0,1]),angle).dot(rot(np.array([0,1,0]),-np.pi/2))

# Rotation = rot(np.array([0,0,1]),+np.pi/4).dot(Rotation)
# Rotation = rot(np.array([0,0,1]),+np.pi/16).dot(Rotation)


# Rotation = rot(np.array([0,0,1]),np.pi/2).dot(Rotation)


### if e1:
# Rotation = rot(np.array([0,0,1]),-np.pi/4).dot(Rotation)
# Rotation = rot(np.array([0,0,1]),+np.pi/16).dot(Rotation)

# Add another rotation around z-axis:
# Rotation = rot(np.array([0,0,1]),+np.pi).dot(Rotation)
# Rotation = rot(np.array([0,0,1]),+np.pi/4).dot(Rotation)


# Rotation = rot(np.array([0,0,1]),+np.pi/8).dot(Rotation)

#e3 :
# Rotation = rot(np.array([0,1,0]),-np.pi/2)
# Rotation = rot(np.array([0,0,1]),np.pi/4).dot(rot(np.array([0,1,0]),-np.pi/2))

# Rotation = rot(np.array([0,0,1]),np.pi/4)
# Rotation = rot(np.array([1,0,0]),np.pi/4)
#### if e1 :
# Rotation = rot(np.array([0,1,0]),-np.pi/2)
#### if e2:
# Rotation = rot(np.array([0,1,0]),-np.pi/2).dot(rot(np.array([1,0,0]),-np.pi/2))
# # #### if e3 :
# zufall dass np.pi/4 genau dem Winkel angle alpha entspricht?:
# (würde) bei e_2 keinen Unterschied machen um z achse zu rotieren?!

# Rotation = rot(np.array([0,0,1]),np.pi/4).dot(rot(np.array([0,1,0]),-np.pi/2).dot(rot(np.array([1,0,0]),-np.pi/2)))
# Rotation = rot(np.array([0,0,1]),np.pi/2).dot(rot(np.array([0,1,0]),-np.pi/2).dot(rot(np.array([1,0,0]),-np.pi/2)))

# Rotation = rot(np.array([1,0,0]),np.pi/2)

# Rotation_vector = e3_mapped  #TEST
# Rotation_vector = np.array([-1/np.sqrt(2),1/np.sqrt(2)])
# Rotation_vector = np.array([0,0,1])

# v = np.array([1,0,0])
# X = np.array([u1,u2,u3])




# T = rotate_data(np.array([u1,u2,u3]),Rotation_vector,Rotation_angle)
T = rotate_data(np.array([u1,u2,u3]),Rotation)

# ax.plot_surface(T[0], T[1], T[2], color = 'w', rstride = 2, cstride = 2, facecolors=cm.brg(colorfunction), alpha=.4, zorder=4)
# ax.plot_surface(T[0], T[1], T[2], color = 'w', rstride = 1, cstride = 1, facecolors=cm.viridis(colorfunction), alpha=.4, zorder=4)


###---- PLOT PARAMETER-PLANE:
# ax.plot_surface(x1,x2,zero,color = 'w', rstride = 1, cstride = 1 )


print('------------------ Kappa : ', kappa)

#midpoint:
midpoint = np.array([(max(x)+min(x))/2,(max(y)+min(y))/2])
print('midpoint',midpoint)

# Map midpoint:
midpoint_mapped = u(midpoint,kappa,e)
print('mapped midpoint', midpoint_mapped)

#map origin
origin = np.array([0,0])
origin_mapped = u(origin,kappa,e)


mapped_e = grad_u(midpoint,kappa,e)
normal = compute_normal(midpoint,kappa,e)

print('mapped_e', mapped_e)
print('normal',normal )

#
# mapped_e = Rotation.dot(mapped_e)
# normal = Rotation.dot(normal)


# Plot Mapped_midPoint
# ax.plot(midpoint_mapped[0],midpoint_mapped[1],midpoint_mapped[2],    # data
# marker='o',     # each marker will be rendered as a circle
# markersize=4,   # marker size
# markerfacecolor='orange',   # marker facecolor
# markeredgecolor='black',  # marker edgecolor
# markeredgewidth=1,       # marker edge width
# linewidth=1,
# zorder=4)          # line width

# ax.quiver([midpoint_mapped[0]], [midpoint_mapped[1]], [midpoint_mapped[2]], [mapped_e[0]], [mapped_e[1]], [mapped_e[2]], color="red")
# ax.quiver([midpoint_mapped[0]], [midpoint_mapped[1]], [midpoint_mapped[2]], [normal[0]], [normal[1]], [normal[2]], color="blue")

# ax.arrow3D(midpoint_mapped[0],midpoint_mapped[1],midpoint_mapped[2],
#            mapped_e[0],mapped_e[1],mapped_e[2],
#            mutation_scale=15,
#            arrowstyle="-|>",
#            linestyle='dashed',fc='green',
#            lw = 2,
#            ec ='green',
#            zorder=3)
#
# ax.arrow3D(midpoint_mapped[0],midpoint_mapped[1],midpoint_mapped[2],
#            normal[0],normal[1],normal[2],
#            mutation_scale=15,
#             lw = 2,
#            arrowstyle="-|>",
#            linestyle='dashed',fc='blue',
#            ec ='blue',
#            zorder = 3)


###-- TEST Rotation :
# v = np.array([1,0,0])
# t = np.array([0,1,0])
#
# ax.arrow3D(0,0,0,
#            t[0],t[1],t[2],
#            mutation_scale=10,
#            arrowstyle="-|>",
#            linestyle='dashed',fc='blue',
#            ec ='blue')
#
# # e_extend
#
# rotM = rot(v,np.pi/2)
#
# print('rotM:', rotM)
#
# rot_t = rotM.dot(t)
#
# print('rot_t:', rot_t)
#
# ax.arrow3D(0,0,0,
#            rot_t[0],rot_t[1],rot_t[2],
#            mutation_scale=10,
#            arrowstyle="-|>",
#            linestyle='dashed',fc='blue',
#            ec ='blue')

### -------------------------------------------

############################################################################################################################################
####################################################################### KAPPA POSITIVE  ####################################################
############################################################################################################################################

# kappa = (-1)*kappa

# if kappa == 0 :
#     u1 = 0*x1
#     u2 = x1*e[0] + x2*e[1]
#     u3 = x2*e[0] - x1*e[1]
# else :
#     u1 = -(1/kappa)*np.cos(kappa*(x1*e[0]+x2*e[1])) + (1/kappa)
#     u2 = (1/kappa)*np.sin(kappa*(x1*e[0]+x2*e[1]))
#     u3 = x2*e[0] -x1*e[1]

if kappa == 0 :
    # u1 = 0*x1
    # u2 = x1*e[0] + x2*e[1]
    # u3 = x2*e[0] - x1*e[1]
    u1 = x1*e[0] + x2*e[1]
    u2 = x2*e[0] - x1*e[1]
    u3 = 0*x1

else :
    # u1 = -(1/kappa)*np.cos(kappa*(x1*e[0]+x2*e[1])) + (1/kappa)
    # u2 = (1/kappa)*np.sin(kappa*(x1*e[0]+x2*e[1]))
    # u3 = x2*e[0] -x1*e[1]
    u1 = -(1/kappa)*np.sin((-1)*kappa*(x1*e[0]+x2*e[1]))
    u2 =  x2*e[0] -x1*e[1]
    u3 = (1/kappa)*np.cos((-1)*kappa*(x1*e[0]+x2*e[1]))-(1/kappa)



# ax.plot_surface(u1, u2, u3, color = 'w', rstride = 1, cstride = 1, facecolors=cm.autumn(colorfunction), alpha=.3)  ##This one!


# T = rotate_data(X,Rotation_vector,Rotation_angle)

T = rotate_data(np.array([u1,u2,u3]),Rotation)
# T = rotate_data(T,np.array([0,1,0]),Rotation_angle)
# T = rotate_data(T,np.array([0,0,1]),-1*Rotation_angle/2)

# ax.plot_surface(T[0], T[1], T[2], rstride = 1, cstride = 1, facecolors=cm.autumn(colorfunction), alpha=.4, zorder=4, antialiased=False)
# ax.plot_surface(T[0], T[1], T[2], rstride = 1, cstride = 1, facecolors=cm.autumn(colorfunction), alpha=.4, zorder=4, antialiased=True)
# ax.plot_surface(T[0], T[1], T[2], rstride = 2, cstride = 2, facecolors=cm.autumn(colorfunction), alpha=.4, zorder=4)
# ax.plot_surface(T[0], T[1], T[2], rstride = 10, cstride = 10, facecolors=cm.brg(colorfunction), alpha=.8, zorder=4)
ax.plot_surface(T[0], T[1], T[2], rstride = 5, cstride = 5, color='orange', alpha=.8, zorder=4)
# ax.plot_surface(T[0], T[1], T[2], rstride = 10, cstride = 10, color='blue', alpha=.8, zorder=4, shade=True)
# ax.plot_surface(T[0], T[1], T[2], rstride = 1, cstride = 1, facecolors=cm.autumn(colorfunction), alpha=.4, zorder=4, shade=True)
# ax.plot_surface(T[0], T[1], T[2], color = 'w', rstride = 1, cstride = 1, facecolors=cm.autumn(colorfunction), alpha=0.8, zorder=4)
# ax.plot_surface(T[0], T[1], T[2],  rstride = 1, cstride = 1, facecolors=cm.autumn(colorfunction), alpha=1, zorde5r=5)

# midpoint = np.array([(max(x)+min(x))/2,(max(y)+min(y))/2])
# print('midpoint',midpoint)
print('------------------ Kappa : ', kappa)
# Map midpoint:
midpoint_mapped = u(midpoint,kappa,e)
print('mapped midpoint', midpoint_mapped)

#map origin
origin = np.array([0,0])
origin_mapped = u(origin,kappa,e)


mapped_e = grad_u(midpoint,kappa,e)
normal = compute_normal(midpoint,kappa,e)

print('mapped_e', mapped_e)
print('normal',normal )


#
mapped_e = Rotation.dot(mapped_e)
normal = Rotation.dot(normal)



# Plot MIDPOINT:
# ax.plot(midpoint_mapped[0],midpoint_mapped[1],midpoint_mapped[2],    # data
# marker='o',     # each marker will be rendered as a circle
# markersize=4,   # marker size
# markerfacecolor='orange',   # marker facecolor
# markeredgecolor='black',  # marker edgecolor
# markeredgewidth=1,       # marker edge width
# linewidth=1,
# zorder=5)          # line width


#midpoint:
endpoint = np.array([min(x),(max(y)+min(y))/2])
print('endpoint',endpoint)

# Map midpoint:
endpoint_mapped = u(endpoint,kappa,e)
print('mapped endpoint', endpoint_mapped)

endpoint_mapped = Rotation.dot(endpoint_mapped)

mapped_e = grad_u(endpoint,kappa,e)
normal = compute_normal(endpoint,kappa,e)


mapped_e = Rotation.dot(mapped_e)
normal = Rotation.dot(normal)




reverse_normal = np.array([ (-1)*normal[0], (-1)*normal[1], (-1)*normal[2]])

ax.plot(endpoint_mapped[0],endpoint_mapped[1],endpoint_mapped[2],    # data
marker='o',     # each marker will be rendered as a circle
markersize=1,   # marker size
markerfacecolor='black',   # marker facecolor
markeredgecolor='black',  # marker edgecolor
markeredgewidth=0.5,       # marker edge width
linewidth=1,
zorder=5)          # line width



# ax.arrow3D(endpoint_mapped[0],endpoint_mapped[1],endpoint_mapped[2],
#            mapped_e[0],mapped_e[1],mapped_e[2],
#            mutation_scale=15,
#            arrowstyle="-|>",
#            linestyle='dashed',fc='green',
#            lw = 1.5,
#            ec ='green',
#            zorder=5)

# ax.arrow3D(endpoint_mapped[0],endpoint_mapped[1],endpoint_mapped[2],
#            normal[0],normal[1],normal[2],
#            mutation_scale=15,
#             lw = 1.5,
#            arrowstyle="-|>",
#            linestyle='dashed',fc='blue',
#            ec ='blue',
#            zorder = 5)

ax.arrow3D(endpoint_mapped[0],endpoint_mapped[1],endpoint_mapped[2],
           reverse_normal[0],reverse_normal[1],reverse_normal[2],
           mutation_scale=10,
            lw = 1.5,
           arrowstyle="-|>",
           linestyle='-',fc='purple', alpha=0.75,
           ec ='purple',
           zorder = 5)




# second Endpoint
endpoint = np.array([max(x),(max(y)+min(y))/2])
print('endpoint',endpoint)

# Map midpoint:
endpoint_mapped = u(endpoint,kappa,e)
print('mapped endpoint', endpoint_mapped)

endpoint_mapped = Rotation.dot(endpoint_mapped)

mapped_e = grad_u(endpoint,kappa,e)
normal = compute_normal(endpoint,kappa,e)


mapped_e = Rotation.dot(mapped_e)
normal = Rotation.dot(normal)




reverse_normal = np.array([ (-1)*normal[0], (-1)*normal[1], (-1)*normal[2]])

ax.plot(endpoint_mapped[0],endpoint_mapped[1],endpoint_mapped[2],    # data
marker='o',     # each marker will be rendered as a circle
markersize=1,   # marker size
markerfacecolor='black',   # marker facecolor
markeredgecolor='black',  # marker edgecolor
markeredgewidth=0.5,       # marker edge width
linewidth=1,
zorder=5)          # line width



# ax.arrow3D(endpoint_mapped[0],endpoint_mapped[1],endpoint_mapped[2],
#            mapped_e[0],mapped_e[1],mapped_e[2],
#            mutation_scale=15,
#            arrowstyle="-|>",
#            linestyle='dashed',fc='green',
#            lw = 1.5,
#            ec ='green',
#            zorder=5)

# ax.arrow3D(endpoint_mapped[0],endpoint_mapped[1],endpoint_mapped[2],
#            normal[0],normal[1],normal[2],
#            mutation_scale=15,
#             lw = 1.5,
#            arrowstyle="-|>",
#            linestyle='dashed',fc='blue',
#            ec ='blue',
#            zorder = 5)

ax.arrow3D(endpoint_mapped[0],endpoint_mapped[1],endpoint_mapped[2],
           reverse_normal[0],reverse_normal[1],reverse_normal[2],
           mutation_scale=10,
            lw = 1.5,
           arrowstyle="-|>",
           linestyle='-',fc='purple', alpha=0.75,
           ec ='purple',
           zorder = 5)


# ax.arrow3D(midpoint_mapped[0],midpoint_mapped[1],midpoint_mapped[2],
#            mapped_e[0],mapped_e[1],mapped_e[2],
#            mutation_scale=15,
#            arrowstyle="-|>",
#            linestyle='dashed',fc='green',
#            lw = 1.5,
#            ec ='green',
#            zorder=5)
#
# ax.arrow3D(midpoint_mapped[0],midpoint_mapped[1],midpoint_mapped[2],
#            normal[0],normal[1],normal[2],
#            mutation_scale=15,
#             lw = 1.5,
#            arrowstyle="-|>",
#            linestyle='dashed',fc='blue',
#            ec ='blue',
#            zorder = 5)


############################################################################################################################################
####################################################################### KAPPA ZERO #########################################################
############################################################################################################################################
kappa = 0

# if kappa == 0 :
#     u1 = 0*x1
#     u2 = x1*e[0] + x2*e[1]
#     u3 = x2*e[0] - x1*e[1]
# else :
#     u1 = -(1/kappa)*np.cos(kappa*(x1*e[0]+x2*e[1])) + (1/kappa)
#     u2 = (1/kappa)*np.sin(kappa*(x1*e[0]+x2*e[1]))
#     u3 = x2*e[0] -x1*e[1]

if kappa == 0 :
    # u1 = 0*x1
    # u2 = x1*e[0] + x2*e[1]
    # u3 = x2*e[0] - x1*e[1]
    u1 = x1*e[0] + x2*e[1]
    u2 = x2*e[0] - x1*e[1]
    u3 = 0*x1

else :
    # u1 = -(1/kappa)*np.cos(kappa*(x1*e[0]+x2*e[1])) + (1/kappa)
    # u2 = (1/kappa)*np.sin(kappa*(x1*e[0]+x2*e[1]))
    # u3 = x2*e[0] -x1*e[1]
    u1 = -(1/kappa)*np.sin((-1)*kappa*(x1*e[0]+x2*e[1]))
    u2 =  x2*e[0] -x1*e[1]
    u3 = (1/kappa)*np.cos((-1)*kappa*(x1*e[0]+x2*e[1]))-(1/kappa)
# ax.plot_surface(u1, u2, u3,  rstride = 1, cstride = 1, color = 'white', alpha=0.85)

# T = rotate_data(np.array([u1,u2,u3]),Rotation_vector,Rotation_angle)

T = rotate_data(np.array([u1,u2,u3]),Rotation)
# T = rotate_data(T,np.array([0,1,0]),Rotation_angle)
# T = rotate_data(T,np.array([0,0,1]),-1*Rotation_angle/2)


# ax.plot_surface(T[0], T[1], T[2],  rstride = 1, cstride = 1, color = 'white', alpha=0.55, zorder=2, antialiased=True)
# ax.plot_surface(T[0], T[1], T[2],  rstride =1 , cstride = 1, color = 'white', alpha=0.55, zorder=3)

# ax.plot_surface(T[0], T[1], T[2],  rstride = 1, cstride = 1, color = 'white', alpha=0.55, zorder=2)
# ax.plot_surface(T[0], T[1], T[2],  rstride = 1, cstride = 1, color = 'white', alpha=0.5, zorder=2, antialiased=True)
# ax.plot_surface(T[0], T[1], T[2],  rstride = 10, cstride = 10, color = 'white', alpha=0.55, zorder=2)
ax.plot_surface(T[0], T[1], T[2],  rstride = 20, cstride = 20, color = 'gray', alpha=0.35, zorder=1, shade=True)
# ax.plot_surface(T[0], T[1], T[2], color = 'white', alpha=0.55, zorder=2)

# midpoint = np.array([(max(x)+min(x))/2,(max(y)+min(y))/2])
mapped_e = grad_u(midpoint,kappa,e)
normal_zeroCurv = compute_normal(midpoint,kappa,e)

# Map midpoint:
midpoint_mapped = u(midpoint,kappa,e)
print('mapped midpoint', midpoint_mapped)

##-----  PLOT MAPPED MIDPOINT :::

# ax.plot(midpoint_mapped[0],midpoint_mapped[1],midpoint_mapped[2],    # data
# marker='o',     # each marker will be rendered as a circle
# markersize=4,   # marker size
# markerfacecolor='orange',   # marker facecolor
# markeredgecolor='black',  # marker edgecolor
# markeredgewidth=1,       # marker edge width
# # linestyle='--',            # line style will be dash line
# linewidth=1,
# zorder=5)

# ax.arrow3D(midpoint_mapped[0],midpoint_mapped[1],midpoint_mapped[2],
#            mapped_e[0],mapped_e[1],mapped_e[2],
#            mutation_scale=10,
#            arrowstyle="-|>",
#            linestyle='dashed',fc='red',
#            ec ='red')
#
# ax.arrow3D(midpoint_mapped[0],midpoint_mapped[1],midpoint_mapped[2],
#            normal_zeroCurv[0],normal_zeroCurv[1],normal_zeroCurv[2],
#            mutation_scale=10,
#            arrowstyle="-|>",
#            linestyle='dashed',fc='blue',
#            ec ='blue')


##----------  PLOT MAPPED ORIGIN :::
# origin = np.array([0,0])
# origin_mapped = u(origin,kappa,e)
# print('origin_mapped', origin_mapped)
#
# ax.plot(origin_mapped[0],origin_mapped[1],origin_mapped[2],    # data
# marker='o',     # each marker will be rendered as a circle
# markersize=4,   # marker size
# markerfacecolor='green',   # marker facecolor
# markeredgecolor='black',  # marker edgecolor
# markeredgewidth=1,       # marker edge width
# linewidth=1,
# zorder=5)          # line width
#
# # rotate mapped origin
# # v = np.array([1,0,0])
# # alpha  = Rotation_angle
#
# rotM = rot(Rotation_vector,Rotation_angle)
# # origin_mRot = rotate_data(origin_mapped,v,alpha)
# origin_mRot = rotM.dot(origin_mapped)
# print('origin_mapped Rotated', origin_mRot)
#
# # --- Compute Distance to Origin 3D
# origin_3D=np.array([0,0,0])
# distance = origin_mapped-origin_3D
# print('distance', distance)

## --------------------------------------------------------

# COMPUTE ANGLE WITH Z AXIS
z = np.array([0,0,1])
print('test', normal_zeroCurv*z)
angle_z = np.arccos(normal_zeroCurv.dot(z) /( (np.linalg.norm(z)*np.linalg.norm(normal_zeroCurv) ) ))
print('angle between normal and z-axis', angle_z)

## unfinished...




###-------------------------------------  PLOT :
plt.axis('off')
# plt.axis('tight')

# ADD colorbar
# scamap = plt.cm.ScalarMappable(cmap='inferno')
# fig.colorbar(scamap)

# ax.colorbar()
# ax.axis('auto')
# ax.set_title(r'Cylindrical minimizer_$\kappa$='+ str(kappa)+ '_$e$=' +  str(e))
# ax.set_title(r'Cylindrical minimizer' + '_$e$=' +  str(e))
ax.set_xlabel(r"x-axis")
ax.set_ylabel(r"y-axis")
ax.set_zlabel(r"z-axis")

# TEST :
# ax.annotate3D('point 1', (0, 0, 0), xytext=(3, 3), textcoords='offset points')
# ax.annotate3D('point 2', (0, 1, 0),
#               xytext=(-30, -30),
#               textcoords='offset points',
#               arrowprops=dict(ec='black', fc='white', shrink=2.5))
# ax.annotate3D('point 3', (0, 0, 1),
#               xytext=(30, -30),
#               textcoords='offset points',
#               bbox=dict(boxstyle="round", fc="lightyellow"),
#               arrowprops=dict(arrowstyle="-|>", ec='black', fc='white', lw=5))

#######################################################################################################################

u1 = T[0]
u2 = T[1]
u3 = T[2]

# max_range = np.array([u1.max()-u1.min(), u2.max()-u2.min(), u3.max()-u3.min()]).max() /3
max_range = np.array([u1.max()-u1.min(), u2.max()-u2.min(), u3.max()-u3.min()]).max() /12
# max_range = np.array([u1.max()-u1.min(), u2.max()-u2.min(), u3.max()-u3.min()]).max() /8
# max_range = np.array([u1.max()-u1.min(), u2.max()-u2.min(), u3.max()-u3.min()]).max() /10
max_range = np.array([u1.max()-u1.min(), u2.max()-u2.min(), u3.max()-u3.min()]).max() /14
# max_range = np.array([u1.max()-u1.min(), u2.max()-u2.min(), u3.max()-u3.min()]).max() /6
# max_range = np.array([u1.max()-u1.min(), u2.max()-u2.min(), u3.max()-u3.min()]).max() /2
mid_u1 = (u1.max()+u1.min()) * 0.5
mid_u2 = (u2.max()+u2.min()) * 0.5
mid_u3 = (u3.max()+u3.min()) * 0.5


ax.set_xlim(mid_u1 - max_range, mid_u1 + max_range)
ax.set_ylim(mid_u2 - max_range, mid_u2 + max_range)
ax.set_zlim(mid_u3 - max_range, mid_u3 + max_range)

ax.set_ylim((mid_u2 - max_range)-2, (mid_u2 + max_range)+2)
# ax.set_zlim((mid_u2 - max_range), (mid_u2 + max_range)+2)
# ax.set_ylim((mid_u2 - max_range)-1.5, (mid_u2 + max_range)+1.5)
# # ax.autoscale(tight=True)
# ax.set_xlim((mid_u1 - max_range)-1, (mid_u1 + max_range)+1)
# ax.set_xlim((mid_u1 - max_range)-0.5, (mid_u1 + max_range)+0.5)
ax.set_xlim((mid_u1 - max_range)-2, (mid_u1 + max_range)+2)
# ax.set_ylim((mid_u2 - max_range)-2, (mid_u2 + max_range)+2)
ax.set_zlim((mid_u2 - max_range), (mid_u2 + max_range)+2)

# ax.set_ylim((mid_u2 - max_range)-2, (mid_u2 + max_range)+3)
# ax.set_zlim((mid_u2 - max_range), (mid_u2 + max_range)+2)


##----- CHANGE CAMERA POSITION:
# ax.view_init(elev=10., azim=0)
# ax.view_init(elev=38, azim=90)
# ax.view_init(elev=38, azim=120)
# ax.view_init(elev=38)

# if e1 ::
# ax.view_init(elev=44)
# ax.view_init(elev=38, azim=-90)
# ax.view_init(elev=38, azim=0)
# ax.view_init(elev=25, azim=-30)
# ax.view_init(elev=18, azim=-30)
ax.view_init(elev=25, azim=-125)
# ax.view_init(elev=25, azim=150)
# ax.view_init(elev=25, azim=135)

# ax.view_init(elev=25, azim=125)   #idx2



# ax.view_init(elev=25, azim=145)


# if e3 ::
# ax.view_init(elev=25)

# ax.set_xlim3d(-2, 2)
# ax.set_ylim3d(-1.0,3.0)
# ax.set_zlim3d(-1.5,2.5)

# ax.set_ylim3d(-10,10)
# ax.set_xlim(mid_u1 - max_range-0.2, mid_u1 + max_range+0.2)
# ax.set_zlim(mid_u3 - max_range-0.2, mid_u3 + max_range+0.2)
# ax.set_ylim(mid_u2 - max_range-0.2, mid_u2 + max_range+0.2)

# width = 6.28 *0.5
# height = width / 1.618
# # height = width / 2.5
# fig.set_size_inches(width, height)
# fig.savefig('Test-Cylindrical.pdf')

# Figurename = r'Cylindrical minimizer_$\kappa$='+ str(kappa)+ '_$e$=' +  str(e)
# Figurename = r'Cylindrical minimizer' + '_$e$=' +  str(e)
Figurename = r'1-ParFamMinimizer_idx' + str(idx)
# plt.savefig("test.png", bbox_inches='tight')
# plt.figure().set_size_inches(width, height)
# plt.set_size_inches(width, height)


# fig.set_size_inches(width, height) !!!!!




fig.savefig(Figurename+".pdf")

plt.savefig(Figurename+".png", bbox_inches='tight', dpi=300)
# plt.savefig(Figurename+".png")
plt.show()



# #---------------------------------------------------------------