Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 // See docs in ../ops/image_ops.cc.
     16 #include <math.h>
     17 #include "tensorflow/core/framework/op_kernel.h"
     18 #include "tensorflow/core/framework/register_types.h"
     19 #include "tensorflow/core/framework/tensor.h"
     20 #include "tensorflow/core/framework/types.h"
     21 #include "tensorflow/core/kernels/bounds_check.h"
     22 #include "tensorflow/core/lib/random/simple_philox.h"
     23 #include "tensorflow/core/util/guarded_philox_random.h"
     24 
     25 using tensorflow::random::SimplePhilox;
     26 
     27 namespace tensorflow {
     28 namespace {
     29 
     30 // A simple Rectangle class that supplies intersection.
     31 class Rectangle {
     32  public:
     33   Rectangle() { Set(0, 0, 0, 0); }
     34   Rectangle(int xmin, int ymin, int xmax, int ymax) {
     35     Set(xmin, ymin, xmax, ymax);
     36   }
     37 
     38   void Set(int xmin, int ymin, int xmax, int ymax) {
     39     min_x_ = xmin;
     40     min_y_ = ymin;
     41     max_x_ = xmax;
     42     max_y_ = ymax;
     43   }
     44 
     45   bool IsEmpty() const { return min_x_ > max_x_ || min_y_ > max_y_; }
     46   float Area() const {
     47     return static_cast<float>((max_x_ - min_x_) * (max_y_ - min_y_));
     48   }
     49 
     50   Rectangle Intersect(const Rectangle& r) const {
     51     const int pmin_x = std::max(min_x_, r.min_x_);
     52     const int pmin_y = std::max(min_y_, r.min_y_);
     53     const int pmax_x = std::min(max_x_, r.max_x_);
     54     const int pmax_y = std::min(max_y_, r.max_y_);
     55 
     56     if (pmin_x > pmax_x || pmin_y > pmax_y) {
     57       return Rectangle();
     58     } else {
     59       return Rectangle(pmin_x, pmin_y, pmax_x, pmax_y);
     60     }
     61   }
     62 
     63   int min_x_;
     64   int min_y_;
     65   int max_x_;
     66   int max_y_;
     67 };
     68 
     69 // Determine if the supplied cropping box covers a sufficient fraction of the
     70 // the supplied bounding boxes.
     71 bool SatisfiesOverlapConstraints(const Rectangle& crop,
     72                                  float minimum_object_covered,
     73                                  const std::vector<Rectangle>& bounding_boxes) {
     74   // Reject any bounding box which contains no pixels.
     75   const float kMinArea = 1.0;
     76   if (crop.Area() < kMinArea) {
     77     return false;
     78   }
     79 
     80   // Loop through all objects and determine if the proposed cropping box covers
     81   // a sufficient fraction of one of the supplied bounding boxes.
     82   bool is_object_covered = false;
     83   for (const auto& bbox : bounding_boxes) {
     84     const float object_area = bbox.Area();
     85     if (object_area < kMinArea) {
     86       continue;
     87     }
     88 
     89     const float object_covered = crop.Intersect(bbox).Area() / object_area;
     90 
     91     if (object_covered >= minimum_object_covered) {
     92       is_object_covered = true;
     93       break;
     94     }
     95   }
     96   return is_object_covered;
     97 }
     98 
     99 // Generate a random crop within the rectangle
    100 // (0, 0, original_width, original_height).
    101 // The minimum area of the crop will be between
    102 //   min_relative_crop_area * orig_width * orig_height
    103 // and
    104 //   max_relative_crop_area * orig_width * orig_height
    105 // such that its width = round(aspect_ratio * height).
    106 // The diameter of the generated rectangle will be uniformly distributed between
    107 // its minimum and maximum size. The center of the rectangle will be distributed
    108 // uniformly within the source rectangle. The function returns false if the
    109 // rectangle could not be generated with the given constraints.
    110 bool GenerateRandomCrop(int original_width, int original_height,
    111                         float min_relative_crop_area,
    112                         float max_relative_crop_area, float aspect_ratio,
    113                         SimplePhilox* random, Rectangle* crop_rect) {
    114   if (max_relative_crop_area <= 0.0 || aspect_ratio <= 0.0 ||
    115       original_width <= 0 || original_height <= 0 ||
    116       min_relative_crop_area > max_relative_crop_area) {
    117     return false;
    118   }
    119 
    120   const float min_area =
    121       min_relative_crop_area * original_width * original_height;
    122   const float max_area =
    123       max_relative_crop_area * original_width * original_height;
    124 
    125   int height = static_cast<int>(lrintf(sqrt(min_area / aspect_ratio)));
    126   int max_height = static_cast<int>(lrintf(sqrt(max_area / aspect_ratio)));
    127 
    128   if (lrintf(max_height * aspect_ratio) > original_width) {
    129     // We must find the smallest max_height satisfying
    130     // round(max_height * aspect_ratio) <= original_width:
    131     const float kEps = 0.0000001;
    132     max_height = static_cast<int>((original_width + 0.5 - kEps) / aspect_ratio);
    133   }
    134 
    135   if (max_height > original_height) {
    136     max_height = original_height;
    137   }
    138 
    139   if (height >= max_height) {
    140     height = max_height;
    141   }
    142 
    143   if (height < max_height) {
    144     // We need to generate a random number in the closed range
    145     // [0, max_height - height].
    146     height += random->Uniform(max_height - height + 1);
    147   }
    148   int width = static_cast<int>(lrintf(height * aspect_ratio));
    149   DCHECK_LE(width, original_width);
    150 
    151   // Let us not fail if rounding error causes the area to be
    152   // outside the constraints.
    153   // Try first with a slightly bigger rectangle first.
    154   float area = static_cast<float>(width * height);
    155   if (area < min_area) {
    156     height += 1;
    157     width = static_cast<int>(lrintf(height * aspect_ratio));
    158     area = width * height;
    159   }
    160 
    161   // Let us not fail if rounding error causes the area to be
    162   // outside the constraints.
    163   // Try first with a slightly smaller rectangle first.
    164   if (area > max_area) {
    165     height -= 1;
    166     width = static_cast<int>(lrintf(height * aspect_ratio));
    167     area = width * height;
    168   }
    169 
    170   // Now, we explored all options to rectify small rounding errors.
    171   // It seems the constraints can't be satisfied: return false.
    172   if (area < min_area || area > max_area || width > original_width ||
    173       height > original_height || width <= 0 || height <= 0) {
    174     return false;
    175   }
    176 
    177   int y = 0;
    178   if (height < original_height) {
    179     y = random->Uniform(original_height - height);
    180   }
    181   int x = 0;
    182   if (width < original_width) {
    183     x = random->Uniform(original_width - width);
    184   }
    185 
    186   crop_rect->min_x_ = x;
    187   crop_rect->min_y_ = y;
    188   crop_rect->max_x_ = x + width;
    189   crop_rect->max_y_ = y + height;
    190   return true;
    191 }
    192 }  // namespace
    193 
    194 template <typename T>
    195 class SampleDistortedBoundingBoxV2Op : public OpKernel {
    196  public:
    197   explicit SampleDistortedBoundingBoxV2Op(OpKernelConstruction* context)
    198       : OpKernel(context) {
    199     OP_REQUIRES_OK(context, generator_.Init(context));
    200 
    201     if (context->num_inputs() == 2) {
    202       OP_REQUIRES_OK(context, context->GetAttr("min_object_covered",
    203                                                &min_object_covered_));
    204       OP_REQUIRES(
    205           context, min_object_covered_ >= 0,
    206           errors::InvalidArgument("Min object covered must be non-negative: ",
    207                                   min_object_covered_));
    208     }
    209 
    210     OP_REQUIRES_OK(context, context->GetAttr("use_image_if_no_bounding_boxes",
    211                                              &use_image_if_no_bounding_boxes_));
    212 
    213     OP_REQUIRES_OK(
    214         context, context->GetAttr("aspect_ratio_range", &aspect_ratio_range_));
    215     OP_REQUIRES(context, aspect_ratio_range_.size() == 2,
    216                 errors::InvalidArgument(
    217                     "Aspect ratio range field must specify 2 dimensions"));
    218 
    219     OP_REQUIRES(
    220         context, aspect_ratio_range_[0] > 0 && aspect_ratio_range_[1] > 0,
    221         errors::InvalidArgument("Aspect ratio range must be non-negative: [",
    222                                 aspect_ratio_range_[0], ", ",
    223                                 aspect_ratio_range_[1], "]"));
    224 
    225     OP_REQUIRES_OK(context, context->GetAttr("area_range", &area_range_));
    226     OP_REQUIRES(
    227         context, area_range_.size() == 2,
    228         errors::InvalidArgument("Area range field must specify 2 dimensions"));
    229 
    230     OP_REQUIRES(
    231         context, area_range_[0] > 0 && area_range_[1] > 0,
    232         errors::InvalidArgument("Area range must be non-negative: [",
    233                                 area_range_[0], ", ", area_range_[1], "]"));
    234 
    235     OP_REQUIRES(context, area_range_[0] <= 1 && area_range_[1] <= 1,
    236                 errors::InvalidArgument(
    237                     "Area range must be less then or equal to 1.0: [",
    238                     area_range_[0], ", ", area_range_[1], "]"));
    239 
    240     OP_REQUIRES_OK(context, context->GetAttr("max_attempts", &max_attempts_));
    241     OP_REQUIRES(context, max_attempts_ > 0,
    242                 errors::InvalidArgument("Max attempts must be non-negative: ",
    243                                         max_attempts_));
    244   }
    245 
    246   void Compute(OpKernelContext* context) override {
    247     const Tensor& image_size = context->input(0);
    248 
    249     OP_REQUIRES(context, image_size.dims() == 1,
    250                 errors::InvalidArgument("image_size must be 1-dimensional",
    251                                         image_size.shape().DebugString()));
    252     OP_REQUIRES(context, image_size.dim_size(0) == 3,
    253                 errors::InvalidArgument("image_size must contain 3 elements",
    254                                         image_size.shape().DebugString()));
    255 
    256     // Note image_size_data(2) is the depth and unused.
    257     const uint64 height_raw = internal::SubtleMustCopy(image_size.flat<T>()(0));
    258     const uint64 width_raw = internal::SubtleMustCopy(image_size.flat<T>()(1));
    259     OP_REQUIRES(context,
    260                 FastBoundsCheck(height_raw, std::numeric_limits<int32>::max()),
    261                 errors::InvalidArgument("image height cannot be >= int32 max"));
    262     OP_REQUIRES(context,
    263                 FastBoundsCheck(width_raw, std::numeric_limits<int32>::max()),
    264                 errors::InvalidArgument("image width cannot be >= int32 max"));
    265     const int32 height = static_cast<int32>(height_raw);
    266     const int32 width = static_cast<int32>(width_raw);
    267 
    268     // Ensure that the supplied bounding boxes are sane and convert them to
    269     // Rectangles.
    270     const Tensor& input_boxes = context->input(1);
    271     OP_REQUIRES(context, input_boxes.dims() == 3,
    272                 errors::InvalidArgument("input boxes must be 3-dimensional "
    273                                         "[batch, num_boxes, coords]: ",
    274                                         input_boxes.shape().DebugString()));
    275     OP_REQUIRES(context, input_boxes.dim_size(input_boxes.dims() - 1) == 4,
    276                 errors::InvalidArgument(
    277                     "bounding boxes must have shape [4] or [*, 4], got ",
    278                     input_boxes.shape().DebugString()));
    279 
    280     float min_object_covered_val = 0.0;
    281     if (context->num_inputs() == 3) {
    282       const Tensor& min_object_covered = context->input(2);
    283 
    284       OP_REQUIRES(
    285           context, TensorShapeUtils::IsScalar(min_object_covered.shape()),
    286           errors::InvalidArgument("min_object_covered must be 0-D, got shape ",
    287                                   min_object_covered.shape().DebugString()));
    288 
    289       min_object_covered_val = min_object_covered.scalar<float>()();
    290 
    291       OP_REQUIRES(
    292           context, min_object_covered_val >= 0,
    293           errors::InvalidArgument("Min object covered must be non-negative: ",
    294                                   min_object_covered_val));
    295     } else {
    296       min_object_covered_val = min_object_covered_;
    297     }
    298 
    299     std::vector<Rectangle> bounding_boxes;
    300     if (input_boxes.NumElements() > 0) {
    301       TTypes<float>::ConstMatrix boxes = input_boxes.flat_inner_dims<float>();
    302       for (int b = 0; b < boxes.dimension(0); ++b) {
    303         for (int i = 0; i < 4; ++i) {
    304           OP_REQUIRES(
    305               context, boxes(b, i) >= 0.0 && boxes(b, i) <= 1.0,
    306               errors::InvalidArgument("All bounding box coordinates must "
    307                                       "be in [0.0, 1.0]: ",
    308                                       boxes(b, i)));
    309         }
    310 
    311         const int32 x_min = static_cast<int32>(boxes(b, 1) * width);
    312         const int32 y_min = static_cast<int32>(boxes(b, 0) * height);
    313         const int32 x_max = static_cast<int32>(boxes(b, 3) * width);
    314         const int32 y_max = static_cast<int32>(boxes(b, 2) * height);
    315 
    316         bounding_boxes.push_back(Rectangle(x_min, y_min, x_max, y_max));
    317       }
    318     }
    319 
    320     // Insert the entire image if no bounding boxes are supplied.
    321     const Rectangle image_rect(0, 0, width, height);
    322     if (bounding_boxes.empty()) {
    323       OP_REQUIRES(context, use_image_if_no_bounding_boxes_,
    324                   errors::InvalidArgument(
    325                       "No bounding boxes provided as input. One must "
    326                       "enable use_image_if_no_bounding_boxes if you wish "
    327                       "to not provide any bounding boxes."));
    328       bounding_boxes.push_back(image_rect);
    329     }
    330 
    331     const float min_sample_area = area_range_[0];
    332     const float max_sample_area = area_range_[1];
    333     const float min_sample_aspect_ratio = aspect_ratio_range_[0];
    334     const float max_sample_aspect_ratio = aspect_ratio_range_[1];
    335 
    336     auto local_gen = generator_.ReserveSamples32(4 * max_attempts_);
    337     random::SimplePhilox random(&local_gen);
    338 
    339     Rectangle crop_rect;
    340     bool sample_generated = false;
    341     for (int i = 0; i < max_attempts_; ++i) {
    342       const float sample_aspect_ratio =
    343           random.RandFloat() *
    344               (max_sample_aspect_ratio - min_sample_aspect_ratio) +
    345           min_sample_aspect_ratio;
    346 
    347       if (GenerateRandomCrop(width, height, min_sample_area, max_sample_area,
    348                              sample_aspect_ratio, &random, &crop_rect)) {
    349         if (SatisfiesOverlapConstraints(crop_rect, min_object_covered_val,
    350                                         bounding_boxes)) {
    351           sample_generated = true;
    352           break;
    353         }
    354       }
    355     }
    356 
    357     if (!sample_generated) {
    358       crop_rect = image_rect;
    359     }
    360 
    361     // Determine the cropping parameters from the bounding box.
    362     const int target_width = crop_rect.max_x_ - crop_rect.min_x_;
    363     const int target_height = crop_rect.max_y_ - crop_rect.min_y_;
    364 
    365     const int offset_width = crop_rect.min_x_;
    366     const int offset_height = crop_rect.min_y_;
    367 
    368     // Ensure that the bounding box fits in the image dimensions.
    369     OP_REQUIRES(context, width >= target_width + offset_width,
    370                 errors::FailedPrecondition(
    371                     "width must be > target_width + offset_width: ", width,
    372                     "vs ", target_width, " + ", offset_width));
    373     OP_REQUIRES(context, height >= target_height + offset_height,
    374                 errors::FailedPrecondition(
    375                     "height must be >= target_height: height = ", height, "vs ",
    376                     target_height, " + ", offset_height));
    377 
    378     // Create two vectors, each 3 elements, to provide as arguments to Slice.
    379     // See Slice() operation for details.
    380     Tensor* begin = nullptr;
    381     OP_REQUIRES_OK(context,
    382                    context->allocate_output(0, TensorShape({3}), &begin));
    383     Tensor* size = nullptr;
    384     OP_REQUIRES_OK(context,
    385                    context->allocate_output(1, TensorShape({3}), &size));
    386     Tensor* bboxes = nullptr;
    387     OP_REQUIRES_OK(
    388         context, context->allocate_output(2, TensorShape({1, 1, 4}), &bboxes));
    389 
    390     typename TTypes<T, 1>::Tensor begin_data(begin->tensor<T, 1>());
    391     typename TTypes<T, 1>::Tensor size_data(size->tensor<T, 1>());
    392     TTypes<float, 3>::Tensor bboxes_data = bboxes->tensor<float, 3>();
    393 
    394     begin_data(0) = T(offset_height);
    395     size_data(0) = T(target_height);
    396 
    397     begin_data(1) = T(offset_width);
    398     size_data(1) = T(target_width);
    399 
    400     bboxes_data(0, 0, 0) =
    401         static_cast<float>(crop_rect.min_y_) / static_cast<float>(height);
    402     bboxes_data(0, 0, 1) =
    403         static_cast<float>(crop_rect.min_x_) / static_cast<float>(width);
    404     bboxes_data(0, 0, 2) =
    405         static_cast<float>(crop_rect.max_y_) / static_cast<float>(height);
    406     bboxes_data(0, 0, 3) =
    407         static_cast<float>(crop_rect.max_x_) / static_cast<float>(width);
    408 
    409     // Retain all of the channels.
    410     begin_data(2) = T(0);
    411     size_data(2) = T(-1);
    412   }
    413 
    414  private:
    415   GuardedPhiloxRandom generator_;
    416   int32 max_attempts_;
    417   std::vector<float> area_range_;
    418   std::vector<float> aspect_ratio_range_;
    419   float min_object_covered_;
    420   bool use_image_if_no_bounding_boxes_;
    421 };
    422 
    423 #define REGISTER_KERNELS(type)                                  \
    424   REGISTER_KERNEL_BUILDER(Name("SampleDistortedBoundingBox")    \
    425                               .Device(DEVICE_CPU)               \
    426                               .TypeConstraint<type>("T"),       \
    427                           SampleDistortedBoundingBoxV2Op<type>) \
    428   REGISTER_KERNEL_BUILDER(Name("SampleDistortedBoundingBoxV2")  \
    429                               .Device(DEVICE_CPU)               \
    430                               .TypeConstraint<type>("T"),       \
    431                           SampleDistortedBoundingBoxV2Op<type>)
    432 
    433 TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
    434 #undef REGISTER_KERNELS
    435 
    436 }  // namespace tensorflow
    437