Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HYPERPLANE_LSH_PROBES_H_
     17 #define TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HYPERPLANE_LSH_PROBES_H_
     18 
     19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     20 
     21 #include "tensorflow/contrib/nearest_neighbor/kernels/heap.h"
     22 
     23 namespace tensorflow {
     24 namespace nearest_neighbor {
     25 
     26 // This class implements hyperplane multiprobe LSH as described in the
     27 // following paper:
     28 //
     29 //   Multi-probe LSH: efficient indexing for high-dimensional similarity search
     30 //   Qin Lv, William Josephson, Zhe Wang, Moses Charikar, Kai Li
     31 //
     32 // The class is only responsible for generating the probing sequence of given
     33 // length for a given batch of points. The actual hash table lookups are
     34 // implemented in other classes.
     35 template <typename CoordinateType, typename HashType>
     36 class HyperplaneMultiprobe {
     37  public:
     38   using Matrix = Eigen::Matrix<CoordinateType, Eigen::Dynamic, Eigen::Dynamic,
     39                                Eigen::RowMajor>;
     40   using ConstMatrixMap = Eigen::Map<const Matrix>;
     41   using MatrixMap = Eigen::Map<Matrix>;
     42   using Vector =
     43       Eigen::Matrix<CoordinateType, Eigen::Dynamic, 1, Eigen::ColMajor>;
     44 
     45   HyperplaneMultiprobe(int num_hyperplanes_per_table, int num_tables)
     46       : num_hyperplanes_per_table_(num_hyperplanes_per_table),
     47         num_tables_(num_tables),
     48         num_probes_(0),
     49         cur_probe_counter_(0),
     50         sorted_hyperplane_indices_(0),
     51         main_table_probe_(num_tables) {}
     52 
     53   // The first input hash_vector is the matrix-vector product between the
     54   // hyperplane matrix and the vector for which we want to generate a probing
     55   // sequence. We assume that each index in hash_vector is proportional to the
     56   // distance between vector and hyperplane (i.e., the hyperplane vectors should
     57   // all have the same norm).
     58   //
     59   // The second input is the number of probes we want to retrieve. If this
     60   // number is fixed in advance, it should be passed in here in order to enable
     61   // some (minor) internal optimizations. If the number of probes it not known
     62   // in advance, the multiprobe sequence can still produce an arbitrary length
     63   // probing sequence (up to the maximum number of probes) by calling
     64   // get_next_probe multiple times.
     65   //
     66   // If num_probes is at most num_tables, it is not necessary to generate an
     67   // actual multiprobe sequence and the multiprobe object will simply return
     68   // the "standard" LSH probes without incurring any multiprobe overhead.
     69   void SetupProbing(const Vector& hash_vector, int_fast64_t num_probes) {
     70     // We accept a copy here for now.
     71     hash_vector_ = hash_vector;
     72     num_probes_ = num_probes;
     73     cur_probe_counter_ = -1;
     74 
     75     // Compute the initial probes for each table, i.e., the "true" hash
     76     // locations LSH without multiprobe would give.
     77     for (int_fast32_t ii = 0; ii < num_tables_; ++ii) {
     78       main_table_probe_[ii] = 0;
     79       for (int_fast32_t jj = 0; jj < num_hyperplanes_per_table_; ++jj) {
     80         main_table_probe_[ii] = main_table_probe_[ii] << 1;
     81         main_table_probe_[ii] =
     82             main_table_probe_[ii] |
     83             (hash_vector_[ii * num_hyperplanes_per_table_ + jj] >= 0.0);
     84       }
     85     }
     86 
     87     if (num_probes_ >= 0 && num_probes_ <= num_tables_) {
     88       return;
     89     }
     90 
     91     if (sorted_hyperplane_indices_.size() == 0) {
     92       sorted_hyperplane_indices_.resize(num_tables_);
     93       for (int_fast32_t ii = 0; ii < num_tables_; ++ii) {
     94         sorted_hyperplane_indices_[ii].resize(num_hyperplanes_per_table_);
     95         for (int_fast32_t jj = 0; jj < num_hyperplanes_per_table_; ++jj) {
     96           sorted_hyperplane_indices_[ii][jj] = jj;
     97         }
     98       }
     99     }
    100 
    101     for (int_fast32_t ii = 0; ii < num_tables_; ++ii) {
    102       HyperplaneComparator comp(hash_vector_, ii * num_hyperplanes_per_table_);
    103       std::sort(sorted_hyperplane_indices_[ii].begin(),
    104                 sorted_hyperplane_indices_[ii].end(), comp);
    105     }
    106 
    107     if (num_probes_ >= 0) {
    108       heap_.Resize(2 * num_probes_);
    109     }
    110     heap_.Reset();
    111     for (int_fast32_t ii = 0; ii < num_tables_; ++ii) {
    112       int_fast32_t best_index = sorted_hyperplane_indices_[ii][0];
    113       CoordinateType score =
    114           hash_vector_[ii * num_hyperplanes_per_table_ + best_index];
    115       score = score * score;
    116       HashType hash_mask = 1;
    117       hash_mask = hash_mask << (num_hyperplanes_per_table_ - best_index - 1);
    118       heap_.InsertUnsorted(score, ProbeCandidate(ii, hash_mask, 0));
    119     }
    120     heap_.Heapify();
    121   }
    122 
    123   // This method stores the current probe (= hash table location) and
    124   // corresponding table in the output parameters. The return value indicates
    125   // whether this succeeded (true) or the current probing sequence is exhausted
    126   // (false). Here, we say a probing sequence is exhausted if one of the
    127   // following two conditions occurs:
    128   // - We have used a non-negative value for num_probes in setup_probing, and
    129   //   we have produced this many number of probes in the current sequence.
    130   // - We have used a negative value for num_probes in setup_probing, and we
    131   //   have produced all possible probes in the probing sequence.
    132   bool GetNextProbe(HashType* cur_probe, int_fast32_t* cur_table) {
    133     cur_probe_counter_ += 1;
    134 
    135     if (num_probes_ >= 0 && cur_probe_counter_ >= num_probes_) {
    136       // We are out of probes in the current probing sequence.
    137       return false;
    138     }
    139 
    140     // For the first num_tables_ probes, we directly return the "standard LSH"
    141     // probes to guarantee that they always come first and we avoid any
    142     // multiprobe overhead.
    143     if (cur_probe_counter_ < num_tables_) {
    144       *cur_probe = main_table_probe_[cur_probe_counter_];
    145       *cur_table = cur_probe_counter_;
    146       return true;
    147     }
    148 
    149     // If the heap is empty, the current probing sequence is exhausted.
    150     if (heap_.IsEmpty()) {
    151       return false;
    152     }
    153 
    154     CoordinateType cur_score;
    155     ProbeCandidate cur_candidate;
    156     heap_.ExtractMin(&cur_score, &cur_candidate);
    157     *cur_table = cur_candidate.table_;
    158     int_fast32_t cur_index =
    159         sorted_hyperplane_indices_[*cur_table][cur_candidate.last_index_];
    160     *cur_probe = main_table_probe_[*cur_table] ^ cur_candidate.hash_mask_;
    161 
    162     if (cur_candidate.last_index_ != num_hyperplanes_per_table_ - 1) {
    163       // swapping out the last flipped index
    164       int_fast32_t next_index =
    165           sorted_hyperplane_indices_[*cur_table][cur_candidate.last_index_ + 1];
    166 
    167       // xor out previous bit, xor in new bit.
    168       HashType next_mask =
    169           cur_candidate.hash_mask_ ^
    170           (HashType(1) << (num_hyperplanes_per_table_ - cur_index - 1)) ^
    171           (HashType(1) << (num_hyperplanes_per_table_ - next_index - 1));
    172 
    173       CoordinateType cur_coord =
    174           hash_vector_[*cur_table * num_hyperplanes_per_table_ + cur_index];
    175       CoordinateType next_coord =
    176           hash_vector_[*cur_table * num_hyperplanes_per_table_ + next_index];
    177       CoordinateType next_score =
    178           cur_score - cur_coord * cur_coord + next_coord * next_coord;
    179 
    180       heap_.Insert(next_score, ProbeCandidate(*cur_table, next_mask,
    181                                               cur_candidate.last_index_ + 1));
    182 
    183       // adding a new flipped index
    184       next_mask =
    185           cur_candidate.hash_mask_ ^
    186           (HashType(1) << (num_hyperplanes_per_table_ - next_index - 1));
    187       next_score = cur_score + next_coord * next_coord;
    188 
    189       heap_.Insert(next_score, ProbeCandidate(*cur_table, next_mask,
    190                                               cur_candidate.last_index_ + 1));
    191     }
    192 
    193     return true;
    194   }
    195 
    196  private:
    197   class ProbeCandidate {
    198    public:
    199     ProbeCandidate(int_fast32_t table = 0, HashType hash_mask = 0,
    200                    int_fast32_t last_index = 0)
    201         : table_(table), hash_mask_(hash_mask), last_index_(last_index) {}
    202 
    203     int_fast32_t table_;
    204     HashType hash_mask_;
    205     int_fast32_t last_index_;
    206   };
    207 
    208   class HyperplaneComparator {
    209    public:
    210     HyperplaneComparator(const Vector& values, int_fast32_t offset)
    211         : values_(values), offset_(offset) {}
    212 
    213     bool operator()(int_fast32_t ii, int_fast32_t jj) const {
    214       return std::abs(values_[offset_ + ii]) < std::abs(values_[offset_ + jj]);
    215     }
    216 
    217    private:
    218     const Vector& values_;
    219     int_fast32_t offset_;
    220   };
    221 
    222   int_fast32_t num_hyperplanes_per_table_;
    223   int_fast32_t num_tables_;
    224   int_fast64_t num_probes_;
    225   int_fast64_t cur_probe_counter_;
    226   std::vector<std::vector<int_fast32_t>> sorted_hyperplane_indices_;
    227   std::vector<HashType> main_table_probe_;
    228   SimpleHeap<CoordinateType, ProbeCandidate> heap_;
    229   Vector hash_vector_;
    230 };
    231 
    232 }  // namespace nearest_neighbor
    233 }  // namespace tensorflow
    234 
    235 #endif  // TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HYPERPLANE_LSH_PROBES_H_
    236