1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <array> 17 18 #include "tensorflow/core/framework/op_kernel.h" 19 #include "tensorflow/core/lib/core/threadpool.h" 20 21 #include "tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.h" 22 23 namespace tensorflow { 24 25 using errors::Internal; 26 using errors::InvalidArgument; 27 28 using nearest_neighbor::HyperplaneMultiprobe; 29 30 // This class wraps the multiprobe LSH code in hyperplane_lsh_probes in a 31 // TensorFlow op implementation. 32 template <typename CoordinateType> 33 class HyperplaneLSHProbesOp : public OpKernel { 34 public: 35 using Matrix = Eigen::Matrix<CoordinateType, Eigen::Dynamic, Eigen::Dynamic, 36 Eigen::RowMajor>; 37 using ConstMatrixMap = Eigen::Map<const Matrix>; 38 using MatrixMap = Eigen::Map<Matrix>; 39 40 explicit HyperplaneLSHProbesOp(OpKernelConstruction* context) 41 : OpKernel(context) {} 42 43 void Compute(OpKernelContext* context) override { 44 // Get the input tensors and check their shapes. 45 const Tensor& products_tensor = context->input(0); 46 OP_REQUIRES(context, products_tensor.dims() == 2, 47 InvalidArgument("Need a two-dimensional products tensor, got ", 48 products_tensor.dims(), " dimensions.")); 49 50 const Tensor& num_tables_tensor = context->input(1); 51 OP_REQUIRES(context, num_tables_tensor.dims() == 0, 52 InvalidArgument("Need a scalar num_tables tensor, got ", 53 num_tables_tensor.dims(), " dimensions.")); 54 int num_tables = num_tables_tensor.scalar<int32>()(); 55 OP_REQUIRES(context, num_tables >= 1, 56 InvalidArgument("num_tables must be at least 1 but got ", 57 num_tables, ".")); 58 OP_REQUIRES(context, num_tables <= 1000, 59 InvalidArgument("Need num_tables <= 1000, got ", num_tables, 60 ". This is mostly to protect against incorrect " 61 "use of this Op. If you really need more tables" 62 ", change the code.")); 63 64 const Tensor& num_hyperplanes_per_table_tensor = context->input(2); 65 OP_REQUIRES(context, num_hyperplanes_per_table_tensor.dims() == 0, 66 InvalidArgument("Need a scalar num_hyperplanes_per_table " 67 "tensor, got ", 68 num_hyperplanes_per_table_tensor.dims(), 69 " dimensions.")); 70 int num_hyperplanes_per_table = 71 num_hyperplanes_per_table_tensor.scalar<int32>()(); 72 OP_REQUIRES(context, num_hyperplanes_per_table >= 1, 73 InvalidArgument("num_hyperplanes_per_table must be at least 1 " 74 "but got ", 75 num_hyperplanes_per_table, ".")); 76 OP_REQUIRES(context, num_hyperplanes_per_table <= 30, 77 InvalidArgument("Need num_hyperplanes_per_table <= 30, got ", 78 num_hyperplanes_per_table, 79 ". " 80 "If you need more hyperplanes, change this Op" 81 " to work for larger integer types (int64).")); 82 83 const Tensor& num_probes_tensor = context->input(3); 84 OP_REQUIRES(context, num_probes_tensor.dims() == 0, 85 InvalidArgument("Need a scalar num_probes tensor, got ", 86 num_probes_tensor.dims(), " dimensions.")); 87 int num_probes = num_probes_tensor.scalar<int32>()(); 88 OP_REQUIRES(context, num_probes >= 1, 89 InvalidArgument("num_probes must be at least 1.")); 90 91 int expected_num_hyperplanes = num_tables * num_hyperplanes_per_table; 92 OP_REQUIRES(context, 93 products_tensor.dim_size(1) == expected_num_hyperplanes, 94 InvalidArgument("Expected number of hyperplanes is ", 95 expected_num_hyperplanes, " but received ", 96 products_tensor.dim_size(1), 97 " inner products per " 98 "point.")); 99 100 auto products_eigen_tensor = products_tensor.matrix<CoordinateType>(); 101 ConstMatrixMap products_matrix(products_eigen_tensor.data(), 102 products_tensor.dim_size(0), 103 products_tensor.dim_size(1)); 104 105 int batch_size = products_tensor.dim_size(0); 106 107 Tensor* probes_tensor = nullptr; 108 Tensor* tables_tensor = nullptr; 109 TensorShape output_shape({batch_size, num_probes}); 110 OP_REQUIRES_OK(context, 111 context->allocate_output(0, output_shape, &probes_tensor)); 112 OP_REQUIRES_OK(context, 113 context->allocate_output(1, output_shape, &tables_tensor)); 114 auto probes_eigen_tensor = probes_tensor->matrix<int32>(); 115 auto tables_eigen_tensor = tables_tensor->matrix<int32>(); 116 117 // Constants (cycles per hyperplane and table) were measured on 118 // lschmidt's workstation. 119 int64 cost_per_unit = 21 * num_hyperplanes_per_table * num_tables; 120 if (num_probes > num_tables) { 121 cost_per_unit += 122 110 * num_hyperplanes_per_table * (num_probes - num_tables); 123 } 124 context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor( 125 batch_size, cost_per_unit, [&](int64 start, int64 end) { 126 HyperplaneMultiprobe<CoordinateType, int32> multiprobe( 127 num_hyperplanes_per_table, num_tables); 128 129 for (int point_index = start; point_index < end; ++point_index) { 130 multiprobe.SetupProbing(products_matrix.row(point_index), 131 num_probes); 132 for (int ii = 0; ii < num_probes; ++ii) { 133 int32 cur_probe; 134 int_fast32_t cur_table; 135 OP_REQUIRES(context, 136 multiprobe.GetNextProbe(&cur_probe, &cur_table), 137 Internal("Failed to get probe number ", ii, 138 " for point number ", point_index, ".")); 139 probes_eigen_tensor(point_index, ii) = cur_probe; 140 tables_eigen_tensor(point_index, ii) = cur_table; 141 } 142 } 143 }); 144 } 145 }; 146 147 REGISTER_KERNEL_BUILDER(Name("HyperplaneLSHProbes") 148 .Device(DEVICE_CPU) 149 .TypeConstraint<float>("CoordinateType"), 150 HyperplaneLSHProbesOp<float>); 151 152 REGISTER_KERNEL_BUILDER(Name("HyperplaneLSHProbes") 153 .Device(DEVICE_CPU) 154 .TypeConstraint<double>("CoordinateType"), 155 HyperplaneLSHProbesOp<double>); 156 157 } // namespace tensorflow 158