Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <limits>
     17 
     18 #include "tensorflow/core/kernels/hinge-loss.h"
     19 #include "tensorflow/core/kernels/logistic-loss.h"
     20 #include "tensorflow/core/kernels/smooth-hinge-loss.h"
     21 #include "tensorflow/core/kernels/squared-loss.h"
     22 #include "tensorflow/core/lib/core/errors.h"
     23 #include "tensorflow/core/lib/core/status.h"
     24 #include "tensorflow/core/lib/core/status_test_util.h"
     25 #include "tensorflow/core/platform/test.h"
     26 
     27 namespace tensorflow {
     28 namespace {
     29 
     30 // TODO(sibyl-Aix6ihai): add a test to show the improvements of the Newton
     31 // modification detailed in readme.md
     32 
     33 TEST(LogisticLoss, ComputePrimalLoss) {
     34   LogisticLossUpdater loss_updater;
     35   EXPECT_NEAR(0.693147,
     36               loss_updater.ComputePrimalLoss(0 /* wx */, 1 /* label */,
     37                                              1 /* example weight */),
     38               1e-3);
     39   EXPECT_NEAR(0.0,
     40               loss_updater.ComputePrimalLoss(70 /* wx */, 1 /* label */,
     41                                              1 /* example weight */),
     42               1e-3);
     43   EXPECT_NEAR(0.0,
     44               loss_updater.ComputePrimalLoss(-70 /* wx */, -1 /* label */,
     45                                              1 /* example weight */),
     46               1e-3);
     47 }
     48 
     49 TEST(LogisticLoss, ComputeDualLoss) {
     50   LogisticLossUpdater loss_updater;
     51   EXPECT_NEAR(0.0,
     52               loss_updater.ComputeDualLoss(0 /* current dual */, 1 /* label */,
     53                                            1 /* example weight */),
     54               1e-3);
     55   EXPECT_NEAR(0.0,
     56               loss_updater.ComputeDualLoss(1 /* current dual */, 1 /* label */,
     57                                            1 /* example weight */),
     58               1e-3);
     59   EXPECT_NEAR(
     60       -0.693147,
     61       loss_updater.ComputeDualLoss(0.5 /* current dual */, 1 /* label */,
     62                                    1 /* example weight */),
     63       1e-3);
     64 }
     65 
     66 TEST(LogisticLoss, ComputeUpdatedDual) {
     67   LogisticLossUpdater loss_updater;
     68   EXPECT_NEAR(0.479,
     69               loss_updater.ComputeUpdatedDual(
     70                   1 /* num partitions */, 1.0 /* label */,
     71                   1.0 /* example weight */, 0.5 /* current_dual */,
     72                   0.3 /* wx */, 10.0 /* weighted_example_norm */),
     73               1e-3);
     74 
     75   EXPECT_NEAR(-0.031,
     76               loss_updater.ComputeUpdatedDual(
     77                   2 /* num partitions */, -1.0 /* label */,
     78                   1.0 /* example weight */, 0.1 /* current_dual */,
     79                   -0.8 /* wx */, 10.0 /* weighted_example_norm */),
     80               1e-3);
     81 }
     82 
     83 TEST(SquaredLoss, ComputePrimalLoss) {
     84   SquaredLossUpdater loss_updater;
     85   EXPECT_NEAR(0.5,
     86               loss_updater.ComputePrimalLoss(0.0 /* wx */, 1.0 /* label */,
     87                                              1.0 /* example weight */),
     88               1e-3);
     89   EXPECT_NEAR(40.5,
     90               loss_updater.ComputePrimalLoss(10.0 /* wx */, 1.0 /* label */,
     91                                              1.0 /* example weight */),
     92               1e-3);
     93   EXPECT_NEAR(0.125,
     94               loss_updater.ComputePrimalLoss(-0.5 /* wx */, -1.0 /* label */,
     95                                              1.0 /* example weight */),
     96               1e-3);
     97   EXPECT_NEAR(4.84,
     98               loss_updater.ComputePrimalLoss(1.2 /* wx */, -1.0 /* label */,
     99                                              2.0 /* example weight */),
    100               1e-3);
    101 }
    102 
    103 TEST(SquaredLoss, ComputeDualLoss) {
    104   SquaredLossUpdater loss_updater;
    105   EXPECT_NEAR(
    106       0.0,
    107       loss_updater.ComputeDualLoss(0.0 /* current dual */, -1.0 /* label */,
    108                                    1.0 /* example weight */),
    109       1e-3);
    110   EXPECT_NEAR(
    111       0.66,
    112       loss_updater.ComputeDualLoss(0.2 /* current dual */, -1.0 /* label */,
    113                                    3.0 /* example weight */),
    114       1e-3);
    115   EXPECT_NEAR(
    116       -0.375,
    117       loss_updater.ComputeDualLoss(1.5 /* current dual */, 1.0 /* label */,
    118                                    1.0 /* example weight */),
    119       1e-3);
    120   EXPECT_NEAR(
    121       -1.125,
    122       loss_updater.ComputeDualLoss(0.5 /* current dual */, 1.0 /* label */,
    123                                    3.0 /* example weight */),
    124       1e-3);
    125 }
    126 
    127 TEST(SquaredLoss, ComputeUpdatedDual) {
    128   SquaredLossUpdater loss_updater;
    129   EXPECT_NEAR(0.336,
    130               loss_updater.ComputeUpdatedDual(
    131                   1 /* num partitions */, 1.0 /* label */,
    132                   1.0 /* example weight */, 0.3 /* current_dual */,
    133                   0.3 /* wx */, 10.0 /* weighted_example_norm */),
    134               1e-3);
    135 
    136   EXPECT_NEAR(-0.427,
    137               loss_updater.ComputeUpdatedDual(
    138                   5 /* num partitions */, -1.0 /* label */,
    139                   1.0 /* example weight */, -0.4 /* current_dual */,
    140                   0.8 /* wx */, 10.0 /* weighted_example_norm */),
    141               1e-3);
    142 }
    143 
    144 TEST(HingeLoss, ComputePrimalLoss) {
    145   HingeLossUpdater loss_updater;
    146   EXPECT_NEAR(1.0,
    147               loss_updater.ComputePrimalLoss(0.0 /* wx */, 1.0 /* label */,
    148                                              1.0 /* example weight */),
    149               1e-3);
    150   EXPECT_NEAR(0.0,
    151               loss_updater.ComputePrimalLoss(10.0 /* wx */, 1.0 /* label */,
    152                                              1.0 /* example weight */),
    153               1e-3);
    154   EXPECT_NEAR(0.5,
    155               loss_updater.ComputePrimalLoss(-0.5 /* wx */, -1.0 /* label */,
    156                                              1.0 /* example weight */),
    157               1e-3);
    158   EXPECT_NEAR(4.4,
    159               loss_updater.ComputePrimalLoss(1.2 /* wx */, -1.0 /* label */,
    160                                              2.0 /* example weight */),
    161               1e-3);
    162 }
    163 
    164 TEST(HingeLoss, ComputeDualLoss) {
    165   HingeLossUpdater loss_updater;
    166   EXPECT_NEAR(
    167       0.0,
    168       loss_updater.ComputeDualLoss(0.0 /* current dual */, -1.0 /* label */,
    169                                    1.0 /* example weight */),
    170       1e-3);
    171   EXPECT_NEAR(
    172       std::numeric_limits<double>::max(),
    173       loss_updater.ComputeDualLoss(0.2 /* current dual */, -1.0 /* label */,
    174                                    3.0 /* example weight */),
    175       1e-3);
    176   EXPECT_NEAR(
    177       std::numeric_limits<double>::max(),
    178       loss_updater.ComputeDualLoss(1.5 /* current dual */, 1.0 /* label */,
    179                                    1.0 /* example weight */),
    180       1e-3);
    181   EXPECT_NEAR(
    182       -1.5,
    183       loss_updater.ComputeDualLoss(0.5 /* current dual */, 1.0 /* label */,
    184                                    3.0 /* example weight */),
    185       1e-3);
    186 }
    187 
    188 TEST(HingeLoss, ConvertLabel) {
    189   HingeLossUpdater loss_updater;
    190   float example_label = 1.0;
    191   Status status;
    192 
    193   // A label with value 1.0 should remain intact.
    194   TF_EXPECT_OK(loss_updater.ConvertLabel(&example_label));
    195   EXPECT_EQ(1.0, example_label);
    196 
    197   // A label with value 0.0 should be converted to -1.0.
    198   example_label = 0.0;
    199   TF_EXPECT_OK(loss_updater.ConvertLabel(&example_label));
    200   EXPECT_EQ(-1.0, example_label);
    201 
    202   // Any other initial value should throw an error.
    203   example_label = 0.5;
    204   status = loss_updater.ConvertLabel(&example_label);
    205   EXPECT_FALSE(status.ok());
    206 }
    207 
    208 TEST(HingeLoss, ComputeUpdatedDual) {
    209   HingeLossUpdater loss_updater;
    210   // When label=1.0, example_weight=1.0, current_dual=0.5, wx=0.3 and
    211   // weighted_example_norm=100.0, it turns out that the optimal value to update
    212   // the dual to is 0.507 which is within the permitted range and thus should be
    213   // the value returned.
    214   EXPECT_NEAR(0.507,
    215               loss_updater.ComputeUpdatedDual(
    216                   1 /* num partitions */, 1.0 /* label */,
    217                   1.0 /* example weight */, 0.5 /* current_dual */,
    218                   0.3 /* wx */, 100.0 /* weighted_example_norm */),
    219               1e-3);
    220   // When label=-1.0, example_weight=1.0, current_dual=0.4, wx=0.6,
    221   // weighted_example_norm=10.0 and num_loss_partitions=10, it turns out that
    222   // the optimal value to update the dual to is 0.384 which is within the
    223   // permitted range and thus should be the value returned.
    224   EXPECT_NEAR(-0.416,
    225               loss_updater.ComputeUpdatedDual(
    226                   10 /* num partitions */, -1.0 /* label */,
    227                   1.0 /* example weight */, -0.4 /* current_dual */,
    228                   0.6 /* wx */, 10.0 /* weighted_example_norm */),
    229               1e-3);
    230   // When label=1.0, example_weight=1.0, current_dual=-0.5, wx=0.3 and
    231   // weighted_example_norm=10.0, it turns out that the optimal value to update
    232   // the dual to is -0.43. However, this is outside the allowed [0.0, 1.0] range
    233   // and hence the closest permitted value (0.0) should be returned instead.
    234   EXPECT_NEAR(0.0,
    235               loss_updater.ComputeUpdatedDual(
    236                   1 /* num partitions */, 1.0 /* label */,
    237                   1.0 /* example weight */, -0.5 /* current_dual */,
    238                   0.3 /* wx */, 10.0 /* weighted_example_norm */),
    239               1e-3);
    240 
    241   // When label=-1.0, example_weight=2.0, current_dual=-1.0, wx=0.3 and
    242   // weighted_example_norm=10.0, it turns out that the optimal value to update
    243   // the dual to is -1.065. However, this is outside the allowed [-1.0, 0.0]
    244   // range and hence the closest permitted value (-1.0) should be returned
    245   // instead.
    246   EXPECT_NEAR(-1.0,
    247               loss_updater.ComputeUpdatedDual(
    248                   1 /* num partitions */, -1.0 /* label */,
    249                   2.0 /* example weight */, -1.0 /* current_dual */,
    250                   0.3 /* wx */, 10.0 /* weighted_example_norm */),
    251               1e-3);
    252 }
    253 
    254 TEST(SmoothHingeLoss, ComputePrimalLoss) {
    255   SmoothHingeLossUpdater loss_updater;
    256   EXPECT_NEAR(0.5,
    257               loss_updater.ComputePrimalLoss(0.0 /* wx */, 1.0 /* label */,
    258                                              1.0 /* example weight */),
    259               1e-3);
    260   EXPECT_NEAR(0.0,
    261               loss_updater.ComputePrimalLoss(10.0 /* wx */, 1.0 /* label */,
    262                                              1.0 /* example weight */),
    263               1e-3);
    264   EXPECT_NEAR(0.125,
    265               loss_updater.ComputePrimalLoss(-0.5 /* wx */, -1.0 /* label */,
    266                                              1.0 /* example weight */),
    267               1e-3);
    268   EXPECT_NEAR(3.4,
    269               loss_updater.ComputePrimalLoss(1.2 /* wx */, -1.0 /* label */,
    270                                              2.0 /* example weight */),
    271               1e-3);
    272 }
    273 
    274 TEST(SmoothHingeLoss, ComputeDualLoss) {
    275   SmoothHingeLossUpdater loss_updater;
    276   EXPECT_NEAR(
    277       0.0,
    278       loss_updater.ComputeDualLoss(0.0 /* current dual */, -1.0 /* label */,
    279                                    1.0 /* example weight */),
    280       1e-3);
    281   EXPECT_NEAR(
    282       std::numeric_limits<double>::max(),
    283       loss_updater.ComputeDualLoss(0.2 /* current dual */, -1.0 /* label */,
    284                                    3.0 /* example weight */),
    285       1e-3);
    286   EXPECT_NEAR(
    287       std::numeric_limits<double>::max(),
    288       loss_updater.ComputeDualLoss(1.5 /* current dual */, 1.0 /* label */,
    289                                    1.0 /* example weight */),
    290       1e-3);
    291   EXPECT_NEAR(
    292       -1.125,
    293       loss_updater.ComputeDualLoss(0.5 /* current dual */, 1.0 /* label */,
    294                                    3.0 /* example weight */),
    295       1e-3);
    296 }
    297 
    298 TEST(SmoothHingeLoss, ComputeUpdatedDual) {
    299   SmoothHingeLossUpdater loss_updater;
    300   EXPECT_NEAR(0.336,
    301               loss_updater.ComputeUpdatedDual(
    302                   1 /* num partitions */, 1.0 /* label */,
    303                   1.0 /* example weight */, 0.3 /* current_dual */,
    304                   0.3 /* wx */, 10.0 /* weighted_example_norm */),
    305               1e-3);
    306 
    307   EXPECT_NEAR(-0.427,
    308               loss_updater.ComputeUpdatedDual(
    309                   5 /* num partitions */, -1.0 /* label */,
    310                   1.0 /* example weight */, -0.4 /* current_dual */,
    311                   0.8 /* wx */, 10.0 /* weighted_example_norm */),
    312               1e-3);
    313 }
    314 
    315 }  // namespace
    316 }  // namespace tensorflow
    317