Home | History | Annotate | Download | only in stats
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h"
     16 
     17 #include "tensorflow/core/framework/tensor_testutil.h"
     18 #include "tensorflow/core/platform/test.h"
     19 
     20 using std::vector;
     21 using tensorflow::test::AsTensor;
     22 
     23 namespace tensorflow {
     24 namespace boosted_trees {
     25 namespace learner {
     26 namespace stochastic {
     27 namespace {
     28 
     29 const double kDelta = 1e-5;
     30 
     31 TEST(NodeStatsTest, AlmostZero) {
     32   LearnerConfig learner_config;
     33   learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
     34   NodeStats node_stats(learner_config, GradientStats(1e-8f, 1e-8f));
     35   EXPECT_EQ(0, node_stats.weight_contribution[0]);
     36   EXPECT_EQ(0, node_stats.gain);
     37 }
     38 
     39 TEST(NodeStatsTest, LessThanMinWeightConstraint) {
     40   LearnerConfig learner_config;
     41   learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
     42   learner_config.mutable_constraints()->set_min_node_weight(3.2f);
     43   NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
     44   EXPECT_EQ(0, node_stats.weight_contribution[0]);
     45   EXPECT_EQ(0, node_stats.gain);
     46 }
     47 
     48 TEST(NodeStatsTest, L1RegSquashed) {
     49   LearnerConfig learner_config;
     50   learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
     51   learner_config.mutable_regularization()->set_l1(10.0f);
     52   NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
     53   EXPECT_EQ(0, node_stats.weight_contribution[0]);
     54   EXPECT_EQ(0, node_stats.gain);
     55 }
     56 
     57 TEST(NodeStatsTest, L1RegPos) {
     58   LearnerConfig learner_config;
     59   learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
     60   learner_config.mutable_regularization()->set_l1(5.0f);
     61   NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
     62   const float expected_clipped_grad = 7.32f - 5.0f;
     63   const float expected_weight_contribution = -expected_clipped_grad / 1.63f;
     64   const float expected_gain =
     65       expected_clipped_grad * expected_clipped_grad / 1.63f;
     66   EXPECT_FLOAT_EQ(expected_weight_contribution,
     67                   node_stats.weight_contribution[0]);
     68   EXPECT_FLOAT_EQ(expected_gain, node_stats.gain);
     69 }
     70 
     71 TEST(NodeStatsTest, L1RegNeg) {
     72   LearnerConfig learner_config;
     73   learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
     74   learner_config.mutable_regularization()->set_l1(5.0f);
     75   NodeStats node_stats(learner_config, GradientStats(-7.32f, 1.63f));
     76   const float expected_clipped_grad = -7.32f + 5.0f;
     77   const float expected_weight_contribution = -expected_clipped_grad / 1.63f;
     78   const float expected_gain =
     79       expected_clipped_grad * expected_clipped_grad / 1.63f;
     80   EXPECT_FLOAT_EQ(expected_weight_contribution,
     81                   node_stats.weight_contribution[0]);
     82   EXPECT_FLOAT_EQ(expected_gain, node_stats.gain);
     83 }
     84 
     85 TEST(NodeStatsTest, L2Reg) {
     86   LearnerConfig learner_config;
     87   learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
     88   learner_config.mutable_regularization()->set_l2(8.0f);
     89   NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
     90   const float expected_denom = 1.63f + 8.0f;
     91   const float expected_weight_contribution = -7.32f / expected_denom;
     92   const float expected_gain = 7.32f * 7.32f / expected_denom;
     93   EXPECT_FLOAT_EQ(expected_weight_contribution,
     94                   node_stats.weight_contribution[0]);
     95   EXPECT_FLOAT_EQ(expected_gain, node_stats.gain);
     96 }
     97 
     98 TEST(NodeStatsTest, L1L2Reg) {
     99   LearnerConfig learner_config;
    100   learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS);
    101   learner_config.mutable_regularization()->set_l1(5.0f);
    102   learner_config.mutable_regularization()->set_l2(8.0f);
    103   NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f));
    104   const float expected_clipped_grad = 7.32f - 5.0f;
    105   const float expected_denom = 1.63f + 8.0f;
    106   const float expected_weight_contribution =
    107       -expected_clipped_grad / expected_denom;
    108   const float expected_gain =
    109       expected_clipped_grad * expected_clipped_grad / expected_denom;
    110   EXPECT_FLOAT_EQ(expected_weight_contribution,
    111                   node_stats.weight_contribution[0]);
    112   EXPECT_FLOAT_EQ(expected_gain, node_stats.gain);
    113 }
    114 
    115 TEST(NodeStatsTest, MulticlassFullHessianTest) {
    116   LearnerConfig learner_config;
    117   learner_config.set_multi_class_strategy(LearnerConfig::FULL_HESSIAN);
    118   learner_config.mutable_regularization()->set_l2(0.3f);
    119 
    120   const int kNumClasses = 4;
    121   const auto& g_shape = TensorShape({1, kNumClasses});
    122   Tensor g = AsTensor<float>({0.5, 0.33, -9, 1}, g_shape);
    123   const auto& hessian_shape = TensorShape({1, kNumClasses, kNumClasses});
    124   Tensor h = AsTensor<float>({3, 5, 7, 8, 5, 4, 1, 5, 7, 1, 8, 4, 8, 5, 4, 9},
    125                              hessian_shape);
    126 
    127   NodeStats node_stats(learner_config, GradientStats(g, h));
    128 
    129   // Index 1 has 0 value because of l1 regularization,
    130   std::vector<float> expected_weight = {0.9607576, 0.4162569, 0.9863192,
    131                                         -1.5820024};
    132 
    133   EXPECT_EQ(kNumClasses, node_stats.weight_contribution.size());
    134   for (int i = 0; i < kNumClasses; ++i) {
    135     EXPECT_NEAR(expected_weight[i], node_stats.weight_contribution[i], kDelta);
    136   }
    137   EXPECT_NEAR(9.841132, node_stats.gain, kDelta);
    138 }
    139 
    140 TEST(NodeStatsTest, MulticlassDiagonalHessianTest) {
    141   // Normal case.
    142   {
    143     LearnerConfig learner_config;
    144     learner_config.set_multi_class_strategy(LearnerConfig::FULL_HESSIAN);
    145     learner_config.mutable_regularization()->set_l2(0.3f);
    146 
    147     const int kNumClasses = 4;
    148     const auto& g_shape = TensorShape({1, kNumClasses});
    149     Tensor g = AsTensor<float>({0.5, 0.33, -9, 1}, g_shape);
    150     Tensor h;
    151     // Full hessian.
    152     {
    153       const auto& hessian_shape = TensorShape({1, kNumClasses, kNumClasses});
    154       // Construct full hessian.
    155       h = AsTensor<float>({3, 0, 0, 0, 0, 4, 0, 0, 0, 0, 8, 0, 0, 0, 0, 9},
    156                           hessian_shape);
    157     }
    158     NodeStats full_node_stats(learner_config, GradientStats(g, h));
    159 
    160     // Diagonal only.
    161     {
    162       const auto& hessian_shape = TensorShape({1, kNumClasses});
    163       // Construct diagonal of hessian.
    164       h = AsTensor<float>({3, 4, 8, 9}, hessian_shape);
    165     }
    166     learner_config.set_multi_class_strategy(LearnerConfig::DIAGONAL_HESSIAN);
    167     NodeStats diag_node_stats(learner_config, GradientStats(g, h));
    168 
    169     // Full and diagonal hessian should return the same results.
    170     EXPECT_EQ(full_node_stats.weight_contribution.size(),
    171               diag_node_stats.weight_contribution.size());
    172     for (int i = 0; i < full_node_stats.weight_contribution.size(); ++i) {
    173       EXPECT_FLOAT_EQ(full_node_stats.weight_contribution[i],
    174                       diag_node_stats.weight_contribution[i]);
    175     }
    176     EXPECT_EQ(full_node_stats.gain, diag_node_stats.gain);
    177   }
    178   // Zero entries in diagonal, no regularization
    179   {
    180     LearnerConfig learner_config;
    181     learner_config.set_multi_class_strategy(LearnerConfig::FULL_HESSIAN);
    182 
    183     const int kNumClasses = 4;
    184     const auto& g_shape = TensorShape({1, kNumClasses});
    185     Tensor g = AsTensor<float>({0.5, 0.33, -9, 1}, g_shape);
    186     Tensor h;
    187     // Full hessian.
    188     {
    189       const auto& hessian_shape = TensorShape({1, kNumClasses, kNumClasses});
    190       // Construct full hessian.
    191       h = AsTensor<float>({3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0},
    192                           hessian_shape);
    193     }
    194     NodeStats full_node_stats(learner_config, GradientStats(g, h));
    195 
    196     // Diagonal only.
    197     {
    198       const auto& hessian_shape = TensorShape({1, kNumClasses});
    199       // Diagonal of hessian, two entries are 0
    200       h = AsTensor<float>({3, 0, 8, 0}, hessian_shape);
    201     }
    202     learner_config.set_multi_class_strategy(LearnerConfig::DIAGONAL_HESSIAN);
    203     NodeStats diag_node_stats(learner_config, GradientStats(g, h));
    204 
    205     // Full and diagonal hessian should return the same results.
    206     EXPECT_EQ(full_node_stats.weight_contribution.size(),
    207               diag_node_stats.weight_contribution.size());
    208     for (int i = 0; i < full_node_stats.weight_contribution.size(); ++i) {
    209       EXPECT_FLOAT_EQ(full_node_stats.weight_contribution[i],
    210                       diag_node_stats.weight_contribution[i]);
    211     }
    212     EXPECT_EQ(full_node_stats.gain, diag_node_stats.gain);
    213   }
    214 }
    215 
    216 }  // namespace
    217 }  // namespace stochastic
    218 }  // namespace learner
    219 }  // namespace boosted_trees
    220 }  // namespace tensorflow
    221