Home | History | Annotate | Download | only in partitioners
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #include "tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h"
     16 #include "tensorflow/core/framework/tensor_testutil.h"
     17 #include "tensorflow/core/lib/core/status_test_util.h"
     18 #include "tensorflow/core/platform/test.h"
     19 
     20 namespace tensorflow {
     21 namespace boosted_trees {
     22 namespace learner {
     23 namespace {
     24 
     25 class ExamplePartitionerTest : public ::testing::Test {
     26  protected:
     27   ExamplePartitionerTest()
     28       : thread_pool_(tensorflow::Env::Default(), "test_pool", 2),
     29         batch_features_(2) {
     30     dense_matrix_ = test::AsTensor<float>({7.0f, -2.0f}, {2, 1});
     31     TF_EXPECT_OK(
     32         batch_features_.Initialize({dense_matrix_}, {}, {}, {}, {}, {}, {}));
     33   }
     34 
     35   thread::ThreadPool thread_pool_;
     36   Tensor dense_matrix_;
     37   boosted_trees::utils::BatchFeatures batch_features_;
     38 };
     39 
     40 TEST_F(ExamplePartitionerTest, EmptyTree) {
     41   boosted_trees::trees::DecisionTreeConfig tree_config;
     42   std::vector<int32> example_partition_ids(2);
     43   ExamplePartitioner::UpdatePartitions(tree_config, batch_features_, 1,
     44                                        &thread_pool_,
     45                                        example_partition_ids.data());
     46   EXPECT_EQ(0, example_partition_ids[0]);
     47   EXPECT_EQ(0, example_partition_ids[1]);
     48 }
     49 
     50 TEST_F(ExamplePartitionerTest, UpdatePartitions) {
     51   // Create tree with one split.
     52   // TODO(salehay): figure out if we can use PARSE_TEXT_PROTO.
     53   boosted_trees::trees::DecisionTreeConfig tree_config;
     54   auto* split = tree_config.add_nodes()->mutable_dense_float_binary_split();
     55   split->set_feature_column(0);
     56   split->set_threshold(3.0f);
     57   split->set_left_id(1);
     58   split->set_right_id(2);
     59   tree_config.add_nodes()->mutable_leaf();
     60   tree_config.add_nodes()->mutable_leaf();
     61 
     62   // Partition input:
     63   // Instance 1 has !(7 <= 3) => go right => leaf 2.
     64   // Instance 2 has (-2 <= 3) => go left => leaf 1.
     65   std::vector<int32> example_partition_ids(2);
     66   ExamplePartitioner::UpdatePartitions(tree_config, batch_features_, 1,
     67                                        &thread_pool_,
     68                                        example_partition_ids.data());
     69   EXPECT_EQ(2, example_partition_ids[0]);
     70   EXPECT_EQ(1, example_partition_ids[1]);
     71 }
     72 
     73 TEST_F(ExamplePartitionerTest, PartitionExamples) {
     74   // Create tree with one split.
     75   // TODO(salehay): figure out if we can use PARSE_TEXT_PROTO.
     76   boosted_trees::trees::DecisionTreeConfig tree_config;
     77   auto* split = tree_config.add_nodes()->mutable_dense_float_binary_split();
     78   split->set_feature_column(0);
     79   split->set_threshold(3.0f);
     80   split->set_left_id(1);
     81   split->set_right_id(2);
     82   tree_config.add_nodes()->mutable_leaf();
     83   tree_config.add_nodes()->mutable_leaf();
     84 
     85   // Partition input:
     86   // Instance 1 has !(7 <= 3) => go right => leaf 2.
     87   // Instance 2 has (-2 <= 3) => go left => leaf 1.
     88   std::vector<int32> example_partition_ids(2);
     89   ExamplePartitioner::PartitionExamples(tree_config, batch_features_, 1,
     90                                         &thread_pool_,
     91                                         example_partition_ids.data());
     92   EXPECT_EQ(2, example_partition_ids[0]);
     93   EXPECT_EQ(1, example_partition_ids[1]);
     94 }
     95 
     96 }  // namespace
     97 }  // namespace learner
     98 }  // namespace boosted_trees
     99 }  // namespace tensorflow
    100