Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_CONV_2D_H_
     17 #define TENSORFLOW_KERNELS_CONV_2D_H_
     18 
     19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     20 #include "tensorflow/core/framework/tensor_types.h"
     21 #include "tensorflow/core/kernels/eigen_backward_spatial_convolutions.h"
     22 #include "tensorflow/core/kernels/eigen_spatial_convolutions.h"
     23 #include "tensorflow/core/util/tensor_format.h"
     24 
     25 namespace tensorflow {
     26 namespace functor {
     27 
     28 // TODO(yangke): revisit these operations and in particular, see if we can
     29 // combine all of them into just one operation without causing nvcc to
     30 // timeout.
     31 template <typename Device, typename T, int Dims, typename IndexType>
     32 struct ShuffleAndReverse {
     33   void operator()(const Device& d,
     34                   typename TTypes<T, Dims, IndexType>::ConstTensor input,
     35                   const Eigen::DSizes<IndexType, Dims>& order,
     36                   const Eigen::array<bool, Dims>& reverse_dims,
     37                   typename TTypes<T, Dims, IndexType>::Tensor output) {
     38     output.device(d) = input.shuffle(order).reverse(reverse_dims);
     39   }
     40 };
     41 
     42 template <typename Device, typename T, int Dims, typename IndexType>
     43 struct InflatePadAndShuffle {
     44   void operator()(
     45       const Device& d, typename TTypes<T, Dims, IndexType>::ConstTensor input,
     46       const Eigen::DSizes<IndexType, Dims>& strides,
     47       const Eigen::array<Eigen::IndexPair<IndexType>, Dims>& pad_dims,
     48       const Eigen::DSizes<IndexType, Dims>& order,
     49       typename TTypes<T, Dims, IndexType>::Tensor output) {
     50     output.device(d) = input.inflate(strides).pad(pad_dims).shuffle(order);
     51   }
     52 };
     53 
     54 template <typename Device, typename Input, typename Filter, typename Output>
     55 void SpatialConvolutionFunc(const Device& d, Output output, Input input,
     56                             Filter filter, int row_stride, int col_stride,
     57                             int row_dilation, int col_dilation,
     58                             const Eigen::PaddingType& padding) {
     59   // Need to swap row/col when calling Eigen.
     60   output.device(d) =
     61       Eigen::SpatialConvolution(input, filter, col_stride, row_stride, padding,
     62                                 col_dilation, row_dilation);
     63 }
     64 
     65 template <typename Device, typename T>
     66 struct SpatialConvolution {
     67   void operator()(const Device& d, typename TTypes<T, 4>::Tensor output,
     68                   typename TTypes<T, 4>::ConstTensor input,
     69                   typename TTypes<T, 4>::ConstTensor filter, int row_stride,
     70                   int col_stride, int row_dilation, int col_dilation,
     71                   const Eigen::PaddingType& padding) {
     72     SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride,
     73                            row_dilation, col_dilation, padding);
     74   }
     75 };
     76 
     77 template <typename Device>
     78 struct SpatialConvolution<Device, Eigen::half> {
     79   void operator()(const Device& d,
     80                   typename TTypes<Eigen::half, 4>::Tensor output,
     81                   typename TTypes<Eigen::half, 4>::ConstTensor input,
     82                   typename TTypes<Eigen::half, 4>::ConstTensor filter,
     83                   int row_stride, int col_stride, int row_dilation,
     84                   int col_dilation, const Eigen::PaddingType& padding) {
     85     output.device(d) =
     86         Eigen::SpatialConvolution(input.cast<float>(), filter.cast<float>(),
     87                                   col_stride, row_stride, padding, col_dilation,
     88                                   row_dilation)
     89             .cast<Eigen::half>();
     90   }
     91 };
     92 
     93 template <typename Device, typename T>
     94 struct SpatialConvolutionBackwardInput {
     95   void operator()(const Device& d, typename TTypes<T, 4>::Tensor input_backward,
     96                   typename TTypes<T, 4>::ConstTensor kernel,
     97                   typename TTypes<T, 4>::ConstTensor output_backward,
     98                   int row_stride, int col_stride, int row_dilation,
     99                   int col_dilation) {
    100     // Need to swap row/col when calling Eigen.
    101     input_backward.device(d) = Eigen::SpatialConvolutionBackwardInput(
    102         kernel, output_backward, input_backward.dimension(2),
    103         input_backward.dimension(1), col_stride, row_stride, col_dilation,
    104         row_dilation);
    105   }
    106 };
    107 
    108 template <typename Device, typename T>
    109 struct SpatialConvolutionBackwardFilter {
    110   void operator()(const Device& d,
    111                   typename TTypes<T, 4>::Tensor kernel_backward,
    112                   typename TTypes<T, 4>::ConstTensor input,
    113                   typename TTypes<T, 4>::ConstTensor output_backward,
    114                   int row_stride, int col_stride, int row_dilation,
    115                   int col_dilation) {
    116     // Need to swap row/col when calling Eigen.
    117     kernel_backward.device(d) = Eigen::SpatialConvolutionBackwardKernel(
    118         input, output_backward, kernel_backward.dimension(1),
    119         kernel_backward.dimension(0), col_stride, row_stride, col_dilation,
    120         row_dilation);
    121   }
    122 };
    123 
    124 // TODO(vrv): Figure out how to use the MatMulFunctor in matmul_op.h.
    125 // My initial attempt to do this compiled but failed in the pytest
    126 // due to a swigdeps error.
    127 template <typename Device, typename T>
    128 struct MatMulConvFunctor {
    129   // Computes on device "d": out = in0 * in1, where * is matrix
    130   // multiplication.
    131   void operator()(
    132       const Device& d, typename TTypes<T, 2>::Tensor out,
    133       typename TTypes<T, 2>::ConstTensor in0,
    134       typename TTypes<T, 2>::ConstTensor in1,
    135       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
    136     out.device(d) = in0.contract(in1, dim_pair);
    137   }
    138 };
    139 
    140 // Shuffles a filter tensor from:
    141 //   [<spatial_dims>, in, out]
    142 // to:
    143 //   [out, in, <spatial_dims>]
    144 template <typename Device, typename T, typename IndexType, int NDIMS>
    145 struct TransformFilter {
    146   void operator()(const Device& d,
    147                   typename TTypes<T, NDIMS, IndexType>::ConstTensor in,
    148                   typename TTypes<T, NDIMS, IndexType>::Tensor out) {
    149     // We want a 3, 2, 0, 1 shuffle. Merge the spatial dimensions together
    150     // to speed up the shuffle operation.
    151     Eigen::DSizes<IndexType, 3> merged_dims;
    152     merged_dims[0] = in.dimension(0);  // spatial dimensions
    153     for (int i = 1; i < NDIMS - 2; ++i) {
    154       merged_dims[0] *= in.dimension(i);
    155     }
    156     merged_dims[1] = in.dimension(NDIMS - 2);  // input filters
    157     merged_dims[2] = in.dimension(NDIMS - 1);  // output filters
    158 
    159     Eigen::DSizes<IndexType, NDIMS> expanded_dims;
    160     expanded_dims[0] = in.dimension(NDIMS - 1);  // output filters
    161     expanded_dims[1] = in.dimension(NDIMS - 2);  // input filters
    162     for (int i = 0; i < NDIMS; ++i) {            // spatial dimensions
    163       expanded_dims[i + 2] = in.dimension(i);
    164     }
    165 
    166     out.device(d) = in.reshape(merged_dims)
    167                         .shuffle(Eigen::DSizes<IndexType, 3>(2, 1, 0))
    168                         .reshape(expanded_dims);
    169   }
    170 };
    171 
    172 template <typename Device, typename T, typename IndexType>
    173 struct TransformDepth {
    174   void operator()(const Device& d,
    175                   typename TTypes<T, 4, IndexType>::ConstTensor in,
    176                   const Eigen::DSizes<IndexType, 4>& shuffle,
    177                   typename TTypes<T, 4, IndexType>::Tensor out) {
    178     Eigen::DSizes<IndexType, 3> merged_dims;
    179     Eigen::DSizes<IndexType, 4> expanded_dims;
    180     Eigen::DSizes<IndexType, 3> new_shuffle;
    181 
    182     // Merge dimensions that won't be shuffled together to speed things up.
    183     if (shuffle[1] == 2 && shuffle[2] == 3) {
    184       merged_dims[0] = in.dimension(0);
    185       merged_dims[1] = in.dimension(1);
    186       merged_dims[2] = in.dimension(2) * in.dimension(3);
    187       new_shuffle[0] = shuffle[0];
    188       new_shuffle[1] = 2;
    189       new_shuffle[2] = shuffle[3];
    190       expanded_dims[0] = in.dimension(shuffle[0]);
    191       expanded_dims[1] = in.dimension(2);
    192       expanded_dims[2] = in.dimension(3);
    193       expanded_dims[3] = in.dimension(shuffle[3]);
    194     } else if (shuffle[0] == 2 && shuffle[1] == 3) {
    195       merged_dims[0] = in.dimension(0);
    196       merged_dims[1] = in.dimension(1);
    197       merged_dims[2] = in.dimension(2) * in.dimension(3);
    198       new_shuffle[0] = 2;
    199       new_shuffle[1] = shuffle[2];
    200       new_shuffle[2] = shuffle[3];
    201       expanded_dims[0] = in.dimension(2);
    202       expanded_dims[1] = in.dimension(3);
    203       expanded_dims[2] = in.dimension(shuffle[2]);
    204       expanded_dims[3] = in.dimension(shuffle[3]);
    205     } else if (shuffle[0] == 0 && shuffle[1] == 3 && shuffle[2] == 1 &&
    206                shuffle[3] == 2) {
    207       merged_dims[0] = in.dimension(0);
    208       merged_dims[1] = in.dimension(1) * in.dimension(2);
    209       merged_dims[2] = in.dimension(3);
    210       new_shuffle[0] = 0;
    211       new_shuffle[1] = 2;
    212       new_shuffle[2] = 1;
    213       expanded_dims[0] = in.dimension(0);
    214       expanded_dims[1] = in.dimension(3);
    215       expanded_dims[2] = in.dimension(1);
    216       expanded_dims[3] = in.dimension(2);
    217     } else {
    218       assert(false && "unexpected shuffle");
    219     }
    220 
    221     out.device(d) =
    222         in.reshape(merged_dims).shuffle(new_shuffle).reshape(expanded_dims);
    223   }
    224 };
    225 
    226 template <typename Device, typename T, typename IndexType, int NDIMS>
    227 struct PadInput {
    228   void operator()(const Device& d,
    229                   typename TTypes<T, NDIMS, IndexType>::ConstTensor in,
    230                   const std::array<int, NDIMS - 2>& padding_left,
    231                   const std::array<int, NDIMS - 2>& padding_right,
    232                   typename TTypes<T, NDIMS, IndexType>::Tensor out,
    233                   TensorFormat format) {
    234     Eigen::array<Eigen::IndexPair<IndexType>, NDIMS> padding;
    235     padding[GetTensorDimIndex<NDIMS - 2>(format, 'N')] = {0, 0};
    236     for (int i = 0; i < NDIMS - 2; ++i) {
    237       padding[GetTensorDimIndex<NDIMS - 2>(format, '0' + i)] = {
    238           padding_left[i], padding_right[i]};
    239     }
    240     padding[GetTensorDimIndex<NDIMS - 2>(format, 'C')] = {0, 0};
    241     out.device(d) = in.pad(padding);
    242   }
    243 };
    244 
    245 // Converts a tensor from:
    246 //   [batch, <spatial>, filters]
    247 // to:
    248 //   [batch, filters, <spatial>]
    249 template <typename Device, typename T, int NDIMS>
    250 struct NHWCToNCHW {
    251   void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
    252                   typename TTypes<T, NDIMS>::Tensor out);
    253 };
    254 
    255 // Converts a tensor from:
    256 //   [batch, filters, <spatial>]
    257 // to:
    258 //   [batch, <spatial>, filters]
    259 template <typename Device, typename T, int NDIMS>
    260 struct NCHWToNHWC {
    261   void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
    262                   typename TTypes<T, NDIMS>::Tensor out);
    263 };
    264 
    265 // Converts a tensor from:
    266 //   [dim0, dim1, dim2]
    267 // to:
    268 //   [dim0, dim2, dim1]
    269 template <typename Device, typename T, bool conjugate = false>
    270 struct SwapDimension1And2InTensor3 {
    271   void operator()(const Device& d, const T* in,
    272                   const gtl::ArraySlice<int64>& input_dims, T* out);
    273 };
    274 
    275 // Converts a tensor from:
    276 //   [dim0, dim1, dim2]
    277 // to:
    278 //   [dim2, dim1, dim0]
    279 template <typename Device, typename T, bool conjugate = false>
    280 struct SwapDimension0And2InTensor3 {
    281   void operator()(const Device& d, const T* in,
    282                   const gtl::ArraySlice<int64>& input_dims, T* out);
    283 };
    284 
    285 // Reverses the effect of TransformFilter above.
    286 template <typename Device, typename T, int NDIMS>
    287 struct ReverseTransformFilter {
    288   void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
    289                   typename TTypes<T, NDIMS>::Tensor out);
    290 };
    291 
    292 }  // namespace functor
    293 
    294 template <class T>
    295 class ConvAlgorithmMap;
    296 
    297 template <>
    298 class ConvAlgorithmMap<Eigen::ThreadPoolDevice> {};
    299 }  // namespace tensorflow
    300 
    301 #endif  // TENSORFLOW_KERNELS_CONV_2D_H_
    302