Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.h"
     17 
     18 #include <vector>
     19 
     20 #include "tensorflow/core/kernels/ops_testutil.h"
     21 
     22 namespace {
     23 
     24 using tensorflow::uint32;
     25 
     26 typedef tensorflow::nearest_neighbor::HyperplaneMultiprobe<float, uint32>
     27     Multiprobe;
     28 
     29 void CheckSequenceSingleTable(Multiprobe* multiprobe,
     30                               const std::vector<uint32>& expected_probes) {
     31   uint32 cur_probe;
     32   int_fast32_t cur_table;
     33   for (int ii = 0; ii < expected_probes.size(); ++ii) {
     34     ASSERT_TRUE(multiprobe->GetNextProbe(&cur_probe, &cur_table));
     35     EXPECT_EQ(expected_probes[ii], cur_probe);
     36     EXPECT_EQ(0, cur_table);
     37   }
     38 }
     39 
     40 void CheckSequenceMultipleTables(
     41     Multiprobe* multiprobe,
     42     const std::vector<std::pair<uint32, int_fast32_t>>& expected_result) {
     43   uint32 cur_probe;
     44   int_fast32_t cur_table;
     45   for (int ii = 0; ii < expected_result.size(); ++ii) {
     46     ASSERT_TRUE(multiprobe->GetNextProbe(&cur_probe, &cur_table));
     47     EXPECT_EQ(expected_result[ii].first, cur_probe);
     48     EXPECT_EQ(expected_result[ii].second, cur_table);
     49   }
     50 }
     51 
     52 // Just the first two probes for two tables and two hyperplanes pro table.
     53 TEST(HyperplaneMultiprobeTest, SimpleTest1) {
     54   Multiprobe multiprobe(2, 2);
     55   Multiprobe::Vector hash_vector(4);
     56   hash_vector << -1.0, 1.0, 1.0, -1.0;
     57   std::vector<std::pair<uint32, int_fast32_t>> expected_result = {{1, 0},
     58                                                                   {2, 1}};
     59   multiprobe.SetupProbing(hash_vector, expected_result.size());
     60   CheckSequenceMultipleTables(&multiprobe, expected_result);
     61 }
     62 
     63 // Checking that the beginning of a probing sequence for a single table is
     64 // correct.
     65 TEST(HyperplaneMultiprobeTest, SimpleTest2) {
     66   Multiprobe multiprobe(4, 1);
     67   Multiprobe::Vector hash_vector(4);
     68   hash_vector << -2.0, -0.9, -0.8, -0.7;
     69   std::vector<uint32> expected_result = {0, 1, 2, 4, 3};
     70   multiprobe.SetupProbing(hash_vector, expected_result.size());
     71   CheckSequenceSingleTable(&multiprobe, expected_result);
     72 }
     73 
     74 // Checking that the probing sequence for a single table is exhaustive.
     75 TEST(HyperplaneMultiprobeTest, SimpleTest3) {
     76   Multiprobe multiprobe(3, 1);
     77   Multiprobe::Vector hash_vector(3);
     78   hash_vector << -1.0, -10.0, -0.1;
     79   std::vector<uint32> expected_result = {0, 1, 4, 5, 2, 3, 6, 7};
     80   multiprobe.SetupProbing(hash_vector, expected_result.size());
     81   CheckSequenceSingleTable(&multiprobe, expected_result);
     82 }
     83 
     84 // Checking that the probing sequence is generated correctly across tables.
     85 TEST(HyperplaneMultiprobeTest, SimpleTest4) {
     86   Multiprobe multiprobe(2, 2);
     87   Multiprobe::Vector hash_vector(4);
     88   hash_vector << -0.2, 0.9, 0.1, -1.0;
     89   std::vector<std::pair<uint32, int_fast32_t>> expected_result = {
     90       {1, 0}, {2, 1}, {0, 1}, {3, 0}, {0, 0}, {2, 0}, {3, 1}, {1, 1}};
     91   multiprobe.SetupProbing(hash_vector, expected_result.size());
     92   CheckSequenceMultipleTables(&multiprobe, expected_result);
     93 }
     94 
     95 // Slightly larger test that checks whether we have an exhaustive probing
     96 // sequence (but this test does not check the order).
     97 TEST(HyperplaneMultiprobeTest, ExhaustiveTest1) {
     98   int dim = 8;
     99   int num_tables = 10;
    100   Multiprobe multiprobe(dim, num_tables);
    101   Multiprobe::Vector hash_vector(dim * num_tables);
    102 
    103   std::mt19937 random_generator(487344882);
    104   std::normal_distribution<float> distribution(0.0, 1.0);
    105   for (int ii = 0; ii < dim * num_tables; ++ii) {
    106     hash_vector[ii] = distribution(random_generator);
    107   }
    108 
    109   std::vector<std::vector<bool>> checked_cell(num_tables);
    110   for (int ii = 0; ii < num_tables; ++ii) {
    111     checked_cell[ii].resize(1 << dim);
    112     std::fill(checked_cell[ii].begin(), checked_cell[ii].end(), false);
    113   }
    114 
    115   int num_probes = (1 << dim) * num_tables;
    116   multiprobe.SetupProbing(hash_vector, num_probes);
    117   uint32 cur_probe;
    118   int_fast32_t cur_table;
    119   for (int ii = 0; ii < num_probes; ++ii) {
    120     ASSERT_TRUE(multiprobe.GetNextProbe(&cur_probe, &cur_table));
    121     ASSERT_LE(0, cur_probe);
    122     ASSERT_LT(cur_probe, 1 << dim);
    123     ASSERT_LE(0, cur_table);
    124     ASSERT_LT(cur_table, num_tables);
    125     EXPECT_FALSE(checked_cell[cur_table][cur_probe]);
    126     checked_cell[cur_table][cur_probe] = true;
    127   }
    128 }
    129 
    130 }  // namespace
    131