import numpy as np
import matplotlib.pyplot as plt
import sympy as sym
import math
import os
import subprocess
import fileinput
import re
import matlab.engine
import sys
# from ClassifyMin import *
from ClassifyMin_New import *
from HelperFunctions import *
# from CellScript import *
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm
from vtk.util import numpy_support
from pyevtk.hl import gridToVTK

import time

from chart_studio import plotly
import plotly.graph_objs as go

import mayavi.mlab as mlab
from mayavi.api import OffScreenEngine

import scipy.signal

import matplotlib as mpl
from matplotlib.ticker import MultipleLocator,FormatStrFormatter,MaxNLocator
import pandas as pd

import seaborn as sns
import matplotlib.colors as mcolors

from mpl_toolkits.axes_grid1.inset_locator import inset_axes
# mlab.options.offscreen = True

# print(sys.executable)

# --------------------------------------------------------------------
# START :
# INPUT (Parameters):   alpha, beta, theta, gamma, mu1, rho1
#
# -Option 1 : (Case lambda = 0 => q12 = 0)
#   compute q1,q2,b1,b2 from Formula
#       Option 1.1 :
#           set mu_gamma = 'q1' or 'q2' (extreme regimes: gamma \in {0,\infty})
#       Option 1.2 :
#           compute mu_gamma with 'Compute_MuGamma' (2D problem much faster then Cell-Problem)
# -Option 2 :
#   compute Q_hom & B_eff by running 'Cell-Problem'
#
# -> CLASSIFY ...
#
# OUTPUT: Minimizer G, angle , type, curvature
# -----------------------------------------------------------------------
#
#
# def GetMuGamma(beta,theta,gamma,mu1,rho1, InputFilePath = os.path.dirname(os.getcwd()) +"/inputs/computeMuGamma.parset",
#                 OutputFilePath = os.path.dirname(os.getcwd()) + "/outputs/outputMuGamma.txt" ):
#     # ------------------------------------ get mu_gamma ------------------------------
#     # ---Scenario 1.1: extreme regimes
#     if gamma == '0':
#         print('extreme regime: gamma = 0')
#         mu_gamma = (1.0/6.0)*arithmeticMean(mu1, beta, theta) # = q2
#         print("mu_gamma:", mu_gamma)
#     elif gamma == 'infinity':
#         print('extreme regime: gamma = infinity')
#         mu_gamma = (1.0/6.0)*harmonicMean(mu1, beta, theta)   # = q1
#         print("mu_gamma:", mu_gamma)
#     else:
#         # --- Scenario 1.2:  compute mu_gamma with 'Compute_MuGamma' (much faster than running full Cell-Problem)
#         # print("Run computeMuGamma for Gamma = ", gamma)
#         with open(InputFilePath, 'r') as file:
#             filedata = file.read()
#         filedata = re.sub('(?m)^gamma=.*','gamma='+str(gamma),filedata)
#         # filedata = re.sub('(?m)^alpha=.*','alpha='+str(alpha),filedata)
#         filedata = re.sub('(?m)^beta=.*','beta='+str(beta),filedata)
#         filedata = re.sub('(?m)^theta=.*','theta='+str(theta),filedata)
#         filedata = re.sub('(?m)^mu1=.*','mu1='+str(mu1),filedata)
#         filedata = re.sub('(?m)^rho1=.*','rho1='+str(rho1),filedata)
#         f = open(InputFilePath,'w')
#         f.write(filedata)
#         f.close()
#         # --- Run Cell-Problem
#
#         # Check Time
#         # t = time.time()
#         # subprocess.run(['./build-cmake/src/Cell-Problem', './inputs/cellsolver.parset'],
#         #                                      capture_output=True, text=True)
#         # --- Run Cell-Problem_muGama   -> faster
#         # subprocess.run(['./build-cmake/src/Cell-Problem_muGamma', './inputs/cellsolver.parset'],
#         #                                              capture_output=True, text=True)
#         # --- Run Compute_muGamma (2D Problem much much faster)
#
#         subprocess.run(['./build-cmake/src/Compute_MuGamma', './inputs/computeMuGamma.parset'],
#                                                              capture_output=True, text=True)
#         # print('elapsed time:', time.time() - t)
#
#         #Extract mu_gamma from Output-File                                           TODO: GENERALIZED THIS FOR QUANTITIES OF INTEREST
#         with open(OutputFilePath, 'r') as file:
#             output = file.read()
#         tmp = re.search(r'(?m)^mu_gamma=.*',output).group()                           # Not necessary for Intention of Program t output Minimizer etc.....
#         s = re.findall(r"[-+]?\d*\.\d+|\d+", tmp)
#         mu_gamma = float(s[0])
#         # print("mu_gamma:", mu_gammaValue)
#     # --------------------------------------------------------------------------------------
#     return mu_gamma
#



