Home | History | Annotate | Download | only in kernels
      1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h"
     16 #include <algorithm>
     17 #include <cfloat>
     18 #include "tensorflow/core/lib/random/philox_random.h"
     19 #include "tensorflow/core/platform/logging.h"
     20 
     21 namespace tensorflow {
     22 namespace tensorforest {
     23 
     24 using tensorflow::Tensor;
     25 
     26 DataColumnTypes FindDenseFeatureSpec(
     27     int32 input_feature, const tensorforest::TensorForestDataSpec& spec) {
     28   return static_cast<DataColumnTypes>(spec.GetDenseFeatureType(input_feature));
     29 }
     30 
     31 DataColumnTypes FindSparseFeatureSpec(
     32     int32 input_feature, const tensorforest::TensorForestDataSpec& spec) {
     33   // TODO(thomaswc): Binary search here, especially when we start using more
     34   // than one sparse column
     35   int32 size_sum = spec.sparse(0).size();
     36   int32 column_num = 0;
     37   while (input_feature >= size_sum && column_num < spec.sparse_size()) {
     38     ++column_num;
     39     size_sum += spec.sparse(column_num).size();
     40   }
     41 
     42   return static_cast<DataColumnTypes>(spec.sparse(column_num).original_type());
     43 }
     44 
     45 void GetTwoBest(int max, const std::function<float(int)>& score_fn,
     46                 float* best_score, int* best_index, float* second_best_score,
     47                 int* second_best_index) {
     48   *best_index = -1;
     49   *second_best_index = -1;
     50   *best_score = FLT_MAX;
     51   *second_best_score = FLT_MAX;
     52   for (int i = 0; i < max; i++) {
     53     float score = score_fn(i);
     54     if (score < *best_score) {
     55       *second_best_score = *best_score;
     56       *second_best_index = *best_index;
     57       *best_score = score;
     58       *best_index = i;
     59     } else if (score < *second_best_score) {
     60       *second_best_score = score;
     61       *second_best_index = i;
     62     }
     63   }
     64 }
     65 
     66 float ClassificationSplitScore(
     67     const Eigen::Tensor<float, 1, Eigen::RowMajor>& splits,
     68     const Eigen::Tensor<float, 1, Eigen::RowMajor>& rights, int32 num_classes,
     69     int i) {
     70   Eigen::array<int, 1> offsets;
     71   // Class counts are stored with the total in [0], so the length of each
     72   // count vector is num_classes + 1.
     73   offsets[0] = i * (num_classes + 1) + 1;
     74   Eigen::array<int, 1> extents;
     75   extents[0] = num_classes;
     76   return WeightedGiniImpurity(splits.slice(offsets, extents)) +
     77          WeightedGiniImpurity(rights.slice(offsets, extents));
     78 }
     79 
     80 void GetTwoBestClassification(const Tensor& total_counts,
     81                               const Tensor& split_counts, int32 accumulator,
     82                               float* best_score, int* best_index,
     83                               float* second_best_score,
     84                               int* second_best_index) {
     85   const int32 num_splits = static_cast<int32>(split_counts.shape().dim_size(1));
     86   const int32 num_classes =
     87       static_cast<int32>(split_counts.shape().dim_size(2)) - 1;
     88 
     89   // Ideally, Eigen::Tensor::chip would be best to use here but it results
     90   // in seg faults, so we have to go with flat views of these tensors.  However,
     91   // it is still pretty efficient because we put off evaluation until the
     92   // score is actually returned.
     93   const auto tc =
     94       total_counts.Slice(accumulator, accumulator + 1).unaligned_flat<float>();
     95 
     96   // TODO(gilberth): See if we can delay evaluation here by templating the
     97   // arguments to ClassificationSplitScore.
     98   const Eigen::Tensor<float, 1, Eigen::RowMajor> splits =
     99       split_counts.Slice(accumulator, accumulator + 1).unaligned_flat<float>();
    100   Eigen::array<int, 1> bcast;
    101   bcast[0] = num_splits;
    102   const Eigen::Tensor<float, 1, Eigen::RowMajor> rights =
    103       tc.broadcast(bcast) - splits;
    104 
    105   std::function<float(int)> score_fn =
    106       std::bind(ClassificationSplitScore, splits, rights, num_classes,
    107                 std::placeholders::_1);
    108 
    109   GetTwoBest(num_splits, score_fn, best_score, best_index, second_best_score,
    110              second_best_index);
    111 }
    112 
    113 int32 BestFeatureClassification(const Tensor& total_counts,
    114                                 const Tensor& split_counts, int32 accumulator) {
    115   float best_score;
    116   float second_best_score;
    117   int best_feature_index;
    118   int second_best_index;
    119   GetTwoBestClassification(total_counts, split_counts, accumulator, &best_score,
    120                            &best_feature_index, &second_best_score,
    121                            &second_best_index);
    122   return best_feature_index;
    123 }
    124 
    125 float RegressionSplitScore(
    126     const Eigen::Tensor<float, 3, Eigen::RowMajor>& splits_count_accessor,
    127     const Eigen::Tensor<float, 2, Eigen::RowMajor>& totals_count_accessor,
    128     const Eigen::Tensor<float, 1, Eigen::RowMajor>& splits_sum,
    129     const Eigen::Tensor<float, 1, Eigen::RowMajor>& splits_square,
    130     const Eigen::Tensor<float, 1, Eigen::RowMajor>& right_sums,
    131     const Eigen::Tensor<float, 1, Eigen::RowMajor>& right_squares,
    132     int32 accumulator, int32 num_regression_dims, int i) {
    133   Eigen::array<int, 1> offsets = {i * num_regression_dims + 1};
    134   Eigen::array<int, 1> extents = {num_regression_dims - 1};
    135   float left_count = splits_count_accessor(accumulator, i, 0);
    136   float right_count = totals_count_accessor(accumulator, 0) - left_count;
    137 
    138   float score = 0;
    139 
    140   // Guard against divide-by-zero.
    141   if (left_count > 0) {
    142     score +=
    143         WeightedVariance(splits_sum.slice(offsets, extents),
    144                          splits_square.slice(offsets, extents), left_count);
    145   }
    146 
    147   if (right_count > 0) {
    148     score +=
    149         WeightedVariance(right_sums.slice(offsets, extents),
    150                          right_squares.slice(offsets, extents), right_count);
    151   }
    152   return score;
    153 }
    154 
    155 void GetTwoBestRegression(const Tensor& total_sums, const Tensor& total_squares,
    156                           const Tensor& split_sums, const Tensor& split_squares,
    157                           int32 accumulator, float* best_score, int* best_index,
    158                           float* second_best_score, int* second_best_index) {
    159   const int32 num_splits = static_cast<int32>(split_sums.shape().dim_size(1));
    160   const int32 num_regression_dims =
    161       static_cast<int32>(split_sums.shape().dim_size(2));
    162   // Ideally, Eigen::Tensor::chip would be best to use here but it results
    163   // in seg faults, so we have to go with flat views of these tensors.  However,
    164   // it is still pretty efficient because we put off evaluation until the
    165   // score is actually returned.
    166   const auto tc_sum =
    167       total_sums.Slice(accumulator, accumulator + 1).unaligned_flat<float>();
    168   const auto tc_square =
    169       total_squares.Slice(accumulator, accumulator + 1).unaligned_flat<float>();
    170   const auto splits_sum =
    171       split_sums.Slice(accumulator, accumulator + 1).unaligned_flat<float>();
    172   const auto splits_square =
    173       split_squares.Slice(accumulator, accumulator + 1).unaligned_flat<float>();
    174   // Eigen is infuriating to work with, usually resulting in all kinds of
    175   // unhelpful compiler errors when trying something that seems sane.  This
    176   // helps us do a simple thing like access the first element (the counts)
    177   // of these tensors so we can calculate expected value in Variance().
    178   const auto splits_count_accessor = split_sums.tensor<float, 3>();
    179   const auto totals_count_accessor = total_sums.tensor<float, 2>();
    180 
    181   Eigen::array<int, 1> bcast;
    182   bcast[0] = num_splits;
    183   const auto right_sums = tc_sum.broadcast(bcast) - splits_sum;
    184   const auto right_squares = tc_square.broadcast(bcast) - splits_square;
    185 
    186   GetTwoBest(num_splits,
    187              std::bind(RegressionSplitScore, splits_count_accessor,
    188                        totals_count_accessor, splits_sum, splits_square,
    189                        right_sums, right_squares, accumulator,
    190                        num_regression_dims, std::placeholders::_1),
    191              best_score, best_index, second_best_score, second_best_index);
    192 }
    193 
    194 int32 BestFeatureRegression(const Tensor& total_sums,
    195                             const Tensor& total_squares,
    196                             const Tensor& split_sums,
    197                             const Tensor& split_squares, int32 accumulator) {
    198   float best_score;
    199   float second_best_score;
    200   int best_feature_index;
    201   int second_best_index;
    202   GetTwoBestRegression(total_sums, total_squares, split_sums, split_squares,
    203                        accumulator, &best_score, &best_feature_index,
    204                        &second_best_score, &second_best_index);
    205   return best_feature_index;
    206 }
    207 
    208 bool BestSplitDominatesRegression(const Tensor& total_sums,
    209                                   const Tensor& total_squares,
    210                                   const Tensor& split_sums,
    211                                   const Tensor& split_squares,
    212                                   int32 accumulator) {
    213   // TODO(thomaswc): Implement this, probably as part of v3.
    214   return false;
    215 }
    216 
    217 int BootstrapGini(int n, int s, const random::DistributionSampler& ds,
    218                   random::SimplePhilox* rand) {
    219   std::vector<int> counts(s, 0);
    220   for (int i = 0; i < n; i++) {
    221     int j = ds.Sample(rand);
    222     counts[j] += 1;
    223   }
    224   int g = 0;
    225   for (int j = 0; j < s; j++) {
    226     g += counts[j] * counts[j];
    227   }
    228   // The true gini is 1 + (-g) / n^2
    229   return -g;
    230 }
    231 
    232 // Populate *weights with the smoothed per-class frequencies needed to
    233 // initialize a DistributionSampler.  Returns the total number of samples
    234 // seen by this accumulator.
    235 int MakeBootstrapWeights(const Tensor& total_counts, const Tensor& split_counts,
    236                          int32 accumulator, int index,
    237                          std::vector<float>* weights) {
    238   const int32 num_classes =
    239       static_cast<int32>(split_counts.shape().dim_size(2)) - 1;
    240 
    241   auto tc = total_counts.tensor<float, 2>();
    242   auto lc = split_counts.tensor<float, 3>();
    243 
    244   int n = tc(accumulator, 0);
    245 
    246   float denom = static_cast<float>(n) + static_cast<float>(num_classes);
    247 
    248   weights->resize(num_classes * 2);
    249   for (int i = 0; i < num_classes; i++) {
    250     // Use the Laplace smoothed per-class probabilities when generating the
    251     // bootstrap samples.
    252     float left_count = lc(accumulator, index, i + 1);
    253     (*weights)[i] = (left_count + 1.0) / denom;
    254     float right_count = tc(accumulator, i + 1) - left_count;
    255     (*weights)[num_classes + i] = (right_count + 1.0) / denom;
    256   }
    257 
    258   return n;
    259 }
    260 
    261 bool BestSplitDominatesClassificationBootstrap(const Tensor& total_counts,
    262                                                const Tensor& split_counts,
    263                                                int32 accumulator,
    264                                                float dominate_fraction,
    265                                                random::SimplePhilox* rand) {
    266   float best_score;
    267   float second_best_score;
    268   int best_feature_index;
    269   int second_best_index;
    270   GetTwoBestClassification(total_counts, split_counts, accumulator, &best_score,
    271                            &best_feature_index, &second_best_score,
    272                            &second_best_index);
    273 
    274   std::vector<float> weights1;
    275   int n1 = MakeBootstrapWeights(total_counts, split_counts, accumulator,
    276                                 best_feature_index, &weights1);
    277   random::DistributionSampler ds1(weights1);
    278 
    279   std::vector<float> weights2;
    280   int n2 = MakeBootstrapWeights(total_counts, split_counts, accumulator,
    281                                 second_best_index, &weights2);
    282   random::DistributionSampler ds2(weights2);
    283 
    284   const int32 num_classes =
    285       static_cast<int32>(split_counts.shape().dim_size(2)) - 1;
    286 
    287   float p = 1.0 - dominate_fraction;
    288   if (p <= 0 || p > 1.0) {
    289     LOG(FATAL) << "Invalid dominate fraction " << dominate_fraction;
    290   }
    291 
    292   int bootstrap_samples = 1;
    293   while (p < 1.0) {
    294     bootstrap_samples += 1;
    295     p = p * 2;
    296   }
    297 
    298   int worst_g1 = 0;
    299   for (int i = 0; i < bootstrap_samples; i++) {
    300     int g1 = BootstrapGini(n1, 2 * num_classes, ds1, rand);
    301     worst_g1 = std::max(worst_g1, g1);
    302   }
    303 
    304   int best_g2 = 99;
    305   for (int i = 0; i < bootstrap_samples; i++) {
    306     int g2 = BootstrapGini(n2, 2 * num_classes, ds2, rand);
    307     best_g2 = std::min(best_g2, g2);
    308   }
    309 
    310   return worst_g1 < best_g2;
    311 }
    312 
    313 bool BestSplitDominatesClassificationHoeffding(const Tensor& total_counts,
    314                                                const Tensor& split_counts,
    315                                                int32 accumulator,
    316                                                float dominate_fraction) {
    317   float best_score;
    318   float second_best_score;
    319   int best_feature_index;
    320   int second_best_index;
    321   VLOG(1) << "BSDC for accumulator " << accumulator;
    322   GetTwoBestClassification(total_counts, split_counts, accumulator, &best_score,
    323                            &best_feature_index, &second_best_score,
    324                            &second_best_index);
    325   VLOG(1) << "Best score = " << best_score;
    326   VLOG(1) << "2nd best score = " << second_best_score;
    327 
    328   const int32 num_classes =
    329       static_cast<int32>(split_counts.shape().dim_size(2)) - 1;
    330   const float n = total_counts.Slice(accumulator, accumulator + 1)
    331                       .unaligned_flat<float>()(0);
    332 
    333   // Each term in the Gini impurity can range from 0 to 0.5 * 0.5.
    334   float range = 0.25 * static_cast<float>(num_classes) * n;
    335 
    336   float hoeffding_bound =
    337       range * sqrt(log(1.0 / (1.0 - dominate_fraction)) / (2.0 * n));
    338 
    339   VLOG(1) << "num_classes = " << num_classes;
    340   VLOG(1) << "n = " << n;
    341   VLOG(1) << "range = " << range;
    342   VLOG(1) << "hoeffding_bound = " << hoeffding_bound;
    343   return (second_best_score - best_score) > hoeffding_bound;
    344 }
    345 
    346 double DirichletCovarianceTrace(const Tensor& total_counts,
    347                                 const Tensor& split_counts, int32 accumulator,
    348                                 int index) {
    349   const int32 num_classes =
    350       static_cast<int32>(split_counts.shape().dim_size(2)) - 1;
    351 
    352   auto tc = total_counts.tensor<float, 2>();
    353   auto lc = split_counts.tensor<float, 3>();
    354 
    355   double leftc = 0.0;
    356   double leftc2 = 0.0;
    357   double rightc = 0.0;
    358   double rightc2 = 0.0;
    359   for (int i = 1; i <= num_classes; i++) {
    360     double l = lc(accumulator, index, i) + 1.0;
    361     leftc += l;
    362     leftc2 += l * l;
    363 
    364     double r = tc(accumulator, i) - lc(accumulator, index, i) + 1.0;
    365     rightc += r;
    366     rightc2 += r * r;
    367   }
    368 
    369   double left_trace = (1.0 - leftc2 / (leftc * leftc)) / (leftc + 1.0);
    370   double right_trace = (1.0 - rightc2 / (rightc * rightc)) / (rightc + 1.0);
    371   return left_trace + right_trace;
    372 }
    373 
    374 void getDirichletMean(const Tensor& total_counts, const Tensor& split_counts,
    375                       int32 accumulator, int index, std::vector<float>* mu) {
    376   const int32 num_classes =
    377       static_cast<int32>(split_counts.shape().dim_size(2)) - 1;
    378 
    379   mu->resize(num_classes * 2);
    380   auto tc = total_counts.tensor<float, 2>();
    381   auto lc = split_counts.tensor<float, 3>();
    382 
    383   double total = tc(accumulator, 0);
    384 
    385   for (int i = 0; i < num_classes; i++) {
    386     double l = lc(accumulator, index, i + 1);
    387     mu->at(i) = (l + 1.0) / (total + num_classes);
    388 
    389     double r = tc(accumulator, i) - l;
    390     mu->at(i + num_classes) = (r + 1.) / (total + num_classes);
    391   }
    392 }
    393 
    394 // Given lambda3, returns the distance from (mu1, mu2) to the surface.
    395 double getDistanceFromLambda3(double lambda3, const std::vector<float>& mu1,
    396                               const std::vector<float>& mu2) {
    397   if (fabs(lambda3) == 1.0) {
    398     return 0.0;
    399   }
    400 
    401   int n = mu1.size();
    402   double lambda1 = -2.0 * lambda3 / n;
    403   double lambda2 = 2.0 * lambda3 / n;
    404   // From below,
    405   //   x = (lambda_1 1 + 2 mu1) / (2 - 2 lambda_3)
    406   //   y = (lambda_2 1 + 2 mu2) / (2 + 2 lambda_3)
    407   double dist = 0.0;
    408   for (size_t i = 0; i < mu1.size(); i++) {
    409     double diff = (lambda1 + 2.0 * mu1[i]) / (2.0 - 2.0 * lambda3) - mu1[i];
    410     dist += diff * diff;
    411     diff = (lambda2 + 2.0 * mu2[i]) / (2.0 + 2.0 * lambda3) - mu2[i];
    412     dist += diff * diff;
    413   }
    414   return dist;
    415 }
    416 
    417 // Returns the distance between (mu1, mu2) and (x, y), where (x, y) is the
    418 // nearest point that lies on the surface defined by
    419 // {x dot 1 = 1, y dot 1 = 1, x dot x - y dot y = 0}.
    420 double getChebyshevEpsilon(const std::vector<float>& mu1,
    421                            const std::vector<float>& mu2) {
    422   // Math time!!
    423   // We are trying to minimize d = |mu1 - x|^2 + |mu2 - y|^2 over the surface.
    424   // Using Langrange multipliers, we get
    425   //   partial d / partial x = -2 mu1 + 2 x = lambda_1 1 + 2 lambda_3 x
    426   //   partial d / partial y = -2 mu2 + 2 y = lambda_2 1 - 2 lambda_3 y
    427   // or
    428   //   x = (lambda_1 1 + 2 mu1) / (2 - 2 lambda_3)
    429   //   y = (lambda_2 1 + 2 mu2) / (2 + 2 lambda_3)
    430   // which implies
    431   //   2 - 2 lambda_3 = lambda_1 1 dot 1 + 2 mu1 dot 1
    432   //   2 + 2 lambda_3 = lambda_2 1 dot 1 + 2 mu2 dot 1
    433   //   |lambda_1 1 + 2 mu1|^2 (2 + 2 lambda_3)^2 =
    434   //     |lambda_2 1 + 2 mu2|^2 (2 - 2 lambda_3)^2
    435   // So solving for the lambda's and using the fact that
    436   // mu1 dot 1 = 1 and mu2 dot 1 = 1,
    437   //   lambda_1 = -2 lambda_3 / (1 dot 1)
    438   //   lambda_2 = 2 lambda_3 / (1 dot 1)
    439   // and (letting n = 1 dot 1)
    440   //   | - lambda_3 1 + n mu1 |^2 (1 + lambda_3)^2 =
    441   //   | lambda_3 1 + n mu2 |^2 (1 - lambda_3)^2
    442   // or
    443   // (lambda_3^2 n - 2 n lambda_3 + n^2 mu1 dot mu1)(1 + lambda_3)^2 =
    444   // (lambda_3^2 n + 2 n lambda_3 + n^2 mu2 dot mu2)(1 - lambda_3)^2
    445   // or
    446   // (lambda_3^2 - 2 lambda_3 + n mu1 dot mu1)(1 + 2 lambda_3 + lambda_3^2) =
    447   // (lambda_3^2 + 2 lambda_3 + n mu2 dot mu2)(1 - 2 lambda_3 + lambda_3^2)
    448   // or
    449   // lambda_3^2 - 2 lambda_3 + n mu1 dot mu1
    450   // + 2 lambda_3^3 - 2 lambda_3^2 + 2n lambda_3 mu1 dot mu1
    451   // + lambda_3^4 - 2 lambda_3^3 + n lambda_3^2 mu1 dot mu1
    452   // =
    453   // lambda_3^2 + 2 lambda_3 + n mu2 dot mu2
    454   // - 2 lambda_3^3 -4 lambda_3^2 - 2n lambda_3 mu2 dot mu2
    455   // + lambda_3^4 + 2 lambda_3^3 + n lambda_3^2 mu2 dot mu2
    456   // or
    457   // - 2 lambda_3 + n mu1 dot mu1
    458   // - 2 lambda_3^2 + 2n lambda_3 mu1 dot mu1
    459   // + n lambda_3^2 mu1 dot mu1
    460   // =
    461   // + 2 lambda_3 + n mu2 dot mu2
    462   // -4 lambda_3^2 - 2n lambda_3 mu2 dot mu2
    463   // + n lambda_3^2 mu2 dot mu2
    464   // or
    465   // lambda_3^2 (2 + n mu1 dot mu1 + n mu2 dot mu2)
    466   // + lambda_3 (2n mu1 dot mu1 + 2n mu2 dot mu2 - 4)
    467   // + n mu1 dot mu1 - n mu2 dot mu2 = 0
    468   // which can be solved using the quadratic formula.
    469   int n = mu1.size();
    470   double len1 = 0.0;
    471   for (float m : mu1) {
    472     len1 += m * m;
    473   }
    474   double len2 = 0.0;
    475   for (float m : mu2) {
    476     len2 += m * m;
    477   }
    478   double a = 2 + n * (len1 + len2);
    479   double b = 2 * n * (len1 + len2) - 4;
    480   double c = n * (len1 - len2);
    481   double discrim = b * b - 4 * a * c;
    482   if (discrim < 0.0) {
    483     LOG(WARNING) << "Negative discriminant " << discrim;
    484     return 0.0;
    485   }
    486 
    487   double sdiscrim = sqrt(discrim);
    488   // TODO(thomaswc): Analyze whetever one of these is always closer.
    489   double v1 = (-b + sdiscrim) / (2 * a);
    490   double v2 = (-b - sdiscrim) / (2 * a);
    491   double dist1 = getDistanceFromLambda3(v1, mu1, mu2);
    492   double dist2 = getDistanceFromLambda3(v2, mu1, mu2);
    493   return std::min(dist1, dist2);
    494 }
    495 
    496 bool BestSplitDominatesClassificationChebyshev(const Tensor& total_counts,
    497                                                const Tensor& split_counts,
    498                                                int32 accumulator,
    499                                                float dominate_fraction) {
    500   float best_score;
    501   float second_best_score;
    502   int best_feature_index;
    503   int second_best_index;
    504   VLOG(1) << "BSDC for accumulator " << accumulator;
    505   GetTwoBestClassification(total_counts, split_counts, accumulator, &best_score,
    506                            &best_feature_index, &second_best_score,
    507                            &second_best_index);
    508   VLOG(1) << "Best score = " << best_score;
    509   VLOG(1) << "2nd best score = " << second_best_score;
    510 
    511   const int32 num_classes =
    512       static_cast<int32>(split_counts.shape().dim_size(2)) - 1;
    513   const float n = total_counts.Slice(accumulator, accumulator + 1)
    514                       .unaligned_flat<float>()(0);
    515 
    516   VLOG(1) << "num_classes = " << num_classes;
    517   VLOG(1) << "n = " << n;
    518   double trace = DirichletCovarianceTrace(total_counts, split_counts,
    519                                           accumulator, best_feature_index) +
    520                  DirichletCovarianceTrace(total_counts, split_counts,
    521                                           accumulator, second_best_index);
    522 
    523   std::vector<float> mu1;
    524   getDirichletMean(total_counts, split_counts, accumulator, best_feature_index,
    525                    &mu1);
    526   std::vector<float> mu2;
    527   getDirichletMean(total_counts, split_counts, accumulator, second_best_index,
    528                    &mu2);
    529   double epsilon = getChebyshevEpsilon(mu1, mu2);
    530 
    531   if (epsilon == 0.0) {
    532     return false;
    533   }
    534 
    535   double dirichlet_bound = 1.0 - trace / (epsilon * epsilon);
    536   return dirichlet_bound > dominate_fraction;
    537 }
    538 
    539 GetFeatureFnType GetDenseFunctor(const Tensor& dense) {
    540   if (dense.shape().dims() == 2) {
    541     const auto dense_features = dense.matrix<float>();
    542     // Here we capture by value, which shouldn't incur a copy of the data
    543     // because of the underlying use of Eigen::TensorMap.
    544     return [dense_features](int32 i, int32 feature) {
    545       return dense_features(i, feature);
    546     };
    547   } else {
    548     return [](int32 i, int32 feature) {
    549       LOG(ERROR) << "trying to access nonexistent dense features.";
    550       return 0;
    551     };
    552   }
    553 }
    554 
    555 GetFeatureFnType GetSparseFunctor(const Tensor& sparse_indices,
    556                                   const Tensor& sparse_values) {
    557   if (sparse_indices.shape().dims() == 2) {
    558     const auto indices = sparse_indices.matrix<int64>();
    559     const auto values = sparse_values.vec<float>();
    560     // Here we capture by value, which shouldn't incur a copy of the data
    561     // because of the underlying use of Eigen::TensorMap.
    562     return [indices, values](int32 i, int32 feature) {
    563       return tensorforest::FindSparseValue(indices, values, i, feature);
    564     };
    565   } else {
    566     return [](int32 i, int32 feature) {
    567       LOG(ERROR) << "trying to access nonexistent sparse features.";
    568       return 0;
    569     };
    570   }
    571 }
    572 
    573 bool DecideNode(const GetFeatureFnType& get_dense,
    574                 const GetFeatureFnType& get_sparse, int32 i, int32 feature,
    575                 float bias, const tensorforest::TensorForestDataSpec& spec) {
    576   if (feature < spec.dense_features_size()) {
    577     return Decide(get_dense(i, feature), bias,
    578                   FindDenseFeatureSpec(feature, spec));
    579   } else {
    580     const int32 sparse_feature = feature - spec.dense_features_size();
    581     return Decide(get_sparse(i, sparse_feature), bias,
    582                   FindSparseFeatureSpec(sparse_feature, spec));
    583   }
    584 }
    585 
    586 bool Decide(float value, float bias, DataColumnTypes type) {
    587   switch (type) {
    588     case kDataFloat:
    589       return value >= bias;
    590 
    591     case kDataCategorical:
    592       // We arbitrarily define categorical equality as going left.
    593       return value != bias;
    594 
    595     default:
    596       LOG(ERROR) << "Got unknown column type: " << type;
    597       return false;
    598   }
    599 }
    600 
    601 void GetParentWeightedMean(float leaf_sum, const float* leaf_data,
    602                            float parent_sum, const float* parent_data,
    603                            float valid_leaf_threshold, int num_outputs,
    604                            std::vector<float>* mean) {
    605   float parent_weight = 0.0;
    606   if (leaf_sum < valid_leaf_threshold && parent_sum >= 0) {
    607     VLOG(1) << "not enough samples at leaf, including parent counts."
    608             << "child sum = " << leaf_sum;
    609     // Weight the parent's counts just enough so that the new sum is
    610     // valid_leaf_threshold_, but never give any counts a weight of
    611     // more than 1.
    612     parent_weight =
    613         std::min(1.0f, (valid_leaf_threshold - leaf_sum) / parent_sum);
    614     leaf_sum += parent_weight * parent_sum;
    615     VLOG(1) << "Sum w/ parent included = " << leaf_sum;
    616   }
    617 
    618   for (int c = 0; c < num_outputs; c++) {
    619     float w = leaf_data[c];
    620     if (parent_weight > 0.0) {
    621       w += parent_weight * parent_data[c];
    622     }
    623     (*mean)[c] = w / leaf_sum;
    624   }
    625 }
    626 
    627 }  // namespace tensorforest
    628 }  // namespace tensorflow
    629