#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jul  6 13:17:28 2022

@author: stefan
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from helper_functions import *


def energy(kappa,alpha,Q,B)  :
    G=kappa*np.array([[np.cos(alpha)**2],[np.sin(alpha)**2],[np.sqrt(2)*np.cos(alpha)*np.sin(alpha)]])-B
    return np.matmul(np.transpose(G),np.matmul(Q,G))[0,0]

def xytokappaalpha(x,y):
   
    if y>0:
        return [np.sqrt(x**2+y**2), np.abs(np.arctan2(y,x))]
    else:
        return [-np.sqrt(x**2+y**2), np.abs(np.arctan2(y,x))]


#
case = -1
if case == 0:        # Beispiel Figure 10 (c)    
    q1=1; q2=2; q12=1/2; q3=0.5*(np.sqrt(4*q1*q2)-q12);
    Q=np.array([[q1,q12/2,0],[q12/2,q2,0],[0,0,q3]])
    B=np.array([[0.491],[0.347],[0]])
elif case == 1: # Beispiel Figure 10 (b)    
    q1=1; q2=2; q12=0; q3=1;
    Q=np.array([[q1,q12/2,0],[q12/2,q2,0],[0,0,q3]])
    B=np.array([[2],[1.5],[0]])       
elif case == 2:  # Beispiel Figure 10 (a)    
    q1=1; q2=2; q12=0; q3=1;
    Q=np.array([[q1,q12/2,0],[q12/2,q2,0],[0,0,q3]])
    B=np.array([[0],[0.5],[0]])
elif case == 3: # Homogener Fall
    q1=1; q2=1; q12=0; q3=1;
    Q=np.array([[q1,q12/2,0],[q12/2,q2,0],[0,0,q3]])
    B=np.array([[.5],[.5],[0]])
elif case == 4: # Section 6.1
    theta=0.2
    rho=1
    q1=1/(3*(2-theta)); q2=(1+theta)/6; q12=0; q3=q1;
    Q=np.array([[q1,q12/2,0],[q12/2,q2,0],[0,0,q3]])
    B=np.array([[rho*3/2*(1-theta)],[rho*3/2*(1-theta)/(1+theta)],[0]])
elif case==-1: # Read from outputs
    DataPath = os.path.dirname(os.getcwd())+'/outputs'
    #DataPath='/home/stefan/DUNE/dune-microstructure/outputs/Results/4x4/1'
    QFilePath = DataPath + '/QMatrix.txt'
    BFilePath = DataPath + '/BMatrix.txt'
    Q, B = ReadEffectiveQuantities(QFilePath,BFilePath)
    Q=0.5*(np.transpose(Q)+Q) # symmetrize
    B=np.transpose([B])
# 

#    
length=0.4
N=200
h=length/N
E=np.zeros([N,N])
X=np.zeros([N,N])
Y=np.zeros([N,N])
for i in range(0,N): 
    for j in range(0,N):     
        X[i,j]=(i-N/2)*h
        Y[i,j]=(j-N/2)*h
        K=xytokappaalpha(X[i,j],Y[i,j])
        E[i,j]=energy(K[0],K[1],Q,B)

fig = plt.figure(figsize=(7,6))
ax = plt.gca()
ax.set_aspect(1)
ax.set_xticks([-length/4,0,length/4])
ax.set_yticks([])
#pcm = plt.pcolor(X,Y,E, norm=colors.LogNorm(vmin=E.min(), vmax=E.max()), cmap='winter', shading='auto')
pcm = plt.pcolor(X,Y,E, norm=colors.PowerNorm(gamma=0.2), cmap='brg')
plt.colorbar(pcm, extend='max')
#plt.imshow(np.log(E-np.min(E)+0.0001)) # normalize to min = 0 and log scale to emphasize energy landscape
# TODO: Beschriftung der Axen sollte von [-h*N/2, h*N/2] sein!
# Kreis mit Radius 1 einzeichnen
plt.show()
#