diff --git a/dune/gfe/geodesicfeassembler.hh b/dune/gfe/geodesicfeassembler.hh
index c6958e50a890e5a995849d64dbe119a447750e6f..7be9133aa8839008cef76454992ffd251fac41f4 100644
--- a/dune/gfe/geodesicfeassembler.hh
+++ b/dune/gfe/geodesicfeassembler.hh
@@ -8,13 +8,16 @@
 
 #include "localgeodesicfestiffness.hh"
 
+#include <dune/solvers/common/wrapownshare.hh>
 
 /** \brief A global FE assembler for problems involving functions that map into non-Euclidean spaces
  */
 template <class Basis, class TargetSpace>
 class GeodesicFEAssembler {
 
+    using field_type = typename TargetSpace::field_type;
     typedef typename Basis::GridView GridView;
+    using LocalStiffness = LocalGeodesicFEStiffness<Basis, TargetSpace>;
 
     //! Dimension of the grid.
     enum { gridDim = GridView::dimension };
@@ -25,29 +28,58 @@ class GeodesicFEAssembler {
     //!
     typedef Dune::FieldMatrix<double, blocksize, blocksize> MatrixBlock;
 
-public:
-    const Basis basis_;
 
 protected:
 
-    LocalGeodesicFEStiffness<Basis,TargetSpace>* localStiffness_;
+    //! The global basis
+    const Basis basis_;
+
+    //! The local stiffness operator
+    std::shared_ptr<LocalStiffness> localStiffness_;
 
 public:
 
     /** \brief Constructor for a given grid */
+    GeodesicFEAssembler(const Basis& basis)
+        : basis_(basis)
+    {}
+
+    /** \brief Constructor for a given grid */
+    template <class LocalStiffnessT>
     GeodesicFEAssembler(const Basis& basis,
-                        LocalGeodesicFEStiffness<Basis, TargetSpace>* localStiffness)
+                        LocalStiffnessT&& localStiffness)
         : basis_(basis),
-          localStiffness_(localStiffness)
+          localStiffness_(Dune::Solvers::wrap_own_share<LocalStiffness>(std::forward<LocalStiffnessT>(localStiffness)))
     {}
 
+    /** \brief Set the local stiffness assembler. This can be a temporary, l-value or shared pointer. */
+    template <class LocalStiffnessT>
+    void setLocalStiffness(LocalStiffnessT&& localStiffness) {
+        localStiffness_ = Dune::Solvers::wrap_own_share<LocalStiffness>(std::forward<LocalStiffnessT>(localStiffness));
+    }
+
+    /** \brief Get the local stiffness operator. */
+    const LocalStiffness& getLocalStiffness() const {
+        return *localStiffness_;
+    }
+
+    /** \brief Get the local stiffness operator. */
+    LocalStiffness& getLocalStiffness() {
+        return *localStiffness_;
+    }
+
+    /** \brief Get the basis. */
+    const Basis& getBasis() const {
+        return basis_;
+    }
+
     /** \brief Assemble the tangent stiffness matrix and the functional gradient together
      *
      * This is more efficient than computing them separately, because you need the gradient
      * anyway to compute the Riemannian Hessian.
      */
     virtual void assembleGradientAndHessian(const std::vector<TargetSpace>& sol,
-                                            Dune::BlockVector<Dune::FieldVector<double, blocksize> >& gradient,
+                                            Dune::BlockVector<Dune::FieldVector<field_type, blocksize> >& gradient,
                                             Dune::BCRSMatrix<MatrixBlock>& hessian,
                                             bool computeOccupationPattern=true) const;
 
@@ -103,7 +135,7 @@ getNeighborsPerVertex(Dune::MatrixIndexSet& nb) const
 template <class Basis, class TargetSpace>
 void GeodesicFEAssembler<Basis,TargetSpace>::
 assembleGradientAndHessian(const std::vector<TargetSpace>& sol,
-                           Dune::BlockVector<Dune::FieldVector<double, blocksize> >& gradient,
+                           Dune::BlockVector<Dune::FieldVector<field_type, blocksize> > &gradient,
                            Dune::BCRSMatrix<MatrixBlock>& hessian,
                            bool computeOccupationPattern) const
 {
diff --git a/dune/gfe/geodesicfeassemblerwrapper.hh b/dune/gfe/geodesicfeassemblerwrapper.hh
index 8dee1f1ffc276fadf8180a460c13b10eff6f3cc0..4570b4077798ce8fbe3e14a55444e91adf4ec1bc 100644
--- a/dune/gfe/geodesicfeassemblerwrapper.hh
+++ b/dune/gfe/geodesicfeassemblerwrapper.hh
@@ -62,6 +62,11 @@ public:
     /** \brief Get the occupation structure of the Hessian */
     virtual void getNeighborsPerVertex(Dune::MatrixIndexSet& nb) const;
 
+    /** \brief Get the basis. */
+    const ScalarBasis& getBasis() const {
+        return basis_;
+    }
+
 private:
     Dune::TupleVector<std::vector<MixedSpace0>,std::vector<MixedSpace1>> splitVector(const std::vector<TargetSpace>& sol) const;
     std::unique_ptr<MatrixType> hessianMixed_;
@@ -186,4 +191,4 @@ computeEnergy(const std::vector<TargetSpace>& sol) const
     auto solutionSplit = splitVector(sol);
     return mixedAssembler_->computeEnergy(solutionSplit[_0], solutionSplit[_1]);
 }
-#endif //GLOBAL_GEODESIC_FE_ASSEMBLERWRAPPER_HH
\ No newline at end of file
+#endif //GLOBAL_GEODESIC_FE_ASSEMBLERWRAPPER_HH
diff --git a/dune/gfe/riemanniantrsolver.cc b/dune/gfe/riemanniantrsolver.cc
index b45d502d01d5cf19495d1430d9e31c572112b456..88d4e7713aa48b760478a93cbd3ddeb323979acc 100644
--- a/dune/gfe/riemanniantrsolver.cc
+++ b/dune/gfe/riemanniantrsolver.cc
@@ -117,7 +117,7 @@ setup(const GridType& grid,
     // //////////////////////////////////////////////////////////////////////////////////////
 
     typedef DuneFunctionsBasis<Basis> FufemBasis;
-    FufemBasis basis(assembler_->basis_);
+    FufemBasis basis(assembler_->getBasis());
     OperatorAssembler<FufemBasis,FufemBasis> operatorAssembler(basis, basis);
 
     LaplaceAssembler<GridType, typename FufemBasis::LocalFiniteElement, typename FufemBasis::LocalFiniteElement> laplaceStiffness;
diff --git a/dune/gfe/rodassembler.cc b/dune/gfe/rodassembler.cc
index 797c006e8e167509ee3b61731ecfbbdd6e9b8da8..32300bd3fe334971b8bd850a69255938bb5afbfa 100644
--- a/dune/gfe/rodassembler.cc
+++ b/dune/gfe/rodassembler.cc
@@ -98,7 +98,7 @@ getStrain(const std::vector<RigidBodyMotion<double,3> >& sol,
 
             double weight = quad[pt].weight() * element.geometry().integrationElement(quadPos);
 
-            FieldVector<double,blocksize> localStrain = dynamic_cast<RodLocalStiffness<GridView, double>* >(this->localStiffness_)->getStrain(localSolution, element, quad[pt].position());
+            FieldVector<double,blocksize> localStrain = std::dynamic_pointer_cast<RodLocalStiffness<GridView, double> >(this->localStiffness_)->getStrain(localSolution, element, quad[pt].position());
 
             // Sum it all up
             strain[elementIdx].axpy(weight, localStrain);
diff --git a/dune/gfe/rodassembler.hh b/dune/gfe/rodassembler.hh
index a800cc8ba1bc1626a4ca33ce16483b228fea3507..64c1327ecae83ff146d6675be04243003a8158ae 100644
--- a/dune/gfe/rodassembler.hh
+++ b/dune/gfe/rodassembler.hh
@@ -41,7 +41,7 @@ class RodAssembler<Basis,3> : public GeodesicFEAssembler<Basis, RigidBodyMotion<
 public:
         //! ???
     RodAssembler(const Basis& basis,
-                 LocalGeodesicFEStiffness<Basis, RigidBodyMotion<double,3> >* localStiffness)
+                 LocalGeodesicFEStiffness<Basis, RigidBodyMotion<double,3> >& localStiffness)
     : GeodesicFEAssembler<Basis, RigidBodyMotion<double,3> >(basis,localStiffness)
         {
             std::vector<RigidBodyMotion<double,3> > referenceConfiguration(basis.size());
@@ -62,7 +62,7 @@ public:
     auto rodEnergy()
     {
       // TODO: Does not work for other stiffness implementations
-      auto localFDStiffness = dynamic_cast<LocalGeodesicFEFDStiffness<Basis, RigidBodyMotion<double,3> >*>(this->localStiffness_);
+      auto localFDStiffness = std::dynamic_pointer_cast<LocalGeodesicFEFDStiffness<Basis, RigidBodyMotion<double,3> > >(this->localStiffness_);
       return const_cast<RodLocalStiffness<GridView,double>*>(dynamic_cast<const RodLocalStiffness<GridView,double>*>(localFDStiffness->localEnergy_));
     }
 
diff --git a/src/cosserat-continuum.cc b/src/cosserat-continuum.cc
index a64eefe9aac68578bd3b64fb6de8474c524b0706..294d78a0beee1011538b0ef7109decd7cd7f1523 100644
--- a/src/cosserat-continuum.cc
+++ b/src/cosserat-continuum.cc
@@ -462,7 +462,7 @@ int main (int argc, char *argv[]) try
     LocalGeodesicFEADOLCStiffness<FEBasis,
                                   TargetSpace> localGFEADOLCStiffness(localCosseratEnergy.get());
 
-    GeodesicFEAssembler<FEBasis,TargetSpace> assembler(gridView, &localGFEADOLCStiffness);
+    GeodesicFEAssembler<FEBasis,TargetSpace> assembler(gridView, localGFEADOLCStiffness);
 #endif
 
     // /////////////////////////////////////////////////
diff --git a/src/gradient-flow.cc b/src/gradient-flow.cc
index 3879b56e0fd23777645367382f2b99439832b2e3..ac87cc85783c30796c7d2cceb34b054d895f5ed5 100644
--- a/src/gradient-flow.cc
+++ b/src/gradient-flow.cc
@@ -211,7 +211,7 @@ int main (int argc, char *argv[]) try
 
   LocalGeodesicFEADOLCStiffness<FEBasis,TargetSpace> localGFEADOLCStiffness(sumEnergy.get());
 
-  GeodesicFEAssembler<FEBasis,TargetSpace> assembler(feBasis, &localGFEADOLCStiffness);
+  GeodesicFEAssembler<FEBasis,TargetSpace> assembler(feBasis, localGFEADOLCStiffness);
 
   ///////////////////////////////////////////////////
   //   Create a Riemannian trust-region solver
diff --git a/src/harmonicmaps.cc b/src/harmonicmaps.cc
index 3cfa88935e56fe0dbc0d73e4d15b308097622056..68c123ebf6a6a324e6b0e665f84e5d420cd31c38 100644
--- a/src/harmonicmaps.cc
+++ b/src/harmonicmaps.cc
@@ -305,7 +305,7 @@ int main (int argc, char *argv[])
 
     LocalGeodesicFEADOLCStiffness<FEBasis,TargetSpace> localGFEADOLCStiffness(localEnergy.get());
 
-    GeodesicFEAssembler<FEBasis,TargetSpace> assembler(feBasis, &localGFEADOLCStiffness);
+    GeodesicFEAssembler<FEBasis,TargetSpace> assembler(feBasis, localGFEADOLCStiffness);
 
     // /////////////////////////////////////////////////
     //   Create a Riemannian trust-region solver
diff --git a/src/rod3d.cc b/src/rod3d.cc
index 90ac7842faf41c97a7c261a70d45e12982fe0cd6..8d4a9abcac5d91b5051fddb28b0debbe0e5a30ca 100644
--- a/src/rod3d.cc
+++ b/src/rod3d.cc
@@ -130,7 +130,7 @@ int main (int argc, char *argv[]) try
 
     LocalGeodesicFEFDStiffness<FEBasis,RigidBodyMotion<double,3> > localFDStiffness(&localStiffness);
 
-    RodAssembler<FEBasis,3> rodAssembler(gridView, &localFDStiffness);
+    RodAssembler<FEBasis,3> rodAssembler(gridView, localFDStiffness);
 
     RiemannianTrustRegionSolver<FEBasis,RigidBodyMotion<double,3> > rodSolver;
 
diff --git a/test/frameinvariancetest.cc b/test/frameinvariancetest.cc
index 56c00781b97548aa00aecbdbc3ea8db1501b53c6..dcf53d480975d78a4b2fce61765376dd61a5e121 100644
--- a/test/frameinvariancetest.cc
+++ b/test/frameinvariancetest.cc
@@ -79,7 +79,7 @@ int main (int argc, char *argv[]) try
 
     LocalGeodesicFEFDStiffness<FEBasis,RigidBodyMotion<double,3> > localFDStiffness(&localRodFirstOrderModel);
 
-    RodAssembler<FEBasis,3> assembler(feBasis, &localFDStiffness);
+    RodAssembler<FEBasis,3> assembler(feBasis, localFDStiffness);
 
     if (std::abs(assembler.computeEnergy(x) - assembler.computeEnergy(rotatedX)) > 1e-6)
         DUNE_THROW(Dune::Exception, "Rod energy not invariant under rigid body motions!");
diff --git a/test/harmonicmaptest.cc b/test/harmonicmaptest.cc
index 1b11a59d11f49b89f9b64fc3f1b92e3b745d269e..adc1467b0ac1e2da8e3dd219f3ec9c13a1d61f61 100644
--- a/test/harmonicmaptest.cc
+++ b/test/harmonicmaptest.cc
@@ -160,7 +160,7 @@ int main (int argc, char *argv[])
 
   LocalGeodesicFEADOLCStiffness<FEBasis,TargetSpace> localGFEADOLCStiffness(localEnergy.get());
 
-  GeodesicFEAssembler<FEBasis,TargetSpace> assembler(feBasis, &localGFEADOLCStiffness);
+  GeodesicFEAssembler<FEBasis,TargetSpace> assembler(feBasis, localGFEADOLCStiffness);
 
   ///////////////////////////////////////////////////
   //  Create a Riemannian trust-region solver
diff --git a/test/rodassemblertest.cc b/test/rodassemblertest.cc
index 20c186d3ee3c1ff42fee6676ad14c7a560ad5179..554007834a7010f3f3a0eb1f35572d9ac4d88685 100644
--- a/test/rodassemblertest.cc
+++ b/test/rodassemblertest.cc
@@ -553,7 +553,7 @@ int main (int argc, char *argv[]) try
 
     LocalGeodesicFEFDStiffness<Basis,RigidBodyMotion<double,3> > localFDStiffness(&localStiffness);
 
-    RodAssembler<Basis,3> rodAssembler(basis, &localFDStiffness);
+    RodAssembler<Basis,3> rodAssembler(basis, localFDStiffness);
 
     std::cout << "Energy: " << rodAssembler.computeEnergy(x) << std::endl;