From 0689eaaeb3a4cd72c097e9edbbafa0e5a64288be Mon Sep 17 00:00:00 2001 From: "Praetorius, Simon" <simon.praetorius@tu-dresden.de> Date: Wed, 13 Mar 2019 12:46:28 +0100 Subject: [PATCH] Reduce errors and compiletime in tree traversal --- cmake/modules/AmdisCXXFeatures.cmake | 14 +- config.h.cmake | 1 + examples/CMakeLists.txt | 2 +- src/amdis/BoundaryManager.hpp | 7 +- src/amdis/DataTransfer.inc.hpp | 14 +- src/amdis/DirichletBC.hpp | 2 +- src/amdis/PeriodicBC.inc.hpp | 2 +- src/amdis/ProblemStat.inc.hpp | 14 +- src/amdis/common/ForEach.hpp | 39 +- src/amdis/common/TypeTraits.hpp | 6 + .../gridfunctions/DiscreteFunction.inc.hpp | 4 +- src/amdis/linearalgebra/DOFMatrixBase.inc.hpp | 6 +- src/amdis/linearalgebra/DOFVectorBase.inc.hpp | 4 +- src/amdis/typetree/CMakeLists.txt | 1 - src/amdis/typetree/Traversal.hpp | 350 ++++++++++++------ src/amdis/typetree/TreeData.hpp | 8 +- src/amdis/typetree/Visitor.hpp | 120 ------ test/DataTransferTest.hpp | 4 +- test/TreeDataTest.cpp | 8 +- 19 files changed, 319 insertions(+), 287 deletions(-) delete mode 100644 src/amdis/typetree/Visitor.hpp diff --git a/cmake/modules/AmdisCXXFeatures.cmake b/cmake/modules/AmdisCXXFeatures.cmake index a033c651..58823c5e 100644 --- a/cmake/modules/AmdisCXXFeatures.cmake +++ b/cmake/modules/AmdisCXXFeatures.cmake @@ -1,6 +1,4 @@ -#include(CheckIncludeFileCXX) include(CheckCXXSourceCompiles) -#include(CheckCXXSymbolExists) # fold expressions (a + ...) check_cxx_source_compiles(" @@ -30,4 +28,16 @@ check_cxx_source_compiles(" return f<1>(); } " AMDIS_HAS_CXX_CONSTEXPR_IF +) + +check_cxx_source_compiles(" + #include <iostream> + #include <tuple> + int main() + { + auto tup = std::make_tuple(0, 'a', 3.14); + for... (auto elem : tup) + std::cout << elem << std::endl; + } +" AMDIS_HAS_EXPANSION_STATEMENTS ) \ No newline at end of file diff --git a/config.h.cmake b/config.h.cmake index a9f4d0ed..471450a3 100644 --- a/config.h.cmake +++ b/config.h.cmake @@ -49,6 +49,7 @@ /* some detected compiler features may be used in AMDiS */ #cmakedefine AMDIS_HAS_CXX_FOLD_EXPRESSIONS 1 #cmakedefine AMDIS_HAS_CXX_CONSTEXPR_IF 1 +#cmakedefine AMDIS_HAS_EXPANSION_STATEMENTS 1 /* end amdis Everything below here will be overwritten diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index a22ffee3..a2a8ee3f 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -24,4 +24,4 @@ add_dependencies(examples stokes1.2d stokes3.2d navier_stokes.2d - convection_diffusion.2d) \ No newline at end of file + convection_diffusion.2d) diff --git a/src/amdis/BoundaryManager.hpp b/src/amdis/BoundaryManager.hpp index 375c334a..a43e28d1 100644 --- a/src/amdis/BoundaryManager.hpp +++ b/src/amdis/BoundaryManager.hpp @@ -146,9 +146,10 @@ namespace AMDiS if (!segment.boundary()) continue; - auto index = segment.boundarySegmentIndex(); - Dune::Hybrid::ifElse(Dune::Std::is_detected<HasBoundaryId, Segment>{}, - [&](auto id) { boundaryIds_[index] = id(segment).boundaryId(); }); + Dune::Hybrid::ifElse(Dune::Std::is_detected<HasBoundaryId, Segment>{}, [&](auto id) { + auto index = segment.boundarySegmentIndex(); + boundaryIds_[index] = id(segment).boundaryId(); + }); } } } diff --git a/src/amdis/DataTransfer.inc.hpp b/src/amdis/DataTransfer.inc.hpp index a181ffe1..de25624a 100644 --- a/src/amdis/DataTransfer.inc.hpp +++ b/src/amdis/DataTransfer.inc.hpp @@ -22,8 +22,8 @@ #include <amdis/Output.hpp> #include <amdis/common/ConcurrentCache.hpp> +#include <amdis/typetree/Traversal.hpp> #include <amdis/typetree/TreeContainer.hpp> -#include <amdis/typetree/Visitor.hpp> namespace AMDiS { @@ -135,7 +135,7 @@ namespace AMDiS auto lv = basis_->localView(); auto const& idSet = gv.grid().localIdSet(); - forEachLeafNode_(lv.tree(), [&](auto const& node, auto const& tp) { + for_each_leaf_node(lv.tree(), [&](auto const& node, auto const& tp) { nodeDataTransfer_[tp].preAdaptInit(lv, coeff, node); }); @@ -148,7 +148,7 @@ namespace AMDiS lv.bind(e); auto& treeContainer = it.first->second; - forEachLeafNode_(lv.tree(), [&](auto const& node, auto const& tp) { + for_each_leaf_node(lv.tree(), [&](auto const& node, auto const& tp) { nodeDataTransfer_[tp].cacheLocal(treeContainer[tp]); }); } @@ -201,7 +201,7 @@ namespace AMDiS }; restrictLocalCompleted = true; - forEachLeafNode_(lv.tree(), [&](auto const& node, auto const& tp) { + for_each_leaf_node(lv.tree(), [&](auto const& node, auto const& tp) { restrictLocalCompleted &= nodeDataTransfer_[tp].restrictLocal(father, treeContainer[tp], xInChildCached, childContainer[tp], init); @@ -224,7 +224,7 @@ namespace AMDiS auto gv = basis_->gridView(); auto lv = basis_->localView(); auto const& idSet = gv.grid().localIdSet(); - forEachLeafNode_(lv.tree(), [&](auto const& node, auto const& tp) { + for_each_leaf_node(lv.tree(), [&](auto const& node, auto const& tp) { nodeDataTransfer_[tp].postAdaptInit(lv, coeff, node); }); @@ -243,7 +243,7 @@ namespace AMDiS if (it != persistentContainer_.end()) { lv.bind(e); auto const& treeContainer = it->second; - forEachLeafNode_(lv.tree(), [&](auto const& node, auto const& tp) { + for_each_leaf_node(lv.tree(), [&](auto const& node, auto const& tp) { nodeDataTransfer_[tp].copyLocal(treeContainer[tp]); }); finished_[index] = true; @@ -275,7 +275,7 @@ namespace AMDiS return fatherGeo.local(childGeo.global(x)); }; - forEachLeafNode_(lv.tree(), [&](auto const& node, auto const& tp) { + for_each_leaf_node(lv.tree(), [&](auto const& node, auto const& tp) { nodeDataTransfer_[tp].prolongLocal(father, treeContainer[tp], xInFather, init); }); diff --git a/src/amdis/DirichletBC.hpp b/src/amdis/DirichletBC.hpp index a74b0eab..df1947d1 100644 --- a/src/amdis/DirichletBC.hpp +++ b/src/amdis/DirichletBC.hpp @@ -11,8 +11,8 @@ #include <amdis/BoundaryCondition.hpp> #include <amdis/common/Concepts.hpp> #include <amdis/typetree/RangeType.hpp> +#include <amdis/typetree/Traversal.hpp> #include <amdis/typetree/TreeData.hpp> -#include <amdis/typetree/Visitor.hpp> namespace AMDiS { diff --git a/src/amdis/PeriodicBC.inc.hpp b/src/amdis/PeriodicBC.inc.hpp index 2b449f3b..7c5ebbf1 100644 --- a/src/amdis/PeriodicBC.inc.hpp +++ b/src/amdis/PeriodicBC.inc.hpp @@ -196,7 +196,7 @@ std::vector<D> PeriodicBC<D,MI>:: coords(Node const& tree, std::vector<std::size_t> const& localIndices) const { std::vector<D> dofCoords(localIndices.size()); - AMDiS::forEachLeafNode_(tree, [&](auto const& node, auto const& tp) + for_each_leaf_node(tree, [&](auto const& node, auto const& tp) { std::size_t size = node.finiteElement().size(); auto geometry = node.element().geometry(); diff --git a/src/amdis/ProblemStat.inc.hpp b/src/amdis/ProblemStat.inc.hpp index 2e8021df..7713a8af 100644 --- a/src/amdis/ProblemStat.inc.hpp +++ b/src/amdis/ProblemStat.inc.hpp @@ -177,7 +177,7 @@ void ProblemStat<Traits>::createMatricesAndVectors() rhs_ = std::make_shared<SystemVector>(*globalBasis_, NO_OPERATION); auto localView = globalBasis_->localView(); - AMDiS::forEachNode_(localView.tree(), [&,this](auto const& node, auto treePath) + for_each_node(localView.tree(), [&,this](auto const& node, auto treePath) { std::string i = to_string(treePath); estimates_[i].resize(globalBasis_->gridView().indexSet().size(0)); @@ -205,7 +205,7 @@ void ProblemStat<Traits>::createMarker() { marker_.clear(); auto localView = globalBasis_->localView(); - AMDiS::forEachNode_(localView.tree(), [&,this](auto const& node, auto treePath) + for_each_node(localView.tree(), [&,this](auto const& node, auto treePath) { std::string componentName = name_ + "->marker[" + to_string(treePath) + "]"; @@ -232,7 +232,7 @@ void ProblemStat<Traits>::createFileWriter() { filewriter_.clear(); auto localView = globalBasis_->localView(); - forEachNode_(localView.tree(), [&,this](auto const& node, auto treePath) + for_each_node(localView.tree(), [&,this](auto const& node, auto treePath) { std::string componentName = name_ + "->output[" + to_string(treePath) + "]"; @@ -428,9 +428,9 @@ buildAfterAdapt(AdaptInfo& /*adaptInfo*/, Flag /*flag*/, bool asmMatrix, bool as rhs_->init(asmVector); auto localView = globalBasis_->localView(); - forEachNode_(localView.tree(), [&,this](auto const& rowNode, auto rowTp) { + for_each_node(localView.tree(), [&,this](auto const& rowNode, auto rowTp) { auto rowBasis = Dune::Functions::subspaceBasis(*globalBasis_, rowTp); - forEachNode_(localView.tree(), [&,this](auto const& colNode, auto colTp) { + for_each_node(localView.tree(), [&,this](auto const& colNode, auto colTp) { auto colBasis = Dune::Functions::subspaceBasis(*globalBasis_, colTp); for (auto bc : dirichletBCs_[rowNode][colNode]) bc->init(rowBasis, colBasis); @@ -456,8 +456,8 @@ buildAfterAdapt(AdaptInfo& /*adaptInfo*/, Flag /*flag*/, bool asmMatrix, bool as systemMatrix_->finish(asmMatrix); rhs_->finish(asmVector); - forEachNode_(localView.tree(), [&,this](auto const& rowNode, auto) { - forEachNode_(localView.tree(), [&,this](auto const& colNode, auto) { + for_each_node(localView.tree(), [&,this](auto const& rowNode, auto) { + for_each_node(localView.tree(), [&,this](auto const& colNode, auto) { // finish boundary condition for (auto bc : dirichletBCs_[rowNode][colNode]) bc->fillBoundaryCondition(*systemMatrix_, *solution_, *rhs_, rowNode, colNode); diff --git a/src/amdis/common/ForEach.hpp b/src/amdis/common/ForEach.hpp index dbe64d7d..7768a561 100644 --- a/src/amdis/common/ForEach.hpp +++ b/src/amdis/common/ForEach.hpp @@ -2,7 +2,6 @@ #include <initializer_list> -#include <amdis/common/Apply.hpp> #include <amdis/common/Index.hpp> #include <amdis/common/Range.hpp> @@ -16,34 +15,52 @@ namespace AMDiS void ignored_evaluation(std::initializer_list<T>&&) { /* do nothing */ } } + template <std::size_t... I, class Tuple, class Functor> + constexpr void for_each(std::index_sequence<I...>, Tuple&& tuple, Functor&& f) + { + using std::get; +#if AMDIS_HAS_EXPANSION_STATEMENTS + for... (auto&& t : tuple) { f(FWD(t)); } +#elif AMDIS_HAS_CXX_FOLD_EXPRESSIONS + (f(get<I>(tuple)),...); +#else + Impl_::ignored_evaluation<int>({0, (f(get<I>(tuple)), 0)...}); +#endif + } + template <class Tuple, class Functor> constexpr void for_each(Tuple&& tuple, Functor&& f) { - #if AMDIS_HAS_CXX_FOLD_EXPRESSIONS - Tools::apply([f=std::move(f)](auto&&... t) { (f(FWD(t)),...); }, tuple); - #else - Tools::apply([f=std::move(f)](auto&&... t) { - Impl_::ignored_evaluation<int>({0, (f(FWD(t)), 0)...}); - }, tuple); - #endif + Tools::for_each(std::make_index_sequence<Size_v<std::remove_reference_t<Tuple>>>{}, FWD(tuple), FWD(f)); + } + + + template <std::size_t I0 = 0, std::size_t... I, class Functor> + constexpr void for_range(std::index_sequence<I...>, Functor&& f) + { +#if AMDIS_HAS_CXX_FOLD_EXPRESSIONS + (f(index_t<I0+I>{}),...); +#else + Impl_::ignored_evaluation<int>({0, (f(index_t<I0+I>{}), 0)...}); +#endif } template <std::size_t I0, std::size_t I1, class Functor> constexpr void for_range(index_t<I0> i0, index_t<I1> i1, Functor&& f) { - Tools::for_each(range_t<I0,I1>{}, FWD(f)); + Tools::for_range<I0>(std::make_index_sequence<std::size_t(I1-I0)>{}, FWD(f)); } template <std::size_t N, class Functor> constexpr void for_range(index_t<N>, Functor&& f) { - Tools::for_each(range_t<0,N>{}, FWD(f)); + Tools::for_range(std::make_index_sequence<N>{}, FWD(f)); } template <std::size_t I0, std::size_t I1, class Functor> constexpr void for_range(Functor&& f) { - Tools::for_each(range_t<I0,I1>{}, FWD(f)); + Tools::for_range<I0>(std::make_index_sequence<std::size_t(I1-I0)>{}, FWD(f)); } } // end namespace Tools diff --git a/src/amdis/common/TypeTraits.hpp b/src/amdis/common/TypeTraits.hpp index 5d63c077..3270cb24 100644 --- a/src/amdis/common/TypeTraits.hpp +++ b/src/amdis/common/TypeTraits.hpp @@ -65,6 +65,12 @@ namespace AMDiS template <class T> using owner = T; + /// A functor with no operation + struct NoOp + { + template <class... T> + constexpr void operator()(T&&...) const { /* no nothing */ } + }; /// Create a unique_ptr by copy/move construction template <class Obj> diff --git a/src/amdis/gridfunctions/DiscreteFunction.inc.hpp b/src/amdis/gridfunctions/DiscreteFunction.inc.hpp index 52c47b3e..d9998b04 100644 --- a/src/amdis/gridfunctions/DiscreteFunction.inc.hpp +++ b/src/amdis/gridfunctions/DiscreteFunction.inc.hpp @@ -147,7 +147,7 @@ LocalFunction::operator()(Domain const& x) const auto&& coefficients = *globalFunction_.dofVector_; auto&& nodeToRangeEntry = globalFunction_.nodeToRangeEntry_; - forEachLeafNode_(*subTree_, [&,this](auto const& node, auto const& tp) + for_each_leaf_node(*subTree_, [&,this](auto const& node, auto const& tp) { auto&& fe = node.finiteElement(); auto&& localBasis = fe.localBasis(); @@ -193,7 +193,7 @@ GradientLocalFunction::operator()(Domain const& x) const auto&& coefficients = *globalFunction_.dofVector_; auto&& nodeToRangeEntry = globalFunction_.nodeToRangeEntry_; - forEachLeafNode_(*subTree_, [&,this](auto const& node, auto const& tp) + for_each_leaf_node(*subTree_, [&,this](auto const& node, auto const& tp) { // TODO: may DOFVectorView::Range to FieldVector type if necessary using LocalDerivativeTraits diff --git a/src/amdis/linearalgebra/DOFMatrixBase.inc.hpp b/src/amdis/linearalgebra/DOFMatrixBase.inc.hpp index 7088f001..f06bbde6 100644 --- a/src/amdis/linearalgebra/DOFMatrixBase.inc.hpp +++ b/src/amdis/linearalgebra/DOFMatrixBase.inc.hpp @@ -2,7 +2,7 @@ #include <amdis/Assembler.hpp> #include <amdis/LocalOperator.hpp> -#include <amdis/typetree/Visitor.hpp> +#include <amdis/typetree/Traversal.hpp> #include <amdis/utility/AssembleOperators.hpp> namespace AMDiS { @@ -75,8 +75,8 @@ assemble(RowLocalView const& rowLocalView, ColLocalView const& colLocalView) auto const& element = rowLocalView.element(); auto geometry = element.geometry(); - forEachNode_(rowLocalView.tree(), [&](auto const& rowNode, auto) { - forEachNode_(colLocalView.tree(), [&](auto const& colNode, auto) { + for_each_node(rowLocalView.tree(), [&](auto const& rowNode, auto) { + for_each_node(colLocalView.tree(), [&](auto const& colNode, auto) { auto& matOp = operators_[rowNode][colNode]; if (matOp) { matOp.bind(element, geometry); diff --git a/src/amdis/linearalgebra/DOFVectorBase.inc.hpp b/src/amdis/linearalgebra/DOFVectorBase.inc.hpp index 3be8021a..4795c768 100644 --- a/src/amdis/linearalgebra/DOFVectorBase.inc.hpp +++ b/src/amdis/linearalgebra/DOFVectorBase.inc.hpp @@ -2,7 +2,7 @@ #include <amdis/Assembler.hpp> #include <amdis/LocalOperator.hpp> -#include <amdis/typetree/Visitor.hpp> +#include <amdis/typetree/Traversal.hpp> #include <amdis/utility/AssembleOperators.hpp> namespace AMDiS { @@ -66,7 +66,7 @@ assemble(LocalView const& localView) auto const& element = localView.element(); auto geometry = element.geometry(); - forEachNode_(localView.tree(), [&](auto const& node, auto) { + for_each_node(localView.tree(), [&](auto const& node, auto) { auto& rhsOp = operators_[node]; if (rhsOp) { rhsOp.bind(element, geometry); diff --git a/src/amdis/typetree/CMakeLists.txt b/src/amdis/typetree/CMakeLists.txt index 60c45aeb..e2544920 100644 --- a/src/amdis/typetree/CMakeLists.txt +++ b/src/amdis/typetree/CMakeLists.txt @@ -7,5 +7,4 @@ install(FILES TreeContainer.hpp TreeData.hpp TreePath.hpp - Visitor.hpp DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/amdis/typetree) diff --git a/src/amdis/typetree/Traversal.hpp b/src/amdis/typetree/Traversal.hpp index 45751bcd..8dfd9cdf 100644 --- a/src/amdis/typetree/Traversal.hpp +++ b/src/amdis/typetree/Traversal.hpp @@ -3,139 +3,257 @@ #include <dune/common/hybridutilities.hh> #include <dune/common/rangeutilities.hh> +#include <dune/typetree/childextraction.hh> #include <dune/typetree/nodetags.hh> #include <dune/typetree/treepath.hh> #include <dune/typetree/visitor.hh> #include <amdis/common/ForEach.hpp> +#include <amdis/common/Logical.hpp> +#include <amdis/common/Range.hpp> #include <amdis/common/TypeTraits.hpp> -namespace AMDiS -{ - // forward declaration of main engine struct - template <typename NodeTag, bool visit = true> - struct TraverseTree; - - - // Do not visit nodes the visitor is not interested in - template <typename NodeTag> - struct TraverseTree<NodeTag, false> - { - template <typename Node, typename Visitor, typename TreePath> - static void apply(const Node& node, const Visitor& visitor, TreePath const& tp) - {} - }; - -#ifndef DOXYGEN - - // some implementation details - - template <class Node, class Index> - struct HybridChildType - : HybridChildType<std::remove_const_t<Node>, std::remove_const_t<Index>> {}; - - template <class Node> - struct HybridChildType<Node, std::size_t> - { - using type = typename Node::template Child<0>::Type; - }; - - template <class Node, std::size_t K> - struct HybridChildType<Node, Dune::index_constant<K>> - { - using type = typename Node::template Child<K>::Type; - }; - - template <class NodeTag, class Node> - constexpr std::size_t hybridDegree(NodeTag, Node const& node) - { - return Dune::TypeTree::degree(node); - } - - template <class Node> - constexpr auto hybridDegree(Dune::TypeTree::CompositeNodeTag, Node const& node) - { - return Dune::index_constant<Node::CHILDREN>{}; - } - - - template <std::size_t k, std::size_t n> - constexpr bool notLastChild(Dune::index_constant<k> const&, Dune::index_constant<n> const&) - { - return k < n-1; - } - - constexpr bool notLastChild(std::size_t k, std::size_t n) - { - return k < n-1; - } - -#endif - - - template <class NodeTag> - struct TraverseTree<NodeTag, true> - { - template <typename N, typename V, typename TreePath> - static void apply(N&& n, V&& v, TreePath const& tp) +// NOTE: backport of dune/typetree/traversal.hpp from Dune 2.7 + +namespace AMDiS { + + enum class TreePathType { - using Node = std::remove_reference_t<N>; - using Visitor = std::remove_reference_t<V>; + DYNAMIC, STATIC + }; + + namespace Impl { + + // This is a constexpr version of the ternery operator c?t1:t1. + // In contrast to the latter the type of t1 and t2 can be different. + // Notice that std::conditional would not do the trick, because + // it only selects between types. + template<bool c, class T1, class T2, + std::enable_if_t<c, int> = 0> + constexpr auto conditionalValue(T1&& t1, T2&& t2) { + return std::forward<T1>(t1); + } + + template<bool c, class T1, class T2, + std::enable_if_t<not c, int> = 0> + constexpr auto conditionalValue(T1&& t1, T2&& t2) { + return std::forward<T2>(t2); + } + + + /* The signature is the same as for the public applyToTree + * function in Dune::Typetree, despite the additionally passed + * treePath argument. The path passed here is associated to + * the tree and the relative paths of the children (wrt. to tree) + * are appended to this. Hence the behavior of the public function + * is resembled by passing an empty treePath. + */ + + /* + * This is the overload for leaf traversal + */ + template<class T, class TP, class V, + std::enable_if_t<remove_cvref_t<T>::isLeaf, int> = 0> + void apply_to_tree(T&& tree, TP treePath, V&& visitor) + { + visitor.leaf(tree, treePath); + } + + /* + * This is the general overload doing child traversal. + */ + template<class T, class TP, class V, + std::enable_if_t<not remove_cvref_t<T>::isLeaf, int> = 0> + void apply_to_tree(T&& tree, TP treePath, V&& visitor) + { + // Do we really want to take care for const-ness of the Tree + // when instantiating VisitChild below? I'd rather expect this: + // using Tree = remove_cvref_t<T>; + // using Visitor = remove_cvref_t<V>; + using Tree = std::remove_reference_t<T>; + using Visitor = std::remove_reference_t<V>; + visitor.pre(tree, treePath); + + // Use statically encoded degree unless tree + // is a power node and dynamic traversal is requested. + constexpr auto useDynamicTraversal = (Tree::isPower and Visitor::treePathType==TreePathType::DYNAMIC); + auto degree = conditionalValue<useDynamicTraversal>(Tree::degree(), Dune::index_constant<Tree::degree()>{}); + + auto indices = Dune::range(degree); + Dune::Hybrid::forEach(indices, [&](auto i) { + auto childTP = Dune::TypeTree::push_back(treePath, i); + auto&& child = tree.child(i); + using Child = TYPEOF(child); + + visitor.beforeChild(tree, child, treePath, i); + + // This requires that visiotor.in(...) can always be instantiated, + // even if there's a single child only. + if (i>0) + visitor.in(tree, treePath); + static constexpr auto visitChild = Visitor::template VisitChild<Tree,Child,TP>::value; + #if AMDIS_HAS_CXX_CONSTEXPR_IF + if constexpr(visitChild) + applyToTree(child, childTP, visitor); + #else // AMDIS_HAS_CXX_CONSTEXPR_IF + Dune::Hybrid::ifElse(bool_t<visitChild>{}, [&] (auto /*id*/) { + applyToTree(child, childTP, visitor); + }); + #endif // AMDIS_HAS_CXX_CONSTEXPR_IF + visitor.afterChild(tree, child, treePath, i); + }); + visitor.post(tree, treePath); + } + + // Overload for leaf nodes + template<class Tree, class TP, class Pre, class Leaf, class Post, + std::enable_if_t<remove_cvref_t<Tree>::isLeaf, int> = 0> + void for_each_node(Tree&& tree, TP treePath, Pre&& /*preFunc*/, Leaf&& leafFunc, Post&& /*postFunc*/) + { + leafFunc(tree, treePath); + } + + // Overload for non-leaf nodes + // Forward declaration needed for recursion + template<class Tree, class TP, class Pre, class Leaf, class Post, + std::enable_if_t<not remove_cvref_t<Tree>::isLeaf,int> = 0> + void for_each_node(Tree&& tree, TP treePath, Pre&& preFunc, Leaf&& leafFunc, Post&& postFunc); + + // Helper for power nodes + template<class Tree, class TP, class Pre, class Leaf, class Post, std::size_t... I, + std::enable_if_t<remove_cvref_t<Tree>::isPower, int> = 0> + void for_each_node_unfold(Tree&& tree, TP treePath, Pre&& preFunc, Leaf&& leafFunc, Post&& postFunc, std::index_sequence<I...>) + { + for (std::size_t i = 0; i < sizeof...(I); ++i) + Impl::for_each_node(tree.child(i), Dune::TypeTree::push_back(treePath, i), preFunc, leafFunc, postFunc); + } - v.pre(FWD(n),tp); + // Helper for composite nodes + template<class Tree, class TP, class Pre, class Leaf, class Post, std::size_t... I, + std::enable_if_t<not remove_cvref_t<Tree>::isPower, int> = 0> + void for_each_node_unfold(Tree&& tree, TP treePath, Pre&& preFunc, Leaf&& leafFunc, Post&& postFunc, std::index_sequence<I...>) + { + (void)std::initializer_list<int>{( + Impl::for_each_node(tree.child(Dune::index_constant<I>{}), + Dune::TypeTree::push_back(treePath, Dune::index_constant<I>{}), preFunc, leafFunc, postFunc),0)... + }; + } - auto const deg = hybridDegree(NodeTag{}, n); - Dune::Hybrid::forEach(Dune::range(deg), [&](auto const _k) + /* + * Traverse tree and visit each node. The signature is the same + * as for the public for_each_node function in Dune::Typtree, + * despite the additionally passed treePath argument. The path + * passed here is associated to the tree and the relative + * paths of the children (wrt. to tree) are appended to this. + * Hence the behavior of the public function is resembled + * by passing an empty treePath. + * + * See also the specialization for leaf-nodes. + */ + template<class Tree, class TP, class Pre, class Leaf, class Post, + std::enable_if_t<not remove_cvref_t<Tree>::isLeaf,int>> + void for_each_node(Tree&& tree, TP treePath, Pre&& preFunc, Leaf&& leafFunc, Post&& postFunc) { - // always call beforeChild(), regardless of the value of visit - v.beforeChild(FWD(n),n.child(_k),tp,_k); + auto indices = std::make_index_sequence<TYPEOF(tree)::degree()>{}; + preFunc(tree, treePath); + Impl::for_each_node_unfold(tree, treePath, preFunc, leafFunc, postFunc, indices); + postFunc(tree, treePath); + } + + } // namespace Impl + + + // ******************************************************************************** + // Public Interface + // ******************************************************************************** - // descend to child - using C = typename HybridChildType<Node, decltype(_k)>::type; - const bool visit = Visitor::template VisitChild<Node,C,TreePath>::value; - TraverseTree<Dune::TypeTree::NodeTag<C>,visit>::apply(n.child(_k),FWD(v),push_back(tp, _k)); + //! Apply visitor to TypeTree. + /** + * \code + * #include <amdis/typetree/Traversal.hpp> + * \endcode + * This function applies the given visitor to the given tree. Both visitor and tree may be const + * or non-const (if the compiler supports rvalue references, they may even be a non-const temporary). + * + * \note The visitor must implement the interface laid out by DefaultVisitor (most easily achieved by + * inheriting from it) and specify the required type of tree traversal (static or dynamic) by + * inheriting from either StaticTraversal or DynamicTraversal. + * + * \param tree The tree the visitor will be applied to. + * \param visitor The visitor to apply to the tree. + */ + template<typename Tree, typename Visitor> + void apply_to_tree(Tree&& tree, Visitor&& visitor) + { + auto root = Dune::TypeTree::hybridTreePath(); + Impl::apply_to_tree(tree, root, visitor); + } - // always call afterChild(), regardless of the value of visit - v.afterChild(FWD(n),n.child(_k),tp,_k); + /** + * \brief Traverse tree and visit each node + * + * All passed callback functions are called with the + * node and corresponding treepath as arguments. + * + * \param tree The tree to traverse + * \param preFunc This function is called for each inner node before visiting its children + * \param leafFunc This function is called for each leaf node + * \param postFunc This function is called for each inner node after visiting its children + */ + template<class Tree, class Pre, class Leaf, class Post> + void for_each_node(Tree&& tree, Pre&& preFunc, Leaf&& leafFunc, Post&& postFunc) + { + auto root = Dune::TypeTree::hybridTreePath(); + Impl::for_each_node(tree, root, preFunc, leafFunc, postFunc); + } - // if this is not the last child, call infix callback - if (notLastChild(_k, deg)) - v.in(FWD(n),tp); - }); + /** + * \brief Traverse tree and visit each node + * + * All passed callback functions are called with the + * node and corresponding treepath as arguments. + * + * \param tree The tree to traverse + * \param innerFunc This function is called for each inner node before visiting its children + * \param leafFunc This function is called for each leaf node + */ + template<class Tree, class Inner, class Leaf> + void for_each_node(Tree&& tree, Inner&& innerFunc, Leaf&& leafFunc) + { + auto root = Dune::TypeTree::hybridTreePath(); + Impl::for_each_node(tree, root, innerFunc, leafFunc, NoOp{}); + } - v.post(FWD(n),tp); + /** + * \brief Traverse tree and visit each node + * + * The passed callback function is called with the + * node and corresponding treepath as arguments. + * + * \param tree The tree to traverse + * \param nodeFunc This function is called for each node + */ + template<class Tree, class NodeFunc> + void for_each_node(Tree&& tree, NodeFunc&& nodeFunc) + { + auto root = Dune::TypeTree::hybridTreePath(); + Impl::for_each_node(tree, root, nodeFunc, nodeFunc, NoOp{}); } - }; - - // LeafNode - just call the leaf() callback - template <> - struct TraverseTree<Dune::TypeTree::LeafNodeTag, true> - { - template <typename N, typename V, typename TreePath> - static void apply(N&& n, V&& v, TreePath const& tp) + + /** + * \brief Traverse tree and visit each leaf node + * + * The passed callback function is called with the + * node and corresponding treepath as arguments. + * + * \param tree The tree to traverse + * \param leafFunc This function is called for each leaf node + */ + template<class Tree, class Leaf> + void for_each_leaf_node(Tree&& tree, Leaf&& leafFunc) { - v.leaf(FWD(n),tp); + auto root = Dune::TypeTree::hybridTreePath(); + Impl::for_each_node(tree, root, NoOp{}, leafFunc, NoOp{}); } - }; - - //! Apply visitor to TypeTree. - /** - * This function applies the given visitor to the given tree. Both visitor and tree may be const - * or non-const (if the compiler supports rvalue references, they may even be a non-const temporary). - * - * \note The visitor must implement the interface laid out by DefaultVisitor (most easily achieved by - * inheriting from it). - * - * \param tree The tree the visitor will be applied to. - * \param visitor The visitor to apply to the tree. - */ - template <typename Tree, typename Visitor> - void traverseTree(Tree&& tree, Visitor&& visitor) - { - using Node = std::remove_reference_t<Tree>; - using NodeTag = Dune::TypeTree::NodeTag<Node>; - using TreePath = Dune::TypeTree::HybridTreePath<>; - TraverseTree<NodeTag>::apply(FWD(tree), FWD(visitor), TreePath{}); - } } // end namespace AMDiS diff --git a/src/amdis/typetree/TreeData.hpp b/src/amdis/typetree/TreeData.hpp index 2e558ace..d6ecfe62 100644 --- a/src/amdis/typetree/TreeData.hpp +++ b/src/amdis/typetree/TreeData.hpp @@ -6,7 +6,7 @@ #include <vector> #include <dune/typetree/typetree.hh> -#include <amdis/typetree/Visitor.hpp> +#include <amdis/typetree/Traversal.hpp> namespace AMDiS { @@ -195,10 +195,10 @@ namespace AMDiS } template <class Func> - void applyImpl(Func&& func, std::true_type) { forEachLeafNode_(basis_->localView().tree(), func); } + void applyImpl(Func&& func, std::true_type) { for_each_leaf_node(basis_->localView().tree(), func); } template <class Func> - void applyImpl(Func&& func, std::false_type) { forEachNode_(basis_->localView().tree(), func); } + void applyImpl(Func&& func, std::false_type) { for_each_node(basis_->localView().tree(), func); } protected: Basis const* basis_ = nullptr; @@ -238,7 +238,7 @@ namespace AMDiS void init(RowBasis const& rowBasis, ColBasis const& colBasis) { Super::init(rowBasis); - forEachNode_(rowBasis.localView().tree(), [&](auto const& node, auto&&) + for_each_node(rowBasis.localView().tree(), [&](auto const& node, auto&&) { (*this)[node].init(colBasis); }); diff --git a/src/amdis/typetree/Visitor.hpp b/src/amdis/typetree/Visitor.hpp deleted file mode 100644 index 153e0e1e..00000000 --- a/src/amdis/typetree/Visitor.hpp +++ /dev/null @@ -1,120 +0,0 @@ -#pragma once - -#include <dune/typetree/visitor.hh> -#include <amdis/typetree/Traversal.hpp> - -namespace AMDiS -{ - // from dune-typetree merge-request !2 - namespace Impl - { - template <class PreFunc, class LeafFunc, class PostFunc> - class CallbackVisitor - : public Dune::TypeTree::TreeVisitor - { - public: - CallbackVisitor(PreFunc& preFunc, LeafFunc& leafFunc, PostFunc& postFunc) - : preFunc_(preFunc) - , leafFunc_(leafFunc) - , postFunc_(postFunc) - {} - - template <typename Node, typename TreePath> - void pre(Node&& node, TreePath treePath) - { - preFunc_(node, treePath); - } - - template <typename Node, typename TreePath> - void leaf(Node&& node, TreePath treePath) - { - leafFunc_(node, treePath); - } - - template <typename Node, typename TreePath> - void post(Node&& node, TreePath treePath) - { - postFunc_(node, treePath); - } - - private: - PreFunc& preFunc_; - LeafFunc& leafFunc_; - PostFunc& postFunc_; - }; - - template <class PreFunc, class LeafFunc, class PostFunc> - auto callbackVisitor(PreFunc& preFunc, LeafFunc& leafFunc, PostFunc& postFunc) - { - return CallbackVisitor<PreFunc, LeafFunc, PostFunc>(preFunc, leafFunc, postFunc); - } - - } // namespace Impl - - - /** - * \brief Traverse tree and visit each node - * - * All passed callback functions are called with the - * node and corresponding treepath as arguments. - * - * \param tree The tree to traverse - * \param preFunc This function is called for each inner node before visiting its children - * \param leafFunc This function is called for each leaf node - * \param postFunc This function is called for each inner node after visiting its children - */ - template <class Tree, class PreFunc, class LeafFunc, class PostFunc> - void forEachNode_(Tree&& tree, PreFunc&& preFunc, LeafFunc&& leafFunc, PostFunc&& postFunc) - { - traverseTree(tree, Impl::callbackVisitor(preFunc, leafFunc, postFunc)); - } - - /** - * \brief Traverse tree and visit each node - * - * All passed callback functions are called with the - * node and corresponding treepath as arguments. - * - * \param tree The tree to traverse - * \param innerFunc This function is called for each inner node before visiting its children - * \param leafFunc This function is called for each leaf node - */ - template <class Tree, class InnerFunc, class LeafFunc> - void forEachNode_(Tree&& tree, InnerFunc&& innerFunc, LeafFunc&& leafFunc) - { - auto nop = [](auto&&... args) {}; - forEachNode_(tree, innerFunc, leafFunc, nop); - } - - /** - * \brief Traverse tree and visit each node - * - * The passed callback function is called with the - * node and corresponding treepath as arguments. - * - * \param tree The tree to traverse - * \param nodeFunc This function is called for each node - */ - template <class Tree, class NodeFunc> - void forEachNode_(Tree&& tree, NodeFunc&& nodeFunc) - { - forEachNode_(tree, nodeFunc, nodeFunc); - } - - /** - * \brief Traverse tree and visit each leaf node - * - * The passed callback function is called with the - * node and corresponding treepath as arguments. - * - * \param tree The tree to traverse - * \param leafFunc This function is called for each leaf node - */ - template <class Tree, class LeafFunc> - void forEachLeafNode_(Tree&& tree, LeafFunc&& leafFunc) - { - auto nop = [](auto&&... args) {}; - forEachNode_(tree, nop, leafFunc, nop); - } - -} // end namespace AMDiS diff --git a/test/DataTransferTest.hpp b/test/DataTransferTest.hpp index 5502b105..11f0e301 100644 --- a/test/DataTransferTest.hpp +++ b/test/DataTransferTest.hpp @@ -68,7 +68,7 @@ auto makeProblem(typename BasisCreator::GlobalBasis::GridView::Grid& grid, Fcts // interpolate given function to initial grid int k = 0; - AMDiS::forEachLeafNode_(localView.tree(), [&](auto const& node, auto tp) + for_each_leaf_node(localView.tree(), [&](auto const& node, auto tp) { interpolate(globalBasis, tp, prob.solution(tp).coefficients(), funcs[k]); k++; @@ -90,7 +90,7 @@ double calcError(Problem const& prob, Fcts const& funcs) int k = 0; // interpolate given function onto reference vector - AMDiS::forEachLeafNode_(localView.tree(), [&](auto const& node, auto tp) + for_each_leaf_node(localView.tree(), [&](auto const& node, auto tp) { interpolate(globalBasis, tp, ref, funcs[k]); k++; diff --git a/test/TreeDataTest.cpp b/test/TreeDataTest.cpp index d0320281..103b4d68 100644 --- a/test/TreeDataTest.cpp +++ b/test/TreeDataTest.cpp @@ -25,7 +25,7 @@ bool operator==(TreeData<Basis,NodeData,false> const& t1, TreeData<Basis,NodeDat AMDIS_TEST(t1.basis() == t2.basis() && t1.basis() != nullptr); bool same = true; - AMDiS::forEachNode_(t1.basis()->localView().tree(), [&](auto const& node, auto) { + for_each_node(t1.basis()->localView().tree(), [&](auto const& node, auto) { same = same && (t1[node] == t2[node]); }); @@ -38,7 +38,7 @@ bool operator==(TreeData<Basis,NodeData,true> const& t1, TreeData<Basis,NodeData AMDIS_TEST(t1.basis() == t2.basis() && t1.basis() != nullptr); bool same = true; - AMDiS::forEachLeafNode_(t1.basis()->localView().tree(), [&](auto const& node, auto) { + for_each_leaf_node(t1.basis()->localView().tree(), [&](auto const& node, auto) { same = same && (t1[node] == t2[node]); }); @@ -73,7 +73,7 @@ int main () TreeData<Basis, NodeData, false> treeData; treeData.init(basis); - AMDiS::forEachNode_(tree, [&](auto const& node, auto) { + for_each_node(tree, [&](auto const& node, auto) { treeData[node] = double(node.treeIndex()); }); @@ -109,7 +109,7 @@ int main () TreeData<Basis, NodeData, true> treeData; treeData.init(basis); - AMDiS::forEachLeafNode_(tree, [&](auto const& node, auto) { + for_each_leaf_node(tree, [&](auto const& node, auto) { treeData[node] = double(node.treeIndex()); }); -- GitLab