import numpy as np
import matplotlib.pyplot as plt
import sympy as sym
import math
import os
import subprocess
import fileinput
import re
import matlab.engine
from HelperFunctions import *
from ClassifyMin import *

import matplotlib.ticker as ticker
# from subprocess import Popen, PIPE
#import sys

import matplotlib.ticker as tickers
import matplotlib as mpl
from matplotlib.ticker import MultipleLocator,FormatStrFormatter,MaxNLocator
import pandas as pd

###################### makePlot.py #########################
#  Generalized Plot-Script giving the option to define
#  quantity of interest and the parameter it depends on
#  to create a plot
#
#  Input: Define y & x for "x-y plot" as Strings
#  - Run the 'Cell-Problem' for the different Parameter-Points
#  (alternatively run 'Compute_MuGamma' if quantity of interest
#   is q3=muGamma for a significant Speedup)

###########################################################

def format_func(value, tick_number):
    # find number of multiples of pi/2
    N = int(np.round(2 * value / np.pi))
    if N == 0:
        return "0"
    elif N == 1:
        return r"$\pi/2$"
    elif N == 2:
        return r"$\pi$"
    elif N % 2 > 0:
        return r"${0}\pi/2$".format(N)
    else:
        return r"${0}\pi$".format(N // 2)





def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]


