import numpy as np
import matplotlib.pyplot as plt
import sympy as sym
import math
import os
import subprocess
import fileinput
import re
import matlab.engine
from HelperFunctions import *
from ClassifyMinVec import *
import matplotlib.cm as cm
from matplotlib.colors import Normalize

import matplotlib.ticker as ticker
# from subprocess import Popen, PIPE
#import sys

###################### makePlot.py #########################
#  Generalized Plot-Script giving the option to define
#  quantity of interest and the parameter it depends on
#  to create a plot
#
#  Input: Define y & x for "x-y plot" as Strings
#  - Run the 'Cell-Problem' for the different Parameter-Points
#  (alternatively run 'Compute_MuGamma' if quantity of interest
#   is q3=muGamma for a significant Speedup)

###########################################################

def format_func(value, tick_number):
    # find number of multiples of pi/2
    N = int(np.round(2 * value / np.pi))
    if N == 0:
        return "0"
    elif N == 1:
        return r"$\pi/2$"
    elif N == 2:
        return r"$\pi$"
    elif N % 2 > 0:
        return r"${0}\pi/2$".format(N)
    else:
        return r"${0}\pi$".format(N // 2)





def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]


def find_nearestIdx(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx



# TODO
# - Fallunterscheidung (Speedup) falls gesuchter value mu_gamma = q3
# - Also Add option to plot Minimization Output


# ----- Setup Paths -----
InputFile  = "/inputs/cellsolver.parset"
OutputFile = "/outputs/output.txt"
# path = os.getcwd()
# InputFilePath = os.getcwd()+InputFile
# OutputFilePath = os.getcwd()+OutputFile
# --------- Run  from src folder:
path_parent = os.path.dirname(os.getcwd())
os.chdir(path_parent)
path = os.getcwd()
print(path)
InputFilePath = os.getcwd()+InputFile
OutputFilePath = os.getcwd()+OutputFile
print("InputFilepath: ", InputFilePath)
print("OutputFilepath: ", OutputFilePath)
print("Path: ", path)

#---------------------------------------------------------------

print('---- Input parameters: -----')
# mu1 = 10.0
# # lambda1 = 10.0
rho1 = 1.0
# alpha = 5.0
# beta = 10.0
# theta = 1.0/4.0


mu1 = 10.0
# lambda1 = 10.0
# rho1 = 10.0
alpha = 5.0
# beta = 2.0
beta = 10.0
theta = 1.0/4.0

theta = 1.0/2.0
# theta = 1.0/12.0


lambda1 = 0.0
gamma = 1.0/4.0

gamma = 'infinity'
gamma = '0'


print('mu1: ', mu1)
print('rho1: ', rho1)
print('alpha: ', alpha)
print('beta: ', beta)
print('theta: ', theta)
print('gamma:', gamma)
print('----------------------------')


# TODO? : Ask User for Input ...
# function = input("Enter value you want to plot (y-value):\n")
# print(f'You entered {function}')
# parameter = input("Enter Parameter this value depends on (x-value) :\n")
# print(f'You entered {parameter}')

# Add Option to change NumberOfElements used for computation of Cell-Problem


# --- Define Quantity of interest:
# Options: 'q1', 'q2', 'q3', 'q12' ,'q21', 'q31', 'q13' , 'q23', 'q32' , 'b1', 'b2' ,'b3'
# TODO: EXTRA (MInimization Output) 'Minimizer (norm?)' 'angle', 'type', 'curvature'
# yName = 'q12'
# # yName = 'b1'
# yName = 'q3'
yName = 'angle'
yName = 'curvature'
yName = 'MinVec'

# --- Define Parameter this function/quantity depends on:
# Options: mu1 ,lambda1, rho1 , alpha, beta, theta, gamma
# xName = 'theta'
# xName = 'gamma'
# xName = 'lambda1'
xName = 'theta'
# xName = 'alpha'


# --- define Interval of x-values:
xmin = 0
xmax = 30

# xmin = 0.245
# xmax = 0.99
#
#
# xmin = 0.14
# xmax = 0.19

# xmin = 0.01
# xmax = 3.0

xmin = 0.125
xmax = 0.250

xmin = 0.05
xmax = 0.3

xmin = 0.15
xmax = 0.3

xmin = 0.193
xmax = 0.24

xmin=0.01
xmiddle = 0.24#0.24242424242424246 #0.24
xmax=0.4

numPoints_1 = 15
numPoints_2 = 15
# test
JumpVal = 0.194 #0.19515151515151516
#X_Values before interesting part

tick = (JumpVal-xmin)/numPoints_1
# tick = (JumpVal-xmin)/numPoints_2
print('tick:' ,tick)

X_Values_one = np.linspace(xmin, JumpVal-tick, num=numPoints_1)
X_Values_middle = np.linspace(JumpVal, xmiddle, num=numPoints_2)
X_Values_two = np.linspace(JumpVal, xmax, num=numPoints_1)

X_Values = np.concatenate([X_Values_one,X_Values_middle,X_Values_two])
print('X_values_one:', X_Values_one)
print('X_values_two:', X_Values_two)
print('X_values:', X_Values)


Y_Values = []
Angle_Values = []


other = False
# other = True


for theta in X_Values:
# for alpha in X_Values:

    print('Situation of Lemma1.4')
    q12 = 0.0
    q1 = (1.0/6.0)*harmonicMean(mu1, beta, theta)
    q2 = (1.0/6.0)*arithmeticMean(mu1, beta, theta)
    b1 = prestrain_b1(rho1, beta, alpha,theta)
    b2 = prestrain_b2(rho1, beta, alpha,theta)
    b3 = 0.0
    if gamma == '0':
        q3 = q2
    if gamma == 'infinity':
        q3 = q1

    if yName == 'q1':                   # TODO: Better use dictionary?...
        print('q1 used')
        Y_Values.append(q1)
    elif yName =='q2':
        print('q2 used')
        Y_Values.append(q2)
    elif yName =='q3':
        print('q3 used')
        Y_Values.append(q3)
    elif yName =='q12':
        print('q12 used')
        Y_Values.append(q12)
    elif yName =='b1':
        print('b1 used')
        Y_Values.append(b1)
    elif yName =='b2':
        print('b2 used')
        Y_Values.append(b2)
    elif yName =='b3':
        print('b3 used')
        Y_Values.append(b3)
    elif yName == 'angle' or yName =='type' or yName =='curvature' or yName =='MinVec':
        G, angle, Type, curvature = classifyMin_ana(alpha,beta,theta, q3,  mu1, rho1)
        if yName =='angle':
            print('angle used')
            Y_Values.append(angle)
        if yName =='type':
            print('angle used')
            Y_Values.append(type)
        if yName =='curvature':
            print('angle used')
            Y_Values.append(curvature)
        if yName =='MinVec':
            print('Minvec used')
            Y_Values.append(G)
            Angle_Values.append(angle)


print("(Output) Values of " + yName + ": ", Y_Values)


# idx = find_nearestIdx(Y_Values, 0)
# print(' Idx of value  closest to 0', idx)
# ValueClose = Y_Values[idx]
# print('GammaValue(Idx) with mu_gamma closest to q_3^*', ValueClose)
#
#
#
# # Find Indices where the difference between the next one is larger than epsilon...
# jump_idx = []
# jump_xValues = []
# jump_yValues = []
# tmp = X_Values[0]
# for idx, x in enumerate(X_Values):
#     print(idx, x)
#     if idx > 0:
#         if abs(Y_Values[idx]-Y_Values[idx-1]) > 1:
#             print('jump candidate')
#             jump_idx.append(idx)
#             jump_xValues.append(x)
#             jump_yValues.append(Y_Values[idx])
#




#
#
# print("Jump Indices", jump_idx)
# print("Jump X-values:", jump_xValues)
# print("Jump Y-values:", jump_yValues)
#
# y_plotValues = [Y_Values[0]]
# x_plotValues = [X_Values[0]]
# # y_plotValues.extend(jump_yValues)
# for i in jump_idx:
#     y_plotValues.extend([Y_Values[i-1], Y_Values[i]])
#     x_plotValues.extend([X_Values[i-1], X_Values[i]])
#
#
# y_plotValues.append(Y_Values[-1])
# # x_plotValues = [X_Values[0]]
# # x_plotValues.extend(jump_xValues)
# x_plotValues.append(X_Values[-1])
#
#
# print("y_plotValues:", y_plotValues)
# print("x_plotValues:", x_plotValues)
# # Y_Values[np.diff(y) >= 0.5] = np.nan

#
# #get values bigger than jump position
# x_rest = X_Values[X_Values>x_plotValues[1]]
#
# Y_Values = np.array(Y_Values)  #convert the np array
#
# y_rest = Y_Values[X_Values>x_plotValues[1]]
# # y_rest = Y_Values[np.nonzero(X_Values>x_plotValues[1]]
# print('X_Values:', X_Values)
# print('Y_Values:', Y_Values)
# print('x_rest:', x_rest)
# print('y_rest:', y_rest)
# print('np.nonzero(X_Values>x_plotValues[1]', np.nonzero(X_Values>x_plotValues[1]) )


print('X_values:', X_Values)
print('Y_values:', Y_Values)



# ---------------- Create Plot -------------------
plt.figure()

f,ax=plt.subplots(1)

# plt.title(r''+ yName + '-Plot')
# plt.plot(X_Values, Y_Values,linewidth=2, '.k')
# plt.plot(X_Values, Y_Values,'.k',markersize=1)
# plt.plot(X_Values, Y_Values,'.',markersize=0.8)

# plt.plot(X_Values, Y_Values)

# ax.plot([[0],X_Values[-1]], [Y_Values[0],Y_Values[-1]])
# ax.plot([x_plotValues[0],x_plotValues[1]], [y_plotValues[0],y_plotValues[1]] , 'b')
# ax.plot(x_rest, y_rest, 'b')

 #Define jump
JumpVal = 0.19
X_Values = np.array(X_Values)
Y_Values = np.array(Y_Values)
Angle_Values = np.array(Angle_Values)
# X_one = X_Values[X_Values<0.19]
# Y_one = Y_Values[X_Values<0.19]
# Angle_one=Angle_Values[X_Values<0.19]
# X_two = X_Values[X_Values>=0.19]
# Y_two = Y_Values[X_Values>=0.19]
# Angle_two=Angle_Values[X_Values>=0.19]

# X_Values = X_two
# Y_Values = Y_two
# # Angle_Values = Angle_two
# print('X_one:', X_Values)


color=['r','b','g']
cmap = cm.get_cmap(name='rainbow')
Y_arr = np.asarray(Y_Values, dtype=float)
Angle_Values = np.asarray(Angle_Values, dtype=float)
# Angle_two = np.asarray(Angle_two, dtype=float)
X_Values = np.asarray(X_Values, dtype=float)

# Y_one = np.asarray(Y_one, dtype=float)
# Angle_one = np.asarray(Angle_one, dtype=float)
# X_one = np.asarray(X_one, dtype=float)



# print('X_one:', X_one)
# print('Y_one:', Y_one)
# print('Angle_one:', Angle_one)

print('X_Values:', X_Values)
print('Y_arr:', Y_arr)
# print('Angle_two:', Angle_two)
#
# print('X_Values:', X_Values)
# print('Y_arr:', Y_arr)
# print('Angle_two:', Angle_two)




# Or = np.zeros_like(Y_arr)
# Or_tmp = np.ones_like(X_Values)

# Or = np.concatenate(([X_Values],[Or_tmp])  ,axis=1)
# Or = np.array([X_Values,Or_tmp])

print('np.transpose(X_Values)', np.transpose(X_Values))
print('X_Values.shape', X_Values.shape[0] )
print('reshape X_Values', X_Values.reshape(X_Values.shape[0],1).shape)

print('ones.', np.ones((5,1),dtype=float))





# Or = np.hstack([np.transpose(X_Values),np.transpose(Or_tmp)])
# Or = np.hstack((X_Values,np.ones((X_Values.shape[0],1), dtype=X_Values.dtype)))
X_Values= X_Values.reshape(X_Values.shape[0],1)
# X_one= X_one.reshape(X_one.shape[0],1)

# Or_one = np.hstack((X_one,np.zeros((X_one.shape[0],1),dtype=float)))
Or = np.hstack((X_Values,np.zeros((X_Values.shape[0],1),dtype=float)))
print('Or:', Or)
# print('Or_one:', Or_one)
# -----------------------------------------------------------------------------

#normalize
sum_of_rows = Y_arr.sum(axis=1)
print('sum_of_rows:', sum_of_rows)
# Y_arrN = Y_arr / sum_of_rows[:,np.newaxis]
Y_arrN = Y_arr / np.linalg.norm(Y_arr, ord=2, axis=1, keepdims=True)

# Y_arrN = Y_arr / np.sqrt(np.sum(Y_arr**2))
# print('normalized Y_arrN_OLD:', Y_arrN)
print('normalized Y_arrN:', Y_arrN)
# sum_of_rows_one = Y_one.sum(axis=1)
# Y_oneN = Y_one / sum_of_rows_one[:,np.newaxis]
# print('normalized Y_one:', Y_oneN)

plt.grid(b=True, which='major')


# plt.quiver([Or[:,0], Or[:,1]] , Y_arrN[:,0], Y_arrN[:,1])
print(Or[:,1])
print(Or[:,0])
print(Y_arrN[:,0])


#
# print('Or_one[:,1]',Or_one[:,1])
# print(Or_one[:,0])
# print(Y_oneN[:,0])

norm = Normalize()
norm.autoscale(Angle_Values)    #here full array needed?!
# norm.autoscale(Angle_one)

colormap = cm.RdBu







# Plot only every second one
skip = (slice(None,None,2))
# skip = (slice(None,None,2))

widths = np.linspace(0, 2, X_Values.size)

# Q = ax.quiver(Or[:,0][skip], Or[:,1][skip] , Y_arrN[:,0][skip], Y_arrN[:,1][skip], color = colormap(norm(Angle_arr)), angles='xy', scale=5, units='xy', alpha=0.8,
# headwidth=2)
# Q_one = ax.quiver(Or_one[:,0][skip], Or_one[:,1][skip] , Y_one[:,0][skip], Y_one[:,1][skip], color = colormap(norm(Angle_Values)), angles='xy', scale=5, units='xy', alpha=0.8,
# headwidth=2)

# Q = ax.quiver(Or[:,0], Or[:,1] , Y_arrN[:,0], Y_arrN[:,1], color = colormap(norm(Angle_Values)), angles='xy', scale=15, units='xy', alpha=0.8,
# headwidth=2)

# Q = ax.quiver(Or[:,0], Or[:,1] , Y_arrN[:,0], Y_arrN[:,1], color = colormap(norm(Angle_Values)), angles='xy',  units='xy', alpha=0.8, scale=10,
# headwidth=2, linewidths=widths, edgecolors='k')

# Q = ax.quiver(Or[:,0], Or[:,1] , Y_arrN[:,0], Y_arrN[:,1], color = colormap(norm(Angle_Values)), angles='xy',  units='xy', alpha=0.8, scale=20,
# headwidth=0.01, headlength=5, width=0.01, edgecolors='k')

# Q = ax.quiver(Or[:,0], Or[:,1] , Y_arrN[:,0], Y_arrN[:,1], color = colormap(norm(Angle_Values)), angles='xy',  units='xy', alpha=0.8, scale=20, linewidth=0.1, edgecolors='k')
# Q = ax.quiver(Or[:,0], Or[:,1] , Y_arrN[:,0], Y_arrN[:,1], color = colormap(norm(Angle_Values)) ,  alpha=0.8, scale=20, linewidth=0.05, edgecolors='k', scale_units='width')
# Q = ax.quiver(Or[:,0], Or[:,1] , Y_arrN[:,0], Y_arrN[:,1], color = colormap(norm(Angle_Values)) ,  alpha=0.8,  scale=15, scale_units='x', linewidth=0.3)
Q = ax.quiver(Or[:,0], Or[:,1] , Y_arrN[:,0], Y_arrN[:,1], color = colormap(norm(Angle_Values)) ,  alpha=1.0, scale=10 , )
# (Y_Values[:,0].max()-Y_Values[:,0].min())

# f.colorbar(Q,extend='max')
# ax.colorbar(Q )

# ax.quiver(Or[:,0], Or[:,1] , Y_arrN[:,0], Y_arrN[:,1], scale=5, units='xy')
# ax.quiver(Or[:,0], Or[:,1] , Y_arrN[:,0], Y_arrN[:,1], Angle_arr, angles='xy', scale=5, units='xy')
# ax.quiver(Or[:,0], Or[:,1] , Y_arrN[:,0], Y_arrN[:,1], color = colormap(norm(Angle_arr)), angles='xy', scale=15, units='xy')
# ax.quiver(Or[:,0], Or[:,1] , Y_arrN[:,0], Y_arrN[:,1], color = colormap(norm(Angle_arr)), angles='xy', scale=5, units='xy', alpha=0.8,
# headwidth=2)


ax.scatter(Or[:,0], Or[:,1])
# ax.scatter(Or_one[:,0], Or_one[:,1], c='black', s=10)
# ax.set_aspect('equal')
# ax.set_aspect('auto')
# ax.axis([0.1,0.4 , -0.1, 0.75])
# plt.quiver(Or[:,0], Or[:,1] , Y_arrN[:,0], Y_arrN[:,1], scale=1)

# ax.set_xlim((0.1, X_Values[:,0].max()))
# ax.set_ylim((-0.1, Y_arrN[:,1].max()))


plt.axvline(JumpVal,ymin=0, ymax= 1, color = 'g',alpha=0.5, linestyle = 'dashed')
plt.axvline(xmiddle,ymin=0, ymax= 1, color = 'g',alpha=0.5, linestyle = 'dashed')
plt.axvline(0.13606060606060608,ymin=0, ymax= 1, color = 'g',alpha=0.5, linestyle = 'dashed')

# ax.yaxis.set_major_locator(plt.MultipleLocator(np.pi / 2))
# ax.yaxis.set_minor_locator(plt.MultipleLocator(np.pi / 12))
# ax.yaxis.set_major_formatter(plt.FuncFormatter(format_func))

ax.set_xlim((0.1, X_Values[:,0].max()+0.1))
ax.set_ylim((-0.05, 0.1 ))
ax.set_xticks(np.arange(0, 0.45, step=0.05))
# ax.set(aspect=1)
ax.tick_params(labelleft = False)
cbar = f.colorbar(cm.ScalarMappable(norm=norm, cmap=colormap), ax=ax, ticks=[0 ,np.pi/2])
cbar.ax.set_yticklabels(['0', r"$\pi/2$"])

plt.show()
# plt.quiver(Or , Y_arrN[:,0], Y_arrN[:,1])


plt.scatter(Or[:,0], Or[:,1], c='black', s=10)
if other:
    for i, y in enumerate(Y_Values):
        maxes = 1.1*np.amax(abs(Y_Values[i]), axis = 0)
        tmp = Y_Values[i]
        print('tmp:', tmp)

        tmp_normalized = tmp / np.sqrt(np.sum(tmp**2))
        print('tmp_normalized:', tmp_normalized)
        # origin = np.array([[0, 0, 0],[0, 0, 0]]) # origin point
        origin = np.array([X_Values[i], 1])
        # origin = np.array([0,0])
        print('origin:', origin)
        # plt.scatter(origin[0],origin[1])
        # plt.plot(origin, 'ok')
        # plt.axis('equal')
        # plt.axis('auto')
        plt.xlim([-0.1, 0.4])
        plt.ylim([-0.1, 4])
        # plt.xlim([-maxes[0], maxes[0]])
        # plt.ylim([-maxes[1], maxes[1]])
    # plt.quiver(*origin, tmp[0], tmp[1], headlength=4)
        # plt.axes().arrow(*origin, tmp[0], tmp[1],head_width=0.05, head_length = 0.1, color = color[1])
        # plt.arrow(*origin, tmp[0], tmp[1],head_width=0.05, head_length = 0.1, color = color[1])
        # plt.arrow(*origin, tmp_normalized[0], tmp_normalized[1], color = color[1])
        # plt.arrow(*origin, tmp_normalized[0], tmp_normalized[1], head_width=0.05, head_length = 0.1, color = color[1])
        # plt.arrow(*origin, tmp_normalized[0], tmp_normalized[1], head_width=0.05, head_length = 0.1, color = cmap(i))
        # plt.arrow(origin[i,0], origin[i,1], tmp_normalized[0], tmp_normalized[1], head_width=0.05, head_length = 0.1, color = cmap(i))
        # plt.arrow(origin[i,0], origin[i,1], tmp_normalized[0], tmp_normalized[1], head_width=0.05, head_length = 0.1, color = cmap(i))
        # w = 0.005 * (y - ymin) / (ymax - ymin)
        w = 0.005
        plt.quiver(Or[i,0], Or[i,1] , Y_arrN[i,0], Y_arrN[i,1], color = colormap(norm(Angle_Values)), angles='xy',  units='xy', alpha=0.8,
        headwidth=2, width=w, edgecolors='k')
        plt.grid(b=True, which='major')
    # plt.quiver(*origin, test[0], test[1], color=['r','b','g'], scale=21)
    # plt.quiver(*origin, Y_Values[0][:,0], Y_Values[0][:,1], color=['r','b','g'], scale=21)
    # plt.quiver(*origin, Y_Values[:,0], V[:,1], color=['r','b','g'], scale=21)
    # plt.quiver(*origin, Y_Values[:,0], V[:,1], color=['r','b','g'], scale=21)





    plt.show()
    # ax.plot(X_Values, Y_Values)
    # ax.scatter(X_Values, Y_Values)
    # plt.plot(x_plotValues, y_plotValues,'.')
    # plt.scatter(X_Values, Y_Values, alpha=0.3)
    # plt.scatter(X_Values, Y_Values)
    # plt.plot(X_Values, Y_Values,'.')
    # plt.plot([X_Values[0],X_Values[-1]], [Y_Values[0],Y_Values[-1]])
    # plt.axis([0, 6, 0, 20])

    plt.xlabel(xName)
    # plt.ylabel(yName)

    plt.ylabel('$\kappa$')

    # ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%g $\pi$'))
    # ax.yaxis.set_major_locator(ticker.MultipleLocator(base=0.1))




    ax.grid(True)

# # if angle PLOT :
# ax.yaxis.set_major_locator(plt.MultipleLocator(np.pi / 2))
# ax.yaxis.set_minor_locator(plt.MultipleLocator(np.pi / 12))
#
# ax.yaxis.set_major_formatter(plt.FuncFormatter(format_func))
#
# # Plot every other line.. not the jumps...
# tmp = 1
# for idx, x in enumerate(x_plotValues):
#     if idx > 0 and tmp == 1:
#         # plt.plot([x_plotValues[idx-1],x_plotValues[idx]] ,[y_plotValues[idx-1],y_plotValues[idx]] )
#         ax.plot([x_plotValues[idx-1],x_plotValues[idx]] ,[y_plotValues[idx-1],y_plotValues[idx]] ,'b')
#         tmp = 0
#     else:
#         tmp = 1

# plt.plot([x_plotValues[0],x_plotValues[1]] ,[y_plotValues[0],y_plotValues[1]] )
# plt.plot([x_plotValues[2],x_plotValues[3]] ,[y_plotValues[2],y_plotValues[3]] )
# plt.plot([x_plotValues[4],x_plotValues[5]] ,[y_plotValues[4],y_plotValues[5]] )
# plt.plot([x_plotValues[6],x_plotValues[7]] ,[y_plotValues[6],y_plotValues[7]] )

#
# for x in jump_xValues:
#     plt.axvline(x,ymin=0, ymax= 1, color = 'g',alpha=0.5, linestyle = 'dashed')

# plt.axvline(x_plotValues[1],ymin=0, ymax= 1, color = 'g',alpha=0.5, linestyle = 'dashed')

# plt.axhline(y = 1.90476, color = 'b', linestyle = ':', label='$q_1$')
# plt.axhline(y = 2.08333, color = 'r', linestyle = 'dashed', label='$q_2$')
# plt.legend()
# plt.show()
# #---------------------------------------------------------------