//
// Software License for AMDiS
//
// Copyright (c) 2010 Dresden University of Technology 
// All rights reserved.
// Authors: Simon Vey, Thomas Witkowski et al.
//
// This file is part of AMDiS
//
// See also license.opensource.txt in the distribution.


#include "ResidualEstimator.h"
#include "Operator.h"
#include "DOFMatrix.h"
#include "DOFVector.h"
#include "Assembler.h"
#include "Traverse.h"
#include "Parameters.h"

#ifdef HAVE_PARALLEL_DOMAIN_AMDIS
#include <mpi.h>
#endif

namespace AMDiS {

  ResidualEstimator::ResidualEstimator(std::string name, int r) 
    : Estimator(name, r),
      C0(0.0), 
      C1(0.0), 
      C2(0.0), 
      C3(0.0),
      jumpResidualOnly(false)
  {
    FUNCNAME("ResidualEstimator::ResidualEstimator()");

    Parameters::get(name + "->C0", C0);
    Parameters::get(name + "->C1", C1);
    Parameters::get(name + "->C2", C2);
    Parameters::get(name + "->C3", C3);

    C0 = C0 > 1.e-25 ? sqr(C0) : 0.0;
    C1 = C1 > 1.e-25 ? sqr(C1) : 0.0;
    C2 = C2 > 1.e-25 ? sqr(C2) : 0.0;
    C3 = C3 > 1.e-25 ? sqr(C3) : 0.0;

    if (C1 != 0.0 && C0 == 0.0 && C3 == 0.0)
      jumpResidualOnly = true;
      
    TEST_EXIT(C2 == 0.0)("C2 is not used! Please remove it or set it to 0.0!\n");
  }


  void ResidualEstimator::init(double ts)
  {
    FUNCNAME("ResidualEstimator::init()");

    timestep = ts;
    nSystems = static_cast<int>(uh.size());

    TEST_EXIT_DBG(nSystems > 0)("no system set\n");

    dim = mesh->getDim();
    basFcts = new const BasisFunction*[nSystems];
    quadFast = new FastQuadrature*[nSystems];

    degree = 0;
    for (int system = 0; system < nSystems; system++) {
      basFcts[system] = uh[system]->getFeSpace()->getBasisFcts();
      degree = std::max(degree, basFcts[system]->getDegree());
    }
    degree *= 2;

    quad = Quadrature::provideQuadrature(dim, degree);
    nPoints = quad->getNumPoints();

    Flag flag = INIT_PHI | INIT_GRD_PHI;
    if (degree > 2)
      flag |= INIT_D2_PHI;    

    for (int system = 0; system < nSystems; system++)
      quadFast[system] = FastQuadrature::provideFastQuadrature(basFcts[system], 
							       *quad, 
							       flag);    
  
    uhEl.resize(nSystems);
    uhNeigh.resize(nSystems);
    if (timestep)
      uhOldEl.resize(nSystems);

    for (int system = 0; system < nSystems; system++) {
      uhEl[system].change_dim(basFcts[system]->getNumber()); 
      uhNeigh[system].change_dim(basFcts[system]->getNumber());
      if (timestep)
	uhOldEl[system].change_dim(basFcts[system]->getNumber());
    }

    if (timestep) {
      uhQP.change_dim(nPoints);
      uhOldQP.change_dim(nPoints);
    }

    riq.change_dim(nPoints);
    grdUh_qp = NULL;
    D2uhqp = NULL;

    // clear error indicators and mark elements for jumpRes
    TraverseStack stack;
    ElInfo *elInfo = stack.traverseFirst(mesh, -1, Mesh::CALL_LEAF_EL);
    while (elInfo) {
      // SIMON: DEL LINE BELOW
      elInfo->getElement()->setEstimation(0.0, row);

      elInfo->getElement()->setMark(1);
      elInfo = stack.traverseNext(elInfo);
    }

    est_sum = 0.0;
    est_max = 0.0;
    est_t_sum = 0.0;
    est_t_max = 0.0;

    traverseFlag = 
      Mesh::FILL_NEIGH      |
      Mesh::FILL_COORDS     |
      Mesh::FILL_OPP_COORDS |
      Mesh::FILL_BOUND      |
      Mesh::FILL_GRD_LAMBDA |
      Mesh::FILL_DET        |
      Mesh::CALL_LEAF_EL;
    neighInfo = mesh->createNewElInfo();

    // === Prepare date for computing jump residual. ===
    if (C1 > 0.0 && dim > 1) {
      surfaceQuad = Quadrature::provideQuadrature(dim - 1, degree);
      nPointsSurface = surfaceQuad->getNumPoints();
      grdUhEl.resize(nPointsSurface);
      grdUhNeigh.resize(nPointsSurface);
      jump.resize(nPointsSurface);
      localJump.resize(nPointsSurface);
      nNeighbours = Global::getGeo(NEIGH, dim);
      lambdaNeigh = new DimVec<WorldVector<double> >(dim, NO_INIT);
      lambda = new DimVec<double>(dim, NO_INIT);

      secondOrderTerms.resize(nSystems);
      for (int system = 0; system < nSystems; system++) {
	secondOrderTerms[system] = false;

	if (matrix[system] == NULL)
	  continue;

	for (std::vector<Operator*>::iterator it = matrix[system]->getOperators().begin();
	     it != matrix[system]->getOperators().end(); ++it)
	  secondOrderTerms[system] = secondOrderTerms[system] || (*it)->secondOrderTerms();
      }
    }
  }


