Home | History | Annotate | Download | only in models
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #include "tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h"
     16 #include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
     17 #include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
     18 #include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h"
     19 
     20 namespace tensorflow {
     21 namespace boosted_trees {
     22 namespace models {
     23 
     24 void MultipleAdditiveTrees::Predict(
     25     const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
     26     const std::vector<int32>& trees_to_include,
     27     const boosted_trees::utils::BatchFeatures& features,
     28     tensorflow::thread::ThreadPool* const worker_threads,
     29     tensorflow::TTypes<float>::Matrix output_predictions) {
     30   // Zero out predictions as the model is additive.
     31   output_predictions.setZero();
     32 
     33   // Get batch size.
     34   const int64 batch_size = features.batch_size();
     35   if (batch_size <= 0) {
     36     return;
     37   }
     38 
     39   // Lambda for doing a block of work.
     40   auto update_predictions = [&config, &features, &trees_to_include,
     41                              &output_predictions](int64 start, int64 end) {
     42     auto examples_iterable = features.examples_iterable(start, end);
     43     for (const auto& example : examples_iterable) {
     44       for (const int32 tree_idx : trees_to_include) {
     45         const boosted_trees::trees::DecisionTreeConfig& tree =
     46             config.trees(tree_idx);
     47         const float tree_weight = config.tree_weights(tree_idx);
     48         const int leaf_idx = trees::DecisionTree::Traverse(tree, 0, example);
     49         QCHECK(leaf_idx >= 0) << "Invalid tree: " << tree.DebugString();
     50         const auto& leaf_node = tree.nodes(leaf_idx);
     51         QCHECK(leaf_node.has_leaf())
     52             << "Invalid leaf node: " << leaf_node.DebugString();
     53         if (leaf_node.leaf().has_sparse_vector()) {
     54           const auto& leaf = leaf_node.leaf().sparse_vector();
     55           QCHECK_EQ(leaf.index_size(), leaf.value_size());
     56           for (size_t logit_dim = 0; logit_dim < leaf.index_size();
     57                ++logit_dim) {
     58             const float value = tree_weight * leaf.value(logit_dim);
     59             output_predictions(example.example_idx, leaf.index(logit_dim)) +=
     60                 value;
     61           }
     62         } else {
     63           QCHECK(leaf_node.leaf().has_vector()) << "Unknown leaf type";
     64           const auto& leaf = leaf_node.leaf().vector();
     65           for (size_t i = 0; i < leaf.value_size(); ++i) {
     66             const float value = tree_weight * leaf.value(i);
     67             output_predictions(example.example_idx, i) += value;
     68           }
     69         }
     70       }
     71     }
     72   };
     73 
     74   // TODO(salehay): parallelize this for low latency in serving path where
     75   // batch size tends to be small but ensemble size tends to be large.
     76   boosted_trees::utils::ParallelFor(batch_size, worker_threads->NumThreads(),
     77                                     worker_threads, update_predictions);
     78 }
     79 
     80 }  // namespace models
     81 }  // namespace boosted_trees
     82 }  // namespace tensorflow
     83