From de9b228eebe644bd1bd5e1cee1f39bd9cc94c58c Mon Sep 17 00:00:00 2001
From: Oliver Sander <sander@igpm.rwth-aachen.de>
Date: Sun, 15 Mar 2015 12:47:34 +0000
Subject: [PATCH] GeodesicFEAssembler and RiemannianTRSolver now directly
 accept dune-functions bases

At least the harmonicmaps.cc example still works as before.
[caveat: only tested without HAVE_MPI]
Internally, at various places in the code, a dune-fufem basis is still constructed
from the dune-functions one.  These parts can be ported one by one, as dune-fufem
slowly replaces its own bases by the ones from dune-functions.

[[Imported from SVN: r10080]]
---
 dune/gfe/geodesicfeassembler.hh | 95 ++++++++++++++++++++++-----------
 dune/gfe/riemanniantrsolver.cc  | 13 ++---
 src/harmonicmaps.cc             | 28 +++++-----
 3 files changed, 86 insertions(+), 50 deletions(-)

diff --git a/dune/gfe/geodesicfeassembler.hh b/dune/gfe/geodesicfeassembler.hh
index c6e07f12..8528b911 100644
--- a/dune/gfe/geodesicfeassembler.hh
+++ b/dune/gfe/geodesicfeassembler.hh
@@ -28,19 +28,21 @@ class GeodesicFEAssembler {
 
 public:
     const Basis basis_;
+    const typename Basis::IndexSet basisIndexSet_;
 
 protected:
 
     LocalGeodesicFEStiffness<GridView,
-                             typename Basis::LocalFiniteElement,
+                             typename Basis::LocalView::Tree::FiniteElement,
                              TargetSpace>* localStiffness_;
 
 public:
 
     /** \brief Constructor for a given grid */
     GeodesicFEAssembler(const Basis& basis,
-                        LocalGeodesicFEStiffness<GridView,typename Basis::LocalFiniteElement, TargetSpace>* localStiffness)
+                        LocalGeodesicFEStiffness<GridView,typename Basis::LocalView::Tree::FiniteElement, TargetSpace>* localStiffness)
         : basis_(basis),
+          basisIndexSet_(basis_.indexSet()),
           localStiffness_(localStiffness)
     {}
 
