1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // See ../ops/image_ops.cc for details. 16 #define EIGEN_USE_THREADS 17 18 #include "tensorflow/core/framework/op_kernel.h" 19 #include "tensorflow/core/framework/register_types.h" 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/framework/tensor_shape.h" 22 #include "tensorflow/core/framework/types.h" 23 #include "tensorflow/core/lib/core/errors.h" 24 #include "tensorflow/core/platform/logging.h" 25 26 namespace tensorflow { 27 28 template <class T> 29 class DrawBoundingBoxesOp : public OpKernel { 30 public: 31 explicit DrawBoundingBoxesOp(OpKernelConstruction* context) 32 : OpKernel(context) {} 33 34 void Compute(OpKernelContext* context) override { 35 const Tensor& images = context->input(0); 36 const Tensor& boxes = context->input(1); 37 const int64 depth = images.dim_size(3); 38 39 OP_REQUIRES(context, images.dims() == 4, 40 errors::InvalidArgument("The rank of the images should be 4")); 41 OP_REQUIRES( 42 context, boxes.dims() == 3, 43 errors::InvalidArgument("The rank of the boxes tensor should be 3")); 44 OP_REQUIRES(context, images.dim_size(0) == boxes.dim_size(0), 45 errors::InvalidArgument("The batch sizes should be the same")); 46 47 OP_REQUIRES( 48 context, depth == 4 || depth == 1 || depth == 3, 49 errors::InvalidArgument("Channel depth should be either 1 (GRY), " 50 "3 (RGB), or 4 (RGBA)")); 51 52 const int64 batch_size = images.dim_size(0); 53 const int64 height = images.dim_size(1); 54 const int64 width = images.dim_size(2); 55 const int64 color_table_length = 10; 56 57 // 0: yellow 58 // 1: blue 59 // 2: red 60 // 3: lime 61 // 4: purple 62 // 5: olive 63 // 6: maroon 64 // 7: navy blue 65 // 8: aqua 66 // 9: fuchsia 67 float color_table[color_table_length][4] = { 68 {1, 1, 0, 1}, {0, 0, 1, 1}, {1, 0, 0, 1}, {0, 1, 0, 1}, 69 {0.5, 0, 0.5, 1}, {0.5, 0.5, 0, 1}, {0.5, 0, 0, 1}, {0, 0, 0.5, 1}, 70 {0, 1, 1, 1}, {1, 0, 1, 1}, 71 }; 72 73 // Reset first color channel to 1 if image is GRY. 74 // For GRY images, this means all bounding boxes will be white. 75 if (depth == 1) { 76 for (int64 i = 0; i < color_table_length; i++) { 77 color_table[i][0] = 1; 78 } 79 } 80 Tensor* output; 81 OP_REQUIRES_OK( 82 context, 83 context->allocate_output( 84 0, TensorShape({batch_size, height, width, depth}), &output)); 85 86 output->tensor<T, 4>() = images.tensor<T, 4>(); 87 auto canvas = output->tensor<T, 4>(); 88 89 for (int64 b = 0; b < batch_size; ++b) { 90 const int64 num_boxes = boxes.dim_size(1); 91 const auto tboxes = boxes.tensor<T, 3>(); 92 for (int64 bb = 0; bb < num_boxes; ++bb) { 93 int64 color_index = bb % color_table_length; 94 const int64 min_box_row = 95 static_cast<float>(tboxes(b, bb, 0)) * (height - 1); 96 const int64 min_box_row_clamp = std::max<int64>(min_box_row, 0); 97 const int64 max_box_row = 98 static_cast<float>(tboxes(b, bb, 2)) * (height - 1); 99 const int64 max_box_row_clamp = 100 std::min<int64>(max_box_row, height - 1); 101 const int64 min_box_col = 102 static_cast<float>(tboxes(b, bb, 1)) * (width - 1); 103 const int64 min_box_col_clamp = std::max<int64>(min_box_col, 0); 104 const int64 max_box_col = 105 static_cast<float>(tboxes(b, bb, 3)) * (width - 1); 106 const int64 max_box_col_clamp = std::min<int64>(max_box_col, width - 1); 107 108 if (min_box_row > max_box_row || min_box_col > max_box_col) { 109 LOG(WARNING) << "Bounding box (" << min_box_row << "," << min_box_col 110 << "," << max_box_row << "," << max_box_col 111 << ") is inverted and will not be drawn."; 112 continue; 113 } 114 if (min_box_row >= height || max_box_row < 0 || min_box_col >= width || 115 max_box_col < 0) { 116 LOG(WARNING) << "Bounding box (" << min_box_row << "," << min_box_col 117 << "," << max_box_row << "," << max_box_col 118 << ") is completely outside the image" 119 << " and will not be drawn."; 120 continue; 121 } 122 123 // At this point, {min,max}_box_{row,col}_clamp are inside the 124 // image. 125 CHECK_GE(min_box_row_clamp, 0); 126 CHECK_GE(max_box_row_clamp, 0); 127 CHECK_LT(min_box_row_clamp, height); 128 CHECK_LT(max_box_row_clamp, height); 129 CHECK_GE(min_box_col_clamp, 0); 130 CHECK_GE(max_box_col_clamp, 0); 131 CHECK_LT(min_box_col_clamp, width); 132 CHECK_LT(max_box_col_clamp, width); 133 134 // At this point, the min_box_row and min_box_col are either 135 // in the image or above/left of it, and max_box_row and 136 // max_box_col are either in the image or below/right or it. 137 CHECK_LT(min_box_row, height); 138 CHECK_GE(max_box_row, 0); 139 CHECK_LT(min_box_col, width); 140 CHECK_GE(max_box_col, 0); 141 142 // Draw top line. 143 if (min_box_row >= 0) { 144 for (int64 j = min_box_col_clamp; j <= max_box_col_clamp; ++j) 145 for (int64 c = 0; c < depth; c++) { 146 canvas(b, min_box_row, j, c) = 147 static_cast<T>(color_table[color_index][c]); 148 } 149 } 150 // Draw bottom line. 151 if (max_box_row < height) { 152 for (int64 j = min_box_col_clamp; j <= max_box_col_clamp; ++j) 153 for (int64 c = 0; c < depth; c++) { 154 canvas(b, max_box_row, j, c) = 155 static_cast<T>(color_table[color_index][c]); 156 } 157 } 158 // Draw left line. 159 if (min_box_col >= 0) { 160 for (int64 i = min_box_row_clamp; i <= max_box_row_clamp; ++i) 161 for (int64 c = 0; c < depth; c++) { 162 canvas(b, i, min_box_col, c) = 163 static_cast<T>(color_table[color_index][c]); 164 } 165 } 166 // Draw right line. 167 if (max_box_col < width) { 168 for (int64 i = min_box_row_clamp; i <= max_box_row_clamp; ++i) 169 for (int64 c = 0; c < depth; c++) { 170 canvas(b, i, max_box_col, c) = 171 static_cast<T>(color_table[color_index][c]); 172 } 173 } 174 } 175 } 176 } 177 }; 178 179 #define REGISTER_CPU_KERNEL(T) \ 180 REGISTER_KERNEL_BUILDER( \ 181 Name("DrawBoundingBoxes").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 182 DrawBoundingBoxesOp<T>); 183 TF_CALL_half(REGISTER_CPU_KERNEL); 184 TF_CALL_float(REGISTER_CPU_KERNEL); 185 186 } // namespace tensorflow 187