1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HYPERPLANE_LSH_PROBES_H_ 17 #define TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HYPERPLANE_LSH_PROBES_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 21 #include "tensorflow/contrib/nearest_neighbor/kernels/heap.h" 22 23 namespace tensorflow { 24 namespace nearest_neighbor { 25 26 // This class implements hyperplane multiprobe LSH as described in the 27 // following paper: 28 // 29 // Multi-probe LSH: efficient indexing for high-dimensional similarity search 30 // Qin Lv, William Josephson, Zhe Wang, Moses Charikar, Kai Li 31 // 32 // The class is only responsible for generating the probing sequence of given 33 // length for a given batch of points. The actual hash table lookups are 34 // implemented in other classes. 35 template <typename CoordinateType, typename HashType> 36 class HyperplaneMultiprobe { 37 public: 38 using Matrix = Eigen::Matrix<CoordinateType, Eigen::Dynamic, Eigen::Dynamic, 39 Eigen::RowMajor>; 40 using ConstMatrixMap = Eigen::Map<const Matrix>; 41 using MatrixMap = Eigen::Map<Matrix>; 42 using Vector = 43 Eigen::Matrix<CoordinateType, Eigen::Dynamic, 1, Eigen::ColMajor>; 44 45 HyperplaneMultiprobe(int num_hyperplanes_per_table, int num_tables) 46 : num_hyperplanes_per_table_(num_hyperplanes_per_table), 47 num_tables_(num_tables), 48 num_probes_(0), 49 cur_probe_counter_(0), 50 sorted_hyperplane_indices_(0), 51 main_table_probe_(num_tables) {} 52 53 // The first input hash_vector is the matrix-vector product between the 54 // hyperplane matrix and the vector for which we want to generate a probing 55 // sequence. We assume that each index in hash_vector is proportional to the 56 // distance between vector and hyperplane (i.e., the hyperplane vectors should 57 // all have the same norm). 58 // 59 // The second input is the number of probes we want to retrieve. If this 60 // number is fixed in advance, it should be passed in here in order to enable 61 // some (minor) internal optimizations. If the number of probes it not known 62 // in advance, the multiprobe sequence can still produce an arbitrary length 63 // probing sequence (up to the maximum number of probes) by calling 64 // get_next_probe multiple times. 65 // 66 // If num_probes is at most num_tables, it is not necessary to generate an 67 // actual multiprobe sequence and the multiprobe object will simply return 68 // the "standard" LSH probes without incurring any multiprobe overhead. 69 void SetupProbing(const Vector& hash_vector, int_fast64_t num_probes) { 70 // We accept a copy here for now. 71 hash_vector_ = hash_vector; 72 num_probes_ = num_probes; 73 cur_probe_counter_ = -1; 74 75 // Compute the initial probes for each table, i.e., the "true" hash 76 // locations LSH without multiprobe would give. 77 for (int_fast32_t ii = 0; ii < num_tables_; ++ii) { 78 main_table_probe_[ii] = 0; 79 for (int_fast32_t jj = 0; jj < num_hyperplanes_per_table_; ++jj) { 80 main_table_probe_[ii] = main_table_probe_[ii] << 1; 81 main_table_probe_[ii] = 82 main_table_probe_[ii] | 83 (hash_vector_[ii * num_hyperplanes_per_table_ + jj] >= 0.0); 84 } 85 } 86 87 if (num_probes_ >= 0 && num_probes_ <= num_tables_) { 88 return; 89 } 90 91 if (sorted_hyperplane_indices_.size() == 0) { 92 sorted_hyperplane_indices_.resize(num_tables_); 93 for (int_fast32_t ii = 0; ii < num_tables_; ++ii) { 94 sorted_hyperplane_indices_[ii].resize(num_hyperplanes_per_table_); 95 for (int_fast32_t jj = 0; jj < num_hyperplanes_per_table_; ++jj) { 96 sorted_hyperplane_indices_[ii][jj] = jj; 97 } 98 } 99 } 100 101 for (int_fast32_t ii = 0; ii < num_tables_; ++ii) { 102 HyperplaneComparator comp(hash_vector_, ii * num_hyperplanes_per_table_); 103 std::sort(sorted_hyperplane_indices_[ii].begin(), 104 sorted_hyperplane_indices_[ii].end(), comp); 105 } 106 107 if (num_probes_ >= 0) { 108 heap_.Resize(2 * num_probes_); 109 } 110 heap_.Reset(); 111 for (int_fast32_t ii = 0; ii < num_tables_; ++ii) { 112 int_fast32_t best_index = sorted_hyperplane_indices_[ii][0]; 113 CoordinateType score = 114 hash_vector_[ii * num_hyperplanes_per_table_ + best_index]; 115 score = score * score; 116 HashType hash_mask = 1; 117 hash_mask = hash_mask << (num_hyperplanes_per_table_ - best_index - 1); 118 heap_.InsertUnsorted(score, ProbeCandidate(ii, hash_mask, 0)); 119 } 120 heap_.Heapify(); 121 } 122 123 // This method stores the current probe (= hash table location) and 124 // corresponding table in the output parameters. The return value indicates 125 // whether this succeeded (true) or the current probing sequence is exhausted 126 // (false). Here, we say a probing sequence is exhausted if one of the 127 // following two conditions occurs: 128 // - We have used a non-negative value for num_probes in setup_probing, and 129 // we have produced this many number of probes in the current sequence. 130 // - We have used a negative value for num_probes in setup_probing, and we 131 // have produced all possible probes in the probing sequence. 132 bool GetNextProbe(HashType* cur_probe, int_fast32_t* cur_table) { 133 cur_probe_counter_ += 1; 134 135 if (num_probes_ >= 0 && cur_probe_counter_ >= num_probes_) { 136 // We are out of probes in the current probing sequence. 137 return false; 138 } 139 140 // For the first num_tables_ probes, we directly return the "standard LSH" 141 // probes to guarantee that they always come first and we avoid any 142 // multiprobe overhead. 143 if (cur_probe_counter_ < num_tables_) { 144 *cur_probe = main_table_probe_[cur_probe_counter_]; 145 *cur_table = cur_probe_counter_; 146 return true; 147 } 148 149 // If the heap is empty, the current probing sequence is exhausted. 150 if (heap_.IsEmpty()) { 151 return false; 152 } 153 154 CoordinateType cur_score; 155 ProbeCandidate cur_candidate; 156 heap_.ExtractMin(&cur_score, &cur_candidate); 157 *cur_table = cur_candidate.table_; 158 int_fast32_t cur_index = 159 sorted_hyperplane_indices_[*cur_table][cur_candidate.last_index_]; 160 *cur_probe = main_table_probe_[*cur_table] ^ cur_candidate.hash_mask_; 161 162 if (cur_candidate.last_index_ != num_hyperplanes_per_table_ - 1) { 163 // swapping out the last flipped index 164 int_fast32_t next_index = 165 sorted_hyperplane_indices_[*cur_table][cur_candidate.last_index_ + 1]; 166 167 // xor out previous bit, xor in new bit. 168 HashType next_mask = 169 cur_candidate.hash_mask_ ^ 170 (HashType(1) << (num_hyperplanes_per_table_ - cur_index - 1)) ^ 171 (HashType(1) << (num_hyperplanes_per_table_ - next_index - 1)); 172 173 CoordinateType cur_coord = 174 hash_vector_[*cur_table * num_hyperplanes_per_table_ + cur_index]; 175 CoordinateType next_coord = 176 hash_vector_[*cur_table * num_hyperplanes_per_table_ + next_index]; 177 CoordinateType next_score = 178 cur_score - cur_coord * cur_coord + next_coord * next_coord; 179 180 heap_.Insert(next_score, ProbeCandidate(*cur_table, next_mask, 181 cur_candidate.last_index_ + 1)); 182 183 // adding a new flipped index 184 next_mask = 185 cur_candidate.hash_mask_ ^ 186 (HashType(1) << (num_hyperplanes_per_table_ - next_index - 1)); 187 next_score = cur_score + next_coord * next_coord; 188 189 heap_.Insert(next_score, ProbeCandidate(*cur_table, next_mask, 190 cur_candidate.last_index_ + 1)); 191 } 192 193 return true; 194 } 195 196 private: 197 class ProbeCandidate { 198 public: 199 ProbeCandidate(int_fast32_t table = 0, HashType hash_mask = 0, 200 int_fast32_t last_index = 0) 201 : table_(table), hash_mask_(hash_mask), last_index_(last_index) {} 202 203 int_fast32_t table_; 204 HashType hash_mask_; 205 int_fast32_t last_index_; 206 }; 207 208 class HyperplaneComparator { 209 public: 210 HyperplaneComparator(const Vector& values, int_fast32_t offset) 211 : values_(values), offset_(offset) {} 212 213 bool operator()(int_fast32_t ii, int_fast32_t jj) const { 214 return std::abs(values_[offset_ + ii]) < std::abs(values_[offset_ + jj]); 215 } 216 217 private: 218 const Vector& values_; 219 int_fast32_t offset_; 220 }; 221 222 int_fast32_t num_hyperplanes_per_table_; 223 int_fast32_t num_tables_; 224 int_fast64_t num_probes_; 225 int_fast64_t cur_probe_counter_; 226 std::vector<std::vector<int_fast32_t>> sorted_hyperplane_indices_; 227 std::vector<HashType> main_table_probe_; 228 SimpleHeap<CoordinateType, ProbeCandidate> heap_; 229 Vector hash_vector_; 230 }; 231 232 } // namespace nearest_neighbor 233 } // namespace tensorflow 234 235 #endif // TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HYPERPLANE_LSH_PROBES_H_ 236