Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/range_sampler.h"
     17 
     18 #include <unordered_set>
     19 #include <vector>
     20 
     21 #include "tensorflow/core/lib/core/errors.h"
     22 #include "tensorflow/core/lib/gtl/map_util.h"
     23 #include "tensorflow/core/lib/io/inputbuffer.h"
     24 #include "tensorflow/core/lib/strings/numbers.h"
     25 #include "tensorflow/core/lib/strings/str_util.h"
     26 #include "tensorflow/core/platform/logging.h"
     27 #include "tensorflow/core/platform/mutex.h"
     28 #include "tensorflow/core/platform/types.h"
     29 
     30 namespace tensorflow {
     31 
     32 using gtl::ArraySlice;
     33 using gtl::MutableArraySlice;
     34 
     35 RangeSampler::~RangeSampler() {}
     36 
     37 void RangeSampler::SampleBatch(random::SimplePhilox* rnd, bool unique,
     38                                gtl::MutableArraySlice<int64> batch) const {
     39   SampleBatchGetExpectedCount(
     40       rnd, unique, batch, gtl::MutableArraySlice<float>(),
     41       gtl::ArraySlice<int64>(), gtl::MutableArraySlice<float>());
     42 }
     43 
     44 void RangeSampler::SampleBatchGetExpectedCount(
     45     random::SimplePhilox* rnd, bool unique, gtl::MutableArraySlice<int64> batch,
     46     gtl::MutableArraySlice<float> batch_expected_count,
     47     gtl::ArraySlice<int64> extras,
     48     gtl::MutableArraySlice<float> extras_expected_count) const {
     49   SampleBatchGetExpectedCountAvoid(rnd, unique, batch, batch_expected_count,
     50                                    extras, extras_expected_count,
     51                                    gtl::ArraySlice<int64>());
     52 }
     53 
     54 namespace {
     55 
     56 // Approximates the expected count of a value in the output of SampleBatch.
     57 //
     58 // If unique=false, then this is (Probability(value) * batch_size)
     59 //
     60 // We use batch_size and num_tries, where num_tries is the observed number of
     61 // tries it took to get batch_size unique values.
     62 //
     63 // Assuming (falsely) that the number of tries to get a batch of batch_size
     64 // distinct values is _always_ num_tries, the probability that the value
     65 // is in a batch is (1 - (1-p)^num_tries)
     66 static float ExpectedCountHelper(float p, int batch_size, int num_tries) {
     67   if (num_tries == batch_size) {
     68     // This shortcut will always be taken if unique=false
     69     return p * batch_size;
     70   }
     71   // numerically stable version of (1 - (1-p)^num_tries)
     72   return -expm1(num_tries * log1p(-p));
     73 }
     74 
     75 }  // namespace
     76 
     77 void RangeSampler::SampleBatchGetExpectedCountAvoid(
     78     random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch,
     79     MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras,
     80     MutableArraySlice<float> extras_expected_count,
     81     ArraySlice<int64> avoided_values) const {
     82   const int batch_size = batch.size();
     83   int num_tries;
     84 
     85   if (unique) {
     86     CHECK_LE(batch_size + avoided_values.size(), range_);
     87     std::unordered_set<int64> used(batch_size);
     88     used.insert(avoided_values.begin(), avoided_values.end());
     89     int num_picked = 0;
     90     num_tries = 0;
     91     while (num_picked < batch_size) {
     92       num_tries++;
     93       CHECK_LT(num_tries, kint32max);
     94       int64 value = Sample(rnd);
     95       if (gtl::InsertIfNotPresent(&used, value)) {
     96         batch[num_picked++] = value;
     97       }
     98     }
     99   } else {
    100     CHECK_EQ(avoided_values.size(), size_t{0})
    101         << "avoided_values only supported with unique=true";
    102     for (int i = 0; i < batch_size; i++) {
    103       batch[i] = Sample(rnd);
    104     }
    105     num_tries = batch_size;
    106   }
    107   // Compute the expected counts of the batch and the extra values
    108   if (!batch_expected_count.empty()) {
    109     CHECK_EQ(batch_size, batch_expected_count.size());
    110     for (int i = 0; i < batch_size; i++) {
    111       batch_expected_count[i] =
    112           ExpectedCountHelper(Probability(batch[i]), batch_size, num_tries);
    113     }
    114   }
    115   CHECK_EQ(extras.size(), extras_expected_count.size());
    116   for (size_t i = 0; i < extras.size(); i++) {
    117     extras_expected_count[i] =
    118         ExpectedCountHelper(Probability(extras[i]), batch_size, num_tries);
    119   }
    120 }
    121 
    122 AllSampler::AllSampler(int64 range) : RangeSampler(range) {}
    123 
    124 void AllSampler::SampleBatchGetExpectedCountAvoid(
    125     random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch,
    126     MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras,
    127     MutableArraySlice<float> extras_expected_count,
    128     ArraySlice<int64> avoided_values) const {
    129   const int batch_size = batch.size();
    130   CHECK_EQ(range_, batch_size);
    131   for (int i = 0; i < batch_size; i++) {
    132     batch[i] = i;
    133   }
    134   if (!batch_expected_count.empty()) {
    135     CHECK_EQ(batch_size, batch_expected_count.size());
    136     for (int i = 0; i < batch_size; i++) {
    137       batch_expected_count[i] = 1;
    138     }
    139   }
    140   CHECK_EQ(size_t{0}, avoided_values.size());
    141   CHECK_EQ(extras.size(), extras_expected_count.size());
    142   for (size_t i = 0; i < extras.size(); i++) {
    143     extras_expected_count[i] = 1;
    144   }
    145 }
    146 
    147 UniformSampler::UniformSampler(int64 range)
    148     : RangeSampler(range), inv_range_(1.0 / range) {}
    149 
    150 int64 UniformSampler::Sample(random::SimplePhilox* rnd) const {
    151   return rnd->Uniform64(range_);
    152 }
    153 
    154 float UniformSampler::Probability(int64 value) const { return inv_range_; }
    155 
    156 LogUniformSampler::LogUniformSampler(int64 range)
    157     : RangeSampler(range), log_range_(log(range + 1)) {}
    158 
    159 int64 LogUniformSampler::Sample(random::SimplePhilox* rnd) const {
    160   const int64 value =
    161       static_cast<int64>(exp(rnd->RandDouble() * log_range_)) - 1;
    162   CHECK_GE(value, 0);
    163   // Mathematically, value should be <= range_, but might not be due to some
    164   // floating point roundoff, so we mod by range_.
    165   return value % range_;
    166 }
    167 
    168 float LogUniformSampler::Probability(int64 value) const {
    169   // value is returned iff the call to UniformDouble(log_range_) in the
    170   // Sample() function returns a value between log(value + 1)
    171   // and log(value + 2).   The probability of this is:
    172   // (log(value + 2) - log(value + 1)) / log_range
    173   // To avoid two calls to log(), we compute this as follows:
    174   return (log((value + 2.0) / (value + 1.0))) / log_range_;
    175 }
    176 
    177 ThreadUnsafeUnigramSampler::ThreadUnsafeUnigramSampler(int64 range)
    178     : RangeSampler(range), picker_(range) {
    179   CHECK_LT(range, kint32max);
    180 }
    181 
    182 int64 ThreadUnsafeUnigramSampler::Sample(random::SimplePhilox* rnd) const {
    183   return picker_.Pick(rnd);
    184 }
    185 
    186 float ThreadUnsafeUnigramSampler::Probability(int64 value) const {
    187   return static_cast<float>(picker_.get_weight(value)) / picker_.total_weight();
    188 }
    189 
    190 void ThreadUnsafeUnigramSampler::Update(ArraySlice<int64> values) {
    191   int num_updates = std::min(static_cast<int>(values.size()),
    192                              kint32max - picker_.total_weight());
    193   for (int i = 0; i < num_updates; i++) {
    194     const int64 value = values[i];
    195     picker_.set_weight(value, picker_.get_weight(value) + 1);
    196   }
    197 }
    198 
    199 // Thread-safe unigram sampler
    200 UnigramSampler::UnigramSampler(int64 range)
    201     : RangeSampler(range), unsafe_sampler_(range) {
    202   CHECK_LT(range, kint32max);
    203 }
    204 
    205 int64 UnigramSampler::Sample(random::SimplePhilox* rnd) const {
    206   mutex_lock lock(mu_);  // could use reader lock
    207   return unsafe_sampler_.Sample(rnd);
    208 }
    209 
    210 float UnigramSampler::Probability(int64 value) const {
    211   mutex_lock lock(mu_);  // could use reader lock
    212   return unsafe_sampler_.Probability(value);
    213 }
    214 
    215 // Overriding at a high level results in far fewer lock acquisitions.
    216 void UnigramSampler::SampleBatchGetExpectedCountAvoid(
    217     random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch,
    218     MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras,
    219     MutableArraySlice<float> extras_expected_count,
    220     ArraySlice<int64> avoided_values) const {
    221   mutex_lock lock(mu_);  // could use reader lock
    222   unsafe_sampler_.SampleBatchGetExpectedCountAvoid(
    223       rnd, unique, batch, batch_expected_count, extras, extras_expected_count,
    224       avoided_values);
    225 }
    226 
    227 void UnigramSampler::Update(ArraySlice<int64> values) {
    228   mutex_lock lock(mu_);
    229   unsafe_sampler_.Update(values);
    230 }
    231 
    232 FixedUnigramSampler::FixedUnigramSampler(Env* env, int64 range,
    233                                          const string& vocab_file,
    234                                          float distortion,
    235                                          int32 num_reserved_ids,
    236                                          int32 num_shards, int32 shard)
    237     : RangeSampler(range),
    238       total_weight_(0.0),
    239       num_shards_(num_shards),
    240       shard_(shard) {
    241   FillReservedIds(num_reserved_ids);
    242   // TODO(vanhoucke): make this non-crashing.
    243   TF_CHECK_OK(LoadFromFile(env, vocab_file, distortion));
    244   CHECK_EQ(range, weights_.size());
    245   dist_sampler_.reset(new random::DistributionSampler(weights_));
    246 }
    247 
    248 FixedUnigramSampler::FixedUnigramSampler(int64 range,
    249                                          const std::vector<float>& unigrams,
    250                                          float distortion,
    251                                          int32 num_reserved_ids,
    252                                          int32 num_shards, int32 shard)
    253     : RangeSampler(range),
    254       total_weight_(0.0),
    255       num_shards_(num_shards),
    256       shard_(shard) {
    257   FillReservedIds(num_reserved_ids);
    258   LoadFromUnigrams(unigrams, distortion);
    259   // TODO(vanhoucke): make this non-crashing.
    260   CHECK_EQ(range, weights_.size());
    261   dist_sampler_.reset(new random::DistributionSampler(weights_));
    262 }
    263 
    264 float FixedUnigramSampler::Probability(int64 value) const {
    265   if (value < 0 || static_cast<size_t>(value) >= weights_.size()) {
    266     return 0.0;
    267   }
    268   return weights_.at(value) / total_weight_;
    269 }
    270 
    271 int64 FixedUnigramSampler::Sample(random::SimplePhilox* rnd) const {
    272   return dist_sampler_->Sample(rnd);
    273 }
    274 
    275 void FixedUnigramSampler::FillReservedIds(int32 num_reserved_ids) {
    276   for (int32 word_id = 0; word_id < num_reserved_ids; ++word_id) {
    277     if (word_id % num_shards_ == shard_) weights_.push_back(0.0);
    278   }
    279 }
    280 
    281 Status FixedUnigramSampler::LoadFromFile(Env* env, const string& vocab_file,
    282                                          float distortion) {
    283   std::unique_ptr<RandomAccessFile> file;
    284   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(vocab_file, &file));
    285 
    286   io::InputBuffer in(file.get(), 262144 /*bytes*/);
    287   string line;
    288   int32 word_id = weights_.size();
    289   while (in.ReadLine(&line).ok()) {
    290     // The vocabulary file should be in csv like format, with the last
    291     // field the weight associated with the word.
    292     std::vector<string> cols = str_util::Split(line, ',');
    293     if (cols.empty()) continue;
    294     // Skip entries that do not belong to this shard.
    295     if (word_id % num_shards_ == shard_) {
    296       float w = 0.0;
    297       if (!strings::safe_strtof(cols.at(cols.size() - 1).c_str(), &w)) {
    298         return errors::InvalidArgument("Wrong vocabulary format at line: ",
    299                                        line);
    300       }
    301       w = pow(w, distortion);
    302       total_weight_ += w;
    303       weights_.push_back(w);
    304     }
    305     ++word_id;
    306   }
    307   return Status::OK();
    308 }
    309 
    310 void FixedUnigramSampler::LoadFromUnigrams(const std::vector<float>& unigrams,
    311                                            float distortion) {
    312   int32 word_id = weights_.size();
    313   for (float w : unigrams) {
    314     // Skip entries that do not belong to this shard.
    315     if (word_id % num_shards_ == shard_) {
    316       w = pow(w, distortion);
    317       total_weight_ += w;
    318       weights_.push_back(w);
    319     }
    320     ++word_id;
    321   }
    322 }
    323 
    324 }  // namespace tensorflow
    325