Home | History | Annotate | Download | only in ceres
      1 // Ceres Solver - A fast non-linear least squares minimizer
      2 // Copyright 2013 Google Inc. All rights reserved.
      3 // http://code.google.com/p/ceres-solver/
      4 //
      5 // Redistribution and use in source and binary forms, with or without
      6 // modification, are permitted provided that the following conditions are met:
      7 //
      8 // * Redistributions of source code must retain the above copyright notice,
      9 //   this list of conditions and the following disclaimer.
     10 // * Redistributions in binary form must reproduce the above copyright notice,
     11 //   this list of conditions and the following disclaimer in the documentation
     12 //   and/or other materials provided with the distribution.
     13 // * Neither the name of Google Inc. nor the names of its contributors may be
     14 //   used to endorse or promote products derived from this software without
     15 //   specific prior written permission.
     16 //
     17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
     18 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
     19 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
     20 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
     21 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
     22 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
     23 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
     24 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
     25 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
     26 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
     27 // POSSIBILITY OF SUCH DAMAGE.
     28 //
     29 // Author: mierle (at) gmail.com (Keir Mierle)
     30 //
     31 // An incomplete C API for Ceres.
     32 //
     33 // TODO(keir): Figure out why logging does not seem to work.
     34 
     35 #include "ceres/c_api.h"
     36 
     37 #include <vector>
     38 #include <iostream>
     39 #include <string>
     40 #include "ceres/cost_function.h"
     41 #include "ceres/loss_function.h"
     42 #include "ceres/problem.h"
     43 #include "ceres/solver.h"
     44 #include "ceres/types.h"  // for std
     45 #include "glog/logging.h"
     46 
     47 using ceres::Problem;
     48 
     49 void ceres_init() {
     50   // This is not ideal, but it's not clear what to do if there is no gflags and
     51   // no access to command line arguments.
     52   char message[] = "<unknown>";
     53   google::InitGoogleLogging(message);
     54 }
     55 
     56 ceres_problem_t* ceres_create_problem() {
     57   return reinterpret_cast<ceres_problem_t*>(new Problem);
     58 }
     59 
     60 void ceres_free_problem(ceres_problem_t* problem) {
     61   delete reinterpret_cast<Problem*>(problem);
     62 }
     63 
     64 // This cost function wraps a C-level function pointer from the user, to bridge
     65 // between C and C++.
     66 class CallbackCostFunction : public ceres::CostFunction {
     67  public:
     68   CallbackCostFunction(ceres_cost_function_t cost_function,
     69                        void* user_data,
     70                        int num_residuals,
     71                        int num_parameter_blocks,
     72                        int* parameter_block_sizes)
     73       : cost_function_(cost_function),
     74         user_data_(user_data) {
     75     set_num_residuals(num_residuals);
     76     for (int i = 0; i < num_parameter_blocks; ++i) {
     77       mutable_parameter_block_sizes()->push_back(parameter_block_sizes[i]);
     78     }
     79   }
     80 
     81   virtual ~CallbackCostFunction() {}
     82 
     83   virtual bool Evaluate(double const* const* parameters,
     84                         double* residuals,
     85                         double** jacobians) const {
     86     return (*cost_function_)(user_data_,
     87                              const_cast<double**>(parameters),
     88                              residuals,
     89                              jacobians);
     90   }
     91 
     92  private:
     93   ceres_cost_function_t cost_function_;
     94   void* user_data_;
     95 };
     96 
     97 // This loss function wraps a C-level function pointer from the user, to bridge
     98 // between C and C++.
     99 class CallbackLossFunction : public ceres::LossFunction {
    100  public:
    101   explicit CallbackLossFunction(ceres_loss_function_t loss_function,
    102                                 void* user_data)
    103     : loss_function_(loss_function), user_data_(user_data) {}
    104   virtual void Evaluate(double sq_norm, double* rho) const {
    105     (*loss_function_)(user_data_, sq_norm, rho);
    106   }
    107 
    108  private:
    109   ceres_loss_function_t loss_function_;
    110   void* user_data_;
    111 };
    112 
    113 // Wrappers for the stock loss functions.
    114 void* ceres_create_huber_loss_function_data(double a) {
    115   return new ceres::HuberLoss(a);
    116 }
    117 void* ceres_create_softl1_loss_function_data(double a) {
    118   return new ceres::SoftLOneLoss(a);
    119 }
    120 void* ceres_create_cauchy_loss_function_data(double a) {
    121   return new ceres::CauchyLoss(a);
    122 }
    123 void* ceres_create_arctan_loss_function_data(double a) {
    124   return new ceres::ArctanLoss(a);
    125 }
    126 void* ceres_create_tolerant_loss_function_data(double a, double b) {
    127   return new ceres::TolerantLoss(a, b);
    128 }
    129 
    130 void ceres_free_stock_loss_function_data(void* loss_function_data) {
    131   delete reinterpret_cast<ceres::LossFunction*>(loss_function_data);
    132 }
    133 
    134 void ceres_stock_loss_function(void* user_data,
    135                                double squared_norm,
    136                                double out[3]) {
    137   reinterpret_cast<ceres::LossFunction*>(user_data)
    138       ->Evaluate(squared_norm, out);
    139 }
    140 
    141 ceres_residual_block_id_t* ceres_problem_add_residual_block(
    142     ceres_problem_t* problem,
    143     ceres_cost_function_t cost_function,
    144     void* cost_function_data,
    145     ceres_loss_function_t loss_function,
    146     void* loss_function_data,
    147     int num_residuals,
    148     int num_parameter_blocks,
    149     int* parameter_block_sizes,
    150     double** parameters) {
    151   Problem* ceres_problem = reinterpret_cast<Problem*>(problem);
    152 
    153   ceres::CostFunction* callback_cost_function =
    154       new CallbackCostFunction(cost_function,
    155                                cost_function_data,
    156                                num_residuals,
    157                                num_parameter_blocks,
    158                                parameter_block_sizes);
    159 
    160   ceres::LossFunction* callback_loss_function = NULL;
    161   if (loss_function != NULL) {
    162     callback_loss_function = new CallbackLossFunction(loss_function,
    163                                                       loss_function_data);
    164   }
    165 
    166   std::vector<double*> parameter_blocks(parameters,
    167                                         parameters + num_parameter_blocks);
    168   return reinterpret_cast<ceres_residual_block_id_t*>(
    169       ceres_problem->AddResidualBlock(callback_cost_function,
    170                                       callback_loss_function,
    171                                       parameter_blocks));
    172 }
    173 
    174 void ceres_solve(ceres_problem_t* c_problem) {
    175   Problem* problem = reinterpret_cast<Problem*>(c_problem);
    176 
    177   // TODO(keir): Obviously, this way of setting options won't scale or last.
    178   // Instead, figure out a way to specify some of the options without
    179   // duplicating everything.
    180   ceres::Solver::Options options;
    181   options.max_num_iterations = 100;
    182   options.linear_solver_type = ceres::DENSE_QR;
    183   options.minimizer_progress_to_stdout = true;
    184 
    185   ceres::Solver::Summary summary;
    186   ceres::Solve(options, problem, &summary);
    187   std::cout << summary.FullReport() << "\n";
    188 }
    189