  void ResidualEstimator::exit(bool output)
  {
    FUNCNAME("ResidualEstimator::exit()");

#ifdef HAVE_PARALLEL_DOMAIN_AMDIS
    double send_est_sum = est_sum;
    double send_est_max = est_max;
    double send_est_t_sum = est_t_sum;
    double send_est_t_max = est_t_max;

    MPI::COMM_WORLD.Allreduce(&send_est_sum, &est_sum, 1, MPI_DOUBLE, MPI_SUM);
    MPI::COMM_WORLD.Allreduce(&send_est_max, &est_max, 1, MPI_DOUBLE, MPI_MAX);
    MPI::COMM_WORLD.Allreduce(&send_est_t_sum, &est_t_sum, 1, MPI_DOUBLE, MPI_SUM);
    MPI::COMM_WORLD.Allreduce(&send_est_t_max, &est_t_max, 1, MPI_DOUBLE, MPI_MAX);
#endif

    est_sum = sqrt(est_sum);
    est_t_sum = sqrt(est_t_sum);

    if (output) {
      MSG("estimate for component %d = %.8e\n", row, est_sum);
      if (C3)
	MSG("time estimate for component %d = %.8e\n", row, est_t_sum);
    }

    delete [] basFcts;
    delete [] quadFast;

    if (grdUh_qp != NULL)
      delete [] grdUh_qp;
    if (D2uhqp != NULL)
      delete [] D2uhqp;

    if (C1 && (dim > 1)) {
      delete lambdaNeigh;
      delete lambda;
    }

    delete neighInfo;
  }


  void ResidualEstimator::estimateElement(ElInfo *elInfo, DualElInfo *dualElInfo)
  {    
    FUNCNAME("ResidualEstimator::estimateElement()");

    TEST_EXIT_DBG(nSystems > 0)("no system set\n");

    Element *el = elInfo->getElement();
    double est_el = el->getEstimation(row);
    // SIMON    double est_el = 0.0;
    std::vector<Operator*>::iterator it;
    std::vector<double*>::iterator itfac;

    // === Init assemblers. ===
    for (int system = 0; system < nSystems; system++) {
      if (matrix[system] == NULL) 
	continue;

      DOFMatrix *dofMat = const_cast<DOFMatrix*>(matrix[system]);
      DOFVector<double> *dofVec = const_cast<DOFVector<double>*>(fh[system]);

      for (it = dofMat->getOperatorsBegin(), itfac = dofMat->getOperatorEstFactorBegin();
	   it != dofMat->getOperatorsEnd(); ++it, ++itfac)
	if (*itfac == NULL || **itfac != 0.0) {	  
	  // If the estimator must only compute the jump residual but there are no
	  // second order terms in the operator, it can be skipped.
	  if (jumpResidualOnly && (*it)->secondOrderTerms() == false)
	    continue;
	  
	  if (dualElInfo)
	    (*it)->getAssembler()->initElement(dualElInfo->smallElInfo, 
					       dualElInfo->largeElInfo,
					       quad);
	  else
	    (*it)->getAssembler()->initElement(elInfo, NULL, quad);	  
	}

      if (C0 > 0.0)
	for (it = dofVec->getOperatorsBegin(); it != dofVec->getOperatorsEnd(); ++it) {
	  if (dualElInfo)
	    (*it)->getAssembler()->initElement(dualElInfo->smallElInfo, 
					       dualElInfo->largeElInfo,
					       quad);
	  else
	    (*it)->getAssembler()->initElement(elInfo, NULL, quad);	  
	}
    }


    // === Compute element residuals and time error estimation. ===
    if (C0 || C3)
      est_el += computeElementResidual(elInfo, dualElInfo);

    // === Compute jump residuals. ===
    if (C1 && dim > 1)
      est_el += computeJumpResidual(elInfo, dualElInfo);

    // === Update global residual variables. ===
    el->setEstimation(est_el, row);
    el->setMark(0);
    est_sum += est_el;
    est_max = std::max(est_max, est_el);
  }


