Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <vector>
     17 
     18 #include "tensorflow/core/kernels/range_sampler.h"
     19 #include "tensorflow/core/lib/core/status_test_util.h"
     20 #include "tensorflow/core/lib/io/path.h"
     21 #include "tensorflow/core/lib/random/simple_philox.h"
     22 #include "tensorflow/core/platform/env.h"
     23 #include "tensorflow/core/platform/logging.h"
     24 #include "tensorflow/core/platform/test.h"
     25 
     26 namespace tensorflow {
     27 namespace {
     28 
     29 using gtl::ArraySlice;
     30 using gtl::MutableArraySlice;
     31 
     32 class RangeSamplerTest : public ::testing::Test {
     33  protected:
     34   void CheckProbabilitiesSumToOne() {
     35     double sum = 0;
     36     for (int i = 0; i < sampler_->range(); i++) {
     37       sum += sampler_->Probability(i);
     38     }
     39     EXPECT_NEAR(sum, 1.0, 1e-4);
     40   }
     41   void CheckHistogram(int num_samples, float tolerance) {
     42     const int range = sampler_->range();
     43     std::vector<int> h(range);
     44     std::vector<int64> a(num_samples);
     45     // Using a fixed random seed to make the test deterministic.
     46     random::PhiloxRandom philox(123, 17);
     47     random::SimplePhilox rnd(&philox);
     48     sampler_->SampleBatch(&rnd, false, &a);
     49     for (int i = 0; i < num_samples; i++) {
     50       int64 val = a[i];
     51       ASSERT_GE(val, 0);
     52       ASSERT_LT(val, range);
     53       h[val]++;
     54     }
     55     for (int val = 0; val < range; val++) {
     56       EXPECT_NEAR((h[val] + 0.0) / num_samples, sampler_->Probability(val),
     57                   tolerance);
     58     }
     59   }
     60   void Update1() {
     61     // Add the value 3 ten times.
     62     std::vector<int64> a(10);
     63     for (int i = 0; i < 10; i++) {
     64       a[i] = 3;
     65     }
     66     sampler_->Update(a);
     67   }
     68   void Update2() {
     69     // Add the value n times.
     70     int64 a[10];
     71     for (int i = 0; i < 10; i++) {
     72       a[i] = i;
     73     }
     74     for (int64 i = 1; i < 10; i++) {
     75       sampler_->Update(ArraySlice<int64>(a + i, 10 - i));
     76     }
     77   }
     78   std::unique_ptr<RangeSampler> sampler_;
     79 };
     80 
     81 TEST_F(RangeSamplerTest, UniformProbabilities) {
     82   sampler_.reset(new UniformSampler(10));
     83   for (int i = 0; i < 10; i++) {
     84     CHECK_EQ(sampler_->Probability(i), sampler_->Probability(0));
     85   }
     86 }
     87 
     88 TEST_F(RangeSamplerTest, UniformChecksum) {
     89   sampler_.reset(new UniformSampler(10));
     90   CheckProbabilitiesSumToOne();
     91 }
     92 
     93 TEST_F(RangeSamplerTest, UniformHistogram) {
     94   sampler_.reset(new UniformSampler(10));
     95   CheckHistogram(1000, 0.05);
     96 }
     97 
     98 TEST_F(RangeSamplerTest, LogUniformProbabilities) {
     99   int range = 1000000;
    100   sampler_.reset(new LogUniformSampler(range));
    101   for (int i = 100; i < range; i *= 2) {
    102     float ratio = sampler_->Probability(i) / sampler_->Probability(i / 2);
    103     EXPECT_NEAR(ratio, 0.5, 0.1);
    104   }
    105 }
    106 
    107 TEST_F(RangeSamplerTest, LogUniformChecksum) {
    108   sampler_.reset(new LogUniformSampler(10));
    109   CheckProbabilitiesSumToOne();
    110 }
    111 
    112 TEST_F(RangeSamplerTest, LogUniformHistogram) {
    113   sampler_.reset(new LogUniformSampler(10));
    114   CheckHistogram(1000, 0.05);
    115 }
    116 
    117 TEST_F(RangeSamplerTest, UnigramProbabilities1) {
    118   sampler_.reset(new UnigramSampler(10));
    119   Update1();
    120   EXPECT_NEAR(sampler_->Probability(3), 0.55, 1e-4);
    121   for (int i = 0; i < 10; i++) {
    122     if (i != 3) {
    123       ASSERT_NEAR(sampler_->Probability(i), 0.05, 1e-4);
    124     }
    125   }
    126 }
    127 TEST_F(RangeSamplerTest, UnigramProbabilities2) {
    128   sampler_.reset(new UnigramSampler(10));
    129   Update2();
    130   for (int i = 0; i < 10; i++) {
    131     ASSERT_NEAR(sampler_->Probability(i), (i + 1) / 55.0, 1e-4);
    132   }
    133 }
    134 TEST_F(RangeSamplerTest, UnigramChecksum) {
    135   sampler_.reset(new UnigramSampler(10));
    136   Update1();
    137   CheckProbabilitiesSumToOne();
    138 }
    139 
    140 TEST_F(RangeSamplerTest, UnigramHistogram) {
    141   sampler_.reset(new UnigramSampler(10));
    142   Update1();
    143   CheckHistogram(1000, 0.05);
    144 }
    145 
    146 static const char kVocabContent[] =
    147     "w1,1\n"
    148     "w2,2\n"
    149     "w3,4\n"
    150     "w4,8\n"
    151     "w5,16\n"
    152     "w6,32\n"
    153     "w7,64\n"
    154     "w8,128\n"
    155     "w9,256";
    156 TEST_F(RangeSamplerTest, FixedUnigramProbabilities) {
    157   Env* env = Env::Default();
    158   string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
    159   TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
    160   sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0));
    161   // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
    162   for (int i = 0; i < 9; i++) {
    163     ASSERT_NEAR(sampler_->Probability(i), pow(2, i * 0.8) / 197.05, 1e-4);
    164   }
    165 }
    166 TEST_F(RangeSamplerTest, FixedUnigramChecksum) {
    167   Env* env = Env::Default();
    168   string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
    169   TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
    170   sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0));
    171   CheckProbabilitiesSumToOne();
    172 }
    173 
    174 TEST_F(RangeSamplerTest, FixedUnigramHistogram) {
    175   Env* env = Env::Default();
    176   string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
    177   TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
    178   sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0));
    179   CheckHistogram(1000, 0.05);
    180 }
    181 TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve1) {
    182   Env* env = Env::Default();
    183   string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
    184   TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
    185   sampler_.reset(new FixedUnigramSampler(env, 10, fname, 0.8, 1, 1, 0));
    186   ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
    187   // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
    188   for (int i = 1; i < 10; i++) {
    189     ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 1) * 0.8) / 197.05, 1e-4);
    190   }
    191 }
    192 TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve2) {
    193   Env* env = Env::Default();
    194   string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
    195   TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
    196   sampler_.reset(new FixedUnigramSampler(env, 11, fname, 0.8, 2, 1, 0));
    197   ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
    198   ASSERT_NEAR(sampler_->Probability(1), 0, 1e-4);
    199   // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
    200   for (int i = 2; i < 11; i++) {
    201     ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 2) * 0.8) / 197.05, 1e-4);
    202   }
    203 }
    204 TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesFromVector) {
    205   std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
    206   sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0));
    207   // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
    208   for (int i = 0; i < 9; i++) {
    209     ASSERT_NEAR(sampler_->Probability(i), pow(2, i * 0.8) / 197.05, 1e-4);
    210   }
    211 }
    212 TEST_F(RangeSamplerTest, FixedUnigramChecksumFromVector) {
    213   std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
    214   sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0));
    215   CheckProbabilitiesSumToOne();
    216 }
    217 TEST_F(RangeSamplerTest, FixedUnigramHistogramFromVector) {
    218   std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
    219   sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0));
    220   CheckHistogram(1000, 0.05);
    221 }
    222 TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve1FromVector) {
    223   std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
    224   sampler_.reset(new FixedUnigramSampler(10, weights, 0.8, 1, 1, 0));
    225   ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
    226   // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
    227   for (int i = 1; i < 10; i++) {
    228     ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 1) * 0.8) / 197.05, 1e-4);
    229   }
    230 }
    231 TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve2FromVector) {
    232   std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
    233   sampler_.reset(new FixedUnigramSampler(11, weights, 0.8, 2, 1, 0));
    234   ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
    235   ASSERT_NEAR(sampler_->Probability(1), 0, 1e-4);
    236   // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
    237   for (int i = 2; i < 11; i++) {
    238     ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 2) * 0.8) / 197.05, 1e-4);
    239   }
    240 }
    241 
    242 // AllSampler cannot call Sample or Probability directly.
    243 // We will test SampleBatchGetExpectedCount instead.
    244 TEST_F(RangeSamplerTest, All) {
    245   int batch_size = 10;
    246   sampler_.reset(new AllSampler(10));
    247   std::vector<int64> batch(batch_size);
    248   std::vector<float> batch_expected(batch_size);
    249   std::vector<int64> extras(2);
    250   std::vector<float> extras_expected(2);
    251   extras[0] = 0;
    252   extras[1] = batch_size - 1;
    253   sampler_->SampleBatchGetExpectedCount(nullptr,  // no random numbers needed
    254                                         false, &batch, &batch_expected, extras,
    255                                         &extras_expected);
    256   for (int i = 0; i < batch_size; i++) {
    257     EXPECT_EQ(i, batch[i]);
    258     EXPECT_EQ(1, batch_expected[i]);
    259   }
    260   EXPECT_EQ(1, extras_expected[0]);
    261   EXPECT_EQ(1, extras_expected[1]);
    262 }
    263 
    264 TEST_F(RangeSamplerTest, Unique) {
    265   // We sample num_batches batches, each without replacement.
    266   //
    267   // We check that the returned expected counts roughly agree with each other
    268   // and with the average observed frequencies over the set of batches.
    269   random::PhiloxRandom philox(123, 17);
    270   random::SimplePhilox rnd(&philox);
    271   const int range = 100;
    272   const int batch_size = 50;
    273   const int num_batches = 100;
    274   sampler_.reset(new LogUniformSampler(range));
    275   std::vector<int> histogram(range);
    276   std::vector<int64> batch(batch_size);
    277   std::vector<int64> all_values(range);
    278   for (int i = 0; i < range; i++) {
    279     all_values[i] = i;
    280   }
    281   std::vector<float> expected(range);
    282 
    283   // Sample one batch and get the expected counts of all values
    284   sampler_->SampleBatchGetExpectedCount(
    285       &rnd, true, &batch, MutableArraySlice<float>(), all_values, &expected);
    286   // Check that all elements are unique
    287   std::set<int64> s(batch.begin(), batch.end());
    288   CHECK_EQ(batch_size, s.size());
    289 
    290   for (int trial = 0; trial < num_batches; trial++) {
    291     std::vector<float> trial_expected(range);
    292     sampler_->SampleBatchGetExpectedCount(&rnd, true, &batch,
    293                                           MutableArraySlice<float>(),
    294                                           all_values, &trial_expected);
    295     for (int i = 0; i < range; i++) {
    296       EXPECT_NEAR(expected[i], trial_expected[i], expected[i] * 0.5);
    297     }
    298     for (int i = 0; i < batch_size; i++) {
    299       histogram[batch[i]]++;
    300     }
    301   }
    302   for (int i = 0; i < range; i++) {
    303     // Check that the computed expected count agrees with the average observed
    304     // count.
    305     const float average_count = static_cast<float>(histogram[i]) / num_batches;
    306     EXPECT_NEAR(expected[i], average_count, 0.2);
    307   }
    308 }
    309 
    310 TEST_F(RangeSamplerTest, Avoid) {
    311   random::PhiloxRandom philox(123, 17);
    312   random::SimplePhilox rnd(&philox);
    313   sampler_.reset(new LogUniformSampler(100));
    314   std::vector<int64> avoided(2);
    315   avoided[0] = 17;
    316   avoided[1] = 23;
    317   std::vector<int64> batch(98);
    318 
    319   // We expect to pick all elements of [0, 100) except the avoided two.
    320   sampler_->SampleBatchGetExpectedCountAvoid(
    321       &rnd, true, &batch, MutableArraySlice<float>(), ArraySlice<int64>(),
    322       MutableArraySlice<float>(), avoided);
    323 
    324   int sum = 0;
    325   for (auto val : batch) {
    326     sum += val;
    327   }
    328   const int expected_sum = 100 * 99 / 2 - avoided[0] - avoided[1];
    329   EXPECT_EQ(expected_sum, sum);
    330 }
    331 
    332 }  // namespace
    333 
    334 }  // namespace tensorflow
    335