import numpy as np
import matplotlib.pyplot as plt
import sympy as sym
import math
import os
import subprocess
import fileinput
import re
import matlab.engine
import sys
from ClassifyMin import *
from HelperFunctions import *
# from CellScript import *
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm
from vtk.util import numpy_support
from pyevtk.hl import gridToVTK

import time

from chart_studio import plotly
import plotly.graph_objs as go

import mayavi.mlab as mlab
from mayavi.api import OffScreenEngine
from mayavi.mlab import *
import tvtk

import scipy.signal

import matplotlib as mpl
from matplotlib.ticker import MultipleLocator,FormatStrFormatter,MaxNLocator
import pandas as pd

import seaborn as sns
import matplotlib.colors as mcolors

from mpl_toolkits.axes_grid1.inset_locator import inset_axes, zoomed_inset_axes
# mlab.options.offscreen = True

# print(sys.executable)

# --------------------------------------------------------------------
# START :
# INPUT (Parameters):   alpha, beta, theta, gamma, mu1, rho1
#
# -Option 1 : (Case lambda = 0 => q12 = 0)
#   compute q1,q2,b1,b2 from Formula
#       Option 1.1 :
#           set mu_gamma = 'q1' or 'q2' (extreme regimes: gamma \in {0,\infty})
#       Option 1.2 :
#           compute mu_gamma with 'Compute_MuGamma' (2D problem much faster then Cell-Problem)
# -Option 2 :
#   compute Q_hom & B_eff by running 'Cell-Problem'
#
# -> CLASSIFY ...
#
# OUTPUT: Minimizer G, angle , type, curvature
# -----------------------------------------------------------------------
#
#
# def GetMuGamma(beta,theta,gamma,mu1,rho1, InputFilePath = os.path.dirname(os.getcwd()) +"/inputs/computeMuGamma.parset",
#                 OutputFilePath = os.path.dirname(os.getcwd()) + "/outputs/outputMuGamma.txt" ):
#     # ------------------------------------ get mu_gamma ------------------------------
#     # ---Scenario 1.1: extreme regimes
#     if gamma == '0':
#         print('extreme regime: gamma = 0')
#         mu_gamma = (1.0/6.0)*arithmeticMean(mu1, beta, theta) # = q2
#         print("mu_gamma:", mu_gamma)
#     elif gamma == 'infinity':
#         print('extreme regime: gamma = infinity')
#         mu_gamma = (1.0/6.0)*harmonicMean(mu1, beta, theta)   # = q1
#         print("mu_gamma:", mu_gamma)
#     else:
#         # --- Scenario 1.2:  compute mu_gamma with 'Compute_MuGamma' (much faster than running full Cell-Problem)
#         # print("Run computeMuGamma for Gamma = ", gamma)
#         with open(InputFilePath, 'r') as file:
#             filedata = file.read()
#         filedata = re.sub('(?m)^gamma=.*','gamma='+str(gamma),filedata)
#         # filedata = re.sub('(?m)^alpha=.*','alpha='+str(alpha),filedata)
#         filedata = re.sub('(?m)^beta=.*','beta='+str(beta),filedata)
#         filedata = re.sub('(?m)^theta=.*','theta='+str(theta),filedata)
#         filedata = re.sub('(?m)^mu1=.*','mu1='+str(mu1),filedata)
#         filedata = re.sub('(?m)^rho1=.*','rho1='+str(rho1),filedata)
#         f = open(InputFilePath,'w')
#         f.write(filedata)
#         f.close()
#         # --- Run Cell-Problem
#
#         # Check Time
#         # t = time.time()
#         # subprocess.run(['./build-cmake/src/Cell-Problem', './inputs/cellsolver.parset'],
#         #                                      capture_output=True, text=True)
#         # --- Run Cell-Problem_muGama   -> faster
#         # subprocess.run(['./build-cmake/src/Cell-Problem_muGamma', './inputs/cellsolver.parset'],
#         #                                              capture_output=True, text=True)
#         # --- Run Compute_muGamma (2D Problem much much faster)
#
#         subprocess.run(['./build-cmake/src/Compute_MuGamma', './inputs/computeMuGamma.parset'],
#                                                              capture_output=True, text=True)
#         # print('elapsed time:', time.time() - t)
#
#         #Extract mu_gamma from Output-File                                           TODO: GENERALIZED THIS FOR QUANTITIES OF INTEREST
#         with open(OutputFilePath, 'r') as file:
#             output = file.read()
#         tmp = re.search(r'(?m)^mu_gamma=.*',output).group()                           # Not necessary for Intention of Program t output Minimizer etc.....
#         s = re.findall(r"[-+]?\d*\.\d+|\d+", tmp)
#         mu_gamma = float(s[0])
#         # print("mu_gamma:", mu_gammaValue)
#     # --------------------------------------------------------------------------------------
#     return mu_gamma
#



