1 // Ceres Solver - A fast non-linear least squares minimizer 2 // Copyright 2012 Google Inc. All rights reserved. 3 // http://code.google.com/p/ceres-solver/ 4 // 5 // Redistribution and use in source and binary forms, with or without 6 // modification, are permitted provided that the following conditions are met: 7 // 8 // * Redistributions of source code must retain the above copyright notice, 9 // this list of conditions and the following disclaimer. 10 // * Redistributions in binary form must reproduce the above copyright notice, 11 // this list of conditions and the following disclaimer in the documentation 12 // and/or other materials provided with the distribution. 13 // * Neither the name of Google Inc. nor the names of its contributors may be 14 // used to endorse or promote products derived from this software without 15 // specific prior written permission. 16 // 17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 21 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 // POSSIBILITY OF SUCH DAMAGE. 28 // 29 // Author: sameeragarwal (at) google.com (Sameer Agarwal) 30 // 31 // Generic loop for line search based optimization algorithms. 32 // 33 // This is primarily inpsired by the minFunc packaged written by Mark 34 // Schmidt. 35 // 36 // http://www.di.ens.fr/~mschmidt/Software/minFunc.html 37 // 38 // For details on the theory and implementation see "Numerical 39 // Optimization" by Nocedal & Wright. 40 41 #include "ceres/line_search_minimizer.h" 42 43 #include <algorithm> 44 #include <cstdlib> 45 #include <cmath> 46 #include <string> 47 #include <vector> 48 49 #include "Eigen/Dense" 50 #include "ceres/array_utils.h" 51 #include "ceres/evaluator.h" 52 #include "ceres/internal/eigen.h" 53 #include "ceres/internal/port.h" 54 #include "ceres/internal/scoped_ptr.h" 55 #include "ceres/line_search.h" 56 #include "ceres/line_search_direction.h" 57 #include "ceres/stringprintf.h" 58 #include "ceres/types.h" 59 #include "ceres/wall_time.h" 60 #include "glog/logging.h" 61 62 namespace ceres { 63 namespace internal { 64 namespace { 65 66 // TODO(sameeragarwal): I think there is a small bug here, in that if 67 // the evaluation fails, then the state can contain garbage. Look at 68 // this more carefully. 69 bool Evaluate(Evaluator* evaluator, 70 const Vector& x, 71 LineSearchMinimizer::State* state, 72 string* message) { 73 if (!evaluator->Evaluate(x.data(), 74 &(state->cost), 75 NULL, 76 state->gradient.data(), 77 NULL)) { 78 *message = "Gradient evaluation failed."; 79 return false; 80 } 81 82 Vector negative_gradient = -state->gradient; 83 Vector projected_gradient_step(x.size()); 84 if (!evaluator->Plus(x.data(), 85 negative_gradient.data(), 86 projected_gradient_step.data())) { 87 *message = "projected_gradient_step = Plus(x, -gradient) failed."; 88 return false; 89 } 90 91 state->gradient_squared_norm = (x - projected_gradient_step).squaredNorm(); 92 state->gradient_max_norm = 93 (x - projected_gradient_step).lpNorm<Eigen::Infinity>(); 94 return true; 95 } 96 97 } // namespace 98 99 void LineSearchMinimizer::Minimize(const Minimizer::Options& options, 100 double* parameters, 101 Solver::Summary* summary) { 102 const bool is_not_silent = !options.is_silent; 103 double start_time = WallTimeInSeconds(); 104 double iteration_start_time = start_time; 105 106 Evaluator* evaluator = CHECK_NOTNULL(options.evaluator); 107 const int num_parameters = evaluator->NumParameters(); 108 const int num_effective_parameters = evaluator->NumEffectiveParameters(); 109 110 summary->termination_type = NO_CONVERGENCE; 111 summary->num_successful_steps = 0; 112 summary->num_unsuccessful_steps = 0; 113 114 VectorRef x(parameters, num_parameters); 115 116 State current_state(num_parameters, num_effective_parameters); 117 State previous_state(num_parameters, num_effective_parameters); 118 119 Vector delta(num_effective_parameters); 120 Vector x_plus_delta(num_parameters); 121 122 IterationSummary iteration_summary; 123 iteration_summary.iteration = 0; 124 iteration_summary.step_is_valid = false; 125 iteration_summary.step_is_successful = false; 126 iteration_summary.cost_change = 0.0; 127 iteration_summary.gradient_max_norm = 0.0; 128 iteration_summary.gradient_norm = 0.0; 129 iteration_summary.step_norm = 0.0; 130 iteration_summary.linear_solver_iterations = 0; 131 iteration_summary.step_solver_time_in_seconds = 0; 132 133 // Do initial cost and Jacobian evaluation. 134 if (!Evaluate(evaluator, x, ¤t_state, &summary->message)) { 135 summary->termination_type = FAILURE; 136 summary->message = "Initial cost and jacobian evaluation failed. " 137 "More details: " + summary->message; 138 LOG_IF(WARNING, is_not_silent) << "Terminating: " << summary->message; 139 return; 140 } 141 142 summary->initial_cost = current_state.cost + summary->fixed_cost; 143 iteration_summary.cost = current_state.cost + summary->fixed_cost; 144 145 iteration_summary.gradient_max_norm = current_state.gradient_max_norm; 146 iteration_summary.gradient_norm = sqrt(current_state.gradient_squared_norm); 147 148 if (iteration_summary.gradient_max_norm <= options.gradient_tolerance) { 149 summary->message = StringPrintf("Gradient tolerance reached. " 150 "Gradient max norm: %e <= %e", 151 iteration_summary.gradient_max_norm, 152 options.gradient_tolerance); 153 summary->termination_type = CONVERGENCE; 154 VLOG_IF(1, is_not_silent) << "Terminating: " << summary->message; 155 return; 156 } 157 158 iteration_summary.iteration_time_in_seconds = 159 WallTimeInSeconds() - iteration_start_time; 160 iteration_summary.cumulative_time_in_seconds = 161 WallTimeInSeconds() - start_time 162 + summary->preprocessor_time_in_seconds; 163 summary->iterations.push_back(iteration_summary); 164 165 LineSearchDirection::Options line_search_direction_options; 166 line_search_direction_options.num_parameters = num_effective_parameters; 167 line_search_direction_options.type = options.line_search_direction_type; 168 line_search_direction_options.nonlinear_conjugate_gradient_type = 169 options.nonlinear_conjugate_gradient_type; 170 line_search_direction_options.max_lbfgs_rank = options.max_lbfgs_rank; 171 line_search_direction_options.use_approximate_eigenvalue_bfgs_scaling = 172 options.use_approximate_eigenvalue_bfgs_scaling; 173 scoped_ptr<LineSearchDirection> line_search_direction( 174 LineSearchDirection::Create(line_search_direction_options)); 175 176 LineSearchFunction line_search_function(evaluator); 177 178 LineSearch::Options line_search_options; 179 line_search_options.interpolation_type = 180 options.line_search_interpolation_type; 181 line_search_options.min_step_size = options.min_line_search_step_size; 182 line_search_options.sufficient_decrease = 183 options.line_search_sufficient_function_decrease; 184 line_search_options.max_step_contraction = 185 options.max_line_search_step_contraction; 186 line_search_options.min_step_contraction = 187 options.min_line_search_step_contraction; 188 line_search_options.max_num_iterations = 189 options.max_num_line_search_step_size_iterations; 190 line_search_options.sufficient_curvature_decrease = 191 options.line_search_sufficient_curvature_decrease; 192 line_search_options.max_step_expansion = 193 options.max_line_search_step_expansion; 194 line_search_options.function = &line_search_function; 195 196 scoped_ptr<LineSearch> 197 line_search(LineSearch::Create(options.line_search_type, 198 line_search_options, 199 &summary->message)); 200 if (line_search.get() == NULL) { 201 summary->termination_type = FAILURE; 202 LOG_IF(ERROR, is_not_silent) << "Terminating: " << summary->message; 203 return; 204 } 205 206 LineSearch::Summary line_search_summary; 207 int num_line_search_direction_restarts = 0; 208 209 while (true) { 210 if (!RunCallbacks(options, iteration_summary, summary)) { 211 break; 212 } 213 214 iteration_start_time = WallTimeInSeconds(); 215 if (iteration_summary.iteration >= options.max_num_iterations) { 216 summary->message = "Maximum number of iterations reached."; 217 summary->termination_type = NO_CONVERGENCE; 218 VLOG_IF(1, is_not_silent) << "Terminating: " << summary->message; 219 break; 220 } 221 222 const double total_solver_time = iteration_start_time - start_time + 223 summary->preprocessor_time_in_seconds; 224 if (total_solver_time >= options.max_solver_time_in_seconds) { 225 summary->message = "Maximum solver time reached."; 226 summary->termination_type = NO_CONVERGENCE; 227 VLOG_IF(1, is_not_silent) << "Terminating: " << summary->message; 228 break; 229 } 230 231 iteration_summary = IterationSummary(); 232 iteration_summary.iteration = summary->iterations.back().iteration + 1; 233 iteration_summary.step_is_valid = false; 234 iteration_summary.step_is_successful = false; 235 236 bool line_search_status = true; 237 if (iteration_summary.iteration == 1) { 238 current_state.search_direction = -current_state.gradient; 239 } else { 240 line_search_status = line_search_direction->NextDirection( 241 previous_state, 242 current_state, 243 ¤t_state.search_direction); 244 } 245 246 if (!line_search_status && 247 num_line_search_direction_restarts >= 248 options.max_num_line_search_direction_restarts) { 249 // Line search direction failed to generate a new direction, and we 250 // have already reached our specified maximum number of restarts, 251 // terminate optimization. 252 summary->message = 253 StringPrintf("Line search direction failure: specified " 254 "max_num_line_search_direction_restarts: %d reached.", 255 options.max_num_line_search_direction_restarts); 256 summary->termination_type = FAILURE; 257 LOG_IF(WARNING, is_not_silent) << "Terminating: " << summary->message; 258 break; 259 } else if (!line_search_status) { 260 // Restart line search direction with gradient descent on first iteration 261 // as we have not yet reached our maximum number of restarts. 262 CHECK_LT(num_line_search_direction_restarts, 263 options.max_num_line_search_direction_restarts); 264 265 ++num_line_search_direction_restarts; 266 LOG_IF(WARNING, is_not_silent) 267 << "Line search direction algorithm: " 268 << LineSearchDirectionTypeToString( 269 options.line_search_direction_type) 270 << ", failed to produce a valid new direction at " 271 << "iteration: " << iteration_summary.iteration 272 << ". Restarting, number of restarts: " 273 << num_line_search_direction_restarts << " / " 274 << options.max_num_line_search_direction_restarts 275 << " [max]."; 276 line_search_direction.reset( 277 LineSearchDirection::Create(line_search_direction_options)); 278 current_state.search_direction = -current_state.gradient; 279 } 280 281 line_search_function.Init(x, current_state.search_direction); 282 current_state.directional_derivative = 283 current_state.gradient.dot(current_state.search_direction); 284 285 // TODO(sameeragarwal): Refactor this into its own object and add 286 // explanations for the various choices. 287 // 288 // Note that we use !line_search_status to ensure that we treat cases when 289 // we restarted the line search direction equivalently to the first 290 // iteration. 291 const double initial_step_size = 292 (iteration_summary.iteration == 1 || !line_search_status) 293 ? min(1.0, 1.0 / current_state.gradient_max_norm) 294 : min(1.0, 2.0 * (current_state.cost - previous_state.cost) / 295 current_state.directional_derivative); 296 // By definition, we should only ever go forwards along the specified search 297 // direction in a line search, most likely cause for this being violated 298 // would be a numerical failure in the line search direction calculation. 299 if (initial_step_size < 0.0) { 300 summary->message = 301 StringPrintf("Numerical failure in line search, initial_step_size is " 302 "negative: %.5e, directional_derivative: %.5e, " 303 "(current_cost - previous_cost): %.5e", 304 initial_step_size, current_state.directional_derivative, 305 (current_state.cost - previous_state.cost)); 306 summary->termination_type = FAILURE; 307 LOG_IF(WARNING, is_not_silent) << "Terminating: " << summary->message; 308 break; 309 } 310 311 line_search->Search(initial_step_size, 312 current_state.cost, 313 current_state.directional_derivative, 314 &line_search_summary); 315 if (!line_search_summary.success) { 316 summary->message = 317 StringPrintf("Numerical failure in line search, failed to find " 318 "a valid step size, (did not run out of iterations) " 319 "using initial_step_size: %.5e, initial_cost: %.5e, " 320 "initial_gradient: %.5e.", 321 initial_step_size, current_state.cost, 322 current_state.directional_derivative); 323 LOG_IF(WARNING, is_not_silent) << "Terminating: " << summary->message; 324 summary->termination_type = FAILURE; 325 break; 326 } 327 328 current_state.step_size = line_search_summary.optimal_step_size; 329 delta = current_state.step_size * current_state.search_direction; 330 331 previous_state = current_state; 332 iteration_summary.step_solver_time_in_seconds = 333 WallTimeInSeconds() - iteration_start_time; 334 335 if (!evaluator->Plus(x.data(), delta.data(), x_plus_delta.data())) { 336 summary->termination_type = FAILURE; 337 summary->message = 338 "x_plus_delta = Plus(x, delta) failed. This should not happen " 339 "as the step was valid when it was selected by the line search."; 340 LOG_IF(WARNING, is_not_silent) << "Terminating: " << summary->message; 341 break; 342 } else if (!Evaluate(evaluator, 343 x_plus_delta, 344 ¤t_state, 345 &summary->message)) { 346 summary->termination_type = FAILURE; 347 summary->message = 348 "Step failed to evaluate. This should not happen as the step was " 349 "valid when it was selected by the line search. More details: " + 350 summary->message; 351 LOG_IF(WARNING, is_not_silent) << "Terminating: " << summary->message; 352 break; 353 } else { 354 x = x_plus_delta; 355 } 356 357 iteration_summary.gradient_max_norm = current_state.gradient_max_norm; 358 iteration_summary.gradient_norm = sqrt(current_state.gradient_squared_norm); 359 iteration_summary.cost_change = previous_state.cost - current_state.cost; 360 iteration_summary.cost = current_state.cost + summary->fixed_cost; 361 iteration_summary.step_norm = delta.norm(); 362 iteration_summary.step_is_valid = true; 363 iteration_summary.step_is_successful = true; 364 iteration_summary.step_norm = delta.norm(); 365 iteration_summary.step_size = current_state.step_size; 366 iteration_summary.line_search_function_evaluations = 367 line_search_summary.num_function_evaluations; 368 iteration_summary.line_search_gradient_evaluations = 369 line_search_summary.num_gradient_evaluations; 370 iteration_summary.line_search_iterations = 371 line_search_summary.num_iterations; 372 iteration_summary.iteration_time_in_seconds = 373 WallTimeInSeconds() - iteration_start_time; 374 iteration_summary.cumulative_time_in_seconds = 375 WallTimeInSeconds() - start_time 376 + summary->preprocessor_time_in_seconds; 377 378 summary->iterations.push_back(iteration_summary); 379 ++summary->num_successful_steps; 380 381 if (iteration_summary.gradient_max_norm <= options.gradient_tolerance) { 382 summary->message = StringPrintf("Gradient tolerance reached. " 383 "Gradient max norm: %e <= %e", 384 iteration_summary.gradient_max_norm, 385 options.gradient_tolerance); 386 summary->termination_type = CONVERGENCE; 387 VLOG_IF(1, is_not_silent) << "Terminating: " << summary->message; 388 break; 389 } 390 391 const double absolute_function_tolerance = 392 options.function_tolerance * previous_state.cost; 393 if (fabs(iteration_summary.cost_change) < absolute_function_tolerance) { 394 summary->message = 395 StringPrintf("Function tolerance reached. " 396 "|cost_change|/cost: %e <= %e", 397 fabs(iteration_summary.cost_change) / 398 previous_state.cost, 399 options.function_tolerance); 400 summary->termination_type = CONVERGENCE; 401 VLOG_IF(1, is_not_silent) << "Terminating: " << summary->message; 402 break; 403 } 404 } 405 } 406 407 } // namespace internal 408 } // namespace ceres 409