1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <vector> 17 18 #include "tensorflow/core/kernels/range_sampler.h" 19 #include "tensorflow/core/lib/core/status_test_util.h" 20 #include "tensorflow/core/lib/io/path.h" 21 #include "tensorflow/core/lib/random/simple_philox.h" 22 #include "tensorflow/core/platform/env.h" 23 #include "tensorflow/core/platform/logging.h" 24 #include "tensorflow/core/platform/test.h" 25 26 namespace tensorflow { 27 namespace { 28 29 using gtl::ArraySlice; 30 using gtl::MutableArraySlice; 31 32 class RangeSamplerTest : public ::testing::Test { 33 protected: 34 void CheckProbabilitiesSumToOne() { 35 double sum = 0; 36 for (int i = 0; i < sampler_->range(); i++) { 37 sum += sampler_->Probability(i); 38 } 39 EXPECT_NEAR(sum, 1.0, 1e-4); 40 } 41 void CheckHistogram(int num_samples, float tolerance) { 42 const int range = sampler_->range(); 43 std::vector<int> h(range); 44 std::vector<int64> a(num_samples); 45 // Using a fixed random seed to make the test deterministic. 46 random::PhiloxRandom philox(123, 17); 47 random::SimplePhilox rnd(&philox); 48 sampler_->SampleBatch(&rnd, false, &a); 49 for (int i = 0; i < num_samples; i++) { 50 int64 val = a[i]; 51 ASSERT_GE(val, 0); 52 ASSERT_LT(val, range); 53 h[val]++; 54 } 55 for (int val = 0; val < range; val++) { 56 EXPECT_NEAR((h[val] + 0.0) / num_samples, sampler_->Probability(val), 57 tolerance); 58 } 59 } 60 void Update1() { 61 // Add the value 3 ten times. 62 std::vector<int64> a(10); 63 for (int i = 0; i < 10; i++) { 64 a[i] = 3; 65 } 66 sampler_->Update(a); 67 } 68 void Update2() { 69 // Add the value n times. 70 int64 a[10]; 71 for (int i = 0; i < 10; i++) { 72 a[i] = i; 73 } 74 for (int64 i = 1; i < 10; i++) { 75 sampler_->Update(ArraySlice<int64>(a + i, 10 - i)); 76 } 77 } 78 std::unique_ptr<RangeSampler> sampler_; 79 }; 80 81 TEST_F(RangeSamplerTest, UniformProbabilities) { 82 sampler_.reset(new UniformSampler(10)); 83 for (int i = 0; i < 10; i++) { 84 CHECK_EQ(sampler_->Probability(i), sampler_->Probability(0)); 85 } 86 } 87 88 TEST_F(RangeSamplerTest, UniformChecksum) { 89 sampler_.reset(new UniformSampler(10)); 90 CheckProbabilitiesSumToOne(); 91 } 92 93 TEST_F(RangeSamplerTest, UniformHistogram) { 94 sampler_.reset(new UniformSampler(10)); 95 CheckHistogram(1000, 0.05); 96 } 97 98 TEST_F(RangeSamplerTest, LogUniformProbabilities) { 99 int range = 1000000; 100 sampler_.reset(new LogUniformSampler(range)); 101 for (int i = 100; i < range; i *= 2) { 102 float ratio = sampler_->Probability(i) / sampler_->Probability(i / 2); 103 EXPECT_NEAR(ratio, 0.5, 0.1); 104 } 105 } 106 107 TEST_F(RangeSamplerTest, LogUniformChecksum) { 108 sampler_.reset(new LogUniformSampler(10)); 109 CheckProbabilitiesSumToOne(); 110 } 111 112 TEST_F(RangeSamplerTest, LogUniformHistogram) { 113 sampler_.reset(new LogUniformSampler(10)); 114 CheckHistogram(1000, 0.05); 115 } 116 117 TEST_F(RangeSamplerTest, UnigramProbabilities1) { 118 sampler_.reset(new UnigramSampler(10)); 119 Update1(); 120 EXPECT_NEAR(sampler_->Probability(3), 0.55, 1e-4); 121 for (int i = 0; i < 10; i++) { 122 if (i != 3) { 123 ASSERT_NEAR(sampler_->Probability(i), 0.05, 1e-4); 124 } 125 } 126 } 127 TEST_F(RangeSamplerTest, UnigramProbabilities2) { 128 sampler_.reset(new UnigramSampler(10)); 129 Update2(); 130 for (int i = 0; i < 10; i++) { 131 ASSERT_NEAR(sampler_->Probability(i), (i + 1) / 55.0, 1e-4); 132 } 133 } 134 TEST_F(RangeSamplerTest, UnigramChecksum) { 135 sampler_.reset(new UnigramSampler(10)); 136 Update1(); 137 CheckProbabilitiesSumToOne(); 138 } 139 140 TEST_F(RangeSamplerTest, UnigramHistogram) { 141 sampler_.reset(new UnigramSampler(10)); 142 Update1(); 143 CheckHistogram(1000, 0.05); 144 } 145 146 static const char kVocabContent[] = 147 "w1,1\n" 148 "w2,2\n" 149 "w3,4\n" 150 "w4,8\n" 151 "w5,16\n" 152 "w6,32\n" 153 "w7,64\n" 154 "w8,128\n" 155 "w9,256"; 156 TEST_F(RangeSamplerTest, FixedUnigramProbabilities) { 157 Env* env = Env::Default(); 158 string fname = io::JoinPath(testing::TmpDir(), "vocab_file"); 159 TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent)); 160 sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0)); 161 // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05 162 for (int i = 0; i < 9; i++) { 163 ASSERT_NEAR(sampler_->Probability(i), pow(2, i * 0.8) / 197.05, 1e-4); 164 } 165 } 166 TEST_F(RangeSamplerTest, FixedUnigramChecksum) { 167 Env* env = Env::Default(); 168 string fname = io::JoinPath(testing::TmpDir(), "vocab_file"); 169 TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent)); 170 sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0)); 171 CheckProbabilitiesSumToOne(); 172 } 173 174 TEST_F(RangeSamplerTest, FixedUnigramHistogram) { 175 Env* env = Env::Default(); 176 string fname = io::JoinPath(testing::TmpDir(), "vocab_file"); 177 TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent)); 178 sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0)); 179 CheckHistogram(1000, 0.05); 180 } 181 TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve1) { 182 Env* env = Env::Default(); 183 string fname = io::JoinPath(testing::TmpDir(), "vocab_file"); 184 TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent)); 185 sampler_.reset(new FixedUnigramSampler(env, 10, fname, 0.8, 1, 1, 0)); 186 ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4); 187 // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05 188 for (int i = 1; i < 10; i++) { 189 ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 1) * 0.8) / 197.05, 1e-4); 190 } 191 } 192 TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve2) { 193 Env* env = Env::Default(); 194 string fname = io::JoinPath(testing::TmpDir(), "vocab_file"); 195 TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent)); 196 sampler_.reset(new FixedUnigramSampler(env, 11, fname, 0.8, 2, 1, 0)); 197 ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4); 198 ASSERT_NEAR(sampler_->Probability(1), 0, 1e-4); 199 // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05 200 for (int i = 2; i < 11; i++) { 201 ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 2) * 0.8) / 197.05, 1e-4); 202 } 203 } 204 TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesFromVector) { 205 std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256}; 206 sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0)); 207 // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05 208 for (int i = 0; i < 9; i++) { 209 ASSERT_NEAR(sampler_->Probability(i), pow(2, i * 0.8) / 197.05, 1e-4); 210 } 211 } 212 TEST_F(RangeSamplerTest, FixedUnigramChecksumFromVector) { 213 std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256}; 214 sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0)); 215 CheckProbabilitiesSumToOne(); 216 } 217 TEST_F(RangeSamplerTest, FixedUnigramHistogramFromVector) { 218 std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256}; 219 sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0)); 220 CheckHistogram(1000, 0.05); 221 } 222 TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve1FromVector) { 223 std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256}; 224 sampler_.reset(new FixedUnigramSampler(10, weights, 0.8, 1, 1, 0)); 225 ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4); 226 // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05 227 for (int i = 1; i < 10; i++) { 228 ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 1) * 0.8) / 197.05, 1e-4); 229 } 230 } 231 TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve2FromVector) { 232 std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256}; 233 sampler_.reset(new FixedUnigramSampler(11, weights, 0.8, 2, 1, 0)); 234 ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4); 235 ASSERT_NEAR(sampler_->Probability(1), 0, 1e-4); 236 // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05 237 for (int i = 2; i < 11; i++) { 238 ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 2) * 0.8) / 197.05, 1e-4); 239 } 240 } 241 242 // AllSampler cannot call Sample or Probability directly. 243 // We will test SampleBatchGetExpectedCount instead. 244 TEST_F(RangeSamplerTest, All) { 245 int batch_size = 10; 246 sampler_.reset(new AllSampler(10)); 247 std::vector<int64> batch(batch_size); 248 std::vector<float> batch_expected(batch_size); 249 std::vector<int64> extras(2); 250 std::vector<float> extras_expected(2); 251 extras[0] = 0; 252 extras[1] = batch_size - 1; 253 sampler_->SampleBatchGetExpectedCount(nullptr, // no random numbers needed 254 false, &batch, &batch_expected, extras, 255 &extras_expected); 256 for (int i = 0; i < batch_size; i++) { 257 EXPECT_EQ(i, batch[i]); 258 EXPECT_EQ(1, batch_expected[i]); 259 } 260 EXPECT_EQ(1, extras_expected[0]); 261 EXPECT_EQ(1, extras_expected[1]); 262 } 263 264 TEST_F(RangeSamplerTest, Unique) { 265 // We sample num_batches batches, each without replacement. 266 // 267 // We check that the returned expected counts roughly agree with each other 268 // and with the average observed frequencies over the set of batches. 269 random::PhiloxRandom philox(123, 17); 270 random::SimplePhilox rnd(&philox); 271 const int range = 100; 272 const int batch_size = 50; 273 const int num_batches = 100; 274 sampler_.reset(new LogUniformSampler(range)); 275 std::vector<int> histogram(range); 276 std::vector<int64> batch(batch_size); 277 std::vector<int64> all_values(range); 278 for (int i = 0; i < range; i++) { 279 all_values[i] = i; 280 } 281 std::vector<float> expected(range); 282 283 // Sample one batch and get the expected counts of all values 284 sampler_->SampleBatchGetExpectedCount( 285 &rnd, true, &batch, MutableArraySlice<float>(), all_values, &expected); 286 // Check that all elements are unique 287 std::set<int64> s(batch.begin(), batch.end()); 288 CHECK_EQ(batch_size, s.size()); 289 290 for (int trial = 0; trial < num_batches; trial++) { 291 std::vector<float> trial_expected(range); 292 sampler_->SampleBatchGetExpectedCount(&rnd, true, &batch, 293 MutableArraySlice<float>(), 294 all_values, &trial_expected); 295 for (int i = 0; i < range; i++) { 296 EXPECT_NEAR(expected[i], trial_expected[i], expected[i] * 0.5); 297 } 298 for (int i = 0; i < batch_size; i++) { 299 histogram[batch[i]]++; 300 } 301 } 302 for (int i = 0; i < range; i++) { 303 // Check that the computed expected count agrees with the average observed 304 // count. 305 const float average_count = static_cast<float>(histogram[i]) / num_batches; 306 EXPECT_NEAR(expected[i], average_count, 0.2); 307 } 308 } 309 310 TEST_F(RangeSamplerTest, Avoid) { 311 random::PhiloxRandom philox(123, 17); 312 random::SimplePhilox rnd(&philox); 313 sampler_.reset(new LogUniformSampler(100)); 314 std::vector<int64> avoided(2); 315 avoided[0] = 17; 316 avoided[1] = 23; 317 std::vector<int64> batch(98); 318 319 // We expect to pick all elements of [0, 100) except the avoided two. 320 sampler_->SampleBatchGetExpectedCountAvoid( 321 &rnd, true, &batch, MutableArraySlice<float>(), ArraySlice<int64>(), 322 MutableArraySlice<float>(), avoided); 323 324 int sum = 0; 325 for (auto val : batch) { 326 sum += val; 327 } 328 const int expected_sum = 100 * 99 / 2 - avoided[0] - avoided[1]; 329 EXPECT_EQ(expected_sum, sum); 330 } 331 332 } // namespace 333 334 } // namespace tensorflow 335