1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #include "tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h" 16 #include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" 17 #include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h" 18 #include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h" 19 20 namespace tensorflow { 21 namespace boosted_trees { 22 namespace models { 23 24 void MultipleAdditiveTrees::Predict( 25 const boosted_trees::trees::DecisionTreeEnsembleConfig& config, 26 const std::vector<int32>& trees_to_include, 27 const boosted_trees::utils::BatchFeatures& features, 28 tensorflow::thread::ThreadPool* const worker_threads, 29 tensorflow::TTypes<float>::Matrix output_predictions) { 30 // Zero out predictions as the model is additive. 31 output_predictions.setZero(); 32 33 // Get batch size. 34 const int64 batch_size = features.batch_size(); 35 if (batch_size <= 0) { 36 return; 37 } 38 39 // Lambda for doing a block of work. 40 auto update_predictions = [&config, &features, &trees_to_include, 41 &output_predictions](int64 start, int64 end) { 42 auto examples_iterable = features.examples_iterable(start, end); 43 for (const auto& example : examples_iterable) { 44 for (const int32 tree_idx : trees_to_include) { 45 const boosted_trees::trees::DecisionTreeConfig& tree = 46 config.trees(tree_idx); 47 const float tree_weight = config.tree_weights(tree_idx); 48 const int leaf_idx = trees::DecisionTree::Traverse(tree, 0, example); 49 QCHECK(leaf_idx >= 0) << "Invalid tree: " << tree.DebugString(); 50 const auto& leaf_node = tree.nodes(leaf_idx); 51 QCHECK(leaf_node.has_leaf()) 52 << "Invalid leaf node: " << leaf_node.DebugString(); 53 if (leaf_node.leaf().has_sparse_vector()) { 54 const auto& leaf = leaf_node.leaf().sparse_vector(); 55 QCHECK_EQ(leaf.index_size(), leaf.value_size()); 56 for (size_t logit_dim = 0; logit_dim < leaf.index_size(); 57 ++logit_dim) { 58 const float value = tree_weight * leaf.value(logit_dim); 59 output_predictions(example.example_idx, leaf.index(logit_dim)) += 60 value; 61 } 62 } else { 63 QCHECK(leaf_node.leaf().has_vector()) << "Unknown leaf type"; 64 const auto& leaf = leaf_node.leaf().vector(); 65 for (size_t i = 0; i < leaf.value_size(); ++i) { 66 const float value = tree_weight * leaf.value(i); 67 output_predictions(example.example_idx, i) += value; 68 } 69 } 70 } 71 } 72 }; 73 74 // TODO(salehay): parallelize this for low latency in serving path where 75 // batch size tends to be small but ensemble size tends to be large. 76 boosted_trees::utils::ParallelFor(batch_size, worker_threads->NumThreads(), 77 worker_threads, update_predictions); 78 } 79 80 } // namespace models 81 } // namespace boosted_trees 82 } // namespace tensorflow 83