From 3fd15b9ed271b3400bd207dee13ee478d7fdfa76 Mon Sep 17 00:00:00 2001
From: Oliver Sander <oliver.sander@tu-dresden.de>
Date: Fri, 20 Nov 2020 14:39:31 +0100
Subject: [PATCH] Support rod discretizations of any order

So far, the Cosserat rod energy implementation hat a first-order
finite element space hardcoded.  This patch removes that restriction.
As for the other models in this Dune module, the finite element basis
is now a template parameter of the model energy, and can be set to
any reasonable basis.
---
 dune/gfe/cosseratrodenergy.hh        | 146 ++++++++++++---------------
 dune/gfe/localgeodesicfefunction.hh  |  14 +++
 dune/gfe/localprojectedfefunction.hh |  14 +++
 problems/staticrod.parset            |   3 +
 src/rod3d.cc                         |  96 ++++++++++++------
 test/frameinvariancetest.cc          |  12 ++-
 6 files changed, 167 insertions(+), 118 deletions(-)

diff --git a/dune/gfe/cosseratrodenergy.hh b/dune/gfe/cosseratrodenergy.hh
index 6e35b707..a7c1386d 100644
--- a/dune/gfe/cosseratrodenergy.hh
+++ b/dune/gfe/cosseratrodenergy.hh
@@ -17,19 +17,18 @@
 #include <dune/fufem/boundarypatch.hh>
 
 #include <dune/gfe/localenergy.hh>
