Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CONTRIB_IMAGE_KERNELS_IMAGE_OPS_H_
     17 #define TENSORFLOW_CONTRIB_IMAGE_KERNELS_IMAGE_OPS_H_
     18 
     19 // See docs in ../ops/image_ops.cc.
     20 
     21 #define EIGEN_USE_THREADS
     22 
     23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     24 #include "tensorflow/core/framework/tensor_types.h"
     25 #include "tensorflow/core/platform/types.h"
     26 
     27 namespace tensorflow {
     28 
     29 namespace generator {
     30 
     31 enum Interpolation { INTERPOLATION_NEAREST, INTERPOLATION_BILINEAR };
     32 
     33 using Eigen::array;
     34 using Eigen::DenseIndex;
     35 
     36 template <typename Device, typename T>
     37 class ProjectiveGenerator {
     38  private:
     39   typename TTypes<T, 4>::ConstTensor input_;
     40   typename TTypes<float>::ConstMatrix transforms_;
     41   const Interpolation interpolation_;
     42 
     43  public:
     44   static const int kNumParameters = 8;
     45 
     46   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
     47   ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input,
     48                       typename TTypes<float>::ConstMatrix transforms,
     49                       const Interpolation interpolation)
     50       : input_(input), transforms_(transforms), interpolation_(interpolation) {}
     51 
     52   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
     53   operator()(const array<DenseIndex, 4>& coords) const {
     54     const int64 output_y = coords[1];
     55     const int64 output_x = coords[2];
     56     const float* transform =
     57         transforms_.dimension(0) == 1
     58             ? transforms_.data()
     59             : &transforms_.data()[transforms_.dimension(1) * coords[0]];
     60     float projection = transform[6] * output_x + transform[7] * output_y + 1.f;
     61     const float input_x =
     62         (transform[0] * output_x + transform[1] * output_y + transform[2]) /
     63         projection;
     64     const float input_y =
     65         (transform[3] * output_x + transform[4] * output_y + transform[5]) /
     66         projection;
     67 
     68     // TODO(ringwalt): Add a fill value input.
     69 #if (defined __CUDA_ARCH__) && (CUDART_VERSION < 8000)
     70     // On CUDA versions previous to 8.0, only __shared__ variables
     71     // could be declared as static in the device code.
     72     const T fill_value = T(0);
     73 #else
     74     static const T fill_value = T(0);
     75 #endif
     76     switch (interpolation_) {
     77       case INTERPOLATION_NEAREST:
     78         // Switch the order of x and y again for indexing into the image.
     79         return nearest_interpolation(coords[0], input_y, input_x, coords[3],
     80                                      fill_value);
     81       case INTERPOLATION_BILINEAR:
     82         return bilinear_interpolation(coords[0], input_y, input_x, coords[3],
     83                                       fill_value);
     84     }
     85     // Unreachable; ImageProjectiveTransform only uses INTERPOLATION_NEAREST
     86     // or INTERPOLATION_BILINEAR.
     87     return T(0);
     88   }
     89 
     90  private:
     91   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
     92   nearest_interpolation(const DenseIndex batch, const float y, const float x,
     93                         const DenseIndex channel, const T fill_value) const {
     94     return read_with_fill_value(batch, DenseIndex(std::round(y)),
     95                                 DenseIndex(std::round(x)), channel, fill_value);
     96   }
     97 
     98   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
     99   bilinear_interpolation(const DenseIndex batch, const float y, const float x,
    100                          const DenseIndex channel, const T fill_value) const {
    101     const float y_floor = std::floor(y);
    102     const float x_floor = std::floor(x);
    103     const float y_ceil = y_floor + 1;
    104     const float x_ceil = x_floor + 1;
    105     // f(x, y_floor) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_floor)
    106     //               + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_floor)
    107     const float value_yfloor =
    108         (x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_floor),
    109                                             DenseIndex(x_floor), channel,
    110                                             fill_value) +
    111         (x - x_floor) * read_with_fill_value(batch, DenseIndex(y_floor),
    112                                              DenseIndex(x_ceil), channel,
    113                                              fill_value);
    114     // f(x, y_ceil) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_ceil)
    115     //              + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_ceil)
    116     const float value_yceil =
    117         (x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_ceil),
    118                                             DenseIndex(x_floor), channel,
    119                                             fill_value) +
    120         (x - x_floor) * read_with_fill_value(batch, DenseIndex(y_ceil),
    121                                              DenseIndex(x_ceil), channel,
    122                                              fill_value);
    123     // f(x, y) = (y_ceil - y) / (y_ceil - y_floor) * f(x, y_floor)
    124     //         + (y - y_floor) / (y_ceil - y_floor) * f(x, y_ceil)
    125     return T((y_ceil - y) * value_yfloor + (y - y_floor) * value_yceil);
    126   }
    127 
    128   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T read_with_fill_value(
    129       const DenseIndex batch, const DenseIndex y, const DenseIndex x,
    130       const DenseIndex channel, const T fill_value) const {
    131     // batch and channel must be correct, because they are passed unchanged from
    132     // the input.
    133     return (0 <= y && y < input_.dimension(1) && 0 <= x &&
    134             x < input_.dimension(2))
    135                ? input_(array<DenseIndex, 4>{batch, y, x, channel})
    136                : fill_value;
    137   }
    138 };
    139 
    140 }  // end namespace generator
    141 
    142 // NOTE(ringwalt): We MUST wrap the generate() call in a functor and explicitly
    143 // instantiate the functor in image_ops_gpu.cu.cc. Otherwise, we will be missing
    144 // some Eigen device code.
    145 namespace functor {
    146 
    147 using generator::Interpolation;
    148 using generator::ProjectiveGenerator;
    149 
    150 template <typename Device, typename T>
    151 struct FillProjectiveTransform {
    152   typedef typename TTypes<T, 4>::Tensor OutputType;
    153   typedef typename TTypes<T, 4>::ConstTensor InputType;
    154   typedef typename TTypes<float, 2>::ConstTensor TransformsType;
    155   const Interpolation interpolation_;
    156 
    157   FillProjectiveTransform(Interpolation interpolation)
    158       : interpolation_(interpolation) {}
    159 
    160   EIGEN_ALWAYS_INLINE
    161   void operator()(const Device& device, OutputType* output,
    162                   const InputType& images,
    163                   const TransformsType& transform) const {
    164     output->device(device) = images.generate(
    165         ProjectiveGenerator<Device, T>(images, transform, interpolation_));
    166   }
    167 };
    168 
    169 }  // end namespace functor
    170 
    171 }  // end namespace tensorflow
    172 
    173 #endif  // TENSORFLOW_CONTRIB_IMAGE_KERNELS_IMAGE_OPS_H_
    174