Home | History | Annotate | Download | only in ceres
      1 // Ceres Solver - A fast non-linear least squares minimizer
      2 // Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
      3 // http://code.google.com/p/ceres-solver/
      4 //
      5 // Redistribution and use in source and binary forms, with or without
      6 // modification, are permitted provided that the following conditions are met:
      7 //
      8 // * Redistributions of source code must retain the above copyright notice,
      9 //   this list of conditions and the following disclaimer.
     10 // * Redistributions in binary form must reproduce the above copyright notice,
     11 //   this list of conditions and the following disclaimer in the documentation
     12 //   and/or other materials provided with the distribution.
     13 // * Neither the name of Google Inc. nor the names of its contributors may be
     14 //   used to endorse or promote products derived from this software without
     15 //   specific prior written permission.
     16 //
     17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
     18 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
     19 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
     20 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
     21 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
     22 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
     23 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
     24 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
     25 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
     26 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
     27 // POSSIBILITY OF SUCH DAMAGE.
     28 //
     29 // Author: sameeragarwal (at) google.com (Sameer Agarwal)
     30 //
     31 // A preconditioned conjugate gradients solver
     32 // (ConjugateGradientsSolver) for positive semidefinite linear
     33 // systems.
     34 //
     35 // We have also augmented the termination criterion used by this
     36 // solver to support not just residual based termination but also
     37 // termination based on decrease in the value of the quadratic model
     38 // that CG optimizes.
     39 
     40 #include "ceres/conjugate_gradients_solver.h"
     41 
     42 #include <cmath>
     43 #include <cstddef>
     44 #include "ceres/fpclassify.h"
     45 #include "ceres/internal/eigen.h"
     46 #include "ceres/linear_operator.h"
     47 #include "ceres/stringprintf.h"
     48 #include "ceres/types.h"
     49 #include "glog/logging.h"
     50 
     51 namespace ceres {
     52 namespace internal {
     53 namespace {
     54 
     55 bool IsZeroOrInfinity(double x) {
     56   return ((x == 0.0) || (IsInfinite(x)));
     57 }
     58 
     59 }  // namespace
     60 
     61 ConjugateGradientsSolver::ConjugateGradientsSolver(
     62     const LinearSolver::Options& options)
     63     : options_(options) {
     64 }
     65 
     66 LinearSolver::Summary ConjugateGradientsSolver::Solve(
     67     LinearOperator* A,
     68     const double* b,
     69     const LinearSolver::PerSolveOptions& per_solve_options,
     70     double* x) {
     71   CHECK_NOTNULL(A);
     72   CHECK_NOTNULL(x);
     73   CHECK_NOTNULL(b);
     74   CHECK_EQ(A->num_rows(), A->num_cols());
     75 
     76   LinearSolver::Summary summary;
     77   summary.termination_type = LINEAR_SOLVER_NO_CONVERGENCE;
     78   summary.message = "Maximum number of iterations reached.";
     79   summary.num_iterations = 0;
     80 
     81   const int num_cols = A->num_cols();
     82   VectorRef xref(x, num_cols);
     83   ConstVectorRef bref(b, num_cols);
     84 
     85   const double norm_b = bref.norm();
     86   if (norm_b == 0.0) {
     87     xref.setZero();
     88     summary.termination_type = LINEAR_SOLVER_SUCCESS;
     89     summary.message = "Convergence. |b| = 0.";
     90     return summary;
     91   }
     92 
     93   Vector r(num_cols);
     94   Vector p(num_cols);
     95   Vector z(num_cols);
     96   Vector tmp(num_cols);
     97 
     98   const double tol_r = per_solve_options.r_tolerance * norm_b;
     99 
    100   tmp.setZero();
    101   A->RightMultiply(x, tmp.data());
    102   r = bref - tmp;
    103   double norm_r = r.norm();
    104   if (norm_r <= tol_r) {
    105     summary.termination_type = LINEAR_SOLVER_SUCCESS;
    106     summary.message =
    107         StringPrintf("Convergence. |r| = %e <= %e.", norm_r, tol_r);
    108     return summary;
    109   }
    110 
    111   double rho = 1.0;
    112 
    113   // Initial value of the quadratic model Q = x'Ax - 2 * b'x.
    114   double Q0 = -1.0 * xref.dot(bref + r);
    115 
    116   for (summary.num_iterations = 1;
    117        summary.num_iterations < options_.max_num_iterations;
    118        ++summary.num_iterations) {
    119     // Apply preconditioner
    120     if (per_solve_options.preconditioner != NULL) {
    121       z.setZero();
    122       per_solve_options.preconditioner->RightMultiply(r.data(), z.data());
    123     } else {
    124       z = r;
    125     }
    126 
    127     double last_rho = rho;
    128     rho = r.dot(z);
    129     if (IsZeroOrInfinity(rho)) {
    130       summary.termination_type = LINEAR_SOLVER_FAILURE;
    131       summary.message = StringPrintf("Numerical failure. rho = r'z = %e.", rho);
    132       break;
    133     };
    134 
    135     if (summary.num_iterations == 1) {
    136       p = z;
    137     } else {
    138       double beta = rho / last_rho;
    139       if (IsZeroOrInfinity(beta)) {
    140         summary.termination_type = LINEAR_SOLVER_FAILURE;
    141         summary.message = StringPrintf(
    142             "Numerical failure. beta = rho_n / rho_{n-1} = %e.", beta);
    143         break;
    144       }
    145       p = z + beta * p;
    146     }
    147 
    148     Vector& q = z;
    149     q.setZero();
    150     A->RightMultiply(p.data(), q.data());
    151     const double pq = p.dot(q);
    152     if ((pq <= 0) || IsInfinite(pq))  {
    153       summary.termination_type = LINEAR_SOLVER_FAILURE;
    154       summary.message = StringPrintf("Numerical failure. p'q = %e.", pq);
    155       break;
    156     }
    157 
    158     const double alpha = rho / pq;
    159     if (IsInfinite(alpha)) {
    160       summary.termination_type = LINEAR_SOLVER_FAILURE;
    161       summary.message =
    162           StringPrintf("Numerical failure. alpha = rho / pq = %e", alpha);
    163       break;
    164     }
    165 
    166     xref = xref + alpha * p;
    167 
    168     // Ideally we would just use the update r = r - alpha*q to keep
    169     // track of the residual vector. However this estimate tends to
    170     // drift over time due to round off errors. Thus every
    171     // residual_reset_period iterations, we calculate the residual as
    172     // r = b - Ax. We do not do this every iteration because this
    173     // requires an additional matrix vector multiply which would
    174     // double the complexity of the CG algorithm.
    175     if (summary.num_iterations % options_.residual_reset_period == 0) {
    176       tmp.setZero();
    177       A->RightMultiply(x, tmp.data());
    178       r = bref - tmp;
    179     } else {
    180       r = r - alpha * q;
    181     }
    182 
    183     // Quadratic model based termination.
    184     //   Q1 = x'Ax - 2 * b' x.
    185     const double Q1 = -1.0 * xref.dot(bref + r);
    186 
    187     // For PSD matrices A, let
    188     //
    189     //   Q(x) = x'Ax - 2b'x
    190     //
    191     // be the cost of the quadratic function defined by A and b. Then,
    192     // the solver terminates at iteration i if
    193     //
    194     //   i * (Q(x_i) - Q(x_i-1)) / Q(x_i) < q_tolerance.
    195     //
    196     // This termination criterion is more useful when using CG to
    197     // solve the Newton step. This particular convergence test comes
    198     // from Stephen Nash's work on truncated Newton
    199     // methods. References:
    200     //
    201     //   1. Stephen G. Nash & Ariela Sofer, Assessing A Search
    202     //   Direction Within A Truncated Newton Method, Operation
    203     //   Research Letters 9(1990) 219-221.
    204     //
    205     //   2. Stephen G. Nash, A Survey of Truncated Newton Methods,
    206     //   Journal of Computational and Applied Mathematics,
    207     //   124(1-2), 45-59, 2000.
    208     //
    209     const double zeta = summary.num_iterations * (Q1 - Q0) / Q1;
    210     if (zeta < per_solve_options.q_tolerance) {
    211       summary.termination_type = LINEAR_SOLVER_SUCCESS;
    212       summary.message =
    213           StringPrintf("Convergence: zeta = %e < %e",
    214                        zeta,
    215                        per_solve_options.q_tolerance);
    216       break;
    217     }
    218     Q0 = Q1;
    219 
    220     // Residual based termination.
    221     norm_r = r. norm();
    222     if (norm_r <= tol_r) {
    223       summary.termination_type = LINEAR_SOLVER_SUCCESS;
    224       summary.message =
    225           StringPrintf("Convergence. |r| = %e <= %e.", norm_r, tol_r);
    226       break;
    227     }
    228   }
    229 
    230   return summary;
    231 };
    232 
    233 }  // namespace internal
    234 }  // namespace ceres
    235