Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/image_ops.cc
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "tensorflow/core/kernels/crop_and_resize_op.h"
     21 
     22 #include <functional>
     23 #include <string>
     24 
     25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     26 #include "tensorflow/core/framework/register_types.h"
     27 #include "tensorflow/core/framework/tensor.h"
     28 #include "tensorflow/core/framework/tensor_shape.h"
     29 #include "tensorflow/core/framework/types.h"
     30 #include "tensorflow/core/kernels/bounds_check.h"
     31 #include "tensorflow/core/lib/core/errors.h"
     32 #include "tensorflow/core/lib/core/status.h"
     33 #include "tensorflow/core/platform/logging.h"
     34 #include "tensorflow/core/platform/types.h"
     35 #include "tensorflow/core/util/work_sharder.h"
     36 
     37 #if GOOGLE_CUDA
     38 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
     39 #include "tensorflow/core/platform/cuda.h"
     40 #include "tensorflow/core/platform/stream_executor.h"
     41 
     42 using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
     43 #endif  // GOOGLE_CUDA
     44 
     45 namespace tensorflow {
     46 
     47 typedef Eigen::ThreadPoolDevice CPUDevice;
     48 typedef Eigen::GpuDevice GPUDevice;
     49 using Callback = std::function<void()>;
     50 
     51 namespace {
     52 
     53 static inline Status ParseAndCheckBoxSizes(const Tensor& boxes,
     54                                            const Tensor& box_index,
     55                                            int* num_boxes) {
     56   if (boxes.NumElements() == 0 && box_index.NumElements() == 0) {
     57     *num_boxes = 0;
     58     return Status::OK();
     59   }
     60   // The shape of 'boxes' is [num_boxes, 4].
     61   if (boxes.dims() != 2) {
     62     return errors::InvalidArgument("boxes must be 2-D",
     63                                    boxes.shape().DebugString());
     64   }
     65   *num_boxes = boxes.dim_size(0);
     66   if (boxes.dim_size(1) != 4) {
     67     return errors::InvalidArgument("boxes must have 4 columns");
     68   }
     69   // The shape of 'box_index' is [num_boxes].
     70   if (box_index.dims() != 1) {
     71     return errors::InvalidArgument("box_index must be 1-D",
     72                                    box_index.shape().DebugString());
     73   }
     74   if (box_index.dim_size(0) != *num_boxes) {
     75     return errors::InvalidArgument("box_index has incompatible shape");
     76   }
     77   return Status::OK();
     78 }
     79 
     80 // Conditionally calls the compute callback if all values in box_index are in
     81 // [0, batch_size) then calls done.
     82 template <typename Device>
     83 inline void RunIfBoxIndexIsValid(
     84     OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
     85     int batch_size, const Callback& compute, const Callback& done);
     86 
     87 // Specialization of CheckValidBoxIndex for a CPUDevice.
     88 template <>
     89 inline void RunIfBoxIndexIsValid<CPUDevice>(
     90     OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
     91     int batch_size, const Callback& compute, const Callback& done) {
     92   const int num_boxes = box_index.dimension(0);
     93   for (int b = 0; b < num_boxes; ++b) {
     94     OP_REQUIRES_ASYNC(
     95         context, FastBoundsCheck(box_index(b), batch_size),
     96         errors::OutOfRange("box_index has values outside [0, batch_size)"),
     97         done);
     98   }
     99   if (compute) {
    100     compute();
    101   }
    102   if (done) {
    103     done();
    104   }
    105 }
    106 
    107 }  // namespace
    108 
    109 template <typename Device, typename T>
    110 class CropAndResizeOp : public AsyncOpKernel {
    111  public:
    112   explicit CropAndResizeOp(OpKernelConstruction* context)
    113       : AsyncOpKernel(context) {
    114     string method;
    115     OP_REQUIRES_OK(context, context->GetAttr("method", &method));
    116     OP_REQUIRES(context, method == "bilinear",
    117                 errors::InvalidArgument("method must be 'bilinear'", method));
    118     OP_REQUIRES_OK(context, context->GetAttr("extrapolation_value",
    119                                              &extrapolation_value_));
    120   }
    121 
    122   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
    123     // The shape of 'image' is [batch_size, image_height, image_width,
    124     // channels].
    125     const Tensor& image = context->input(0);
    126     // The shape of 'boxes' is [num_boxes, 4].
    127     const Tensor& boxes = context->input(1);
    128     // The shape of 'box_index' is [num_boxes].
    129     const Tensor& box_index = context->input(2);
    130     // The shape of 'crop_size' is [2].
    131     const Tensor& crop_size = context->input(3);
    132 
    133     // Validate inputs dimensions.
    134     OP_REQUIRES_ASYNC(context, image.dims() == 4,
    135                       errors::InvalidArgument("input image must be 4-D",
    136                                               image.shape().DebugString()),
    137                       done);
    138     const int batch_size = image.dim_size(0);
    139     const int image_height = image.dim_size(1);
    140     const int image_width = image.dim_size(2);
    141     const int depth = image.dim_size(3);
    142     OP_REQUIRES_ASYNC(
    143         context, image_height > 0 && image_width > 0,
    144         errors::InvalidArgument("image dimensions must be positive"), done);
    145     int num_boxes = 0;
    146     OP_REQUIRES_OK_ASYNC(
    147         context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
    148 
    149     OP_REQUIRES_ASYNC(context, crop_size.dims() == 1,
    150                       errors::InvalidArgument("crop_size must be 1-D",
    151                                               crop_size.shape().DebugString()),
    152                       done);
    153     OP_REQUIRES_ASYNC(
    154         context, crop_size.dim_size(0) == 2,
    155         errors::InvalidArgument("crop_size must have two elements",
    156                                 crop_size.shape().DebugString()),
    157         done);
    158 
    159     // Copy and validate crop sizes.
    160     auto crop_size_vec = crop_size.vec<int32>();
    161     const int crop_height = internal::SubtleMustCopy(crop_size_vec(0));
    162     const int crop_width = internal::SubtleMustCopy(crop_size_vec(1));
    163     OP_REQUIRES_ASYNC(
    164         context, crop_height > 0 && crop_width > 0,
    165         errors::InvalidArgument("crop dimensions must be positive"), done);
    166 
    167     // Allocate output tensor.
    168     Tensor* output = nullptr;
    169     OP_REQUIRES_OK_ASYNC(
    170         context,
    171         context->allocate_output(
    172             0, TensorShape({num_boxes, crop_height, crop_width, depth}),
    173             &output),
    174         done);
    175 
    176     auto compute_callback = [this, context, output]() {
    177       const Tensor& image = context->input(0);
    178       const Tensor& boxes = context->input(1);
    179       const Tensor& box_index = context->input(2);
    180       const bool status = functor::CropAndResize<Device, T>()(
    181           context, image.tensor<T, 4>(), boxes.tensor<float, 2>(),
    182           box_index.tensor<int32, 1>(), extrapolation_value_,
    183           output->tensor<float, 4>());
    184       if (!status) {
    185         context->SetStatus(
    186             errors::Internal("Failed launch CropAndResizeKernel."));
    187       }
    188     };
    189 
    190     RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
    191                                  batch_size, std::move(compute_callback),
    192                                  std::move(done));
    193   }
    194 
    195  private:
    196   float extrapolation_value_;
    197 };
    198 
    199 // Partial specialization of CropAndResize functor for a CPUDevice.
    200 namespace functor {
    201 template <typename T>
    202 struct CropAndResize<CPUDevice, T> {
    203   bool operator()(const OpKernelContext* context,
    204                   typename TTypes<T, 4>::ConstTensor image,
    205                   typename TTypes<float, 2>::ConstTensor boxes,
    206                   typename TTypes<int32, 1>::ConstTensor box_index,
    207                   float extrapolation_value,
    208                   typename TTypes<float, 4>::Tensor crops) {
    209     const int batch_size = image.dimension(0);
    210     const int image_height = image.dimension(1);
    211     const int image_width = image.dimension(2);
    212 
    213     const int num_boxes = crops.dimension(0);
    214     const int crop_height = crops.dimension(1);
    215     const int crop_width = crops.dimension(2);
    216     const int depth = crops.dimension(3);
    217 
    218     // Sharding across boxes.
    219     auto CropAndResizePerBox = [&](int start_box, int limit_box) {
    220       for (int b = start_box; b < limit_box; ++b) {
    221         const float y1 = boxes(b, 0);
    222         const float x1 = boxes(b, 1);
    223         const float y2 = boxes(b, 2);
    224         const float x2 = boxes(b, 3);
    225 
    226         const int32 b_in = box_index(b);
    227         if (!FastBoundsCheck(b_in, batch_size)) {
    228           continue;
    229         }
    230 
    231         const float height_scale =
    232             (crop_height > 1)
    233                 ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
    234                 : 0;
    235         const float width_scale =
    236             (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1)
    237                              : 0;
    238 
    239         for (int y = 0; y < crop_height; ++y) {
    240           const float in_y = (crop_height > 1)
    241                                  ? y1 * (image_height - 1) + y * height_scale
    242                                  : 0.5 * (y1 + y2) * (image_height - 1);
    243           if (in_y < 0 || in_y > image_height - 1) {
    244             for (int x = 0; x < crop_width; ++x) {
    245               for (int d = 0; d < depth; ++d) {
    246                 crops(b, y, x, d) = extrapolation_value;
    247               }
    248             }
    249             continue;
    250           }
    251           const int top_y_index = floorf(in_y);
    252           const int bottom_y_index = ceilf(in_y);
    253           const float y_lerp = in_y - top_y_index;
    254 
    255           for (int x = 0; x < crop_width; ++x) {
    256             const float in_x = (crop_width > 1)
    257                                    ? x1 * (image_width - 1) + x * width_scale
    258                                    : 0.5 * (x1 + x2) * (image_width - 1);
    259             if (in_x < 0 || in_x > image_width - 1) {
    260               for (int d = 0; d < depth; ++d) {
    261                 crops(b, y, x, d) = extrapolation_value;
    262               }
    263               continue;
    264             }
    265             const int left_x_index = floorf(in_x);
    266             const int right_x_index = ceilf(in_x);
    267             const float x_lerp = in_x - left_x_index;
    268 
    269             for (int d = 0; d < depth; ++d) {
    270               const float top_left(static_cast<float>(
    271                   image(b_in, top_y_index, left_x_index, d)));
    272               const float top_right(static_cast<float>(
    273                   image(b_in, top_y_index, right_x_index, d)));
    274               const float bottom_left(static_cast<float>(
    275                   image(b_in, bottom_y_index, left_x_index, d)));
    276               const float bottom_right(static_cast<float>(
    277                   image(b_in, bottom_y_index, right_x_index, d)));
    278               const float top = top_left + (top_right - top_left) * x_lerp;
    279               const float bottom =
    280                   bottom_left + (bottom_right - bottom_left) * x_lerp;
    281               crops(b, y, x, d) = top + (bottom - top) * y_lerp;
    282             }
    283           }
    284         }
    285       }
    286     };
    287 
    288     // A rough estimation of the cost for each cropped box.
    289     const double cost_per_pixel =
    290         depth * (Eigen::TensorOpCost::AddCost<float>() * 6 +
    291                  Eigen::TensorOpCost::MulCost<float>() * 3 +
    292                  Eigen::TensorOpCost::CastCost<T, float>() * 4) +
    293         (Eigen::TensorOpCost::AddCost<float>() * 2 +
    294          Eigen::TensorOpCost::AddCost<float>() * 3);
    295     const double cost_per_box = crop_height * crop_width * cost_per_pixel;
    296 
    297     const DeviceBase::CpuWorkerThreads& worker_threads =
    298         *(context->device()->tensorflow_cpu_worker_threads());
    299     Shard(worker_threads.num_threads, worker_threads.workers, num_boxes,
    300           cost_per_box, CropAndResizePerBox);
    301 
    302     return true;
    303   }
    304 };
    305 
    306 }  // namespace functor
    307 
    308 template <typename Device, typename T>
    309 class CropAndResizeGradImageOp : public AsyncOpKernel {
    310  public:
    311   explicit CropAndResizeGradImageOp(OpKernelConstruction* context)
    312       : AsyncOpKernel(context) {
    313     string method;
    314     OP_REQUIRES_OK(context, context->GetAttr("method", &method));
    315     OP_REQUIRES(context, method == "bilinear",
    316                 errors::InvalidArgument("method must be 'bilinear'", method));
    317   }
    318 
    319   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
    320     // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
    321     const Tensor& grads = context->input(0);
    322     // The shape of 'boxes' is [num_boxes, 4].
    323     const Tensor& boxes = context->input(1);
    324     // The shape of 'box_index' is [num_boxes].
    325     const Tensor& box_index = context->input(2);
    326     // The shape of 'image_size' is [4].
    327     const Tensor& image_size = context->input(3);
    328 
    329     // Validate input shapes.
    330     OP_REQUIRES_ASYNC(context, grads.dims() == 4,
    331                       errors::InvalidArgument("grads image must be 4-D",
    332                                               grads.shape().DebugString()),
    333                       done);
    334     const int crop_height = grads.dim_size(1);
    335     const int crop_width = grads.dim_size(2);
    336     OP_REQUIRES_ASYNC(
    337         context, crop_height > 0 && crop_width > 0,
    338         errors::InvalidArgument("grads dimensions must be positive"), done);
    339     int num_boxes = 0;
    340     OP_REQUIRES_OK_ASYNC(
    341         context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
    342     OP_REQUIRES_ASYNC(
    343         context, grads.dim_size(0) == num_boxes,
    344         errors::InvalidArgument("boxes and grads have incompatible shape"),
    345         done);
    346 
    347     OP_REQUIRES_ASYNC(context, image_size.dims() == 1,
    348                       errors::InvalidArgument("image_size must be 1-D",
    349                                               image_size.shape().DebugString()),
    350                       done);
    351     OP_REQUIRES_ASYNC(context, image_size.dim_size(0) == 4,
    352                       errors::InvalidArgument("image_size must have 4 elements",
    353                                               image_size.shape().DebugString()),
    354                       done);
    355     auto image_size_vec = image_size.vec<int32>();
    356     const int batch_size = internal::SubtleMustCopy(image_size_vec(0));
    357     const int image_height = internal::SubtleMustCopy(image_size_vec(1));
    358     const int image_width = internal::SubtleMustCopy(image_size_vec(2));
    359     const int depth = internal::SubtleMustCopy(image_size_vec(3));
    360     OP_REQUIRES_ASYNC(
    361         context, image_height > 0 && image_width > 0,
    362         errors::InvalidArgument("image dimensions must be positive"), done);
    363     OP_REQUIRES_ASYNC(
    364         context, grads.dim_size(3) == depth,
    365         errors::InvalidArgument("image_size and grads are incompatible"), done);
    366 
    367     // Allocate output tensor.
    368     Tensor* output = nullptr;
    369     OP_REQUIRES_OK_ASYNC(
    370         context,
    371         context->allocate_output(
    372             0, TensorShape({batch_size, image_height, image_width, depth}),
    373             &output),
    374         done);
    375 
    376     auto compute_callback = [context, output]() {
    377       const Tensor& grads = context->input(0);
    378       const Tensor& boxes = context->input(1);
    379       const Tensor& box_index = context->input(2);
    380       const bool status = functor::CropAndResizeBackpropImage<Device, T>()(
    381           context->eigen_device<Device>(), grads.tensor<float, 4>(),
    382           boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(),
    383           output->tensor<T, 4>());
    384       if (!status) {
    385         context->SetStatus(errors::Internal(
    386             "Failed launch CropAndResizeBackpropImage kernel."));
    387       }
    388     };
    389 
    390     RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
    391                                  batch_size, std::move(compute_callback),
    392                                  std::move(done));
    393   }
    394 };
    395 
    396 // Partial specialization of CropAndResizeBackpropImage functor for a CPUDevice.
    397 namespace functor {
    398 template <typename T>
    399 struct CropAndResizeBackpropImage<CPUDevice, T> {
    400   bool operator()(const CPUDevice& d,
    401                   typename TTypes<float, 4>::ConstTensor grads,
    402                   typename TTypes<float, 2>::ConstTensor boxes,
    403                   typename TTypes<int32, 1>::ConstTensor box_index,
    404                   typename TTypes<T, 4>::Tensor grads_image) {
    405     const int batch_size = grads_image.dimension(0);
    406     const int image_height = grads_image.dimension(1);
    407     const int image_width = grads_image.dimension(2);
    408 
    409     const int num_boxes = grads.dimension(0);
    410     const int crop_height = grads.dimension(1);
    411     const int crop_width = grads.dimension(2);
    412     const int depth = grads.dimension(3);
    413 
    414     grads_image.setZero();
    415 
    416     for (int b = 0; b < num_boxes; ++b) {
    417       const float y1 = boxes(b, 0);
    418       const float x1 = boxes(b, 1);
    419       const float y2 = boxes(b, 2);
    420       const float x2 = boxes(b, 3);
    421 
    422       const int32 b_in = box_index(b);
    423       if (!FastBoundsCheck(b_in, batch_size)) {
    424         continue;
    425       }
    426 
    427       const float height_scale =
    428           (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
    429                             : 0;
    430       const float width_scale =
    431           (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1)
    432                            : 0;
    433 
    434       for (int y = 0; y < crop_height; ++y) {
    435         const float in_y = (crop_height > 1)
    436                                ? y1 * (image_height - 1) + y * height_scale
    437                                : 0.5 * (y1 + y2) * (image_height - 1);
    438         if (in_y < 0 || in_y > image_height - 1) {
    439           continue;
    440         }
    441         const int top_y_index = floorf(in_y);
    442         const int bottom_y_index = ceilf(in_y);
    443         const float y_lerp = in_y - top_y_index;
    444 
    445         for (int x = 0; x < crop_width; ++x) {
    446           const float in_x = (crop_width > 1)
    447                                  ? x1 * (image_width - 1) + x * width_scale
    448                                  : 0.5 * (x1 + x2) * (image_width - 1);
    449           if (in_x < 0 || in_x > image_width - 1) {
    450             continue;
    451           }
    452           const int left_x_index = floorf(in_x);
    453           const int right_x_index = ceilf(in_x);
    454           const float x_lerp = in_x - left_x_index;
    455 
    456           for (int d = 0; d < depth; ++d) {
    457             const float dtop = (1 - y_lerp) * grads(b, y, x, d);
    458             grads_image(b_in, top_y_index, left_x_index, d) +=
    459                 static_cast<T>((1 - x_lerp) * dtop);
    460             grads_image(b_in, top_y_index, right_x_index, d) +=
    461                 static_cast<T>(x_lerp * dtop);
    462             const float dbottom = y_lerp * grads(b, y, x, d);
    463             grads_image(b_in, bottom_y_index, left_x_index, d) +=
    464                 static_cast<T>((1 - x_lerp) * dbottom);
    465             grads_image(b_in, bottom_y_index, right_x_index, d) +=
    466                 static_cast<T>(x_lerp * dbottom);
    467           }
    468         }
    469       }
    470     }
    471     return true;
    472   }
    473 };
    474 
    475 }  // namespace functor
    476 
    477 template <typename Device, typename T>
    478 class CropAndResizeGradBoxesOp : public AsyncOpKernel {
    479  public:
    480   explicit CropAndResizeGradBoxesOp(OpKernelConstruction* context)
    481       : AsyncOpKernel(context) {
    482     string method;
    483     OP_REQUIRES_OK(context, context->GetAttr("method", &method));
    484     OP_REQUIRES(context, method == "bilinear",
    485                 errors::InvalidArgument("method must be 'bilinear'", method));
    486   }
    487 
    488   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
    489     // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
    490     const Tensor& grads = context->input(0);
    491     // The shape of 'boxes' is [num_boxes, 4].
    492     const Tensor& boxes = context->input(2);
    493     // The shape of 'box_index' is [num_boxes].
    494     const Tensor& box_index = context->input(3);
    495     // The shape of 'image' is [batch_size, image_height, image_width, depth].
    496     const Tensor& image = context->input(1);
    497 
    498     // Validate input shapes.
    499     OP_REQUIRES_ASYNC(context, grads.dims() == 4,
    500                       errors::InvalidArgument("grads image must be 4-D",
    501                                               grads.shape().DebugString()),
    502                       done);
    503     const int crop_height = grads.dim_size(1);
    504     const int crop_width = grads.dim_size(2);
    505     const int depth = grads.dim_size(3);
    506     OP_REQUIRES_ASYNC(
    507         context, crop_height > 0 && crop_width > 0,
    508         errors::InvalidArgument("grads dimensions must be positive"), done);
    509 
    510     OP_REQUIRES_ASYNC(context, image.dims() == 4,
    511                       errors::InvalidArgument("input image must be 4-D",
    512                                               image.shape().DebugString()),
    513                       done);
    514     const int batch_size = image.dim_size(0);
    515     const int image_height = image.dim_size(1);
    516     const int image_width = image.dim_size(2);
    517     OP_REQUIRES_ASYNC(
    518         context, image_height > 0 && image_width > 0,
    519         errors::InvalidArgument("image dimensions must be positive"), done);
    520     OP_REQUIRES_ASYNC(context, image.dim_size(3) == depth,
    521                       errors::InvalidArgument("image, grads depth differ"),
    522                       done);
    523 
    524     int num_boxes = 0;
    525     OP_REQUIRES_OK_ASYNC(
    526         context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
    527 
    528     OP_REQUIRES_ASYNC(
    529         context, grads.dim_size(0) == num_boxes,
    530         errors::InvalidArgument("boxes and grads have incompatible shape"),
    531         done);
    532 
    533     // Allocate output tensor.
    534     Tensor* output = nullptr;
    535     OP_REQUIRES_OK_ASYNC(
    536         context,
    537         context->allocate_output(0, TensorShape({num_boxes, 4}), &output),
    538         done);
    539 
    540     auto compute_callback = [context, output]() {
    541       const Tensor& grads = context->input(0);
    542       const Tensor& image = context->input(1);
    543       const Tensor& boxes = context->input(2);
    544       const Tensor& box_index = context->input(3);
    545       const bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
    546           context->eigen_device<Device>(), grads.tensor<float, 4>(),
    547           image.tensor<T, 4>(), boxes.tensor<float, 2>(),
    548           box_index.tensor<int32, 1>(), output->tensor<float, 2>());
    549       if (!status) {
    550         context->SetStatus(errors::Internal(
    551             "Failed launch CropAndResizeBackpropBoxes kernel."));
    552       }
    553     };
    554 
    555     RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
    556                                  batch_size, std::move(compute_callback),
    557                                  std::move(done));
    558   }
    559 };
    560 
    561 // Partial specialization of CropAndResizeBackpropBoxes functor for a CPUDevice.
    562 namespace functor {
    563 template <typename T>
    564 struct CropAndResizeBackpropBoxes<CPUDevice, T> {
    565   bool operator()(const CPUDevice& d,
    566                   typename TTypes<float, 4>::ConstTensor grads,
    567                   typename TTypes<T, 4>::ConstTensor image,
    568                   typename TTypes<float, 2>::ConstTensor boxes,
    569                   typename TTypes<int32, 1>::ConstTensor box_index,
    570                   typename TTypes<float, 2>::Tensor grads_boxes) {
    571     const int batch_size = image.dimension(0);
    572     const int image_height = image.dimension(1);
    573     const int image_width = image.dimension(2);
    574 
    575     const int num_boxes = grads.dimension(0);
    576     const int crop_height = grads.dimension(1);
    577     const int crop_width = grads.dimension(2);
    578     const int depth = grads.dimension(3);
    579 
    580     grads_boxes.setZero();
    581 
    582     for (int b = 0; b < num_boxes; ++b) {
    583       const float y1 = boxes(b, 0);
    584       const float x1 = boxes(b, 1);
    585       const float y2 = boxes(b, 2);
    586       const float x2 = boxes(b, 3);
    587 
    588       const int32 b_in = box_index(b);
    589       if (!FastBoundsCheck(b_in, batch_size)) {
    590         continue;
    591       }
    592 
    593       const float height_ratio =
    594           (crop_height > 1)
    595               ? static_cast<float>(image_height - 1) / (crop_height - 1)
    596               : 0;
    597       const float width_ratio =
    598           (crop_width > 1)
    599               ? static_cast<float>(image_width - 1) / (crop_width - 1)
    600               : 0;
    601 
    602       const float height_scale =
    603           (crop_height > 1) ? (y2 - y1) * height_ratio : 0;
    604       const float width_scale = (crop_width > 1) ? (x2 - x1) * width_ratio : 0;
    605 
    606       for (int y = 0; y < crop_height; ++y) {
    607         const float in_y = (crop_height > 1)
    608                                ? y1 * (image_height - 1) + y * height_scale
    609                                : 0.5 * (y1 + y2) * (image_height - 1);
    610         if (in_y < 0 || in_y > image_height - 1) {
    611           continue;
    612         }
    613         const int top_y_index = floorf(in_y);
    614         const int bottom_y_index = ceilf(in_y);
    615         const float y_lerp = in_y - top_y_index;
    616 
    617         for (int x = 0; x < crop_width; ++x) {
    618           const float in_x = (crop_width > 1)
    619                                  ? x1 * (image_width - 1) + x * width_scale
    620                                  : 0.5 * (x1 + x2) * (image_width - 1);
    621           if (in_x < 0 || in_x > image_width - 1) {
    622             continue;
    623           }
    624           const int left_x_index = floorf(in_x);
    625           const int right_x_index = ceilf(in_x);
    626           const float x_lerp = in_x - left_x_index;
    627 
    628           for (int d = 0; d < depth; ++d) {
    629             const float top_left(
    630                 static_cast<float>(image(b_in, top_y_index, left_x_index, d)));
    631             const float top_right(
    632                 static_cast<float>(image(b_in, top_y_index, right_x_index, d)));
    633             const float bottom_left(static_cast<float>(
    634                 image(b_in, bottom_y_index, left_x_index, d)));
    635             const float bottom_right(static_cast<float>(
    636                 image(b_in, bottom_y_index, right_x_index, d)));
    637             // Compute the image gradient.
    638             float image_grad_y = (1 - x_lerp) * (bottom_left - top_left) +
    639                                  x_lerp * (bottom_right - top_right);
    640             float image_grad_x = (1 - y_lerp) * (top_right - top_left) +
    641                                  y_lerp * (bottom_right - bottom_left);
    642             // Modulate the image gradient with the incoming gradient.
    643             const float top_grad = grads(b, y, x, d);
    644             image_grad_y *= top_grad;
    645             image_grad_x *= top_grad;
    646             // dy1, dy2
    647             if (crop_height > 1) {
    648               grads_boxes(b, 0) +=
    649                   image_grad_y * (image_height - 1 - y * height_ratio);
    650               grads_boxes(b, 2) += image_grad_y * (y * height_ratio);
    651             } else {
    652               grads_boxes(b, 0) += image_grad_y * 0.5 * (image_height - 1);
    653               grads_boxes(b, 2) += image_grad_y * 0.5 * (image_height - 1);
    654             }
    655             // dx1, dx2
    656             if (crop_width > 1) {
    657               grads_boxes(b, 1) +=
    658                   image_grad_x * (image_width - 1 - x * width_ratio);
    659               grads_boxes(b, 3) += image_grad_x * (x * width_ratio);
    660             } else {
    661               grads_boxes(b, 1) += image_grad_x * 0.5 * (image_width - 1);
    662               grads_boxes(b, 3) += image_grad_x * 0.5 * (image_width - 1);
    663             }
    664           }
    665         }
    666       }
    667     }
    668     return true;
    669   }
    670 };
    671 
    672 }  // namespace functor
    673 
    674 #define REGISTER_KERNEL(T)                                \
    675   REGISTER_KERNEL_BUILDER(Name("CropAndResize")           \
    676                               .Device(DEVICE_CPU)         \
    677                               .TypeConstraint<T>("T")     \
    678                               .HostMemory("crop_size"),   \
    679                           CropAndResizeOp<CPUDevice, T>); \
    680                                                           \
    681   REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes")  \
    682                               .Device(DEVICE_CPU)         \
    683                               .TypeConstraint<T>("T"),    \
    684                           CropAndResizeGradBoxesOp<CPUDevice, T>);
    685 
    686 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
    687 
    688 #undef REGISTER_KERNEL
    689 
    690 #define REGISTER_KERNEL(T)                               \
    691   REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage") \
    692                               .Device(DEVICE_CPU)        \
    693                               .TypeConstraint<T>("T")    \
    694                               .HostMemory("image_size"), \
    695                           CropAndResizeGradImageOp<CPUDevice, T>);
    696 
    697 TF_CALL_half(REGISTER_KERNEL);
    698 TF_CALL_float(REGISTER_KERNEL);
    699 TF_CALL_double(REGISTER_KERNEL);
    700 
    701 #undef REGISTER_KERNEL
    702 
    703 #if GOOGLE_CUDA
    704 
    705 // Forward declaration of the CheckValidBoxIndexHelper specialization for GPU.
    706 namespace functor {
    707 template <>
    708 void CheckValidBoxIndexHelper<GPUDevice>::operator()(
    709     const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_index,
    710     int batch_size, typename TTypes<bool, 0>::Tensor isvalid);
    711 extern template struct CheckValidBoxIndexHelper<GPUDevice>;
    712 }  // namespace functor
    713 
    714 namespace {
    715 
    716 // Specialization of CheckValidBoxIndex for a GPUDevice.
    717 template <>
    718 inline void RunIfBoxIndexIsValid<GPUDevice>(
    719     OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
    720     int batch_size, const Callback& compute, const Callback& done) {
    721   const int num_boxes = box_index.dimension(0);
    722   if (num_boxes == 0) {
    723     compute();
    724     done();
    725     return;
    726   }
    727 
    728   Tensor isvalid_dev_tensor;
    729   OP_REQUIRES_OK_ASYNC(
    730       context,
    731       context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
    732                              &isvalid_dev_tensor),
    733       done);
    734   typename TTypes<bool, 0>::Tensor isvalid_dev =
    735       isvalid_dev_tensor.tensor<bool, 0>();
    736 
    737   // Run the actual box check on the device.
    738   functor::CheckValidBoxIndexHelper<GPUDevice>()(
    739       context->eigen_device<GPUDevice>(), box_index, batch_size, isvalid_dev);
    740 
    741   // Copy the result back to the host.
    742   auto* stream = context->op_device_context()->stream();
    743   OP_REQUIRES_ASYNC(context, stream,
    744                     errors::Internal("No GPU stream available."), done);
    745   Tensor isvalid_host_tensor;
    746   // Use pinned host memory on the host to avoid unnecessary
    747   // synchronization.
    748   AllocatorAttributes alloc_attr;
    749   alloc_attr.set_on_host(true);
    750   alloc_attr.set_gpu_compatible(true);
    751   OP_REQUIRES_OK_ASYNC(
    752       context,
    753       context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
    754                              &isvalid_host_tensor, alloc_attr),
    755       done);
    756   perftools::gputools::DeviceMemoryBase wrapped(isvalid_dev.data(),
    757                                                 sizeof(bool));
    758   const bool status =
    759       stream
    760           ->ThenMemcpy(
    761               isvalid_host_tensor.scalar<bool>().data() /* destination */,
    762               wrapped /* source */, sizeof(bool))
    763           .ok();
    764   OP_REQUIRES_ASYNC(
    765       context, status,
    766       errors::Internal("Failed to launch copy of isvalid from device to host."),
    767       done);
    768 
    769   // We capture both temporary tensors to prevent them from being deallocated
    770   // when ComputeAsync returns and before the closure runs.
    771   TensorReference isvalid_dev_ref(isvalid_dev_tensor);
    772   auto wrapped_callback = [context, isvalid_host_tensor, isvalid_dev_ref,
    773                            compute, done]() {
    774     auto stream = context->op_device_context()->stream();
    775     ScopedActivateExecutorContext scoped_activation{stream->parent()};
    776     const bool isvalid = isvalid_host_tensor.scalar<bool>()();
    777     isvalid_dev_ref.Unref();
    778     OP_REQUIRES_ASYNC(
    779         context, isvalid,
    780         errors::OutOfRange("box_index has values outside [0, batch_size)"),
    781         done);
    782     compute();
    783     done();
    784   };
    785 
    786   context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
    787       stream, wrapped_callback);
    788 }
    789 
    790 }  // namespace
    791 
    792 #define REGISTER_KERNEL(T)                                         \
    793   REGISTER_KERNEL_BUILDER(Name("CropAndResize")                    \
    794                               .Device(DEVICE_GPU)                  \
    795                               .TypeConstraint<T>("T")              \
    796                               .HostMemory("crop_size"),            \
    797                           CropAndResizeOp<GPUDevice, T>);          \
    798                                                                    \
    799   REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage")           \
    800                               .Device(DEVICE_GPU)                  \
    801                               .TypeConstraint<T>("T")              \
    802                               .HostMemory("image_size"),           \
    803                           CropAndResizeGradImageOp<GPUDevice, T>); \
    804                                                                    \
    805   REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes")           \
    806                               .Device(DEVICE_GPU)                  \
    807                               .TypeConstraint<T>("T"),             \
    808                           CropAndResizeGradBoxesOp<GPUDevice, T>);
    809 
    810 TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL);
    811 
    812 #undef REGISTER_KERNEL
    813 
    814 #endif  // GOOGLE_CUDA
    815 
    816 }  // namespace tensorflow
    817