// -*- tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 2 -*-
// vi: set et ts=4 sw=2 sts=2:
#ifndef DUNE_GFE_L2_DISTANCE_SQUARED_ENERGY_HH
#define DUNE_GFE_L2_DISTANCE_SQUARED_ENERGY_HH

#include <dune/geometry/quadraturerules.hh>

#include <dune/gfe/localgeodesicfestiffness.hh>
#include <dune/gfe/localgeodesicfefunction.hh>

template<class Basis, class TargetSpace>
class L2DistanceSquaredEnergy
  : public LocalGeodesicFEStiffness<Basis,TargetSpace>
{
  // grid types
  typedef typename Basis::GridView GridView;
  typedef typename GridView::ctype DT;
  typedef typename TargetSpace::ctype RT;

  // some other sizes
  enum {gridDim=GridView::dimension};

public:

  // This is the function that we are computing the L2-distance to
  std::shared_ptr<VirtualGridViewFunction<GridView,typename TargetSpace::template rebind<double>::other > > origin_;

  /** \brief Assemble the energy for a single element */
  RT energy (const typename Basis::LocalView& localView,
             const std::vector<TargetSpace>& localSolution) const override

  {
    RT energy = 0;

    const auto& localFiniteElement = localView.tree().finiteElement();
    typedef LocalGeodesicFEFunction<gridDim, double, decltype(localFiniteElement), TargetSpace> LocalGFEFunctionType;
    LocalGFEFunctionType localGeodesicFEFunction(localFiniteElement,localSolution);

    // Just guessing an appropriate quadrature order
    auto quadOrder = localFiniteElement.localBasis().order() * 2 * gridDim;

    const auto element = localView.element();

    const auto& quad = Dune::QuadratureRules<double, gridDim>::rule(localFiniteElement.type(), quadOrder);

    for (size_t pt=0; pt<quad.size(); pt++)
    {
      // Local position of the quadrature point
      const Dune::FieldVector<double,gridDim>& quadPos = quad[pt].position();

      const double integrationElement = element.geometry().integrationElement(quadPos);

      auto weight = quad[pt].weight() * integrationElement;

      // The function value
      auto value = localGeodesicFEFunction.evaluate(quadPos);
      typename TargetSpace::template rebind<double>::other originValue;
      origin_->evaluateLocal(element,quadPos, originValue);

      // The derivative of the 'origin' function
      // First: as function defined on the reference element
      typename VirtualGridViewFunction<GridView,typename TargetSpace::template rebind<double>::other>::DerivativeType originReferenceDerivative;
      origin_->evaluateDerivativeLocal(element,quadPos,originReferenceDerivative);

      // The derivative of the function defined on the actual element
      typename VirtualGridViewFunction<GridView,typename TargetSpace::template rebind<double>::other >::DerivativeType originDerivative(0);

      auto jacobianInverseTransposed = element.geometry().jacobianInverseTransposed(quadPos);

      for (size_t comp=0; comp<originReferenceDerivative.N(); comp++)
        jacobianInverseTransposed.umv(originReferenceDerivative[comp], originDerivative[comp]);

      double weightFactor = originDerivative.frobenius_norm();
      // un-comment the following line to switch off the weight factor
      //weightFactor = 1.0;

      // Add the local energy density
      energy += weight * weightFactor * TargetSpace::distanceSquared(originValue, value);

    }

    return energy;
  }

};

#endif