1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.h" 17 18 #include <vector> 19 20 #include "tensorflow/core/kernels/ops_testutil.h" 21 22 namespace { 23 24 using tensorflow::uint32; 25 26 typedef tensorflow::nearest_neighbor::HyperplaneMultiprobe<float, uint32> 27 Multiprobe; 28 29 void CheckSequenceSingleTable(Multiprobe* multiprobe, 30 const std::vector<uint32>& expected_probes) { 31 uint32 cur_probe; 32 int_fast32_t cur_table; 33 for (int ii = 0; ii < expected_probes.size(); ++ii) { 34 ASSERT_TRUE(multiprobe->GetNextProbe(&cur_probe, &cur_table)); 35 EXPECT_EQ(expected_probes[ii], cur_probe); 36 EXPECT_EQ(0, cur_table); 37 } 38 } 39 40 void CheckSequenceMultipleTables( 41 Multiprobe* multiprobe, 42 const std::vector<std::pair<uint32, int_fast32_t>>& expected_result) { 43 uint32 cur_probe; 44 int_fast32_t cur_table; 45 for (int ii = 0; ii < expected_result.size(); ++ii) { 46 ASSERT_TRUE(multiprobe->GetNextProbe(&cur_probe, &cur_table)); 47 EXPECT_EQ(expected_result[ii].first, cur_probe); 48 EXPECT_EQ(expected_result[ii].second, cur_table); 49 } 50 } 51 52 // Just the first two probes for two tables and two hyperplanes pro table. 53 TEST(HyperplaneMultiprobeTest, SimpleTest1) { 54 Multiprobe multiprobe(2, 2); 55 Multiprobe::Vector hash_vector(4); 56 hash_vector << -1.0, 1.0, 1.0, -1.0; 57 std::vector<std::pair<uint32, int_fast32_t>> expected_result = {{1, 0}, 58 {2, 1}}; 59 multiprobe.SetupProbing(hash_vector, expected_result.size()); 60 CheckSequenceMultipleTables(&multiprobe, expected_result); 61 } 62 63 // Checking that the beginning of a probing sequence for a single table is 64 // correct. 65 TEST(HyperplaneMultiprobeTest, SimpleTest2) { 66 Multiprobe multiprobe(4, 1); 67 Multiprobe::Vector hash_vector(4); 68 hash_vector << -2.0, -0.9, -0.8, -0.7; 69 std::vector<uint32> expected_result = {0, 1, 2, 4, 3}; 70 multiprobe.SetupProbing(hash_vector, expected_result.size()); 71 CheckSequenceSingleTable(&multiprobe, expected_result); 72 } 73 74 // Checking that the probing sequence for a single table is exhaustive. 75 TEST(HyperplaneMultiprobeTest, SimpleTest3) { 76 Multiprobe multiprobe(3, 1); 77 Multiprobe::Vector hash_vector(3); 78 hash_vector << -1.0, -10.0, -0.1; 79 std::vector<uint32> expected_result = {0, 1, 4, 5, 2, 3, 6, 7}; 80 multiprobe.SetupProbing(hash_vector, expected_result.size()); 81 CheckSequenceSingleTable(&multiprobe, expected_result); 82 } 83 84 // Checking that the probing sequence is generated correctly across tables. 85 TEST(HyperplaneMultiprobeTest, SimpleTest4) { 86 Multiprobe multiprobe(2, 2); 87 Multiprobe::Vector hash_vector(4); 88 hash_vector << -0.2, 0.9, 0.1, -1.0; 89 std::vector<std::pair<uint32, int_fast32_t>> expected_result = { 90 {1, 0}, {2, 1}, {0, 1}, {3, 0}, {0, 0}, {2, 0}, {3, 1}, {1, 1}}; 91 multiprobe.SetupProbing(hash_vector, expected_result.size()); 92 CheckSequenceMultipleTables(&multiprobe, expected_result); 93 } 94 95 // Slightly larger test that checks whether we have an exhaustive probing 96 // sequence (but this test does not check the order). 97 TEST(HyperplaneMultiprobeTest, ExhaustiveTest1) { 98 int dim = 8; 99 int num_tables = 10; 100 Multiprobe multiprobe(dim, num_tables); 101 Multiprobe::Vector hash_vector(dim * num_tables); 102 103 std::mt19937 random_generator(487344882); 104 std::normal_distribution<float> distribution(0.0, 1.0); 105 for (int ii = 0; ii < dim * num_tables; ++ii) { 106 hash_vector[ii] = distribution(random_generator); 107 } 108 109 std::vector<std::vector<bool>> checked_cell(num_tables); 110 for (int ii = 0; ii < num_tables; ++ii) { 111 checked_cell[ii].resize(1 << dim); 112 std::fill(checked_cell[ii].begin(), checked_cell[ii].end(), false); 113 } 114 115 int num_probes = (1 << dim) * num_tables; 116 multiprobe.SetupProbing(hash_vector, num_probes); 117 uint32 cur_probe; 118 int_fast32_t cur_table; 119 for (int ii = 0; ii < num_probes; ++ii) { 120 ASSERT_TRUE(multiprobe.GetNextProbe(&cur_probe, &cur_table)); 121 ASSERT_LE(0, cur_probe); 122 ASSERT_LT(cur_probe, 1 << dim); 123 ASSERT_LE(0, cur_table); 124 ASSERT_LT(cur_table, num_tables); 125 EXPECT_FALSE(checked_cell[cur_table][cur_probe]); 126 checked_cell[cur_table][cur_probe] = true; 127 } 128 } 129 130 } // namespace 131