-#include <dune/gfe/localgeodesicfefunction.hh>
 #include <dune/gfe/rigidbodymotion.hh>
 
 namespace Dune::GFE {
 
-template<class GridView, class RT>
+template<class Basis, class LocalInterpolationRule, class RT>
 class CosseratRodEnergy
-: public LocalEnergy<Functions::LagrangeBasis<GridView,1>, RigidBodyMotion<RT,3> >
+: public LocalEnergy<Basis, RigidBodyMotion<RT,3> >
 {
     typedef RigidBodyMotion<RT,3> TargetSpace;
-    typedef Functions::LagrangeBasis<GridView,1> Basis;
 
     // grid types
+    using GridView = typename Basis::GridView;
     typedef typename GridView::Grid::ctype DT;
     typedef typename GridView::template Codim<0>::Entity Entity;
 
@@ -56,11 +55,7 @@ public:
 public:
 
     //! Each block is x, y, theta in 2d, T (R^3 \times SO(3)) in 3d
-    enum { blocksize = 6 };
-
-    // define the number of components of your system, this is used outside
-    // to allocate the correct size of (dense) blocks with a FieldMatrix
-    enum {m=blocksize};
+    static constexpr auto blocksize = TargetSpace::EmbeddedTangentVector::dimension;
 
     // /////////////////////////////////
     //   The material parameters
@@ -102,8 +97,8 @@ public:
         A_[2] = E * A;
     }
 
-
-
+    /** \brief Set the stress-free configuration
+     */
     void setReferenceConfiguration(const std::vector<RigidBodyMotion<double,3> >& referenceConfiguration) {
         referenceConfiguration_ = referenceConfiguration;
     }
@@ -116,8 +111,8 @@ public:
      *
      * \tparam Number This is a member template because the method has to work for double and adouble
      */
-    template<class Number>
-    FieldVector<Number, 6> getStrain(const std::vector<RigidBodyMotion<Number,3> >& localSolution,
+    template<class ReboundLocalInterpolationRule>
+    auto getStrain(const ReboundLocalInterpolationRule& localSolution,
                                            const Entity& element,
                                            const FieldVector<double,1>& pos) const;
 
@@ -126,7 +121,7 @@ public:
      * \tparam Number This is a member template because the method has to work for double and adouble
      */
     template<class Number>
-    FieldVector<Number, 6> getStress(const std::vector<RigidBodyMotion<Number,3> >& localSolution,
+    auto getStress(const std::vector<RigidBodyMotion<Number,3> >& localSolution,
                                            const Entity& element,
                                            const FieldVector<double,1>& pos) const;
 
@@ -142,19 +137,20 @@ public:
 
      \note Linear run-time in the size of the grid */
     template <class PatchGridView>
-    FieldVector<double,6> getResultantForce(const BoundaryPatch<PatchGridView>& boundary,
+    auto getResultantForce(const BoundaryPatch<PatchGridView>& boundary,
                                                   const std::vector<RigidBodyMotion<double,3> >& sol) const;
 
 protected:
 
-    void getLocalReferenceConfiguration(const Entity& element,
-                                        std::vector<RigidBodyMotion<double,3> >& localReferenceConfiguration) const {
-
-        unsigned int numOfBaseFct = element.subEntities(dim);
-        localReferenceConfiguration.resize(numOfBaseFct);
+    std::vector<RigidBodyMotion<double,3> > getLocalReferenceConfiguration(const typename Basis::LocalView& localView) const
+    {
+        unsigned int numOfBaseFct = localView.size();
+        std::vector<RigidBodyMotion<double,3> > localReferenceConfiguration(numOfBaseFct);
 
         for (size_t i=0; i<numOfBaseFct; i++)
-            localReferenceConfiguration[i] = referenceConfiguration_[gridView_.indexSet().subIndex(element,i,dim)];
+            localReferenceConfiguration[i] = referenceConfiguration_[localView.index(i)];
+
+        return localReferenceConfiguration;
     }
 
       template <class T>
@@ -171,18 +167,21 @@ protected:
 
 };
 
-template <class GridView, class RT>
-RT CosseratRodEnergy<GridView, RT>::
+template<class Basis, class LocalInterpolationRule, class RT>
+RT CosseratRodEnergy<Basis, LocalInterpolationRule, RT>::
 energy(const typename Basis::LocalView& localView,
-       const std::vector<RigidBodyMotion<RT,3> >& localSolution) const
+       const std::vector<RigidBodyMotion<RT,3> >& localCoefficients) const
 {
-    assert(localSolution.size()==2);
+    const auto& localFiniteElement = localView.tree().finiteElement();
+    LocalInterpolationRule localConfiguration(localFiniteElement, localCoefficients);
+
     const auto& element = localView.element();
 
     RT energy = 0;
 
-    std::vector<RigidBodyMotion<double,3> > localReferenceConfiguration;
-    getLocalReferenceConfiguration(element, localReferenceConfiguration);
+    std::vector<RigidBodyMotion<double,3> > localReferenceCoefficients = getLocalReferenceConfiguration(localView);
+    using InactiveLocalInterpolationRule = typename LocalInterpolationRule::template rebind<RigidBodyMotion<double,3> >::other;
+    InactiveLocalInterpolationRule localReferenceConfiguration(localFiniteElement, localReferenceCoefficients);
 
     // ///////////////////////////////////////////////////////////////////////////////
     //   The following two loops are a reduced integration scheme.  We integrate
@@ -190,23 +189,18 @@ energy(const typename Basis::LocalView& localView,
     //   formula, even though it should be second order.  This prevents shear-locking.
     // ///////////////////////////////////////////////////////////////////////////////
 
-    const QuadratureRule<double, 1>& shearingQuad
-        = QuadratureRules<double, 1>::rule(element.type(), shearQuadOrder);
-
-    // hack: convert from std::array to std::vector
-    // TODO: REMOVE ME!
-    std::vector<RigidBodyMotion<RT,3> > localSolutionVector(localSolution.begin(), localSolution.end());
+    const auto& shearingQuad = QuadratureRules<double, 1>::rule(element.type(), shearQuadOrder);
 
     for (size_t pt=0; pt<shearingQuad.size(); pt++) {
 
         // Local position of the quadrature point
-        const FieldVector<double,1>& quadPos = shearingQuad[pt].position();
+        const auto quadPos = shearingQuad[pt].position();
 
         const double integrationElement = element.geometry().integrationElement(quadPos);
 
         double weight = shearingQuad[pt].weight() * integrationElement;
 
-        auto strain = getStrain(localSolutionVector, element, quadPos);
+        auto strain = getStrain(localConfiguration, element, quadPos);
 
         // The reference strain
         auto referenceStrain = getStrain(localReferenceConfiguration, element, quadPos);
@@ -217,8 +211,7 @@ energy(const typename Basis::LocalView& localView,
     }
 
     // Get quadrature rule
-    const QuadratureRule<double, 1>& bendingQuad
-        = QuadratureRules<double, 1>::rule(element.type(), bendingQuadOrder);
+    const auto& bendingQuad = QuadratureRules<double, 1>::rule(element.type(), bendingQuadOrder);
 
     for (size_t pt=0; pt<bendingQuad.size(); pt++) {
 
@@ -227,7 +220,7 @@ energy(const typename Basis::LocalView& localView,
 
         double weight = bendingQuad[pt].weight() * element.geometry().integrationElement(quadPos);
 
-        auto strain = getStrain(localSolutionVector, element, quadPos);
+        auto strain = getStrain(localConfiguration, element, quadPos);
 
         // The reference strain
         auto referenceStrain = getStrain(localReferenceConfiguration, element, quadPos);
@@ -242,30 +235,18 @@ energy(const typename Basis::LocalView& localView,
 }
 
 
-template <class GridView, class RT>
-template <class Number>
-FieldVector<Number, 6> CosseratRodEnergy<GridView, RT>::
-getStrain(const std::vector<RigidBodyMotion<Number,3> >& localSolution,
+template<class Basis, class LocalInterpolationRule, class RT>
+template <class ReboundLocalInterpolationRule>
+auto CosseratRodEnergy<Basis, LocalInterpolationRule, RT>::
+getStrain(const ReboundLocalInterpolationRule& localInterpolation,
           const Entity& element,
           const FieldVector<double,1>& pos) const
 {
-    if (!element.isLeaf())
-        DUNE_THROW(NotImplemented, "Only for leaf elements");
-
-    assert(localSolution.size() == 2);
-
-    // Extract local solution on this element
-    P1LocalFiniteElement<double,double,1> localFiniteElement;
-
     const auto jit = element.geometry().jacobianInverseTransposed(pos);
-    using LocalInterpolationRule = LocalGeodesicFEFunction<1, typename GridView::ctype,
-                                                           decltype(localFiniteElement),
-                                                           RigidBodyMotion<Number,3> >;
-    LocalInterpolationRule localInterpolationRule(localFiniteElement,localSolution);
 
-    auto value = localInterpolationRule.evaluate(pos);
+    auto value = localInterpolation.evaluate(pos);
 
-    auto referenceDerivative = localInterpolationRule.evaluateDerivative(pos);
+    auto referenceDerivative = localInterpolation.evaluateDerivative(pos);
 #if DUNE_VERSION_GTE(DUNE_COMMON, 2, 8)
     auto derivative = referenceDerivative * transpose(jit);
 #else
@@ -273,13 +254,14 @@ getStrain(const std::vector<RigidBodyMotion<Number,3> >& localSolution,
     derivative *= jit[0][0];
 #endif
 
-    FieldVector<Number,3> r_s = {derivative[0], derivative[1], derivative[2]};
+    using Number = std::decay_t<decltype(derivative[0][0])>;
+    FieldVector<Number,3> r_s = {derivative[0][0], derivative[1][0], derivative[2][0]};
 
     // Transformation from the reference element
-    Quaternion<Number> q_s(derivative[3],
-                           derivative[4],
-                           derivative[5],
-                           derivative[6]);
+    Quaternion<Number> q_s(derivative[3][0],
+                           derivative[4][0],
+                           derivative[5][0],
+                           derivative[6][0]);
 
     // /////////////////////////////////////////////
     //   Sum it all up
@@ -294,7 +276,7 @@ getStrain(const std::vector<RigidBodyMotion<Number,3> >& localSolution,
 
     // Part II: the Darboux vector
 
-    FieldVector<Number,3> u = darboux(value.q, q_s);
+    FieldVector<Number,3> u = darboux<Number>(value.q, q_s);
     strain[3] = u[0];
     strain[4] = u[1];
     strain[5] = u[2];
@@ -302,12 +284,12 @@ getStrain(const std::vector<RigidBodyMotion<Number,3> >& localSolution,
     return strain;
 }
 
-template <class GridView, class RT>
+template<class Basis, class LocalInterpolationRule, class RT>
 template <class Number>
-FieldVector<Number, 6> CosseratRodEnergy<GridView, RT>::
+auto CosseratRodEnergy<Basis, LocalInterpolationRule, RT>::
 getStress(const std::vector<RigidBodyMotion<Number,3> >& localSolution,
-              const Entity& element,
-                        const FieldVector<double, 1>& pos) const
+          const Entity& element,
+          const FieldVector<double, 1>& pos) const
 {
     const auto& indexSet = gridView_.indexSet();
     std::vector<TargetSpace> localRefConf = {referenceConfiguration_[indexSet.subIndex(element, 0, 1)],
@@ -325,8 +307,8 @@ getStress(const std::vector<RigidBodyMotion<Number,3> >& localSolution,
     return stress;
 }
 
-template <class GridView, class RT>
-void CosseratRodEnergy<GridView, RT>::
+template<class Basis, class LocalInterpolationRule, class RT>
+void CosseratRodEnergy<Basis, LocalInterpolationRule, RT>::
 getStrain(const std::vector<RigidBodyMotion<double,3> >& sol,
           BlockVector<FieldVector<double, blocksize> >& strain) const
 {
@@ -364,7 +346,7 @@ getStrain(const std::vector<RigidBodyMotion<double,3> >& sol,
 
             double weight = quad[pt].weight() * element.geometry().integrationElement(quadPos);
 
-            auto localStrain = std::dynamic_pointer_cast<CosseratRodEnergy<GridView, double> >(this->localStiffness_)->getStrain(localSolution, element, quad[pt].position());
+            auto localStrain = getStrain(localSolution, element, quad[pt].position());
 
             // Sum it all up
             strain[elementIdx].axpy(weight, localStrain);
@@ -380,8 +362,8 @@ getStrain(const std::vector<RigidBodyMotion<double,3> >& sol,
     }
 }
 
-template <class GridView, class RT>
-void CosseratRodEnergy<GridView, RT>::
+template<class Basis, class LocalInterpolationRule, class RT>
+void CosseratRodEnergy<Basis, LocalInterpolationRule, RT>::
 getStress(const std::vector<RigidBodyMotion<double,3> >& sol,
           BlockVector<FieldVector<double, blocksize> >& stress) const
 {
@@ -390,22 +372,22 @@ getStress(const std::vector<RigidBodyMotion<double,3> >& sol,
 
     // Get reference strain
     BlockVector<FieldVector<double, blocksize> > referenceStrain;
-    getStrain(dynamic_cast<CosseratRodEnergy<GridView, double>* >(this->localStiffness_)->referenceConfiguration_, referenceStrain);
+    getStrain(referenceConfiguration_, referenceStrain);
 
     // Linear diagonal constitutive law
     for (size_t i=0; i<stress.size(); i++)
     {
         for (int j=0; j<3; j++)
         {
-            stress[i][j]   = (stress[i][j]   - referenceStrain[i][j])   * dynamic_cast<CosseratRodEnergy<GridView, double>* >(this->localStiffness_)->A_[j];
-            stress[i][j+3] = (stress[i][j+3] - referenceStrain[i][j+3]) * dynamic_cast<CosseratRodEnergy<GridView, double>* >(this->localStiffness_)->K_[j];
+            stress[i][j]   = (stress[i][j]   - referenceStrain[i][j])   * A_[j];
+            stress[i][j+3] = (stress[i][j+3] - referenceStrain[i][j+3]) * K_[j];
         }
     }
 }
 
-template <class GridView, class RT>
+template<class Basis, class LocalInterpolationRule, class RT>
 template <class PatchGridView>
-FieldVector<double,6> CosseratRodEnergy<GridView, RT>::
+auto CosseratRodEnergy<Basis, LocalInterpolationRule, RT>::
 getResultantForce(const BoundaryPatch<PatchGridView>& boundary,
                   const std::vector<RigidBodyMotion<double,3> >& sol) const
 {
@@ -431,19 +413,19 @@ getResultantForce(const BoundaryPatch<PatchGridView>& boundary,
         localSolution[1] = sol[indexSet.subIndex(*facet.inside(),1,1)];
 
         std::vector<RigidBodyMotion<double,3> > localRefConf(2);
-        localRefConf[0] = dynamic_cast<CosseratRodEnergy<GridView, double>* >(this->localStiffness_)->referenceConfiguration_[indexSet.subIndex(*facet.inside(),0,1)];
-        localRefConf[1] = dynamic_cast<CosseratRodEnergy<GridView, double>* >(this->localStiffness_)->referenceConfiguration_[indexSet.subIndex(*facet.inside(),1,1)];
+        localRefConf[0] = referenceConfiguration_[indexSet.subIndex(*facet.inside(),0,1)];
+        localRefConf[1] = referenceConfiguration_[indexSet.subIndex(*facet.inside(),1,1)];
 
-        auto strain          = dynamic_cast<CosseratRodEnergy<GridView, double>* >(this->localStiffness_)->getStrain(localSolution, *facet.inside(), pos);
-        auto referenceStrain = dynamic_cast<CosseratRodEnergy<GridView, double>* >(this->localStiffness_)->getStrain(localRefConf, *facet.inside(), pos);
+        auto strain          = getStrain(localSolution, *facet.inside(), pos);
+        auto referenceStrain = getStrain(localRefConf, *facet.inside(), pos);
 
         FieldVector<double,3> localStress;
         for (int i=0; i<3; i++)
-            localStress[i] = (strain[i] - referenceStrain[i]) * dynamic_cast<CosseratRodEnergy<GridView, double>* >(this->localStiffness_)->A_[i];
+            localStress[i] = (strain[i] - referenceStrain[i]) * A_[i];
 
         FieldVector<double,3> localTorque;
         for (int i=0; i<3; i++)
-            localTorque[i] = (strain[i+3] - referenceStrain[i+3]) * dynamic_cast<CosseratRodEnergy<GridView, double>* >(this->localStiffness_)->K_[i];
+            localTorque[i] = (strain[i+3] - referenceStrain[i+3]) * K_[i];
 
         // Transform stress given with respect to the basis given by the three directors to
         // the canonical basis of R^3
diff --git a/dune/gfe/localgeodesicfefunction.hh b/dune/gfe/localgeodesicfefunction.hh
index 3502f371..e13acbd8 100644
--- a/dune/gfe/localgeodesicfefunction.hh
+++ b/dune/gfe/localgeodesicfefunction.hh
@@ -59,6 +59,13 @@ public:
         assert(localFiniteElement_.localBasis().size() == coefficients_.size());
     }
 
+    /** \brief Rebind the FEFunction to another TargetSpace */
+    template<class U>
+    struct rebind
+    {
+      using other = LocalGeodesicFEFunction<dim,ctype,LocalFiniteElement,U>;
+    };
+
     /** \brief The number of Lagrange points */
     unsigned int size() const
     {
@@ -589,6 +596,13 @@ public:
 
     }
 
+    /** \brief Rebind the FEFunction to another TargetSpace */
+    template<class U>
+    struct rebind
+    {
+      using other = LocalGeodesicFEFunction<dim,ctype,LocalFiniteElement,U>;
+    };
+
     /** \brief The number of Lagrange points */
     unsigned int size() const
     {
diff --git a/dune/gfe/localprojectedfefunction.hh b/dune/gfe/localprojectedfefunction.hh
index 259270cc..ba538c55 100644
--- a/dune/gfe/localprojectedfefunction.hh
+++ b/dune/gfe/localprojectedfefunction.hh
@@ -70,6 +70,13 @@ Dune::FieldMatrix< K, m, p > operator* ( const Dune::FieldMatrix< K, m, n > &A,
         assert(localFiniteElement_.localBasis().size() == coefficients_.size());
       }
 
+      /** \brief Rebind the FEFunction to another TargetSpace */
+      template<class U>
+      struct rebind
+      {
+        using other = LocalProjectedFEFunction<dim,ctype,LocalFiniteElement,U>;
+      };
+
       /** \brief The number of Lagrange points */
       unsigned int size() const
       {
@@ -452,6 +459,13 @@ Dune::FieldMatrix< K, m, p > operator* ( const Dune::FieldMatrix< K, m, n > &A,
         orientationFunction_ = std::make_unique<LocalProjectedFEFunction<dim,ctype,LocalFiniteElement,Rotation<field_type,3> > > (localFiniteElement,orientationCoefficients);
       }
 
+      /** \brief Rebind the FEFunction to another TargetSpace */
+      template<class U>
+      struct rebind
+      {
+        using other = LocalProjectedFEFunction<dim,ctype,LocalFiniteElement,U>;
+      };
+
       /** \brief The number of Lagrange points */
       unsigned int size() const
       {
diff --git a/problems/staticrod.parset b/problems/staticrod.parset
index 9821813d..2d001eca 100644
--- a/problems/staticrod.parset
+++ b/problems/staticrod.parset
@@ -51,6 +51,9 @@ instrumented = no
 #   Problem specifications
 ############################
 
+# Interpolation method
+interpolationMethod = geodesic
+
 A = 1
 J1 = 1
 J2 = 1
diff --git a/src/rod3d.cc b/src/rod3d.cc
index 0efdaff8..866addb2 100644
--- a/src/rod3d.cc
+++ b/src/rod3d.cc
@@ -25,6 +25,7 @@
 
 #if HAVE_DUNE_VTK
 #include <dune/vtk/vtkwriter.hh>
+#include <dune/vtk/datacollectors/lagrangedatacollector.hh>
 #else
 #include <dune/gfe/cosseratvtkwriter.hh>
 #endif
@@ -32,6 +33,8 @@
 #include <dune/gfe/cosseratrodenergy.hh>
 #include <dune/gfe/geodesicfeassembler.hh>
 #include <dune/gfe/localgeodesicfeadolcstiffness.hh>
+#include <dune/gfe/localgeodesicfefunction.hh>
+#include <dune/gfe/localprojectedfefunction.hh>
 #include <dune/gfe/rigidbodymotion.hh>
 #include <dune/gfe/rotation.hh>
 #include <dune/gfe/riemanniantrsolver.hh>
@@ -40,6 +43,9 @@ typedef RigidBodyMotion<double,3> TargetSpace;
 
 const int blocksize = TargetSpace::TangentVector::dimension;
 
+// Approximation order of the finite element space
+constexpr int order = 2;
+
 using namespace Dune;
 
 int main (int argc, char *argv[]) try
@@ -91,22 +97,37 @@ int main (int argc, char *argv[]) try
     using GridView = GridType::LeafGridView;
     GridView gridView = grid.leafGridView();
 
-    using FEBasis = Functions::LagrangeBasis<GridView,1>;
+    using FEBasis = Functions::LagrangeBasis<GridView,order>;
     FEBasis feBasis(gridView);
 
     SolutionType x(feBasis.size());
 
-    // //////////////////////////
-    //   Initial solution
-    // //////////////////////////
+    //////////////////////////////////////////////
+    //  Create the stress-free configuration
+    //////////////////////////////////////////////
+
+    std::vector<double> referenceConfigurationX(feBasis.size());
+
+    auto identity = [](const FieldVector<double,1>& x) { return x; };
 
-    for (size_t i=0; i<x.size(); i++) {
-        x[i].r[0] = 0;
-        x[i].r[1] = 0;
-        x[i].r[2] = double(i)/(x.size()-1);
-        x[i].q    = Rotation<double,3>::identity();
+    Functions::interpolate(feBasis, referenceConfigurationX, identity);
+
+    std::vector<RigidBodyMotion<double,3> > referenceConfiguration(feBasis.size());
+
+    for (std::size_t i=0; i<referenceConfiguration.size(); i++)
+    {
+        referenceConfiguration[i].r[0] = 0;
+        referenceConfiguration[i].r[1] = 0;
+        referenceConfiguration[i].r[2] = referenceConfigurationX[i];
+        referenceConfiguration[i].q = Rotation<double,3>::identity();
     }
 
+    /////////////////////////////////////////////////////////////////
+    //   Select the reference configuration as initial iterate
+    /////////////////////////////////////////////////////////////////
+
+    x = referenceConfiguration;
+
     // /////////////////////////////////////////
     //   Read Dirichlet values
     // /////////////////////////////////////////
@@ -135,37 +156,44 @@ int main (int argc, char *argv[]) try
         
     dirichletNodes[0] = true;
     dirichletNodes.back() = true;
-    
+
     //////////////////////////////////////////////
-    //  Create the stress-free configuration
+    //  Create the energy and assembler
     //////////////////////////////////////////////
 
-    auto localRodEnergy = std::make_shared<GFE::CosseratRodEnergy<GridView,adouble> >(gridView,
-                                                                                      A, J1, J2, E, nu);
+    using ATargetSpace = TargetSpace::rebind<adouble>::other;
+    using GeodesicInterpolationRule  = LocalGeodesicFEFunction<1, double, FEBasis::LocalView::Tree::FiniteElement, ATargetSpace>;
+    using ProjectedInterpolationRule = GFE::LocalProjectedFEFunction<1, double, FEBasis::LocalView::Tree::FiniteElement, ATargetSpace>;
 
-    std::vector<RigidBodyMotion<double,3> > referenceConfiguration(gridView.size(1));
+    // Assembler using ADOL-C
+    std::shared_ptr<GFE::LocalEnergy<FEBasis,ATargetSpace> > localRodEnergy;
 
-    for (const auto vertex : vertices(gridView))
+    if (parameterSet["interpolationMethod"] == "geodesic")
     {
-        auto idx = gridView.indexSet().index(vertex);
-
-        referenceConfiguration[idx].r[0] = 0;
-        referenceConfiguration[idx].r[1] = 0;
-        referenceConfiguration[idx].r[2] = vertex.geometry().corner(0)[0];
-        referenceConfiguration[idx].q = Rotation<double,3>::identity();
+        auto energy = std::make_shared<GFE::CosseratRodEnergy<FEBasis, GeodesicInterpolationRule, adouble> >(gridView,
+                                                                                                             A, J1, J2, E, nu);
+        energy->setReferenceConfiguration(referenceConfiguration);
+        localRodEnergy = energy;
     }
-
-    localRodEnergy->setReferenceConfiguration(referenceConfiguration);
-
-    // ///////////////////////////////////////////
-    //   Create a solver for the rod problem
-    // ///////////////////////////////////////////
+    else if (parameterSet["interpolationMethod"] == "projected")
+    {
+        auto energy = std::make_shared<GFE::CosseratRodEnergy<FEBasis, ProjectedInterpolationRule, adouble> >(gridView,
+                                                                                                              A, J1, J2, E, nu);
+        energy->setReferenceConfiguration(referenceConfiguration);
+        localRodEnergy = energy;
+    }
+    else
+        DUNE_THROW(Exception, "Unknown interpolation method " << parameterSet["interpolationMethod"] << " requested!");
 
     LocalGeodesicFEADOLCStiffness<FEBasis,
                                   TargetSpace> localStiffness(localRodEnergy.get());
 
     GeodesicFEAssembler<FEBasis,TargetSpace> rodAssembler(gridView, localStiffness);
 
+    /////////////////////////////////////////////
+    //   Create a solver for the rod problem
+    /////////////////////////////////////////////
+
     RiemannianTrustRegionSolver<FEBasis,RigidBodyMotion<double,3> > rodSolver;
 
     rodSolver.setup(grid, 
@@ -197,14 +225,16 @@ int main (int argc, char *argv[]) try
     //   Output result
     // //////////////////////////////
 #if HAVE_DUNE_VTK
-    VtkUnstructuredGridWriter<GridView> vtkWriter(gridView, Vtk::ASCII);
+    using DataCollector = Vtk::LagrangeDataCollector<GridView,order>;
+    DataCollector dataCollector(gridView);
+    VtkUnstructuredGridWriter<GridView,DataCollector> vtkWriter(gridView, Vtk::ASCII);
 
     // Make basis for R^3-valued data
     using namespace Functions::BasisFactory;
 
     auto worldBasis = makeBasis(
       gridView,
-      power<3>(lagrange<1>())
+      power<3>(lagrange<order>())
     );
 
     // The rod displacement field
@@ -227,14 +257,14 @@ int main (int argc, char *argv[]) try
     // The three director fields
     using FunctionType = decltype(displacementFunction);
     std::array<std::optional<FunctionType>, 3> directorFunction;
-
+    std::array<BlockVector<FieldVector<double, 3> >, 3> director;
     for (int i=0; i<3; i++)
     {
-      BlockVector<FieldVector<double, 3> > director(worldBasis.size());
+      director[i].resize(worldBasis.size());
       for (std::size_t j=0; j<x.size(); j++)
-        director[j] = x[j].q.director(i);
+        director[i][j] = x[j].q.director(i);
 
-      directorFunction[i] = Functions::makeDiscreteGlobalBasisFunction<FieldVector<double,3> >(worldBasis, std::move(director));
+      directorFunction[i] = Functions::makeDiscreteGlobalBasisFunction<FieldVector<double,3> >(worldBasis, std::move(director[i]));
       vtkWriter.addPointData(*directorFunction[i], "director " + std::to_string(i), 3);
     }
 
diff --git a/test/frameinvariancetest.cc b/test/frameinvariancetest.cc
index 361d0cdd..ff65271b 100644
--- a/test/frameinvariancetest.cc
+++ b/test/frameinvariancetest.cc
@@ -5,7 +5,7 @@
 #include <dune/functions/functionspacebases/lagrangebasis.hh>
 
 #include <dune/gfe/cosseratrodenergy.hh>
-#include <dune/gfe/quaternion.hh>
+#include <dune/gfe/localgeodesicfefunction.hh>
 #include <dune/gfe/rigidbodymotion.hh>
 
 
@@ -69,8 +69,14 @@ int main (int argc, char *argv[]) try
         rotatedX[i].q = rotation.mult(x[i].q);
     }
 
-    GFE::CosseratRodEnergy<GridView,double> localRodEnergy(gridView,
-                                                      1,1,1,1e6,0.3);
+    using GeodesicInterpolationRule  = LocalGeodesicFEFunction<1, double,
+                                                               FEBasis::LocalView::Tree::FiniteElement,
+                                                               RigidBodyMotion<double,3> >;
+
+    GFE::CosseratRodEnergy<FEBasis,
+                           GeodesicInterpolationRule,
+                           double> localRodEnergy(gridView,
+                                                  1,1,1,1e6,0.3);
 
     std::vector<RigidBodyMotion<double,3> > referenceConfiguration(gridView.size(1));
 
-- 
GitLab