# ----------- SETUP PATHS
# InputFile  = "/inputs/cellsolver.parset"
# OutputFile = "/outputs/output.txt"
InputFile  = "/inputs/computeMuGamma.parset"
OutputFile = "/outputs/outputMuGamma.txt"
# --------- Run  from src folder:
path_parent = os.path.dirname(os.getcwd())
os.chdir(path_parent)
path = os.getcwd()
print(path)
InputFilePath = os.getcwd()+InputFile
OutputFilePath = os.getcwd()+OutputFile
print("InputFilepath: ", InputFilePath)
print("OutputFilepath: ", OutputFilePath)
print("Path: ", path)


# -------------------------- Input Parameters --------------------
# mu1 = 10.0               # TODO : here must be the same values as in the Parset for computeMuGamma
mu1 = 1.0
rho1 = 1.0
alpha = 2.0
beta = 2.0
# beta = 5.0
theta = 1.0/4.0
#set gamma either to 1. '0' 2. 'infinity' or 3. a numerical positive value
gamma = '0'
gamma = 'infinity'
# gamma = 0.5
# gamma = 0.25
# gamma = 1.0

# gamma = 5.0

#added
# lambda1 = 10.0
lambda1 = 0.0

#Test:
# rho1 = -1.0



print('---- Input parameters: -----')
print('mu1: ', mu1)
print('rho1: ', rho1)
print('alpha: ', alpha)
print('beta: ', beta)
print('theta: ', theta)
print('gamma:', gamma)

print('lambda1: ', lambda1)
print('----------------------------')
# ----------------------------------------------------------------

#
# gamma_min = 0.5
# gamma_max = 1.0
#
# # gamma_min = 1
# # gamma_max = 1
# Gamma_Values = np.linspace(gamma_min, gamma_max, num=3)
# # #
# # # Gamma_Values = np.linspace(gamma_min, gamma_max, num=13)    # TODO variable Input Parameters...alpha,beta...
# print('(Input) Gamma_Values:', Gamma_Values)

print('type of gamma:', type(gamma))
# # #
# Gamma_Values = ['0', 'infinity']
Gamma_Values = ['infinity']
# Gamma_Values = ['0']
print('(Input) Gamma_Values:', Gamma_Values)