  double ResidualEstimator::computeElementResidual(ElInfo *elInfo, 
						   DualElInfo *dualElInfo)
  {
    FUNCNAME("ResidualEstimator::computeElementResidual()");

    TEST_EXIT(!dualElInfo)("Not yet implemented!\n");

    std::vector<Operator*>::iterator it;
    std::vector<double*>::iterator itfac;
    double det = elInfo->getDet();
    double h2 = h2_from_det(det, dim);
    riq = 0.0;

    for (int system = 0; system < nSystems; system++) {
      if (matrix[system] == NULL) 
	continue;

      if (timestep && uhOld[system]) {
	TEST_EXIT_DBG(uhOld[system])("no uhOld\n");
	uhOld[system]->getLocalVector(elInfo->getElement(), uhOldEl[system]);
  
	// === Compute time error. ===

	if (C0 > 0.0 || C3 > 0.0) {   
	  uh[system]->getVecAtQPs(elInfo, NULL, quadFast[system], uhQP);
	  uhOld[system]->getVecAtQPs(elInfo, NULL, quadFast[system], uhOldQP);
	  
	  if (C3 > 0.0 && system == std::max(row, 0)) {
	    double result = 0.0;
	    for (int iq = 0; iq < nPoints; iq++) {
	      double tiq = (uhQP[iq] - uhOldQP[iq]);
	      result += quad->getWeight(iq) * tiq * tiq;
	    }
	    double v = C3 * det * result;
	    est_t_sum += v;
	    est_t_max = std::max(est_t_max, v);
	  }
	}
      }
           
      // === Compute element residual. ===
      if (C0 > 0.0) {
	DOFMatrix *dofMat = const_cast<DOFMatrix*>(matrix[system]);
	DOFVector<double> *dofVec = const_cast<DOFVector<double>*>(fh[system]);
  
	for (it = dofMat->getOperatorsBegin(), itfac = dofMat->getOperatorEstFactorBegin();
	     it != dofMat->getOperatorsEnd();  ++it, ++itfac) {
	  if (*itfac == NULL || **itfac != 0.0) {
	    if (num_rows(uhQP) == 0 && (*it)->zeroOrderTerms()) {
	      uhQP.change_dim(nPoints);
	      uh[system]->getVecAtQPs(elInfo, NULL, quadFast[system], uhQP);
	    }
	    if (grdUh_qp == NULL && 
		((*it)->firstOrderTermsGrdPsi() || (*it)->firstOrderTermsGrdPhi())) {
	      grdUh_qp = new WorldVector<double>[nPoints];
	      uh[system]->getGrdAtQPs(elInfo, NULL, quadFast[system], grdUh_qp);
	    }
	    if (D2uhqp == NULL && degree > 2 && (*it)->secondOrderTerms()) { 
	      D2uhqp = new WorldMatrix<double>[nPoints];
	      uh[system]->getD2AtQPs(elInfo, NULL, quadFast[system], D2uhqp);	    
	    }
	  }
	}
	
	// === Compute the element residual and store it in irq. ===

	r(elInfo,
	  nPoints, 
	  uhQP,
	  grdUh_qp,
	  D2uhqp,
	  uhOldQP,
	  NULL,  // grdUhOldQP 
	  NULL,  // D2UhOldQP
	  dofMat, 
	  dofVec,
	  quad,
	  riq);
      }     
    }

    // add integral over r square
    double result = 0.0;
    for (int iq = 0; iq < nPoints; iq++)
      result += quad->getWeight(iq) * riq[iq] * riq[iq];
   
    if (timestep != 0.0 || norm == NO_NORM || norm == L2_NORM)
      result = C0 * h2 * h2 * det * result;
    else
      result = C0 * h2 * det * result;
    
    return result;
  }


