import numpy as np
import matplotlib.pyplot as plt
import sympy as sym
import math
import os
import subprocess
import fileinput
import re
import matlab.engine
import sys
from ClassifyMin import *
from HelperFunctions import *
# from CellScript import *
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm
from vtk.util import numpy_support
from pyevtk.hl import gridToVTK

import time

# ----------- SETUP PATHS
# InputFile  = "/inputs/cellsolver.parset"
# OutputFile = "/outputs/output.txt"
InputFile  = "/inputs/computeMuGamma.parset"
OutputFile = "/outputs/outputMuGamma.txt"
# --------- Run  from src folder:
path_parent = os.path.dirname(os.getcwd())
os.chdir(path_parent)
path = os.getcwd()
print(path)
InputFilePath = os.getcwd()+InputFile
OutputFilePath = os.getcwd()+OutputFile
print("InputFilepath: ", InputFilePath)
print("OutputFilepath: ", OutputFilePath)
print("Path: ", path)


# -------------------------- Input Parameters --------------------
# mu1 = 10.0               # TODO : here must be the same values as in the Parset for computeMuGamma
mu1 = 1.0
rho1 = 1.0
alpha = 2.0
beta = 2.0
beta = 5.0
theta = 1.0/4.0

#set gamma either to 1. '0' 2. 'infinity' or 3. a numerical positive value
gamma = '0'
gamma = 'infinity'
# gamma = 0.5
# gamma = 0.25
# gamma = 1.0

# gamma = 5.0

#added
# lambda1 = 10.0
lambda1 = 0.0


print('---- Input parameters: -----')
print('mu1: ', mu1)
print('rho1: ', rho1)
print('alpha: ', alpha)
print('beta: ', beta)
print('theta: ', theta)
print('gamma:', gamma)

print('lambda1: ', lambda1)
print('----------------------------')
# ----------------------------------------------------------------

#
# gamma_min = 0.5
# gamma_max = 1.0
#
# # gamma_min = 1
# # gamma_max = 1
# Gamma_Values = np.linspace(gamma_min, gamma_max, num=3)
# # #
# # # Gamma_Values = np.linspace(gamma_min, gamma_max, num=13)    # TODO variable Input Parameters...alpha,beta...
# print('(Input) Gamma_Values:', Gamma_Values)

print('type of gamma:', type(gamma))
# # #
# Gamma_Values = ['0', 'infinity']
Gamma_Values = ['infinity']
# Gamma_Values = ['0']
print('(Input) Gamma_Values:', Gamma_Values)

