Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/image_ops.cc
     17 #define EIGEN_USE_THREADS
     18 
     19 #include <algorithm>
     20 #include <memory>
     21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/register_types.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/tensor_shape.h"
     26 #include "tensorflow/core/framework/types.h"
     27 #include "tensorflow/core/kernels/image_resizer_state.h"
     28 #include "tensorflow/core/lib/core/status.h"
     29 #include "tensorflow/core/platform/logging.h"
     30 
     31 namespace tensorflow {
     32 
     33 typedef Eigen::ThreadPoolDevice CPUDevice;
     34 
     35 namespace {
     36 struct CachedInterpolation {
     37   int64 start;
     38   int64 end;
     39   float start_scale;
     40   float end_minus_one_scale;
     41   bool needs_bounding;
     42 };
     43 }  // namespace
     44 
     45 template <typename Device, typename T>
     46 class ResizeAreaOp : public OpKernel {
     47  public:
     48   explicit ResizeAreaOp(OpKernelConstruction* context) : OpKernel(context) {
     49     OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
     50   }
     51 
     52   // Computes the sum of all x values defined by <x_interp> taken across
     53   // the y offsets and scales defined by y_ptrs and y_scales, for channel c.
     54   //
     55   // Note that <NeedsXBounding> is a template parameter to avoid a performance
     56   // penalty from dynamically checking it.
     57   template <bool NeedsXBounding>
     58   static void ComputePatchSumOf3Channels(float scale,
     59                                          const ImageResizerState& st,
     60                                          const std::vector<const T*>& y_ptrs,
     61                                          const std::vector<float>& y_scales,
     62                                          const CachedInterpolation& x_interp,
     63                                          float* output_ptr) {
     64 #define BOUND_IF_NEEDED(x, y) (NeedsXBounding ? Bound(x, y) : (x))
     65 
     66     float sum_0 = 0;
     67     float sum_1 = 0;
     68     float sum_2 = 0;
     69     for (int i = 0; i < y_ptrs.size(); ++i) {
     70       const T* ptr = y_ptrs[i];
     71       float scale_x = x_interp.start_scale;
     72       int64 offset = 3 * BOUND_IF_NEEDED(x_interp.start, st.in_width);
     73       float sum_y_0 = static_cast<float>(ptr[offset + 0]) * scale_x;
     74       float sum_y_1 = static_cast<float>(ptr[offset + 1]) * scale_x;
     75       float sum_y_2 = static_cast<float>(ptr[offset + 2]) * scale_x;
     76 
     77       if (x_interp.start + 1 != x_interp.end) {
     78         for (int64 x = x_interp.start + 1; x < x_interp.end - 1; ++x) {
     79           int64 offset = 3 * BOUND_IF_NEEDED(x, st.in_width);
     80           sum_y_0 += static_cast<float>(ptr[offset + 0]);
     81           sum_y_1 += static_cast<float>(ptr[offset + 1]);
     82           sum_y_2 += static_cast<float>(ptr[offset + 2]);
     83         }
     84         scale_x = x_interp.end_minus_one_scale;
     85         offset = 3 * BOUND_IF_NEEDED(x_interp.end - 1, st.in_width);
     86         sum_y_0 += static_cast<float>(ptr[offset + 0]) * scale_x;
     87         sum_y_1 += static_cast<float>(ptr[offset + 1]) * scale_x;
     88         sum_y_2 += static_cast<float>(ptr[offset + 2]) * scale_x;
     89       }
     90       sum_0 += sum_y_0 * y_scales[i];
     91       sum_1 += sum_y_1 * y_scales[i];
     92       sum_2 += sum_y_2 * y_scales[i];
     93     }
     94 
     95     output_ptr[0] = sum_0 * scale;
     96     output_ptr[1] = sum_1 * scale;
     97     output_ptr[2] = sum_2 * scale;
     98 
     99 #undef BOUND_IF_NEEDED
    100   }
    101 
    102   // Computes the sum of all x values defined by <x_interp> taken across
    103   // the y offsets and scales defined by y_ptrs and y_scales, for channel c.
    104   //
    105   // Note that <NeedsXBounding> is a template parameter to avoid a performance
    106   // penalty from dynamically checking it.
    107   template <bool NeedsXBounding>
    108   static void ComputePatchSum(float scale, const ImageResizerState& st,
    109                               const std::vector<const T*>& y_ptrs,
    110                               const std::vector<float>& y_scales,
    111                               const CachedInterpolation& x_interp,
    112                               float* output_ptr) {
    113 #define BOUND_IF_NEEDED(x, y) (NeedsXBounding ? Bound(x, y) : (x))
    114 
    115     const auto num_channels = st.channels;
    116     for (int64 c = 0; c < num_channels; ++c) {
    117       float sum = 0;
    118       for (int i = 0; i < y_ptrs.size(); ++i) {
    119         const T* ptr = y_ptrs[i];
    120         float scale_x = x_interp.start_scale;
    121         float sum_y = static_cast<float>(
    122                           ptr[num_channels *
    123                                   BOUND_IF_NEEDED(x_interp.start, st.in_width) +
    124                               c]) *
    125                       scale_x;
    126         if (x_interp.start + 1 != x_interp.end) {
    127           for (int64 x = x_interp.start + 1; x < x_interp.end - 1; ++x) {
    128             sum_y += static_cast<float>(
    129                 ptr[num_channels * BOUND_IF_NEEDED(x, st.in_width) + c]);
    130           }
    131           scale_x = x_interp.end_minus_one_scale;
    132           sum_y += static_cast<float>(
    133                        ptr[num_channels *
    134                                BOUND_IF_NEEDED(x_interp.end - 1, st.in_width) +
    135                            c]) *
    136                    scale_x;
    137         }
    138         sum += sum_y * y_scales[i];
    139       }
    140       output_ptr[c] = sum * scale;
    141     }
    142 #undef BOUND_IF_NEEDED
    143   }
    144 
    145   void Compute(OpKernelContext* context) override {
    146     const Tensor& input = context->input(0);
    147     ImageResizerState st(align_corners_);
    148     st.ValidateAndCreateOutput(context, input);
    149 
    150     if (!context->status().ok()) return;
    151 
    152     typename TTypes<T, 4>::ConstTensor input_data(input.tensor<T, 4>());
    153 
    154     // Precompute values used when iterating over x coordinates within a row.
    155     // Note that it may be useful to cache x_interps for a given
    156     // ImageResizerState.
    157     std::vector<CachedInterpolation> x_interps(st.out_width);
    158     for (int64 x = 0; x < st.out_width; ++x) {
    159       auto& x_interp = x_interps[x];
    160       const float in_x = x * st.width_scale;
    161       const float in_x1 = (x + 1) * st.width_scale;
    162       // The start and end width indices of all the cells that could
    163       // contribute to the target cell.
    164       int64 v = floor(in_x);
    165       x_interp.start = v;
    166       // TODO(cwhipkey): simplify this logic.
    167       x_interp.start_scale =
    168           v < in_x ? (v + 1 > in_x1 ? st.width_scale : v + 1 - in_x)
    169                    : (v + 1 > in_x1 ? in_x1 - v : 1.0);
    170 
    171       v = ceil(in_x1);
    172       x_interp.end = v;
    173       v = x_interp.end - 1;
    174       x_interp.end_minus_one_scale =
    175           v < in_x ? (v + 1 > in_x1 ? st.width_scale : v + 1 - in_x)
    176                    : (v + 1 > in_x1 ? in_x1 - v : 1.0);
    177       x_interp.needs_bounding =
    178           Bound(x_interp.start, st.in_width) != x_interp.start ||
    179           Bound(x_interp.end - 1, st.in_width) != (x_interp.end - 1);
    180     }
    181 
    182     if (st.channels == 3) {
    183       ComputeLoop<3>(st, x_interps, input_data);
    184     } else {
    185       ComputeLoop<-1>(st, x_interps, input_data);
    186     }
    187   }
    188 
    189   template <int64 kKnownNumChannels>
    190   void ComputeLoop(const ImageResizerState& st,
    191                    const std::vector<CachedInterpolation>& x_interps,
    192                    typename TTypes<T, 4>::ConstTensor input_data) {
    193     TTypes<float, 4>::Tensor output_data = st.output->tensor<float, 4>();
    194 
    195     // When using this algorithm for downsizing, the target pixel value is the
    196     // weighted average of all the source pixels. The weight is determined by
    197     // the contribution percentage of the source pixel.
    198     //
    199     // Let "scale" be "target_image_size/source_image_size". If 1/n of the
    200     // source pixel contributes to the target pixel, then the weight is (1/n *
    201     // scale); if the complete source pixel contributes to the target pixel,
    202     // then the weight is scale.
    203     //
    204     // To visualize the implementation, use one dimension as an example:
    205     // Resize in[4] to out[3].
    206     //   scale = 3/4 = 0.75
    207     //   out[0]: in[0] and 1/3 of in[1]
    208     //   out[1]: 2/3 of in[1] and 2/3 of in[2]
    209     //   out[2]: 1/3 of in[2] and in[1]
    210     // Hence, the output pixel values are:
    211     //   out[0] = (in[0] * 1.0 + in[1] * 1/3) * scale
    212     //   out[1] = (in[1] * 2/3 + in[2] * 2/3 * scale
    213     //   out[2] = (in[3] * 1/3 + in[3] * 1.0) * scale
    214     const T* const input_ptr = input_data.data();
    215     std::vector<float> y_scales;
    216     std::vector<const T*> y_ptrs;
    217     float scale = 1.0 / (st.height_scale * st.width_scale);
    218     float* output_ptr = output_data.data();
    219     for (int64 b = 0; b < st.batch_size; ++b) {
    220       for (int64 y = 0; y < st.out_height; ++y) {
    221         const float in_y = y * st.height_scale;
    222         const float in_y1 = (y + 1) * st.height_scale;
    223         // The start and end height indices of all the cells that could
    224         // contribute to the target cell.
    225         const int64 y_start = floor(in_y);
    226         const int64 y_end = ceil(in_y1);
    227         y_scales.clear();
    228         y_ptrs.clear();
    229         for (int64 i = y_start; i < y_end; ++i) {
    230           float scale_y;
    231           if (i < in_y) {
    232             scale_y = (i + 1 > in_y1 ? st.height_scale : i + 1 - in_y);
    233           } else {
    234             scale_y = (i + 1 > in_y1 ? in_y1 - i : 1.0);
    235           }
    236           // TODO(cwhipkey): can this data unified with CachedInterpolation?
    237           y_scales.push_back(scale_y);
    238           y_ptrs.push_back(
    239               input_ptr + (b * st.in_height * st.in_width * st.channels +
    240                            Bound(i, st.in_height) * st.in_width * st.channels));
    241         }
    242 
    243         if (kKnownNumChannels == 3) {
    244           for (int64 x = 0; x < st.out_width; ++x) {
    245             const CachedInterpolation& x_interp = x_interps[x];
    246             if (x_interp.needs_bounding) {
    247               ComputePatchSumOf3Channels<true>(scale, st, y_ptrs, y_scales,
    248                                                x_interp, output_ptr);
    249             } else {
    250               ComputePatchSumOf3Channels<false>(scale, st, y_ptrs, y_scales,
    251                                                 x_interp, output_ptr);
    252             }
    253             output_ptr += 3;
    254           }
    255         } else {
    256           for (int64 x = 0; x < st.out_width; ++x) {
    257             const CachedInterpolation& x_interp = x_interps[x];
    258             if (x_interp.needs_bounding) {
    259               ComputePatchSum<true>(scale, st, y_ptrs, y_scales, x_interp,
    260                                     output_ptr);
    261             } else {
    262               ComputePatchSum<false>(scale, st, y_ptrs, y_scales, x_interp,
    263                                      output_ptr);
    264             }
    265             output_ptr += st.channels;
    266           }
    267         }
    268       }
    269     }
    270   }
    271 
    272  private:
    273   static EIGEN_ALWAYS_INLINE int64 Bound(int64 val, int64 limit) {
    274     return std::min(limit - 1ll, std::max(0ll, val));
    275   }
    276 
    277   bool align_corners_;
    278 };
    279 
    280 #define REGISTER_KERNEL(T)                            \
    281   REGISTER_KERNEL_BUILDER(Name("ResizeArea")          \
    282                               .Device(DEVICE_CPU)     \
    283                               .TypeConstraint<T>("T") \
    284                               .HostMemory("size"),    \
    285                           ResizeAreaOp<CPUDevice, T>);
    286 
    287 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
    288 
    289 #undef REGISTER_KERNEL
    290 
    291 }  // namespace tensorflow
    292