BlockMTLMatrix.hpp 5.75 KB
Newer Older
1
/** \file BlockMTLMatrix.hpp */
2
3
4

#pragma once

5
6
#include <array>

7
8
9
10
#include <boost/numeric/mtl/matrices.hpp>

#include <dune/amdis/Basic.hpp>
#include <dune/amdis/Loops.hpp>
11
#include <dune/amdis/linear_algebra/LinearAlgebraBase.hpp>
12
13
14
15
16
17
18
19
20
21
22

namespace AMDiS
{
  /// A wrapper for AMDiS::SolverMatrix to be used in MTL/ITL solvers
  template <class MTLMatrix, size_t _N, size_t _M>
  class BlockMTLMatrix
    : public std::array<std::array<MTLMatrix, _M>, _N>
  {
    using Self = BlockMTLMatrix;
    
  public:
23
    /// The index/size - type
24
    using size_type  = typename MTLMatrix::size_type;
25
26
    
    /// The type of the elements of the MTLMatrix
27
    using value_type = typename MTLMatrix::value_type;
28
29
30
    
    /// The underlying mtl matrix type
    using BaseMatrix = MTLMatrix;
31

32
33
  public:
    /// Return the (R,C)'th matrix block
34
35
36
37
38
39
40
    template <size_t R, size_t C>
    auto& operator()(const index_<R>, const index_<C>)
    {
      static_assert(R < N() && C < M(), "Indices out of range [0,N)x[0,M)");
      return std::get<C>(std::get<R>(*this));
    }
    
41
    /// Return the (R,C)'th matrix block
42
43
44
45
46
47
48
    template <size_t R, size_t C>
    auto const& operator()(const index_<R>, const index_<C>) const
    {
      static_assert(R < N() && C < M(), "Indices out of range [0,N)x[0,M)");
      return std::get<C>(std::get<R>(*this));
    }
    
49
    /// Return the number of row blocks
50
    static constexpr size_t N() { return _N; }
51
52
    
    /// Return the number of column blocks
53
54
55
56
57
58
59
60
61
62
63
    static constexpr size_t M() { return _M; }

    /// perform blockwise multiplication A*b -> x
    template <class VectorIn, class VectorOut, class Assign>
    void mult(VectorIn const& b, VectorOut& x, Assign) const
    {
      // create iranges to access array blocks
      std::array<mtl::irange, _N> r_rows;
      std::array<mtl::irange, _M> r_cols;
      getRanges(r_rows, r_cols);
      
64
      For<0, _N>::loop([&](const auto _i) {
65
        bool first = true;
66
67
        
        // a reference to the i'th block of x
68
        VectorOut x_i(x[r_rows[_i]]);
69
        For<0, _M>::loop([&](const auto _j) {
70
          auto const& A_ij = this->operator()(_i, _j);
71
          if (num_rows(A_ij) > 0) {
72
            // a reference to the j'th block of b
73
            const VectorIn b_j(b[r_cols[_j]]);
74
            if (first) {
75
              Assign::first_update(x_i, A_ij * b_j);
76
77
78
              first = false;
            }
            else {
79
              Assign::update(x_i, A_ij * b_j);
80
81
82
83
84
85
            }
          }
        });
      });
    }

86
87
    /// A Multiplication operator returns a multiplication-expresssion.
    /// Calls \ref mult internally.
88
    template <class VectorIn>
89
    mtl::vec::mat_cvec_multiplier<Self, VectorIn> 
90
91
92
93
94
    operator*(VectorIn const& v) const
    {
      return {*this, v};
    }
    
95
96
    /// Fill an array of irange corresponding to the row-sizes, used 
    /// to access sub-vectors
97
98
99
100
101
102
103
104
105
106
    void getRowRanges(std::array<mtl::irange, _N>& r_rows) const
    {      
      size_t start = 0;
      For<0, _N>::loop([&](const auto _r) {
        size_t finish = start + num_rows((*this)(_r,index_<0>()));
        r_rows[_r].set(start, finish);
        start = finish;
      });
    }
    
107
108
    /// Fill an array of irange corresponding to the column-sizes, used 
    /// to access sub-vectors
109
110
111
112
113
114
115
116
117
118
    void getColRanges(std::array<mtl::irange, _M>& r_cols) const
    {      
      size_t start = 0;
      For<0, _M>::loop([&](const auto _c) {
        size_t finish = start + num_cols((*this)(index_<0>(),_c));
        r_cols[_c].set(start, finish);
        start = finish;
      });
    }
    
119
120
    /// Fill two arrays of irange corresponding to row and column sizes.
    /// \see getRowRanges() and \see getColRanges()
121
122
    void getRanges(std::array<mtl::irange, _N>& r_rows, 
                   std::array<mtl::irange, _M>& r_cols) const
123
124
125
126
127
    {      
      getRowRanges(r_rows);
      getColRanges(r_cols);
    }
  };
128
129
130
131
132
133
134
135
136
137
138
  
  
  namespace Impl
  {
    /// Specialization of Impl::MTLMatrix from \file LinearAlgebraBase.hpp
    template <class MTLMatrix, size_t _N, size_t _M>
    struct BaseMatrix<BlockMTLMatrix<MTLMatrix, _N, _M>>
    {
      using type = MTLMatrix;
    };
  }
139

140
  /// Return the number of overall rows of a BlockMTLMatrix
141
142
143
144
145
146
147
148
149
150
  template <class MTLMatrix, size_t _N, size_t _M>
  inline size_t num_rows(BlockMTLMatrix<MTLMatrix, _N, _M> const& A)
  {
    size_t nRows = 0;
    For<0, _N>::loop([&](const auto _r) {
      nRows += num_rows(A(_r,index_<0>()));
    });
    return nRows;
  }

151
  /// Return the number of overall columns of a BlockMTLMatrix
152
153
154
155
156
157
158
159
160
161
  template <class MTLMatrix, size_t _N, size_t _M>
  inline size_t num_cols(BlockMTLMatrix<MTLMatrix, _N, _M> const& A)
  {
    size_t nCols = 0;
    For<0, _M>::loop([&](const auto _c) {
      nCols += num_cols(A(index_<0>(),_c));
    });
    return nCols;
  }

162
  /// Return the size, i.e. rows*columns of a BlockMTLMatrix
163
164
165
166
167
168
  template <class MTLMatrix, size_t _N, size_t _M>
  inline size_t size(BlockMTLMatrix<MTLMatrix, _N, _M> const& A)
  {
    return num_rows(A) * num_cols(A);
  }

169
  /// Nullify a BlockMTLMatrix, i.e. nullify each block.
170
171
172
173
174
175
176
177
178
179
180
181
  template <class MTLMatrix, size_t _N, size_t _M>
  inline void set_to_zero(BlockMTLMatrix<MTLMatrix, _N, _M>& A) 
  {
    For<0, _N>::loop([&](const auto _r) {
      For<0, _M>::loop([&](const auto _c) {
        set_to_zero(A(_r,_c));
      });
    });
  }

} // end namespace AMDiS

182
183

/// \cond HIDDEN_SYMBOLS
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
namespace mtl
{
  template <class MTLMatrix, size_t _N, size_t _M>
  struct Collection<AMDiS::BlockMTLMatrix<MTLMatrix, _N, _M>>
  {
    using value_type = typename MTLMatrix::value_type;
    using size_type  = typename MTLMatrix::size_type;
  };

  namespace ashape
  {
    template <class MTLMatrix, size_t _N, size_t _M>
    struct ashape_aux<AMDiS::BlockMTLMatrix<MTLMatrix, _N, _M>>
    {
      using type = nonscal;
    };

  } // end namespace ashape

} // end namespace mtl
204
/// \endcond