From 814135cb77c391e31e559f162e34002b2a002fad Mon Sep 17 00:00:00 2001
From: Oliver Sander <sander@igpm.rwth-aachen.de>
Date: Thu, 30 Apr 2009 12:46:29 +0000
Subject: [PATCH] modify hesse matrix and rhs to account for Dirichlet nodes

[[Imported from SVN: r4131]]
---
 src/riemanniantrsolver.cc | 67 ++++++++++++++++++++++++++++++++-------
 src/riemanniantrsolver.hh |  3 ++
 2 files changed, 58 insertions(+), 12 deletions(-)

diff --git a/src/riemanniantrsolver.cc b/src/riemanniantrsolver.cc
index e9378398..32389f07 100644
--- a/src/riemanniantrsolver.cc
+++ b/src/riemanniantrsolver.cc
@@ -51,6 +51,7 @@ setup(const GridType& grid,
     innerIterations_          = multigridIterations;
     innerTolerance_           = mgTolerance;
     instrumented_             = instrumented;
+    ignoreNodes_              = &dirichletNodes;
 
     int numLevels = grid_->maxLevel()+1;
 
@@ -166,6 +167,7 @@ setupTCG(const GridType& grid,
     innerIterations_          = innerIterations;
     innerTolerance_           = innerTolerance;
     instrumented_             = instrumented;
+    ignoreNodes_              = &dirichletNodes;
 
     // ////////////////////////////////////////////////////////////
     //    Create Hessian matrix and its occupation structure
@@ -195,14 +197,10 @@ setupTCG(const GridType& grid,
     //   Create a truncated conjugate gradient solver
     // ////////////////////////////////////////////////////
 
-    innerSolver_ = new TruncatedCGSolver<MatrixType,CorrectionType>(*hessianMatrix_,
-                                                                    x,
-                                                                    this->rhs_,
-                                                                    innerIterations_,
+    innerSolver_ = new TruncatedCGSolver<MatrixType,CorrectionType>(innerIterations_,
                                                                     innerTolerance_,
                                                                     h1SemiNorm_,
-                                                                    initialTrustRegionRadius,
-                                                                    Solver::QUIET);
+                                                                    Solver::FULL);
 
     // Write all intermediate solutions, if requested
     if (instrumented_
@@ -224,7 +222,9 @@ void RiemannianTrustRegionSolver<GridType,TargetSpace>::solve()
 
     MaxNormTrustRegion<blocksize> trustRegion(x_.size(), initialTrustRegionRadius_);
 
-    std::vector<std::vector<BoxConstraint<field_type,blocksize> > > trustRegionObstacles(mgStep->numLevels_);
+    std::vector<std::vector<BoxConstraint<field_type,blocksize> > > trustRegionObstacles((mgStep) 
+                                                                                         ? mgStep->numLevels_
+                                                                                         : 0);
 
    // /////////////////////////////////////////////////////
     //   Set up the log file, if requested
@@ -264,23 +264,66 @@ void RiemannianTrustRegionSolver<GridType,TargetSpace>::solve()
         //gradientFDCheck(x_, rhs, *rodAssembler_);
         //hessianFDCheck(x_, *hessianMatrix_, *rodAssembler_);
 
+        // The right hand side is the _negative_ gradient
         rhs *= -1;
 
-        mgStep->setProblem(*hessianMatrix_, corr, rhs, grid_->maxLevel()+1);
 
-        trustRegionObstacles.back() = trustRegion.obstacles();
-        mgStep->obstacles_ = &trustRegionObstacles;
+        // //////////////////////////////////////////////////////////////////////
+        //   Modify matrix and right-hand side to account for Dirichlet values
+        // //////////////////////////////////////////////////////////////////////
 
+        typedef typename MatrixType::row_type::Iterator ColumnIterator;
+  
+        for (size_t j=0; j<ignoreNodes_->size(); j++) {
+            
+            if (ignoreNodes_->operator[](j).count() > 0) {
+                
+                // make matrix row an identity row
+                ColumnIterator cIt    = (*hessianMatrix_)[j].begin();
+                ColumnIterator cEndIt = (*hessianMatrix_)[j].end();
+                
+                for (; cIt!=cEndIt; ++cIt) {
+                    for (int k=0; k<blocksize; k++) {
+                        if (ignoreNodes_->operator[](j)[k])
+                            (*cIt)[k] = 0;
+                        if (j==cIt.index())
+                            (*cIt)[k][k] = 1;
+                    }
+                }
+
+                // Dirichlet value.  Zero, because we are solving defect problems
+                for (int k=0; k<blocksize; k++)
+                    if (ignoreNodes_->operator[](j)[k])
+                        rhs[j][k] = 0;
+            }
+
+        }
+
+
+        if (mgStep) {  // inner solver is a monotone multigrid
+
+            mgStep->setProblem(*hessianMatrix_, corr, rhs, grid_->maxLevel()+1);
+
+            trustRegionObstacles.back() = trustRegion.obstacles();
+            mgStep->obstacles_ = &trustRegionObstacles;
         
+        } else {       // inner solver is a truncated cg
+
+            assert((dynamic_cast<TruncatedCGSolver<MatrixType,CorrectionType>*>(innerSolver_)));
+            dynamic_cast<TruncatedCGSolver<MatrixType,CorrectionType>*>(innerSolver_)->setProblem(*hessianMatrix_, &corr, &rhs, trustRegion.radius());
+
+        }
+
         innerSolver_->preprocess();
         
-        
         // /////////////////////////////
         //    Solve !
         // /////////////////////////////
+        
         innerSolver_->solve();
         
-        corr = mgStep->getSol();
+        if (mgStep)
+            corr = mgStep->getSol();
         
         //std::cout << "Correction: " << std::endl << corr << std::endl;
         
diff --git a/src/riemanniantrsolver.hh b/src/riemanniantrsolver.hh
index 2f6e7528..cb6065ac 100644
--- a/src/riemanniantrsolver.hh
+++ b/src/riemanniantrsolver.hh
@@ -111,6 +111,9 @@ protected:
         expects them :-( */
     std::vector<Dune::BitSetVector<1> > hasObstacle_;
 
+    /** \brief The Dirichlet nodes */
+    const Dune::BitSetVector<blocksize>* ignoreNodes_;
+
     /** \brief The norm used to measure multigrid convergence */
     H1SemiNorm<CorrectionType>* h1SemiNorm_;
     
-- 
GitLab