Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 // See ../ops/image_ops.cc for details.
     16 #define EIGEN_USE_THREADS
     17 
     18 #include "tensorflow/core/framework/op_kernel.h"
     19 #include "tensorflow/core/framework/register_types.h"
     20 #include "tensorflow/core/framework/tensor.h"
     21 #include "tensorflow/core/framework/tensor_shape.h"
     22 #include "tensorflow/core/framework/types.h"
     23 #include "tensorflow/core/lib/core/errors.h"
     24 #include "tensorflow/core/platform/logging.h"
     25 
     26 namespace tensorflow {
     27 
     28 template <class T>
     29 class DrawBoundingBoxesOp : public OpKernel {
     30  public:
     31   explicit DrawBoundingBoxesOp(OpKernelConstruction* context)
     32       : OpKernel(context) {}
     33 
     34   void Compute(OpKernelContext* context) override {
     35     const Tensor& images = context->input(0);
     36     const Tensor& boxes = context->input(1);
     37     const int64 depth = images.dim_size(3);
     38 
     39     OP_REQUIRES(context, images.dims() == 4,
     40                 errors::InvalidArgument("The rank of the images should be 4"));
     41     OP_REQUIRES(
     42         context, boxes.dims() == 3,
     43         errors::InvalidArgument("The rank of the boxes tensor should be 3"));
     44     OP_REQUIRES(context, images.dim_size(0) == boxes.dim_size(0),
     45                 errors::InvalidArgument("The batch sizes should be the same"));
     46 
     47     OP_REQUIRES(
     48         context, depth == 4 || depth == 1 || depth == 3,
     49         errors::InvalidArgument("Channel depth should be either 1 (GRY), "
     50                                 "3 (RGB), or 4 (RGBA)"));
     51 
     52     const int64 batch_size = images.dim_size(0);
     53     const int64 height = images.dim_size(1);
     54     const int64 width = images.dim_size(2);
     55     const int64 color_table_length = 10;
     56 
     57     // 0: yellow
     58     // 1: blue
     59     // 2: red
     60     // 3: lime
     61     // 4: purple
     62     // 5: olive
     63     // 6: maroon
     64     // 7: navy blue
     65     // 8: aqua
     66     // 9: fuchsia
     67     float color_table[color_table_length][4] = {
     68         {1, 1, 0, 1},     {0, 0, 1, 1},     {1, 0, 0, 1},   {0, 1, 0, 1},
     69         {0.5, 0, 0.5, 1}, {0.5, 0.5, 0, 1}, {0.5, 0, 0, 1}, {0, 0, 0.5, 1},
     70         {0, 1, 1, 1},     {1, 0, 1, 1},
     71     };
     72 
     73     // Reset first color channel to 1 if image is GRY.
     74     // For GRY images, this means all bounding boxes will be white.
     75     if (depth == 1) {
     76       for (int64 i = 0; i < color_table_length; i++) {
     77         color_table[i][0] = 1;
     78       }
     79     }
     80     Tensor* output;
     81     OP_REQUIRES_OK(
     82         context,
     83         context->allocate_output(
     84             0, TensorShape({batch_size, height, width, depth}), &output));
     85 
     86     output->tensor<T, 4>() = images.tensor<T, 4>();
     87     auto canvas = output->tensor<T, 4>();
     88 
     89     for (int64 b = 0; b < batch_size; ++b) {
     90       const int64 num_boxes = boxes.dim_size(1);
     91       const auto tboxes = boxes.tensor<T, 3>();
     92       for (int64 bb = 0; bb < num_boxes; ++bb) {
     93         int64 color_index = bb % color_table_length;
     94         const int64 min_box_row =
     95             static_cast<float>(tboxes(b, bb, 0)) * (height - 1);
     96         const int64 min_box_row_clamp = std::max<int64>(min_box_row, 0);
     97         const int64 max_box_row =
     98             static_cast<float>(tboxes(b, bb, 2)) * (height - 1);
     99         const int64 max_box_row_clamp =
    100             std::min<int64>(max_box_row, height - 1);
    101         const int64 min_box_col =
    102             static_cast<float>(tboxes(b, bb, 1)) * (width - 1);
    103         const int64 min_box_col_clamp = std::max<int64>(min_box_col, 0);
    104         const int64 max_box_col =
    105             static_cast<float>(tboxes(b, bb, 3)) * (width - 1);
    106         const int64 max_box_col_clamp = std::min<int64>(max_box_col, width - 1);
    107 
    108         if (min_box_row > max_box_row || min_box_col > max_box_col) {
    109           LOG(WARNING) << "Bounding box (" << min_box_row << "," << min_box_col
    110                        << "," << max_box_row << "," << max_box_col
    111                        << ") is inverted and will not be drawn.";
    112           continue;
    113         }
    114         if (min_box_row >= height || max_box_row < 0 || min_box_col >= width ||
    115             max_box_col < 0) {
    116           LOG(WARNING) << "Bounding box (" << min_box_row << "," << min_box_col
    117                        << "," << max_box_row << "," << max_box_col
    118                        << ") is completely outside the image"
    119                        << " and will not be drawn.";
    120           continue;
    121         }
    122 
    123         // At this point, {min,max}_box_{row,col}_clamp are inside the
    124         // image.
    125         CHECK_GE(min_box_row_clamp, 0);
    126         CHECK_GE(max_box_row_clamp, 0);
    127         CHECK_LT(min_box_row_clamp, height);
    128         CHECK_LT(max_box_row_clamp, height);
    129         CHECK_GE(min_box_col_clamp, 0);
    130         CHECK_GE(max_box_col_clamp, 0);
    131         CHECK_LT(min_box_col_clamp, width);
    132         CHECK_LT(max_box_col_clamp, width);
    133 
    134         // At this point, the min_box_row and min_box_col are either
    135         // in the image or above/left of it, and max_box_row and
    136         // max_box_col are either in the image or below/right or it.
    137         CHECK_LT(min_box_row, height);
    138         CHECK_GE(max_box_row, 0);
    139         CHECK_LT(min_box_col, width);
    140         CHECK_GE(max_box_col, 0);
    141 
    142         // Draw top line.
    143         if (min_box_row >= 0) {
    144           for (int64 j = min_box_col_clamp; j <= max_box_col_clamp; ++j)
    145             for (int64 c = 0; c < depth; c++) {
    146               canvas(b, min_box_row, j, c) =
    147                   static_cast<T>(color_table[color_index][c]);
    148             }
    149         }
    150         // Draw bottom line.
    151         if (max_box_row < height) {
    152           for (int64 j = min_box_col_clamp; j <= max_box_col_clamp; ++j)
    153             for (int64 c = 0; c < depth; c++) {
    154               canvas(b, max_box_row, j, c) =
    155                   static_cast<T>(color_table[color_index][c]);
    156             }
    157         }
    158         // Draw left line.
    159         if (min_box_col >= 0) {
    160           for (int64 i = min_box_row_clamp; i <= max_box_row_clamp; ++i)
    161             for (int64 c = 0; c < depth; c++) {
    162               canvas(b, i, min_box_col, c) =
    163                   static_cast<T>(color_table[color_index][c]);
    164             }
    165         }
    166         // Draw right line.
    167         if (max_box_col < width) {
    168           for (int64 i = min_box_row_clamp; i <= max_box_row_clamp; ++i)
    169             for (int64 c = 0; c < depth; c++) {
    170               canvas(b, i, max_box_col, c) =
    171                   static_cast<T>(color_table[color_index][c]);
    172             }
    173         }
    174       }
    175     }
    176   }
    177 };
    178 
    179 #define REGISTER_CPU_KERNEL(T)                                             \
    180   REGISTER_KERNEL_BUILDER(                                                 \
    181       Name("DrawBoundingBoxes").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    182       DrawBoundingBoxesOp<T>);
    183 TF_CALL_half(REGISTER_CPU_KERNEL);
    184 TF_CALL_float(REGISTER_CPU_KERNEL);
    185 
    186 }  // namespace tensorflow
    187