def find_nearestIdx(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx



# TODO
# - Fallunterscheidung (Speedup) falls gesuchter value mu_gamma = q3
# - Also Add option to plot Minimization Output


# ----- Setup Paths -----
InputFile  = "/inputs/cellsolver.parset"
OutputFile = "/outputs/output.txt"
# path = os.getcwd()
# InputFilePath = os.getcwd()+InputFile
# OutputFilePath = os.getcwd()+OutputFile
# --------- Run  from src folder:
path_parent = os.path.dirname(os.getcwd())
os.chdir(path_parent)
path = os.getcwd()
print(path)
InputFilePath = os.getcwd()+InputFile
OutputFilePath = os.getcwd()+OutputFile
print("InputFilepath: ", InputFilePath)
print("OutputFilepath: ", OutputFilePath)
print("Path: ", path)

#---------------------------------------------------------------

print('---- Input parameters: -----')
# mu1 = 10.0
# # lambda1 = 10.0
rho1 = 1.0
# alpha = 5.0
# beta = 10.0
# theta = 1.0/4.0


mu1 = 10.0
# lambda1 = 10.0
# rho1 = 10.0
alpha = 5.0
# beta = 2.0
beta = 10.0
theta = 1.0/4.0

theta = 1.0/2.0
# theta = 1.0/12.0




lambda1 = 0.0
gamma = 1.0/4.0

gamma = 'infinity'
gamma = '0'


print('mu1: ', mu1)
print('rho1: ', rho1)
print('alpha: ', alpha)
print('beta: ', beta)
print('theta: ', theta)
print('gamma:', gamma)
print('----------------------------')


# Optional - TODO? :
# -Ask User for Input ...
# function = input("Enter value you want to plot (y-value):\n")
# print(f'You entered {function}')
# parameter = input("Enter Parameter this value depends on (x-value) :\n")
# print(f'You entered {parameter}')

# -Add Option to change NumberOfElements used for computation of Cell-Problem


# --- Define Quantity of interest:
# Options: 'q1', 'q2', 'q3', 'q12' ,'q21', 'q31', 'q13' , 'q23', 'q32' , 'b1', 'b2' ,'b3'
# TODO: EXTRA (MInimization Output) 'Minimizer (norm?)' 'angle', 'type', 'curvature'
# yName = 'q12'
# # yName = 'b1'
# yName = 'q3'
yName = 'angle'
yName = 'curvature'

# --- Define Parameter this function/quantity depends on:
# Options: mu1 ,lambda1, rho1 , alpha, beta, theta, gamma
# xName = 'theta'
# xName = 'gamma'
# xName = 'lambda1'
xName = 'theta'
# xName = 'alpha'


# --- define Interval of x-values:
xmin = 0
xmax = 30

# xmin = 0.245
# xmax = 0.99
#
#
# xmin = 0.14
# xmax = 0.19

# xmin = 0.01
# xmax = 3.0

xmin = 0.01
xmax = 0.4



numPoints = 100
X_Values = np.linspace(xmin, xmax, num=numPoints)
print(X_Values)


Y_Values = []







for theta in X_Values:
# for alpha in X_Values:

    print('Situation of Lemma1.4')
    q12 = 0.0
    q1 = (1.0/6.0)*harmonicMean(mu1, beta, theta)
    q2 = (1.0/6.0)*arithmeticMean(mu1, beta, theta)
    b1 = prestrain_b1(rho1, beta, alpha,theta)
    b2 = prestrain_b2(rho1, beta, alpha,theta)
    b3 = 0.0
    if gamma == '0':
        q3 = q2
    if gamma == 'infinity':
        q3 = q1

    if yName == 'q1':                   # TODO: Better use dictionary?...
        print('q1 used')
        Y_Values.append(q1)
    elif yName =='q2':
        print('q2 used')
        Y_Values.append(q2)
    elif yName =='q3':
        print('q3 used')
        Y_Values.append(q3)
    elif yName =='q12':
        print('q12 used')
        Y_Values.append(q12)
    elif yName =='b1':
        print('b1 used')
        Y_Values.append(b1)
    elif yName =='b2':
        print('b2 used')
        Y_Values.append(b2)
    elif yName =='b3':
        print('b3 used')
        Y_Values.append(b3)
    elif yName == 'angle' or yName =='type' or yName =='curvature':
        G, angle, Type, curvature = classifyMin_ana(alpha,beta,theta, q3,  mu1, rho1)
        if yName =='angle':
            print('angle used')
            Y_Values.append(angle)
        if yName =='type':
            print('angle used')
            Y_Values.append(type)
        if yName =='curvature':
            print('curvature used')
            Y_Values.append(curvature)


print("(Output) Values of " + yName + ": ", Y_Values)


idx = find_nearestIdx(Y_Values, 0)
print(' Idx of value  closest to 0:', idx)
ValueClose = Y_Values[idx]
print('GammaValue(Idx) with mu_gamma closest to q_3^*:', ValueClose)
print('Theta(Idx) with curvature closest to 0:', ValueClose)





# Find Indices where the difference between the next one is larger than epsilon...
jump_idx = []
jump_xValues = []
jump_yValues = []
tmp = X_Values[0]
for idx, x in enumerate(X_Values):
    print(idx, x)
    if idx > 0:
        if abs(Y_Values[idx]-Y_Values[idx-1]) > 1:
            print('jump candidate')
            jump_idx.append(idx)
            jump_xValues.append(x)
            jump_yValues.append(Y_Values[idx])







print("Jump Indices", jump_idx)
print("Jump X-values:", jump_xValues)
print("Jump Y-values:", jump_yValues)

y_plotValues = [Y_Values[0]]
x_plotValues = [X_Values[0]]
# y_plotValues.extend(jump_yValues)
for i in jump_idx:
    y_plotValues.extend([Y_Values[i-1], Y_Values[i]])
    x_plotValues.extend([X_Values[i-1], X_Values[i]])


y_plotValues.append(Y_Values[-1])
# x_plotValues = [X_Values[0]]
# x_plotValues.extend(jump_xValues)
x_plotValues.append(X_Values[-1])


print("y_plotValues:", y_plotValues)
print("x_plotValues:", x_plotValues)
# Y_Values[np.diff(y) >= 0.5] = np.nan


#get values bigger than jump position
x_rest = X_Values[X_Values>x_plotValues[1]]

Y_Values = np.array(Y_Values)  #convert the np array

y_rest = Y_Values[X_Values>x_plotValues[1]]
# y_rest = Y_Values[np.nonzero(X_Values>x_plotValues[1]]
print('X_Values:', X_Values)
print('Y_Values:', Y_Values)
print('x_rest:', x_rest)
print('y_rest:', y_rest)
print('np.nonzero(X_Values>x_plotValues[1]', np.nonzero(X_Values>x_plotValues[1]) )


# --- Convert to numpy array
Y_Values = np.array(Y_Values)
X_Values = np.array(X_Values)


# ---------------- Create Plot -------------------
mpl.rcParams['text.usetex'] = True
mpl.rcParams["font.family"] = "serif"
mpl.rcParams["font.size"] = "9"
# width as measured in inkscape
width = 6.28 *0.5
height = width / 1.618
fig = plt.figure()
ax = plt.axes((0.15,0.18,0.8,0.8))
ax.tick_params(axis='x',which='major', direction='out',pad=3)
ax.tick_params(axis='y',which='major', length=3, width=1, direction='out',pad=3)
ax.xaxis.set_major_locator(MultipleLocator(0.05))
ax.xaxis.set_minor_locator(MultipleLocator(0.025))
ax.grid(True,which='major',axis='both',alpha=0.3)
# plt.figure()
# f,ax=plt.subplots(1)


ax.set_xlabel(r"volume fraction $\theta$")
ax.set_ylabel(r"curvature $\kappa$")
# plt.xlabel(xName)
# plt.ylabel(yName)
# plt.ylabel('$\kappa$')
# ax.grid(True)

# Add transition Points
if gamma == '0':
    transition_point1 =  0.13663316582914573
    transition_point2 =  0.20899497487437185
    plt.axvline(transition_point1,ymin=0, ymax= 1, color = 'orange',alpha=0.5, linestyle = 'dashed', linewidth=1)
    plt.axvline(transition_point2,ymin=0, ymax= 1, color = 'orange',alpha=0.5, linestyle = 'dashed', linewidth=1)

    ax.plot(X_Values[X_Values<jump_xValues[0]], Y_Values[X_Values<jump_xValues[0]], 'royalblue')
    ax.plot(X_Values[np.where(np.logical_and(X_Values>jump_xValues[0], X_Values<jump_xValues[1])) ], Y_Values[np.where(np.logical_and(X_Values>jump_xValues[0] ,X_Values<jump_xValues[1] ))] ,'royalblue')
    ax.plot(X_Values[X_Values>jump_xValues[1]], Y_Values[X_Values>jump_xValues[1]], 'royalblue')
    # ax.plot(x_plotValues,y_plotValues, 'royalblue')
    ax.scatter([transition_point1, transition_point2],[jump_yValues[0], jump_yValues[1]],s=6, marker='o', cmap=None, norm=None, facecolor = 'black',
                              edgecolor = 'black', vmin=None, vmax=None, alpha=None, linewidths=None, zorder=3)

    ax.text(transition_point1-0.02 , jump_yValues[0]-0.02, r"$4$", size=6, bbox=dict(boxstyle="circle",facecolor='white', alpha=1.0, pad=0.1, linewidth=0.5)
                       )

    ax.text(transition_point2+0.012 , jump_yValues[1]+0.02, r"$5$", size=6, bbox=dict(boxstyle="circle",facecolor='white', alpha=1.0, pad=0.1, linewidth=0.5)
                   )

if gamma == 'infinity':
    transition_point1 = 0.13663316582914573
    transition_point2 = 0.1929145728643216
    transition_point3 = 0.24115577889447234
    plt.axvline(transition_point1,ymin=0, ymax= 1, color = 'orange',alpha=0.5, linestyle = 'dashed', linewidth=1)
    plt.axvline(transition_point2,ymin=0, ymax= 1, color = 'orange',alpha=0.5, linestyle = 'dashed', linewidth=1)
    plt.axvline(transition_point3,ymin=0, ymax= 1, color = 'orange',alpha=0.5, linestyle = 'dashed', linewidth=1)
    ax.plot(X_Values[X_Values<jump_xValues[0]], Y_Values[X_Values<jump_xValues[0]], 'royalblue')
    ax.plot(X_Values[X_Values>jump_xValues[0]], Y_Values[X_Values>jump_xValues[0]], 'royalblue')

    idx1 = find_nearestIdx(X_Values, transition_point1)
    idx2 = find_nearestIdx(X_Values, transition_point2)
    print('idx1', idx1)
    print('idx2', idx2)
    Y_TP1 = Y_Values[idx1]
    Y_TP2 = Y_Values[idx2]
    print('Y_TP1', Y_TP1)
    print('Y_TP2', Y_TP2)


    ax.scatter([transition_point1, transition_point2],[Y_TP1, Y_TP2],s=6, marker='o', cmap=None, norm=None, facecolor = 'black',
                              edgecolor = 'black', vmin=None, vmax=None, alpha=None, linewidths=None, zorder=3)

    ax.text(transition_point1-0.02 , Y_TP1-0.02, r"$6$", size=6, bbox=dict(boxstyle="circle",facecolor='white', alpha=1.0, pad=0.1, linewidth=0.5)
                       )

    ax.text(transition_point2+0.015 , Y_TP2+0.020, r"$7$", size=6, bbox=dict(boxstyle="circle",facecolor='white', alpha=1.0, pad=0.1, linewidth=0.5))
# for x in jump_xValues:
#     plt.axvline(x,ymin=0, ymax= 1, color = 'g',alpha=0.5, linestyle = 'dashed')



fig.set_size_inches(width, height)
fig.savefig('Plot-Curvature-Theta.pdf')

# plt.axhline(y = 1.90476, color = 'b', linestyle = ':', label='$q_1$')
# plt.axhline(y = 2.08333, color = 'r', linestyle = 'dashed', label='$q_2$')
# plt.legend()
plt.show()
# #---------------------------------------------------------------