Home | History | Annotate | Download | only in stats
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_
     16 #define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_
     17 
     18 #include "third_party/eigen3/Eigen/Core"
     19 #include "third_party/eigen3/Eigen/Eigenvalues"
     20 #include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h"
     21 #include "tensorflow/contrib/boosted_trees/proto/learner.pb.h"
     22 #include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"
     23 #include "tensorflow/core/framework/shape_inference.h"
     24 #include "tensorflow/core/framework/tensor_shape.h"
     25 #include "tensorflow/core/lib/strings/str_util.h"
     26 #include "tensorflow/core/lib/strings/strcat.h"
     27 #include "tensorflow/core/lib/strings/stringprintf.h"
     28 
     29 namespace tensorflow {
     30 namespace boosted_trees {
     31 namespace learner {
     32 namespace stochastic {
     33 
     34 using tensorflow::boosted_trees::learner::LearnerConfig;
     35 using tensorflow::boosted_trees::learner::LearnerConfig_MultiClassStrategy;
     36 using tensorflow::boosted_trees::learner::
     37     LearnerConfig_MultiClassStrategy_DIAGONAL_HESSIAN;
     38 using tensorflow::boosted_trees::learner::
     39     LearnerConfig_MultiClassStrategy_FULL_HESSIAN;
     40 using tensorflow::boosted_trees::learner::
     41     LearnerConfig_MultiClassStrategy_TREE_PER_CLASS;
     42 
     43 // NodeStats holds aggregate gradient stats as well as metadata about the node.
     44 struct NodeStats {
     45   // Initialize the NodeStats with 0 stats.  We need the output length
     46   // so that we can make weight_contribution the right length.
     47   explicit NodeStats(const int output_length)
     48       : weight_contribution(output_length, 0.0f), gain(0) {}
     49 
     50   NodeStats(const LearnerConfig& learner_config,
     51             const GradientStats& grad_stats)
     52       : NodeStats(learner_config.regularization().l1(),
     53                   learner_config.regularization().l2(),
     54                   learner_config.constraints().min_node_weight(),
     55                   learner_config.multi_class_strategy(), grad_stats) {}
     56 
     57   NodeStats(float l1_reg, float l2_reg, float min_node_weight,
     58             const LearnerConfig_MultiClassStrategy& strategy,
     59             const GradientStats& grad_stats)
     60       : gradient_stats(grad_stats), gain(0) {
     61     switch (strategy) {
     62       case LearnerConfig_MultiClassStrategy_TREE_PER_CLASS: {
     63         float g;
     64         float h;
     65         // Initialize now in case of early return.
     66         weight_contribution.push_back(0.0f);
     67 
     68         if (grad_stats.first.t.NumElements() == 0 ||
     69             grad_stats.second.t.NumElements() == 0) {
     70           return;
     71         }
     72 
     73         g = grad_stats.first.t.unaligned_flat<float>()(0);
     74         h = grad_stats.second.t.unaligned_flat<float>()(0);
     75 
     76         if (grad_stats.IsAlmostZero() || h <= min_node_weight) {
     77           return;
     78         }
     79 
     80         // Apply L1 regularization.
     81         if (l1_reg > 0) {
     82           if (g > l1_reg) {
     83             g -= l1_reg;
     84           } else if (g < -l1_reg) {
     85             g += l1_reg;
     86           } else {
     87             return;
     88           }
     89         }
     90 
     91         // The node gain is given by: (l'^2) / (l'' + l2_reg) and the node
     92         // weight
     93         // contribution is given by: (-l') / (l'' + l2_reg).
     94         // Note that l'' can't be zero here because of the min node weight check
     95         // since min node weight must be >= 0.
     96         weight_contribution[0] = -g / (h + l2_reg);
     97         gain = (weight_contribution[0] * -g);
     98         break;
     99       }
    100       case LearnerConfig_MultiClassStrategy_FULL_HESSIAN: {
    101         weight_contribution.clear();
    102 
    103         if (grad_stats.first.t.NumElements() == 0 ||
    104             grad_stats.second.t.NumElements() == 0) {
    105           return;
    106         }
    107         const int64 grad_dim = grad_stats.first.t.dim_size(1);
    108 
    109         QCHECK(grad_stats.first.t.dims() == 2)
    110             << strings::Printf("Gradient should be of rank 2, got rank %d",
    111                                grad_stats.first.t.dims());
    112         QCHECK(grad_stats.first.t.dim_size(0) == 1) << strings::Printf(
    113             "Gradient must be of shape 1 x %lld, got %lld x %lld", grad_dim,
    114             grad_stats.first.t.dim_size(0), grad_dim);
    115         QCHECK(grad_stats.second.t.dims() == 3)
    116             << strings::Printf("Hessian should be of rank 3, got rank %d",
    117                                grad_stats.second.t.dims());
    118         QCHECK(grad_stats.second.t.shape() ==
    119                TensorShape({1, grad_dim, grad_dim}))
    120             << strings::Printf(
    121                    "Hessian must be of shape 1 x %lld x %lld, got %lld x % lld "
    122                    " x % lld ",
    123                    grad_dim, grad_dim, grad_stats.second.t.shape().dim_size(0),
    124                    grad_stats.second.t.shape().dim_size(1),
    125                    grad_stats.second.t.shape().dim_size(2));
    126 
    127         // Check if we're violating min weight constraint.
    128 
    129         if (grad_stats.IsAlmostZero() ||
    130             grad_stats.second.Magnitude() <= min_node_weight) {
    131           return;
    132         }
    133         // TODO(nponomareva): figure out l1 in matrix form.
    134         // g is a vector of gradients, H is a hessian matrix.
    135         Eigen::VectorXf g = TensorToEigenVector(grad_stats.first.t, grad_dim);
    136 
    137         Eigen::MatrixXf hessian =
    138             TensorToEigenMatrix(grad_stats.second.t, grad_dim, grad_dim);
    139         // I is an identity matrix.
    140         // The gain in general form is g^T (H+l2 I)^-1 g.
    141         // The node weights are -(H+l2 I)^-1 g.
    142         Eigen::MatrixXf identity;
    143         identity.setIdentity(grad_dim, grad_dim);
    144 
    145         Eigen::MatrixXf hessian_and_reg = hessian + l2_reg * identity;
    146 
    147         CalculateWeightAndGain(hessian_and_reg, g);
    148         break;
    149       }
    150       case LearnerConfig_MultiClassStrategy_DIAGONAL_HESSIAN: {
    151         weight_contribution.clear();
    152         if (grad_stats.first.t.NumElements() == 0 ||
    153             grad_stats.second.t.NumElements() == 0) {
    154           return;
    155         }
    156         const int64 grad_dim = grad_stats.first.t.dim_size(1);
    157 
    158         QCHECK(grad_stats.first.t.dims() == 2)
    159             << strings::Printf("Gradient should be of rank 2, got rank %d",
    160                                grad_stats.first.t.dims());
    161         QCHECK(grad_stats.first.t.dim_size(0) == 1) << strings::Printf(
    162             "Gradient must be of shape 1 x %lld, got %lld x %lld", grad_dim,
    163             grad_stats.first.t.dim_size(0), grad_dim);
    164         QCHECK(grad_stats.second.t.dims() == 2)
    165             << strings::Printf("Hessian should be of rank 2, got rank %d",
    166                                grad_stats.second.t.dims());
    167         QCHECK(grad_stats.second.t.shape() == TensorShape({1, grad_dim}))
    168             << strings::Printf(
    169                    "Hessian must be of shape 1 x %lld, got %lld x %lld",
    170                    grad_dim, grad_stats.second.t.shape().dim_size(0),
    171                    grad_stats.second.t.shape().dim_size(1));
    172 
    173         // Check if we're violating min weight constraint.
    174         if (grad_stats.IsAlmostZero() ||
    175             grad_stats.second.Magnitude() <= min_node_weight) {
    176           return;
    177         }
    178         // TODO(nponomareva): figure out l1 in matrix form.
    179         // Diagonal of the hessian.
    180         Eigen::ArrayXf hessian =
    181             TensorToEigenArray(grad_stats.second.t, grad_dim);
    182         Eigen::ArrayXf hessian_and_reg = hessian + l2_reg;
    183 
    184         // Check if any of the elements are zeros.
    185         bool invertible = true;
    186         for (int i = 0; i < hessian_and_reg.size(); ++i) {
    187           if (hessian_and_reg[i] == 0.0) {
    188             invertible = false;
    189             break;
    190           }
    191         }
    192         if (invertible) {
    193           Eigen::ArrayXf g = TensorToEigenArray(grad_stats.first.t, grad_dim);
    194           // Operations on arrays are element wise. The formulas are as for full
    195           // hessian, but for hessian of diagonal form they are simplified.
    196           Eigen::ArrayXf ones = Eigen::ArrayXf::Ones(grad_dim);
    197           Eigen::ArrayXf temp = ones / hessian_and_reg;
    198           Eigen::ArrayXf weight = -temp * g;
    199 
    200           // Copy over weights to weight_contribution.
    201           weight_contribution =
    202               std::vector<float>(weight.data(), weight.data() + weight.rows());
    203           gain = (-g * weight).sum();
    204         } else {
    205           Eigen::VectorXf g = TensorToEigenVector(grad_stats.first.t, grad_dim);
    206           // Hessian is not invertible. We will go the same route as in full
    207           // hessian to get an approximate solution.
    208           CalculateWeightAndGain(hessian_and_reg.matrix().asDiagonal(), g);
    209         }
    210         break;
    211       }
    212       default:
    213         LOG(FATAL) << "Unknown multi-class strategy " << strategy;
    214         break;
    215     }
    216   }
    217 
    218   string DebugString() const {
    219     return strings::StrCat(
    220         gradient_stats.DebugString(), "\n",
    221         "Weight_contrib = ", str_util::Join(weight_contribution, ","),
    222         "Gain = ", gain);
    223   }
    224 
    225   // Use these node stats to populate a Leaf's model.
    226   void FillLeaf(const int class_id, boosted_trees::trees::Leaf* leaf) const {
    227     if (class_id == -1) {
    228       for (int i = 0; i < weight_contribution.size(); i++) {
    229         leaf->mutable_vector()->add_value(weight_contribution[i]);
    230       }
    231     } else {
    232       CHECK(weight_contribution.size() == 1)
    233           << "Weight contribution size = " << weight_contribution.size();
    234       leaf->mutable_sparse_vector()->add_index(class_id);
    235       leaf->mutable_sparse_vector()->add_value(weight_contribution[0]);
    236     }
    237   }
    238 
    239   // Sets the weight_contribution and gain member variables based on the
    240   // given regularized Hessian and gradient vector g.
    241   void CalculateWeightAndGain(const Eigen::MatrixXf& hessian_and_reg,
    242                               const Eigen::VectorXf& g) {
    243     // The gain in general form is g^T (Hessian_and_regularization)^-1 g.
    244     // The node weights are -(Hessian_and_regularization)^-1 g.
    245     Eigen::VectorXf weight;
    246     // If we want to calculate x = K^-1 v, instead of explicitly calculating
    247     // K^-1 and multiplying by v, we can solve this matrix equation using
    248     // solve method.
    249     weight = -hessian_and_reg.colPivHouseholderQr().solve(g);
    250     // Copy over weights to weight_contribution.
    251     weight_contribution =
    252         std::vector<float>(weight.data(), weight.data() + weight.rows());
    253 
    254     gain = -g.transpose() * weight;
    255   }
    256 
    257   static Eigen::MatrixXf TensorToEigenMatrix(const Tensor& tensor,
    258                                              const int num_rows,
    259                                              const int num_cols) {
    260     return Eigen::Map<const Eigen::MatrixXf>(tensor.flat<float>().data(),
    261                                              num_rows, num_cols);
    262   }
    263 
    264   static Eigen::VectorXf TensorToEigenVector(const Tensor& tensor,
    265                                              const int num_elements) {
    266     return Eigen::Map<const Eigen::VectorXf>(tensor.flat<float>().data(),
    267                                              num_elements);
    268   }
    269 
    270   static Eigen::ArrayXf TensorToEigenArray(const Tensor& tensor,
    271                                            const int num_elements) {
    272     return Eigen::Map<const Eigen::ArrayXf>(tensor.flat<float>().data(),
    273                                             num_elements);
    274   }
    275 
    276   GradientStats gradient_stats;
    277   std::vector<float> weight_contribution;
    278   float gain;
    279 };
    280 
    281 // Helper macro to check std::vector<float> approximate equality.
    282 #define EXPECT_VECTOR_FLOAT_EQ(x, y)       \
    283   {                                        \
    284     EXPECT_EQ((x).size(), (y).size());     \
    285     for (int i = 0; i < (x).size(); ++i) { \
    286       EXPECT_FLOAT_EQ((x)[i], (y)[i]);     \
    287     }                                      \
    288   }
    289 
    290 // Helper macro to check node stats approximate equality.
    291 #define EXPECT_NODE_STATS_EQ(val1, val2)                                      \
    292   EXPECT_GRADIENT_STATS_EQ(val1.gradient_stats, val2.gradient_stats);         \
    293   EXPECT_VECTOR_FLOAT_EQ(val1.weight_contribution, val2.weight_contribution); \
    294   EXPECT_FLOAT_EQ(val1.gain, val2.gain);
    295 
    296 }  // namespace stochastic
    297 }  // namespace learner
    298 }  // namespace boosted_trees
    299 }  // namespace tensorflow
    300 
    301 #endif  // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_
    302