From d7cb77f94cba60367e94776977303214bddbc8b4 Mon Sep 17 00:00:00 2001
From: Oliver Sander <oliver.sander@tu-dresden.de>
Date: Thu, 3 Dec 2015 14:16:16 +0100
Subject: [PATCH] Use a MultiTypeBlockMatrix instead of four separate matrices

---
 dune/gfe/mixedriemanniantrsolver.cc | 49 +++++++++++++----------------
 dune/gfe/mixedriemanniantrsolver.hh |  8 ++---
 2 files changed, 25 insertions(+), 32 deletions(-)

diff --git a/dune/gfe/mixedriemanniantrsolver.cc b/dune/gfe/mixedriemanniantrsolver.cc
index 345919c4..9e770d1b 100644
--- a/dune/gfe/mixedriemanniantrsolver.cc
+++ b/dune/gfe/mixedriemanniantrsolver.cc
@@ -184,10 +184,7 @@ setup(const GridType& grid,
 
     // \todo Why are the hessianMatrix objects class members at all, and not local to 'solve'?
 
-    hessianMatrix00_ = std::unique_ptr<MatrixType00>(new MatrixType00);
-    hessianMatrix01_ = std::unique_ptr<MatrixType01>(new MatrixType01);
-    hessianMatrix10_ = std::unique_ptr<MatrixType10>(new MatrixType10);
-    hessianMatrix11_ = std::unique_ptr<MatrixType11>(new MatrixType11);
+    hessianMatrix_ = std::make_unique<MatrixType>();
 
     // ////////////////////////////////////
     //   Create the transfer operators
@@ -321,16 +318,15 @@ void MixedRiemannianTrustRegionSolver<GridType,Basis0,TargetSpace0,Basis1,Target
     //   Trust-Region Solver
     // /////////////////////////////////////////////////////
 
+    using namespace Dune::TypeTree::Indices;
+
     double oldEnergy = assembler_->computeEnergy(x0_, x1_);
     oldEnergy = mpiHelper.getCollectiveCommunication().sum(oldEnergy);
 
     bool recomputeGradientHessian = true;
     CorrectionType0 rhs0;
     CorrectionType1 rhs1;
-    MatrixType00 stiffnessMatrix00;
-    MatrixType01 stiffnessMatrix01;
-    MatrixType10 stiffnessMatrix10;
-    MatrixType11 stiffnessMatrix11;
+    MatrixType stiffnessMatrix;
     CorrectionType0 rhs_global0;
     CorrectionType1 rhs_global1;
 #if 0
@@ -362,10 +358,10 @@ void MixedRiemannianTrustRegionSolver<GridType,Basis0,TargetSpace0,Basis1,Target
                                                    x1_,
                                                    rhs0,
                                                    rhs1,
-                                                   *hessianMatrix00_,
-                                                   *hessianMatrix01_,
-                                                   *hessianMatrix10_,
-                                                   *hessianMatrix11_,
+                                                   (*hessianMatrix_)[_0][_0],
+                                                   (*hessianMatrix_)[_0][_1],
+                                                   (*hessianMatrix_)[_1][_0],
+                                                   (*hessianMatrix_)[_1][_1],
                                                    i==0    // assemble occupation pattern only for the first call
                                                    );
 
@@ -382,10 +378,7 @@ void MixedRiemannianTrustRegionSolver<GridType,Basis0,TargetSpace0,Basis1,Target
             stiffnessMatrix10 = matrixComm.reduceAdd(*hessianMatrix10_);
             stiffnessMatrix11 = matrixComm.reduceAdd(*hessianMatrix11_);
 #endif
-            stiffnessMatrix00 = *hessianMatrix00_;
-            stiffnessMatrix01 = *hessianMatrix01_;
-            stiffnessMatrix10 = *hessianMatrix10_;
-            stiffnessMatrix11 = *hessianMatrix11_;
+            stiffnessMatrix = *hessianMatrix_;
 
             // Transfer vector data
 #if 0
@@ -412,13 +405,13 @@ void MixedRiemannianTrustRegionSolver<GridType,Basis0,TargetSpace0,Basis1,Target
             CorrectionType0 residual0 = rhs_global0;
             CorrectionType1 residual1 = rhs_global1;
 
-            mmgStep0->setProblem(stiffnessMatrix00, corr_global0, residual0);
+            mmgStep0->setProblem(stiffnessMatrix[_0][_0], corr_global0, residual0);
             trustRegionObstacles0 = trustRegion0.obstacles();
             mmgStep0->obstacles_ = &trustRegionObstacles0;
 
             mmgStep0->preprocess();
 
-            mmgStep1->setProblem(stiffnessMatrix11, corr_global1, residual1);
+            mmgStep1->setProblem(stiffnessMatrix[_1][_1], corr_global1, residual1);
             trustRegionObstacles1 = trustRegion1.obstacles();
             mmgStep1->obstacles_ = &trustRegionObstacles1;
 
