From 0099e1f4d7333f593f9aadf41a09ef2ea6ee9984 Mon Sep 17 00:00:00 2001
From: Oliver Sander <sander@igpm.rwth-aachen.de>
Date: Wed, 14 May 2014 16:20:25 +0000
Subject: [PATCH] Introduce methods reduceAdd and reduceCopy for the
 MatrixCommunicator

Much easier to use than the previous API.

[[Imported from SVN: r9735]]
---
 dune/gfe/parallel/matrixcommunicator.hh | 39 +++++++++++++------------
 dune/gfe/riemanniantrsolver.cc          | 15 +++-------
 2 files changed, 25 insertions(+), 29 deletions(-)

diff --git a/dune/gfe/parallel/matrixcommunicator.hh b/dune/gfe/parallel/matrixcommunicator.hh
index d3014da2..fcc6318b 100644
--- a/dune/gfe/parallel/matrixcommunicator.hh
+++ b/dune/gfe/parallel/matrixcommunicator.hh
@@ -11,7 +11,7 @@
 
 template<typename GUIndex, typename MatrixType>
 class MatrixCommunicator {
-public:
+
   struct TransferMatrixTuple {
     typedef typename MatrixType::block_type EntryType;
 
@@ -22,20 +22,6 @@ public:
     TransferMatrixTuple(const size_t& r, const size_t& c, const EntryType& e) : row(r), col(c), entry(e) {}
   };
 
-public:
-  MatrixCommunicator(const GUIndex& rowIndex, const int& root)
-  : guIndex1_(rowIndex),
-    guIndex2_(rowIndex),
-    root_rank(root)
-  {}
-
-  MatrixCommunicator(const GUIndex& rowIndex, const GUIndex& colIndex, const int& root)
-  : guIndex1_(rowIndex),
-    guIndex2_(colIndex),
-    root_rank(root)
-  {}
-
-
   void transferMatrix(const MatrixType& localMatrix) {
     // Create vector for transfer data
     std::vector<TransferMatrixTuple> localMatrixEntries;
@@ -58,8 +44,23 @@ public:
     globalMatrixEntries = MPIFunctions::gatherv(guIndex1_.getGridView(), localMatrixEntries, localMatrixEntriesSizes, root_rank);
   }
 
+public:
+  MatrixCommunicator(const GUIndex& rowIndex, const int& root)
+  : guIndex1_(rowIndex),
+    guIndex2_(rowIndex),
+    root_rank(root)
+  {}
+
+  MatrixCommunicator(const GUIndex& rowIndex, const GUIndex& colIndex, const int& root)
+  : guIndex1_(rowIndex),
+    guIndex2_(colIndex),
+    root_rank(root)
+  {}
+
+  MatrixType reduceAdd(const MatrixType& local)
+  {
+    transferMatrix(local);
 
-  MatrixType createGlobalMatrix() const {
     MatrixType globalMatrix;
 
     // Create occupation pattern in matrix
@@ -79,11 +80,13 @@ public:
     for(size_t k = 0; k < globalMatrixEntries.size(); ++k)
       globalMatrix[globalMatrixEntries[k].row][globalMatrixEntries[k].col] += globalMatrixEntries[k].entry;
 
-
     return globalMatrix;
   }
 
-  MatrixType copyIntoGlobalMatrix() const {
+  MatrixType reduceCopy(const MatrixType& local)
+  {
+    transferMatrix(local);
+
     MatrixType globalMatrix;
 
     // Create occupation pattern in matrix
diff --git a/dune/gfe/riemanniantrsolver.cc b/dune/gfe/riemanniantrsolver.cc
index 9b9a3337..78fa3d23 100644
--- a/dune/gfe/riemanniantrsolver.cc
+++ b/dune/gfe/riemanniantrsolver.cc
@@ -195,12 +195,10 @@ setup(const GridType& grid,
 
         typedef typename TruncatedCompressedMGTransfer<CorrectionType>::TransferOperatorType TransferOperatorType;
         MatrixCommunicator<LevelGUIndex, TransferOperatorType> matrixComm(fineGUIndex, coarseGUIndex, 0);
-        matrixComm.transferMatrix(newTransferOp->getMatrix());
 
-        if (rank==0) {
-            mmgStep->mgTransfer_[i] = new TruncatedCompressedMGTransfer<CorrectionType>
-                 (Dune::make_shared<TransferOperatorType>(matrixComm.copyIntoGlobalMatrix()));
-        }
+        mmgStep->mgTransfer_[i] = new TruncatedCompressedMGTransfer<CorrectionType>
+             (Dune::make_shared<TransferOperatorType>(matrixComm.reduceCopy(newTransferOp->getMatrix())));
+
     }
 #endif
 
@@ -341,16 +339,11 @@ void RiemannianTrustRegionSolver<GridType,TargetSpace>::solve()
               std::cout << "Assembly took " << gradientTimer.elapsed() << " sec." << std::endl;
 
             // Transfer matrix data
-            matrixComm.transferMatrix(*hessianMatrix_);
+            stiffnessMatrix = matrixComm.reduceAdd(*hessianMatrix_);
 
             // Transfer vector data
             rhs_global = vectorComm.reduceAdd(rhs);
 
-            if (rank ==0) {
-              // Create global stiffnessMatrix
-              stiffnessMatrix = matrixComm.createGlobalMatrix();
-            }
-
             recomputeGradientHessian = false;
 
         }
-- 
GitLab