Home | History | Annotate | Download | only in partitioners
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #include "tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h"
     16 #include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h"
     17 
     18 namespace tensorflow {
     19 namespace boosted_trees {
     20 namespace learner {
     21 
     22 void ExamplePartitioner::UpdatePartitions(
     23     const boosted_trees::trees::DecisionTreeConfig& tree_config,
     24     const boosted_trees::utils::BatchFeatures& features,
     25     const int desired_parallelism, thread::ThreadPool* const thread_pool,
     26     int32* example_partition_ids) {
     27   // Get batch size.
     28   const int64 batch_size = features.batch_size();
     29   if (batch_size <= 0) {
     30     return;
     31   }
     32 
     33   // Lambda for doing a block of work.
     34   auto partition_examples_subset = [&tree_config, &features,
     35                                     &example_partition_ids](const int64 start,
     36                                                             const int64 end) {
     37     if (TF_PREDICT_TRUE(tree_config.nodes_size() > 0)) {
     38       auto examples_iterable = features.examples_iterable(start, end);
     39       for (const auto& example : examples_iterable) {
     40         int32& example_partition = example_partition_ids[example.example_idx];
     41         example_partition = boosted_trees::trees::DecisionTree::Traverse(
     42             tree_config, example_partition, example);
     43         DCHECK_GE(example_partition, 0);
     44       }
     45     } else {
     46       std::fill(example_partition_ids + start, example_partition_ids + end, 0);
     47     }
     48   };
     49 
     50   // Parallelize partitioning over the batch.
     51   boosted_trees::utils::ParallelFor(batch_size, desired_parallelism,
     52                                     thread_pool, partition_examples_subset);
     53 }
     54 
     55 void ExamplePartitioner::PartitionExamples(
     56     const boosted_trees::trees::DecisionTreeConfig& tree_config,
     57     const boosted_trees::utils::BatchFeatures& features,
     58     const int desired_parallelism, thread::ThreadPool* const thread_pool,
     59     int32* example_partition_ids) {
     60   // Get batch size.
     61   const int64 batch_size = features.batch_size();
     62   if (batch_size <= 0) {
     63     return;
     64   }
     65 
     66   // Lambda for doing a block of work.
     67   auto partition_examples_subset = [&tree_config, &features,
     68                                     &example_partition_ids](const int64 start,
     69                                                             const int64 end) {
     70     if (TF_PREDICT_TRUE(tree_config.nodes_size() > 0)) {
     71       auto examples_iterable = features.examples_iterable(start, end);
     72       for (const auto& example : examples_iterable) {
     73         uint32 partition = boosted_trees::trees::DecisionTree::Traverse(
     74             tree_config, 0, example);
     75         example_partition_ids[example.example_idx] = partition;
     76         DCHECK_GE(partition, 0);
     77       }
     78     } else {
     79       std::fill(example_partition_ids + start, example_partition_ids + end, 0);
     80     }
     81   };
     82 
     83   // Parallelize partitioning over the batch.
     84   boosted_trees::utils::ParallelFor(batch_size, desired_parallelism,
     85                                     thread_pool, partition_examples_subset);
     86 }
     87 
     88 }  // namespace learner
     89 }  // namespace boosted_trees
     90 }  // namespace tensorflow
     91