From bbc34b2e55a69fd0af6b7fb579f521fb4916a64a Mon Sep 17 00:00:00 2001 From: Oliver Sander <sander@igpm.rwth-aachen.de> Date: Wed, 14 May 2014 15:36:35 +0000 Subject: [PATCH] Add a method VectorCommunicator::scatter, and use it [[Imported from SVN: r9732]] --- dune/gfe/parallel/vectorcommunicator.hh | 43 +++++++++++++------------ dune/gfe/riemanniantrsolver.cc | 18 ++++------- 2 files changed, 29 insertions(+), 32 deletions(-) diff --git a/dune/gfe/parallel/vectorcommunicator.hh b/dune/gfe/parallel/vectorcommunicator.hh index 7dca7e45..2fcb7541 100644 --- a/dune/gfe/parallel/vectorcommunicator.hh +++ b/dune/gfe/parallel/vectorcommunicator.hh @@ -55,6 +55,24 @@ private: return globalVector; } + VectorType createLocalSolution() { + const int localSize = localVectorEntriesSizes[guIndex.getGridView().comm().rank()]; + + // Create vector for transfer data + std::vector<TransferVectorTuple> localVectorEntries(localSize); + + MPIFunctions::scatterv(guIndex.getGridView(), localVectorEntries, globalVectorEntries, localVectorEntriesSizes, root_rank); + + // Create vector for local solution + VectorType x(localSize); + + // And translate solution again + for (size_t k = 0; k < localVectorEntries.size(); ++k) + x[guIndex.localIndex(localVectorEntries[k].row)] = localVectorEntries[k].entry; + + return x; + } + public: VectorCommunicator(const GUIndex& gi, const int& root) : guIndex(gi), root_rank(root) @@ -75,29 +93,12 @@ public: return copyIntoGlobalVector(); } - void fillEntriesFromVector(const VectorType& x_global) { + VectorType scatter(const VectorType& global) + { for (size_t k = 0; k < globalVectorEntries.size(); ++k) - globalVectorEntries[k].entry = x_global[globalVectorEntries[k].row]; - } - - - VectorType createLocalSolution() { - const int localSize = localVectorEntriesSizes[guIndex.getGridView().comm().rank()]; - - // Create vector for transfer data - std::vector<TransferVectorTuple> localVectorEntries(localSize); - - MPIFunctions::scatterv(guIndex.getGridView(), localVectorEntries, globalVectorEntries, localVectorEntriesSizes, root_rank); + globalVectorEntries[k].entry = global[globalVectorEntries[k].row]; - // Create vector for local solution - VectorType x(localSize); - - // And translate solution again - for (size_t k = 0; k < localVectorEntries.size(); ++k) - x[guIndex.localIndex(localVectorEntries[k].row)] = localVectorEntries[k].entry; - - - return x; + return createLocalSolution(); } private: diff --git a/dune/gfe/riemanniantrsolver.cc b/dune/gfe/riemanniantrsolver.cc index 322d6053..01b48563 100644 --- a/dune/gfe/riemanniantrsolver.cc +++ b/dune/gfe/riemanniantrsolver.cc @@ -355,11 +355,11 @@ void RiemannianTrustRegionSolver<GridType,TargetSpace>::solve() } + CorrectionType corr_global(rhs_global.size()); + corr_global = 0; + if (rank==0) { - CorrectionType corr_global(rhs_global.size()); - corr_global = 0; - mgStep->setProblem(stiffnessMatrix, corr_global, rhs_global); trustRegionObstacles.back() = trustRegion.obstacles(); @@ -378,18 +378,14 @@ void RiemannianTrustRegionSolver<GridType,TargetSpace>::solve() if (mgStep) corr_global = mgStep->getSol(); - // Translate solution back - if (mpiHelper.size()>1) - std::cout << "Translating solution back on root process ..." << std::endl; - - // Recycle VectorCommunicator by using it for the solution vector - vectorComm.fillEntriesFromVector(corr_global); - //std::cout << "Correction: " << std::endl << corr_global << std::endl; } // Distribute solution - corr = CorrectionType(vectorComm.createLocalSolution()); + if (mpiHelper.size()>1) + std::cout << "Transfer solution back to root process ..." << std::endl; + + corr = vectorComm.scatter(corr_global); if (instrumented_) { -- GitLab