  double ResidualEstimator::computeJumpResidual(ElInfo *elInfo, 
						DualElInfo *dualElInfo)
  {
    FUNCNAME("ResidualEstimator::computeJumpResidual()");

    double result = 0.0;
    int dow = Global::getGeo(WORLD);
    Element *el = elInfo->getElement();
    const DimVec<WorldVector<double> > &grdLambda = elInfo->getGrdLambda();
    double det = elInfo->getDet();
    double h2 = h2_from_det(det, dim);

    for (int face = 0; face < nNeighbours; face++) {  
      Element *neigh = const_cast<Element*>(elInfo->getNeighbour(face));

      if (!(neigh && neigh->getMark()))
	continue;

      int oppV = elInfo->getOppVertex(face);
	
      el->sortFaceIndices(face, &faceIndEl);
      neigh->sortFaceIndices(oppV, &faceIndNeigh);	
      neighInfo->setElement(const_cast<Element*>(neigh));
      neighInfo->setFillFlag(Mesh::FILL_COORDS);
      
      for (int i = 0; i < dow; i++)
	neighInfo->getCoord(oppV)[i] = elInfo->getOppCoord(face)[i];
      
      // periodic leaf data ?
      ElementData *ldp = el->getElementData()->getElementData(PERIODIC);	
      bool periodicCoords = false;
      
      if (ldp) {
	typedef std::list<LeafDataPeriodic::PeriodicInfo> PerInfList;
	PerInfList& infoList = dynamic_cast<LeafDataPeriodic*>(ldp)->getInfoList();
	
	for (PerInfList::iterator it = infoList.begin(); it != infoList.end(); ++it) {
	  if (it->elementSide == face) {
	    for (int i = 0; i < dim; i++) {
	      int i1 = faceIndEl[i];
	      int i2 = faceIndNeigh[i];
	      
	      int j = 0;
	      for (; j < dim; j++)
		if (i1 == el->getVertexOfPosition(INDEX_OF_DIM(dim - 1, dim), face, j))
		  break;
	      
	      TEST_EXIT_DBG(j != dim)("vertex i1 not on face ???\n");
	      
	      neighInfo->getCoord(i2) = (*(it->periodicCoords))[j];
	    }
	    periodicCoords = true;
	    break;
	  }
	}
      }  // if (ldp)
      
      if (!periodicCoords) {
	for (int i = 0; i < dim; i++) {
	  int i1 = faceIndEl[i];
	  int i2 = faceIndNeigh[i];
	  for (int j = 0; j < dow; j++)
	    neighInfo->getCoord(i2)[j] = elInfo->getCoord(i1)[j];
	}
      }
	
      Parametric *parametric = mesh->getParametric();
      if (parametric)
	neighInfo = parametric->addParametricInfo(neighInfo);	  

      double detNeigh = abs(neighInfo->calcGrdLambda(*lambdaNeigh));
           
      for (int iq = 0; iq < nPointsSurface; iq++)
	jump[iq].set(0.0);     
      
      for (int system = 0; system < nSystems; system++) {
	if (matrix[system] == NULL || secondOrderTerms[system] == false) 
	  continue;
	      
	uh[system]->getLocalVector(el, uhEl[system]);	
	uh[system]->getLocalVector(neigh, uhNeigh[system]);
	  
	for (int iq = 0; iq < nPointsSurface; iq++) {
	  (*lambda)[face] = 0.0;
	  for (int i = 0; i < dim; i++)
	    (*lambda)[faceIndEl[i]] = surfaceQuad->getLambda(iq, i);
	  
	  basFcts[system]->evalGrdUh(*lambda, grdLambda, uhEl[system], &grdUhEl[iq]);
	  
	  (*lambda)[oppV] = 0.0;
	  for (int i = 0; i < dim; i++)
	    (*lambda)[faceIndNeigh[i]] = surfaceQuad->getLambda(iq, i);		  
	  
	  basFcts[system]->evalGrdUh(*lambda, *lambdaNeigh, uhNeigh[system], &grdUhNeigh[iq]);
	  
	  grdUhEl[iq] -= grdUhNeigh[iq];
	} // for iq				
	
	std::vector<double*>::iterator fac;
	std::vector<Operator*>::iterator it;
	DOFMatrix *mat = const_cast<DOFMatrix*>(matrix[system]);
        for (it = mat->getOperatorsBegin(), fac = mat->getOperatorEstFactorBegin(); 
	     it != mat->getOperatorsEnd(); ++it, ++fac) {
	
	  if (*fac == NULL || **fac != 0.0) {
	    for (int iq = 0; iq < nPointsSurface; iq++)
	      localJump[iq].set(0.0);
	    
	    (*it)->weakEvalSecondOrder(grdUhEl, localJump);

	    double factor = *fac ? **fac : 1.0;
	    if (factor != 1.0)
	      for (int i = 0; i < nPointsSurface; i++)
		localJump[i] *= factor;
	    
	    for (int i = 0; i < nPointsSurface; i++)
	      jump[i] += localJump[i];
	  }		
	} // for (it = ...
      } // for system
    
      double val = 0.0;
      for (int iq = 0; iq < nPointsSurface; iq++)
	val += surfaceQuad->getWeight(iq) * (jump[iq] * jump[iq]);

      double d = 0.5 * (det + detNeigh);
   
      if (norm == NO_NORM || norm == L2_NORM)
	val *= C1 * h2_from_det(d, dim) * d;
      else
	val *= C1 * d;
      
      if (parametric)
	neighInfo = parametric->removeParametricInfo(neighInfo);
      
      neigh->setEstimation(neigh->getEstimation(row) + val, row);
      result += val;
    } // for face
    
    double val = fh[std::max(row, 0)]->getBoundaryManager()->
      boundResidual(elInfo, matrix[std::max(row, 0)], uh[std::max(row, 0)]);    
    if (norm == NO_NORM || norm == L2_NORM)
      val *= C1 * h2;
    else
      val *= C1;   
    result += val;

    return result;
  }