for gamma in Gamma_Values:

    print('Run for gamma = ', gamma)
    print('type of gamma:', type(gamma))
        # muGamma = GetMuGamma(beta,theta,gamma,mu1,rho1,InputFilePath)
        # # muGamma = GetMuGamma(beta,theta,gamma,mu1,rho1)
        # print('Test MuGamma:', muGamma)

        # ------- Options --------
        # print_Cases = True
        # print_Output = True

                            #TODO
    # generalCase = True #Read Output from Cell-Problem instead of using Lemma1.4 (special case)
    generalCase = False

    # make_3D_plot = True
    # make_3D_PhaseDiagram = True
    make_2D_plot = False
    make_2D_PhaseDiagram = False
    make_3D_plot = False
    make_3D_PhaseDiagram = False
    # make_2D_plot = True
    make_2D_PhaseDiagram = True
    #

    # ---------------------- MAKE PLOT / Write to VTK------------------------------------------------------------------------------

    # SamplePoints_3D = 10 # Number of sample points in each direction
    # SamplePoints_2D = 10 # Number of sample points in each direction
    SamplePoints_3D = 300 # Number of sample points in each direction
    # SamplePoints_3D = 150 # Number of sample points in each direction
    SamplePoints_3D = 100 # Number of sample points in each direction
    # SamplePoints_3D = 200 # Number of sample points in each direction
    # SamplePoints_3D = 400 # Number of sample points in each direction
    # SamplePoints_2D = 7500 # Number of sample points in each direction
    SamplePoints_2D = 4000 # 4000 # Number of sample points in each direction
    SamplePoints_2D = 500 # 4000 # Number of sample points in each direction

    # if make_3D_PhaseDiagram:
        # alphas_ = np.linspace(-20, 20, SamplePoints_3D)
        # # alphas_ = np.linspace(-10, 10, SamplePoints_3D)
        # # betas_  = np.linspace(0.01,40.01,SamplePoints_3D) # Full Range
        # # betas_  = np.linspace(0.01,20.01,SamplePoints_3D) # FULL Range
        # # betas_  = np.linspace(0.01,0.99,SamplePoints_3D)  # weird part
        # betas_  = np.linspace(1.01,40.01,SamplePoints_3D)     #TEST !!!!!  For Beta <1 weird tings happen...
        # thetas_ = np.linspace(0.01,0.99,SamplePoints_3D)
        #
        #
        # alphas, betas, thetas = np.meshgrid(alphas_, betas_, thetas_, indexing='ij')
        # classifyMin_anaVec = np.vectorize(classifyMin_ana)
        #
        # # Get MuGamma values ...
        # GetMuGammaVec = np.vectorize(GetMuGamma)
        # muGammas = GetMuGammaVec(betas, thetas, gamma, mu1, rho1)
        # # Classify Minimizers....
        # G, angles, Types, curvature = classifyMin_anaVec(alphas, betas, thetas, muGammas,  mu1, rho1)   # Sets q12 to zero!!!
        #
        # # G, angles, Types, curvature = classifyMin_anaVec(alphas, betas, thetas, muGammas,  mu1, rho1, True, True)
        # # print('size of G:', G.shape)
        # # print('G:', G)
        #
        # # Option to print angles
        # # print('angles:', angles)
        #
        #
        # # Out = classifyMin_anaVec(alphas,betas,thetas)
        # # T = Out[2]
        # # --- Write to VTK
        #
        # GammaString = str(gamma)
        # VTKOutputName = "outputs/PhaseDiagram3D" + "Gamma" + GammaString
        # gridToVTK(VTKOutputName , alphas, betas, thetas, pointData = {'Type': Types, 'angles': angles, 'curvature': curvature} )
        # print('Written to VTK-File:', VTKOutputName )

    if make_2D_PhaseDiagram:
        # alphas_ = np.linspace(-20, 20, SamplePoints_2D)
        # alphas_ = np.linspace(0, 1, SamplePoints_2D)
        thetas_ = np.linspace(0.01,0.99,SamplePoints_2D)
        # alphas_ = np.linspace(-5, 5, SamplePoints_2D)
        # alphas_ = np.linspace(-5, 15, SamplePoints_2D)
        # thetas_ = np.linspace(0.05,0.25,SamplePoints_2D)
        # betas_ = 10.0
        alphas_ = -0.5
        # alphas_ = -3.0
        # alphas_ = -3.0
        alphas_ = 5.0
        betas_  = np.linspace(1.01,10.01,SamplePoints_3D)     #TEST !!!!!  For Beta <1 weird tings happen...

        alphas, betas, thetas = np.meshgrid(alphas_, betas_, thetas_, indexing='ij')


        harmonicMeanVec = np.vectorize(harmonicMean)
        arithmeticMeanVec = np.vectorize(arithmeticMean)
        prestrain_b1Vec = np.vectorize(prestrain_b1)
        prestrain_b2Vec = np.vectorize(prestrain_b2)

            #
            # q1 = (1.0/6.0)*harmonicMean(mu_1, beta, theta)
            # q2 = (1.0/6.0)*arithmeticMean(mu_1, beta, theta)

        GetMuGammaVec = np.vectorize(GetMuGamma)
        muGammas = GetMuGammaVec(betas,thetas,gamma,mu1,rho1,InputFilePath ,OutputFilePath )

        q1 = (1.0/6.0)*harmonicMeanVec(mu1, betas, thetas)
        q2 = (1.0/6.0)*arithmeticMeanVec(mu1, betas, thetas)

        b1 = prestrain_b1Vec(rho1, betas, alphas, thetas)
        b2 = prestrain_b2Vec(rho1, betas, alphas, thetas)

        # G, angles, Tq1 = harmonicMeanVec(mu1, betas, thetas)ypes, curvature = classifyMin_anaVec(alphas,betas,thetas, muGammas,  mu1, rho1)    # Sets q12 to zero!!!
            # print('size of G:', G.shape)
            # print('G:', G)
            # print('Types:', Types)
            # Out = classifyMin_anaVec(alphas,betas,thetas)
            # T = Out[2]
            # --- Write to VTK
            # VTKOutputName = + path + "./PhaseDiagram2DNEW"

        elasticRatio = q1/q2
        prestrainRatio = b1/b2

        print('type( q1) :', type(q1))

        print('q1:', q1)
        print('q2:', q2)
        print('q1/q2:', q1/q2)

        print('prestrain ratio b1/b2:', prestrainRatio)
        print('max prestrain ratio:', np.max(prestrainRatio))
        print('min prestrain ratio:', np.min(prestrainRatio))

        GammaString = str(gamma)
        VTKOutputName = "outputs/ElasticRatio" #+ "Gamma_" + GammaString


        classifyMin_anaVec = np.vectorize(classifyMin_ana)
        GetMuGammaVec = np.vectorize(GetMuGamma)
        muGammas = GetMuGammaVec(betas,thetas,gamma,mu1,rho1,InputFilePath ,OutputFilePath )
        G, angles, Types, curvature = classifyMin_anaVec(alphas,betas,thetas, muGammas,  mu1, rho1)    # Sets q12 to zero!!!







        gridToVTK(VTKOutputName , alphas, betas, thetas, pointData = {'elasticRatio': elasticRatio, 'prestrainRatio': prestrainRatio, 'Type': Types, 'angles': angles, 'curvature': curvature} )
        print('Written to VTK-File:', VTKOutputName )


    # --- Make 3D Scatter plot
    if(make_3D_plot or make_2D_plot):
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        colors = cm.plasma(Types)
        # if make_2D_plot: pnt3d=ax.scatter(alphas,thetas,c=Types.flat)
        # if make_3D_plot: pnt3d=ax.scatter(alphas,betas,thetas,c=Types.flat)
        if make_2D_plot: pnt3d=ax.scatter(alphas,thetas,c=angles.flat)
        if make_3D_plot: pnt3d=ax.scatter(alphas,betas,thetas,c=angles.flat)
        # cbar=plt.colorbar(pnt3d)
        # cbar.set_label("Values (units)")
        plt.axvline(x = 8, color = 'b', linestyle = ':', label='$q_1$')
        plt.axhline(y = 0.083333333, color = 'b', linestyle = ':', label='$q_1$')

        ax.set_xlabel('alpha')
        ax.set_ylabel('beta')
        if make_3D_plot: ax.set_zlabel('theta')
        plt.show()





# ALTERNATIVE
# colors = ("red", "green", "blue")
# groups = ("Type 1", "Type2", "Type3")
#
# # Create plot
# fig = plt.figure()
# ax = fig.add_subplot(1, 1, 1)
#
# for data, color, group in zip(Types, colors, groups):
#     # x, y = data
#     ax.scatter(alphas, thetas, alpha=0.8, c=color, edgecolors='none', label=group)
#
# plt.title('Matplot scatter plot')
# plt.legend(loc=2)
# plt.show()