Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/image_ops.cc
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "tensorflow/core/kernels/non_max_suppression_op.h"
     21 
     22 #include <vector>
     23 
     24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/register_types.h"
     27 #include "tensorflow/core/framework/tensor.h"
     28 #include "tensorflow/core/framework/tensor_shape.h"
     29 #include "tensorflow/core/framework/types.h"
     30 #include "tensorflow/core/kernels/bounds_check.h"
     31 #include "tensorflow/core/lib/core/status.h"
     32 #include "tensorflow/core/lib/gtl/stl_util.h"
     33 #include "tensorflow/core/platform/logging.h"
     34 
     35 namespace tensorflow {
     36 namespace {
     37 
     38 typedef Eigen::ThreadPoolDevice CPUDevice;
     39 
     40 static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
     41                                          const Tensor& boxes,
     42                                          const Tensor& scores, int* num_boxes) {
     43   // The shape of 'boxes' is [num_boxes, 4]
     44   OP_REQUIRES(context, boxes.dims() == 2,
     45               errors::InvalidArgument("boxes must be 2-D",
     46                                       boxes.shape().DebugString()));
     47   *num_boxes = boxes.dim_size(0);
     48   OP_REQUIRES(context, boxes.dim_size(1) == 4,
     49               errors::InvalidArgument("boxes must have 4 columns"));
     50 
     51   // The shape of 'scores' is [num_boxes]
     52   OP_REQUIRES(context, scores.dims() == 1,
     53               errors::InvalidArgument("scores must be 1-D",
     54                                       scores.shape().DebugString()));
     55   OP_REQUIRES(context, scores.dim_size(0) == *num_boxes,
     56               errors::InvalidArgument("scores has incompatible shape"));
     57 }
     58 
     59 static inline void DecreasingArgSort(const std::vector<float>& values,
     60                                      std::vector<int>* indices) {
     61   indices->resize(values.size());
     62   for (int i = 0; i < values.size(); ++i) (*indices)[i] = i;
     63   std::sort(
     64       indices->begin(), indices->end(),
     65       [&values](const int i, const int j) { return values[i] > values[j]; });
     66 }
     67 
     68 // Return true if intersection-over-union overlap between boxes i and j
     69 // is greater than iou_threshold.
     70 static inline bool IOUGreaterThanThreshold(
     71     typename TTypes<float, 2>::ConstTensor boxes, int i, int j,
     72     float iou_threshold) {
     73   const float ymin_i = std::min<float>(boxes(i, 0), boxes(i, 2));
     74   const float xmin_i = std::min<float>(boxes(i, 1), boxes(i, 3));
     75   const float ymax_i = std::max<float>(boxes(i, 0), boxes(i, 2));
     76   const float xmax_i = std::max<float>(boxes(i, 1), boxes(i, 3));
     77   const float ymin_j = std::min<float>(boxes(j, 0), boxes(j, 2));
     78   const float xmin_j = std::min<float>(boxes(j, 1), boxes(j, 3));
     79   const float ymax_j = std::max<float>(boxes(j, 0), boxes(j, 2));
     80   const float xmax_j = std::max<float>(boxes(j, 1), boxes(j, 3));
     81   const float area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i);
     82   const float area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j);
     83   if (area_i <= 0 || area_j <= 0) return 0.0;
     84   const float intersection_ymin = std::max<float>(ymin_i, ymin_j);
     85   const float intersection_xmin = std::max<float>(xmin_i, xmin_j);
     86   const float intersection_ymax = std::min<float>(ymax_i, ymax_j);
     87   const float intersection_xmax = std::min<float>(xmax_i, xmax_j);
     88   const float intersection_area =
     89       std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
     90       std::max<float>(intersection_xmax - intersection_xmin, 0.0);
     91   const float iou = intersection_area / (area_i + area_j - intersection_area);
     92   return iou > iou_threshold;
     93 }
     94 
     95 void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes,
     96                            const Tensor& scores, const Tensor& max_output_size,
     97                            const float iou_threshold) {
     98   OP_REQUIRES(context, iou_threshold >= 0 && iou_threshold <= 1,
     99               errors::InvalidArgument("iou_threshold must be in [0, 1]"));
    100 
    101   int num_boxes = 0;
    102   ParseAndCheckBoxSizes(context, boxes, scores, &num_boxes);
    103   if (!context->status().ok()) {
    104     return;
    105   }
    106 
    107   const int output_size = std::min(max_output_size.scalar<int>()(), num_boxes);
    108   TTypes<float, 2>::ConstTensor boxes_data = boxes.tensor<float, 2>();
    109 
    110   std::vector<float> scores_data(num_boxes);
    111   std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin());
    112   std::vector<int> sorted_indices;
    113   DecreasingArgSort(scores_data, &sorted_indices);
    114 
    115   std::vector<int> selected;
    116   std::vector<int> selected_indices(output_size, 0);
    117   int num_selected = 0;
    118   for (int i = 0; i < num_boxes; ++i) {
    119     if (selected.size() >= output_size) break;
    120     bool should_select = true;
    121     // Overlapping boxes are likely to have similar scores,
    122     // therefore we iterate through the selected boxes backwards.
    123     for (int j = num_selected - 1; j >= 0; --j) {
    124       if (IOUGreaterThanThreshold(boxes_data, sorted_indices[i],
    125                                   sorted_indices[selected_indices[j]],
    126                                   iou_threshold)) {
    127         should_select = false;
    128         break;
    129       }
    130     }
    131     if (should_select) {
    132       selected.push_back(sorted_indices[i]);
    133       selected_indices[num_selected++] = i;
    134     }
    135   }
    136 
    137   // Allocate output tensor
    138   Tensor* output = nullptr;
    139   TensorShape output_shape({static_cast<int>(selected.size())});
    140   OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    141   TTypes<int, 1>::Tensor selected_indices_data = output->tensor<int, 1>();
    142   std::copy_n(selected.begin(), selected.size(), selected_indices_data.data());
    143 }
    144 
    145 }  // namespace
    146 
    147 template <typename Device>
    148 class NonMaxSuppressionOp : public OpKernel {
    149  public:
    150   explicit NonMaxSuppressionOp(OpKernelConstruction* context)
    151       : OpKernel(context) {
    152     OP_REQUIRES_OK(context, context->GetAttr("iou_threshold", &iou_threshold_));
    153   }
    154 
    155   void Compute(OpKernelContext* context) override {
    156     // boxes: [num_boxes, 4]
    157     const Tensor& boxes = context->input(0);
    158     // scores: [num_boxes]
    159     const Tensor& scores = context->input(1);
    160     // max_output_size: scalar
    161     const Tensor& max_output_size = context->input(2);
    162     OP_REQUIRES(
    163         context, TensorShapeUtils::IsScalar(max_output_size.shape()),
    164         errors::InvalidArgument("max_output_size must be 0-D, got shape ",
    165                                 max_output_size.shape().DebugString()));
    166 
    167     DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
    168                           iou_threshold_);
    169   }
    170 
    171  private:
    172   float iou_threshold_;
    173 };
    174 
    175 template <typename Device>
    176 class NonMaxSuppressionV2Op : public OpKernel {
    177  public:
    178   explicit NonMaxSuppressionV2Op(OpKernelConstruction* context)
    179       : OpKernel(context) {}
    180 
    181   void Compute(OpKernelContext* context) override {
    182     // boxes: [num_boxes, 4]
    183     const Tensor& boxes = context->input(0);
    184     // scores: [num_boxes]
    185     const Tensor& scores = context->input(1);
    186     // max_output_size: scalar
    187     const Tensor& max_output_size = context->input(2);
    188     OP_REQUIRES(
    189         context, TensorShapeUtils::IsScalar(max_output_size.shape()),
    190         errors::InvalidArgument("max_output_size must be 0-D, got shape ",
    191                                 max_output_size.shape().DebugString()));
    192     // iou_threshold: scalar
    193     const Tensor& iou_threshold = context->input(3);
    194     OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
    195                 errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
    196                                         iou_threshold.shape().DebugString()));
    197 
    198     const float iou_threshold_val = iou_threshold.scalar<float>()();
    199 
    200     DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
    201                           iou_threshold_val);
    202   }
    203 };
    204 
    205 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU),
    206                         NonMaxSuppressionOp<CPUDevice>);
    207 
    208 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU),
    209                         NonMaxSuppressionV2Op<CPUDevice>);
    210 
    211 }  // namespace tensorflow
    212