Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <array>
     17 
     18 #include "tensorflow/core/framework/op_kernel.h"
     19 #include "tensorflow/core/lib/core/threadpool.h"
     20 
     21 #include "tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.h"
     22 
     23 namespace tensorflow {
     24 
     25 using errors::Internal;
     26 using errors::InvalidArgument;
     27 
     28 using nearest_neighbor::HyperplaneMultiprobe;
     29 
     30 // This class wraps the multiprobe LSH code in hyperplane_lsh_probes in a
     31 // TensorFlow op implementation.
     32 template <typename CoordinateType>
     33 class HyperplaneLSHProbesOp : public OpKernel {
     34  public:
     35   using Matrix = Eigen::Matrix<CoordinateType, Eigen::Dynamic, Eigen::Dynamic,
     36                                Eigen::RowMajor>;
     37   using ConstMatrixMap = Eigen::Map<const Matrix>;
     38   using MatrixMap = Eigen::Map<Matrix>;
     39 
     40   explicit HyperplaneLSHProbesOp(OpKernelConstruction* context)
     41       : OpKernel(context) {}
     42 
     43   void Compute(OpKernelContext* context) override {
     44     // Get the input tensors and check their shapes.
     45     const Tensor& products_tensor = context->input(0);
     46     OP_REQUIRES(context, products_tensor.dims() == 2,
     47                 InvalidArgument("Need a two-dimensional products tensor, got ",
     48                                 products_tensor.dims(), " dimensions."));
     49 
     50     const Tensor& num_tables_tensor = context->input(1);
     51     OP_REQUIRES(context, num_tables_tensor.dims() == 0,
     52                 InvalidArgument("Need a scalar num_tables tensor, got ",
     53                                 num_tables_tensor.dims(), " dimensions."));
     54     int num_tables = num_tables_tensor.scalar<int32>()();
     55     OP_REQUIRES(context, num_tables >= 1,
     56                 InvalidArgument("num_tables must be at least 1 but got ",
     57                                 num_tables, "."));
     58     OP_REQUIRES(context, num_tables <= 1000,
     59                 InvalidArgument("Need num_tables <= 1000, got ", num_tables,
     60                                 ". This is mostly to protect against incorrect "
     61                                 "use of this Op. If you really need more tables"
     62                                 ", change the code."));
     63 
     64     const Tensor& num_hyperplanes_per_table_tensor = context->input(2);
     65     OP_REQUIRES(context, num_hyperplanes_per_table_tensor.dims() == 0,
     66                 InvalidArgument("Need a scalar num_hyperplanes_per_table "
     67                                 "tensor, got ",
     68                                 num_hyperplanes_per_table_tensor.dims(),
     69                                 " dimensions."));
     70     int num_hyperplanes_per_table =
     71         num_hyperplanes_per_table_tensor.scalar<int32>()();
     72     OP_REQUIRES(context, num_hyperplanes_per_table >= 1,
     73                 InvalidArgument("num_hyperplanes_per_table must be at least 1 "
     74                                 "but got ",
     75                                 num_hyperplanes_per_table, "."));
     76     OP_REQUIRES(context, num_hyperplanes_per_table <= 30,
     77                 InvalidArgument("Need num_hyperplanes_per_table <= 30, got ",
     78                                 num_hyperplanes_per_table,
     79                                 ". "
     80                                 "If you need more hyperplanes, change this Op"
     81                                 " to work for larger integer types (int64)."));
     82 
     83     const Tensor& num_probes_tensor = context->input(3);
     84     OP_REQUIRES(context, num_probes_tensor.dims() == 0,
     85                 InvalidArgument("Need a scalar num_probes tensor, got ",
     86                                 num_probes_tensor.dims(), " dimensions."));
     87     int num_probes = num_probes_tensor.scalar<int32>()();
     88     OP_REQUIRES(context, num_probes >= 1,
     89                 InvalidArgument("num_probes must be at least 1."));
     90 
     91     int expected_num_hyperplanes = num_tables * num_hyperplanes_per_table;
     92     OP_REQUIRES(context,
     93                 products_tensor.dim_size(1) == expected_num_hyperplanes,
     94                 InvalidArgument("Expected number of hyperplanes is ",
     95                                 expected_num_hyperplanes, " but received ",
     96                                 products_tensor.dim_size(1),
     97                                 " inner products per "
     98                                 "point."));
     99 
    100     auto products_eigen_tensor = products_tensor.matrix<CoordinateType>();
    101     ConstMatrixMap products_matrix(products_eigen_tensor.data(),
    102                                    products_tensor.dim_size(0),
    103                                    products_tensor.dim_size(1));
    104 
    105     int batch_size = products_tensor.dim_size(0);
    106 
    107     Tensor* probes_tensor = nullptr;
    108     Tensor* tables_tensor = nullptr;
    109     TensorShape output_shape({batch_size, num_probes});
    110     OP_REQUIRES_OK(context,
    111                    context->allocate_output(0, output_shape, &probes_tensor));
    112     OP_REQUIRES_OK(context,
    113                    context->allocate_output(1, output_shape, &tables_tensor));
    114     auto probes_eigen_tensor = probes_tensor->matrix<int32>();
    115     auto tables_eigen_tensor = tables_tensor->matrix<int32>();
    116 
    117     // Constants (cycles per hyperplane and table) were measured on
    118     // lschmidt's workstation.
    119     int64 cost_per_unit = 21 * num_hyperplanes_per_table * num_tables;
    120     if (num_probes > num_tables) {
    121       cost_per_unit +=
    122           110 * num_hyperplanes_per_table * (num_probes - num_tables);
    123     }
    124     context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
    125         batch_size, cost_per_unit, [&](int64 start, int64 end) {
    126           HyperplaneMultiprobe<CoordinateType, int32> multiprobe(
    127               num_hyperplanes_per_table, num_tables);
    128 
    129           for (int point_index = start; point_index < end; ++point_index) {
    130             multiprobe.SetupProbing(products_matrix.row(point_index),
    131                                     num_probes);
    132             for (int ii = 0; ii < num_probes; ++ii) {
    133               int32 cur_probe;
    134               int_fast32_t cur_table;
    135               OP_REQUIRES(context,
    136                           multiprobe.GetNextProbe(&cur_probe, &cur_table),
    137                           Internal("Failed to get probe number ", ii,
    138                                    " for point number ", point_index, "."));
    139               probes_eigen_tensor(point_index, ii) = cur_probe;
    140               tables_eigen_tensor(point_index, ii) = cur_table;
    141             }
    142           }
    143         });
    144   }
    145 };
    146 
    147 REGISTER_KERNEL_BUILDER(Name("HyperplaneLSHProbes")
    148                             .Device(DEVICE_CPU)
    149                             .TypeConstraint<float>("CoordinateType"),
    150                         HyperplaneLSHProbesOp<float>);
    151 
    152 REGISTER_KERNEL_BUILDER(Name("HyperplaneLSHProbes")
    153                             .Device(DEVICE_CPU)
    154                             .TypeConstraint<double>("CoordinateType"),
    155                         HyperplaneLSHProbesOp<double>);
    156 
    157 }  // namespace tensorflow
    158