Commit 1a410159 authored by Praetorius, Simon's avatar Praetorius, Simon
Browse files

Merge branch 'feature/operator_localoperator_restructuring' into 'master'

Implement Operator and LocalOperator with TypeErasure

See merge request !204
parents 832af288 453b30f3
...@@ -145,11 +145,12 @@ namespace AMDiS ...@@ -145,11 +145,12 @@ namespace AMDiS
} }
/// Assemble the matrix operators on the bound element. /// Assemble the matrix operators on the bound element.
template <class RowLocalView, class ColLocalView, template <class RowLocalView, class ColLocalView, class LocalOperators,
REQUIRES(Concepts::LocalView<RowLocalView>), REQUIRES(Concepts::LocalView<RowLocalView>),
REQUIRES(Concepts::LocalView<ColLocalView>)> REQUIRES(Concepts::LocalView<ColLocalView>)>
void assemble(RowLocalView const& rowLocalView, void assemble(RowLocalView const& rowLocalView,
ColLocalView const& colLocalView); ColLocalView const& colLocalView,
LocalOperators& localOperators);
/// Assemble all matrix operators, TODO: incorporate boundary conditions /// Assemble all matrix operators, TODO: incorporate boundary conditions
void assemble(); void assemble();
...@@ -179,8 +180,19 @@ namespace AMDiS ...@@ -179,8 +180,19 @@ namespace AMDiS
/// Set the flag that forces an update of the pattern since the underlying /// Set the flag that forces an update of the pattern since the underlying
/// basis that defines the indexset has been changed /// basis that defines the indexset has been changed
void updateImpl(event::adapt e, index_t<0> i) override { updatePattern_ = true; } void updateImpl(event::adapt e, index_t<0> i) override { updateImpl2(*rowBasis_); }
void updateImpl(event::adapt e, index_t<1> i) override { updatePattern_ = true; } void updateImpl(event::adapt e, index_t<1> i) override { updateImpl2(*colBasis_); }
auto& operators() { return operators_; }
private:
template <class Basis>
void updateImpl2(Basis const& basis)
{
if (!updatePattern_)
Recursive::forEach(operators_, [&](auto& op) { op.update(basis.gridView()); });
updatePattern_ = true;
}
protected: protected:
/// The symmetry property if the bilinear form /// The symmetry property if the bilinear form
......
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
#include <utility> #include <utility>
#include <amdis/Assembler.hpp> #include <amdis/ContextGeometry.hpp>
#include <amdis/GridFunctionOperator.hpp>
#include <amdis/LocalOperator.hpp> #include <amdis/LocalOperator.hpp>
#include <amdis/typetree/Traversal.hpp> #include <amdis/typetree/Traversal.hpp>
#include <amdis/utility/AssembleOperators.hpp>
namespace AMDiS { namespace AMDiS {
...@@ -21,42 +21,35 @@ addOperator(ContextTag contextTag, Expr const& expr, ...@@ -21,42 +21,35 @@ addOperator(ContextTag contextTag, Expr const& expr,
"col must be a valid treepath, or an integer/index-constant"); "col must be a valid treepath, or an integer/index-constant");
auto i = makeTreePath(row); auto i = makeTreePath(row);
auto node_i = child(this->rowBasis().localView().treeCache(), i);
auto j = makeTreePath(col); auto j = makeTreePath(col);
auto node_j = child(this->colBasis().localView().treeCache(), j);
using LocalContext = typename ContextTag::type; using LocalContext = typename ContextTag::type;
using Tr = DefaultAssemblerTraits<LocalContext, ElementMatrix>; auto op = makeOperator<LocalContext>(expr, this->rowBasis().gridView());
auto op = makeLocalOperator<LocalContext>(expr, this->rowBasis().gridView()); operators_[i][j].push(contextTag, std::move(op));
auto localAssembler = makeUniquePtr(makeAssembler<Tr>(std::move(op), node_i, node_j));
operators_[i][j].push(contextTag, std::move(localAssembler));
updatePattern_ = true; updatePattern_ = true;
} }
template <class RB, class CB, class T, class Traits> template <class RB, class CB, class T, class Traits>
template <class RowLocalView, class ColLocalView, template <class RowLocalView, class ColLocalView, class LocalOperators,
REQUIRES_(Concepts::LocalView<RowLocalView>), REQUIRES_(Concepts::LocalView<RowLocalView>),
REQUIRES_(Concepts::LocalView<ColLocalView>)> REQUIRES_(Concepts::LocalView<ColLocalView>)>
void BiLinearForm<RB,CB,T,Traits>:: void BiLinearForm<RB,CB,T,Traits>::
assemble(RowLocalView const& rowLocalView, ColLocalView const& colLocalView) assemble(RowLocalView const& rowLocalView, ColLocalView const& colLocalView,
LocalOperators& localOperators)
{ {
elementMatrix_.resize(rowLocalView.size(), colLocalView.size()); elementMatrix_.resize(rowLocalView.size(), colLocalView.size());
elementMatrix_ = 0; elementMatrix_ = 0;
auto const& gv = this->rowBasis().gridView(); auto const& gv = this->rowBasis().gridView();
auto const& element = rowLocalView.element(); GlobalContext context{gv, rowLocalView.element(), rowLocalView.element().geometry()};
auto geometry = element.geometry();
Traversal::forEachNode(rowLocalView.treeCache(), [&](auto const& rowNode, auto rowTp) { Traversal::forEachNode(rowLocalView.treeCache(), [&](auto const& rowNode, auto rowTp) {
Traversal::forEachNode(colLocalView.treeCache(), [&](auto const& colNode, auto colTp) { Traversal::forEachNode(colLocalView.treeCache(), [&](auto const& colNode, auto colTp) {
auto& matOp = operators_[rowTp][colTp]; auto& matOp = localOperators[rowTp][colTp];
if (matOp) { matOp.bind(context.element());
matOp.bind(element, geometry); matOp.assemble(context, rowNode, colNode, elementMatrix_);
assembleOperators(gv, element, matOp, makeMatrixAssembler(rowNode, colNode, elementMatrix_)); matOp.unbind();
matOp.unbind();
}
}); });
}); });
...@@ -72,13 +65,14 @@ assemble() ...@@ -72,13 +65,14 @@ assemble()
auto colLocalView = this->colBasis().localView(); auto colLocalView = this->colBasis().localView();
this->init(); this->init();
auto localOperators = AMDiS::localOperators(operators_);
for (auto const& element : elements(this->rowBasis().gridView(), typename Traits::PartitionSet{})) { for (auto const& element : elements(this->rowBasis().gridView(), typename Traits::PartitionSet{})) {
rowLocalView.bind(element); rowLocalView.bind(element);
if (this->rowBasis_ == this->colBasis_) if (this->rowBasis_ == this->colBasis_)
this->assemble(rowLocalView, rowLocalView); this->assemble(rowLocalView, rowLocalView, localOperators);
else { else {
colLocalView.bind(element); colLocalView.bind(element);
this->assemble(rowLocalView, colLocalView); this->assemble(rowLocalView, colLocalView, localOperators);
colLocalView.unbind(); colLocalView.unbind();
} }
rowLocalView.unbind(); rowLocalView.unbind();
......
...@@ -21,8 +21,6 @@ install(FILES ...@@ -21,8 +21,6 @@ install(FILES
AdaptiveGrid.hpp AdaptiveGrid.hpp
AdaptStationary.hpp AdaptStationary.hpp
AMDiS.hpp AMDiS.hpp
Assembler.hpp
AssemblerInterface.hpp
BackupRestore.hpp BackupRestore.hpp
BiLinearForm.hpp BiLinearForm.hpp
BiLinearForm.inc.hpp BiLinearForm.inc.hpp
...@@ -56,6 +54,7 @@ install(FILES ...@@ -56,6 +54,7 @@ install(FILES
MeshCreator.hpp MeshCreator.hpp
Observer.hpp Observer.hpp
Operations.hpp Operations.hpp
Operator.hpp
OperatorList.hpp OperatorList.hpp
Output.hpp Output.hpp
PeriodicBC.hpp PeriodicBC.hpp
......
...@@ -160,4 +160,51 @@ namespace AMDiS ...@@ -160,4 +160,51 @@ namespace AMDiS
mutable std::optional<LocalGeometry> localGeometry_; mutable std::optional<LocalGeometry> localGeometry_;
}; };
template <class GV>
class GlobalContext
{
public:
using GridView = GV;
using Element = typename GV::template Codim<0>::Entity;
using Geometry = typename Element::Geometry;
enum {
dim = GridView::dimension, //< the dimension of the grid element
dow = GridView::dimensionworld //< the dimension of the world
};
/// Constructor. Stores a copy of gridView and a pointer to element and geometry.
GlobalContext(GridView const& gridView, Element const& element,
Geometry const& geometry)
: gridView_(gridView)
, element_(&element)
, geometry_(&geometry)
{}
public:
/// Return the GridView this context is bound to
GridView const& gridView() const
{
return gridView_;
}
/// Return the bound element (entity of codim 0)
Element const& element() const
{
return *element_;
}
/// Return the geometry of the \ref Element
Geometry const& geometry() const
{
return *geometry_;
}
private:
GridView gridView_;
Element const* element_;
Geometry const* geometry_;
};
} // end namespace AMDiS } // end namespace AMDiS
...@@ -3,11 +3,9 @@ ...@@ -3,11 +3,9 @@
#include <cassert> #include <cassert>
#include <type_traits> #include <type_traits>
#include <amdis/GridFunctions.hpp> #include <amdis/GridFunctionOperatorTransposed.hpp>
#include <amdis/LocalOperator.hpp> #include <amdis/common/Order.hpp>
#include <amdis/common/Transposed.hpp>
#include <amdis/common/TypeTraits.hpp> #include <amdis/common/TypeTraits.hpp>
#include <amdis/typetree/FiniteElementType.hpp>
#include <amdis/utility/QuadratureFactory.hpp> #include <amdis/utility/QuadratureFactory.hpp>
namespace AMDiS namespace AMDiS
...@@ -17,256 +15,246 @@ namespace AMDiS ...@@ -17,256 +15,246 @@ namespace AMDiS
* @{ * @{
**/ **/
/// \brief The main implementation of an CRTP-base class for operators using a grid-function template <class LF, class Imp>
/// coefficient to be used in an \ref Assembler. class GridFunctionLocalOperator;
/// \brief The main implementation of an operator depending on a grid-function
/** /**
* An Operator that takes a GridFunction as coefficient. * An Operator that takes a grid-function as coefficient.
* Provides quadrature rules and handles the evaluation of the GridFunction at * Generates a \ref GridFunctionLocalOperator on \ref localOperator()
* local coordinates.
* *
* The class is specialized, by deriving from it, in \ref GridFunctionOperator. * The class implements the interface of an \ref Operator.
* *
* \tparam Derived The class derived from GridFunctionOperatorBase * \tparam GF The class type of the grid-function
* \tparam LC The Element or Intersection type * \tparam Imp Class providing the local assembling method, forwarded to
* \tparam GF The GridFunction, a LocalFunction is created from, and * GridFunctionLocalOperator class
* that is evaluated at quadrature points.
* *
* **Requirements:** * **Requirements:**
* - `LC` models either Entity (of codim 0) or Intersection
* - `GF` models the \ref Concepts::GridFunction * - `GF` models the \ref Concepts::GridFunction
**/ **/
template <class Derived, class LC, class GF> template <class GF, class Imp>
class GridFunctionOperatorBase class GridFunctionOperator
: public LocalOperator<Derived, LC>
{ {
template <class, class> public:
friend class LocalOperator;
using ContextType = Impl::ContextTypes<LC>;
using Super = LocalOperator<Derived, LC>;
private:
using GridFunction = GF; using GridFunction = GF;
using Implementation = Imp;
/// The type of the localFunction associated with the GridFunction /// \brief Constructor. Stores a copy of `gridFct` and `impl`.
using LocalFunction = decltype(localFunction(std::declval<GF>()));
/// The Codim=0 entity of the grid, the localFunction can be bound to
using Element = typename ContextType::Entity;
/// The geometry-type of the grid element
using Geometry = typename Element::Geometry;
/// The type of the local coordinates in the \ref Element
using LocalCoordinate = typename GF::EntitySet::LocalCoordinate;
/// A factory for QuadratureRules that incooperate the order of the LocalFunction
using QuadFactory = QuadratureFactory<typename Geometry::ctype, LC::mydimension, LocalFunction>;
public:
/// \brief Constructor. Stores a copy of `gridFct`.
/** /**
* A GridFunctionOperator takes a gridFunction and the * A GridFunctionOperator takes a grid-function and
* differentiation order of the operator, to calculate the * an implementation class for the assemble method.
* quadrature degree in \ref getDegree.
**/ **/
template <class GridFct> template <class GridFct, class Impl>
GridFunctionOperatorBase(GridFct&& gridFct, int termOrder) GridFunctionOperator(GridFct&& gridFct, Impl&& impl,
int derivDeg, int gridFctOrder)
: gridFct_(FWD(gridFct)) : gridFct_(FWD(gridFct))
, termOrder_(termOrder) , impl_(FWD(impl))
, derivDeg_(derivDeg)
, gridFctOrder_(gridFctOrder)
{} {}
/// Create a quadrature factory from a PreQuadratureFactory, e.g. class derived from \ref QuadratureFactory template <class GridView>
/// \tparam PQF A PreQuadratureFactory void update(GridView const&) { /* do nothing */ }
template <class PQF>
void setQuadFactory(PQF&& pre)
{
using ctype = typename Geometry::ctype;
quadFactory_ = makeUniquePtr(
makeQuadratureFactory<ctype, LC::mydimension, LocalFunction>(FWD(pre)));
}
protected:
/// Return expression value at LocalCoordinates
auto coefficient(LocalCoordinate const& local) const
{
assert( this->bound_ );
return (*localFct_)(local);
}
/// Create a quadrature rule using the \ref QuadratureFactory by calculating the friend auto localOperator(GridFunctionOperator const& op)
/// quadrature order necessary to integrate the (bi)linear-form accurately.
template <class... Nodes>
auto const& getQuadratureRule(Dune::GeometryType type, Nodes const&... nodes) const
{ {
assert( bool(quadFactory_) ); return GridFunctionLocalOperator{localFunction(op.gridFct_), op.impl_,
int quadDegree = this->getDegree(termOrder_, quadFactory_->order(), nodes...); op.derivDeg_, op.gridFctOrder_};
return quadFactory_->rule(type, quadDegree);
} }
private: private:
/// \brief Binds operator to `element` and `geometry`. /// The grid-function associated to this operator
/**
* Binding an operator to the currently visited element in grid traversal.
* Since all operators need geometry information, the `Element::Geometry`
* object `geometry` is created once, during grid traversal, and passed in.
*
* By default, it binds the \ref localFct_ to the `element` and the Quadrature
* factory to the localFunction.
**/
void bind_impl(Element const& element, Geometry const& geometry)
{
assert( bool(quadFactory_) );
localFct_.emplace(localFunction(gridFct_));
localFct_->bind(element);
quadFactory_->bind(localFct_.value());
}
/// Unbinds operator from element.
void unbind_impl()
{
localFct_->unbind();
localFct_.reset();
}
private:
/// The gridFunction to be used within the operator
GridFunction gridFct_; GridFunction gridFct_;
/// localFunction associated with gridFunction. Mus be updated whenever gridFunction changes. /// An implementation class for the assembling
std::optional<LocalFunction> localFct_; Implementation impl_;
/// Assign each element type a quadrature rule /// Maximal degree of derivative this operator represents
std::shared_ptr<QuadFactory> quadFactory_; int derivDeg_;
/// the derivative order of this operator (in {0, 1, 2}) /// Polynomial degree of the grid-function (or -1)
int termOrder_ = 0; int gridFctOrder_;
}; };
template <class GridFct, class Impl>
GridFunctionOperator(GridFct const& gridFct, Impl const& impl, int, int)
-> GridFunctionOperator<GridFct, Impl>;
/// \brief The transposed operator, implemented in term of its transposed by /// \brief The main implementation of a local-operator depending on a local-function
/// calling \ref getElementMatrix with inverted arguments. /**
template <class Derived, class Transposed> * A LocalOperator that takes a local-function as coefficient.
class GridFunctionOperatorTransposed * Provides quadrature rules and passes the local-function, bound to an element,
: public LocalOperator<Derived, typename Transposed::LocalContext> * to the assemble method of an implementation class.
*
* The class implements the interface of a \ref LocalOperator.
*
* \tparam LF The class type of the local-function
* \tparam Imp Class providing the local assembling method
*
* **Requirements:**
* - `LF` models the \ref Concepts::LocalFunction
**/
template <class LF, class Imp>
class GridFunctionLocalOperator
{ {
template <class, class> private:
friend class LocalOperator; /// The type of the localFunction
using LocalFunction = LF;
template <class T, class... Args> /// Type of the implementation class
using Constructable = decltype( new T(std::declval<Args>()...) ); using Implementation = Imp;
public: public:
template <class... Args, /// \brief Constructor. Stores a copy of `localFct` and `impl`.
std::enable_if_t<Dune::Std::is_detected<Constructable, Transposed, Args...>::value, int> = 0> /**
GridFunctionOperatorTransposed(Args&&... args) * A GridFunctionLocalOperator takes a local-function, an implementation class,
: transposedOp_(FWD(args)...) * the differentiation order of the operator and the local-function polynomial
* degree, to calculate the quadrature degree of the operator
**/
template <class LocalFct, class Impl>
GridFunctionLocalOperator(LocalFct&& localFct, Impl&& impl,
int derivDeg, int localFctOrder)
: localFct_(FWD(localFct))
, impl_(FWD(impl))
, derivDeg_(derivDeg)
, localFctOrder_(localFctOrder)
{} {}
/// Redirects the setQuadFactory call top the transposed operator /// \brief Binds operator to `element`.
/// \tparam PQF A PreQuadratureFactory /**
template <class PQF> * By default, it binds the \ref localFct_ to the `element`.
void setQuadFactory(PQF&& pre) **/
template <class Element>
void bind(Element const& element)
{ {
transposedOp_.setQuadFactory(FWD(pre)); localFct_.bind(element);
} }
private: /// Unbinds operator from element.
/// Redirects the bind call top the transposed operator void unbind()
template <class Element, class Geometry>
void bind_impl(Element const& element, Geometry const& geometry)
{ {
transposedOp_.bind(element, geometry); localFct_.unbind();
} }
/// Redirects the unbind call top the transposed operator /// Assemble a local element matrix on the element that is bound.
void unbind_impl() /**
* This function calls the assemble method from the impl_ class and
* additionally passes a quadrature rule and the localFct_ to that method.
**/
template <class CG, class RN, class CN, class Mat>
void assemble(CG const& contextGeo, RN const& rowNode, CN const& colNode,
Mat& elementMatrix) const
{ {
transposedOp_.unbind(); auto const& quad = getQuadratureRule(contextGeo.localContext().geometry(),
derivDeg_, localFctOrder(), rowNode, colNode);
impl().assemble(contextGeo, rowNode, colNode, quad, localFct_, elementMatrix);
} }
/// Apply the assembling to the transposed elementMatrix with interchanged row-/colNode /// Assemble a local element vector on the element that is bound.
/** /**
* \tparam CG ContextGeometry * This function calls the assemble method from the impl_ class and
* \tparam RN RowNode * additionally passes a quadrature rule and the localFct_ to that method.
* \tparam CN ColNode
* \tparam Mat ElementMatrix
**/ **/
template <class CG, class RN, class CN, class Mat> template <class CG, class Node, class Vec>
void getElementMatrix(CG const& contextGeometry, RN const& rowNode, CN const& colNode, Mat& elementMatrix) void assemble(CG const& contextGeo, Node const& node,
Vec& elementVector) const
{ {
auto elementMatrixTransposed = transposed(elementMatrix); auto const& quad = getQuadratureRule(contextGeo.localContext().geometry(),
transposedOp_.getElementMatrix( derivDeg_, localFctOrder(), node);
contextGeometry, colNode, rowNode, elementMatrixTransposed); impl().assemble(contextGeo, node, quad, localFct_, elementVector);
}
Implementation & impl() { return impl_; }
Implementation const& impl() const { return impl_; }
protected:
// calculated polynomial order of local-function. Fallback to localFctOrder_;
int localFctOrder() const
{
if constexpr (Concepts::Polynomial<LF>)
return order(localFct_);
else
return localFctOrder_;
} }
private: private:
Transposed transposedOp_; /// The local-function to be used within the operator
LocalFunction localFct_;
/// Implementation details of the assembling
Implementation impl_;
/// Maximal degree of derivative this operator represents
int derivDeg_;
/// Polynomial degree of the local-function (or -1)
int localFctOrder_;
}; };
// deduction guide