1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/lib/random/weighted_picker.h" 17 18 #include <string.h> 19 #include <algorithm> 20 21 #include "tensorflow/core/lib/random/simple_philox.h" 22 23 namespace tensorflow { 24 namespace random { 25 26 WeightedPicker::WeightedPicker(int N) { 27 CHECK_GE(N, 0); 28 N_ = N; 29 30 // Find the number of levels 31 num_levels_ = 1; 32 while (LevelSize(num_levels_ - 1) < N) { 33 num_levels_++; 34 } 35 36 // Initialize the levels 37 level_ = new int32*[num_levels_]; 38 for (int l = 0; l < num_levels_; l++) { 39 level_[l] = new int32[LevelSize(l)]; 40 } 41 42 SetAllWeights(1); 43 } 44 45 WeightedPicker::~WeightedPicker() { 46 for (int l = 0; l < num_levels_; l++) { 47 delete[] level_[l]; 48 } 49 delete[] level_; 50 } 51 52 static int32 UnbiasedUniform(SimplePhilox* r, int32 n) { 53 CHECK_LE(0, n); 54 const uint32 range = ~static_cast<uint32>(0); 55 if (n == 0) { 56 return r->Rand32() * n; 57 } else if (0 == (n & (n - 1))) { 58 // N is a power of two, so just mask off the lower bits. 59 return r->Rand32() & (n - 1); 60 } else { 61 // Reject all numbers that skew the distribution towards 0. 62 63 // Rand32's output is uniform in the half-open interval [0, 2^{32}). 64 // For any interval [m,n), the number of elements in it is n-m. 65 66 uint32 rem = (range % n) + 1; 67 uint32 rnd; 68 69 // rem = ((2^{32}-1) \bmod n) + 1 70 // 1 <= rem <= n 71 72 // NB: rem == n is impossible, since n is not a power of 2 (from 73 // earlier check). 74 75 do { 76 rnd = r->Rand32(); // rnd uniform over [0, 2^{32}) 77 } while (rnd < rem); // reject [0, rem) 78 // rnd is uniform over [rem, 2^{32}) 79 // 80 // The number of elements in the half-open interval is 81 // 82 // 2^{32} - rem = 2^{32} - ((2^{32}-1) \bmod n) - 1 83 // = 2^{32}-1 - ((2^{32}-1) \bmod n) 84 // = n \cdot \lfloor (2^{32}-1)/n \rfloor 85 // 86 // therefore n evenly divides the number of integers in the 87 // interval. 88 // 89 // The function v \rightarrow v % n takes values from [bias, 90 // 2^{32}) to [0, n). Each integer in the range interval [0, n) 91 // will have exactly \lfloor (2^{32}-1)/n \rfloor preimages from 92 // the domain interval. 93 // 94 // Therefore, v % n is uniform over [0, n). QED. 95 96 return rnd % n; 97 } 98 } 99 100 int WeightedPicker::Pick(SimplePhilox* rnd) const { 101 if (total_weight() == 0) return -1; 102 103 // using unbiased uniform distribution to avoid bias 104 // toward low elements resulting from a possible use 105 // of big weights. 106 return PickAt(UnbiasedUniform(rnd, total_weight())); 107 } 108 109 int WeightedPicker::PickAt(int32 weight_index) const { 110 if (weight_index < 0 || weight_index >= total_weight()) return -1; 111 112 int32 position = weight_index; 113 int index = 0; 114 115 for (int l = 1; l < num_levels_; l++) { 116 // Pick left or right child of "level_[l-1][index]" 117 const int32 left_weight = level_[l][2 * index]; 118 if (position < left_weight) { 119 // Descend to left child 120 index = 2 * index; 121 } else { 122 // Descend to right child 123 index = 2 * index + 1; 124 position -= left_weight; 125 } 126 } 127 CHECK_GE(index, 0); 128 CHECK_LT(index, N_); 129 CHECK_LE(position, level_[num_levels_ - 1][index]); 130 return index; 131 } 132 133 void WeightedPicker::set_weight(int index, int32 weight) { 134 assert(index >= 0); 135 assert(index < N_); 136 137 // Adjust the sums all the way up to the root 138 const int32 delta = weight - get_weight(index); 139 for (int l = num_levels_ - 1; l >= 0; l--) { 140 level_[l][index] += delta; 141 index >>= 1; 142 } 143 } 144 145 void WeightedPicker::SetAllWeights(int32 weight) { 146 // Initialize leaves 147 int32* leaves = level_[num_levels_ - 1]; 148 for (int i = 0; i < N_; i++) leaves[i] = weight; 149 for (int i = N_; i < LevelSize(num_levels_ - 1); i++) leaves[i] = 0; 150 151 // Now sum up towards the root 152 RebuildTreeWeights(); 153 } 154 155 void WeightedPicker::SetWeightsFromArray(int N, const int32* weights) { 156 Resize(N); 157 158 // Initialize leaves 159 int32* leaves = level_[num_levels_ - 1]; 160 for (int i = 0; i < N_; i++) leaves[i] = weights[i]; 161 for (int i = N_; i < LevelSize(num_levels_ - 1); i++) leaves[i] = 0; 162 163 // Now sum up towards the root 164 RebuildTreeWeights(); 165 } 166 167 void WeightedPicker::RebuildTreeWeights() { 168 for (int l = num_levels_ - 2; l >= 0; l--) { 169 int32* level = level_[l]; 170 int32* children = level_[l + 1]; 171 for (int i = 0; i < LevelSize(l); i++) { 172 level[i] = children[2 * i] + children[2 * i + 1]; 173 } 174 } 175 } 176 177 void WeightedPicker::Append(int32 weight) { 178 Resize(num_elements() + 1); 179 set_weight(num_elements() - 1, weight); 180 } 181 182 void WeightedPicker::Resize(int new_size) { 183 CHECK_GE(new_size, 0); 184 if (new_size <= LevelSize(num_levels_ - 1)) { 185 // The new picker fits in the existing levels. 186 187 // First zero out any of the weights that are being dropped so 188 // that the levels are correct (only needed when shrinking) 189 for (int i = new_size; i < N_; i++) { 190 set_weight(i, 0); 191 } 192 193 // We do not need to set any new weights when enlarging because 194 // the unneeded entries always have weight zero. 195 N_ = new_size; 196 return; 197 } 198 199 // We follow the simple strategy of just copying the old 200 // WeightedPicker into a new WeightedPicker. The cost is 201 // O(N) regardless. 202 assert(new_size > N_); 203 WeightedPicker new_picker(new_size); 204 int32* dst = new_picker.level_[new_picker.num_levels_ - 1]; 205 int32* src = this->level_[this->num_levels_ - 1]; 206 memcpy(dst, src, sizeof(dst[0]) * N_); 207 memset(dst + N_, 0, sizeof(dst[0]) * (new_size - N_)); 208 new_picker.RebuildTreeWeights(); 209 210 // Now swap the two pickers 211 std::swap(new_picker.N_, this->N_); 212 std::swap(new_picker.num_levels_, this->num_levels_); 213 std::swap(new_picker.level_, this->level_); 214 assert(this->N_ == new_size); 215 } 216 217 } // namespace random 218 } // namespace tensorflow 219