diff --git a/dune/gfe/nonplanarcosseratshellenergy.hh b/dune/gfe/nonplanarcosseratshellenergy.hh
index 09e4c0c22d009bccfe56bce3c7bccacdaf727aa0..8fd0aceb4de0effe76dc7d854a990bf5cd7396d2 100644
--- a/dune/gfe/nonplanarcosseratshellenergy.hh
+++ b/dune/gfe/nonplanarcosseratshellenergy.hh
@@ -16,6 +16,10 @@
 #include <dune/gfe/tensor3.hh>
 #include <dune/gfe/localprojectedfefunction.hh>
 
+#if HAVE_DUNE_CURVEDGEOMETRY
+#include <dune/curvedgeometry/curvedgeometry.hh>
+#include <dune/localfunctions/lagrange/lfecache.hh>
+#endif
 
 template<class Basis, int dim, class field_type=double>
 class NonplanarCosseratShellEnergy
@@ -39,12 +43,10 @@ public:
    * \param parameters The material parameters
    */
   NonplanarCosseratShellEnergy(const Dune::ParameterTree& parameters,
-                               const std::vector<UnitVector<double,3> >& vertexNormals,
                                const BoundaryPatch<GridView>* neumannBoundary,
                                const std::function<Dune::FieldVector<double,3>(Dune::FieldVector<double,dimworld>)> neumannFunction,
                                const std::function<Dune::FieldVector<double,3>(Dune::FieldVector<double,dimworld>)> volumeLoad)
-  : vertexNormals_(vertexNormals),
-    neumannBoundary_(neumannBoundary),
+  : neumannBoundary_(neumannBoundary),
     neumannFunction_(neumannFunction),
     volumeLoad_(volumeLoad)
   {
@@ -109,9 +111,6 @@ public:
   /** \brief Curvature parameters */
   double b1_, b2_, b3_;
 
-  /** \brief The normal vectors at the grid vertices.  This are used to compute the reference surface curvature. */
-  std::vector<UnitVector<double,3> > vertexNormals_;
-
   /** \brief The Neumann boundary */
   const BoundaryPatch<GridView>* neumannBoundary_;
 
@@ -135,22 +134,6 @@ energy(const typename Basis::LocalView& localView,
   // The set of shape functions on this element
   const auto& localFiniteElement = localView.tree().finiteElement();
 
-  ////////////////////////////////////////////////////////////////////////////////////
-  //  Construct a linear (i.e., non-constant!) normal field on each element
-  ////////////////////////////////////////////////////////////////////////////////////
-  auto gridView = localView.globalBasis().gridView();
-
-  assert(vertexNormals_.size() == gridView.indexSet().size(gridDim));
-  std::vector<UnitVector<double,3> > cornerNormals(element.subEntities(gridDim));
-  for (size_t i=0; i<cornerNormals.size(); i++)
-    cornerNormals[i] = vertexNormals_[gridView.indexSet().subIndex(element,i,2)];
-
-  typedef typename Dune::PQkLocalFiniteElementCache<DT, double, gridDim, 1> P1FiniteElementCache;
-  typedef typename P1FiniteElementCache::FiniteElementType P1LocalFiniteElement;
-  P1FiniteElementCache p1FiniteElementCache;
-  const auto& p1LocalFiniteElement = p1FiniteElementCache.get(element.type());
-  Dune::GFE::LocalProjectedFEFunction<gridDim, DT, P1LocalFiniteElement, UnitVector<double,3> > unitNormals(p1LocalFiniteElement, cornerNormals);
-
   ////////////////////////////////////////////////////////////////////////////////////
   //  Set up the local nonlinear finite element function
   ////////////////////////////////////////////////////////////////////////////////////
@@ -231,9 +214,22 @@ energy(const typename Basis::LocalView& localView,
       for (int beta=0; beta<2; beta++)
         c += aScalar * eps[alpha][beta] * Dune::GFE::dyadicProduct(aContravariant[alpha], aContravariant[beta]);
 
-    // Second fundamental form
-    // The derivative of the normal field
-    auto normalDerivative = unitNormals.evaluateDerivative(quadPos);
+#if HAVE_DUNE_CURVEDGEOMETRY
+    // Construct a curved geometry to evaluate the derivative of the normal field on each quadrature point
+    // The variable local holds the local coordinates in the reference element
+    // and localGeometry.global maps them to the world coordinates
+    // we want to take the derivative of the normal field on the element in world coordinates
+    Dune::CurvedGeometry<DT, gridDim, dimworld, Dune::CurvedGeometryTraits<DT, Dune::LagrangeLFECache<DT,DT,gridDim>>> curvedGeometry(referenceElement(element),
+      [localGeometry=element.geometry()](const auto& local) {
+        return localGeometry.global(local);
+      }, 1); //order = 1
+
+    // Second fundamental form: The derivative of the normal field
+    auto normalDerivative = curvedGeometry.normalGradient(quad[pt].position());
+#else
+    //In case dune-curvedgeometry is not installed, the normal derivative is set to zero.
+    Dune::FieldMatrix<double,3,3> normalDerivative(0);
+#endif
 
     Dune::FieldMatrix<double,3,3> b(0);
     for (int alpha=0; alpha<gridDim; alpha++)
diff --git a/src/cosserat-continuum.cc b/src/cosserat-continuum.cc
index 3ed30be55becc0c59b444c9ce426abdba0abe555..eb763ee627ae9684bef0a6436126649715778155 100644
--- a/src/cosserat-continuum.cc
+++ b/src/cosserat-continuum.cc
@@ -47,7 +47,6 @@
 #include <dune/gfe/cosseratvtkreader.hh>
 #include <dune/gfe/geodesicfeassembler.hh>
 #include <dune/gfe/riemanniantrsolver.hh>
-#include <dune/gfe/vertexnormals.hh>
 #include <dune/gfe/embeddedglobalgfefunction.hh>
 #include <dune/gfe/mixedgfeassembler.hh>
 #include <dune/gfe/mixedriemanniantrsolver.hh>
@@ -56,6 +55,7 @@
 #include <dune/vtk/vtkreader.hh>
 #endif
 
+
 // grid dimension
 const int dim = 2;
 const int dimworld = 2;
@@ -440,9 +440,7 @@ int main (int argc, char *argv[]) try
     }
     else
     {
-      std::vector<UnitVector<double,3> > vertexNormals = computeVertexNormals(gridView);
       localCosseratEnergy = std::make_shared<NonplanarCosseratShellEnergy<FEBasis,3,adouble> >(materialParameters,
-                                                                                               std::move(vertexNormals),
                                                                                                &neumannBoundary,
                                                                                                neumannFunction,
                                                                                                volumeLoad);