1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/lib/random/weighted_picker.h" 17 18 #include <string.h> 19 #include <vector> 20 21 #include "tensorflow/core/lib/random/simple_philox.h" 22 #include "tensorflow/core/platform/logging.h" 23 #include "tensorflow/core/platform/macros.h" 24 #include "tensorflow/core/platform/test.h" 25 #include "tensorflow/core/platform/test_benchmark.h" 26 #include "tensorflow/core/platform/types.h" 27 28 namespace tensorflow { 29 namespace random { 30 31 static void TestPicker(SimplePhilox* rnd, int size); 32 static void CheckUniform(SimplePhilox* rnd, WeightedPicker* picker, int trials); 33 static void CheckSkewed(SimplePhilox* rnd, WeightedPicker* picker, int trials); 34 static void TestPickAt(int items, const int32* weights); 35 36 TEST(WeightedPicker, Simple) { 37 PhiloxRandom philox(testing::RandomSeed(), 17); 38 SimplePhilox rnd(&philox); 39 40 { 41 VLOG(0) << "======= Zero-length picker"; 42 WeightedPicker picker(0); 43 EXPECT_EQ(picker.Pick(&rnd), -1); 44 } 45 46 { 47 VLOG(0) << "======= Singleton picker"; 48 WeightedPicker picker(1); 49 EXPECT_EQ(picker.Pick(&rnd), 0); 50 EXPECT_EQ(picker.Pick(&rnd), 0); 51 EXPECT_EQ(picker.Pick(&rnd), 0); 52 } 53 54 { 55 VLOG(0) << "======= Grown picker"; 56 WeightedPicker picker(0); 57 for (int i = 0; i < 10; i++) { 58 picker.Append(1); 59 } 60 CheckUniform(&rnd, &picker, 100000); 61 } 62 63 { 64 VLOG(0) << "======= Grown picker with zero weights"; 65 WeightedPicker picker(1); 66 picker.Resize(10); 67 EXPECT_EQ(picker.Pick(&rnd), 0); 68 EXPECT_EQ(picker.Pick(&rnd), 0); 69 EXPECT_EQ(picker.Pick(&rnd), 0); 70 } 71 72 { 73 VLOG(0) << "======= Shrink picker and check weights"; 74 WeightedPicker picker(1); 75 picker.Resize(10); 76 EXPECT_EQ(picker.Pick(&rnd), 0); 77 EXPECT_EQ(picker.Pick(&rnd), 0); 78 EXPECT_EQ(picker.Pick(&rnd), 0); 79 for (int i = 0; i < 10; i++) { 80 picker.set_weight(i, i); 81 } 82 EXPECT_EQ(picker.total_weight(), 45); 83 picker.Resize(5); 84 EXPECT_EQ(picker.total_weight(), 10); 85 picker.Resize(2); 86 EXPECT_EQ(picker.total_weight(), 1); 87 picker.Resize(1); 88 EXPECT_EQ(picker.total_weight(), 0); 89 } 90 } 91 92 TEST(WeightedPicker, BigWeights) { 93 PhiloxRandom philox(testing::RandomSeed() + 1, 17); 94 SimplePhilox rnd(&philox); 95 VLOG(0) << "======= Check uniform with big weights"; 96 WeightedPicker picker(2); 97 picker.SetAllWeights(2147483646L / 3); // (2^31 - 2) / 3 98 CheckUniform(&rnd, &picker, 100000); 99 } 100 101 TEST(WeightedPicker, Deterministic) { 102 VLOG(0) << "======= Testing deterministic pick"; 103 static const int32 weights[] = {1, 0, 200, 5, 42}; 104 TestPickAt(TF_ARRAYSIZE(weights), weights); 105 } 106 107 TEST(WeightedPicker, Randomized) { 108 PhiloxRandom philox(testing::RandomSeed() + 10, 17); 109 SimplePhilox rnd(&philox); 110 TestPicker(&rnd, 1); 111 TestPicker(&rnd, 2); 112 TestPicker(&rnd, 3); 113 TestPicker(&rnd, 4); 114 TestPicker(&rnd, 7); 115 TestPicker(&rnd, 8); 116 TestPicker(&rnd, 9); 117 TestPicker(&rnd, 10); 118 TestPicker(&rnd, 100); 119 } 120 121 static void TestPicker(SimplePhilox* rnd, int size) { 122 VLOG(0) << "======= Testing size " << size; 123 124 // Check that empty picker returns -1 125 { 126 WeightedPicker picker(size); 127 picker.SetAllWeights(0); 128 for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), -1); 129 } 130 131 // Create zero weights array 132 std::vector<int32> weights(size); 133 for (int elem = 0; elem < size; elem++) { 134 weights[elem] = 0; 135 } 136 137 // Check that singleton picker always returns the same element 138 for (int elem = 0; elem < size; elem++) { 139 WeightedPicker picker(size); 140 picker.SetAllWeights(0); 141 picker.set_weight(elem, elem + 1); 142 for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), elem); 143 weights[elem] = 10; 144 picker.SetWeightsFromArray(size, &weights[0]); 145 for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), elem); 146 weights[elem] = 0; 147 } 148 149 // Check that uniform picker generates elements roughly uniformly 150 { 151 WeightedPicker picker(size); 152 CheckUniform(rnd, &picker, 100000); 153 } 154 155 // Check uniform picker that was grown piecemeal 156 if (size / 3 > 0) { 157 WeightedPicker picker(size / 3); 158 while (picker.num_elements() != size) { 159 picker.Append(1); 160 } 161 CheckUniform(rnd, &picker, 100000); 162 } 163 164 // Check that skewed distribution works 165 if (size <= 10) { 166 // When picker grows one element at a time 167 WeightedPicker picker(size); 168 int32 weight = 1; 169 for (int elem = 0; elem < size; elem++) { 170 picker.set_weight(elem, weight); 171 weights[elem] = weight; 172 weight *= 2; 173 } 174 CheckSkewed(rnd, &picker, 1000000); 175 176 // When picker is created from an array 177 WeightedPicker array_picker(0); 178 array_picker.SetWeightsFromArray(size, &weights[0]); 179 CheckSkewed(rnd, &array_picker, 1000000); 180 } 181 } 182 183 static void CheckUniform(SimplePhilox* rnd, WeightedPicker* picker, 184 int trials) { 185 const int size = picker->num_elements(); 186 int* count = new int[size]; 187 memset(count, 0, sizeof(count[0]) * size); 188 for (int i = 0; i < size * trials; i++) { 189 const int elem = picker->Pick(rnd); 190 EXPECT_GE(elem, 0); 191 EXPECT_LT(elem, size); 192 count[elem]++; 193 } 194 const int expected_min = int(0.9 * trials); 195 const int expected_max = int(1.1 * trials); 196 for (int i = 0; i < size; i++) { 197 EXPECT_GE(count[i], expected_min); 198 EXPECT_LE(count[i], expected_max); 199 } 200 delete[] count; 201 } 202 203 static void CheckSkewed(SimplePhilox* rnd, WeightedPicker* picker, int trials) { 204 const int size = picker->num_elements(); 205 int* count = new int[size]; 206 memset(count, 0, sizeof(count[0]) * size); 207 for (int i = 0; i < size * trials; i++) { 208 const int elem = picker->Pick(rnd); 209 EXPECT_GE(elem, 0); 210 EXPECT_LT(elem, size); 211 count[elem]++; 212 } 213 214 for (int i = 0; i < size - 1; i++) { 215 LOG(INFO) << i << ": " << count[i]; 216 const float ratio = float(count[i + 1]) / float(count[i]); 217 EXPECT_GE(ratio, 1.6f); 218 EXPECT_LE(ratio, 2.4f); 219 } 220 delete[] count; 221 } 222 223 static void TestPickAt(int items, const int32* weights) { 224 WeightedPicker picker(items); 225 picker.SetWeightsFromArray(items, weights); 226 int weight_index = 0; 227 for (int i = 0; i < items; ++i) { 228 for (int j = 0; j < weights[i]; ++j) { 229 int pick = picker.PickAt(weight_index); 230 EXPECT_EQ(pick, i); 231 ++weight_index; 232 } 233 } 234 EXPECT_EQ(weight_index, picker.total_weight()); 235 } 236 237 static void BM_Create(int iters, int arg) { 238 while (--iters > 0) { 239 WeightedPicker p(arg); 240 } 241 } 242 BENCHMARK(BM_Create)->Range(1, 1024); 243 244 static void BM_CreateAndSetWeights(int iters, int arg) { 245 std::vector<int32> weights(arg); 246 for (int i = 0; i < arg; i++) { 247 weights[i] = i * 10; 248 } 249 while (--iters > 0) { 250 WeightedPicker p(arg); 251 p.SetWeightsFromArray(arg, &weights[0]); 252 } 253 } 254 BENCHMARK(BM_CreateAndSetWeights)->Range(1, 1024); 255 256 static void BM_Pick(int iters, int arg) { 257 PhiloxRandom philox(301, 17); 258 SimplePhilox rnd(&philox); 259 WeightedPicker p(arg); 260 int result = 0; 261 while (--iters > 0) { 262 result += p.Pick(&rnd); 263 } 264 VLOG(4) << result; // Dummy use 265 } 266 BENCHMARK(BM_Pick)->Range(1, 1024); 267 268 } // namespace random 269 } // namespace tensorflow 270