1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h" 16 17 #include "tensorflow/core/lib/random/philox_random.h" 18 #include "tensorflow/core/lib/random/simple_philox.h" 19 #include "tensorflow/core/platform/test.h" 20 #include "tensorflow/core/platform/test_benchmark.h" 21 22 namespace tensorflow { 23 namespace { 24 25 using Buffer = 26 boosted_trees::quantiles::WeightedQuantilesBuffer<double, double>; 27 using BufferEntry = 28 boosted_trees::quantiles::WeightedQuantilesBuffer<double, 29 double>::BufferEntry; 30 31 class WeightedQuantilesBufferTest : public ::testing::Test {}; 32 33 TEST_F(WeightedQuantilesBufferTest, Invalid) { 34 EXPECT_DEATH( 35 ({ 36 boosted_trees::quantiles::WeightedQuantilesBuffer<double, double> 37 buffer(2, 0); 38 }), 39 "Invalid buffer specification"); 40 EXPECT_DEATH( 41 ({ 42 boosted_trees::quantiles::WeightedQuantilesBuffer<double, double> 43 buffer(0, 2); 44 }), 45 "Invalid buffer specification"); 46 } 47 48 TEST_F(WeightedQuantilesBufferTest, PushEntryNotFull) { 49 Buffer buffer(20, 100); 50 buffer.PushEntry(5, 9); 51 buffer.PushEntry(2, 3); 52 buffer.PushEntry(-1, 7); 53 buffer.PushEntry(3, 0); // This entry will be ignored. 54 55 EXPECT_FALSE(buffer.IsFull()); 56 EXPECT_EQ(buffer.Size(), 3); 57 } 58 59 TEST_F(WeightedQuantilesBufferTest, PushEntryFull) { 60 // buffer capacity is 4. 61 Buffer buffer(2, 100); 62 buffer.PushEntry(5, 9); 63 buffer.PushEntry(2, 3); 64 buffer.PushEntry(-1, 7); 65 buffer.PushEntry(2, 1); 66 67 std::vector<BufferEntry> expected; 68 expected.emplace_back(-1, 7); 69 expected.emplace_back(2, 4); 70 expected.emplace_back(5, 9); 71 72 // At this point, we have pushed 4 entries and we expect the buffer to be 73 // full. 74 EXPECT_TRUE(buffer.IsFull()); 75 EXPECT_EQ(buffer.GenerateEntryList(), expected); 76 EXPECT_FALSE(buffer.IsFull()); 77 } 78 79 TEST_F(WeightedQuantilesBufferTest, PushEntryFullDeath) { 80 // buffer capacity is 4. 81 Buffer buffer(2, 100); 82 buffer.PushEntry(5, 9); 83 buffer.PushEntry(2, 3); 84 buffer.PushEntry(-1, 7); 85 buffer.PushEntry(2, 1); 86 87 std::vector<BufferEntry> expected; 88 expected.emplace_back(-1, 7); 89 expected.emplace_back(2, 4); 90 expected.emplace_back(5, 9); 91 92 // At this point, we have pushed 4 entries and we expect the buffer to be 93 // full. 94 EXPECT_TRUE(buffer.IsFull()); 95 // Can't push any more entries before clearing. 96 EXPECT_DEATH(({ buffer.PushEntry(6, 6); }), "Buffer already full"); 97 } 98 99 } // namespace 100 } // namespace tensorflow 101