Commit 305aa9e3 authored by Praetorius, Simon's avatar Praetorius, Simon
Browse files

flat indexing for blocked basis

parent 73513f33
install(FILES
FlatIndex.hpp
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/amdis/functions)
#pragma once
#include <vector>
#include <dune/common/fvector.hh>
#include <dune/istl/io.hh>
#include <dune/functions/common/indexaccess.hh>
#include <amdis/common/Logical.hpp>
#include <amdis/Output.hpp>
namespace AMDiS
{
template <class MI, class Vec, class = void>
class FlatIndex
{
using size_type = typename MI::size_type;
using Vector = Vec;
public:
template <class SizeProvider>
FlatIndex(SizeProvider const& sizeProvider)
{
typename SizeProvider::SizePrefix prefix;
prefix.resize(0);
resize(shifts_, sizeProvider, prefix, 0);
// Dune::printvector(std::cout, shifts_, "shifts", "");
}
size_type operator()(MI const& index) const
{
return Dune::Functions::resolveDynamicMultiIndex(shifts_, index) + index.back();
}
private:
// Template aliases for using detection idiom.
template<class C>
using dynamicIndexAccess_t = decltype(std::declval<C>()[0]);
template<class C>
using staticIndexAccess_t = decltype(std::declval<C>()[Dune::Indices::_0]);
template<class C>
using resizeMethod_t = decltype(std::declval<C>().resize(0));
// Short cuts for feature detection
template<class C>
using hasDynamicIndexAccess = Dune::Std::is_detected<dynamicIndexAccess_t, std::remove_reference_t<C>>;
template<class C>
using hasStaticIndexAccess = Dune::Std::is_detected<staticIndexAccess_t, std::remove_reference_t<C>>;
template<class C>
using hasResizeMethod = Dune::Std::is_detected<resizeMethod_t, std::remove_reference_t<C>>;
template<class C>
using isDynamicVector = Dune::Std::is_detected<dynamicIndexAccess_t, std::remove_reference_t<C>>;
template<class C>
using isStaticVector = Dune::Std::bool_constant<
Dune::Std::is_detected_v<staticIndexAccess_t, std::remove_reference_t<C>>
and not Dune::Std::is_detected_v<dynamicIndexAccess_t, std::remove_reference_t<C>>>;
template<class C>
using isScalar = Dune::Std::bool_constant<not Dune::Std::is_detected_v<staticIndexAccess_t, std::remove_reference_t<C>>>;
template<class C>
using isVector = Dune::Std::bool_constant<Dune::Std::is_detected_v<staticIndexAccess_t, std::remove_reference_t<C>>>;
template <class C, class SizeProvider,
std::enable_if_t<hasResizeMethod<C>::value, int> = 0>
size_type resize(C&& c, const SizeProvider& sizeProvider, typename SizeProvider::SizePrefix prefix, size_type shift)
{
size_type size = sizeProvider.size(prefix);
if (size == 0) {
// c = shift;
return size;
}
c.resize(size);
prefix.push_back(0);
size_type sub_size = 0;
for (size_type i = 0; i < size; ++i) {
prefix.back() = i;
sub_size += resize(c[i], sizeProvider, prefix, shift + sub_size);
}
return sub_size;
}
template <class C, class SizeProvider,
std::enable_if_t<not hasResizeMethod<C>::value, int> = 0,
std::enable_if_t<isVector<C>::value, int> = 0>
size_type resize(C&& c, const SizeProvider& sizeProvider, typename SizeProvider::SizePrefix prefix, size_type shift)
{
auto size = sizeProvider.size(prefix);
if (size == 0 || c.size() != size) {
c = shift;
return size;
}
// Recursively resize all entries of c now.
using namespace Dune::Hybrid;
prefix.push_back(0);
size_type sub_size = 0;
Dune::Hybrid::forEach(integralRange(Dune::Hybrid::size(c)), [&](auto&& i) {
prefix.back() = i;
sub_size += this->resize(c[i], sizeProvider, prefix, shift + sub_size);
});
return sub_size;
}
template<class C, class SizeProvider,
std::enable_if_t<not hasResizeMethod<C>::value, int> = 0,
std::enable_if_t<isScalar<C>::value, int> = 0>
size_type resize(C&& c, const SizeProvider& sizeProvider, typename SizeProvider::SizePrefix prefix, size_type shift)
{
auto size = sizeProvider.size(prefix);
c = shift;
return size;
}
private:
Vector shifts_;
};
template <class MI, class Vec>
class FlatIndex<MI, Vec, std::enable_if_t<(MI::max_size() == 1)>>
{
using size_type = typename MI::size_type;
public:
template <class SizeProvider>
FlatIndex(SizeProvider const&) {}
size_type operator()(MI const& multiIndex) const
{
return multiIndex[0];
}
};
} // end namespace AMDiS
...@@ -30,6 +30,9 @@ dune_add_test(SOURCES FiniteElementTypeTest.cpp ...@@ -30,6 +30,9 @@ dune_add_test(SOURCES FiniteElementTypeTest.cpp
dune_add_test(SOURCES FilesystemTest.cpp dune_add_test(SOURCES FilesystemTest.cpp
LINK_LIBRARIES amdis) LINK_LIBRARIES amdis)
dune_add_test(SOURCES FlatIndexTest.cpp
LINK_LIBRARIES amdis)
dune_add_test(SOURCES IntegrateTest.cpp dune_add_test(SOURCES IntegrateTest.cpp
LINK_LIBRARIES amdis LINK_LIBRARIES amdis
CMD_ARGS "${CMAKE_SOURCE_DIR}/examples/init/ellipt.dat.2d") CMD_ARGS "${CMAKE_SOURCE_DIR}/examples/init/ellipt.dat.2d")
......
#include <dune/common/timer.hh>
#include <dune/grid/yaspgrid.hh>
#include <dune/istl/bvector.hh>
#include <dune/functions/functionspacebases/compositebasis.hh>
#include <dune/functions/functionspacebases/powerbasis.hh>
#include <dune/functions/functionspacebases/lagrangebasis.hh>
#include <amdis/functions/FlatIndex.hpp>
#include <amdis/typetree/Traversal.hpp>
#include "Tests.hpp"
using namespace AMDiS;
template <class MI>
std::string to_string(MI const& multiIndex)
{
std::stringstream out;
for (std::size_t i = 0; i < multiIndex.size(); ++i)
out << multiIndex[i] << " ";
return out.str();
}
template <class Vec, class Basis>
void test(Basis const& basis)
{
Dune::Timer t;
using MultiIndex = typename Basis::MultiIndex;
FlatIndex<MultiIndex,Vec> flatIndex(basis);
auto localView = basis.localView();
std::vector<int> vector(basis.dimension(),0);
for (auto const& element : elements(basis.gridView())) {
localView.bind(element);
for_each_leaf_node(localView.tree(), [&](auto const& node, auto tp) {
// std::cout << tp << '\n';
std::size_t size = node.finiteElement().size();
for (std::size_t i = 0; i < size; ++i) {
auto multiIndex = localView.index(node.localIndex(i));
auto flat = flatIndex(multiIndex);
// std::cout << " i=" << i << ", mi=" << to_string(multiIndex) << ", flat=" << flat << '\n';
vector[flat] = 1;
}
});
localView.unbind();
}
AMDIS_TEST(std::all_of(std::begin(vector), std::end(vector), [](int v) { return v == 1; }));
}
template <class GridView>
void run_test(GridView const& gridView)
{
using namespace Dune::Functions::BasisBuilder;
static const int k = 1;
auto basis1 = makeBasis(gridView,
composite(
power<GridView::dimensionworld>(lagrange<k+1>(), blockedInterleaved()),
lagrange<k>(), blockedLexicographic()
));
auto basis2 = makeBasis(gridView,
composite(
power<GridView::dimensionworld>(lagrange<k+1>(), flatInterleaved()),
lagrange<k>(), flatLexicographic()
));
auto basis3 = makeBasis(gridView,
composite(
power<GridView::dimensionworld>(lagrange<k+1>(), flatInterleaved()),
lagrange<k>(), blockedLexicographic()
));
Dune::Timer t;
t.reset();
test<Dune::BlockVector<Dune::BlockVector<Dune::FieldVector<std::size_t,1>>>>(basis1);
msg("time (basis1) = {}", t.elapsed());
t.reset();
test<std::size_t>(basis2);
msg("time (basis2) = {}", t.elapsed());
t.reset();
test<std::vector<std::size_t>>(basis3);
msg("time (basis3) = {}", t.elapsed());
}
int main()
{
// create grid
Dune::YaspGrid<2> grid({1.0,1.0}, std::array<int,2>{20,20});
auto gridView = grid.leafGridView();
run_test(gridView);
// create grid
Dune::YaspGrid<3> grid2({1.0,1.0,1.0}, std::array<int,3>{10,10,10});
auto gridView2 = grid2.leafGridView();
run_test(gridView2);
return report_errors();
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment