Home | History | Annotate | Download | only in ceres
      1 // Ceres Solver - A fast non-linear least squares minimizer
      2 // Copyright 2014 Google Inc. All rights reserved.
      3 // http://code.google.com/p/ceres-solver/
      4 //
      5 // Redistribution and use in source and binary forms, with or without
      6 // modification, are permitted provided that the following conditions are met:
      7 //
      8 // * Redistributions of source code must retain the above copyright notice,
      9 //   this list of conditions and the following disclaimer.
     10 // * Redistributions in binary form must reproduce the above copyright notice,
     11 //   this list of conditions and the following disclaimer in the documentation
     12 //   and/or other materials provided with the distribution.
     13 // * Neither the name of Google Inc. nor the names of its contributors may be
     14 //   used to endorse or promote products derived from this software without
     15 //   specific prior written permission.
     16 //
     17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
     18 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
     19 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
     20 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
     21 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
     22 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
     23 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
     24 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
     25 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
     26 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
     27 // POSSIBILITY OF SUCH DAMAGE.
     28 //
     29 // Author: sameeragarwal (at) google.com (Sameer Agarwal)
     30 
     31 #include "ceres/solver.h"
     32 
     33 #include <limits>
     34 #include <cmath>
     35 #include <vector>
     36 #include "gtest/gtest.h"
     37 #include "ceres/internal/scoped_ptr.h"
     38 #include "ceres/autodiff_cost_function.h"
     39 #include "ceres/sized_cost_function.h"
     40 #include "ceres/problem.h"
     41 #include "ceres/problem_impl.h"
     42 
     43 namespace ceres {
     44 namespace internal {
     45 
     46 TEST(SolverOptions, DefaultTrustRegionOptionsAreValid) {
     47   Solver::Options options;
     48   options.minimizer_type = TRUST_REGION;
     49   string error;
     50   EXPECT_TRUE(options.IsValid(&error)) << error;
     51 }
     52 
     53 TEST(SolverOptions, DefaultLineSearchOptionsAreValid) {
     54   Solver::Options options;
     55   options.minimizer_type = LINE_SEARCH;
     56   string error;
     57   EXPECT_TRUE(options.IsValid(&error)) << error;
     58 }
     59 
     60 struct QuadraticCostFunctor {
     61   template <typename T> bool operator()(const T* const x,
     62                                         T* residual) const {
     63     residual[0] = T(5.0) - *x;
     64     return true;
     65   }
     66 
     67   static CostFunction* Create() {
     68     return new AutoDiffCostFunction<QuadraticCostFunctor, 1, 1>(
     69         new QuadraticCostFunctor);
     70   }
     71 };
     72 
     73 struct RememberingCallback : public IterationCallback {
     74   explicit RememberingCallback(double *x) : calls(0), x(x) {}
     75   virtual ~RememberingCallback() {}
     76   virtual CallbackReturnType operator()(const IterationSummary& summary) {
     77     x_values.push_back(*x);
     78     return SOLVER_CONTINUE;
     79   }
     80   int calls;
     81   double *x;
     82   vector<double> x_values;
     83 };
     84 
     85 TEST(Solver, UpdateStateEveryIterationOption) {
     86   double x = 50.0;
     87   const double original_x = x;
     88 
     89   scoped_ptr<CostFunction> cost_function(QuadraticCostFunctor::Create());
     90   Problem::Options problem_options;
     91   problem_options.cost_function_ownership = DO_NOT_TAKE_OWNERSHIP;
     92   Problem problem(problem_options);
     93   problem.AddResidualBlock(cost_function.get(), NULL, &x);
     94 
     95   Solver::Options options;
     96   options.linear_solver_type = DENSE_QR;
     97 
     98   RememberingCallback callback(&x);
     99   options.callbacks.push_back(&callback);
    100 
    101   Solver::Summary summary;
    102 
    103   int num_iterations;
    104 
    105   // First try: no updating.
    106   Solve(options, &problem, &summary);
    107   num_iterations = summary.num_successful_steps +
    108                    summary.num_unsuccessful_steps;
    109   EXPECT_GT(num_iterations, 1);
    110   for (int i = 0; i < callback.x_values.size(); ++i) {
    111     EXPECT_EQ(50.0, callback.x_values[i]);
    112   }
    113 
    114   // Second try: with updating
    115   x = 50.0;
    116   options.update_state_every_iteration = true;
    117   callback.x_values.clear();
    118   Solve(options, &problem, &summary);
    119   num_iterations = summary.num_successful_steps +
    120                    summary.num_unsuccessful_steps;
    121   EXPECT_GT(num_iterations, 1);
    122   EXPECT_EQ(original_x, callback.x_values[0]);
    123   EXPECT_NE(original_x, callback.x_values[1]);
    124 }
    125 
    126 // The parameters must be in separate blocks so that they can be individually
    127 // set constant or not.
    128 struct Quadratic4DCostFunction {
    129   template <typename T> bool operator()(const T* const x,
    130                                         const T* const y,
    131                                         const T* const z,
    132                                         const T* const w,
    133                                         T* residual) const {
    134     // A 4-dimension axis-aligned quadratic.
    135     residual[0] = T(10.0) - *x +
    136                   T(20.0) - *y +
    137                   T(30.0) - *z +
    138                   T(40.0) - *w;
    139     return true;
    140   }
    141 
    142   static CostFunction* Create() {
    143     return new AutoDiffCostFunction<Quadratic4DCostFunction, 1, 1, 1, 1, 1>(
    144         new Quadratic4DCostFunction);
    145   }
    146 };
    147 
    148 // A cost function that simply returns its argument.
    149 class UnaryIdentityCostFunction : public SizedCostFunction<1, 1> {
    150  public:
    151   virtual bool Evaluate(double const* const* parameters,
    152                         double* residuals,
    153                         double** jacobians) const {
    154     residuals[0] = parameters[0][0];
    155     if (jacobians != NULL && jacobians[0] != NULL) {
    156       jacobians[0][0] = 1.0;
    157     }
    158     return true;
    159   }
    160 };
    161 
    162 TEST(Solver, TrustRegionProblemHasNoParameterBlocks) {
    163   Problem problem;
    164   Solver::Options options;
    165   options.minimizer_type = TRUST_REGION;
    166   Solver::Summary summary;
    167   Solve(options, &problem, &summary);
    168   EXPECT_EQ(summary.termination_type, CONVERGENCE);
    169   EXPECT_EQ(summary.message,
    170             "Function tolerance reached. "
    171             "No non-constant parameter blocks found.");
    172 }
    173 
    174 TEST(Solver, LineSearchProblemHasNoParameterBlocks) {
    175   Problem problem;
    176   Solver::Options options;
    177   options.minimizer_type = LINE_SEARCH;
    178   Solver::Summary summary;
    179   Solve(options, &problem, &summary);
    180   EXPECT_EQ(summary.termination_type, CONVERGENCE);
    181   EXPECT_EQ(summary.message,
    182             "Function tolerance reached. "
    183             "No non-constant parameter blocks found.");
    184 }
    185 
    186 TEST(Solver, TrustRegionProblemHasZeroResiduals) {
    187   Problem problem;
    188   double x = 1;
    189   problem.AddParameterBlock(&x, 1);
    190   Solver::Options options;
    191   options.minimizer_type = TRUST_REGION;
    192   Solver::Summary summary;
    193   Solve(options, &problem, &summary);
    194   EXPECT_EQ(summary.termination_type, CONVERGENCE);
    195   EXPECT_EQ(summary.message,
    196             "Function tolerance reached. "
    197             "No non-constant parameter blocks found.");
    198 }
    199 
    200 TEST(Solver, LineSearchProblemHasZeroResiduals) {
    201   Problem problem;
    202   double x = 1;
    203   problem.AddParameterBlock(&x, 1);
    204   Solver::Options options;
    205   options.minimizer_type = LINE_SEARCH;
    206   Solver::Summary summary;
    207   Solve(options, &problem, &summary);
    208   EXPECT_EQ(summary.termination_type, CONVERGENCE);
    209   EXPECT_EQ(summary.message,
    210             "Function tolerance reached. "
    211             "No non-constant parameter blocks found.");
    212 }
    213 
    214 TEST(Solver, TrustRegionProblemIsConstant) {
    215   Problem problem;
    216   double x = 1;
    217   problem.AddResidualBlock(new UnaryIdentityCostFunction, NULL, &x);
    218   problem.SetParameterBlockConstant(&x);
    219   Solver::Options options;
    220   options.minimizer_type = TRUST_REGION;
    221   Solver::Summary summary;
    222   Solve(options, &problem, &summary);
    223   EXPECT_EQ(summary.termination_type, CONVERGENCE);
    224   EXPECT_EQ(summary.initial_cost, 1.0 / 2.0);
    225   EXPECT_EQ(summary.final_cost, 1.0 / 2.0);
    226 }
    227 
    228 TEST(Solver, LineSearchProblemIsConstant) {
    229   Problem problem;
    230   double x = 1;
    231   problem.AddResidualBlock(new UnaryIdentityCostFunction, NULL, &x);
    232   problem.SetParameterBlockConstant(&x);
    233   Solver::Options options;
    234   options.minimizer_type = LINE_SEARCH;
    235   Solver::Summary summary;
    236   Solve(options, &problem, &summary);
    237   EXPECT_EQ(summary.termination_type, CONVERGENCE);
    238   EXPECT_EQ(summary.initial_cost, 1.0 / 2.0);
    239   EXPECT_EQ(summary.final_cost, 1.0 / 2.0);
    240 }
    241 
    242 #if defined(CERES_NO_SUITESPARSE)
    243 TEST(Solver, SparseNormalCholeskyNoSuiteSparse) {
    244   Solver::Options options;
    245   options.sparse_linear_algebra_library_type = SUITE_SPARSE;
    246   options.linear_solver_type = SPARSE_NORMAL_CHOLESKY;
    247   string message;
    248   EXPECT_FALSE(options.IsValid(&message));
    249 }
    250 #endif
    251 
    252 #if defined(CERES_NO_CXSPARSE)
    253 TEST(Solver, SparseNormalCholeskyNoCXSparse) {
    254   Solver::Options options;
    255   options.sparse_linear_algebra_library_type = CX_SPARSE;
    256   options.linear_solver_type = SPARSE_NORMAL_CHOLESKY;
    257   string message;
    258   EXPECT_FALSE(options.IsValid(&message));
    259 }
    260 #endif
    261 
    262 TEST(Solver, IterativeLinearSolverForDogleg) {
    263   Solver::Options options;
    264   options.trust_region_strategy_type = DOGLEG;
    265   string message;
    266   options.linear_solver_type = ITERATIVE_SCHUR;
    267   EXPECT_FALSE(options.IsValid(&message));
    268 
    269   options.linear_solver_type = CGNR;
    270   EXPECT_FALSE(options.IsValid(&message));
    271 }
    272 
    273 TEST(Solver, LinearSolverTypeNormalOperation) {
    274   Solver::Options options;
    275   options.linear_solver_type = DENSE_QR;
    276 
    277   string message;
    278   EXPECT_TRUE(options.IsValid(&message));
    279 
    280   options.linear_solver_type = DENSE_NORMAL_CHOLESKY;
    281   EXPECT_TRUE(options.IsValid(&message));
    282 
    283   options.linear_solver_type = DENSE_SCHUR;
    284   EXPECT_TRUE(options.IsValid(&message));
    285 
    286   options.linear_solver_type = SPARSE_SCHUR;
    287 #if defined(CERES_NO_SUITESPARSE) && defined(CERES_NO_CXSPARSE)
    288   EXPECT_FALSE(options.IsValid(&message));
    289 #else
    290   EXPECT_TRUE(options.IsValid(&message));
    291 #endif
    292 
    293   options.linear_solver_type = ITERATIVE_SCHUR;
    294   EXPECT_TRUE(options.IsValid(&message));
    295 }
    296 
    297 }  // namespace internal
    298 }  // namespace ceres
    299