1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h" 16 17 #include "tensorflow/core/framework/tensor_testutil.h" 18 #include "tensorflow/core/platform/test.h" 19 20 using std::vector; 21 using tensorflow::test::AsTensor; 22 23 namespace tensorflow { 24 namespace boosted_trees { 25 namespace learner { 26 namespace stochastic { 27 namespace { 28 29 const double kDelta = 1e-5; 30 31 TEST(NodeStatsTest, AlmostZero) { 32 LearnerConfig learner_config; 33 learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); 34 NodeStats node_stats(learner_config, GradientStats(1e-8f, 1e-8f)); 35 EXPECT_EQ(0, node_stats.weight_contribution[0]); 36 EXPECT_EQ(0, node_stats.gain); 37 } 38 39 TEST(NodeStatsTest, LessThanMinWeightConstraint) { 40 LearnerConfig learner_config; 41 learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); 42 learner_config.mutable_constraints()->set_min_node_weight(3.2f); 43 NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f)); 44 EXPECT_EQ(0, node_stats.weight_contribution[0]); 45 EXPECT_EQ(0, node_stats.gain); 46 } 47 48 TEST(NodeStatsTest, L1RegSquashed) { 49 LearnerConfig learner_config; 50 learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); 51 learner_config.mutable_regularization()->set_l1(10.0f); 52 NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f)); 53 EXPECT_EQ(0, node_stats.weight_contribution[0]); 54 EXPECT_EQ(0, node_stats.gain); 55 } 56 57 TEST(NodeStatsTest, L1RegPos) { 58 LearnerConfig learner_config; 59 learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); 60 learner_config.mutable_regularization()->set_l1(5.0f); 61 NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f)); 62 const float expected_clipped_grad = 7.32f - 5.0f; 63 const float expected_weight_contribution = -expected_clipped_grad / 1.63f; 64 const float expected_gain = 65 expected_clipped_grad * expected_clipped_grad / 1.63f; 66 EXPECT_FLOAT_EQ(expected_weight_contribution, 67 node_stats.weight_contribution[0]); 68 EXPECT_FLOAT_EQ(expected_gain, node_stats.gain); 69 } 70 71 TEST(NodeStatsTest, L1RegNeg) { 72 LearnerConfig learner_config; 73 learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); 74 learner_config.mutable_regularization()->set_l1(5.0f); 75 NodeStats node_stats(learner_config, GradientStats(-7.32f, 1.63f)); 76 const float expected_clipped_grad = -7.32f + 5.0f; 77 const float expected_weight_contribution = -expected_clipped_grad / 1.63f; 78 const float expected_gain = 79 expected_clipped_grad * expected_clipped_grad / 1.63f; 80 EXPECT_FLOAT_EQ(expected_weight_contribution, 81 node_stats.weight_contribution[0]); 82 EXPECT_FLOAT_EQ(expected_gain, node_stats.gain); 83 } 84 85 TEST(NodeStatsTest, L2Reg) { 86 LearnerConfig learner_config; 87 learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); 88 learner_config.mutable_regularization()->set_l2(8.0f); 89 NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f)); 90 const float expected_denom = 1.63f + 8.0f; 91 const float expected_weight_contribution = -7.32f / expected_denom; 92 const float expected_gain = 7.32f * 7.32f / expected_denom; 93 EXPECT_FLOAT_EQ(expected_weight_contribution, 94 node_stats.weight_contribution[0]); 95 EXPECT_FLOAT_EQ(expected_gain, node_stats.gain); 96 } 97 98 TEST(NodeStatsTest, L1L2Reg) { 99 LearnerConfig learner_config; 100 learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); 101 learner_config.mutable_regularization()->set_l1(5.0f); 102 learner_config.mutable_regularization()->set_l2(8.0f); 103 NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f)); 104 const float expected_clipped_grad = 7.32f - 5.0f; 105 const float expected_denom = 1.63f + 8.0f; 106 const float expected_weight_contribution = 107 -expected_clipped_grad / expected_denom; 108 const float expected_gain = 109 expected_clipped_grad * expected_clipped_grad / expected_denom; 110 EXPECT_FLOAT_EQ(expected_weight_contribution, 111 node_stats.weight_contribution[0]); 112 EXPECT_FLOAT_EQ(expected_gain, node_stats.gain); 113 } 114 115 TEST(NodeStatsTest, MulticlassFullHessianTest) { 116 LearnerConfig learner_config; 117 learner_config.set_multi_class_strategy(LearnerConfig::FULL_HESSIAN); 118 learner_config.mutable_regularization()->set_l2(0.3f); 119 120 const int kNumClasses = 4; 121 const auto& g_shape = TensorShape({1, kNumClasses}); 122 Tensor g = AsTensor<float>({0.5, 0.33, -9, 1}, g_shape); 123 const auto& hessian_shape = TensorShape({1, kNumClasses, kNumClasses}); 124 Tensor h = AsTensor<float>({3, 5, 7, 8, 5, 4, 1, 5, 7, 1, 8, 4, 8, 5, 4, 9}, 125 hessian_shape); 126 127 NodeStats node_stats(learner_config, GradientStats(g, h)); 128 129 // Index 1 has 0 value because of l1 regularization, 130 std::vector<float> expected_weight = {0.9607576, 0.4162569, 0.9863192, 131 -1.5820024}; 132 133 EXPECT_EQ(kNumClasses, node_stats.weight_contribution.size()); 134 for (int i = 0; i < kNumClasses; ++i) { 135 EXPECT_NEAR(expected_weight[i], node_stats.weight_contribution[i], kDelta); 136 } 137 EXPECT_NEAR(9.841132, node_stats.gain, kDelta); 138 } 139 140 TEST(NodeStatsTest, MulticlassDiagonalHessianTest) { 141 // Normal case. 142 { 143 LearnerConfig learner_config; 144 learner_config.set_multi_class_strategy(LearnerConfig::FULL_HESSIAN); 145 learner_config.mutable_regularization()->set_l2(0.3f); 146 147 const int kNumClasses = 4; 148 const auto& g_shape = TensorShape({1, kNumClasses}); 149 Tensor g = AsTensor<float>({0.5, 0.33, -9, 1}, g_shape); 150 Tensor h; 151 // Full hessian. 152 { 153 const auto& hessian_shape = TensorShape({1, kNumClasses, kNumClasses}); 154 // Construct full hessian. 155 h = AsTensor<float>({3, 0, 0, 0, 0, 4, 0, 0, 0, 0, 8, 0, 0, 0, 0, 9}, 156 hessian_shape); 157 } 158 NodeStats full_node_stats(learner_config, GradientStats(g, h)); 159 160 // Diagonal only. 161 { 162 const auto& hessian_shape = TensorShape({1, kNumClasses}); 163 // Construct diagonal of hessian. 164 h = AsTensor<float>({3, 4, 8, 9}, hessian_shape); 165 } 166 learner_config.set_multi_class_strategy(LearnerConfig::DIAGONAL_HESSIAN); 167 NodeStats diag_node_stats(learner_config, GradientStats(g, h)); 168 169 // Full and diagonal hessian should return the same results. 170 EXPECT_EQ(full_node_stats.weight_contribution.size(), 171 diag_node_stats.weight_contribution.size()); 172 for (int i = 0; i < full_node_stats.weight_contribution.size(); ++i) { 173 EXPECT_FLOAT_EQ(full_node_stats.weight_contribution[i], 174 diag_node_stats.weight_contribution[i]); 175 } 176 EXPECT_EQ(full_node_stats.gain, diag_node_stats.gain); 177 } 178 // Zero entries in diagonal, no regularization 179 { 180 LearnerConfig learner_config; 181 learner_config.set_multi_class_strategy(LearnerConfig::FULL_HESSIAN); 182 183 const int kNumClasses = 4; 184 const auto& g_shape = TensorShape({1, kNumClasses}); 185 Tensor g = AsTensor<float>({0.5, 0.33, -9, 1}, g_shape); 186 Tensor h; 187 // Full hessian. 188 { 189 const auto& hessian_shape = TensorShape({1, kNumClasses, kNumClasses}); 190 // Construct full hessian. 191 h = AsTensor<float>({3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0}, 192 hessian_shape); 193 } 194 NodeStats full_node_stats(learner_config, GradientStats(g, h)); 195 196 // Diagonal only. 197 { 198 const auto& hessian_shape = TensorShape({1, kNumClasses}); 199 // Diagonal of hessian, two entries are 0 200 h = AsTensor<float>({3, 0, 8, 0}, hessian_shape); 201 } 202 learner_config.set_multi_class_strategy(LearnerConfig::DIAGONAL_HESSIAN); 203 NodeStats diag_node_stats(learner_config, GradientStats(g, h)); 204 205 // Full and diagonal hessian should return the same results. 206 EXPECT_EQ(full_node_stats.weight_contribution.size(), 207 diag_node_stats.weight_contribution.size()); 208 for (int i = 0; i < full_node_stats.weight_contribution.size(); ++i) { 209 EXPECT_FLOAT_EQ(full_node_stats.weight_contribution[i], 210 diag_node_stats.weight_contribution[i]); 211 } 212 EXPECT_EQ(full_node_stats.gain, diag_node_stats.gain); 213 } 214 } 215 216 } // namespace 217 } // namespace stochastic 218 } // namespace learner 219 } // namespace boosted_trees 220 } // namespace tensorflow 221