From 2c9cd7c506e186ae4060144e6bb84515536a514d Mon Sep 17 00:00:00 2001
From: Lisa Julia Nebel <lisa_julia.nebel@tu-dresden.de>
Date: Wed, 2 Sep 2020 17:05:10 +0200
Subject: [PATCH] Adjust Riemannian trust-region solver for parallel runs

In case the solve step (which is only done on process 0) goes wrong
communicate it to all processes.
---
 dune/gfe/riemanniantrsolver.cc | 21 +++++++++++++++------
 1 file changed, 15 insertions(+), 6 deletions(-)

diff --git a/dune/gfe/riemanniantrsolver.cc b/dune/gfe/riemanniantrsolver.cc
index 8ccf0b99..b8604dfc 100644
--- a/dune/gfe/riemanniantrsolver.cc
+++ b/dune/gfe/riemanniantrsolver.cc
@@ -462,14 +462,14 @@ void RiemannianTrustRegionSolver<Basis,TargetSpace>::solve()
             } catch (Dune::Exception &e) {
                 std::cerr << "Error while solving: " << e << std::endl;
                 solved = false;
-                corr_global = 0;
             }
             std::cout << "Solving the quadratic problem took " << solutionTimer.elapsed() << " seconds." << std::endl;
             totalSolverTime += solutionTimer.elapsed();
 
-            if (mgStep && solved)
+            if (mgStep && solved) {
                 corr_global = mgStep->getSol();
                 std::cout << "Two norm of the correction: " << corr_global.two_norm() << std::endl;
+            }
         }
 
         // Distribute solution
@@ -477,7 +477,13 @@ void RiemannianTrustRegionSolver<Basis,TargetSpace>::solve()
             std::cout << "Transfer solution back to root process ..." << std::endl;
 
 #if HAVE_MPI
-        corr = vectorComm.scatter(corr_global);
+        solved = grid_->comm().min(solved);
+        if (solved) {
+            corr = vectorComm.scatter(corr_global);
+        } else  {
+            corr_global = 0;
+            corr = 0;
+        }
 #else
         corr = corr_global;
 #endif
@@ -583,11 +589,14 @@ void RiemannianTrustRegionSolver<Basis,TargetSpace>::solve()
             } catch (Dune::Exception &e) {
                 std::cerr << "Error while computing the energy of the new Iterate: " << e << std::endl;
                 std::cerr << "Redoing trust region step with smaller radius..." << std::endl;
-                newIterate = x_;
                 solved = false;
-                energy = oldEnergy;
             }
-            if (solved) {
+            solved = grid_->comm().min(solved);
+
+            if (!solved) {
+                newIterate = x_;
+                energy = oldEnergy;
+            } else {
                 energy = grid_->comm().sum(energy);
 
                 // compute the model decrease
-- 
GitLab