1 // Ceres Solver - A fast non-linear least squares minimizer 2 // Copyright 2010, 2011, 2012 Google Inc. All rights reserved. 3 // http://code.google.com/p/ceres-solver/ 4 // 5 // Redistribution and use in source and binary forms, with or without 6 // modification, are permitted provided that the following conditions are met: 7 // 8 // * Redistributions of source code must retain the above copyright notice, 9 // this list of conditions and the following disclaimer. 10 // * Redistributions in binary form must reproduce the above copyright notice, 11 // this list of conditions and the following disclaimer in the documentation 12 // and/or other materials provided with the distribution. 13 // * Neither the name of Google Inc. nor the names of its contributors may be 14 // used to endorse or promote products derived from this software without 15 // specific prior written permission. 16 // 17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 21 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 // POSSIBILITY OF SUCH DAMAGE. 28 // 29 // Author: sameeragarwal (at) google.com (Sameer Agarwal) 30 31 #include "ceres/loss_function.h" 32 33 #include <cstddef> 34 35 #include "glog/logging.h" 36 #include "gtest/gtest.h" 37 38 namespace ceres { 39 namespace internal { 40 namespace { 41 42 // Helper function for testing a LossFunction callback. 43 // 44 // Compares the values of rho'(s) and rho''(s) computed by the 45 // callback with estimates obtained by symmetric finite differencing 46 // of rho(s). 47 void AssertLossFunctionIsValid(const LossFunction& loss, double s) { 48 CHECK_GT(s, 0); 49 50 // Evaluate rho(s), rho'(s) and rho''(s). 51 double rho[3]; 52 loss.Evaluate(s, rho); 53 54 // Use symmetric finite differencing to estimate rho'(s) and 55 // rho''(s). 56 const double kH = 1e-4; 57 // Values at s + kH. 58 double fwd[3]; 59 // Values at s - kH. 60 double bwd[3]; 61 loss.Evaluate(s + kH, fwd); 62 loss.Evaluate(s - kH, bwd); 63 64 // First derivative. 65 const double fd_1 = (fwd[0] - bwd[0]) / (2 * kH); 66 ASSERT_NEAR(fd_1, rho[1], 1e-6); 67 68 // Second derivative. 69 const double fd_2 = (fwd[0] - 2*rho[0] + bwd[0]) / (kH * kH); 70 ASSERT_NEAR(fd_2, rho[2], 1e-6); 71 } 72 } // namespace 73 74 // Try two values of the scaling a = 0.7 and 1.3 75 // (where scaling makes sense) and of the squared norm 76 // s = 0.357 and 1.792 77 // 78 // Note that for the Huber loss the test exercises both code paths 79 // (i.e. both small and large values of s). 80 81 TEST(LossFunction, TrivialLoss) { 82 AssertLossFunctionIsValid(TrivialLoss(), 0.357); 83 AssertLossFunctionIsValid(TrivialLoss(), 1.792); 84 } 85 86 TEST(LossFunction, HuberLoss) { 87 AssertLossFunctionIsValid(HuberLoss(0.7), 0.357); 88 AssertLossFunctionIsValid(HuberLoss(0.7), 1.792); 89 AssertLossFunctionIsValid(HuberLoss(1.3), 0.357); 90 AssertLossFunctionIsValid(HuberLoss(1.3), 1.792); 91 } 92 93 TEST(LossFunction, SoftLOneLoss) { 94 AssertLossFunctionIsValid(SoftLOneLoss(0.7), 0.357); 95 AssertLossFunctionIsValid(SoftLOneLoss(0.7), 1.792); 96 AssertLossFunctionIsValid(SoftLOneLoss(1.3), 0.357); 97 AssertLossFunctionIsValid(SoftLOneLoss(1.3), 1.792); 98 } 99 100 TEST(LossFunction, CauchyLoss) { 101 AssertLossFunctionIsValid(CauchyLoss(0.7), 0.357); 102 AssertLossFunctionIsValid(CauchyLoss(0.7), 1.792); 103 AssertLossFunctionIsValid(CauchyLoss(1.3), 0.357); 104 AssertLossFunctionIsValid(CauchyLoss(1.3), 1.792); 105 } 106 107 TEST(LossFunction, ArctanLoss) { 108 AssertLossFunctionIsValid(ArctanLoss(0.7), 0.357); 109 AssertLossFunctionIsValid(ArctanLoss(0.7), 1.792); 110 AssertLossFunctionIsValid(ArctanLoss(1.3), 0.357); 111 AssertLossFunctionIsValid(ArctanLoss(1.3), 1.792); 112 } 113 114 TEST(LossFunction, TolerantLoss) { 115 AssertLossFunctionIsValid(TolerantLoss(0.7, 0.4), 0.357); 116 AssertLossFunctionIsValid(TolerantLoss(0.7, 0.4), 1.792); 117 AssertLossFunctionIsValid(TolerantLoss(0.7, 0.4), 55.5); 118 AssertLossFunctionIsValid(TolerantLoss(1.3, 0.1), 0.357); 119 AssertLossFunctionIsValid(TolerantLoss(1.3, 0.1), 1.792); 120 AssertLossFunctionIsValid(TolerantLoss(1.3, 0.1), 55.5); 121 // Check the value at zero is actually zero. 122 double rho[3]; 123 TolerantLoss(0.7, 0.4).Evaluate(0.0, rho); 124 ASSERT_NEAR(rho[0], 0.0, 1e-6); 125 // Check that loss before and after the approximation threshold are good. 126 // A threshold of 36.7 is used by the implementation. 127 AssertLossFunctionIsValid(TolerantLoss(20.0, 1.0), 20.0 + 36.6); 128 AssertLossFunctionIsValid(TolerantLoss(20.0, 1.0), 20.0 + 36.7); 129 AssertLossFunctionIsValid(TolerantLoss(20.0, 1.0), 20.0 + 36.8); 130 AssertLossFunctionIsValid(TolerantLoss(20.0, 1.0), 20.0 + 1000.0); 131 } 132 133 TEST(LossFunction, ComposedLoss) { 134 { 135 HuberLoss f(0.7); 136 CauchyLoss g(1.3); 137 ComposedLoss c(&f, DO_NOT_TAKE_OWNERSHIP, &g, DO_NOT_TAKE_OWNERSHIP); 138 AssertLossFunctionIsValid(c, 0.357); 139 AssertLossFunctionIsValid(c, 1.792); 140 } 141 { 142 CauchyLoss f(0.7); 143 HuberLoss g(1.3); 144 ComposedLoss c(&f, DO_NOT_TAKE_OWNERSHIP, &g, DO_NOT_TAKE_OWNERSHIP); 145 AssertLossFunctionIsValid(c, 0.357); 146 AssertLossFunctionIsValid(c, 1.792); 147 } 148 } 149 150 TEST(LossFunction, ScaledLoss) { 151 // Wrap a few loss functions, and a few scale factors. This can't combine 152 // construction with the call to AssertLossFunctionIsValid() because Apple's 153 // GCC is unable to eliminate the copy of ScaledLoss, which is not copyable. 154 { 155 ScaledLoss scaled_loss(NULL, 6, TAKE_OWNERSHIP); 156 AssertLossFunctionIsValid(scaled_loss, 0.323); 157 } 158 { 159 ScaledLoss scaled_loss(new TrivialLoss(), 10, TAKE_OWNERSHIP); 160 AssertLossFunctionIsValid(scaled_loss, 0.357); 161 } 162 { 163 ScaledLoss scaled_loss(new HuberLoss(0.7), 0.1, TAKE_OWNERSHIP); 164 AssertLossFunctionIsValid(scaled_loss, 1.792); 165 } 166 { 167 ScaledLoss scaled_loss(new SoftLOneLoss(1.3), 0.1, TAKE_OWNERSHIP); 168 AssertLossFunctionIsValid(scaled_loss, 1.792); 169 } 170 { 171 ScaledLoss scaled_loss(new CauchyLoss(1.3), 10, TAKE_OWNERSHIP); 172 AssertLossFunctionIsValid(scaled_loss, 1.792); 173 } 174 { 175 ScaledLoss scaled_loss(new ArctanLoss(1.3), 10, TAKE_OWNERSHIP); 176 AssertLossFunctionIsValid(scaled_loss, 1.792); 177 } 178 { 179 ScaledLoss scaled_loss( 180 new TolerantLoss(1.3, 0.1), 10, TAKE_OWNERSHIP); 181 AssertLossFunctionIsValid(scaled_loss, 1.792); 182 } 183 { 184 ScaledLoss scaled_loss( 185 new ComposedLoss( 186 new HuberLoss(0.8), TAKE_OWNERSHIP, 187 new TolerantLoss(1.3, 0.5), TAKE_OWNERSHIP), 10, TAKE_OWNERSHIP); 188 AssertLossFunctionIsValid(scaled_loss, 1.792); 189 } 190 } 191 192 TEST(LossFunction, LossFunctionWrapper) { 193 // Initialization 194 HuberLoss loss_function1(1.0); 195 LossFunctionWrapper loss_function_wrapper(new HuberLoss(1.0), 196 TAKE_OWNERSHIP); 197 198 double s = 0.862; 199 double rho_gold[3]; 200 double rho[3]; 201 loss_function1.Evaluate(s, rho_gold); 202 loss_function_wrapper.Evaluate(s, rho); 203 for (int i = 0; i < 3; ++i) { 204 EXPECT_NEAR(rho[i], rho_gold[i], 1e-12); 205 } 206 207 // Resetting 208 HuberLoss loss_function2(0.5); 209 loss_function_wrapper.Reset(new HuberLoss(0.5), TAKE_OWNERSHIP); 210 loss_function_wrapper.Evaluate(s, rho); 211 loss_function2.Evaluate(s, rho_gold); 212 for (int i = 0; i < 3; ++i) { 213 EXPECT_NEAR(rho[i], rho_gold[i], 1e-12); 214 } 215 216 // Not taking ownership. 217 HuberLoss loss_function3(0.3); 218 loss_function_wrapper.Reset(&loss_function3, DO_NOT_TAKE_OWNERSHIP); 219 loss_function_wrapper.Evaluate(s, rho); 220 loss_function3.Evaluate(s, rho_gold); 221 for (int i = 0; i < 3; ++i) { 222 EXPECT_NEAR(rho[i], rho_gold[i], 1e-12); 223 } 224 } 225 226 } // namespace internal 227 } // namespace ceres 228