1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #include "tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h" 16 #include "tensorflow/core/framework/tensor_testutil.h" 17 #include "tensorflow/core/lib/core/status_test_util.h" 18 #include "tensorflow/core/platform/test.h" 19 20 namespace tensorflow { 21 namespace boosted_trees { 22 namespace learner { 23 namespace { 24 25 class ExamplePartitionerTest : public ::testing::Test { 26 protected: 27 ExamplePartitionerTest() 28 : thread_pool_(tensorflow::Env::Default(), "test_pool", 2), 29 batch_features_(2) { 30 dense_matrix_ = test::AsTensor<float>({7.0f, -2.0f}, {2, 1}); 31 TF_EXPECT_OK( 32 batch_features_.Initialize({dense_matrix_}, {}, {}, {}, {}, {}, {})); 33 } 34 35 thread::ThreadPool thread_pool_; 36 Tensor dense_matrix_; 37 boosted_trees::utils::BatchFeatures batch_features_; 38 }; 39 40 TEST_F(ExamplePartitionerTest, EmptyTree) { 41 boosted_trees::trees::DecisionTreeConfig tree_config; 42 std::vector<int32> example_partition_ids(2); 43 ExamplePartitioner::UpdatePartitions(tree_config, batch_features_, 1, 44 &thread_pool_, 45 example_partition_ids.data()); 46 EXPECT_EQ(0, example_partition_ids[0]); 47 EXPECT_EQ(0, example_partition_ids[1]); 48 } 49 50 TEST_F(ExamplePartitionerTest, UpdatePartitions) { 51 // Create tree with one split. 52 // TODO(salehay): figure out if we can use PARSE_TEXT_PROTO. 53 boosted_trees::trees::DecisionTreeConfig tree_config; 54 auto* split = tree_config.add_nodes()->mutable_dense_float_binary_split(); 55 split->set_feature_column(0); 56 split->set_threshold(3.0f); 57 split->set_left_id(1); 58 split->set_right_id(2); 59 tree_config.add_nodes()->mutable_leaf(); 60 tree_config.add_nodes()->mutable_leaf(); 61 62 // Partition input: 63 // Instance 1 has !(7 <= 3) => go right => leaf 2. 64 // Instance 2 has (-2 <= 3) => go left => leaf 1. 65 std::vector<int32> example_partition_ids(2); 66 ExamplePartitioner::UpdatePartitions(tree_config, batch_features_, 1, 67 &thread_pool_, 68 example_partition_ids.data()); 69 EXPECT_EQ(2, example_partition_ids[0]); 70 EXPECT_EQ(1, example_partition_ids[1]); 71 } 72 73 TEST_F(ExamplePartitionerTest, PartitionExamples) { 74 // Create tree with one split. 75 // TODO(salehay): figure out if we can use PARSE_TEXT_PROTO. 76 boosted_trees::trees::DecisionTreeConfig tree_config; 77 auto* split = tree_config.add_nodes()->mutable_dense_float_binary_split(); 78 split->set_feature_column(0); 79 split->set_threshold(3.0f); 80 split->set_left_id(1); 81 split->set_right_id(2); 82 tree_config.add_nodes()->mutable_leaf(); 83 tree_config.add_nodes()->mutable_leaf(); 84 85 // Partition input: 86 // Instance 1 has !(7 <= 3) => go right => leaf 2. 87 // Instance 2 has (-2 <= 3) => go left => leaf 1. 88 std::vector<int32> example_partition_ids(2); 89 ExamplePartitioner::PartitionExamples(tree_config, batch_features_, 1, 90 &thread_pool_, 91 example_partition_ids.data()); 92 EXPECT_EQ(2, example_partition_ids[0]); 93 EXPECT_EQ(1, example_partition_ids[1]); 94 } 95 96 } // namespace 97 } // namespace learner 98 } // namespace boosted_trees 99 } // namespace tensorflow 100