Home | History | Annotate | Download | only in quantiles
      1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h"
     16 
     17 #include "tensorflow/core/lib/random/philox_random.h"
     18 #include "tensorflow/core/lib/random/simple_philox.h"
     19 #include "tensorflow/core/platform/test.h"
     20 #include "tensorflow/core/platform/test_benchmark.h"
     21 
     22 namespace tensorflow {
     23 namespace {
     24 
     25 using Buffer =
     26     boosted_trees::quantiles::WeightedQuantilesBuffer<double, double>;
     27 using BufferEntry =
     28     boosted_trees::quantiles::WeightedQuantilesBuffer<double,
     29                                                       double>::BufferEntry;
     30 
     31 class WeightedQuantilesBufferTest : public ::testing::Test {};
     32 
     33 TEST_F(WeightedQuantilesBufferTest, Invalid) {
     34   EXPECT_DEATH(
     35       ({
     36         boosted_trees::quantiles::WeightedQuantilesBuffer<double, double>
     37             buffer(2, 0);
     38       }),
     39       "Invalid buffer specification");
     40   EXPECT_DEATH(
     41       ({
     42         boosted_trees::quantiles::WeightedQuantilesBuffer<double, double>
     43             buffer(0, 2);
     44       }),
     45       "Invalid buffer specification");
     46 }
     47 
     48 TEST_F(WeightedQuantilesBufferTest, PushEntryNotFull) {
     49   Buffer buffer(20, 100);
     50   buffer.PushEntry(5, 9);
     51   buffer.PushEntry(2, 3);
     52   buffer.PushEntry(-1, 7);
     53   buffer.PushEntry(3, 0);  // This entry will be ignored.
     54 
     55   EXPECT_FALSE(buffer.IsFull());
     56   EXPECT_EQ(buffer.Size(), 3);
     57 }
     58 
     59 TEST_F(WeightedQuantilesBufferTest, PushEntryFull) {
     60   // buffer capacity is 4.
     61   Buffer buffer(2, 100);
     62   buffer.PushEntry(5, 9);
     63   buffer.PushEntry(2, 3);
     64   buffer.PushEntry(-1, 7);
     65   buffer.PushEntry(2, 1);
     66 
     67   std::vector<BufferEntry> expected;
     68   expected.emplace_back(-1, 7);
     69   expected.emplace_back(2, 4);
     70   expected.emplace_back(5, 9);
     71 
     72   // At this point, we have pushed 4 entries and we expect the buffer to be
     73   // full.
     74   EXPECT_TRUE(buffer.IsFull());
     75   EXPECT_EQ(buffer.GenerateEntryList(), expected);
     76   EXPECT_FALSE(buffer.IsFull());
     77 }
     78 
     79 TEST_F(WeightedQuantilesBufferTest, PushEntryFullDeath) {
     80   // buffer capacity is 4.
     81   Buffer buffer(2, 100);
     82   buffer.PushEntry(5, 9);
     83   buffer.PushEntry(2, 3);
     84   buffer.PushEntry(-1, 7);
     85   buffer.PushEntry(2, 1);
     86 
     87   std::vector<BufferEntry> expected;
     88   expected.emplace_back(-1, 7);
     89   expected.emplace_back(2, 4);
     90   expected.emplace_back(5, 9);
     91 
     92   // At this point, we have pushed 4 entries and we expect the buffer to be
     93   // full.
     94   EXPECT_TRUE(buffer.IsFull());
     95   // Can't push any more entries before clearing.
     96   EXPECT_DEATH(({ buffer.PushEntry(6, 6); }), "Buffer already full");
     97 }
     98 
     99 }  // namespace
    100 }  // namespace tensorflow
    101