for gamma in Gamma_Values:

    print('Run for gamma = ', gamma)
    print('type of gamma:', type(gamma))
        # muGamma = GetMuGamma(beta,theta,gamma,mu1,rho1,InputFilePath)
        # # muGamma = GetMuGamma(beta,theta,gamma,mu1,rho1)
        # print('Test MuGamma:', muGamma)

        # ------- Options --------
        # print_Cases = True
        # print_Output = True

                            #TODO
    # generalCase = True #Read Output from Cell-Problem instead of using Lemma1.4 (special case)
    generalCase = False

    make_3D_plot = True
    make_3D_PhaseDiagram = True
    make_2D_plot = False
    make_2D_PhaseDiagram = False
    # make_3D_plot = False
    # make_3D_PhaseDiagram = False
    # make_2D_plot = True
    # make_2D_PhaseDiagram = True
    #

    # --- Define effective quantities: q1, q2 , q3 = mu_gamma, q12 ---
    # q1 = harmonicMean(mu1, beta, theta)
    # q2 = arithmeticMean(mu1, beta, theta)
    # --- Set q12
    # q12 = 0.0  # (analytical example)              # TEST / TODO read from Cell-Output





    # b1 = prestrain_b1(rho1, beta, alpha, theta)
    # b2 = prestrain_b2(rho1, beta, alpha, theta)
    #
    # print('---- Input parameters: -----')
    # print('mu1: ', mu1)
    # print('rho1: ', rho1)
    # print('alpha: ', alpha)
    # print('beta: ', beta)
    # print('theta: ', theta)
    # print("q1: ", q1)
    # print("q2: ", q2)
    # print("mu_gamma: ", mu_gamma)
    # print("q12: ", q12)
    # print("b1: ", b1)
    # print("b2: ", b2)
    # print('----------------------------')
    # print("machine epsilon", sys.float_info.epsilon)

    # G, angle, type, kappa = classifyMin(q1, q2, mu_gamma, q12,  b1, b2, print_Cases, print_Output)
    # Test = f(1,2 ,q1,q2,mu_gamma,q12,b1,b2)
    # print("Test", Test)

    # ---------------------- MAKE PLOT / Write to VTK------------------------------------------------------------------------------

    # SamplePoints_3D = 10 # Number of sample points in each direction
    # SamplePoints_2D = 10 # Number of sample points in each direction
    SamplePoints_3D = 300 # Number of sample points in each direction
    SamplePoints_3D = 150 # Number of sample points in each direction
    SamplePoints_3D = 100 # Number of sample points in each direction
    SamplePoints_3D = 50 # Number of sample points in each direction
    # SamplePoints_3D = 25 # Number of sample points in each direction
    # SamplePoints_3D = 200 # Number of sample points in each direction
    # SamplePoints_3D = 400 # Number of sample points in each direction
    # SamplePoints_2D = 7500 # Number of sample points in each direction
    # SamplePoints_2D = 4000 # 4000 # Number of sample points in each direction
    # SamplePoints_2D = 400 # 4000 # Number of sample points in each direction
    # SamplePoints_2D = 1000 # 4000 # Number of sample points in each direction

    print('NUMBER OF POINTS USED(3D):', SamplePoints_3D)

    if make_3D_PhaseDiagram:
        alphas_ = np.linspace(-20, 20, SamplePoints_3D)
        # alphas_ = np.linspace(-10, 10, SamplePoints_3D)

        # betas_  = np.linspace(0.01,40.01,SamplePoints_3D) # Full Range
        # betas_  = np.linspace(0.01,20.01,SamplePoints_3D) # FULL Range



        # betas_  = np.linspace(0.01,0.99,SamplePoints_3D)  # weird part
        betas_  = np.linspace(1.01,40.01,SamplePoints_3D)     #TEST !!!!!  For Beta <1 weird tings happen...
        thetas_ = np.linspace(0.01,0.99,SamplePoints_3D)


        # TEST
        # alphas_ = np.linspace(-2, 2, SamplePoints_3D)
        # betas_  = np.linspace(1.01,10.01,SamplePoints_3D)
        # print('betas:', betas_)

        # TEST :
        # alphas_ = np.linspace(-40, 40, SamplePoints_3D)
        # betas_  = np.linspace(0.01,80.01,SamplePoints_3D) # Full Range

        # print('type of alphas', type(alphas_))
        # print('Test:', type(np.array([mu_gamma])) )
        alphas, betas, thetas = np.meshgrid(alphas_, betas_, thetas_, indexing='ij')
        classifyMin_anaVec = np.vectorize(classifyMin_ana)

        # Get MuGamma values ...
        GetMuGammaVec = np.vectorize(GetMuGamma)
        muGammas = GetMuGammaVec(betas, thetas, gamma, mu1, rho1)
        # Classify Minimizers....
        G, angles, Types, curvature = classifyMin_anaVec(alphas, betas, thetas, muGammas,  mu1, rho1)   # Sets q12 to zero!!!

        # G, angles, Types, curvature = classifyMin_anaVec(alphas, betas, thetas, muGammas,  mu1, rho1, True, True)
        # print('size of G:', G.shape)
        # print('G:', G)

        # Option to print angles
        # print('angles:', angles)


        # Out = classifyMin_anaVec(alphas,betas,thetas)
        # T = Out[2]
        # --- Write to VTK

        GammaString = str(gamma)
        VTKOutputName = "outputs/PhaseDiagram3D" + "Gamma" + GammaString
        gridToVTK(VTKOutputName , alphas, betas, thetas, pointData = {'Type': Types, 'angles': angles, 'curvature': curvature} )
        print('Written to VTK-File:', VTKOutputName )

    if make_2D_PhaseDiagram:
        # alphas_ = np.linspace(-20, 20, SamplePoints_2D)
        # alphas_ = np.linspace(0, 1, SamplePoints_2D)
        thetas_ = np.linspace(0.01,0.99,SamplePoints_2D)
        alphas_ = np.linspace(-5, 5, SamplePoints_2D)
        # alphas_ = np.linspace(-5, 15, SamplePoints_2D)
        # thetas_ = np.linspace(0.05,0.25,SamplePoints_2D)


        # good range:
        # alphas_ = np.linspace(9, 10, SamplePoints_2D)
        # thetas_ = np.linspace(0.075,0.14,SamplePoints_2D)

        # range used:
        # alphas_ = np.linspace(8, 10, SamplePoints_2D)
        # thetas_ = np.linspace(0.05,0.16,SamplePoints_2D)

            # alphas_ = np.linspace(8, 12, SamplePoints_2D)
            # thetas_ = np.linspace(0.05,0.2,SamplePoints_2D)
        # betas_  = np.linspace(0.01,40.01,1)
        #fix to one value:
        betas_ = 2.0;
        # betas_ = 10.0;
        # betas_ = 5.0;
        # betas_ = 0.5;


        #intermediate Values
        # alphas_ = np.linspace(-2, 1, SamplePoints_2D)
        # thetas_ = np.linspace(0.4,0.6,SamplePoints_2D)
        # betas_ = 10.0;

        # TEST
        # alphas_ = np.linspace(-8, 8, SamplePoints_2D)
        # thetas_ = np.linspace(0.01,0.99,SamplePoints_2D)
        # betas_ = 1.0; #TEST Problem: disvison by zero if alpha = 9, theta = 0.1 !
        # betas_ = 0.9;
        # betas_ = 0.5;  #TEST!!!
        alphas, betas, thetas = np.meshgrid(alphas_, betas_, thetas_, indexing='ij')


        if generalCase:
            classifyMin_matVec = np.vectorize(classifyMin_mat)
            GetCellOutputVec = np.vectorize(GetCellOutput, otypes=[np.ndarray, np.ndarray])
            Q, B = GetCellOutputVec(alphas,betas,thetas,gamma,mu1,rho1,lambda1, InputFilePath ,OutputFilePath )


            # print('type of Q:', type(Q))
            # print('Q:', Q)
            G, angles, Types, curvature = classifyMin_matVec(Q,B)

        else:
            classifyMin_anaVec = np.vectorize(classifyMin_ana)
            GetMuGammaVec = np.vectorize(GetMuGamma)
            muGammas = GetMuGammaVec(betas,thetas,gamma,mu1,rho1,InputFilePath ,OutputFilePath )
            G, angles, Types, curvature = classifyMin_anaVec(alphas,betas,thetas, muGammas,  mu1, rho1)    # Sets q12 to zero!!!
            # print('size of G:', G.shape)
            # print('G:', G)
            # print('Types:', Types)
            # Out = classifyMin_anaVec(alphas,betas,thetas)
            # T = Out[2]
            # --- Write to VTK
            # VTKOutputName = + path + "./PhaseDiagram2DNEW"


        GammaString = str(gamma)
        VTKOutputName = "outputs/PhaseDiagram2D" + "Gamma_" + GammaString
        gridToVTK(VTKOutputName , alphas, betas, thetas, pointData = {'Type': Types, 'angles': angles, 'curvature': curvature} )
        print('Written to VTK-File:', VTKOutputName )


    # --- Make 3D Scatter plot
    if(make_3D_plot or make_2D_plot):
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        # ax = plt.axes(projection ='3d', adjustable='box')
        # fig,ax = plt.subplots(111, projection='3d')
        # ax = plt.axes(projection ='3d', adjustable='box')
        colors = cm.plasma(Types)
        # if make_2D_plot: pnt3d=ax.scatter(alphas,thetas,c=Types.flat)

        if make_2D_plot: pnt3d=ax.scatter(alphas,thetas,c=angles.flat)




        if make_3D_plot:

            width = 6.28 *0.5
            # width = 6.28
            # height = width / 1.618
            height = width

            # pnt3d=ax.scatter(alphas,betas,thetas,c=angles.flatten())
            plt.style.use("seaborn")
            # plt.style.use("seaborn-paper")
            # plt.style.use('ggplot')
            # plt.rcParams["font.family"] = "Avenir"
            # plt.rcParams["font.size"] = 16

            # plt.style.use("seaborn-darkgrid")
            mpl.rcParams['text.usetex'] = True
            mpl.rcParams["font.family"] = "serif"
            mpl.rcParams["font.size"] = "10"
            # mpl.rcParams['xtick.labelsize'] = 16mpl.rcParams['xtick.major.size'] = 2.5
            # mpl.rcParams['xtick.bottom'] = True
            # mpl.rcParams['ticks'] = True
            mpl.rcParams['xtick.bottom'] = True
            mpl.rcParams['xtick.major.size'] = 3
            mpl.rcParams['xtick.minor.size'] = 1.5
            mpl.rcParams['xtick.major.width'] = 0.75
            mpl.rcParams['ytick.left'] = True
            mpl.rcParams['ytick.major.size'] = 3
            mpl.rcParams['ytick.minor.size'] = 1.5
            mpl.rcParams['ytick.major.width'] = 0.75

            mpl.rcParams.update({'font.size': 10})
            mpl.rcParams['axes.labelpad'] = 3


            # FLATTEN
            # angles = angles.flatten()



            cmap = mpl.colors.LinearSegmentedColormap.from_list("", ["blue","violet","red"])
            cmap=mpl.cm.RdBu_r
            # cmap=mpl.cm.viridis_r
            # cmap=mpl.cm.bwr
            # cmap=mpl.cm.coolwarm
            # cmap=mpl.cm.Blues_r

            # norm = mpl.colors.Normalize(vmin=5, vmax=10)
            # cmap=mpl.cm.gnuplot
            # cmap=mpl.cm.magma_r
            # cmap=mpl.cm.inferno_r
            # cmap=mpl.cm.plasma
            # cmap=mpl.cm.plasma_r
            # cmap=mpl.cm.cividis_r
            # cmap = mpl.colors.LinearSegmentedColormap.from_list("", ["blue","violet","red"])
            # cmap = mpl.colors.LinearSegmentedColormap.from_list("", ["blue","orange"])
            divnorm=mcolors.TwoSlopeNorm(vmin=angles.min(), vcenter=(angles.max()+angles.min())/2, vmax=angles.max())
            # cmap = cm.ScalarMappable(norm=divnorm, cmap=cmap)

            # opacity_list = 1-angles/angles.max()
            # print('opacity_list', opacity_list)
            # print('opacity_list.max():', opacity_list.max())
            #
            # # get a Nx4 array of RGBA corresponding to zs
            # # cmap expects values between 0 and 1
            # colors = cmap(angles/angles.max())
            # # colors = angles/angles.max()
            # print('colors:', colors)
            #
            # ### set the alpha values according to i_list
            # ### must satisfy 0 <= i <= 1
            # # epsilon =0.01
            # opacity_list = np.array(opacity_list)
            # colors[:,-1] = opacity_list / opacity_list.max()


            # ax.scatter(alphas,betas,thetas,c=angles.flatten())
            # S = ax.scatter(alphas,betas,thetas,c=colors, cmap=cmap, norm = divnorm)

            # S = ax.scatter(alphas,betas,thetas,c=colors)
            # S_2 = ax.scatter(alphas,betas,thetas,c=angles/angles.max(), cmap=cmap, s=0) # Hack for colormap...



            # TEST  3D C0NTOUR
            # S = ax.contour3D(alphas,betas,thetas,c=colors)


            # ax.view_init(elev=30, azim=75)
            # ax.view_init(elev=25, azim=75)





            # fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
            #  cax=ax, orientation='horizontal', label='Some Units')
            # plt.colorbar(S)
            # fig.colorbar(S, ax=ax)
            # axins1 = inset_axes(ax,
            #                    width="5%",  # width = 5% of parent_bbox width
            #                    height="100%",  # height : 50%
            #                    loc='lower left',
            #                    bbox_to_anchor=(1.05, 0., 1, 1),
            #                    bbox_transform=ax[1].transAxes,
            #                    borderpad=0,
            #                    )

            # ax.clabel(CS2, CS2.levels, inline=True, fontsize=10)
            # ax.clabel(CS,  fontsize=5, colors='black')
            # cbar = fig.colorbar(CS,label=r'angle $\alpha$', ticks=[0, np.pi/8, np.pi/4, 3*np.pi/8 , np.pi/2 ])
            # cbar = fig.colorbar(CS_1, ticks=[0, np.pi/8, np.pi/4, 3*np.pi/8 , np.pi/2 ])

            # cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
            # cbar = fig.colorbar(S, cax=ax, ticks=[0, np.pi/8, np.pi/4, 3*np.pi/8 , np.pi/2 ])
            # cbar = fig.colorbar(S_2, ax=ax)
            # cbar = fig.colorbar(S, ax=ax)
            # cbar = fig.colorbar(CS_1, cax=cbar_ax, shrink=0.2, location='right', ticks=[0, np.pi/8, np.pi/4, 3*np.pi/8 , np.pi/2 ])
            # cbar = fig.colorbar(CS_1,  ax=ax[:], shrink=0.8, location='right', ticks=[0, np.pi/8, np.pi/4, 3*np.pi/8 , np.pi/2 ])

            ## ADD COLORBAR:

            # axins = inset_axes(ax,
            #                     width="5%",
            #                     height="100%",
            #                     loc='center right',
            #                     borderpad=-0.5,
            #
            #                     # borderpad=0,
            #                     # loc='lower left',
            #                     # bbox_to_anchor=(1.05, 1, 1, 1),
            #                     # borderpad=-0.5,
            #                     # bbox_to_anchor=[0.0, 0.5]
            #                     # bbox_to_anchor=[1.5, 1.5]
            #                    )
            #
            #
            # cbar = fig.colorbar(S_2, cax=axins, shrink=0.5 )
            # # cbar = fig.colorbar(S_2, orientation="horizontal", pad=0.2)
            # # cbar = fig.colorbar(S_2, pad=0.2)
            # cbar.ax.set_yticklabels([r'$0$',r'$\pi/8$', r'$\pi/4$' ,r'$3\pi/8$' , r'$\pi/2$'])
            # cbar.ax.set_title(r'$\alpha$')


            ### COLORBAR :
            # cbar = plt.colorbar()
            # cbar.ax.tick_params(labelsize=10)
            # fig.colorbar(S)
        # cbar=plt.colorbar(pnt3d)
        # cbar.set_label("Values (units)")
        # plt.axvline(x = 8, color = 'b', linestyle = ':', label='$q_1$')
        # plt.axhline(y = 0.083333333, color = 'b', linestyle = ':', label='$q_1$')

        # if make_3D_plot: pnt3d=ax.scatter(alphas,betas,thetas,c=angles.flat)
        # if make_3D_plot: fig = go.Figure(data=[go.Surface(z=thetas, x=alphas, y=betas, color=angles.flat)])
        #### PLOTLY:
        # print('angles.flatten()',angles.flatten())
        # fig = go.Figure(data=go.Isosurface(
        #     x=alphas.flatten(),
        #     y=betas.flatten(),
        #     z=thetas.flatten(),
        #     value=angles.flatten(),
        #     isomin=0,
        #     isomax=1.565,
        #     opacity=1.0,
        #     colorscale='agsunset',
        #     flatshading = True
        #     # caps=dict(x_show=False, y_show=False)
        #     ))
        # fig.show()

        # ----TEST SAVITZKY_GOLAY FILTER
        # zhat = scipy.signal.savgol_filter(angles.flatten(), 5, 4) # window size 51, polynomial order 3
        #
        # fig = go.Figure(data=go.Volume(
        #     x=alphas.flatten(),
        #     y=betas.flatten(),
        #     z=thetas.flatten(),
        #     value=zhat,
        #     isomin=0.0,
        #     isomax=1.56,
        #     opacity=0.1, # needs to be small to see through all surfaces
        #     surface_count=17, # needs to be a large number for good volume rendering
        #     colorscale='RdBu'
        #     ))
        # fig.show()


        ## --------------------------------
        # alphas = np.array(alphas)
        # print('alphas.shape:',np.shape(alphas))
        # #### ------- MAYAVI:
        # # s = angles.flatten()
        s = angles
        src = mlab.pipeline.scalar_field(s)
        # # mlab.pipeline.iso_surface(src, contours=[s.min()+0.1*s.ptp(), ], opacity=0.3)
        # # mlab.pipeline.iso_surface(src, contours=[s.max()-0.1*s.ptp(), ], opacity=0.3)
        #

        # alphas = np.ogrid(alphas)
        # betas = np.ogrid(betas)
        # thetas = np.ogrid(thetas)


        mlab.contour3d(alphas,betas,thetas, s)
        # mlab.contour3d(src)
        # # mlab.pipeline.user_defined(surf, filter=tvtk.CubeAxesActor())
        # # mlab.outline()
        # # mlab.mesh(alphas,betas,thetas)
        # mlab.colorbar( orientation='vertical', nb_labels=5)
        # # mlab.orientation_axes()
        # mlab.show()

        #TEST
        # x, y, z = np.ogrid[-5:5:64j, -5:5:64j, -5:5:64j]
        #
        # scalars = x * x * 0.5 + y * y + z * z * 2.0
        #
        # obj = contour3d(scalars, contours=4, transparent=True)

        ### ---------------



        fig.subplots_adjust(right=0.85)

        ax.set_xlabel(r'$\theta_\rho$', labelpad=2)
        ax.set_ylabel(r"$\theta_\mu$", labelpad=2)
        if make_3D_plot: ax.set_zlabel(r'$\theta$',labelpad=2)

        fig.set_size_inches(width, height)
        # fig.savefig('PhaseDiagram3D.pdf')
        fig.savefig('PhaseDiagram3D.png', format='png', dpi=300)
        # fig.savefig('PhaseDiagram3D.png', format='png', dpi=300,bbox_extra_artists=(cbar)
                    # )
        # fig.savefig('Plot-Prestrain-Theta_AlphaFix.pdf',bbox_extra_artists=(cbar,),
        #             bbox_inches='tight')
        # fig.savefig('Plot-Prestrain-Theta_AlphaFix.pdf',format='png',bbox_extra_artists=(cbar,),
                    # bbox_inches='tight')
        # fig.savefig('PhaseDiagram3D', format='svg')
        # fig.savefig('PhaseDiagram3D.pdf', dpi=90)
        # plt.show()

        # fig.set_size_inches(width, height)
        # fig.savefig('PhaseDiagram3D.pdf')
        # plt.savefig('common_labels.png', dpi=300)
        # print('T:', T)
        # print('Type 1 occured here:', np.where(T == 1))
        # print('Type 2 occured here:', np.where(T == 2))


        # print(alphas_)
        # print(betas_)





# ALTERNATIVE
# colors = ("red", "green", "blue")
# groups = ("Type 1", "Type2", "Type3")
#
# # Create plot
# fig = plt.figure()
# ax = fig.add_subplot(1, 1, 1)
#
# for data, color, group in zip(Types, colors, groups):
#     # x, y = data
#     ax.scatter(alphas, thetas, alpha=0.8, c=color, edgecolors='none', label=group)
#
# plt.title('Matplot scatter plot')
# plt.legend(loc=2)
# plt.show()