diff --git a/src/amdis/common/FieldMatVec.hpp b/src/amdis/common/FieldMatVec.hpp index 86492c9dad4e063c49c2736a69f34f09e3c7ab96..5ad50b6492f5601a17e1505f7614a9598ae068a7 100644 --- a/src/amdis/common/FieldMatVec.hpp +++ b/src/amdis/common/FieldMatVec.hpp @@ -6,6 +6,21 @@ #include <dune/common/fmatrix.hh> #include <dune/common/fvector.hh> +namespace std +{ + template <class T, int N> + struct common_type<Dune::FieldVector<T,N>, T> + { + using type = T; + }; + + template <class T, int N, int M> + struct common_type<Dune::FieldMatrix<T,N,M>, T> + { + using type = T; + }; +} + namespace Dune { // some arithmetic operations with FieldVector @@ -14,25 +29,19 @@ namespace Dune FieldVector<T,N> operator-(FieldVector<T,N> v); template <class T, int N, class S, - std::enable_if_t<std::is_arithmetic<S>::value,int> = 0 > + std::enable_if_t<std::is_convertible<S,T>::value, int> = 0> FieldVector<T,N> operator*(FieldVector<T,N> v, S factor); - template <class S, class T, int N, - std::enable_if_t<std::is_arithmetic<S>::value,int> = 0 > + template <class T, int N, class S, + std::enable_if_t<std::is_convertible<S,T>::value, int> = 0> FieldVector<T,N> operator*(S factor, FieldVector<T,N> v); template <class T, int N, class S, - std::enable_if_t<std::is_arithmetic<S>::value,int> = 0 > + std::enable_if_t<std::is_convertible<S,T>::value, int> = 0> FieldVector<T,N> operator/(FieldVector<T,N> v, S factor); template <class T> - FieldVector<T,1> operator*(FieldVector<T,1> const& v, FieldVector<T,1> const& w); - - template <class T, int N> - FieldVector<T,N> operator*(FieldVector<T,1> const& factor, FieldVector<T,N> v); - - template <class T, int N> - FieldVector<T,N> operator*(FieldVector<T,N> v, FieldVector<T,1> const& factor); + FieldVector<T,1> operator*(FieldVector<T,1> v, FieldVector<T,1> w); // ---------------------------------------------------------------------------- @@ -48,9 +57,9 @@ namespace Dune template <class T, class S, int N> auto dot(FieldVector<T,N> const& vec1, FieldVector<S,N> const& vec2); - template <class T, int N, int M, - std::enable_if_t<( N!=1 && M!=1 ),int> = 0> - auto operator*(FieldVector<T,N> const& v, FieldVector<T,M> const& w); + // template <class T, int N, int M, + // std::enable_if_t<( N!=1 && M!=1 ),int> = 0> + // auto operator*(FieldVector<T,N> const& v, FieldVector<T,M> const& w); template <class T, class S, int N> auto dot(FieldMatrix<T,1,N> const& vec1, FieldMatrix<S,1,N> const& vec2); @@ -197,34 +206,19 @@ namespace Dune template <class T, int M, int N, class S, - std::enable_if_t<std::is_arithmetic<S>::value,int> = 0 > + std::enable_if_t<std::is_convertible<S,T>::value, int> = 0> FieldMatrix<T,M,N> operator*(S scalar, FieldMatrix<T, M, N> A); template <class T, int M, int N, class S, - std::enable_if_t<std::is_arithmetic<S>::value,int> = 0 > + std::enable_if_t<std::is_convertible<S,T>::value, int> = 0> FieldMatrix<T,M,N> operator*(FieldMatrix<T, M, N> A, S scalar); - template <class T, int M, int N > - FieldMatrix<T,M,N> operator*(FieldMatrix<T,1,1> scalar, FieldMatrix<T, M, N> A); - - template <class T, int M, int N > - FieldMatrix<T,M,N> operator*(FieldMatrix<T, M, N> A, FieldMatrix<T,1,1> scalar); - - template <class T, int N, int M> - FieldMatrix<T,N,M> operator*(FieldVector<T,1> scalar, FieldMatrix<T,N,M> mat); - - template <class T, int N> - FieldMatrix<T,N,1> operator*(FieldVector<T,1> scalar, FieldMatrix<T,N,1> mat); - - template <class T, int N, int M> - FieldMatrix<T,N,M> operator*(FieldMatrix<T,N,M> mat, FieldVector<T,1> scalar); - - template <class T, int N> - FieldMatrix<T,N,1> operator*(FieldMatrix<T,N,1> mat, FieldVector<T,1> scalar); + template <class T> + FieldMatrix<T,1,1> operator*(FieldMatrix<T,1,1> lhs, FieldMatrix<T,1,1> rhs); template <class T, int M, int N, class S, - std::enable_if_t<std::is_arithmetic<S>::value,int> = 0 > + std::enable_if_t<std::is_convertible<S,T>::value, int> = 0> FieldMatrix<T,M,N> operator/(FieldMatrix<T, M, N> A, S scalar); @@ -258,12 +252,25 @@ namespace Dune FieldMatrix<T,M,N>& multiplies_ABt(FieldMatrix<T, M, N> const& A, DiagonalMatrix<T, N> const& B, FieldMatrix<T,M,N>& C); +// ----------------------------------------------------------------------------- + + template <class T> + T operator*(FieldVector<T,1> lhs, FieldMatrix<T,1,1> rhs); + + template <class T> + T operator*(FieldMatrix<T,1,1> lhs, FieldVector<T,1> rhs); + +// ----------------------------------------------------------------------------- + template <class T, int N> T const& at(FieldMatrix<T,N,1> const& vec, std::size_t i); template <class T, int M> T const& at(FieldMatrix<T,1,M> const& vec, std::size_t i); + template <class T> + T const& at(FieldMatrix<T,1,1> const& vec, std::size_t i); + template <class T, int N> T const& at(FieldVector<T,N> const& vec, std::size_t i); diff --git a/src/amdis/common/FieldMatVec.inc.hpp b/src/amdis/common/FieldMatVec.inc.hpp index e974e4eca0f129606e3181264b19b5d7797520ec..bbccc158375fe0e220751cfb4d6fc4d3b5dd7343 100644 --- a/src/amdis/common/FieldMatVec.inc.hpp +++ b/src/amdis/common/FieldMatVec.inc.hpp @@ -20,42 +20,30 @@ FieldVector<T,N> operator-(FieldVector<T,N> v) } template <class T, int N, class S, - std::enable_if_t<std::is_arithmetic<S>::value,int> > + std::enable_if_t<std::is_convertible<S,T>::value, int>> FieldVector<T,N> operator*(FieldVector<T,N> v, S factor) { return v *= factor; } -template <class S, class T, int N, - std::enable_if_t<std::is_arithmetic<S>::value,int> > +template <class T, int N, class S, + std::enable_if_t<std::is_convertible<S,T>::value, int>> FieldVector<T,N> operator*(S factor, FieldVector<T,N> v) { return v *= factor; } template <class T, int N, class S, - std::enable_if_t<std::is_arithmetic<S>::value,int> > + std::enable_if_t<std::is_convertible<S,T>::value, int>> FieldVector<T,N> operator/(FieldVector<T,N> v, S factor) { return v /= factor; } template <class T> -FieldVector<T,1> operator*(FieldVector<T,1> const& v, FieldVector<T,1> const& w) -{ - return {v[0] * w[0]}; -} - -template <class T, int N> -FieldVector<T,N> operator*(FieldVector<T,1> const& factor, FieldVector<T,N> v) -{ - return v *= factor[0]; -} - -template <class T, int N> -FieldVector<T,N> operator*(FieldVector<T,N> v, FieldVector<T,1> const& factor) +FieldVector<T,1> operator*(FieldVector<T,1> v, FieldVector<T,1> w) { - return v *= factor[0]; + return v *= w[0]; } // ---------------------------------------------------------------------------- @@ -83,13 +71,13 @@ auto dot(FieldVector<T,N> const& vec1, FieldVector<S,N> const& vec2) return vec1.dot(vec2); } -template <class T, int N, int M, - std::enable_if_t<( N!=1 && M!=1 ),int> > -auto operator*(FieldVector<T,N> const& v, FieldVector<T,M> const& w) -{ - static_assert(M == N, "Requires vectors of the same type!"); - return v.dot(w); -} +// template <class T, int N, int M, +// std::enable_if_t<( N!=1 && M!=1 ),int> > +// auto operator*(FieldVector<T,N> const& v, FieldVector<T,M> const& w) +// { +// static_assert(M == N, "Requires vectors of the same type!"); +// return v.dot(w); +// } template <class T, class S, int N> auto dot(FieldMatrix<T,1,N> const& vec1, FieldMatrix<S,1,N> const& vec2) @@ -377,57 +365,27 @@ FieldMatrix<T,N,M> trans(FieldMatrix<T, M, N> const& A) template <class T, int M, int N, class S, - std::enable_if_t<std::is_arithmetic<S>::value,int> > + std::enable_if_t<std::is_convertible<S,T>::value, int>> FieldMatrix<T,M,N> operator*(S scalar, FieldMatrix<T, M, N> A) { return A *= scalar; } template <class T, int M, int N, class S, - std::enable_if_t<std::is_arithmetic<S>::value,int> > + std::enable_if_t<std::is_convertible<S,T>::value, int>> FieldMatrix<T,M,N> operator*(FieldMatrix<T, M, N> A, S scalar) { return A *= scalar; } -template <class T, int N, int M> -FieldMatrix<T,N,M> operator*(FieldMatrix<T,1,1> scalar, FieldMatrix<T,N,M> mat) -{ - return mat *= scalar[0][0]; -} - -template <class T, int N, int M> -FieldMatrix<T,N,M> operator*(FieldMatrix<T,N,M> mat, FieldMatrix<T,1,1> scalar) -{ - return mat *= scalar[0][0]; -} - -template <class T, int N, int M> -FieldMatrix<T,N,M> operator*(FieldVector<T,1> scalar, FieldMatrix<T,N,M> mat) -{ - return mat *= scalar[0]; -} - -template <class T, int N> -FieldMatrix<T,N,1> operator*(FieldVector<T,1> scalar, FieldMatrix<T,N,1> mat) -{ - return mat *= scalar[0]; -} - -template <class T, int N, int M> -FieldMatrix<T,N,M> operator*(FieldMatrix<T,N,M> mat, FieldVector<T,1> scalar) -{ - return mat *= scalar[0]; -} - -template <class T, int N> -FieldMatrix<T,N,1> operator*(FieldMatrix<T,N,1> mat, FieldVector<T,1> scalar) +template <class T> +FieldMatrix<T,1,1> operator*(FieldMatrix<T,1,1> lhs, FieldMatrix<T,1,1> rhs) { - return mat *= scalar[0]; + return lhs *= rhs[0][0]; } template <class T, int M, int N, class S, - std::enable_if_t<std::is_arithmetic<S>::value,int> > + std::enable_if_t<std::is_convertible<S,T>::value, int>> FieldMatrix<T,M,N> operator/(FieldMatrix<T, M, N> A, S scalar) { return A /= scalar; @@ -517,6 +475,19 @@ FieldMatrix<T,M,N>& multiplies_ABt(FieldMatrix<T, M, N> const& A, DiagonalMatri } +template <class T> +T operator*(FieldVector<T,1> lhs, FieldMatrix<T,1,1> rhs) +{ + return lhs[0]*rhs[0][0]; +} + +template <class T> +T operator*(FieldMatrix<T,1,1> lhs, FieldVector<T,1> rhs) +{ + return lhs[0][0]*rhs[0]; +} + + template <class T, int N> T const& at(FieldMatrix<T,N,1> const& vec, std::size_t i) { @@ -529,6 +500,12 @@ T const& at(FieldMatrix<T,1,M> const& vec, std::size_t i) return vec[0][i]; } +template <class T> +T const& at(FieldMatrix<T,1,1> const& vec, std::size_t i) +{ + return vec[0][i]; +} + template <class T, int N> T const& at(FieldVector<T,N> const& vec, std::size_t i) { diff --git a/test/FieldMatVecTest.cpp b/test/FieldMatVecTest.cpp index f5abf8c3588c60b5bb6fa46636fb5112d2a55b06..9f3768389b75c2da00afeec4c0e4b4d914d6cf57 100644 --- a/test/FieldMatVecTest.cpp +++ b/test/FieldMatVecTest.cpp @@ -128,12 +128,88 @@ void test2() AMDIS_TEST_EQ( sol, a ); } +// test of scalar wrapper FieldVector<T,1> and FieldMatrix<T,1,1> +void test3() +{ + using V = FieldVector<double, 3>; + using M = FieldMatrix<double, 3, 3>; + + V a{1.0, 2.0, 3.0}; + M A{ {1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 1.0} }; + + using VD = FieldVector<double, 1>; + using MD = FieldMatrix<double, 1, 1>; + + VD vd1 = 1.0; + MD md1 = 1.0; + + using VI = FieldVector<int, 1>; + using MI = FieldMatrix<int, 1, 1>; + + VI vi1 = 1; + MI mi1 = 1; + + // scale a vector + AMDIS_TEST_EQ( 1*a, a ); + AMDIS_TEST_EQ( 1.0*a, a ); + AMDIS_TEST_EQ( a*1, a ); + AMDIS_TEST_EQ( a*1.0, a ); + AMDIS_TEST_EQ( vd1*a, a ); + AMDIS_TEST_EQ( a*vd1, a ); + AMDIS_TEST_EQ( vi1*a, a ); + AMDIS_TEST_EQ( a*vi1, a ); + AMDIS_TEST_EQ( md1*a, a ); + AMDIS_TEST_EQ( a*md1, a ); + AMDIS_TEST_EQ( mi1*a, a ); + AMDIS_TEST_EQ( a*mi1, a ); + AMDIS_TEST_EQ( a/1, a ); + AMDIS_TEST_EQ( a/1.0, a ); + AMDIS_TEST_EQ( a/vd1, a ); + AMDIS_TEST_EQ( a/vi1, a ); + AMDIS_TEST_EQ( a/md1, a ); + AMDIS_TEST_EQ( a/mi1, a ); + + // scale a matrix + AMDIS_TEST_EQ( 1*A, A ); + AMDIS_TEST_EQ( 1.0*A, A ); + AMDIS_TEST_EQ( A*1, A ); + AMDIS_TEST_EQ( A*1.0, A ); + AMDIS_TEST_EQ( vd1*A, A ); + AMDIS_TEST_EQ( A*vd1, A ); + AMDIS_TEST_EQ( vi1*A, A ); + AMDIS_TEST_EQ( A*vi1, A ); + AMDIS_TEST_EQ( md1*A, A ); + AMDIS_TEST_EQ( A*md1, A ); + AMDIS_TEST_EQ( mi1*A, A ); + AMDIS_TEST_EQ( A*mi1, A ); + AMDIS_TEST_EQ( A/1, A ); + AMDIS_TEST_EQ( A/1.0, A ); + AMDIS_TEST_EQ( A/vd1, A ); + AMDIS_TEST_EQ( A/vi1, A ); + AMDIS_TEST_EQ( A/md1, A ); + AMDIS_TEST_EQ( A/mi1, A ); + + AMDIS_TEST_EQ( vd1*vd1, 1.0 ); + AMDIS_TEST_EQ( vd1*md1, 1.0 ); + AMDIS_TEST_EQ( md1*md1, 1.0 ); + AMDIS_TEST_EQ( md1*vd1, 1.0 ); + AMDIS_TEST_EQ( vd1*1.0, 1.0 ); + AMDIS_TEST_EQ( md1*1.0, 1.0 ); + AMDIS_TEST_EQ( 1.0*md1, 1.0 ); + AMDIS_TEST_EQ( 1.0*vd1, 1.0 ); + AMDIS_TEST_EQ( vi1*vi1, 1 ); + AMDIS_TEST_EQ( vi1*mi1, 1 ); + AMDIS_TEST_EQ( mi1*mi1, 1 ); + AMDIS_TEST_EQ( mi1*vi1, 1 ); +} + int main() { test0(); test1(); test2(); + test3(); return report_errors(); }