#include <config.h>
#include <array>
#include <vector>
#include <fstream>

#include <iostream>
#include <dune/common/indices.hh>
#include <dune/common/bitsetvector.hh>
#include <dune/common/parametertree.hh>
#include <dune/common/parametertreeparser.hh>
#include <dune/common/float_cmp.hh>
#include <dune/common/math.hh>


#include <dune/geometry/quadraturerules.hh>

#include <dune/grid/uggrid.hh>
#include <dune/grid/yaspgrid.hh>
// #include <dune/grid/utility/structuredgridfactory.hh> //TEST
#include <dune/grid/io/file/vtk/subsamplingvtkwriter.hh>

#include <dune/istl/matrix.hh>
#include <dune/istl/bcrsmatrix.hh>
#include <dune/istl/multitypeblockmatrix.hh>
#include <dune/istl/multitypeblockvector.hh>
#include <dune/istl/matrixindexset.hh>
#include <dune/istl/solvers.hh>
#include <dune/istl/spqr.hh>
#include <dune/istl/preconditioners.hh>
#include <dune/istl/io.hh>

#include <dune/functions/functionspacebases/interpolate.hh>
#include <dune/functions/backends/istlvectorbackend.hh>
#include <dune/functions/functionspacebases/powerbasis.hh>
#include <dune/functions/functionspacebases/compositebasis.hh>
#include <dune/functions/functionspacebases/lagrangebasis.hh>
#include <dune/functions/functionspacebases/periodicbasis.hh>
#include <dune/functions/functionspacebases/subspacebasis.hh>
#include <dune/functions/functionspacebases/boundarydofs.hh>
#include <dune/functions/gridfunctions/discreteglobalbasisfunction.hh>
#include <dune/functions/gridfunctions/gridviewfunction.hh>
#include <dune/functions/functionspacebases/hierarchicvectorwrapper.hh>

#include <dune/common/fvector.hh>
#include <dune/common/fmatrix.hh> 

#include <dune/microstructure/prestrain_material_geometry.hh>
#include <dune/microstructure/matrix_operations.hh>
#include <dune/microstructure/vtk_filler.hh>    //TEST
#include <dune/microstructure/CorrectorComputer.hh>    
#include <dune/microstructure/EffectiveQuantitiesComputer.hh>  
#include <dune/microstructure/prestrainedMaterial.hh>  

#include <dune/solvers/solvers/umfpacksolver.hh>  //TEST 
#include <dune/istl/eigenvalue/test/matrixinfo.hh> // TEST: compute condition Number 

// #include <dune/fufem/discretizationerror.hh>
#include <dune/fufem/dunepython.hh>
#include <python2.7/Python.h>

// #include <dune/fufem/functions/virtualgridfunction.hh> //TEST 

// #include <boost/multiprecision/cpp_dec_float.hpp>
#include <any>
#include <variant>
#include <string>
#include <iomanip>   // needed when working with relative paths e.g. from python-scripts

using namespace Dune;
using namespace MatrixOperations;

//////////////////////////////////////////////////////////////////////
// Helper functions for Table-Output
//////////////////////////////////////////////////////////////////////
/*! Center-aligns string within a field of width w. Pads with blank spaces
    to enforce alignment. */
std::string center(const std::string s, const int w) {
    std::stringstream ss, spaces;
    int padding = w - s.size();                 // count excess room to pad
    for(int i=0; i<padding/2; ++i)
        spaces << " ";
    ss << spaces.str() << s << spaces.str();    // format with padding
    if(padding>0 && padding%2!=0)               // if odd #, add 1 space
        ss << " ";
    return ss.str();
}

/* Convert double to string with specified number of places after the decimal
   and left padding. */
template<class type>
std::string prd(const type x, const int decDigits, const int width) {
    std::stringstream ss;
//     ss << std::fixed << std::right;
    ss << std::scientific << std::right;                     // Use scientific Output!
    ss.fill(' ');        // fill space around displayed #
    ss.width(width);     // set  width around displayed #
    ss.precision(decDigits); // set # places after decimal
    ss << x;
    return ss.str();
}

//////////////////////////////////////////////////
//   Infrastructure for handling periodicity
//////////////////////////////////////////////////
// Check whether two points are equal on R/Z x R/Z x R
auto equivalent = [](const FieldVector<double,3>& x, const FieldVector<double,3>& y)
                {
                    return ( (FloatCmp::eq(x[0],y[0]) or FloatCmp::eq(x[0]+1,y[0]) or FloatCmp::eq(x[0]-1,y[0]))
                            and (FloatCmp::eq(x[1],y[1]) or FloatCmp::eq(x[1]+1,y[1]) or FloatCmp::eq(x[1]-1,y[1]))
                            and (FloatCmp::eq(x[2],y[2]))
                        );
                };




////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char *argv[])
{
  MPIHelper::instance(argc, argv);
  
  Dune::Timer globalTimer;

  ParameterTree parameterSet;
  if (argc < 2)
    ParameterTreeParser::readINITree("../../inputs/cellsolver.parset", parameterSet);
  else
  {
    ParameterTreeParser::readINITree(argv[1], parameterSet);
    ParameterTreeParser::readOptions(argc, argv, parameterSet);
  }

  //--- Output setter
  std::string outputPath = parameterSet.get("outputPath", "../../outputs");

  //--- setup Log-File
  std::fstream log;
  log.open(outputPath + "/output.txt" ,std::ios::out);

  std::cout << "outputPath:" << outputPath << std::endl;
  
//   parameterSet.report(log); // short Alternativ
  
    //--- Get Path for Material/Geometry functions
    // std::string geometryFunctionPath = parameterSet.get<std::string>("geometryFunctionPath",);
    //--- Start Python interpreter
    Python::start();
    Python::Reference main = Python::import("__main__");
    Python::run("import math");
    Python::runStream()
        << std::endl << "import sys"
        // << std::endl << "sys.path.append('" << geometryFunctionPath << "')"
        << std::endl;


  constexpr int dim = 3;
  constexpr int dimWorld = 3;

  // Debug/Print Options 
  bool print_debug = parameterSet.get<bool>("print_debug", false);

  // VTK-write options
  bool write_materialFunctions   = parameterSet.get<bool>("write_materialFunctions", false);
  bool write_prestrainFunctions  = parameterSet.get<bool>("write_prestrainFunctions", false);

  ///////////////////////////////////
  // Generate the grid
  ///////////////////////////////////
  // --- Corrector Problem Domain (-1/2,1/2)^3:
  FieldVector<double,dim> lower({-1.0/2.0, -1.0/2.0, -1.0/2.0});
  FieldVector<double,dim> upper({1.0/2.0, 1.0/2.0, 1.0/2.0});

  std::array<int,2> numLevels = parameterSet.get<std::array<int,2>>("numLevels", {1,3});
  int levelCounter = 0;
   
  
  ///////////////////////////////////
  // Create Date Storage
  ///////////////////////////////////
  //--- Storage:: #1 level #2 L2SymError #3 L2SymErrorOrder #4  L2Norm(sym) #5 L2Norm(sym-analytic) #6 L2Norm(phi_1)
  std::vector<std::variant<std::string, size_t , double>> Storage_Error;
  //--- Storage:: | level | q1 | q2 | q3 | q12 | q23 | b1 | b2 | b3 |           
  std::vector<std::variant<std::string, size_t , double>> Storage_Quantities;         


  //--- GridLevel-Loop:
  for(size_t level = numLevels[0] ; level <= numLevels[1]; level++)     
  {
    std::cout << " ----------------------------------" << std::endl;
    std::cout << "GridLevel: " << level << std::endl;
    std::cout << " ----------------------------------" << std::endl;

    Storage_Error.push_back(level);
    Storage_Quantities.push_back(level);
    std::array<int, dim> nElements = {(int)std::pow(2,level) ,(int)std::pow(2,level) ,(int)std::pow(2,level)};
    std::cout << "Number of Grid-Elements in each direction: " << nElements << std::endl;
    log << "Number of Grid-Elements in each direction: " << nElements << std::endl;

    using CellGridType = YaspGrid<dim, EquidistantOffsetCoordinates<double, dim> >;
    CellGridType grid_CE(lower,upper,nElements);
    using GridView = CellGridType::LeafGridView;
    const GridView gridView_CE = grid_CE.leafGridView();
    if(print_debug)
       std::cout << "Host grid has " << gridView_CE.size(dim) << " vertices." << std::endl;

    //--- Choose a finite element space for Cell Problem
    using namespace Functions::BasisFactory;
    Functions::BasisFactory::Experimental::PeriodicIndexSet periodicIndices;

    //--- get PeriodicIndices for periodicBasis (Don't do the following in real life: It has quadratic run-time in the number of vertices.)
    for (const auto& v1 : vertices(gridView_CE))
        for (const auto& v2 : vertices(gridView_CE))
            if (equivalent(v1.geometry().corner(0), v2.geometry().corner(0)))
            {
                periodicIndices.unifyIndexPair({gridView_CE.indexSet().index(v1)}, {gridView_CE.indexSet().index(v2)});
            }

    //--- setup first order periodic Lagrange-Basis
    auto Basis_CE = makeBasis(
        gridView_CE,
        power<dim>(                                                                             // eig dimworld?!?! 
        Functions::BasisFactory::Experimental::periodic(lagrange<1>(), periodicIndices),
        flatLexicographic()
        //blockedInterleaved()   // Not Implemented
        ));     
    if(print_debug)
       std::cout << "power<periodic> basis has " << Basis_CE.dimension() << " degrees of freedom" << std::endl;
    

    ///////////////////////////////////
    //  Create prestrained material object
    ///////////////////////////////////
    auto material_ = prestrainedMaterial(gridView_CE,parameterSet);

    // --- get scale ratio 
    double gamma = parameterSet.get<double>("gamma",1.0); 

    //------------------------------------------------------------------------------------------------
    //--- Compute Correctors
    // auto correctorComputer = CorrectorComputer(Basis_CE, muTerm, lambdaTerm, gamma, log, parameterSet);
    // auto correctorComputer = CorrectorComputer(Basis_CE, material_, muTerm, lambdaTerm, gamma, log, parameterSet);
    auto correctorComputer = CorrectorComputer(Basis_CE, material_, gamma, log, parameterSet);
    correctorComputer.solve();

    //--- Check Correctors (options):
    if(parameterSet.get<bool>("write_L2Error", false))
         correctorComputer.computeNorms();
    if(parameterSet.get<bool>("write_VTK", false))
         correctorComputer.writeCorrectorsVTK(level);
    //--- Additional Test: check orthogonality (75) from paper:
    if(parameterSet.get<bool>("write_checkOrthogonality", false))
        correctorComputer.check_Orthogonality();
    //--- Check symmetry of stiffness matrix
    if(print_debug)
        correctorComputer.checkSymmetry();

    //--- Compute effective quantities
    auto effectiveQuantitiesComputer = EffectiveQuantitiesComputer(correctorComputer,material_);
    effectiveQuantitiesComputer.computeEffectiveQuantities();

    //--- write material indicator function to VTK
    if (write_materialFunctions)
    {
        material_.writeVTKMaterialFunctions(nElements,level);
    }

    //--- TEST:: Compute Qeff without using the orthogonality (75)... 
    // only really makes a difference whenever the orthogonality is not satisfied!
    // std::cout << "----------computeFullQ-----------"<< std::endl;  //TEST
    // effectiveQuantitiesComputer.computeFullQ();

    //--- get effective quantities
    auto Qeff = effectiveQuantitiesComputer.getQeff();
    auto Beff = effectiveQuantitiesComputer.getBeff();
    printmatrix(std::cout, Qeff, "Matrix Qeff", "--");
    printvector(std::cout, Beff, "Beff", "--");

    //--- write effective quantities to matlab folder (for symbolic minimization)
    if(parameterSet.get<bool>("write_toMATLAB", false))
        effectiveQuantitiesComputer.writeToMatlab(outputPath);

    std::cout.precision(10);
    std::cout<< "q1 : " << Qeff[0][0] << std::endl;
    std::cout<< "q2 : " << Qeff[1][1] << std::endl;
    std::cout<< "q3 : " << std::fixed << Qeff[2][2] << std::endl;
    std::cout<< std::fixed << std::setprecision(6) << "q_onetwo=" << Qeff[0][1] << std::endl;
    // -------------------------------------------

    // --- Fill output-Table:
    Storage_Quantities.push_back(Qeff[0][0] );
    Storage_Quantities.push_back(Qeff[1][1] );
    Storage_Quantities.push_back(Qeff[2][2] );
    Storage_Quantities.push_back(Qeff[0][1] );
    Storage_Quantities.push_back(Qeff[1][2] );
    Storage_Quantities.push_back(Beff[0]);
    Storage_Quantities.push_back(Beff[1]);
    Storage_Quantities.push_back(Beff[2]);

    log << "size of FiniteElementBasis: " << Basis_CE.size() << std::endl;
    log << "q1="  << Qeff[0][0] << std::endl;
    log << "q2="  << Qeff[1][1] << std::endl;
    log << "q3="  << Qeff[2][2] << std::endl;
    log << "q12=" << Qeff[0][1] << std::endl;
    log << "q23=" << Qeff[1][2] << std::endl;
    log << std::fixed << std::setprecision(6) << "q_onetwo=" << Qeff[0][1] << std::endl;
    log << "b1=" << Beff[0] << std::endl;
    log << "b2=" << Beff[1] << std::endl;
    log << "b3=" << Beff[2] << std::endl;
    log << "mu_gamma=" << Qeff[2][2] << std::endl;           // added for Python-Script


  levelCounter++; 
  } // GridLevel-Loop End

    //////////////////////////////////////////
    //--- Print Storage
    //////////////////////////////////////////
    int tableWidth = 12;
    std::cout << center("Levels ",tableWidth)   << " | "
              << center("q1",tableWidth)        << " | "
              << center("q2",tableWidth)        << " | "
              << center("q3",tableWidth)        << " | "
              << center("q12",tableWidth)       << " | "
              << center("q23",tableWidth)       << " | "
              << center("b1",tableWidth)        << " | "
              << center("b2",tableWidth)        << " | "
              << center("b3",tableWidth)        << " | " << "\n";
    std::cout << std::string(tableWidth*9 + 3*9, '-')    << "\n";
    log       << std::string(tableWidth*9 + 3*9, '-')    << "\n";   
    log       << center("Levels ",tableWidth)   << " | "
              << center("q1",tableWidth)        << " | "
              << center("q2",tableWidth)        << " | "
              << center("q3",tableWidth)        << " | "
              << center("q12",tableWidth)       << " | "
              << center("q23",tableWidth)       << " | "
              << center("b1",tableWidth)        << " | "
              << center("b2",tableWidth)        << " | "
              << center("b3",tableWidth)        << " | " << "\n";
    log       << std::string(tableWidth*9 + 3*9, '-')    << "\n";   
  
    int StorageCount2 = 0;
    for(auto& v: Storage_Quantities) 
    {
        std::visit([tableWidth](auto&& arg){std::cout << center(prd(arg,5,1),tableWidth)      << " | ";}, v);
        std::visit([tableWidth, &log](auto&& arg){log << center(prd(arg,5,1),tableWidth)      << " & ";}, v);
        StorageCount2++;
        if(StorageCount2 % 9 == 0 )
        {
            std::cout << std::endl;
            log << std::endl;
        }
    }
    std::cout << std::string(tableWidth*9 + 3*9, '-') << "\n";
    log       << std::string(tableWidth*9 + 3*9, '-') << "\n";  

    log.close(); //close log-file

    std::cout << "Total time elapsed: " << globalTimer.elapsed() << std::endl;
}