Commit 071a86f2 authored by Praetorius, Simon's avatar Praetorius, Simon
Browse files

Add utility eliminateRows

parent 832af288
Pipeline #5668 failed with stage
in 35 minutes and 22 seconds
......@@ -25,6 +25,13 @@ namespace AMDiS
/* do nothing */
warning("periodicBC not implemented for this matrix type.");
}
template <class Mat, class BitVec, class Assoc>
static void eliminateRows(Mat& /*matrix*/, BitVec const& /*left*/, bool /*setDiagonal*/ = true)
{
/* do nothing */
warning("eliminateRows not implemented for this matrix type.");
}
};
template <class Mat, class Sol, class Rhs, class BitVec>
......@@ -39,6 +46,12 @@ namespace AMDiS
Constraints<Mat>::periodicBC(matrix, solution, rhs, left, association, setDiagonal);
}
template <class Mat, class BitVec>
void eliminateRows(Mat& matrix, BitVec const& nodes, bool setDiagonal = true)
{
Constraints<Mat>::eliminateRows(matrix, nodes, setDiagonal);
}
template <class RB, class CB, class T, class Traits>
struct Constraints<BiLinearForm<RB,CB,T,Traits>>
......@@ -56,6 +69,12 @@ namespace AMDiS
{
AMDiS::periodicBC(matrix.impl(), solution.impl(), rhs.impl(), left, association, setDiagonal);
}
template <class BitVec>
static void eliminateRows(Matrix& matrix, BitVec const& nodes, bool setDiagonal = true)
{
AMDiS::eliminateRows(matrix.impl(), nodes, setDiagonal);
}
};
} // end namespace AMDiS
......@@ -4,7 +4,6 @@
#include <Eigen/SparseCore>
#include <amdis/common/Index.hpp>
#include <amdis/linearalgebra/Constraints.hpp>
#include <amdis/linearalgebra/eigen/MatrixBackend.hpp>
#include <amdis/linearalgebra/eigen/VectorBackend.hpp>
......@@ -20,7 +19,7 @@ namespace AMDiS
template <class BitVector>
static void dirichletBC(Matrix& mat, Vector& sol, Vector& rhs, BitVector const& nodes, bool setDiagonal = true)
{
clearDirichletRow(mat.matrix(), nodes, setDiagonal);
eliminateRows(mat, nodes, setDiagonal);
// copy solution dirichlet data to rhs vector
for (typename Vector::size_type i = 0; i < sol.vector().size(); ++i) {
......@@ -36,9 +35,14 @@ namespace AMDiS
error_exit("Not implemented");
}
protected:
template <class BitVector>
static void clearDirichletRow(Eigen::SparseMatrix<T, Eigen::ColMajor>& mat, BitVector const& nodes, bool setDiagonal)
static void eliminateRows(Matrix& mat, BitVector const& nodes, bool setDiagonal)
{
AMDiS::eliminateRows(mat.matrix(), nodes, setDiagonal);
}
template <class BitVector>
static void eliminateRows(Eigen::SparseMatrix<T, Eigen::ColMajor>& mat, BitVector const& nodes, bool setDiagonal)
{
using Mat = Eigen::SparseMatrix<T, Eigen::ColMajor>;
for (typename Mat::Index c = 0; c < mat.outerSize(); ++c) {
......@@ -52,7 +56,7 @@ namespace AMDiS
}
template <class BitVector>
static void clearDirichletRow(Eigen::SparseMatrix<T, Eigen::RowMajor>& mat, BitVector const& nodes, bool setDiagonal)
static void eliminateRows(Eigen::SparseMatrix<T, Eigen::RowMajor>& mat, BitVector const& nodes, bool setDiagonal)
{
using Mat = Eigen::SparseMatrix<T, Eigen::RowMajor>;
for (typename Mat::Index r = 0; r < mat.outerSize(); ++r) {
......
#pragma once
#include <dune/common/ftraits.hh>
#include <amdis/Output.hpp>
#include <amdis/linearalgebra/Constraints.hpp>
#include <amdis/linearalgebra/istl/MatrixBackend.hpp>
......@@ -11,35 +13,55 @@ namespace AMDiS
struct Constraints<ISTLBCRSMatrix<T,C>>
{
using Matrix = ISTLBCRSMatrix<T,C>;
using Vector = ISTLBlockVector<T>;
using VectorX = ISTLBlockVector<T>;
using VectorY = ISTLBlockVector<T>;
template <class BitVector>
static void dirichletBC(Matrix& mat, Vector& sol, Vector& rhs, BitVector const& nodes, bool setDiagonal = true)
static void dirichletBC(Matrix& mat, VectorX& sol, VectorY& rhs, BitVector const& nodes, bool setDiagonal = true)
{
// loop over the matrix rows
for (std::size_t i = 0; i < mat.matrix().N(); ++i) {
if (nodes[i]) {
auto cIt = mat.matrix()[i].begin();
auto cEndIt = mat.matrix()[i].end();
// loop over nonzero matrix entries in current row
for (; cIt != cEndIt; ++cIt)
*cIt = (setDiagonal && i == cIt.index() ? T(1) : T(0));
}
}
eliminateRows(mat, nodes, setDiagonal);
// copy solution dirichlet data to rhs vector
for (std::size_t i = 0; i < sol.vector().size(); ++i) {
for (std::size_t i = 0; i < sol.size(); ++i) {
if (nodes[i])
rhs.vector()[i] = sol.vector()[i];
rhs[i] = sol[i];
}
}
template <class BitVector, class Associations>
static void periodicBC(Matrix& mat, Vector& sol, Vector& rhs, BitVector const& left, Associations const& left2right,
static void periodicBC(Matrix& mat, VectorX& sol, VectorY& rhs, BitVector const& left, Associations const& left2right,
bool setDiagonal = true)
{
error_exit("Not implemented");
}
template <class BitVector>
static void eliminateRows(Matrix& mat, BitVector const& nodes, bool setDiagonal = true)
{
AMDiS::eliminateRows(mat.matrix(), nodes, setDiagonal);
}
};
template <class B, class A>
struct Constraints<Dune::BCRSMatrix<B,A>>
{
using Matrix = Dune::BCRSMatrix<B,A>;
using T = typename Dune::FieldTraits<B>::field_type;
template <class BitVector>
static void eliminateRows(Matrix& mat, BitVector const& nodes, bool setDiagonal)
{
// loop over the matrix rows
for (std::size_t i = 0; i < mat.N(); ++i) {
if (nodes[i]) {
auto cIt = mat[i].begin();
auto cEndIt = mat[i].end();
// loop over nonzero matrix entries in current row
for (; cIt != cEndIt; ++cIt)
*cIt = (setDiagonal && i == cIt.index() ? T(1) : T(0));
}
}
}
};
} // end namespace AMDiS
......@@ -25,10 +25,18 @@ namespace AMDiS
static void dirichletBC(Matrix& mat, Vector& sol, Vector& rhs, BitVector const& nodes, bool setDiagonal = true)
{
SymmetryStructure const symmetry = mat.symmetry();
if (symmetry == SymmetryStructure::spd || symmetry == SymmetryStructure::symmetric || symmetry == SymmetryStructure::hermitian)
if (symmetry == SymmetryStructure::spd ||
symmetry == SymmetryStructure::symmetric ||
symmetry == SymmetryStructure::hermitian)
symmetricDirichletBC(mat.matrix(), sol.vector(), rhs.vector(), nodes, setDiagonal);
else
unsymmetricDirichletBC(mat.matrix(), sol.vector(), rhs.vector(), nodes, setDiagonal);
eliminateRows(mat.matrix(), nodes, setDiagonal);
// copy solution dirichlet data to rhs vector
for (typename Vec::size_type i = 0; i < mtl::size(sol); ++i) {
if (nodes[i])
rhs[i] = sol[i];
}
}
template <class Mat, class Vec, class BitVector>
......@@ -65,55 +73,15 @@ namespace AMDiS
ins[i][i] = 1;
}
}
// copy solution dirichlet data to rhs vector
for (typename Vec::size_type i = 0; i < mtl::size(sol); ++i) {
if (nodes[i])
rhs[i] = sol[i];
}
}
template <class Mat, class Vec, class BitVector>
static void unsymmetricDirichletBC(Mat& mat, Vec& sol, Vec& rhs, BitVector const& nodes, bool setDiagonal = true)
{
// Define the property maps
auto row = mtl::mat::row_map(mat);
auto col = mtl::mat::col_map(mat);
auto value = mtl::mat::value_map(mat);
// iterate over the matrix
for (auto r : mtl::rows_of(mat)) { // rows of the matrix
if (nodes[r.value()]) {
for (auto i : mtl::nz_of(r)) { // non-zeros within
// set identity row
value(i, (setDiagonal && row(i) == col(i) ? 1 : 0) );
}
}
}
// copy solution dirichlet data to rhs vector
for (typename Vec::size_type i = 0; i < mtl::size(sol); ++i) {
if (nodes[i])
rhs[i] = sol[i];
}
}
template <class Associations>
static std::size_t at(Associations const& m, std::size_t idx)
{
auto it = m.find(idx);
assert(it != m.end());
return it->second;
}
template <class BitVector, class Associations>
static void periodicBC(Matrix& mat, Vector& sol, Vector& rhs, BitVector const& left, Associations const& left2right,
bool setDiagonal = true)
static void periodicBC(Matrix& mat, Vector& sol, Vector& rhs, BitVector const& left, Associations const& left2right, bool setDiagonal = true)
{
SymmetryStructure const symmetry = mat.symmetry();
if (symmetry == SymmetryStructure::spd || symmetry == SymmetryStructure::symmetric || symmetry == SymmetryStructure::hermitian)
if (symmetry == SymmetryStructure::spd ||
symmetry == SymmetryStructure::symmetric ||
symmetry == SymmetryStructure::hermitian)
symmetricPeriodicBC(mat.matrix(), sol.vector(), rhs.vector(), left, left2right, setDiagonal);
else
unsymmetricPeriodicBC(mat.matrix(), sol.vector(), rhs.vector(), left, left2right, setDiagonal);
......@@ -121,8 +89,7 @@ namespace AMDiS
template <class Mat, class Vec, class BitVector, class Associations>
static void symmetricPeriodicBC(Mat& mat, Vec& sol, Vec& rhs, BitVector const& left, Associations const& left2right,
bool setDiagonal = true)
static void symmetricPeriodicBC(Mat& mat, Vec& sol, Vec& rhs, BitVector const& left, Associations const& left2right, bool setDiagonal = true)
{
error_exit("Not implemented");
}
......@@ -136,8 +103,7 @@ namespace AMDiS
};
template <class Mat, class Vec, class BitVector, class Associations>
static void unsymmetricPeriodicBC(Mat& mat, Vec& sol, Vec& rhs, BitVector const& left, Associations const& left2right,
bool setDiagonal = true)
static void unsymmetricPeriodicBC(Mat& mat, Vec& sol, Vec& rhs, BitVector const& left, Associations const& left2right, bool setDiagonal = true)
{
std::vector<Triplet<typename Mat::value_type>> rowValues;
rowValues.reserve(left2right.size()*std::size_t(mat.nnz()/(0.9*num_rows(mat))));
......@@ -152,7 +118,7 @@ namespace AMDiS
for (auto r : mtl::rows_of(mat)) {
if (left[r.value()]) {
slotSize = std::max(slotSize, std::size_t(mat.nnz_local(r.value())));
std::size_t right = at(left2right,r.value());
std::size_t right = left2right.at(r.value());
for (auto i : mtl::nz_of(r)) {
rowValues.push_back({right,col(i),value(i)});
......@@ -167,7 +133,7 @@ namespace AMDiS
for (std::size_t i = 0; i < mtl::size(left); ++i) {
if (left[i]) {
std::size_t j = at(left2right,i);
std::size_t j = left2right.at(i);
if (setDiagonal) {
ins[i][i] = 1;
ins[i][j] = -1;
......@@ -181,6 +147,31 @@ namespace AMDiS
}
}
template <class BitVector>
static void eliminateRows(Matrix& mat, BitVector const& nodes, bool setDiagonal = true)
{
AMDiS::eliminateRows(mat.matrix(), nodes, setDiagonal);
}
template <class Mat, class BitVector>
static void eliminateRows(Mat& mat, BitVector const& nodes, bool setDiagonal = true)
{
// Define the property maps
auto row = mtl::mat::row_map(mat);
auto col = mtl::mat::col_map(mat);
auto value = mtl::mat::value_map(mat);
// iterate over the matrix
for (auto r : mtl::rows_of(mat)) { // rows of the matrix
if (nodes[r.value()]) {
for (auto i : mtl::nz_of(r)) { // non-zeros within
// set identity row
value(i, (setDiagonal && row(i) == col(i) ? 1 : 0) );
}
}
}
}
};
} // end namespace AMDiS
......@@ -57,7 +57,7 @@ namespace AMDiS
for (PetscInt i = 0; i < PetscInt(left.size()); ++i) {
if (left[i]) {
// get global row index
PetscInt row_local[2] = {i, at(left2right,i)};
PetscInt row_local[2] = {i, left2right.at(i)};
PetscInt row[2] = {dofMap.global(row_local[0]), dofMap.global(row_local[1])};
rows.push_back(row[0]);
......@@ -108,13 +108,17 @@ namespace AMDiS
VecAssemblyEnd(x.vector());
}
private:
template <class Associations>
static PetscInt at(Associations const& m, std::size_t idx)
template <class BitVector>
static void eliminateRows(Matrix& mat, BitVector const& nodes, bool setDiagonal = true)
{
auto it = m.find(idx);
assert(it != m.end());
return it->second;
thread_local std::vector<PetscInt> rows;
rows.clear();
auto const& dofMap = mat.dofMap_;
for (std::size_t i = 0; i < nodes.size(); ++i)
if (nodes[i])
rows.push_back(dofMap.global(i));
MatZeroRows(mat.matrix(), rows.size(), rows.data(), setDiagonal ? 1.0 : 0.0, PETSC_NULL, PETSC_NULL);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment