Home | History | Annotate | Download | only in ceres
      1 // Ceres Solver - A fast non-linear least squares minimizer
      2 // Copyright 2012 Google Inc. All rights reserved.
      3 // http://code.google.com/p/ceres-solver/
      4 //
      5 // Redistribution and use in source and binary forms, with or without
      6 // modification, are permitted provided that the following conditions are met:
      7 //
      8 // * Redistributions of source code must retain the above copyright notice,
      9 //   this list of conditions and the following disclaimer.
     10 // * Redistributions in binary form must reproduce the above copyright notice,
     11 //   this list of conditions and the following disclaimer in the documentation
     12 //   and/or other materials provided with the distribution.
     13 // * Neither the name of Google Inc. nor the names of its contributors may be
     14 //   used to endorse or promote products derived from this software without
     15 //   specific prior written permission.
     16 //
     17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
     18 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
     19 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
     20 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
     21 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
     22 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
     23 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
     24 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
     25 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
     26 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
     27 // POSSIBILITY OF SUCH DAMAGE.
     28 //
     29 // Author: sameeragarwal (at) google.com (Sameer Agarwal)
     30 //
     31 // Generic loop for line search based optimization algorithms.
     32 //
     33 // This is primarily inpsired by the minFunc packaged written by Mark
     34 // Schmidt.
     35 //
     36 // http://www.di.ens.fr/~mschmidt/Software/minFunc.html
     37 //
     38 // For details on the theory and implementation see "Numerical
     39 // Optimization" by Nocedal & Wright.
     40 
     41 #ifndef CERES_NO_LINE_SEARCH_MINIMIZER
     42 
     43 #include "ceres/line_search_minimizer.h"
     44 
     45 #include <algorithm>
     46 #include <cstdlib>
     47 #include <cmath>
     48 #include <string>
     49 #include <vector>
     50 
     51 #include "Eigen/Dense"
     52 #include "ceres/array_utils.h"
     53 #include "ceres/evaluator.h"
     54 #include "ceres/internal/eigen.h"
     55 #include "ceres/internal/port.h"
     56 #include "ceres/internal/scoped_ptr.h"
     57 #include "ceres/line_search.h"
     58 #include "ceres/line_search_direction.h"
     59 #include "ceres/stringprintf.h"
     60 #include "ceres/types.h"
     61 #include "ceres/wall_time.h"
     62 #include "glog/logging.h"
     63 
     64 namespace ceres {
     65 namespace internal {
     66 namespace {
     67 // Small constant for various floating point issues.
     68 // TODO(sameeragarwal): Change to a better name if this has only one
     69 // use.
     70 const double kEpsilon = 1e-12;
     71 
     72 bool Evaluate(Evaluator* evaluator,
     73               const Vector& x,
     74               LineSearchMinimizer::State* state) {
     75   const bool status = evaluator->Evaluate(x.data(),
     76                                           &(state->cost),
     77                                           NULL,
     78                                           state->gradient.data(),
     79                                           NULL);
     80   if (status) {
     81     state->gradient_squared_norm = state->gradient.squaredNorm();
     82     state->gradient_max_norm = state->gradient.lpNorm<Eigen::Infinity>();
     83   }
     84 
     85   return status;
     86 }
     87 
     88 }  // namespace
     89 
     90 void LineSearchMinimizer::Minimize(const Minimizer::Options& options,
     91                                    double* parameters,
     92                                    Solver::Summary* summary) {
     93   double start_time = WallTimeInSeconds();
     94   double iteration_start_time =  start_time;
     95 
     96   Evaluator* evaluator = CHECK_NOTNULL(options.evaluator);
     97   const int num_parameters = evaluator->NumParameters();
     98   const int num_effective_parameters = evaluator->NumEffectiveParameters();
     99 
    100   summary->termination_type = NO_CONVERGENCE;
    101   summary->num_successful_steps = 0;
    102   summary->num_unsuccessful_steps = 0;
    103 
    104   VectorRef x(parameters, num_parameters);
    105 
    106   State current_state(num_parameters, num_effective_parameters);
    107   State previous_state(num_parameters, num_effective_parameters);
    108 
    109   Vector delta(num_effective_parameters);
    110   Vector x_plus_delta(num_parameters);
    111 
    112   IterationSummary iteration_summary;
    113   iteration_summary.iteration = 0;
    114   iteration_summary.step_is_valid = false;
    115   iteration_summary.step_is_successful = false;
    116   iteration_summary.cost_change = 0.0;
    117   iteration_summary.gradient_max_norm = 0.0;
    118   iteration_summary.step_norm = 0.0;
    119   iteration_summary.linear_solver_iterations = 0;
    120   iteration_summary.step_solver_time_in_seconds = 0;
    121 
    122   // Do initial cost and Jacobian evaluation.
    123   if (!Evaluate(evaluator, x, &current_state)) {
    124     LOG(WARNING) << "Terminating: Cost and gradient evaluation failed.";
    125     summary->termination_type = NUMERICAL_FAILURE;
    126     return;
    127   }
    128 
    129   summary->initial_cost = current_state.cost + summary->fixed_cost;
    130   iteration_summary.cost = current_state.cost + summary->fixed_cost;
    131 
    132   iteration_summary.gradient_max_norm = current_state.gradient_max_norm;
    133 
    134   // The initial gradient max_norm is bounded from below so that we do
    135   // not divide by zero.
    136   const double initial_gradient_max_norm =
    137       max(iteration_summary.gradient_max_norm, kEpsilon);
    138   const double absolute_gradient_tolerance =
    139       options.gradient_tolerance * initial_gradient_max_norm;
    140 
    141   if (iteration_summary.gradient_max_norm <= absolute_gradient_tolerance) {
    142     summary->termination_type = GRADIENT_TOLERANCE;
    143     VLOG(1) << "Terminating: Gradient tolerance reached."
    144             << "Relative gradient max norm: "
    145             << iteration_summary.gradient_max_norm / initial_gradient_max_norm
    146             << " <= " << options.gradient_tolerance;
    147     return;
    148   }
    149 
    150   iteration_summary.iteration_time_in_seconds =
    151       WallTimeInSeconds() - iteration_start_time;
    152   iteration_summary.cumulative_time_in_seconds =
    153       WallTimeInSeconds() - start_time
    154       + summary->preprocessor_time_in_seconds;
    155   summary->iterations.push_back(iteration_summary);
    156 
    157   LineSearchDirection::Options line_search_direction_options;
    158   line_search_direction_options.num_parameters = num_effective_parameters;
    159   line_search_direction_options.type = options.line_search_direction_type;
    160   line_search_direction_options.nonlinear_conjugate_gradient_type =
    161       options.nonlinear_conjugate_gradient_type;
    162   line_search_direction_options.max_lbfgs_rank = options.max_lbfgs_rank;
    163   line_search_direction_options.use_approximate_eigenvalue_bfgs_scaling =
    164       options.use_approximate_eigenvalue_bfgs_scaling;
    165   scoped_ptr<LineSearchDirection> line_search_direction(
    166       LineSearchDirection::Create(line_search_direction_options));
    167 
    168   LineSearchFunction line_search_function(evaluator);
    169 
    170   LineSearch::Options line_search_options;
    171   line_search_options.interpolation_type =
    172       options.line_search_interpolation_type;
    173   line_search_options.min_step_size = options.min_line_search_step_size;
    174   line_search_options.sufficient_decrease =
    175       options.line_search_sufficient_function_decrease;
    176   line_search_options.max_step_contraction =
    177       options.max_line_search_step_contraction;
    178   line_search_options.min_step_contraction =
    179       options.min_line_search_step_contraction;
    180   line_search_options.max_num_iterations =
    181       options.max_num_line_search_step_size_iterations;
    182   line_search_options.sufficient_curvature_decrease =
    183       options.line_search_sufficient_curvature_decrease;
    184   line_search_options.max_step_expansion =
    185       options.max_line_search_step_expansion;
    186   line_search_options.function = &line_search_function;
    187 
    188   scoped_ptr<LineSearch>
    189       line_search(LineSearch::Create(options.line_search_type,
    190                                      line_search_options,
    191                                      &summary->error));
    192   if (line_search.get() == NULL) {
    193     LOG(ERROR) << "Ceres bug: Unable to create a LineSearch object, please "
    194                << "contact the developers!, error: " << summary->error;
    195     summary->termination_type = DID_NOT_RUN;
    196     return;
    197   }
    198 
    199   LineSearch::Summary line_search_summary;
    200   int num_line_search_direction_restarts = 0;
    201 
    202   while (true) {
    203     if (!RunCallbacks(options.callbacks, iteration_summary, summary)) {
    204       return;
    205     }
    206 
    207     iteration_start_time = WallTimeInSeconds();
    208     if (iteration_summary.iteration >= options.max_num_iterations) {
    209       summary->termination_type = NO_CONVERGENCE;
    210       VLOG(1) << "Terminating: Maximum number of iterations reached.";
    211       break;
    212     }
    213 
    214     const double total_solver_time = iteration_start_time - start_time +
    215         summary->preprocessor_time_in_seconds;
    216     if (total_solver_time >= options.max_solver_time_in_seconds) {
    217       summary->termination_type = NO_CONVERGENCE;
    218       VLOG(1) << "Terminating: Maximum solver time reached.";
    219       break;
    220     }
    221 
    222     iteration_summary = IterationSummary();
    223     iteration_summary.iteration = summary->iterations.back().iteration + 1;
    224     iteration_summary.step_is_valid = false;
    225     iteration_summary.step_is_successful = false;
    226 
    227     bool line_search_status = true;
    228     if (iteration_summary.iteration == 1) {
    229       current_state.search_direction = -current_state.gradient;
    230     } else {
    231       line_search_status = line_search_direction->NextDirection(
    232           previous_state,
    233           current_state,
    234           &current_state.search_direction);
    235     }
    236 
    237     if (!line_search_status &&
    238         num_line_search_direction_restarts >=
    239         options.max_num_line_search_direction_restarts) {
    240       // Line search direction failed to generate a new direction, and we
    241       // have already reached our specified maximum number of restarts,
    242       // terminate optimization.
    243       summary->error =
    244           StringPrintf("Line search direction failure: specified "
    245                        "max_num_line_search_direction_restarts: %d reached.",
    246                        options.max_num_line_search_direction_restarts);
    247       LOG(WARNING) << summary->error << " terminating optimization.";
    248       summary->termination_type = NUMERICAL_FAILURE;
    249       break;
    250 
    251     } else if (!line_search_status) {
    252       // Restart line search direction with gradient descent on first iteration
    253       // as we have not yet reached our maximum number of restarts.
    254       CHECK_LT(num_line_search_direction_restarts,
    255                options.max_num_line_search_direction_restarts);
    256 
    257       ++num_line_search_direction_restarts;
    258       LOG(WARNING)
    259           << "Line search direction algorithm: "
    260           << LineSearchDirectionTypeToString(options.line_search_direction_type)
    261           << ", failed to produce a valid new direction at iteration: "
    262           << iteration_summary.iteration << ". Restarting, number of "
    263           << "restarts: " << num_line_search_direction_restarts << " / "
    264           << options.max_num_line_search_direction_restarts << " [max].";
    265       line_search_direction.reset(
    266           LineSearchDirection::Create(line_search_direction_options));
    267       current_state.search_direction = -current_state.gradient;
    268     }
    269 
    270     line_search_function.Init(x, current_state.search_direction);
    271     current_state.directional_derivative =
    272         current_state.gradient.dot(current_state.search_direction);
    273 
    274     // TODO(sameeragarwal): Refactor this into its own object and add
    275     // explanations for the various choices.
    276     //
    277     // Note that we use !line_search_status to ensure that we treat cases when
    278     // we restarted the line search direction equivalently to the first
    279     // iteration.
    280     const double initial_step_size =
    281         (iteration_summary.iteration == 1 || !line_search_status)
    282         ? min(1.0, 1.0 / current_state.gradient_max_norm)
    283         : min(1.0, 2.0 * (current_state.cost - previous_state.cost) /
    284               current_state.directional_derivative);
    285     // By definition, we should only ever go forwards along the specified search
    286     // direction in a line search, most likely cause for this being violated
    287     // would be a numerical failure in the line search direction calculation.
    288     if (initial_step_size < 0.0) {
    289       summary->error =
    290           StringPrintf("Numerical failure in line search, initial_step_size is "
    291                        "negative: %.5e, directional_derivative: %.5e, "
    292                        "(current_cost - previous_cost): %.5e",
    293                        initial_step_size, current_state.directional_derivative,
    294                        (current_state.cost - previous_state.cost));
    295       LOG(WARNING) << summary->error;
    296       summary->termination_type = NUMERICAL_FAILURE;
    297       break;
    298     }
    299 
    300     line_search->Search(initial_step_size,
    301                         current_state.cost,
    302                         current_state.directional_derivative,
    303                         &line_search_summary);
    304 
    305     current_state.step_size = line_search_summary.optimal_step_size;
    306     delta = current_state.step_size * current_state.search_direction;
    307 
    308     previous_state = current_state;
    309     iteration_summary.step_solver_time_in_seconds =
    310         WallTimeInSeconds() - iteration_start_time;
    311 
    312     // TODO(sameeragarwal): Collect stats.
    313     if (!evaluator->Plus(x.data(), delta.data(), x_plus_delta.data()) ||
    314         !Evaluate(evaluator, x_plus_delta, &current_state)) {
    315       LOG(WARNING) << "Evaluation failed.";
    316     } else {
    317       x = x_plus_delta;
    318     }
    319 
    320     iteration_summary.gradient_max_norm = current_state.gradient_max_norm;
    321     if (iteration_summary.gradient_max_norm <= absolute_gradient_tolerance) {
    322       summary->termination_type = GRADIENT_TOLERANCE;
    323       VLOG(1) << "Terminating: Gradient tolerance reached."
    324               << "Relative gradient max norm: "
    325               << iteration_summary.gradient_max_norm / initial_gradient_max_norm
    326               << " <= " << options.gradient_tolerance;
    327       break;
    328     }
    329 
    330     iteration_summary.cost_change = previous_state.cost - current_state.cost;
    331     const double absolute_function_tolerance =
    332         options.function_tolerance * previous_state.cost;
    333     if (fabs(iteration_summary.cost_change) < absolute_function_tolerance) {
    334       VLOG(1) << "Terminating. Function tolerance reached. "
    335               << "|cost_change|/cost: "
    336               << fabs(iteration_summary.cost_change) / previous_state.cost
    337               << " <= " << options.function_tolerance;
    338       summary->termination_type = FUNCTION_TOLERANCE;
    339       return;
    340     }
    341 
    342     iteration_summary.cost = current_state.cost + summary->fixed_cost;
    343     iteration_summary.step_norm = delta.norm();
    344     iteration_summary.step_is_valid = true;
    345     iteration_summary.step_is_successful = true;
    346     iteration_summary.step_norm = delta.norm();
    347     iteration_summary.step_size =  current_state.step_size;
    348     iteration_summary.line_search_function_evaluations =
    349         line_search_summary.num_function_evaluations;
    350     iteration_summary.line_search_gradient_evaluations =
    351         line_search_summary.num_gradient_evaluations;
    352     iteration_summary.line_search_iterations =
    353         line_search_summary.num_iterations;
    354     iteration_summary.iteration_time_in_seconds =
    355         WallTimeInSeconds() - iteration_start_time;
    356     iteration_summary.cumulative_time_in_seconds =
    357         WallTimeInSeconds() - start_time
    358         + summary->preprocessor_time_in_seconds;
    359 
    360     summary->iterations.push_back(iteration_summary);
    361     ++summary->num_successful_steps;
    362   }
    363 }
    364 
    365 }  // namespace internal
    366 }  // namespace ceres
    367 
    368 #endif  // CERES_NO_LINE_SEARCH_MINIMIZER
    369