1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // See docs in ../ops/image_ops.cc. 16 #include <math.h> 17 #include "tensorflow/core/framework/op_kernel.h" 18 #include "tensorflow/core/framework/register_types.h" 19 #include "tensorflow/core/framework/tensor.h" 20 #include "tensorflow/core/framework/types.h" 21 #include "tensorflow/core/kernels/bounds_check.h" 22 #include "tensorflow/core/lib/random/simple_philox.h" 23 #include "tensorflow/core/util/guarded_philox_random.h" 24 25 using tensorflow::random::SimplePhilox; 26 27 namespace tensorflow { 28 namespace { 29 30 // A simple Rectangle class that supplies intersection. 31 class Rectangle { 32 public: 33 Rectangle() { Set(0, 0, 0, 0); } 34 Rectangle(int xmin, int ymin, int xmax, int ymax) { 35 Set(xmin, ymin, xmax, ymax); 36 } 37 38 void Set(int xmin, int ymin, int xmax, int ymax) { 39 min_x_ = xmin; 40 min_y_ = ymin; 41 max_x_ = xmax; 42 max_y_ = ymax; 43 } 44 45 bool IsEmpty() const { return min_x_ > max_x_ || min_y_ > max_y_; } 46 float Area() const { 47 return static_cast<float>((max_x_ - min_x_) * (max_y_ - min_y_)); 48 } 49 50 Rectangle Intersect(const Rectangle& r) const { 51 const int pmin_x = std::max(min_x_, r.min_x_); 52 const int pmin_y = std::max(min_y_, r.min_y_); 53 const int pmax_x = std::min(max_x_, r.max_x_); 54 const int pmax_y = std::min(max_y_, r.max_y_); 55 56 if (pmin_x > pmax_x || pmin_y > pmax_y) { 57 return Rectangle(); 58 } else { 59 return Rectangle(pmin_x, pmin_y, pmax_x, pmax_y); 60 } 61 } 62 63 int min_x_; 64 int min_y_; 65 int max_x_; 66 int max_y_; 67 }; 68 69 // Determine if the supplied cropping box covers a sufficient fraction of the 70 // the supplied bounding boxes. 71 bool SatisfiesOverlapConstraints(const Rectangle& crop, 72 float minimum_object_covered, 73 const std::vector<Rectangle>& bounding_boxes) { 74 // Reject any bounding box which contains no pixels. 75 const float kMinArea = 1.0; 76 if (crop.Area() < kMinArea) { 77 return false; 78 } 79 80 // Loop through all objects and determine if the proposed cropping box covers 81 // a sufficient fraction of one of the supplied bounding boxes. 82 bool is_object_covered = false; 83 for (const auto& bbox : bounding_boxes) { 84 const float object_area = bbox.Area(); 85 if (object_area < kMinArea) { 86 continue; 87 } 88 89 const float object_covered = crop.Intersect(bbox).Area() / object_area; 90 91 if (object_covered >= minimum_object_covered) { 92 is_object_covered = true; 93 break; 94 } 95 } 96 return is_object_covered; 97 } 98 99 // Generate a random crop within the rectangle 100 // (0, 0, original_width, original_height). 101 // The minimum area of the crop will be between 102 // min_relative_crop_area * orig_width * orig_height 103 // and 104 // max_relative_crop_area * orig_width * orig_height 105 // such that its width = round(aspect_ratio * height). 106 // The diameter of the generated rectangle will be uniformly distributed between 107 // its minimum and maximum size. The center of the rectangle will be distributed 108 // uniformly within the source rectangle. The function returns false if the 109 // rectangle could not be generated with the given constraints. 110 bool GenerateRandomCrop(int original_width, int original_height, 111 float min_relative_crop_area, 112 float max_relative_crop_area, float aspect_ratio, 113 SimplePhilox* random, Rectangle* crop_rect) { 114 if (max_relative_crop_area <= 0.0 || aspect_ratio <= 0.0 || 115 original_width <= 0 || original_height <= 0 || 116 min_relative_crop_area > max_relative_crop_area) { 117 return false; 118 } 119 120 const float min_area = 121 min_relative_crop_area * original_width * original_height; 122 const float max_area = 123 max_relative_crop_area * original_width * original_height; 124 125 int height = static_cast<int>(lrintf(sqrt(min_area / aspect_ratio))); 126 int max_height = static_cast<int>(lrintf(sqrt(max_area / aspect_ratio))); 127 128 if (lrintf(max_height * aspect_ratio) > original_width) { 129 // We must find the smallest max_height satisfying 130 // round(max_height * aspect_ratio) <= original_width: 131 const float kEps = 0.0000001; 132 max_height = static_cast<int>((original_width + 0.5 - kEps) / aspect_ratio); 133 } 134 135 if (max_height > original_height) { 136 max_height = original_height; 137 } 138 139 if (height >= max_height) { 140 height = max_height; 141 } 142 143 if (height < max_height) { 144 // We need to generate a random number in the closed range 145 // [0, max_height - height]. 146 height += random->Uniform(max_height - height + 1); 147 } 148 int width = static_cast<int>(lrintf(height * aspect_ratio)); 149 DCHECK_LE(width, original_width); 150 151 // Let us not fail if rounding error causes the area to be 152 // outside the constraints. 153 // Try first with a slightly bigger rectangle first. 154 float area = static_cast<float>(width * height); 155 if (area < min_area) { 156 height += 1; 157 width = static_cast<int>(lrintf(height * aspect_ratio)); 158 area = width * height; 159 } 160 161 // Let us not fail if rounding error causes the area to be 162 // outside the constraints. 163 // Try first with a slightly smaller rectangle first. 164 if (area > max_area) { 165 height -= 1; 166 width = static_cast<int>(lrintf(height * aspect_ratio)); 167 area = width * height; 168 } 169 170 // Now, we explored all options to rectify small rounding errors. 171 // It seems the constraints can't be satisfied: return false. 172 if (area < min_area || area > max_area || width > original_width || 173 height > original_height || width <= 0 || height <= 0) { 174 return false; 175 } 176 177 int y = 0; 178 if (height < original_height) { 179 y = random->Uniform(original_height - height); 180 } 181 int x = 0; 182 if (width < original_width) { 183 x = random->Uniform(original_width - width); 184 } 185 186 crop_rect->min_x_ = x; 187 crop_rect->min_y_ = y; 188 crop_rect->max_x_ = x + width; 189 crop_rect->max_y_ = y + height; 190 return true; 191 } 192 } // namespace 193 194 template <typename T> 195 class SampleDistortedBoundingBoxV2Op : public OpKernel { 196 public: 197 explicit SampleDistortedBoundingBoxV2Op(OpKernelConstruction* context) 198 : OpKernel(context) { 199 OP_REQUIRES_OK(context, generator_.Init(context)); 200 201 if (context->num_inputs() == 2) { 202 OP_REQUIRES_OK(context, context->GetAttr("min_object_covered", 203 &min_object_covered_)); 204 OP_REQUIRES( 205 context, min_object_covered_ >= 0, 206 errors::InvalidArgument("Min object covered must be non-negative: ", 207 min_object_covered_)); 208 } 209 210 OP_REQUIRES_OK(context, context->GetAttr("use_image_if_no_bounding_boxes", 211 &use_image_if_no_bounding_boxes_)); 212 213 OP_REQUIRES_OK( 214 context, context->GetAttr("aspect_ratio_range", &aspect_ratio_range_)); 215 OP_REQUIRES(context, aspect_ratio_range_.size() == 2, 216 errors::InvalidArgument( 217 "Aspect ratio range field must specify 2 dimensions")); 218 219 OP_REQUIRES( 220 context, aspect_ratio_range_[0] > 0 && aspect_ratio_range_[1] > 0, 221 errors::InvalidArgument("Aspect ratio range must be non-negative: [", 222 aspect_ratio_range_[0], ", ", 223 aspect_ratio_range_[1], "]")); 224 225 OP_REQUIRES_OK(context, context->GetAttr("area_range", &area_range_)); 226 OP_REQUIRES( 227 context, area_range_.size() == 2, 228 errors::InvalidArgument("Area range field must specify 2 dimensions")); 229 230 OP_REQUIRES( 231 context, area_range_[0] > 0 && area_range_[1] > 0, 232 errors::InvalidArgument("Area range must be non-negative: [", 233 area_range_[0], ", ", area_range_[1], "]")); 234 235 OP_REQUIRES(context, area_range_[0] <= 1 && area_range_[1] <= 1, 236 errors::InvalidArgument( 237 "Area range must be less then or equal to 1.0: [", 238 area_range_[0], ", ", area_range_[1], "]")); 239 240 OP_REQUIRES_OK(context, context->GetAttr("max_attempts", &max_attempts_)); 241 OP_REQUIRES(context, max_attempts_ > 0, 242 errors::InvalidArgument("Max attempts must be non-negative: ", 243 max_attempts_)); 244 } 245 246 void Compute(OpKernelContext* context) override { 247 const Tensor& image_size = context->input(0); 248 249 OP_REQUIRES(context, image_size.dims() == 1, 250 errors::InvalidArgument("image_size must be 1-dimensional", 251 image_size.shape().DebugString())); 252 OP_REQUIRES(context, image_size.dim_size(0) == 3, 253 errors::InvalidArgument("image_size must contain 3 elements", 254 image_size.shape().DebugString())); 255 256 // Note image_size_data(2) is the depth and unused. 257 const uint64 height_raw = internal::SubtleMustCopy(image_size.flat<T>()(0)); 258 const uint64 width_raw = internal::SubtleMustCopy(image_size.flat<T>()(1)); 259 OP_REQUIRES(context, 260 FastBoundsCheck(height_raw, std::numeric_limits<int32>::max()), 261 errors::InvalidArgument("image height cannot be >= int32 max")); 262 OP_REQUIRES(context, 263 FastBoundsCheck(width_raw, std::numeric_limits<int32>::max()), 264 errors::InvalidArgument("image width cannot be >= int32 max")); 265 const int32 height = static_cast<int32>(height_raw); 266 const int32 width = static_cast<int32>(width_raw); 267 268 // Ensure that the supplied bounding boxes are sane and convert them to 269 // Rectangles. 270 const Tensor& input_boxes = context->input(1); 271 OP_REQUIRES(context, input_boxes.dims() == 3, 272 errors::InvalidArgument("input boxes must be 3-dimensional " 273 "[batch, num_boxes, coords]: ", 274 input_boxes.shape().DebugString())); 275 OP_REQUIRES(context, input_boxes.dim_size(input_boxes.dims() - 1) == 4, 276 errors::InvalidArgument( 277 "bounding boxes must have shape [4] or [*, 4], got ", 278 input_boxes.shape().DebugString())); 279 280 float min_object_covered_val = 0.0; 281 if (context->num_inputs() == 3) { 282 const Tensor& min_object_covered = context->input(2); 283 284 OP_REQUIRES( 285 context, TensorShapeUtils::IsScalar(min_object_covered.shape()), 286 errors::InvalidArgument("min_object_covered must be 0-D, got shape ", 287 min_object_covered.shape().DebugString())); 288 289 min_object_covered_val = min_object_covered.scalar<float>()(); 290 291 OP_REQUIRES( 292 context, min_object_covered_val >= 0, 293 errors::InvalidArgument("Min object covered must be non-negative: ", 294 min_object_covered_val)); 295 } else { 296 min_object_covered_val = min_object_covered_; 297 } 298 299 std::vector<Rectangle> bounding_boxes; 300 if (input_boxes.NumElements() > 0) { 301 TTypes<float>::ConstMatrix boxes = input_boxes.flat_inner_dims<float>(); 302 for (int b = 0; b < boxes.dimension(0); ++b) { 303 for (int i = 0; i < 4; ++i) { 304 OP_REQUIRES( 305 context, boxes(b, i) >= 0.0 && boxes(b, i) <= 1.0, 306 errors::InvalidArgument("All bounding box coordinates must " 307 "be in [0.0, 1.0]: ", 308 boxes(b, i))); 309 } 310 311 const int32 x_min = static_cast<int32>(boxes(b, 1) * width); 312 const int32 y_min = static_cast<int32>(boxes(b, 0) * height); 313 const int32 x_max = static_cast<int32>(boxes(b, 3) * width); 314 const int32 y_max = static_cast<int32>(boxes(b, 2) * height); 315 316 bounding_boxes.push_back(Rectangle(x_min, y_min, x_max, y_max)); 317 } 318 } 319 320 // Insert the entire image if no bounding boxes are supplied. 321 const Rectangle image_rect(0, 0, width, height); 322 if (bounding_boxes.empty()) { 323 OP_REQUIRES(context, use_image_if_no_bounding_boxes_, 324 errors::InvalidArgument( 325 "No bounding boxes provided as input. One must " 326 "enable use_image_if_no_bounding_boxes if you wish " 327 "to not provide any bounding boxes.")); 328 bounding_boxes.push_back(image_rect); 329 } 330 331 const float min_sample_area = area_range_[0]; 332 const float max_sample_area = area_range_[1]; 333 const float min_sample_aspect_ratio = aspect_ratio_range_[0]; 334 const float max_sample_aspect_ratio = aspect_ratio_range_[1]; 335 336 auto local_gen = generator_.ReserveSamples32(4 * max_attempts_); 337 random::SimplePhilox random(&local_gen); 338 339 Rectangle crop_rect; 340 bool sample_generated = false; 341 for (int i = 0; i < max_attempts_; ++i) { 342 const float sample_aspect_ratio = 343 random.RandFloat() * 344 (max_sample_aspect_ratio - min_sample_aspect_ratio) + 345 min_sample_aspect_ratio; 346 347 if (GenerateRandomCrop(width, height, min_sample_area, max_sample_area, 348 sample_aspect_ratio, &random, &crop_rect)) { 349 if (SatisfiesOverlapConstraints(crop_rect, min_object_covered_val, 350 bounding_boxes)) { 351 sample_generated = true; 352 break; 353 } 354 } 355 } 356 357 if (!sample_generated) { 358 crop_rect = image_rect; 359 } 360 361 // Determine the cropping parameters from the bounding box. 362 const int target_width = crop_rect.max_x_ - crop_rect.min_x_; 363 const int target_height = crop_rect.max_y_ - crop_rect.min_y_; 364 365 const int offset_width = crop_rect.min_x_; 366 const int offset_height = crop_rect.min_y_; 367 368 // Ensure that the bounding box fits in the image dimensions. 369 OP_REQUIRES(context, width >= target_width + offset_width, 370 errors::FailedPrecondition( 371 "width must be > target_width + offset_width: ", width, 372 "vs ", target_width, " + ", offset_width)); 373 OP_REQUIRES(context, height >= target_height + offset_height, 374 errors::FailedPrecondition( 375 "height must be >= target_height: height = ", height, "vs ", 376 target_height, " + ", offset_height)); 377 378 // Create two vectors, each 3 elements, to provide as arguments to Slice. 379 // See Slice() operation for details. 380 Tensor* begin = nullptr; 381 OP_REQUIRES_OK(context, 382 context->allocate_output(0, TensorShape({3}), &begin)); 383 Tensor* size = nullptr; 384 OP_REQUIRES_OK(context, 385 context->allocate_output(1, TensorShape({3}), &size)); 386 Tensor* bboxes = nullptr; 387 OP_REQUIRES_OK( 388 context, context->allocate_output(2, TensorShape({1, 1, 4}), &bboxes)); 389 390 typename TTypes<T, 1>::Tensor begin_data(begin->tensor<T, 1>()); 391 typename TTypes<T, 1>::Tensor size_data(size->tensor<T, 1>()); 392 TTypes<float, 3>::Tensor bboxes_data = bboxes->tensor<float, 3>(); 393 394 begin_data(0) = T(offset_height); 395 size_data(0) = T(target_height); 396 397 begin_data(1) = T(offset_width); 398 size_data(1) = T(target_width); 399 400 bboxes_data(0, 0, 0) = 401 static_cast<float>(crop_rect.min_y_) / static_cast<float>(height); 402 bboxes_data(0, 0, 1) = 403 static_cast<float>(crop_rect.min_x_) / static_cast<float>(width); 404 bboxes_data(0, 0, 2) = 405 static_cast<float>(crop_rect.max_y_) / static_cast<float>(height); 406 bboxes_data(0, 0, 3) = 407 static_cast<float>(crop_rect.max_x_) / static_cast<float>(width); 408 409 // Retain all of the channels. 410 begin_data(2) = T(0); 411 size_data(2) = T(-1); 412 } 413 414 private: 415 GuardedPhiloxRandom generator_; 416 int32 max_attempts_; 417 std::vector<float> area_range_; 418 std::vector<float> aspect_ratio_range_; 419 float min_object_covered_; 420 bool use_image_if_no_bounding_boxes_; 421 }; 422 423 #define REGISTER_KERNELS(type) \ 424 REGISTER_KERNEL_BUILDER(Name("SampleDistortedBoundingBox") \ 425 .Device(DEVICE_CPU) \ 426 .TypeConstraint<type>("T"), \ 427 SampleDistortedBoundingBoxV2Op<type>) \ 428 REGISTER_KERNEL_BUILDER(Name("SampleDistortedBoundingBoxV2") \ 429 .Device(DEVICE_CPU) \ 430 .TypeConstraint<type>("T"), \ 431 SampleDistortedBoundingBoxV2Op<type>) 432 433 TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS); 434 #undef REGISTER_KERNELS 435 436 } // namespace tensorflow 437