1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/image_ops.cc 17 18 #define EIGEN_USE_THREADS 19 20 #include "tensorflow/core/kernels/non_max_suppression_op.h" 21 22 #include <vector> 23 24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/register_types.h" 27 #include "tensorflow/core/framework/tensor.h" 28 #include "tensorflow/core/framework/tensor_shape.h" 29 #include "tensorflow/core/framework/types.h" 30 #include "tensorflow/core/kernels/bounds_check.h" 31 #include "tensorflow/core/lib/core/status.h" 32 #include "tensorflow/core/lib/gtl/stl_util.h" 33 #include "tensorflow/core/platform/logging.h" 34 35 namespace tensorflow { 36 namespace { 37 38 typedef Eigen::ThreadPoolDevice CPUDevice; 39 40 static inline void ParseAndCheckBoxSizes(OpKernelContext* context, 41 const Tensor& boxes, 42 const Tensor& scores, int* num_boxes) { 43 // The shape of 'boxes' is [num_boxes, 4] 44 OP_REQUIRES(context, boxes.dims() == 2, 45 errors::InvalidArgument("boxes must be 2-D", 46 boxes.shape().DebugString())); 47 *num_boxes = boxes.dim_size(0); 48 OP_REQUIRES(context, boxes.dim_size(1) == 4, 49 errors::InvalidArgument("boxes must have 4 columns")); 50 51 // The shape of 'scores' is [num_boxes] 52 OP_REQUIRES(context, scores.dims() == 1, 53 errors::InvalidArgument("scores must be 1-D", 54 scores.shape().DebugString())); 55 OP_REQUIRES(context, scores.dim_size(0) == *num_boxes, 56 errors::InvalidArgument("scores has incompatible shape")); 57 } 58 59 static inline void DecreasingArgSort(const std::vector<float>& values, 60 std::vector<int>* indices) { 61 indices->resize(values.size()); 62 for (int i = 0; i < values.size(); ++i) (*indices)[i] = i; 63 std::sort( 64 indices->begin(), indices->end(), 65 [&values](const int i, const int j) { return values[i] > values[j]; }); 66 } 67 68 // Return true if intersection-over-union overlap between boxes i and j 69 // is greater than iou_threshold. 70 static inline bool IOUGreaterThanThreshold( 71 typename TTypes<float, 2>::ConstTensor boxes, int i, int j, 72 float iou_threshold) { 73 const float ymin_i = std::min<float>(boxes(i, 0), boxes(i, 2)); 74 const float xmin_i = std::min<float>(boxes(i, 1), boxes(i, 3)); 75 const float ymax_i = std::max<float>(boxes(i, 0), boxes(i, 2)); 76 const float xmax_i = std::max<float>(boxes(i, 1), boxes(i, 3)); 77 const float ymin_j = std::min<float>(boxes(j, 0), boxes(j, 2)); 78 const float xmin_j = std::min<float>(boxes(j, 1), boxes(j, 3)); 79 const float ymax_j = std::max<float>(boxes(j, 0), boxes(j, 2)); 80 const float xmax_j = std::max<float>(boxes(j, 1), boxes(j, 3)); 81 const float area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i); 82 const float area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j); 83 if (area_i <= 0 || area_j <= 0) return 0.0; 84 const float intersection_ymin = std::max<float>(ymin_i, ymin_j); 85 const float intersection_xmin = std::max<float>(xmin_i, xmin_j); 86 const float intersection_ymax = std::min<float>(ymax_i, ymax_j); 87 const float intersection_xmax = std::min<float>(xmax_i, xmax_j); 88 const float intersection_area = 89 std::max<float>(intersection_ymax - intersection_ymin, 0.0) * 90 std::max<float>(intersection_xmax - intersection_xmin, 0.0); 91 const float iou = intersection_area / (area_i + area_j - intersection_area); 92 return iou > iou_threshold; 93 } 94 95 void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes, 96 const Tensor& scores, const Tensor& max_output_size, 97 const float iou_threshold) { 98 OP_REQUIRES(context, iou_threshold >= 0 && iou_threshold <= 1, 99 errors::InvalidArgument("iou_threshold must be in [0, 1]")); 100 101 int num_boxes = 0; 102 ParseAndCheckBoxSizes(context, boxes, scores, &num_boxes); 103 if (!context->status().ok()) { 104 return; 105 } 106 107 const int output_size = std::min(max_output_size.scalar<int>()(), num_boxes); 108 TTypes<float, 2>::ConstTensor boxes_data = boxes.tensor<float, 2>(); 109 110 std::vector<float> scores_data(num_boxes); 111 std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin()); 112 std::vector<int> sorted_indices; 113 DecreasingArgSort(scores_data, &sorted_indices); 114 115 std::vector<int> selected; 116 std::vector<int> selected_indices(output_size, 0); 117 int num_selected = 0; 118 for (int i = 0; i < num_boxes; ++i) { 119 if (selected.size() >= output_size) break; 120 bool should_select = true; 121 // Overlapping boxes are likely to have similar scores, 122 // therefore we iterate through the selected boxes backwards. 123 for (int j = num_selected - 1; j >= 0; --j) { 124 if (IOUGreaterThanThreshold(boxes_data, sorted_indices[i], 125 sorted_indices[selected_indices[j]], 126 iou_threshold)) { 127 should_select = false; 128 break; 129 } 130 } 131 if (should_select) { 132 selected.push_back(sorted_indices[i]); 133 selected_indices[num_selected++] = i; 134 } 135 } 136 137 // Allocate output tensor 138 Tensor* output = nullptr; 139 TensorShape output_shape({static_cast<int>(selected.size())}); 140 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 141 TTypes<int, 1>::Tensor selected_indices_data = output->tensor<int, 1>(); 142 std::copy_n(selected.begin(), selected.size(), selected_indices_data.data()); 143 } 144 145 } // namespace 146 147 template <typename Device> 148 class NonMaxSuppressionOp : public OpKernel { 149 public: 150 explicit NonMaxSuppressionOp(OpKernelConstruction* context) 151 : OpKernel(context) { 152 OP_REQUIRES_OK(context, context->GetAttr("iou_threshold", &iou_threshold_)); 153 } 154 155 void Compute(OpKernelContext* context) override { 156 // boxes: [num_boxes, 4] 157 const Tensor& boxes = context->input(0); 158 // scores: [num_boxes] 159 const Tensor& scores = context->input(1); 160 // max_output_size: scalar 161 const Tensor& max_output_size = context->input(2); 162 OP_REQUIRES( 163 context, TensorShapeUtils::IsScalar(max_output_size.shape()), 164 errors::InvalidArgument("max_output_size must be 0-D, got shape ", 165 max_output_size.shape().DebugString())); 166 167 DoNonMaxSuppressionOp(context, boxes, scores, max_output_size, 168 iou_threshold_); 169 } 170 171 private: 172 float iou_threshold_; 173 }; 174 175 template <typename Device> 176 class NonMaxSuppressionV2Op : public OpKernel { 177 public: 178 explicit NonMaxSuppressionV2Op(OpKernelConstruction* context) 179 : OpKernel(context) {} 180 181 void Compute(OpKernelContext* context) override { 182 // boxes: [num_boxes, 4] 183 const Tensor& boxes = context->input(0); 184 // scores: [num_boxes] 185 const Tensor& scores = context->input(1); 186 // max_output_size: scalar 187 const Tensor& max_output_size = context->input(2); 188 OP_REQUIRES( 189 context, TensorShapeUtils::IsScalar(max_output_size.shape()), 190 errors::InvalidArgument("max_output_size must be 0-D, got shape ", 191 max_output_size.shape().DebugString())); 192 // iou_threshold: scalar 193 const Tensor& iou_threshold = context->input(3); 194 OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()), 195 errors::InvalidArgument("iou_threshold must be 0-D, got shape ", 196 iou_threshold.shape().DebugString())); 197 198 const float iou_threshold_val = iou_threshold.scalar<float>()(); 199 200 DoNonMaxSuppressionOp(context, boxes, scores, max_output_size, 201 iou_threshold_val); 202 } 203 }; 204 205 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU), 206 NonMaxSuppressionOp<CPUDevice>); 207 208 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU), 209 NonMaxSuppressionV2Op<CPUDevice>); 210 211 } // namespace tensorflow 212