# ----------- SETUP PATHS
# InputFile  = "/inputs/cellsolver.parset"
# OutputFile = "/outputs/output.txt"
InputFile  = "/inputs/computeMuGamma.parset"
OutputFile = "/outputs/outputMuGamma.txt"
# --------- Run  from src folder:
path_parent = os.path.dirname(os.getcwd())
os.chdir(path_parent)
path = os.getcwd()
print(path)
InputFilePath = os.getcwd()+InputFile
OutputFilePath = os.getcwd()+OutputFile
print("InputFilepath: ", InputFilePath)
print("OutputFilepath: ", OutputFilePath)
print("Path: ", path)


# -------------------------- Input Parameters --------------------
# mu1 = 10.0               # TODO : here must be the same values as in the Parset for computeMuGamma
mu1 = 1.0
rho1 = 1.0
alpha = 2.0
beta = 2.0
# beta = 5.0
theta = 1.0/4.0
#set gamma either to 1. '0' 2. 'infinity' or 3. a numerical positive value
gamma = '0'
# gamma = 'infinity'
# gamma = 0.5
# gamma = 0.25
# gamma = 1.0

# gamma = 5.0

#added
# lambda1 = 10.0
lambda1 = 0.0

#Test:
# rho1 = -1.0



print('---- Input parameters: -----')
print('mu1: ', mu1)
print('rho1: ', rho1)
print('alpha: ', alpha)
print('beta: ', beta)
print('theta: ', theta)
print('gamma:', gamma)

print('lambda1: ', lambda1)
print('----------------------------')
# ----------------------------------------------------------------

#
# gamma_min = 0.5
# gamma_max = 1.0
#
# # gamma_min = 1
# # gamma_max = 1
# Gamma_Values = np.linspace(gamma_min, gamma_max, num=3)
# # #
# # # Gamma_Values = np.linspace(gamma_min, gamma_max, num=13)    # TODO variable Input Parameters...alpha,beta...
# print('(Input) Gamma_Values:', Gamma_Values)

print('type of gamma:', type(gamma))
# # #
# Gamma_Values = ['0', 'infinity']
# Gamma_Values = ['infinity']
Gamma_Values = ['0']
# Gamma_Values = [ 'infinity','0']
print('(Input) Gamma_Values:', Gamma_Values)

