1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/kernels/range_sampler.h" 17 18 #include <unordered_set> 19 #include <vector> 20 21 #include "tensorflow/core/lib/core/errors.h" 22 #include "tensorflow/core/lib/gtl/map_util.h" 23 #include "tensorflow/core/lib/io/inputbuffer.h" 24 #include "tensorflow/core/lib/strings/numbers.h" 25 #include "tensorflow/core/lib/strings/str_util.h" 26 #include "tensorflow/core/platform/logging.h" 27 #include "tensorflow/core/platform/mutex.h" 28 #include "tensorflow/core/platform/types.h" 29 30 namespace tensorflow { 31 32 using gtl::ArraySlice; 33 using gtl::MutableArraySlice; 34 35 RangeSampler::~RangeSampler() {} 36 37 void RangeSampler::SampleBatch(random::SimplePhilox* rnd, bool unique, 38 gtl::MutableArraySlice<int64> batch) const { 39 SampleBatchGetExpectedCount( 40 rnd, unique, batch, gtl::MutableArraySlice<float>(), 41 gtl::ArraySlice<int64>(), gtl::MutableArraySlice<float>()); 42 } 43 44 void RangeSampler::SampleBatchGetExpectedCount( 45 random::SimplePhilox* rnd, bool unique, gtl::MutableArraySlice<int64> batch, 46 gtl::MutableArraySlice<float> batch_expected_count, 47 gtl::ArraySlice<int64> extras, 48 gtl::MutableArraySlice<float> extras_expected_count) const { 49 SampleBatchGetExpectedCountAvoid(rnd, unique, batch, batch_expected_count, 50 extras, extras_expected_count, 51 gtl::ArraySlice<int64>()); 52 } 53 54 namespace { 55 56 // Approximates the expected count of a value in the output of SampleBatch. 57 // 58 // If unique=false, then this is (Probability(value) * batch_size) 59 // 60 // We use batch_size and num_tries, where num_tries is the observed number of 61 // tries it took to get batch_size unique values. 62 // 63 // Assuming (falsely) that the number of tries to get a batch of batch_size 64 // distinct values is _always_ num_tries, the probability that the value 65 // is in a batch is (1 - (1-p)^num_tries) 66 static float ExpectedCountHelper(float p, int batch_size, int num_tries) { 67 if (num_tries == batch_size) { 68 // This shortcut will always be taken if unique=false 69 return p * batch_size; 70 } 71 // numerically stable version of (1 - (1-p)^num_tries) 72 return -expm1(num_tries * log1p(-p)); 73 } 74 75 } // namespace 76 77 void RangeSampler::SampleBatchGetExpectedCountAvoid( 78 random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch, 79 MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras, 80 MutableArraySlice<float> extras_expected_count, 81 ArraySlice<int64> avoided_values) const { 82 const int batch_size = batch.size(); 83 int num_tries; 84 85 if (unique) { 86 CHECK_LE(batch_size + avoided_values.size(), range_); 87 std::unordered_set<int64> used(batch_size); 88 used.insert(avoided_values.begin(), avoided_values.end()); 89 int num_picked = 0; 90 num_tries = 0; 91 while (num_picked < batch_size) { 92 num_tries++; 93 CHECK_LT(num_tries, kint32max); 94 int64 value = Sample(rnd); 95 if (gtl::InsertIfNotPresent(&used, value)) { 96 batch[num_picked++] = value; 97 } 98 } 99 } else { 100 CHECK_EQ(avoided_values.size(), size_t{0}) 101 << "avoided_values only supported with unique=true"; 102 for (int i = 0; i < batch_size; i++) { 103 batch[i] = Sample(rnd); 104 } 105 num_tries = batch_size; 106 } 107 // Compute the expected counts of the batch and the extra values 108 if (!batch_expected_count.empty()) { 109 CHECK_EQ(batch_size, batch_expected_count.size()); 110 for (int i = 0; i < batch_size; i++) { 111 batch_expected_count[i] = 112 ExpectedCountHelper(Probability(batch[i]), batch_size, num_tries); 113 } 114 } 115 CHECK_EQ(extras.size(), extras_expected_count.size()); 116 for (size_t i = 0; i < extras.size(); i++) { 117 extras_expected_count[i] = 118 ExpectedCountHelper(Probability(extras[i]), batch_size, num_tries); 119 } 120 } 121 122 AllSampler::AllSampler(int64 range) : RangeSampler(range) {} 123 124 void AllSampler::SampleBatchGetExpectedCountAvoid( 125 random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch, 126 MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras, 127 MutableArraySlice<float> extras_expected_count, 128 ArraySlice<int64> avoided_values) const { 129 const int batch_size = batch.size(); 130 CHECK_EQ(range_, batch_size); 131 for (int i = 0; i < batch_size; i++) { 132 batch[i] = i; 133 } 134 if (!batch_expected_count.empty()) { 135 CHECK_EQ(batch_size, batch_expected_count.size()); 136 for (int i = 0; i < batch_size; i++) { 137 batch_expected_count[i] = 1; 138 } 139 } 140 CHECK_EQ(size_t{0}, avoided_values.size()); 141 CHECK_EQ(extras.size(), extras_expected_count.size()); 142 for (size_t i = 0; i < extras.size(); i++) { 143 extras_expected_count[i] = 1; 144 } 145 } 146 147 UniformSampler::UniformSampler(int64 range) 148 : RangeSampler(range), inv_range_(1.0 / range) {} 149 150 int64 UniformSampler::Sample(random::SimplePhilox* rnd) const { 151 return rnd->Uniform64(range_); 152 } 153 154 float UniformSampler::Probability(int64 value) const { return inv_range_; } 155 156 LogUniformSampler::LogUniformSampler(int64 range) 157 : RangeSampler(range), log_range_(log(range + 1)) {} 158 159 int64 LogUniformSampler::Sample(random::SimplePhilox* rnd) const { 160 const int64 value = 161 static_cast<int64>(exp(rnd->RandDouble() * log_range_)) - 1; 162 CHECK_GE(value, 0); 163 // Mathematically, value should be <= range_, but might not be due to some 164 // floating point roundoff, so we mod by range_. 165 return value % range_; 166 } 167 168 float LogUniformSampler::Probability(int64 value) const { 169 // value is returned iff the call to UniformDouble(log_range_) in the 170 // Sample() function returns a value between log(value + 1) 171 // and log(value + 2). The probability of this is: 172 // (log(value + 2) - log(value + 1)) / log_range 173 // To avoid two calls to log(), we compute this as follows: 174 return (log((value + 2.0) / (value + 1.0))) / log_range_; 175 } 176 177 ThreadUnsafeUnigramSampler::ThreadUnsafeUnigramSampler(int64 range) 178 : RangeSampler(range), picker_(range) { 179 CHECK_LT(range, kint32max); 180 } 181 182 int64 ThreadUnsafeUnigramSampler::Sample(random::SimplePhilox* rnd) const { 183 return picker_.Pick(rnd); 184 } 185 186 float ThreadUnsafeUnigramSampler::Probability(int64 value) const { 187 return static_cast<float>(picker_.get_weight(value)) / picker_.total_weight(); 188 } 189 190 void ThreadUnsafeUnigramSampler::Update(ArraySlice<int64> values) { 191 int num_updates = std::min(static_cast<int>(values.size()), 192 kint32max - picker_.total_weight()); 193 for (int i = 0; i < num_updates; i++) { 194 const int64 value = values[i]; 195 picker_.set_weight(value, picker_.get_weight(value) + 1); 196 } 197 } 198 199 // Thread-safe unigram sampler 200 UnigramSampler::UnigramSampler(int64 range) 201 : RangeSampler(range), unsafe_sampler_(range) { 202 CHECK_LT(range, kint32max); 203 } 204 205 int64 UnigramSampler::Sample(random::SimplePhilox* rnd) const { 206 mutex_lock lock(mu_); // could use reader lock 207 return unsafe_sampler_.Sample(rnd); 208 } 209 210 float UnigramSampler::Probability(int64 value) const { 211 mutex_lock lock(mu_); // could use reader lock 212 return unsafe_sampler_.Probability(value); 213 } 214 215 // Overriding at a high level results in far fewer lock acquisitions. 216 void UnigramSampler::SampleBatchGetExpectedCountAvoid( 217 random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch, 218 MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras, 219 MutableArraySlice<float> extras_expected_count, 220 ArraySlice<int64> avoided_values) const { 221 mutex_lock lock(mu_); // could use reader lock 222 unsafe_sampler_.SampleBatchGetExpectedCountAvoid( 223 rnd, unique, batch, batch_expected_count, extras, extras_expected_count, 224 avoided_values); 225 } 226 227 void UnigramSampler::Update(ArraySlice<int64> values) { 228 mutex_lock lock(mu_); 229 unsafe_sampler_.Update(values); 230 } 231 232 FixedUnigramSampler::FixedUnigramSampler(Env* env, int64 range, 233 const string& vocab_file, 234 float distortion, 235 int32 num_reserved_ids, 236 int32 num_shards, int32 shard) 237 : RangeSampler(range), 238 total_weight_(0.0), 239 num_shards_(num_shards), 240 shard_(shard) { 241 FillReservedIds(num_reserved_ids); 242 // TODO(vanhoucke): make this non-crashing. 243 TF_CHECK_OK(LoadFromFile(env, vocab_file, distortion)); 244 CHECK_EQ(range, weights_.size()); 245 dist_sampler_.reset(new random::DistributionSampler(weights_)); 246 } 247 248 FixedUnigramSampler::FixedUnigramSampler(int64 range, 249 const std::vector<float>& unigrams, 250 float distortion, 251 int32 num_reserved_ids, 252 int32 num_shards, int32 shard) 253 : RangeSampler(range), 254 total_weight_(0.0), 255 num_shards_(num_shards), 256 shard_(shard) { 257 FillReservedIds(num_reserved_ids); 258 LoadFromUnigrams(unigrams, distortion); 259 // TODO(vanhoucke): make this non-crashing. 260 CHECK_EQ(range, weights_.size()); 261 dist_sampler_.reset(new random::DistributionSampler(weights_)); 262 } 263 264 float FixedUnigramSampler::Probability(int64 value) const { 265 if (value < 0 || static_cast<size_t>(value) >= weights_.size()) { 266 return 0.0; 267 } 268 return weights_.at(value) / total_weight_; 269 } 270 271 int64 FixedUnigramSampler::Sample(random::SimplePhilox* rnd) const { 272 return dist_sampler_->Sample(rnd); 273 } 274 275 void FixedUnigramSampler::FillReservedIds(int32 num_reserved_ids) { 276 for (int32 word_id = 0; word_id < num_reserved_ids; ++word_id) { 277 if (word_id % num_shards_ == shard_) weights_.push_back(0.0); 278 } 279 } 280 281 Status FixedUnigramSampler::LoadFromFile(Env* env, const string& vocab_file, 282 float distortion) { 283 std::unique_ptr<RandomAccessFile> file; 284 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(vocab_file, &file)); 285 286 io::InputBuffer in(file.get(), 262144 /*bytes*/); 287 string line; 288 int32 word_id = weights_.size(); 289 while (in.ReadLine(&line).ok()) { 290 // The vocabulary file should be in csv like format, with the last 291 // field the weight associated with the word. 292 std::vector<string> cols = str_util::Split(line, ','); 293 if (cols.empty()) continue; 294 // Skip entries that do not belong to this shard. 295 if (word_id % num_shards_ == shard_) { 296 float w = 0.0; 297 if (!strings::safe_strtof(cols.at(cols.size() - 1).c_str(), &w)) { 298 return errors::InvalidArgument("Wrong vocabulary format at line: ", 299 line); 300 } 301 w = pow(w, distortion); 302 total_weight_ += w; 303 weights_.push_back(w); 304 } 305 ++word_id; 306 } 307 return Status::OK(); 308 } 309 310 void FixedUnigramSampler::LoadFromUnigrams(const std::vector<float>& unigrams, 311 float distortion) { 312 int32 word_id = weights_.size(); 313 for (float w : unigrams) { 314 // Skip entries that do not belong to this shard. 315 if (word_id % num_shards_ == shard_) { 316 w = pow(w, distortion); 317 total_weight_ += w; 318 weights_.push_back(w); 319 } 320 ++word_id; 321 } 322 } 323 324 } // namespace tensorflow 325