Home | History | Annotate | Download | only in random
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/lib/random/weighted_picker.h"
     17 
     18 #include <string.h>
     19 #include <vector>
     20 
     21 #include "tensorflow/core/lib/random/simple_philox.h"
     22 #include "tensorflow/core/platform/logging.h"
     23 #include "tensorflow/core/platform/macros.h"
     24 #include "tensorflow/core/platform/test.h"
     25 #include "tensorflow/core/platform/test_benchmark.h"
     26 #include "tensorflow/core/platform/types.h"
     27 
     28 namespace tensorflow {
     29 namespace random {
     30 
     31 static void TestPicker(SimplePhilox* rnd, int size);
     32 static void CheckUniform(SimplePhilox* rnd, WeightedPicker* picker, int trials);
     33 static void CheckSkewed(SimplePhilox* rnd, WeightedPicker* picker, int trials);
     34 static void TestPickAt(int items, const int32* weights);
     35 
     36 TEST(WeightedPicker, Simple) {
     37   PhiloxRandom philox(testing::RandomSeed(), 17);
     38   SimplePhilox rnd(&philox);
     39 
     40   {
     41     VLOG(0) << "======= Zero-length picker";
     42     WeightedPicker picker(0);
     43     EXPECT_EQ(picker.Pick(&rnd), -1);
     44   }
     45 
     46   {
     47     VLOG(0) << "======= Singleton picker";
     48     WeightedPicker picker(1);
     49     EXPECT_EQ(picker.Pick(&rnd), 0);
     50     EXPECT_EQ(picker.Pick(&rnd), 0);
     51     EXPECT_EQ(picker.Pick(&rnd), 0);
     52   }
     53 
     54   {
     55     VLOG(0) << "======= Grown picker";
     56     WeightedPicker picker(0);
     57     for (int i = 0; i < 10; i++) {
     58       picker.Append(1);
     59     }
     60     CheckUniform(&rnd, &picker, 100000);
     61   }
     62 
     63   {
     64     VLOG(0) << "======= Grown picker with zero weights";
     65     WeightedPicker picker(1);
     66     picker.Resize(10);
     67     EXPECT_EQ(picker.Pick(&rnd), 0);
     68     EXPECT_EQ(picker.Pick(&rnd), 0);
     69     EXPECT_EQ(picker.Pick(&rnd), 0);
     70   }
     71 
     72   {
     73     VLOG(0) << "======= Shrink picker and check weights";
     74     WeightedPicker picker(1);
     75     picker.Resize(10);
     76     EXPECT_EQ(picker.Pick(&rnd), 0);
     77     EXPECT_EQ(picker.Pick(&rnd), 0);
     78     EXPECT_EQ(picker.Pick(&rnd), 0);
     79     for (int i = 0; i < 10; i++) {
     80       picker.set_weight(i, i);
     81     }
     82     EXPECT_EQ(picker.total_weight(), 45);
     83     picker.Resize(5);
     84     EXPECT_EQ(picker.total_weight(), 10);
     85     picker.Resize(2);
     86     EXPECT_EQ(picker.total_weight(), 1);
     87     picker.Resize(1);
     88     EXPECT_EQ(picker.total_weight(), 0);
     89   }
     90 }
     91 
     92 TEST(WeightedPicker, BigWeights) {
     93   PhiloxRandom philox(testing::RandomSeed() + 1, 17);
     94   SimplePhilox rnd(&philox);
     95   VLOG(0) << "======= Check uniform with big weights";
     96   WeightedPicker picker(2);
     97   picker.SetAllWeights(2147483646L / 3);  // (2^31 - 2) / 3
     98   CheckUniform(&rnd, &picker, 100000);
     99 }
    100 
    101 TEST(WeightedPicker, Deterministic) {
    102   VLOG(0) << "======= Testing deterministic pick";
    103   static const int32 weights[] = {1, 0, 200, 5, 42};
    104   TestPickAt(TF_ARRAYSIZE(weights), weights);
    105 }
    106 
    107 TEST(WeightedPicker, Randomized) {
    108   PhiloxRandom philox(testing::RandomSeed() + 10, 17);
    109   SimplePhilox rnd(&philox);
    110   TestPicker(&rnd, 1);
    111   TestPicker(&rnd, 2);
    112   TestPicker(&rnd, 3);
    113   TestPicker(&rnd, 4);
    114   TestPicker(&rnd, 7);
    115   TestPicker(&rnd, 8);
    116   TestPicker(&rnd, 9);
    117   TestPicker(&rnd, 10);
    118   TestPicker(&rnd, 100);
    119 }
    120 
    121 static void TestPicker(SimplePhilox* rnd, int size) {
    122   VLOG(0) << "======= Testing size " << size;
    123 
    124   // Check that empty picker returns -1
    125   {
    126     WeightedPicker picker(size);
    127     picker.SetAllWeights(0);
    128     for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), -1);
    129   }
    130 
    131   // Create zero weights array
    132   std::vector<int32> weights(size);
    133   for (int elem = 0; elem < size; elem++) {
    134     weights[elem] = 0;
    135   }
    136 
    137   // Check that singleton picker always returns the same element
    138   for (int elem = 0; elem < size; elem++) {
    139     WeightedPicker picker(size);
    140     picker.SetAllWeights(0);
    141     picker.set_weight(elem, elem + 1);
    142     for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), elem);
    143     weights[elem] = 10;
    144     picker.SetWeightsFromArray(size, &weights[0]);
    145     for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), elem);
    146     weights[elem] = 0;
    147   }
    148 
    149   // Check that uniform picker generates elements roughly uniformly
    150   {
    151     WeightedPicker picker(size);
    152     CheckUniform(rnd, &picker, 100000);
    153   }
    154 
    155   // Check uniform picker that was grown piecemeal
    156   if (size / 3 > 0) {
    157     WeightedPicker picker(size / 3);
    158     while (picker.num_elements() != size) {
    159       picker.Append(1);
    160     }
    161     CheckUniform(rnd, &picker, 100000);
    162   }
    163 
    164   // Check that skewed distribution works
    165   if (size <= 10) {
    166     // When picker grows one element at a time
    167     WeightedPicker picker(size);
    168     int32 weight = 1;
    169     for (int elem = 0; elem < size; elem++) {
    170       picker.set_weight(elem, weight);
    171       weights[elem] = weight;
    172       weight *= 2;
    173     }
    174     CheckSkewed(rnd, &picker, 1000000);
    175 
    176     // When picker is created from an array
    177     WeightedPicker array_picker(0);
    178     array_picker.SetWeightsFromArray(size, &weights[0]);
    179     CheckSkewed(rnd, &array_picker, 1000000);
    180   }
    181 }
    182 
    183 static void CheckUniform(SimplePhilox* rnd, WeightedPicker* picker,
    184                          int trials) {
    185   const int size = picker->num_elements();
    186   int* count = new int[size];
    187   memset(count, 0, sizeof(count[0]) * size);
    188   for (int i = 0; i < size * trials; i++) {
    189     const int elem = picker->Pick(rnd);
    190     EXPECT_GE(elem, 0);
    191     EXPECT_LT(elem, size);
    192     count[elem]++;
    193   }
    194   const int expected_min = int(0.9 * trials);
    195   const int expected_max = int(1.1 * trials);
    196   for (int i = 0; i < size; i++) {
    197     EXPECT_GE(count[i], expected_min);
    198     EXPECT_LE(count[i], expected_max);
    199   }
    200   delete[] count;
    201 }
    202 
    203 static void CheckSkewed(SimplePhilox* rnd, WeightedPicker* picker, int trials) {
    204   const int size = picker->num_elements();
    205   int* count = new int[size];
    206   memset(count, 0, sizeof(count[0]) * size);
    207   for (int i = 0; i < size * trials; i++) {
    208     const int elem = picker->Pick(rnd);
    209     EXPECT_GE(elem, 0);
    210     EXPECT_LT(elem, size);
    211     count[elem]++;
    212   }
    213 
    214   for (int i = 0; i < size - 1; i++) {
    215     LOG(INFO) << i << ": " << count[i];
    216     const float ratio = float(count[i + 1]) / float(count[i]);
    217     EXPECT_GE(ratio, 1.6f);
    218     EXPECT_LE(ratio, 2.4f);
    219   }
    220   delete[] count;
    221 }
    222 
    223 static void TestPickAt(int items, const int32* weights) {
    224   WeightedPicker picker(items);
    225   picker.SetWeightsFromArray(items, weights);
    226   int weight_index = 0;
    227   for (int i = 0; i < items; ++i) {
    228     for (int j = 0; j < weights[i]; ++j) {
    229       int pick = picker.PickAt(weight_index);
    230       EXPECT_EQ(pick, i);
    231       ++weight_index;
    232     }
    233   }
    234   EXPECT_EQ(weight_index, picker.total_weight());
    235 }
    236 
    237 static void BM_Create(int iters, int arg) {
    238   while (--iters > 0) {
    239     WeightedPicker p(arg);
    240   }
    241 }
    242 BENCHMARK(BM_Create)->Range(1, 1024);
    243 
    244 static void BM_CreateAndSetWeights(int iters, int arg) {
    245   std::vector<int32> weights(arg);
    246   for (int i = 0; i < arg; i++) {
    247     weights[i] = i * 10;
    248   }
    249   while (--iters > 0) {
    250     WeightedPicker p(arg);
    251     p.SetWeightsFromArray(arg, &weights[0]);
    252   }
    253 }
    254 BENCHMARK(BM_CreateAndSetWeights)->Range(1, 1024);
    255 
    256 static void BM_Pick(int iters, int arg) {
    257   PhiloxRandom philox(301, 17);
    258   SimplePhilox rnd(&philox);
    259   WeightedPicker p(arg);
    260   int result = 0;
    261   while (--iters > 0) {
    262     result += p.Pick(&rnd);
    263   }
    264   VLOG(4) << result;  // Dummy use
    265 }
    266 BENCHMARK(BM_Pick)->Range(1, 1024);
    267 
    268 }  // namespace random
    269 }  // namespace tensorflow
    270