Home | History | Annotate | Download | only in random
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/lib/random/weighted_picker.h"
     17 
     18 #include <string.h>
     19 #include <algorithm>
     20 
     21 #include "tensorflow/core/lib/random/simple_philox.h"
     22 
     23 namespace tensorflow {
     24 namespace random {
     25 
     26 WeightedPicker::WeightedPicker(int N) {
     27   CHECK_GE(N, 0);
     28   N_ = N;
     29 
     30   // Find the number of levels
     31   num_levels_ = 1;
     32   while (LevelSize(num_levels_ - 1) < N) {
     33     num_levels_++;
     34   }
     35 
     36   // Initialize the levels
     37   level_ = new int32*[num_levels_];
     38   for (int l = 0; l < num_levels_; l++) {
     39     level_[l] = new int32[LevelSize(l)];
     40   }
     41 
     42   SetAllWeights(1);
     43 }
     44 
     45 WeightedPicker::~WeightedPicker() {
     46   for (int l = 0; l < num_levels_; l++) {
     47     delete[] level_[l];
     48   }
     49   delete[] level_;
     50 }
     51 
     52 static int32 UnbiasedUniform(SimplePhilox* r, int32 n) {
     53   CHECK_LE(0, n);
     54   const uint32 range = ~static_cast<uint32>(0);
     55   if (n == 0) {
     56     return r->Rand32() * n;
     57   } else if (0 == (n & (n - 1))) {
     58     // N is a power of two, so just mask off the lower bits.
     59     return r->Rand32() & (n - 1);
     60   } else {
     61     // Reject all numbers that skew the distribution towards 0.
     62 
     63     // Rand32's output is uniform in the half-open interval [0, 2^{32}).
     64     // For any interval [m,n), the number of elements in it is n-m.
     65 
     66     uint32 rem = (range % n) + 1;
     67     uint32 rnd;
     68 
     69     // rem = ((2^{32}-1) \bmod n) + 1
     70     // 1 <= rem <= n
     71 
     72     // NB: rem == n is impossible, since n is not a power of 2 (from
     73     // earlier check).
     74 
     75     do {
     76       rnd = r->Rand32();  // rnd uniform over [0, 2^{32})
     77     } while (rnd < rem);  // reject [0, rem)
     78     // rnd is uniform over [rem, 2^{32})
     79     //
     80     // The number of elements in the half-open interval is
     81     //
     82     //  2^{32} - rem = 2^{32} - ((2^{32}-1) \bmod n) - 1
     83     //               = 2^{32}-1 - ((2^{32}-1) \bmod n)
     84     //               = n \cdot \lfloor (2^{32}-1)/n \rfloor
     85     //
     86     // therefore n evenly divides the number of integers in the
     87     // interval.
     88     //
     89     // The function v \rightarrow v % n takes values from [bias,
     90     // 2^{32}) to [0, n).  Each integer in the range interval [0, n)
     91     // will have exactly \lfloor (2^{32}-1)/n \rfloor preimages from
     92     // the domain interval.
     93     //
     94     // Therefore, v % n is uniform over [0, n).  QED.
     95 
     96     return rnd % n;
     97   }
     98 }
     99 
    100 int WeightedPicker::Pick(SimplePhilox* rnd) const {
    101   if (total_weight() == 0) return -1;
    102 
    103   // using unbiased uniform distribution to avoid bias
    104   // toward low elements resulting from a possible use
    105   // of big weights.
    106   return PickAt(UnbiasedUniform(rnd, total_weight()));
    107 }
    108 
    109 int WeightedPicker::PickAt(int32 weight_index) const {
    110   if (weight_index < 0 || weight_index >= total_weight()) return -1;
    111 
    112   int32 position = weight_index;
    113   int index = 0;
    114 
    115   for (int l = 1; l < num_levels_; l++) {
    116     // Pick left or right child of "level_[l-1][index]"
    117     const int32 left_weight = level_[l][2 * index];
    118     if (position < left_weight) {
    119       // Descend to left child
    120       index = 2 * index;
    121     } else {
    122       // Descend to right child
    123       index = 2 * index + 1;
    124       position -= left_weight;
    125     }
    126   }
    127   CHECK_GE(index, 0);
    128   CHECK_LT(index, N_);
    129   CHECK_LE(position, level_[num_levels_ - 1][index]);
    130   return index;
    131 }
    132 
    133 void WeightedPicker::set_weight(int index, int32 weight) {
    134   assert(index >= 0);
    135   assert(index < N_);
    136 
    137   // Adjust the sums all the way up to the root
    138   const int32 delta = weight - get_weight(index);
    139   for (int l = num_levels_ - 1; l >= 0; l--) {
    140     level_[l][index] += delta;
    141     index >>= 1;
    142   }
    143 }
    144 
    145 void WeightedPicker::SetAllWeights(int32 weight) {
    146   // Initialize leaves
    147   int32* leaves = level_[num_levels_ - 1];
    148   for (int i = 0; i < N_; i++) leaves[i] = weight;
    149   for (int i = N_; i < LevelSize(num_levels_ - 1); i++) leaves[i] = 0;
    150 
    151   // Now sum up towards the root
    152   RebuildTreeWeights();
    153 }
    154 
    155 void WeightedPicker::SetWeightsFromArray(int N, const int32* weights) {
    156   Resize(N);
    157 
    158   // Initialize leaves
    159   int32* leaves = level_[num_levels_ - 1];
    160   for (int i = 0; i < N_; i++) leaves[i] = weights[i];
    161   for (int i = N_; i < LevelSize(num_levels_ - 1); i++) leaves[i] = 0;
    162 
    163   // Now sum up towards the root
    164   RebuildTreeWeights();
    165 }
    166 
    167 void WeightedPicker::RebuildTreeWeights() {
    168   for (int l = num_levels_ - 2; l >= 0; l--) {
    169     int32* level = level_[l];
    170     int32* children = level_[l + 1];
    171     for (int i = 0; i < LevelSize(l); i++) {
    172       level[i] = children[2 * i] + children[2 * i + 1];
    173     }
    174   }
    175 }
    176 
    177 void WeightedPicker::Append(int32 weight) {
    178   Resize(num_elements() + 1);
    179   set_weight(num_elements() - 1, weight);
    180 }
    181 
    182 void WeightedPicker::Resize(int new_size) {
    183   CHECK_GE(new_size, 0);
    184   if (new_size <= LevelSize(num_levels_ - 1)) {
    185     // The new picker fits in the existing levels.
    186 
    187     // First zero out any of the weights that are being dropped so
    188     // that the levels are correct (only needed when shrinking)
    189     for (int i = new_size; i < N_; i++) {
    190       set_weight(i, 0);
    191     }
    192 
    193     // We do not need to set any new weights when enlarging because
    194     // the unneeded entries always have weight zero.
    195     N_ = new_size;
    196     return;
    197   }
    198 
    199   // We follow the simple strategy of just copying the old
    200   // WeightedPicker into a new WeightedPicker.  The cost is
    201   // O(N) regardless.
    202   assert(new_size > N_);
    203   WeightedPicker new_picker(new_size);
    204   int32* dst = new_picker.level_[new_picker.num_levels_ - 1];
    205   int32* src = this->level_[this->num_levels_ - 1];
    206   memcpy(dst, src, sizeof(dst[0]) * N_);
    207   memset(dst + N_, 0, sizeof(dst[0]) * (new_size - N_));
    208   new_picker.RebuildTreeWeights();
    209 
    210   // Now swap the two pickers
    211   std::swap(new_picker.N_, this->N_);
    212   std::swap(new_picker.num_levels_, this->num_levels_);
    213   std::swap(new_picker.level_, this->level_);
    214   assert(this->N_ == new_size);
    215 }
    216 
    217 }  // namespace random
    218 }  // namespace tensorflow
    219