diff --git a/dune/gfe/geodesicfeassembler.hh b/dune/gfe/geodesicfeassembler.hh index c6958e50a890e5a995849d64dbe119a447750e6f..7be9133aa8839008cef76454992ffd251fac41f4 100644 --- a/dune/gfe/geodesicfeassembler.hh +++ b/dune/gfe/geodesicfeassembler.hh @@ -8,13 +8,16 @@ #include "localgeodesicfestiffness.hh" +#include <dune/solvers/common/wrapownshare.hh> /** \brief A global FE assembler for problems involving functions that map into non-Euclidean spaces */ template <class Basis, class TargetSpace> class GeodesicFEAssembler { + using field_type = typename TargetSpace::field_type; typedef typename Basis::GridView GridView; + using LocalStiffness = LocalGeodesicFEStiffness<Basis, TargetSpace>; //! Dimension of the grid. enum { gridDim = GridView::dimension }; @@ -25,29 +28,58 @@ class GeodesicFEAssembler { //! typedef Dune::FieldMatrix<double, blocksize, blocksize> MatrixBlock; -public: - const Basis basis_; protected: - LocalGeodesicFEStiffness<Basis,TargetSpace>* localStiffness_; + //! The global basis + const Basis basis_; + + //! The local stiffness operator + std::shared_ptr<LocalStiffness> localStiffness_; public: /** \brief Constructor for a given grid */ + GeodesicFEAssembler(const Basis& basis) + : basis_(basis) + {} + + /** \brief Constructor for a given grid */ + template <class LocalStiffnessT> GeodesicFEAssembler(const Basis& basis, - LocalGeodesicFEStiffness<Basis, TargetSpace>* localStiffness) + LocalStiffnessT&& localStiffness) : basis_(basis), - localStiffness_(localStiffness) + localStiffness_(Dune::Solvers::wrap_own_share<LocalStiffness>(std::forward<LocalStiffnessT>(localStiffness))) {} + /** \brief Set the local stiffness assembler. This can be a temporary, l-value or shared pointer. */ + template <class LocalStiffnessT> + void setLocalStiffness(LocalStiffnessT&& localStiffness) { + localStiffness_ = Dune::Solvers::wrap_own_share<LocalStiffness>(std::forward<LocalStiffnessT>(localStiffness)); + } + + /** \brief Get the local stiffness operator. */ + const LocalStiffness& getLocalStiffness() const { + return *localStiffness_; + } + + /** \brief Get the local stiffness operator. */ + LocalStiffness& getLocalStiffness() { + return *localStiffness_; + } + + /** \brief Get the basis. */ + const Basis& getBasis() const { + return basis_; + } + /** \brief Assemble the tangent stiffness matrix and the functional gradient together * * This is more efficient than computing them separately, because you need the gradient * anyway to compute the Riemannian Hessian. */ virtual void assembleGradientAndHessian(const std::vector<TargetSpace>& sol, - Dune::BlockVector<Dune::FieldVector<double, blocksize> >& gradient, + Dune::BlockVector<Dune::FieldVector<field_type, blocksize> >& gradient, Dune::BCRSMatrix<MatrixBlock>& hessian, bool computeOccupationPattern=true) const; @@ -103,7 +135,7 @@ getNeighborsPerVertex(Dune::MatrixIndexSet& nb) const template <class Basis, class TargetSpace> void GeodesicFEAssembler<Basis,TargetSpace>:: assembleGradientAndHessian(const std::vector<TargetSpace>& sol, - Dune::BlockVector<Dune::FieldVector<double, blocksize> >& gradient, + Dune::BlockVector<Dune::FieldVector<field_type, blocksize> > &gradient, Dune::BCRSMatrix<MatrixBlock>& hessian, bool computeOccupationPattern) const { diff --git a/dune/gfe/geodesicfeassemblerwrapper.hh b/dune/gfe/geodesicfeassemblerwrapper.hh index 8dee1f1ffc276fadf8180a460c13b10eff6f3cc0..4570b4077798ce8fbe3e14a55444e91adf4ec1bc 100644 --- a/dune/gfe/geodesicfeassemblerwrapper.hh +++ b/dune/gfe/geodesicfeassemblerwrapper.hh @@ -62,6 +62,11 @@ public: /** \brief Get the occupation structure of the Hessian */ virtual void getNeighborsPerVertex(Dune::MatrixIndexSet& nb) const; + /** \brief Get the basis. */ + const ScalarBasis& getBasis() const { + return basis_; + } + private: Dune::TupleVector<std::vector<MixedSpace0>,std::vector<MixedSpace1>> splitVector(const std::vector<TargetSpace>& sol) const; std::unique_ptr<MatrixType> hessianMixed_; @@ -186,4 +191,4 @@ computeEnergy(const std::vector<TargetSpace>& sol) const auto solutionSplit = splitVector(sol); return mixedAssembler_->computeEnergy(solutionSplit[_0], solutionSplit[_1]); } -#endif //GLOBAL_GEODESIC_FE_ASSEMBLERWRAPPER_HH \ No newline at end of file +#endif //GLOBAL_GEODESIC_FE_ASSEMBLERWRAPPER_HH diff --git a/dune/gfe/riemanniantrsolver.cc b/dune/gfe/riemanniantrsolver.cc index b45d502d01d5cf19495d1430d9e31c572112b456..88d4e7713aa48b760478a93cbd3ddeb323979acc 100644 --- a/dune/gfe/riemanniantrsolver.cc +++ b/dune/gfe/riemanniantrsolver.cc @@ -117,7 +117,7 @@ setup(const GridType& grid, // ////////////////////////////////////////////////////////////////////////////////////// typedef DuneFunctionsBasis<Basis> FufemBasis; - FufemBasis basis(assembler_->basis_); + FufemBasis basis(assembler_->getBasis()); OperatorAssembler<FufemBasis,FufemBasis> operatorAssembler(basis, basis); LaplaceAssembler<GridType, typename FufemBasis::LocalFiniteElement, typename FufemBasis::LocalFiniteElement> laplaceStiffness; diff --git a/dune/gfe/rodassembler.cc b/dune/gfe/rodassembler.cc index 797c006e8e167509ee3b61731ecfbbdd6e9b8da8..32300bd3fe334971b8bd850a69255938bb5afbfa 100644 --- a/dune/gfe/rodassembler.cc +++ b/dune/gfe/rodassembler.cc @@ -98,7 +98,7 @@ getStrain(const std::vector<RigidBodyMotion<double,3> >& sol, double weight = quad[pt].weight() * element.geometry().integrationElement(quadPos); - FieldVector<double,blocksize> localStrain = dynamic_cast<RodLocalStiffness<GridView, double>* >(this->localStiffness_)->getStrain(localSolution, element, quad[pt].position()); + FieldVector<double,blocksize> localStrain = std::dynamic_pointer_cast<RodLocalStiffness<GridView, double> >(this->localStiffness_)->getStrain(localSolution, element, quad[pt].position()); // Sum it all up strain[elementIdx].axpy(weight, localStrain); diff --git a/dune/gfe/rodassembler.hh b/dune/gfe/rodassembler.hh index a800cc8ba1bc1626a4ca33ce16483b228fea3507..64c1327ecae83ff146d6675be04243003a8158ae 100644 --- a/dune/gfe/rodassembler.hh +++ b/dune/gfe/rodassembler.hh @@ -41,7 +41,7 @@ class RodAssembler<Basis,3> : public GeodesicFEAssembler<Basis, RigidBodyMotion< public: //! ??? RodAssembler(const Basis& basis, - LocalGeodesicFEStiffness<Basis, RigidBodyMotion<double,3> >* localStiffness) + LocalGeodesicFEStiffness<Basis, RigidBodyMotion<double,3> >& localStiffness) : GeodesicFEAssembler<Basis, RigidBodyMotion<double,3> >(basis,localStiffness) { std::vector<RigidBodyMotion<double,3> > referenceConfiguration(basis.size()); @@ -62,7 +62,7 @@ public: auto rodEnergy() { // TODO: Does not work for other stiffness implementations - auto localFDStiffness = dynamic_cast<LocalGeodesicFEFDStiffness<Basis, RigidBodyMotion<double,3> >*>(this->localStiffness_); + auto localFDStiffness = std::dynamic_pointer_cast<LocalGeodesicFEFDStiffness<Basis, RigidBodyMotion<double,3> > >(this->localStiffness_); return const_cast<RodLocalStiffness<GridView,double>*>(dynamic_cast<const RodLocalStiffness<GridView,double>*>(localFDStiffness->localEnergy_)); } diff --git a/src/cosserat-continuum.cc b/src/cosserat-continuum.cc index a64eefe9aac68578bd3b64fb6de8474c524b0706..294d78a0beee1011538b0ef7109decd7cd7f1523 100644 --- a/src/cosserat-continuum.cc +++ b/src/cosserat-continuum.cc @@ -462,7 +462,7 @@ int main (int argc, char *argv[]) try LocalGeodesicFEADOLCStiffness<FEBasis, TargetSpace> localGFEADOLCStiffness(localCosseratEnergy.get()); - GeodesicFEAssembler<FEBasis,TargetSpace> assembler(gridView, &localGFEADOLCStiffness); + GeodesicFEAssembler<FEBasis,TargetSpace> assembler(gridView, localGFEADOLCStiffness); #endif // ///////////////////////////////////////////////// diff --git a/src/gradient-flow.cc b/src/gradient-flow.cc index 3879b56e0fd23777645367382f2b99439832b2e3..ac87cc85783c30796c7d2cceb34b054d895f5ed5 100644 --- a/src/gradient-flow.cc +++ b/src/gradient-flow.cc @@ -211,7 +211,7 @@ int main (int argc, char *argv[]) try LocalGeodesicFEADOLCStiffness<FEBasis,TargetSpace> localGFEADOLCStiffness(sumEnergy.get()); - GeodesicFEAssembler<FEBasis,TargetSpace> assembler(feBasis, &localGFEADOLCStiffness); + GeodesicFEAssembler<FEBasis,TargetSpace> assembler(feBasis, localGFEADOLCStiffness); /////////////////////////////////////////////////// // Create a Riemannian trust-region solver diff --git a/src/harmonicmaps.cc b/src/harmonicmaps.cc index 3cfa88935e56fe0dbc0d73e4d15b308097622056..68c123ebf6a6a324e6b0e665f84e5d420cd31c38 100644 --- a/src/harmonicmaps.cc +++ b/src/harmonicmaps.cc @@ -305,7 +305,7 @@ int main (int argc, char *argv[]) LocalGeodesicFEADOLCStiffness<FEBasis,TargetSpace> localGFEADOLCStiffness(localEnergy.get()); - GeodesicFEAssembler<FEBasis,TargetSpace> assembler(feBasis, &localGFEADOLCStiffness); + GeodesicFEAssembler<FEBasis,TargetSpace> assembler(feBasis, localGFEADOLCStiffness); // ///////////////////////////////////////////////// // Create a Riemannian trust-region solver diff --git a/src/rod3d.cc b/src/rod3d.cc index 90ac7842faf41c97a7c261a70d45e12982fe0cd6..8d4a9abcac5d91b5051fddb28b0debbe0e5a30ca 100644 --- a/src/rod3d.cc +++ b/src/rod3d.cc @@ -130,7 +130,7 @@ int main (int argc, char *argv[]) try LocalGeodesicFEFDStiffness<FEBasis,RigidBodyMotion<double,3> > localFDStiffness(&localStiffness); - RodAssembler<FEBasis,3> rodAssembler(gridView, &localFDStiffness); + RodAssembler<FEBasis,3> rodAssembler(gridView, localFDStiffness); RiemannianTrustRegionSolver<FEBasis,RigidBodyMotion<double,3> > rodSolver; diff --git a/test/frameinvariancetest.cc b/test/frameinvariancetest.cc index 56c00781b97548aa00aecbdbc3ea8db1501b53c6..dcf53d480975d78a4b2fce61765376dd61a5e121 100644 --- a/test/frameinvariancetest.cc +++ b/test/frameinvariancetest.cc @@ -79,7 +79,7 @@ int main (int argc, char *argv[]) try LocalGeodesicFEFDStiffness<FEBasis,RigidBodyMotion<double,3> > localFDStiffness(&localRodFirstOrderModel); - RodAssembler<FEBasis,3> assembler(feBasis, &localFDStiffness); + RodAssembler<FEBasis,3> assembler(feBasis, localFDStiffness); if (std::abs(assembler.computeEnergy(x) - assembler.computeEnergy(rotatedX)) > 1e-6) DUNE_THROW(Dune::Exception, "Rod energy not invariant under rigid body motions!"); diff --git a/test/harmonicmaptest.cc b/test/harmonicmaptest.cc index 1b11a59d11f49b89f9b64fc3f1b92e3b745d269e..adc1467b0ac1e2da8e3dd219f3ec9c13a1d61f61 100644 --- a/test/harmonicmaptest.cc +++ b/test/harmonicmaptest.cc @@ -160,7 +160,7 @@ int main (int argc, char *argv[]) LocalGeodesicFEADOLCStiffness<FEBasis,TargetSpace> localGFEADOLCStiffness(localEnergy.get()); - GeodesicFEAssembler<FEBasis,TargetSpace> assembler(feBasis, &localGFEADOLCStiffness); + GeodesicFEAssembler<FEBasis,TargetSpace> assembler(feBasis, localGFEADOLCStiffness); /////////////////////////////////////////////////// // Create a Riemannian trust-region solver diff --git a/test/rodassemblertest.cc b/test/rodassemblertest.cc index 20c186d3ee3c1ff42fee6676ad14c7a560ad5179..554007834a7010f3f3a0eb1f35572d9ac4d88685 100644 --- a/test/rodassemblertest.cc +++ b/test/rodassemblertest.cc @@ -553,7 +553,7 @@ int main (int argc, char *argv[]) try LocalGeodesicFEFDStiffness<Basis,RigidBodyMotion<double,3> > localFDStiffness(&localStiffness); - RodAssembler<Basis,3> rodAssembler(basis, &localFDStiffness); + RodAssembler<Basis,3> rodAssembler(basis, localFDStiffness); std::cout << "Energy: " << rodAssembler.computeEnergy(x) << std::endl;