Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_RANGE_SAMPLER_H_
     17 #define TENSORFLOW_KERNELS_RANGE_SAMPLER_H_
     18 
     19 #include <vector>
     20 
     21 #include "tensorflow/core/lib/core/status.h"
     22 #include "tensorflow/core/lib/gtl/array_slice.h"
     23 #include "tensorflow/core/lib/random/distribution_sampler.h"
     24 #include "tensorflow/core/lib/random/random_distributions.h"
     25 #include "tensorflow/core/lib/random/weighted_picker.h"
     26 #include "tensorflow/core/platform/logging.h"
     27 #include "tensorflow/core/platform/mutex.h"
     28 #include "tensorflow/core/platform/thread_annotations.h"
     29 #include "tensorflow/core/platform/types.h"
     30 
     31 namespace tensorflow {
     32 
     33 class Env;
     34 
     35 // Abstract subclass for sampling from the set of non-negative integers
     36 // [0, range)
     37 class RangeSampler {
     38  public:
     39   explicit RangeSampler(int64 range) : range_(range) { CHECK_GT(range_, 0); }
     40   virtual ~RangeSampler();
     41 
     42   // Sample a single value
     43   virtual int64 Sample(random::SimplePhilox* rnd) const = 0;
     44 
     45   // The probability that a single call to Sample() returns the given value.
     46   // Assumes that value is in [0, range).  No range checking is done.
     47   virtual float Probability(int64 value) const = 0;
     48 
     49   // Fill "batch" with samples from the distribution.
     50   // If unique=true, then we re-pick each element until we get a
     51   // value distinct from all previously picked values in the batch.
     52   void SampleBatch(random::SimplePhilox* rnd, bool unique,
     53                    gtl::MutableArraySlice<int64> batch) const;
     54 
     55   // Fill "batch" with samples from the distribution, and report
     56   // "expected counts".
     57   //
     58   // The "expected count" of a value is an estimate of the expected
     59   // number of occurrences of the value in the batch returned by a
     60   // call to this function with the given parameters.  If unique=true,
     61   // the expected count is an inclusion probability.  For details on
     62   // this estimation, see the comment to "ExpectedCountHelper" in the
     63   // .cc file.
     64   //
     65   // Expected counts for the elements of the returned "batch" are reported
     66   // in the aligned array "batch_expected_count".
     67   //
     68   // The user can optionally provide "extras", containing values in the range.
     69   // The expected counts for the extras are reported in the aligned array
     70   // "extras_expected_count".
     71   //
     72   // "batch_expected_count" must have size equal to 0 or to the size of "batch".
     73   // "extras" and "extras_expected_count" must have equal size.
     74   void SampleBatchGetExpectedCount(
     75       random::SimplePhilox* rnd, bool unique,
     76       gtl::MutableArraySlice<int64> batch,
     77       gtl::MutableArraySlice<float> batch_expected_count,
     78       gtl::ArraySlice<int64> extras,
     79       gtl::MutableArraySlice<float> extras_expected_count) const;
     80 
     81   // Same as SampleBatchGetExpectedCount (see above), but with avoided values.
     82   // We repick to avoid all of the values in "avoided_values".
     83   // "avoided_values" is only supported with unique=true.  If
     84   // unique=false, then avoided_values must be empty.
     85   virtual void SampleBatchGetExpectedCountAvoid(
     86       random::SimplePhilox* rnd, bool unique,
     87       gtl::MutableArraySlice<int64> batch,
     88       gtl::MutableArraySlice<float> batch_expected_count,
     89       gtl::ArraySlice<int64> extras,
     90       gtl::MutableArraySlice<float> extras_expected_count,
     91       gtl::ArraySlice<int64> avoided_values) const;
     92 
     93   // Does this sampler need to be updated with values, e.g. UnigramSampler
     94   virtual bool NeedsUpdates() const { return false; }
     95 
     96   // Updates the underlying distribution
     97   virtual void Update(gtl::ArraySlice<int64> values) {
     98     LOG(FATAL) << "Update not supported for this sampler type.";
     99   }
    100 
    101   int64 range() { return range_; }
    102 
    103  protected:
    104   const int64 range_;
    105 };
    106 
    107 // An AllSampler only samples batches of size equal to range.
    108 // It returns the entire range.
    109 // It cannot sample single values.
    110 class AllSampler : public RangeSampler {
    111  public:
    112   explicit AllSampler(int64 range);
    113 
    114   ~AllSampler() override {}
    115 
    116   int64 Sample(random::SimplePhilox* rnd) const override {
    117     LOG(FATAL) << "Should not be called";
    118     return 0;
    119   }
    120 
    121   float Probability(int64 value) const override {
    122     LOG(FATAL) << "Should not be called";
    123     return 0;
    124   }
    125 
    126   void SampleBatchGetExpectedCountAvoid(
    127       random::SimplePhilox* rnd, bool unique,
    128       gtl::MutableArraySlice<int64> batch,
    129       gtl::MutableArraySlice<float> batch_expected_count,
    130       gtl::ArraySlice<int64> extras,
    131       gtl::MutableArraySlice<float> extras_expected_count,
    132       gtl::ArraySlice<int64> avoided_values) const override;
    133 };
    134 
    135 class UniformSampler : public RangeSampler {
    136  public:
    137   explicit UniformSampler(int64 range);
    138 
    139   ~UniformSampler() override {}
    140 
    141   int64 Sample(random::SimplePhilox* rnd) const override;
    142 
    143   float Probability(int64 value) const override;
    144 
    145  private:
    146   const float inv_range_;
    147 };
    148 
    149 class LogUniformSampler : public RangeSampler {
    150  public:
    151   explicit LogUniformSampler(int64 range);
    152 
    153   ~LogUniformSampler() override {}
    154 
    155   int64 Sample(random::SimplePhilox* rnd) const override;
    156 
    157   float Probability(int64 value) const override;
    158 
    159  private:
    160   const double log_range_;
    161 };
    162 
    163 // Thread-unsafe unigram sampler
    164 class ThreadUnsafeUnigramSampler : public RangeSampler {
    165  public:
    166   explicit ThreadUnsafeUnigramSampler(int64 range);
    167   ~ThreadUnsafeUnigramSampler() override {}
    168 
    169   int64 Sample(random::SimplePhilox* rnd) const override;
    170 
    171   float Probability(int64 value) const override;
    172 
    173   bool NeedsUpdates() const override { return true; }
    174   void Update(gtl::ArraySlice<int64> values) override;
    175 
    176  private:
    177   random::WeightedPicker picker_;
    178 };
    179 
    180 // Thread-safe unigram sampler
    181 class UnigramSampler : public RangeSampler {
    182  public:
    183   explicit UnigramSampler(int64 range);
    184   ~UnigramSampler() override {}
    185 
    186   int64 Sample(random::SimplePhilox* rnd) const override;
    187 
    188   float Probability(int64 value) const override;
    189 
    190   // Overriding at a high level results in far fewer lock acquisitions.
    191   void SampleBatchGetExpectedCountAvoid(
    192       random::SimplePhilox* rnd, bool unique,
    193       gtl::MutableArraySlice<int64> batch,
    194       gtl::MutableArraySlice<float> batch_expected_count,
    195       gtl::ArraySlice<int64> extras,
    196       gtl::MutableArraySlice<float> extras_expected_count,
    197       gtl::ArraySlice<int64> avoided_values) const override;
    198 
    199   bool NeedsUpdates() const override { return true; }
    200   void Update(gtl::ArraySlice<int64> values) override;
    201 
    202  private:
    203   ThreadUnsafeUnigramSampler unsafe_sampler_ GUARDED_BY(mu_);
    204   mutable mutex mu_;
    205 };
    206 
    207 // A unigram sampler that uses a fixed unigram distribution read from a
    208 // file or passed in as an in-memory array instead of building up the
    209 // distribution from data on the fly. There is also an option to skew the
    210 // distribution by applying a distortion power to the weights.
    211 class FixedUnigramSampler : public RangeSampler {
    212  public:
    213   // The vocab_file is assumed to be a CSV, with the last entry of each row a
    214   // value representing the counts or probabilities for the corresponding ID.
    215   FixedUnigramSampler(Env* env, int64 range, const string& vocab_file,
    216                       float distortion, int32 num_reserved_ids,
    217                       int32 num_shards, int32 shard);
    218 
    219   FixedUnigramSampler(int64 range, const std::vector<float>& unigrams,
    220                       float distortion, int32 num_reserved_ids,
    221                       int32 num_shards, int32 shard);
    222 
    223   float Probability(int64 value) const override;
    224 
    225   int64 Sample(random::SimplePhilox* rnd) const override;
    226 
    227  private:
    228   // Underlying distribution sampler.
    229   std::unique_ptr<random::DistributionSampler> dist_sampler_;
    230   // Weights for individual samples. The probability of a sample i is defined
    231   // as weights_.at(i) / total_weight_.
    232   std::vector<float> weights_;
    233   // The total weights of all samples.
    234   float total_weight_;
    235   // Sharding information of the sampler. The whole vocabulary is sharded
    236   // into num_shards_ smaller ranges and each sampler is responsible for one
    237   // such smaller range, identified by the shard number.
    238   int32 num_shards_;
    239   int32 shard_;
    240 
    241   // Fill the sampler with the appropriate number of reserved IDs.
    242   void FillReservedIds(int32 num_reserved_ids);
    243   // Load IDs to sample from a CSV file. It is assumed that the last item of
    244   // each row contains a count or probability for the corresponding ID.
    245   Status LoadFromFile(Env* env, const string& vocab_file, float distortion);
    246   // Load from an in-memory array.
    247   void LoadFromUnigrams(const std::vector<float>& unigrams, float distortion);
    248 };
    249 
    250 }  // namespace tensorflow
    251 
    252 #endif  // TENSORFLOW_KERNELS_RANGE_SAMPLER_H_
    253