diff --git a/dune/gfe/nonplanarcosseratshellenergy.hh b/dune/gfe/nonplanarcosseratshellenergy.hh
index 8fd0aceb4de0effe76dc7d854a990bf5cd7396d2..21a1f2a9e390bff5db91d6984bfcce7b03c56343 100644
--- a/dune/gfe/nonplanarcosseratshellenergy.hh
+++ b/dune/gfe/nonplanarcosseratshellenergy.hh
@@ -9,6 +9,8 @@
 
 #include <dune/fufem/boundarypatch.hh>
 
+#include <dune/functions/gridfunctions/discreteglobalbasisfunction.hh>
+
 #include <dune/gfe/localenergy.hh>
 #include <dune/gfe/localgeodesicfefunction.hh>
 #include <dune/gfe/rigidbodymotion.hh>
@@ -21,7 +23,15 @@
 #include <dune/localfunctions/lagrange/lfecache.hh>
 #endif
 
-template<class Basis, int dim, class field_type=double>
+/** \brief Assembles the cosserat energy for a single element.
+ *
+ * \tparam Basis                       Type of the Basis used for assembling
+ * \tparam dim                         Dimension of the Targetspace, 3
+ * \tparam field_type                  The coordinate type of the TargetSpace
+ * \tparam StressFreeStateGridFunction Type of the GridFunction representing the Cosserat shell in a stress free state
+ */
+template<class Basis, int dim, class field_type=double, class StressFreeStateGridFunction = 
+  Dune::Functions::DiscreteGlobalBasisFunction<Basis,std::vector<Dune::FieldVector<double, Basis::GridView::dimensionworld>> > >
 class NonplanarCosseratShellEnergy
   : public Dune::GFE::LocalEnergy<Basis,RigidBodyMotion<field_type,dim> >
 {
@@ -40,13 +50,16 @@ class NonplanarCosseratShellEnergy
 public:
 
   /** \brief Constructor with a set of material parameters
-   * \param parameters The material parameters
+   * \param parameters                  The material parameters
+   * \param stressFreeStateGridFunction Pointer to a parametrization representing the Cosserat shell in a stress-free state
    */
   NonplanarCosseratShellEnergy(const Dune::ParameterTree& parameters,
+                               const StressFreeStateGridFunction* stressFreeStateGridFunction,
                                const BoundaryPatch<GridView>* neumannBoundary,
                                const std::function<Dune::FieldVector<double,3>(Dune::FieldVector<double,dimworld>)> neumannFunction,
                                const std::function<Dune::FieldVector<double,3>(Dune::FieldVector<double,dimworld>)> volumeLoad)
-  : neumannBoundary_(neumannBoundary),
+  : stressFreeStateGridFunction_(stressFreeStateGridFunction),
+    neumannBoundary_(neumannBoundary),
     neumannFunction_(neumannFunction),
     volumeLoad_(volumeLoad)
   {
@@ -111,6 +124,9 @@ public:
   /** \brief Curvature parameters */
   double b1_, b2_, b3_;
 
+  /** \brief The geometry used for assembling */
+  const StressFreeStateGridFunction* stressFreeStateGridFunction_;
+
   /** \brief The Neumann boundary */
   const BoundaryPatch<GridView>* neumannBoundary_;
 
@@ -121,15 +137,34 @@ public:
   const std::function<Dune::FieldVector<double,3>(Dune::FieldVector<double,dimworld>)> volumeLoad_;
 };
 
-template <class Basis, int dim, class field_type>
-typename NonplanarCosseratShellEnergy<Basis,dim,field_type>::RT
-NonplanarCosseratShellEnergy<Basis,dim,field_type>::
+template <class Basis, int dim, class field_type, class StressFreeStateGridFunction>
+typename NonplanarCosseratShellEnergy<Basis, dim, field_type, StressFreeStateGridFunction>::RT
+NonplanarCosseratShellEnergy<Basis,dim,field_type, StressFreeStateGridFunction>::
 energy(const typename Basis::LocalView& localView,
        const std::vector<RigidBodyMotion<field_type,dim> >& localSolution) const
 {
   // The element geometry
   auto element = localView.element();
+
+#if HAVE_DUNE_CURVEDGEOMETRY
+  // Construct a curved geometry of this element of the Cosserat shell in stress-free state
+  // When using element.geometry(), then the curvatures on the element are zero, when using a curved geometry, they are not
+  // If a parametrization representing the Cosserat shell in a stress-free state is given,
+  // this is used for the curved geometry approximation.
+  // The variable local holds the local coordinates in the reference element
+  // and localGeometry.global maps them to the world coordinates
+  Dune::CurvedGeometry<DT, gridDim, dimworld, Dune::CurvedGeometryTraits<DT, Dune::LagrangeLFECache<DT,DT,gridDim>>> geometry(referenceElement(element),
+    [this,element](const auto& local) {
+      if (not stressFreeStateGridFunction_) {
+        return element.geometry().global(local);
+      }
+      auto localGridFunction = localFunction(*stressFreeStateGridFunction_);
+      localGridFunction.bind(element);
+      return localGridFunction(local);
+    }, 2); /*order*/
+#else
   auto geometry = element.geometry();
+#endif
 
   // The set of shape functions on this element
   const auto& localFiniteElement = localView.tree().finiteElement();
@@ -215,17 +250,8 @@ energy(const typename Basis::LocalView& localView,
         c += aScalar * eps[alpha][beta] * Dune::GFE::dyadicProduct(aContravariant[alpha], aContravariant[beta]);
 
 #if HAVE_DUNE_CURVEDGEOMETRY
-    // Construct a curved geometry to evaluate the derivative of the normal field on each quadrature point
-    // The variable local holds the local coordinates in the reference element
-    // and localGeometry.global maps them to the world coordinates
-    // we want to take the derivative of the normal field on the element in world coordinates
-    Dune::CurvedGeometry<DT, gridDim, dimworld, Dune::CurvedGeometryTraits<DT, Dune::LagrangeLFECache<DT,DT,gridDim>>> curvedGeometry(referenceElement(element),
-      [localGeometry=element.geometry()](const auto& local) {
-        return localGeometry.global(local);
-      }, 1); //order = 1
-
-    // Second fundamental form: The derivative of the normal field
-    auto normalDerivative = curvedGeometry.normalGradient(quad[pt].position());
+    // Second fundamental form: The derivative of the normal field, on each quadrature point
+    auto normalDerivative = geometry.normalGradient(quad[pt].position());
 #else
     //In case dune-curvedgeometry is not installed, the normal derivative is set to zero.
     Dune::FieldMatrix<double,3,3> normalDerivative(0);
diff --git a/src/cosserat-continuum.cc b/src/cosserat-continuum.cc
index fb68dbcd00eb541d27049b8843c6eafc6612b472..f3299ff541dffd08817aa22ca0c589cde4c199a1 100644
--- a/src/cosserat-continuum.cc
+++ b/src/cosserat-continuum.cc
@@ -424,6 +424,7 @@ int main (int argc, char *argv[]) try
     else
     {
       localCosseratEnergy = std::make_shared<NonplanarCosseratShellEnergy<FEBasis,3,adouble> >(materialParameters,
+                                                                                               nullptr,
                                                                                                &neumannBoundary,
                                                                                                neumannFunction,
                                                                                                volumeLoad);