@@ -72,23 +74,31 @@ template <class Basis, class TargetSpace>
 void GeodesicFEAssembler<Basis,TargetSpace>::
 getNeighborsPerVertex(Dune::MatrixIndexSet& nb) const
 {
-    int n = basis_.size();
+    auto n = basisIndexSet_.size();
 
     nb.resize(n, n);
 
-    ElementIterator it    = basis_.getGridView().template begin<0,Dune::Interior_Partition>();
-    ElementIterator endit = basis_.getGridView().template end<0,Dune::Interior_Partition>  ();
+    // A view on the FE basis on a single element
+    typename Basis::LocalView localView(&basis_);
+    auto localIndexSet = basisIndexSet_.localIndexSet();
+
+    ElementIterator it    = basis_.gridView().template begin<0,Dune::Interior_Partition>();
+    ElementIterator endit = basis_.gridView().template end<0,Dune::Interior_Partition>  ();
 
     for (; it!=endit; ++it) {
 
-        const typename Basis::LocalFiniteElement& lfe = basis_.getLocalFiniteElement(*it);
+        // Bind the local FE basis view to the current element
+        localView.bind(*it);
+        localIndexSet.bind(localView);
+
+        const auto& lfe = localView.tree().finiteElement();
 
-        for (size_t i=0; i<lfe.localBasis().size(); i++) {
+        for (size_t i=0; i<lfe.size(); i++) {
 
-            for (size_t j=0; j<lfe.localBasis().size(); j++) {
+            for (size_t j=0; j<lfe.size(); j++) {
 
-                int iIdx = basis_.index(*it,i);
-                int jIdx = basis_.index(*it,j);
+                auto iIdx = localIndexSet.index(i)[0];
+                auto jIdx = localIndexSet.index(j)[0];
 
                 nb.add(iIdx, jIdx);
 
@@ -120,32 +130,39 @@ assembleGradientAndHessian(const std::vector<TargetSpace>& sol,
     gradient.resize(sol.size());
     gradient = 0;
 
-    ElementIterator it    = basis_.getGridView().template begin<0,Dune::Interior_Partition>();
-    ElementIterator endit = basis_.getGridView().template end<0,Dune::Interior_Partition>  ();
+    // A view on the FE basis on a single element
+    typename Basis::LocalView localView(&basis_);
+    auto localIndexSet = basisIndexSet_.localIndexSet();
+
+    ElementIterator it    = basis_.gridView().template begin<0,Dune::Interior_Partition>();
+    ElementIterator endit = basis_.gridView().template end<0,Dune::Interior_Partition>  ();
 
     for( ; it != endit; ++it ) {
 
-        const int numOfBaseFct = basis_.getLocalFiniteElement(*it).localBasis().size();
+        localView.bind(*it);
+        localIndexSet.bind(localView);
+
+        const int numOfBaseFct = localView.tree().size();
 
         // Extract local solution
         std::vector<TargetSpace> localSolution(numOfBaseFct);
 
         for (int i=0; i<numOfBaseFct; i++)
-            localSolution[i] = sol[basis_.index(*it,i)];
+            localSolution[i] = sol[localIndexSet.index(i)[0]];
 
         std::vector<Dune::FieldVector<double,blocksize> > localGradient(numOfBaseFct);
 
         // setup local matrix and gradient
-        localStiffness_->assembleGradientAndHessian(*it, basis_.getLocalFiniteElement(*it), localSolution, localGradient);
+        localStiffness_->assembleGradientAndHessian(*it, localView.tree().finiteElement(), localSolution, localGradient);
 
         // Add element matrix to global stiffness matrix
         for(int i=0; i<numOfBaseFct; i++) {
 
-            int row = basis_.index(*it,i);
+            auto row = localIndexSet.index(i)[0];
 
             for (int j=0; j<numOfBaseFct; j++ ) {
 
-                int col = basis_.index(*it,j);
+                auto col = localIndexSet.index(j)[0];
                 hessian[row][col] += localStiffness_->A_[i][j];
 
             }
@@ -153,7 +170,7 @@ assembleGradientAndHessian(const std::vector<TargetSpace>& sol,
 
         // Add local gradient to global gradient
         for (int i=0; i<numOfBaseFct; i++)
-            gradient[basis_.index(*it,i)] += localGradient[i];
+            gradient[localIndexSet.index(i)[0]] += localGradient[i];
 
     }
 
@@ -164,35 +181,42 @@ void GeodesicFEAssembler<Basis,TargetSpace>::
 assembleGradient(const std::vector<TargetSpace>& sol,
                  Dune::BlockVector<Dune::FieldVector<double, blocksize> >& grad) const
 {
-    if (sol.size()!=basis_.size())
+    if (sol.size()!=basisIndexSet_.size())
         DUNE_THROW(Dune::Exception, "Solution vector doesn't match the grid!");
 
     grad.resize(sol.size());
     grad = 0;
 
-    ElementIterator it    = basis_.getGridView().template begin<0,Dune::Interior_Partition>();
-    ElementIterator endIt = basis_.getGridView().template end<0,Dune::Interior_Partition>();
+    // A view on the FE basis on a single element
+    typename Basis::LocalView localView(&basis_);
+    auto localIndexSet = basisIndexSet_.localIndexSet();
+
+    ElementIterator it    = basis_.gridView().template begin<0,Dune::Interior_Partition>();
+    ElementIterator endIt = basis_.gridView().template end<0,Dune::Interior_Partition>();
 
     // Loop over all elements
     for (; it!=endIt; ++it) {
 
+        localView.bind(*it);
+        localIndexSet.bind(localView);
+
         // A 1d grid has two vertices
-        const int nDofs = basis_.getLocalFiniteElement(*it).localBasis().size();
+        const auto nDofs = localView.tree().size();
 
         // Extract local solution
         std::vector<TargetSpace> localSolution(nDofs);
 
         for (int i=0; i<nDofs; i++)
-            localSolution[i] = sol[basis_.index(*it,i)];
+            localSolution[i] = sol[localIndexSet.index(i)[0]];
 
         // Assemble local gradient
         std::vector<Dune::FieldVector<double,blocksize> > localGradient(nDofs);
 
-        localStiffness_->assembleGradient(*it, basis_.getLocalFiniteElement(*it), localSolution, localGradient);
+        localStiffness_->assembleGradient(*it, localView.tree().finiteElement(), localSolution, localGradient);
 
         // Add to global gradient
         for (int i=0; i<nDofs; i++)
-            grad[basis_.index(*it,i)] += localGradient[i];
+            grad[localIndexSet.index(i)[0]] += localGradient[i];
 
     }
 
@@ -205,24 +229,31 @@ computeEnergy(const std::vector<TargetSpace>& sol) const
 {
     double energy = 0;
 
-    if (sol.size()!=basis_.size())
-        DUNE_THROW(Dune::Exception, "Solution vector doesn't match the grid!");
+    if (sol.size() != basisIndexSet_.size())
+        DUNE_THROW(Dune::Exception, "Coefficient vector doesn't match the function space basis!");
 
-    ElementIterator it    = basis_.getGridView().template begin<0,Dune::Interior_Partition>();
-    ElementIterator endIt = basis_.getGridView().template end<0,Dune::Interior_Partition>();
+    // A view on the FE basis on a single element
+    typename Basis::LocalView localView(&basis_);
+    auto localIndexSet = basisIndexSet_.localIndexSet();
+
+    ElementIterator it    = basis_.gridView().template begin<0,Dune::Interior_Partition>();
+    ElementIterator endIt = basis_.gridView().template end<0,Dune::Interior_Partition>();
 
     // Loop over all elements
     for (; it!=endIt; ++it) {
 
+        localView.bind(*it);
+        localIndexSet.bind(localView);
+
         // Number of degrees of freedom on this element
-        size_t nDofs = basis_.getLocalFiniteElement(*it).localBasis().size();
+        size_t nDofs = localView.tree().size();
 
         std::vector<TargetSpace> localSolution(nDofs);
 
         for (size_t i=0; i<nDofs; i++)
-            localSolution[i] = sol[basis_.index(*it,i)];
+            localSolution[i] = sol[localIndexSet.index(i)[0]];
 
-        energy += localStiffness_->energy(*it, basis_.getLocalFiniteElement(*it), localSolution);
+        energy += localStiffness_->energy(*it, localView.tree().finiteElement(), localSolution);
 
     }
 
diff --git a/dune/gfe/riemanniantrsolver.cc b/dune/gfe/riemanniantrsolver.cc
index 1702f064..65c0f0dd 100644
--- a/dune/gfe/riemanniantrsolver.cc
+++ b/dune/gfe/riemanniantrsolver.cc
@@ -120,10 +120,11 @@ setup(const GridType& grid,
     //   Assemble a Laplace matrix to create a norm that's equivalent to the H1-norm
     // //////////////////////////////////////////////////////////////////////////////////////
 
-    Basis basis(grid.leafGridView());
-    OperatorAssembler<Basis,Basis> operatorAssembler(basis, basis);
+    typedef DuneFunctionsBasis<Basis> FufemBasis;
+    FufemBasis basis(grid.leafGridView());
+    OperatorAssembler<FufemBasis,FufemBasis> operatorAssembler(basis, basis);
 
-    LaplaceAssembler<GridType, typename Basis::LocalFiniteElement, typename Basis::LocalFiniteElement> laplaceStiffness;
+    LaplaceAssembler<GridType, typename FufemBasis::LocalFiniteElement, typename FufemBasis::LocalFiniteElement> laplaceStiffness;
     typedef Dune::BCRSMatrix<Dune::FieldMatrix<double,1,1> > ScalarMatrixType;
     ScalarMatrixType localA;
 
@@ -158,7 +159,7 @@ setup(const GridType& grid,
     //   This will be used to monitor the gradient
     // //////////////////////////////////////////////////////////////////////////////////////
 
-    MassAssembler<GridType, typename Basis::LocalFiniteElement, typename Basis::LocalFiniteElement> massStiffness;
+    MassAssembler<GridType, typename Basis::LocalView::Tree::FiniteElement, typename Basis::LocalView::Tree::FiniteElement> massStiffness;
     ScalarMatrixType localMassMatrix;
 
     operatorAssembler.assemble(massStiffness, localMassMatrix);
@@ -216,7 +217,7 @@ setup(const GridType& grid,
         TransferOperatorType pkToP1TransferMatrix;
         assembleBasisInterpolationMatrix<TransferOperatorType,
                                          P1NodalBasis<typename GridType::LeafGridView,double>,
-                                         Basis>(pkToP1TransferMatrix,p1Basis,basis);
+                                         FufemBasis>(pkToP1TransferMatrix,p1Basis,basis);
 #if HAVE_MPI
         // If we are on more than 1 processors, join all local transfer matrices on rank 0,
         // and construct a single global transfer operator there.
@@ -316,7 +317,7 @@ void RiemannianTrustRegionSolver<Basis,TargetSpace>::solve()
     MaxNormTrustRegion<blocksize> trustRegion(globalMapper_->size(), initialTrustRegionRadius_);
 #else
     Basis basis(grid_->leafGridView());
-    MaxNormTrustRegion<blocksize> trustRegion(basis.size(), initialTrustRegionRadius_);
+    MaxNormTrustRegion<blocksize> trustRegion(basis.indexSet().size(), initialTrustRegionRadius_);
 #endif
     trustRegion.set(initialTrustRegionRadius_, scaling_);
 
diff --git a/src/harmonicmaps.cc b/src/harmonicmaps.cc
index 1aad240f..8c0d0500 100644
--- a/src/harmonicmaps.cc
+++ b/src/harmonicmaps.cc
@@ -132,10 +132,14 @@ int main (int argc, char *argv[]) try
     //  Construct the scalar function space basis corresponding to the GFE space
     //////////////////////////////////////////////////////////////////////////////////
 
-    typedef DuneFunctionsBasis<Dune::Functions::PQKNodalBasis<typename GridType::LeafGridView, 3> > FEBasis;
+    typedef Dune::Functions::PQKNodalBasis<typename GridType::LeafGridView, 3> FEBasis;
+
     FEBasis feBasis(grid->leafGridView());
 
-    SolutionType x(feBasis.size());
+    typedef DuneFunctionsBasis<FEBasis> FufemFEBasis;
+    FufemFEBasis fufemFeBasis(feBasis);
+
+    SolutionType x(fufemFeBasis.size());
 
     // /////////////////////////////////////////
     //   Read Dirichlet values
@@ -146,7 +150,7 @@ int main (int argc, char *argv[]) try
     BoundaryPatch<typename GridType::LeafGridView> dirichletBoundary(grid->leafGridView(), allNodes);
 
     BitSetVector<blocksize> dirichletNodes;
-    constructBoundaryDofs(dirichletBoundary,feBasis,dirichletNodes);
+    constructBoundaryDofs(dirichletBoundary,fufemFeBasis,dirichletNodes);
 
     // //////////////////////////
     //   Initial iterate
@@ -159,7 +163,7 @@ int main (int argc, char *argv[]) try
     auto pythonInitialIterate = module.get("fdf").toC<std::shared_ptr<FBase>>();
 
     std::vector<TargetSpace::CoordinateType> v;
-    ::Functions::interpolate(feBasis, v, *pythonInitialIterate);
+    ::Functions::interpolate(fufemFeBasis, v, *pythonInitialIterate);
 
     for (size_t i=0; i<x.size(); i++)
       x[i] = v[i];
@@ -173,24 +177,24 @@ int main (int argc, char *argv[]) try
 
     // Assembler using ADOL-C
     typedef TargetSpace::rebind<adouble>::other ATargetSpace;
-    std::shared_ptr<LocalGeodesicFEStiffness<GridType::LeafGridView,FEBasis::LocalFiniteElement,ATargetSpace> > localEnergy;
+    std::shared_ptr<LocalGeodesicFEStiffness<GridType::LeafGridView,FEBasis::LocalView::Tree::FiniteElement,ATargetSpace> > localEnergy;
 
     std::string energy = parameterSet.get<std::string>("energy");
     if (energy == "harmonic")
     {
 
-      localEnergy.reset(new HarmonicEnergyLocalStiffness<GridType::LeafGridView, FEBasis::LocalFiniteElement, ATargetSpace>);
+      localEnergy.reset(new HarmonicEnergyLocalStiffness<GridType::LeafGridView, FEBasis::LocalView::Tree::FiniteElement, ATargetSpace>);
 
     } else if (energy == "chiral_skyrmion")
     {
 
-      localEnergy.reset(new GFE::ChiralSkyrmionEnergy<GridType::LeafGridView, FEBasis::LocalFiniteElement, adouble>(parameterSet.sub("energyParameters")));
+      localEnergy.reset(new GFE::ChiralSkyrmionEnergy<GridType::LeafGridView, FEBasis::LocalView::Tree::FiniteElement, adouble>(parameterSet.sub("energyParameters")));
 
     } else
       DUNE_THROW(Exception, "Unknown energy type '" << energy << "'");
 
     LocalGeodesicFEADOLCStiffness<GridType::LeafGridView,
-                                  FEBasis::LocalFiniteElement,
+                                  FEBasis::LocalView::Tree::FiniteElement,
                                   TargetSpace> localGFEADOLCStiffness(localEnergy.get());
 
     GeodesicFEAssembler<FEBasis,TargetSpace> assembler(feBasis, &localGFEADOLCStiffness);
@@ -239,9 +243,9 @@ int main (int argc, char *argv[]) try
     }
 
     VTKWriter<GridType::LeafGridView> vtkWriter(grid->leafGridView());
-    Dune::shared_ptr<VTKBasisGridFunction<FEBasis,EmbeddedVectorType> > vtkVectorField
-        = Dune::shared_ptr<VTKBasisGridFunction<FEBasis,EmbeddedVectorType> >
-               (new VTKBasisGridFunction<FEBasis,EmbeddedVectorType>(feBasis, xEmbedded, "orientation"));
+    Dune::shared_ptr<VTKBasisGridFunction<FufemFEBasis,EmbeddedVectorType> > vtkVectorField
+        = Dune::shared_ptr<VTKBasisGridFunction<FufemFEBasis,EmbeddedVectorType> >
+               (new VTKBasisGridFunction<FufemFEBasis,EmbeddedVectorType>(fufemFeBasis, xEmbedded, "orientation"));
     vtkWriter.addVertexData(vtkVectorField);
 
     vtkWriter.write(resultPath + "_" + energy + "_result");
@@ -259,7 +263,7 @@ int main (int argc, char *argv[]) try
       auto referenceSolution = module.get("fdf").toC<std::shared_ptr<FBase>>();
 
       // The numerical solution, as a grid function
-      GFE::EmbeddedGlobalGFEFunction<FEBasis, TargetSpace> numericalSolution(feBasis, x);
+      GFE::EmbeddedGlobalGFEFunction<FufemFEBasis, TargetSpace> numericalSolution(feBasis, x);
 
       // QuadratureRule for the integral of the L^2 error
       QuadratureRuleKey quadKey(dim,6);
-- 
GitLab