for gamma in Gamma_Values:

    print('Run for gamma = ', gamma)
    print('type of gamma:', type(gamma))
        # muGamma = GetMuGamma(beta,theta,gamma,mu1,rho1,InputFilePath)
        # # muGamma = GetMuGamma(beta,theta,gamma,mu1,rho1)
        # print('Test MuGamma:', muGamma)

        # ------- Options --------
        # print_Cases = True
        # print_Output = True

                            #TODO
    # generalCase = True #Read Output from Cell-Problem instead of using Lemma1.4 (special case)
    generalCase = False

    make_3D_plot = True
    make_3D_PhaseDiagram = True
    make_2D_plot = False
    make_2D_PhaseDiagram = False
    make_3D_plot = False

    # make_3D_PhaseDiagram = False
    # make_2D_plot = True
    # make_2D_PhaseDiagram = True
    #

    # --- Define effective quantities: q1, q2 , q3 = mu_gamma, q12 ---
    # q1 = harmonicMean(mu1, beta, theta)
    # q2 = arithmeticMean(mu1, beta, theta)
    # --- Set q12
    # q12 = 0.0  # (analytical example)              # TEST / TODO read from Cell-Output





    # b1 = prestrain_b1(rho1, beta, alpha, theta)
    # b2 = prestrain_b2(rho1, beta, alpha, theta)
    #
    # print('---- Input parameters: -----')
    # print('mu1: ', mu1)
    # print('rho1: ', rho1)
    # print('alpha: ', alpha)
    # print('beta: ', beta)
    # print('theta: ', theta)
    # print("q1: ", q1)
    # print("q2: ", q2)
    # print("mu_gamma: ", mu_gamma)
    # print("q12: ", q12)
    # print("b1: ", b1)
    # print("b2: ", b2)
    # print('----------------------------')
    # print("machine epsilon", sys.float_info.epsilon)

    # G, angle, type, kappa = classifyMin(q1, q2, mu_gamma, q12,  b1, b2, print_Cases, print_Output)
    # Test = f(1,2 ,q1,q2,mu_gamma,q12,b1,b2)
    # print("Test", Test)

    # ---------------------- MAKE PLOT / Write to VTK------------------------------------------------------------------------------

    # SamplePoints_3D = 10 # Number of sample points in each direction
    # SamplePoints_2D = 10 # Number of sample points in each direction
    # SamplePoints_3D = 300 # Number of sample points in each direction
    # SamplePoints_3D = 150 # Number of sample points in each direction
    SamplePoints_3D = 100 # Number of sample points in each direction
    # # SamplePoints_3D = 50 # Number of sample points in each direction
    # SamplePoints_3D = 25 # Number of sample points in each direction
    # SamplePoints_3D = 200 # Number of sample points in each direction
    # SamplePoints_3D = 400 # Number of sample points in each direction
    # SamplePoints_2D = 7500 # Number of sample points in each direction
    # SamplePoints_2D = 4000 # 4000 # Number of sample points in each direction
    # SamplePoints_2D = 400 # 4000 # Number of sample points in each direction
    # SamplePoints_2D = 1000 # 4000 # Number of sample points in each direction
    # SamplePoints_3D = 10 # Number of sample points in each direction

    print('NUMBER OF POINTS USED(3D):', SamplePoints_3D)

    if make_3D_PhaseDiagram:
        alphas_ = np.linspace(-20, 20, SamplePoints_3D)
        # alphas_ = np.linspace(-10, 10, SamplePoints_3D)

        # betas_  = np.linspace(0.01,40.01,SamplePoints_3D) # Full Range
        # betas_  = np.linspace(0.01,20.01,SamplePoints_3D) # FULL Range



        # betas_  = np.linspace(0.01,0.99,SamplePoints_3D)  # weird part
        # betas_  = np.linspace(1.01,40.01,SamplePoints_3D)     #TEST !!!!!  For Beta <1 weird tings happen...

        thetas_ = np.linspace(0.01,0.99,SamplePoints_3D)

        #TEST
        alphas_ = np.linspace(-5, 15, SamplePoints_3D)
        betas_  = np.linspace(1.01,20.01,SamplePoints_3D)     #TEST !!!!!  For Beta <1 weird tings happen...
        # TEST
        # alphas_ = np.linspace(-2, 2, SamplePoints_3D)
        # betas_  = np.linspace(1.01,10.01,SamplePoints_3D)
        # print('betas:', betas_)

        # TEST :
        # alphas_ = np.linspace(-40, 40, SamplePoints_3D)
        # betas_  = np.linspace(0.01,80.01,SamplePoints_3D) # Full Range

        # print('type of alphas', type(alphas_))
        # print('Test:', type(np.array([mu_gamma])) )
        alphas, betas, thetas = np.meshgrid(alphas_, betas_, thetas_, indexing='ij')
        classifyMin_anaVec = np.vectorize(classifyMin_ana)

        # Get MuGamma values ...
        GetMuGammaVec = np.vectorize(GetMuGamma)
        muGammas = GetMuGammaVec(betas, thetas, gamma, mu1, rho1)
        # Classify Minimizers....
        G, angles, Types, curvature = classifyMin_anaVec(alphas, betas, thetas, muGammas,  mu1, rho1)   # Sets q12 to zero!!!
        # G, angles, Types, curvature = classifyMin_anaVec(alphas, betas, thetas, muGammas,  mu1, rho1,True,True)   # Sets q12 to zero!!!
        # G, angles, Types, curvature = classifyMin_anaVec(alphas, betas, thetas, muGammas,  mu1, rho1, True, True)
        # print('size of G:', G.shape)
        # print('G:', G)

        # Option to print angles
        # print('angles:', angles)


        # Out = classifyMin_anaVec(alphas,betas,thetas)
        # T = Out[2]
        # --- Write to VTK

        GammaString = str(gamma)
        VTKOutputName = "outputs/PhaseDiagram3D" + "Gamma" + GammaString
        gridToVTK(VTKOutputName , alphas, betas, thetas, pointData = {'Type': Types, 'angles': angles, 'curvature': curvature} )
        print('Written to VTK-File:', VTKOutputName )

    if make_2D_PhaseDiagram:
        # alphas_ = np.linspace(-20, 20, SamplePoints_2D)
        # alphas_ = np.linspace(0, 1, SamplePoints_2D)
        thetas_ = np.linspace(0.01,0.99,SamplePoints_2D)
        alphas_ = np.linspace(-5, 5, SamplePoints_2D)
        # alphas_ = np.linspace(-5, 15, SamplePoints_2D)
        # thetas_ = np.linspace(0.05,0.25,SamplePoints_2D)


        # good range:
        # alphas_ = np.linspace(9, 10, SamplePoints_2D)
        # thetas_ = np.linspace(0.075,0.14,SamplePoints_2D)

        # range used:
        # alphas_ = np.linspace(8, 10, SamplePoints_2D)
        # thetas_ = np.linspace(0.05,0.16,SamplePoints_2D)

            # alphas_ = np.linspace(8, 12, SamplePoints_2D)
            # thetas_ = np.linspace(0.05,0.2,SamplePoints_2D)
        # betas_  = np.linspace(0.01,40.01,1)
        #fix to one value:
        betas_ = 2.0;
        # betas_ = 10.0;
        # betas_ = 5.0;
        # betas_ = 0.5;


        #intermediate Values
        # alphas_ = np.linspace(-2, 1, SamplePoints_2D)
        # thetas_ = np.linspace(0.4,0.6,SamplePoints_2D)
        # betas_ = 10.0;

        # TEST
        # alphas_ = np.linspace(-8, 8, SamplePoints_2D)
        # thetas_ = np.linspace(0.01,0.99,SamplePoints_2D)
        # betas_ = 1.0; #TEST Problem: disvison by zero if alpha = 9, theta = 0.1 !
        # betas_ = 0.9;
        # betas_ = 0.5;  #TEST!!!
        alphas, betas, thetas = np.meshgrid(alphas_, betas_, thetas_, indexing='ij')


        if generalCase:
            classifyMin_matVec = np.vectorize(classifyMin_mat)
            GetCellOutputVec = np.vectorize(GetCellOutput, otypes=[np.ndarray, np.ndarray])
            Q, B = GetCellOutputVec(alphas,betas,thetas,gamma,mu1,rho1,lambda1, InputFilePath ,OutputFilePath )


            # print('type of Q:', type(Q))
            # print('Q:', Q)
            G, angles, Types, curvature = classifyMin_matVec(Q,B)

        else:
            classifyMin_anaVec = np.vectorize(classifyMin_ana)
            GetMuGammaVec = np.vectorize(GetMuGamma)
            muGammas = GetMuGammaVec(betas,thetas,gamma,mu1,rho1,InputFilePath ,OutputFilePath )
            G, angles, Types, curvature = classifyMin_anaVec(alphas,betas,thetas, muGammas,  mu1, rho1)    # Sets q12 to zero!!!
            # print('size of G:', G.shape)
            # print('G:', G)
            # print('Types:', Types)
            # Out = classifyMin_anaVec(alphas,betas,thetas)
            # T = Out[2]
            # --- Write to VTK
            # VTKOutputName = + path + "./PhaseDiagram2DNEW"


        GammaString = str(gamma)
        VTKOutputName = "outputs/PhaseDiagram2D" + "Gamma_" + GammaString
        gridToVTK(VTKOutputName , alphas, betas, thetas, pointData = {'Type': Types, 'angles': angles, 'curvature': curvature} )
        print('Written to VTK-File:', VTKOutputName )


    # --- Make 3D Scatter plot
    if(make_3D_plot or make_2D_plot):
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        # ax = plt.axes(projection ='3d', adjustable='box')
        # fig,ax = plt.subplots(111, projection='3d')
        # ax = plt.axes(projection ='3d', adjustable='box')
        colors = cm.plasma(Types)
        # if make_2D_plot: pnt3d=ax.scatter(alphas,thetas,c=Types.flat)

        if make_2D_plot: pnt3d=ax.scatter(alphas,thetas,c=angles.flat)




        if make_3D_plot:

            width = 6.28 *0.5
            # width = 6.28
            # height = width / 1.618
            height = width

            # pnt3d=ax.scatter(alphas,betas,thetas,c=angles.flatten())
            plt.style.use("seaborn")
            # plt.style.use("seaborn-paper")
            # plt.style.use('ggplot')
            # plt.rcParams["font.family"] = "Avenir"
            # plt.rcParams["font.size"] = 16

            # plt.style.use("seaborn-darkgrid")
            mpl.rcParams['text.usetex'] = True
            mpl.rcParams["font.family"] = "serif"
            mpl.rcParams["font.size"] = "10"
            # mpl.rcParams['xtick.labelsize'] = 16mpl.rcParams['xtick.major.size'] = 2.5
            # mpl.rcParams['xtick.bottom'] = True
            # mpl.rcParams['ticks'] = True
            mpl.rcParams['xtick.bottom'] = True
            mpl.rcParams['xtick.major.size'] = 3
            mpl.rcParams['xtick.minor.size'] = 1.5
            mpl.rcParams['xtick.major.width'] = 0.75
            mpl.rcParams['ytick.left'] = True
            mpl.rcParams['ytick.major.size'] = 3
            mpl.rcParams['ytick.minor.size'] = 1.5
            mpl.rcParams['ytick.major.width'] = 0.75

            mpl.rcParams.update({'font.size': 10})
            mpl.rcParams['axes.labelpad'] = 1



            angles = angles.flatten()
            cmap = mpl.colors.LinearSegmentedColormap.from_list("", ["blue","violet","red"])
            cmap=mpl.cm.RdBu_r
            # cmap=mpl.cm.viridis_r
            # cmap=mpl.cm.bwr
            # cmap=mpl.cm.coolwarm
            # cmap=mpl.cm.Blues_r

            # norm = mpl.colors.Normalize(vmin=5, vmax=10)
            # cmap=mpl.cm.gnuplot
            # cmap=mpl.cm.magma_r
            # cmap=mpl.cm.inferno_r
            # cmap=mpl.cm.plasma
            # cmap=mpl.cm.plasma_r
            # cmap=mpl.cm.cividis_r
            # cmap = mpl.colors.LinearSegmentedColormap.from_list("", ["blue","violet","red"])
            # cmap = mpl.colors.LinearSegmentedColormap.from_list("", ["blue","orange"])
            divnorm=mcolors.TwoSlopeNorm(vmin=angles.min(), vcenter=(angles.max()+angles.min())/2, vmax=angles.max())
            # cmap = cm.ScalarMappable(norm=divnorm, cmap=cmap)

            opacity_list = 1-angles/angles.max()
            print('opacity_list', opacity_list)
            print('opacity_list.max():', opacity_list.max())

            # get a Nx4 array of RGBA corresponding to zs
            # cmap expects values between 0 and 1
            colors = cmap(angles/angles.max())
            # colors = angles/angles.max()
            print('colors:', colors)

            ### set the alpha values according to i_list
            ### must satisfy 0 <= i <= 1
            # epsilon =0.01
            opacity_list = np.array(opacity_list)
            colors[:,-1] = opacity_list / opacity_list.max()


            # ax.scatter(alphas,betas,thetas,c=angles.flatten())
            # S = ax.scatter(alphas,betas,thetas,c=colors, cmap=cmap, norm = divnorm)
            S = ax.scatter(alphas,betas,thetas,c=colors)

            S_2 = ax.scatter(alphas,betas,thetas,c=angles/angles.max(), cmap=cmap, s=0) # Hack for colormap...
            # ax.view_init(elev=30, azim=75)
            ax.view_init(elev=25, azim=75)



            # fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
            #  cax=ax, orientation='horizontal', label='Some Units')
            # plt.colorbar(S)
            # fig.colorbar(S, ax=ax)
            # axins1 = inset_axes(ax,
            #                    width="5%",  # width = 5% of parent_bbox width
            #                    height="100%",  # height : 50%
            #                    loc='lower left',
            #                    bbox_to_anchor=(1.05, 0., 1, 1),
            #                    bbox_transform=ax[1].transAxes,
            #                    borderpad=0,
            #                    )

            # ax.clabel(CS2, CS2.levels, inline=True, fontsize=10)
            # ax.clabel(CS,  fontsize=5, colors='black')
            # cbar = fig.colorbar(CS,label=r'angle $\alpha$', ticks=[0, np.pi/8, np.pi/4, 3*np.pi/8 , np.pi/2 ])
            # cbar = fig.colorbar(CS_1, ticks=[0, np.pi/8, np.pi/4, 3*np.pi/8 , np.pi/2 ])

            # cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
            # cbar = fig.colorbar(S, cax=ax, ticks=[0, np.pi/8, np.pi/4, 3*np.pi/8 , np.pi/2 ])
            # cbar = fig.colorbar(S_2, ax=ax)
            # cbar = fig.colorbar(S, ax=ax)
            # cbar = fig.colorbar(CS_1, cax=cbar_ax, shrink=0.2, location='right', ticks=[0, np.pi/8, np.pi/4, 3*np.pi/8 , np.pi/2 ])
            # cbar = fig.colorbar(CS_1,  ax=ax[:], shrink=0.8, location='right', ticks=[0, np.pi/8, np.pi/4, 3*np.pi/8 , np.pi/2 ])

            ## ADD COLORBAR:

            axins = inset_axes(ax,
                                width="5%",
                                height="100%",
                                loc='right',
                                borderpad=0,
                                bbox_to_anchor=[0.0, 0.5]
                               )
            cbar = fig.colorbar(S_2, cax=axins)
            # cbar = fig.colorbar(S_2, orientation="horizontal", pad=0.2)
            # cbar = fig.colorbar(S_2, pad=0.2)
            cbar.ax.set_yticklabels([r'$0$',r'$\pi/8$', r'$\pi/4$' ,r'$3\pi/8$' , r'$\pi/2$'])
            cbar.ax.set_title(r'$\alpha$')


            ### COLORBAR :
            # cbar = plt.colorbar()
            # cbar.ax.tick_params(labelsize=10)
            # fig.colorbar(S)
        # cbar=plt.colorbar(pnt3d)
        # cbar.set_label("Values (units)")
        # plt.axvline(x = 8, color = 'b', linestyle = ':', label='$q_1$')
        # plt.axhline(y = 0.083333333, color = 'b', linestyle = ':', label='$q_1$')

        # if make_3D_plot: pnt3d=ax.scatter(alphas,betas,thetas,c=angles.flat)
        # if make_3D_plot: fig = go.Figure(data=[go.Surface(z=thetas, x=alphas, y=betas, color=angles.flat)])
        #### PLOTLY:
        # print('angles.flatten()',angles.flatten())
        # fig = go.Figure(data=go.Isosurface(
        #     x=alphas.flatten(),
        #     y=betas.flatten(),
        #     z=thetas.flatten(),
        #     value=angles.flatten(),
        #     isomin=0,
        #     isomax=1.565,
        #     opacity=1.0,
        #     colorscale='agsunset',
        #     flatshading = True
        #     # caps=dict(x_show=False, y_show=False)
        #     ))
        # fig.show()

        # ----TEST SAVITZKY_GOLAY FILTER
        # zhat = scipy.signal.savgol_filter(angles.flatten(), 5, 4) # window size 51, polynomial order 3
        #
        # fig = go.Figure(data=go.Volume(
        #     x=alphas.flatten(),
        #     y=betas.flatten(),
        #     z=thetas.flatten(),
        #     value=zhat,
        #     isomin=0.0,
        #     isomax=1.56,
        #     opacity=0.1, # needs to be small to see through all surfaces
        #     surface_count=17, # needs to be a large number for good volume rendering
        #     colorscale='RdBu'
        #     ))
        # fig.show()


        ## --------------------------------
        # alphas = np.array(alphas)
        # print('alphas.shape:',np.shape(alphas))
        # #### ------- MAYAVI:
        # # s = angles.flatten()
        # s = angles
        # src = mlab.pipeline.scalar_field(s)
        # mlab.pipeline.iso_surface(src, contours=[s.min()+0.1*s.ptp(), ], opacity=0.3)
        # mlab.pipeline.iso_surface(src, contours=[s.max()-0.1*s.ptp(), ],)
        # # mlab.outline()
        # # mlab.mesh(alphas,betas,thetas)
        # mlab.colorbar( orientation='vertical', nb_labels=5)
        # # mlab.orientation_axes()
        # mlab.show()
        ### ---------------





        ax.set_xlabel(r'$\theta_\rho$', labelpad=2)
        ax.set_ylabel(r"$\theta_\mu$", labelpad=2)
        if make_3D_plot: ax.set_zlabel(r'$\theta$',labelpad=2)

        fig.set_size_inches(width, height)
        # fig.savefig('PhaseDiagram3D.pdf')
        fig.savefig('PhaseDiagram3D.png', format='png')
        # fig.savefig('Plot-Prestrain-Theta_AlphaFix.pdf',bbox_extra_artists=(cbar,),
        #             bbox_inches='tight')
        # fig.savefig('Plot-Prestrain-Theta_AlphaFix.pdf',format='png',bbox_extra_artists=(cbar,),
                    # bbox_inches='tight')
        # fig.savefig('PhaseDiagram3D', format='svg')
        # fig.savefig('PhaseDiagram3D.pdf', dpi=90)
        plt.show()

        # fig.set_size_inches(width, height)
        # fig.savefig('PhaseDiagram3D.pdf')
        # plt.savefig('common_labels.png', dpi=300)
        # print('T:', T)
        # print('Type 1 occured here:', np.where(T == 1))
        # print('Type 2 occured here:', np.where(T == 2))


        # print(alphas_)
        # print(betas_)





# ALTERNATIVE
# colors = ("red", "green", "blue")
# groups = ("Type 1", "Type2", "Type3")
#
# # Create plot
# fig = plt.figure()
# ax = fig.add_subplot(1, 1, 1)
#
# for data, color, group in zip(Types, colors, groups):
#     # x, y = data
#     ax.scatter(alphas, thetas, alpha=0.8, c=color, edgecolors='none', label=group)
#
# plt.title('Matplot scatter plot')
# plt.legend(loc=2)
# plt.show()