@@ -431,23 +424,23 @@ void MixedRiemannianTrustRegionSolver<GridType,Basis0,TargetSpace0,Basis1,Target
             for (int ii=0; ii<innerIterations_; ii++)
             {
               residual0 = rhs_global0;
-              stiffnessMatrix01.mmv(corr_global1, residual0);
+              stiffnessMatrix[_0][_1].mmv(corr_global1, residual0);
               mmgStep0->setRhs(residual0);
               mmgStep0->iterate();
 
               residual1 = rhs_global1;
-              stiffnessMatrix10.mmv(corr_global0, residual1);
+              stiffnessMatrix[_1][_0].mmv(corr_global0, residual1);
               mmgStep1->setRhs(residual1);
               mmgStep1->iterate();
 
               // Compute energy
               CorrectionType0 tmp0(corr_global0);
-              stiffnessMatrix00.mv(corr_global0,tmp0);
-              stiffnessMatrix01.umv(corr_global1,tmp0);
+              stiffnessMatrix[_0][_0].mv(corr_global0,tmp0);
+              stiffnessMatrix[_0][_1].umv(corr_global1,tmp0);
 
               CorrectionType1 tmp1(corr_global1);
-              stiffnessMatrix10.mv(corr_global0,tmp1);
-              stiffnessMatrix11.umv(corr_global1,tmp1);
+              stiffnessMatrix[_1][_0].mv(corr_global0,tmp1);
+              stiffnessMatrix[_1][_1].umv(corr_global1,tmp1);
 
               double energy = 0.5 * (tmp0*corr_global0 + tmp1*corr_global1) - (rhs_global0*corr_global0 + rhs_global1*corr_global1);
 
@@ -504,13 +497,13 @@ void MixedRiemannianTrustRegionSolver<GridType,Basis0,TargetSpace0,Basis1,Target
         // Note that rhs = -g
         CorrectionType0 tmp0(corr0.size());
         tmp0 = 0;
-        hessianMatrix00_->umv(corr0, tmp0);
-        hessianMatrix01_->umv(corr1, tmp0);
+        (*hessianMatrix_)[_0][_0].umv(corr0, tmp0);
+        (*hessianMatrix_)[_0][_1].umv(corr1, tmp0);
 
         CorrectionType1 tmp1(corr1.size());
         tmp1 = 0;
-        hessianMatrix10_->umv(corr0, tmp1);
-        hessianMatrix11_->umv(corr1, tmp1);
+        (*hessianMatrix_)[_1][_0].umv(corr0, tmp1);
+        (*hessianMatrix_)[_1][_1].umv(corr1, tmp1);
 
         double modelDecrease = (rhs0*corr0+rhs1*corr1) - 0.5 * (corr0*tmp0+corr1*tmp1);
         modelDecrease = mpiHelper.getCollectiveCommunication().sum(modelDecrease);
diff --git a/dune/gfe/mixedriemanniantrsolver.hh b/dune/gfe/mixedriemanniantrsolver.hh
index 4c311e0c..4807dd90 100644
--- a/dune/gfe/mixedriemanniantrsolver.hh
+++ b/dune/gfe/mixedriemanniantrsolver.hh
@@ -7,6 +7,7 @@
 
 #include <dune/istl/bcrsmatrix.hh>
 #include <dune/istl/bvector.hh>
+#include <dune/istl/multitypeblockmatrix.hh>
 
 #include <dune/grid/utility/globalindexset.hh>
 
@@ -38,6 +39,8 @@ class MixedRiemannianTrustRegionSolver
     typedef Dune::BCRSMatrix<Dune::FieldMatrix<field_type, blocksize0, blocksize1> > MatrixType01;
     typedef Dune::BCRSMatrix<Dune::FieldMatrix<field_type, blocksize1, blocksize0> > MatrixType10;
     typedef Dune::BCRSMatrix<Dune::FieldMatrix<field_type, blocksize1, blocksize1> > MatrixType11;
+    typedef Dune::MultiTypeBlockMatrix<Dune::MultiTypeBlockVector<MatrixType00,MatrixType01>,
+                                       Dune::MultiTypeBlockVector<MatrixType10,MatrixType11> > MatrixType;
     typedef Dune::BlockVector<Dune::FieldVector<field_type, blocksize0> >             CorrectionType0;
     typedef Dune::BlockVector<Dune::FieldVector<field_type, blocksize1> >             CorrectionType1;
     typedef std::vector<TargetSpace0>                                                SolutionType0;
@@ -140,10 +143,7 @@ protected:
     double innerTolerance_;
 
     /** \brief Hessian matrix */
-    std::unique_ptr<MatrixType00> hessianMatrix00_;
-    std::unique_ptr<MatrixType01> hessianMatrix01_;
-    std::unique_ptr<MatrixType10> hessianMatrix10_;
-    std::unique_ptr<MatrixType11> hessianMatrix11_;
+    std::unique_ptr<MatrixType> hessianMatrix_;
 
     /** \brief The assembler for the material law */
     const MixedGFEAssembler<Basis0, TargetSpace0, Basis1, TargetSpace1>* assembler_;
-- 
GitLab