  void r(const ElInfo *elInfo,
	 int nPoints,
	 const ElementVector& uhIq,
	 const WorldVector<double> *grdUhIq,
	 const WorldMatrix<double> *D2UhIq,
	 const ElementVector& uhOldIq,
	 const WorldVector<double> *grdUhOldIq,
	 const WorldMatrix<double> *D2UhOldIq,
	 DOFMatrix *A, 
	 DOFVector<double> *fh,
	 Quadrature *quad,
	 ElementVector& result)
  {
    std::vector<Operator*>::iterator it;
    std::vector<double*>::iterator fac;

    // lhs
    for (it = A->getOperatorsBegin(), fac = A->getOperatorEstFactorBegin(); 
	 it != A->getOperatorsEnd(); ++it, ++fac) {
     
      double factor = *fac ? **fac : 1.0;

      if (factor) {
	if (D2UhIq)
	  (*it)->evalSecondOrder(nPoints, uhIq, grdUhIq, D2UhIq, result, -factor);	

	if (grdUhIq) {
	  (*it)->evalFirstOrderGrdPsi(nPoints, uhIq, grdUhIq, D2UhIq, result, factor);
	  (*it)->evalFirstOrderGrdPhi(nPoints, uhIq, grdUhIq, D2UhIq, result, factor);
	}
	
	if (num_rows(uhIq) > 0)
	  (*it)->evalZeroOrder(nPoints, uhIq, grdUhIq, D2UhIq, result, factor);	
      }
    }
    
    // rhs
    for (it = fh->getOperatorsBegin(), fac = fh->getOperatorEstFactorBegin(); 
	 it != fh->getOperatorsEnd(); ++it, ++fac) {

      double factor = *fac ? **fac : 1.0;

      if (factor) {
	if ((*it)->getUhOld()) {
	  if (D2UhOldIq)
	    (*it)->evalSecondOrder(nPoints, uhOldIq, grdUhOldIq, D2UhOldIq, result, factor);
	  
	  if (grdUhOldIq) {
	    (*it)->evalFirstOrderGrdPsi(nPoints, uhOldIq, grdUhOldIq, D2UhOldIq, result, -factor);
	    (*it)->evalFirstOrderGrdPhi(nPoints, uhOldIq, grdUhOldIq, D2UhOldIq, result, -factor);
	  }

	  if (num_rows(uhOldIq) > 0)
	    (*it)->evalZeroOrder(nPoints, uhOldIq, grdUhOldIq, D2UhOldIq, result, -factor);	  
	} else {
	  ElementVector fx(nPoints, 0.0);
	  (*it)->getC(elInfo, nPoints, fx);

	  for (int iq = 0; iq < nPoints; iq++)
	    result[iq] -= factor * fx[iq];
	}
      }
    }    
  }


}