Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc.
     17 
     18 #define USE_EIGEN_TENSOR
     19 #define EIGEN_USE_THREADS
     20 
     21 #include "tensorflow/core/kernels/conv_grad_ops.h"
     22 
     23 #include <algorithm>
     24 #include <vector>
     25 
     26 #include "tensorflow/core/framework/numeric_op.h"
     27 #include "tensorflow/core/framework/op_kernel.h"
     28 #include "tensorflow/core/framework/register_types.h"
     29 #include "tensorflow/core/framework/tensor.h"
     30 #include "tensorflow/core/framework/tensor_shape.h"
     31 #include "tensorflow/core/framework/tensor_slice.h"
     32 #include "tensorflow/core/kernels/conv_2d.h"
     33 #ifdef TENSORFLOW_USE_LIBXSMM
     34 #include "tensorflow/core/kernels/xsmm_conv2d.h"
     35 #endif
     36 #include "tensorflow/core/kernels/ops_util.h"
     37 #include "tensorflow/core/lib/core/errors.h"
     38 #include "tensorflow/core/lib/gtl/array_slice.h"
     39 #include "tensorflow/core/platform/logging.h"
     40 #include "tensorflow/core/platform/macros.h"
     41 #include "tensorflow/core/util/padding.h"
     42 #include "tensorflow/core/util/tensor_format.h"
     43 #include "tensorflow/core/util/use_cudnn.h"
     44 #include "tensorflow/core/util/work_sharder.h"
     45 
     46 #if GOOGLE_CUDA
     47 #include "tensorflow/core/kernels/conv_ops_gpu.h"
     48 #include "tensorflow/core/platform/stream_executor.h"
     49 #endif  // GOOGLE_CUDA
     50 
     51 namespace {
     52 
     53 // Returns in 'im_data' (assumes to be zero-initialized) image patch in storage
     54 // order (height, width, depth), constructed from patches in 'col_data', which
     55 // is required to be in storage order (out_height * out_width, filter_height,
     56 // filter_width, in_depth).  Implementation by Yangqing Jia (jiayq).
     57 template <typename T>
     58 void Col2im(const T* col_data, const int depth, const int height,
     59             const int width, const int filter_h, const int filter_w,
     60             const int pad_t, const int pad_l, const int pad_b, const int pad_r,
     61             const int stride_h, const int stride_w, T* im_data) {
     62   int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
     63   int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
     64   int h_pad = -pad_t;
     65   for (int h = 0; h < height_col; ++h) {
     66     int w_pad = -pad_l;
     67     for (int w = 0; w < width_col; ++w) {
     68       T* im_patch_data = im_data + (h_pad * width + w_pad) * depth;
     69       for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
     70         for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
     71           if (ih >= 0 && ih < height && iw >= 0 && iw < width) {
     72             // TODO(andydavis) Vectorize this loop (if compiler does not).
     73             for (int i = 0; i < depth; ++i) {
     74               im_patch_data[i] += col_data[i];
     75             }
     76           }
     77           im_patch_data += depth;
     78           col_data += depth;
     79         }
     80         // Jump over remaining number of depth.
     81         im_patch_data += depth * (width - filter_w);
     82       }
     83       w_pad += stride_w;
     84     }
     85     h_pad += stride_h;
     86   }
     87 }
     88 
     89 }  // namespace
     90 
     91 namespace tensorflow {
     92 
     93 typedef Eigen::ThreadPoolDevice CPUDevice;
     94 typedef Eigen::GpuDevice GPUDevice;
     95 
     96 // The fast versions using eigen computations directly. They are only enabled
     97 // for CPU for now since nvcc times out when trying to compile them.
     98 // TODO(yangke): enable them for GPUs when we have a faster compiler.
     99 
    100 template <typename T>
    101 struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
    102   void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
    103                   const Tensor& out_backprop, const Tensor& filter,
    104                   int row_stride, int col_stride, const Padding& padding,
    105                   Tensor* in_backprop, TensorFormat data_format) {
    106     const CPUDevice& d = ctx->eigen_device<CPUDevice>();
    107     functor::SpatialConvolutionBackwardInput<CPUDevice, T>()(
    108         d, in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
    109         out_backprop.tensor<T, 4>(), row_stride, col_stride,
    110         /*row_dilation=*/1, /*col_dilation=*/1);
    111   }
    112 };
    113 
    114 #ifdef TENSORFLOW_USE_LIBXSMM
    115 template <typename Device, class T>
    116 struct LaunchXsmmBackwardInputConvolution {
    117   bool operator()(OpKernelContext* context, const Device& d,
    118                   typename TTypes<T, 4>::Tensor input_backward,
    119                   typename TTypes<T, 4>::ConstTensor kernel,
    120                   typename TTypes<T, 4>::ConstTensor output_backward,
    121                   int input_rows, int input_cols, int row_stride,
    122                   int col_stride, int pad_h, int pad_w,
    123                   TensorFormat data_format) const {
    124     return false;
    125   }
    126 };
    127 
    128 template <>
    129 struct LaunchXsmmBackwardInputConvolution<CPUDevice, float> {
    130   bool operator()(OpKernelContext* context, const CPUDevice& d,
    131                   typename TTypes<float, 4>::Tensor input_backward,
    132                   typename TTypes<float, 4>::ConstTensor kernel,
    133                   typename TTypes<float, 4>::ConstTensor output_backward,
    134                   int input_rows, int input_cols, int row_stride,
    135                   int col_stride, int pad_h, int pad_w,
    136                   TensorFormat data_format) const {
    137     auto batch = input_backward.dimension(0);
    138     auto in_depth = input_backward.dimension(3);
    139     auto out_depth = output_backward.dimension(3);
    140     auto filter_rows = kernel.dimension(0);
    141     auto filter_cols = kernel.dimension(1);
    142     auto num_threads =
    143         context->device()->tensorflow_cpu_worker_threads()->num_threads;
    144     // See libxsmm_dnn.h for this struct definition.
    145     libxsmm_dnn_conv_desc desc;
    146     desc.N = batch;
    147     desc.C = in_depth;
    148     desc.H = input_rows;
    149     desc.W = input_cols;
    150     desc.K = out_depth;
    151     desc.R = filter_rows;
    152     desc.S = filter_cols;
    153     desc.u = row_stride;
    154     desc.v = col_stride;
    155     desc.pad_h = pad_h;
    156     desc.pad_w = pad_w;
    157     desc.pad_h_in = 0;
    158     desc.pad_w_in = 0;
    159     desc.pad_h_out = 0;
    160     desc.pad_w_out = 0;
    161     desc.threads = num_threads;
    162     desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
    163     desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
    164     desc.filter_format =
    165         LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;  // LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
    166     desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
    167     desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE;
    168     desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
    169 
    170     auto input_ptr = input_backward.data();
    171     auto filter_ptr = kernel.data();
    172     auto output_ptr = output_backward.data();
    173 
    174     bool success = functor::XsmmBkwInputConv2D<CPUDevice, float>()(
    175         context, desc, input_ptr, filter_ptr, output_ptr);
    176     return success;
    177   }
    178 };
    179 #endif
    180 
    181 template <typename Device, class T>
    182 class Conv2DFastBackpropInputOp : public OpKernel {
    183  public:
    184   explicit Conv2DFastBackpropInputOp(OpKernelConstruction* context)
    185       : OpKernel(context) {
    186     string data_format;
    187     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    188     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    189                 errors::InvalidArgument("Invalid data format"));
    190     OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
    191                 errors::InvalidArgument(
    192                     "Eigen Conv2DFastBackpropInputOp only supports NHWC."));
    193     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
    194     OP_REQUIRES(context, strides_.size() == 4,
    195                 errors::InvalidArgument("Sliding window strides field must "
    196                                         "specify 4 dimensions"));
    197     OP_REQUIRES(
    198         context, (strides_[0] == 1 && strides_[3] == 1),
    199         errors::InvalidArgument("Current implementation does not yet support "
    200                                 "strides in the batch and depth dimensions."));
    201     OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0,
    202                 errors::InvalidArgument(
    203                     "Row and column strides should be larger than 0."));
    204     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    205     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
    206     OP_REQUIRES(context, dilations_.size() == 4,
    207                 errors::InvalidArgument("Sliding window dilations field must "
    208                                         "specify 4 dimensions"));
    209     OP_REQUIRES(context, (dilations_[0] && dilations_[3]),
    210                 errors::InvalidArgument(
    211                     "Current implementation does not yet support "
    212                     "dilations in the batch and depth dimensions."));
    213     // TODO(yangzihao): Add a CPU implementation for dilated convolution.
    214     OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
    215                 errors::InvalidArgument(
    216                     "Current Eigen and libxsmm implementations do not "
    217                     "yet support dilation rates larger than 1."));
    218   }
    219 
    220   void Compute(OpKernelContext* context) override {
    221     const Tensor& input_sizes = context->input(0);
    222     const Tensor& filter = context->input(1);
    223     const Tensor& out_backprop = context->input(2);
    224     OP_REQUIRES(
    225         context, TensorShapeUtils::IsVector(input_sizes.shape()),
    226         errors::InvalidArgument(
    227             "Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
    228             input_sizes.dims()));
    229     TensorShape input_shape;
    230     OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
    231                                 input_sizes.vec<int32>(), &input_shape));
    232 
    233     ConvBackpropDimensions dims;
    234     OP_REQUIRES_OK(context,
    235                    ConvBackpropComputeDimensions(
    236                        "Conv2DFastBackpropInput", /*num_spatial_dims=*/2,
    237                        input_shape, filter.shape(), out_backprop.shape(),
    238                        strides_, padding_, data_format_, &dims));
    239 
    240     Tensor* in_backprop = nullptr;
    241     OP_REQUIRES_OK(context,
    242                    context->allocate_output(0, input_shape, &in_backprop));
    243 
    244     // If there is nothing to compute, return.
    245     if (input_shape.num_elements() == 0) {
    246       return;
    247     }
    248 
    249 #if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD
    250     int64 pad_top, pad_bottom;
    251     int64 pad_left, pad_right;
    252     OP_REQUIRES_OK(
    253         context,
    254         GetWindowedOutputSizeVerbose(
    255             dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
    256             dims.spatial_dims[0].stride, padding_,
    257             &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom));
    258     OP_REQUIRES_OK(
    259         context,
    260         GetWindowedOutputSizeVerbose(
    261             dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
    262             dims.spatial_dims[1].stride, padding_,
    263             &dims.spatial_dims[1].output_size, &pad_left, &pad_right));
    264 
    265     if (pad_left == pad_right && pad_top == pad_bottom) {
    266       if (LaunchXsmmBackwardInputConvolution<Device, T>()(
    267               context, context->eigen_device<Device>(),
    268               in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
    269               out_backprop.tensor<T, 4>(), dims.spatial_dims[0].input_size,
    270               dims.spatial_dims[1].input_size,
    271               static_cast<int>(dims.spatial_dims[0].stride),
    272               static_cast<int>(dims.spatial_dims[1].stride),
    273               static_cast<int>(pad_top), static_cast<int>(pad_left),
    274               data_format_)) {
    275         return;
    276       }
    277     }
    278 #endif
    279 
    280     LaunchConv2DBackpropInputOp<Device, T>()(
    281         context, false, false, out_backprop, filter,
    282         dims.spatial_dims[0].stride, dims.spatial_dims[1].stride, padding_,
    283         in_backprop, data_format_);
    284   }
    285 
    286  private:
    287   std::vector<int32> dilations_;
    288   std::vector<int32> strides_;
    289   Padding padding_;
    290   TensorFormat data_format_;
    291 
    292   TF_DISALLOW_COPY_AND_ASSIGN(Conv2DFastBackpropInputOp);
    293 };
    294 
    295 // Based on implementation written by Yangqing Jia (jiayq).
    296 template <typename Device, class T>
    297 class Conv2DCustomBackpropInputOp : public OpKernel {
    298  public:
    299   explicit Conv2DCustomBackpropInputOp(OpKernelConstruction* context)
    300       : OpKernel(context) {
    301     string data_format;
    302     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    303     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    304                 errors::InvalidArgument("Invalid data format"));
    305     OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
    306                 errors::InvalidArgument(
    307                     "Conv2DCustomBackpropInputOp only supports NHWC."));
    308     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
    309     OP_REQUIRES(context, strides_.size() == 4,
    310                 errors::InvalidArgument("Sliding window strides field must "
    311                                         "specify 4 dimensions"));
    312     OP_REQUIRES(
    313         context, (strides_[0] == 1 && strides_[3] == 1),
    314         errors::InvalidArgument("Current implementation does not yet support "
    315                                 "strides in the batch and depth dimensions."));
    316     OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0,
    317                 errors::InvalidArgument(
    318                     "Row and column strides should be larger than 0."));
    319     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    320     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
    321     OP_REQUIRES(context, dilations_.size() == 4,
    322                 errors::InvalidArgument("Sliding window dilations field must "
    323                                         "specify 4 dimensions"));
    324     OP_REQUIRES(context, (dilations_[0] == 1 && dilations_[3] == 1),
    325                 errors::InvalidArgument(
    326                     "Current implementation does not yet support "
    327                     "dilations in the batch and depth dimensions."));
    328     // TODO(yangzihao): Add a CPU implementation for dilated convolution.
    329     OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
    330                 errors::InvalidArgument(
    331                     "Current libxsmm and customized CPU implementations do "
    332                     "not yet support dilation rates larger than 1."));
    333   }
    334 
    335   void Compute(OpKernelContext* context) override {
    336     const Tensor& input_sizes = context->input(0);
    337     const Tensor& filter = context->input(1);
    338     const Tensor& out_backprop = context->input(2);
    339     OP_REQUIRES(
    340         context, TensorShapeUtils::IsVector(input_sizes.shape()),
    341         errors::InvalidArgument(
    342             "Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
    343             input_sizes.dims()));
    344     TensorShape input_shape;
    345     OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
    346                                 input_sizes.vec<int32>(), &input_shape));
    347 
    348     ConvBackpropDimensions dims;
    349     OP_REQUIRES_OK(context,
    350                    ConvBackpropComputeDimensions(
    351                        "Conv2DCustomBackpropInput", /*num_spatial_dims=*/2,
    352                        input_shape, filter.shape(), out_backprop.shape(),
    353                        strides_, padding_, data_format_, &dims));
    354 
    355     Tensor* in_backprop = nullptr;
    356     OP_REQUIRES_OK(context,
    357                    context->allocate_output(0, input_shape, &in_backprop));
    358 
    359     // If there is nothing to compute, return.
    360     if (input_shape.num_elements() == 0) {
    361       return;
    362     }
    363 
    364 // TODO(andydavis) Consider moving code shared with
    365 // Conv2DCustomBackpropFilterOp into a shared helper function.
    366 #if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD
    367     int64 pad_top, pad_bottom;
    368     int64 pad_left, pad_right;
    369     OP_REQUIRES_OK(
    370         context,
    371         GetWindowedOutputSizeVerbose(
    372             dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
    373             dims.spatial_dims[0].stride, padding_,
    374             &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom));
    375     OP_REQUIRES_OK(
    376         context,
    377         GetWindowedOutputSizeVerbose(
    378             dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
    379             dims.spatial_dims[1].stride, padding_,
    380             &dims.spatial_dims[1].output_size, &pad_left, &pad_right));
    381 
    382     if (pad_left == pad_right && pad_top == pad_bottom) {
    383       if (LaunchXsmmBackwardInputConvolution<Device, T>()(
    384               context, context->eigen_device<Device>(),
    385               in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
    386               out_backprop.tensor<T, 4>(), dims.spatial_dims[0].input_size,
    387               dims.spatial_dims[1].input_size,
    388               static_cast<int>(dims.spatial_dims[0].stride),
    389               static_cast<int>(dims.spatial_dims[1].stride),
    390               static_cast<int>(pad_top), static_cast<int>(pad_left),
    391               data_format_)) {
    392         return;
    393       }
    394     }
    395 #else
    396     int64 pad_top, pad_bottom;
    397     int64 pad_left, pad_right;
    398 #endif
    399     OP_REQUIRES_OK(
    400         context,
    401         GetWindowedOutputSizeVerbose(
    402             dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
    403             dims.spatial_dims[0].stride, padding_,
    404             &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom));
    405     OP_REQUIRES_OK(
    406         context,
    407         GetWindowedOutputSizeVerbose(
    408             dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
    409             dims.spatial_dims[1].stride, padding_,
    410             &dims.spatial_dims[1].output_size, &pad_left, &pad_right));
    411 
    412     // The total dimension size of each kernel.
    413     const int filter_total_size = dims.spatial_dims[0].filter_size *
    414                                   dims.spatial_dims[1].filter_size *
    415                                   dims.in_depth;
    416     // The output image size is the spatial size of the output.
    417     const int output_image_size =
    418         dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size;
    419 
    420     // TODO(andydavis) Get L2/L3 cache sizes from device.
    421     const size_t l2_cache_size = 256LL << 10;
    422     const size_t l3_cache_size = 30LL << 20;
    423 
    424     // Use L3 cache size as target working set size.
    425     const size_t target_working_set_size = l3_cache_size / sizeof(T);
    426 
    427     // Calculate size of matrices involved in MatMul: C = A x B.
    428     const size_t size_A = output_image_size * dims.out_depth;
    429 
    430     const size_t size_B = filter_total_size * dims.out_depth;
    431 
    432     const size_t size_C = output_image_size * filter_total_size;
    433 
    434     const size_t work_unit_size = size_A + size_B + size_C;
    435 
    436     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
    437 
    438     // Calculate per-thread work unit size.
    439     const size_t thread_work_unit_size =
    440         work_unit_size / worker_threads.num_threads;
    441 
    442     // Set minimum per-thread work unit size to size of L2 cache.
    443     const size_t min_thread_work_unit_size = l2_cache_size / sizeof(T);
    444 
    445     // Use parallel tensor contractions if there is no batching, or if the
    446     // minimum per-thread work unit size threshold has been exceeded.
    447     // Otherwise, revert to multiple single-threaded matmul ops running in
    448     // parallel to keep all threads busy.
    449     // TODO(andydavis) Explore alternatives to branching the code in this way
    450     // (i.e. run multiple, parallel tensor contractions in another thread pool).
    451     const bool use_parallel_contraction =
    452         dims.batch_size == 1 ||
    453         thread_work_unit_size >= min_thread_work_unit_size;
    454 
    455     const size_t shard_size =
    456         use_parallel_contraction
    457             ? 1
    458             : (target_working_set_size + work_unit_size - 1) / work_unit_size;
    459 
    460     Tensor col_buffer;
    461     OP_REQUIRES_OK(context,
    462                    context->allocate_temp(
    463                        DataTypeToEnum<T>::value,
    464                        TensorShape({static_cast<int64>(shard_size),
    465                                     static_cast<int64>(output_image_size),
    466                                     static_cast<int64>(filter_total_size)}),
    467                        &col_buffer));
    468 
    469     // The input offset corresponding to a single input image.
    470     const int input_offset = dims.spatial_dims[0].input_size *
    471                              dims.spatial_dims[1].input_size * dims.in_depth;
    472     // The output offset corresponding to a single output image.
    473     const int output_offset = dims.spatial_dims[0].output_size *
    474                               dims.spatial_dims[1].output_size * dims.out_depth;
    475 
    476     const T* filter_data = filter.template flat<T>().data();
    477     T* col_buffer_data = col_buffer.template flat<T>().data();
    478     const T* out_backprop_data = out_backprop.template flat<T>().data();
    479 
    480     auto in_backprop_flat = in_backprop->template flat<T>();
    481     T* input_backprop_data = in_backprop_flat.data();
    482     in_backprop_flat.device(context->eigen_device<Device>()) =
    483         in_backprop_flat.constant(T(0));
    484 
    485     if (use_parallel_contraction) {
    486       typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
    487                                Eigen::Unaligned>
    488           TensorMap;
    489       typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
    490                                Eigen::Unaligned>
    491           ConstTensorMap;
    492 
    493       // Initialize contraction dims (we need to transpose 'B' below).
    494       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
    495       contract_dims[0].first = 1;
    496       contract_dims[0].second = 1;
    497 
    498       for (int image_id = 0; image_id < dims.batch_size; ++image_id) {
    499         // Compute gradient into col_buffer.
    500         TensorMap C(col_buffer_data, output_image_size, filter_total_size);
    501 
    502         ConstTensorMap A(out_backprop_data + output_offset * image_id,
    503                          output_image_size, dims.out_depth);
    504         ConstTensorMap B(filter_data, filter_total_size, dims.out_depth);
    505 
    506         C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims);
    507 
    508         Col2im<T>(
    509             col_buffer_data, dims.in_depth, dims.spatial_dims[0].input_size,
    510             dims.spatial_dims[1].input_size, dims.spatial_dims[0].filter_size,
    511             dims.spatial_dims[1].filter_size, pad_top, pad_left, pad_bottom,
    512             pad_right, dims.spatial_dims[0].stride, dims.spatial_dims[1].stride,
    513             input_backprop_data);
    514 
    515         input_backprop_data += input_offset;
    516       }
    517     } else {
    518       typedef Eigen::Map<
    519           Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
    520           MatrixMap;
    521       typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
    522                                              Eigen::RowMajor>>
    523           ConstMatrixMap;
    524 
    525       for (int image_id = 0; image_id < dims.batch_size;
    526            image_id += shard_size) {
    527         const int shard_limit =
    528             std::min(static_cast<int>(shard_size),
    529                      static_cast<int>(dims.batch_size) - image_id);
    530 
    531         auto shard = [&dims, &pad_top, &pad_left, &pad_bottom, &pad_right,
    532                       &output_image_size, &filter_total_size,
    533                       &input_backprop_data, &col_buffer_data,
    534                       &out_backprop_data, &filter_data, &input_offset,
    535                       &output_offset, &size_C](int64 start, int64 limit) {
    536           for (int shard_id = start; shard_id < limit; ++shard_id) {
    537             T* im2col_buf = col_buffer_data + shard_id * size_C;
    538             T* input_data = input_backprop_data + shard_id * input_offset;
    539             const T* out_data = out_backprop_data + shard_id * output_offset;
    540 
    541             // Compute gradient into 'im2col_buf'.
    542             MatrixMap C(im2col_buf, output_image_size, filter_total_size);
    543 
    544             ConstMatrixMap A(out_data, output_image_size, dims.out_depth);
    545             ConstMatrixMap B(filter_data, filter_total_size, dims.out_depth);
    546 
    547             C.noalias() = A * B.transpose();
    548 
    549             Col2im<T>(im2col_buf, dims.in_depth,
    550                       dims.spatial_dims[0].input_size,
    551                       dims.spatial_dims[1].input_size,
    552                       dims.spatial_dims[0].filter_size,
    553                       dims.spatial_dims[1].filter_size, pad_top, pad_left,
    554                       pad_bottom, pad_right, dims.spatial_dims[0].stride,
    555                       dims.spatial_dims[1].stride, input_data);
    556           }
    557         };
    558         Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
    559               work_unit_size, shard);
    560 
    561         input_backprop_data += input_offset * shard_limit;
    562         out_backprop_data += output_offset * shard_limit;
    563       }
    564     }
    565   }
    566 
    567  private:
    568   std::vector<int32> dilations_;
    569   std::vector<int32> strides_;
    570   Padding padding_;
    571   TensorFormat data_format_;
    572 
    573   TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropInputOp);
    574 };
    575 
    576 #define REGISTER_CPU_KERNELS(T)                                              \
    577   REGISTER_KERNEL_BUILDER(                                                   \
    578       Name("Conv2DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    579       Conv2DCustomBackpropInputOp<CPUDevice, T>);                            \
    580   REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")                        \
    581                               .Device(DEVICE_CPU)                            \
    582                               .Label("custom")                               \
    583                               .TypeConstraint<T>("T"),                       \
    584                           Conv2DCustomBackpropInputOp<CPUDevice, T>);        \
    585   REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")                        \
    586                               .Device(DEVICE_CPU)                            \
    587                               .Label("eigen_tensor")                         \
    588                               .TypeConstraint<T>("T"),                       \
    589                           Conv2DFastBackpropInputOp<CPUDevice, T>);
    590 
    591 TF_CALL_half(REGISTER_CPU_KERNELS);
    592 TF_CALL_float(REGISTER_CPU_KERNELS);
    593 #undef REGISTER_CPU_KERNELS
    594 
    595 // GPU definitions.
    596 #if GOOGLE_CUDA
    597 // The slow version (but compiles for GPU)
    598 
    599 // A dummy type to group forward backward data autotune results together.
    600 struct ConvBackwardDataAutoTuneGroup {
    601   static string name() { return "ConvBwdData"; }
    602 };
    603 typedef AutoTuneSingleton<ConvBackwardDataAutoTuneGroup, ConvParameters,
    604                           perftools::gputools::dnn::AlgorithmConfig>
    605     AutoTuneConvBwdData;
    606 
    607 // Backprop for input.
    608 template <typename Device, class T>
    609 class Conv2DSlowBackpropInputOp : public OpKernel {
    610  public:
    611   explicit Conv2DSlowBackpropInputOp(OpKernelConstruction* context)
    612       : OpKernel(context) {
    613     string data_format;
    614     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    615     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    616                 errors::InvalidArgument("Invalid data format"));
    617     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
    618     OP_REQUIRES(context, strides_.size() == 4,
    619                 errors::InvalidArgument("Sliding window strides field must "
    620                                         "specify 4 dimensions"));
    621     int stride_n = GetTensorDim(strides_, data_format_, 'N');
    622     int stride_c = GetTensorDim(strides_, data_format_, 'C');
    623     int stride_h = GetTensorDim(strides_, data_format_, 'H');
    624     int stride_w = GetTensorDim(strides_, data_format_, 'W');
    625     OP_REQUIRES(
    626         context, (stride_n == 1 && stride_c == 1),
    627         errors::InvalidArgument("Current implementation does not yet support "
    628                                 "strides in the batch and depth dimensions."));
    629     OP_REQUIRES(context, stride_h > 0 && stride_w > 0,
    630                 errors::InvalidArgument(
    631                     "Row and column strides should be larger than 0."));
    632     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
    633     OP_REQUIRES(context, dilations_.size() == 4,
    634                 errors::InvalidArgument("Sliding window dilations field must "
    635                                         "specify 4 dimensions"));
    636     int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
    637     int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
    638     int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
    639     int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
    640     OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1),
    641                 errors::InvalidArgument(
    642                     "Current implementation does not yet support "
    643                     "dilations in the batch and depth dimensions."));
    644     OP_REQUIRES(
    645         context, dilation_h > 0 && dilation_w > 0,
    646         errors::InvalidArgument("Dilated rates should be larger than 0."));
    647     OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
    648     use_cudnn_ &= CanUseCudnn();
    649     cudnn_use_autotune_ = CudnnUseAutotune();
    650     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    651   }
    652 
    653   void Compute(OpKernelContext* context) override {
    654     const Tensor& input_sizes = context->input(0);
    655     const Tensor& filter = context->input(1);
    656     const Tensor& out_backprop = context->input(2);
    657     OP_REQUIRES(
    658         context, TensorShapeUtils::IsVector(input_sizes.shape()),
    659         errors::InvalidArgument(
    660             "Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
    661             input_sizes.dims()));
    662     TensorShape input_shape;
    663     OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
    664                                 input_sizes.vec<int32>(), &input_shape));
    665 
    666     Tensor* in_backprop = nullptr;
    667     OP_REQUIRES_OK(context,
    668                    context->allocate_output(0, input_shape, &in_backprop));
    669 
    670     // If there is nothing to compute, return.
    671     if (input_shape.num_elements() == 0) {
    672       return;
    673     }
    674 
    675     // For now we take the stride from the second and third dimensions only (we
    676     // do not support striding on the batch or depth dimension).
    677     const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
    678     const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
    679     const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H');
    680     const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W');
    681 
    682     launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, filter,
    683               dilation_rows, dilation_cols, stride_rows, stride_cols, padding_,
    684               in_backprop, data_format_);
    685   }
    686 
    687  private:
    688   std::vector<int32> dilations_;
    689   std::vector<int32> strides_;
    690   Padding padding_;
    691   bool use_cudnn_;
    692   TensorFormat data_format_;
    693   LaunchConv2DBackpropInputOp<Device, T> launcher_;
    694   bool cudnn_use_autotune_;
    695 
    696   TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropInputOp);
    697 };
    698 
    699 template <typename T>
    700 void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
    701     OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
    702     const Tensor& out_backprop, const Tensor& filter, int row_dilation,
    703     int col_dilation, int row_stride, int col_stride, const Padding& padding,
    704     Tensor* in_backprop, TensorFormat data_format) {
    705   using perftools::gputools::dnn::AlgorithmConfig;
    706   using perftools::gputools::dnn::AlgorithmDesc;
    707   using perftools::gputools::dnn::ProfileResult;
    708 
    709   std::vector<int32> strides(4, 1);
    710   std::vector<int32> dilations(4, 1);
    711   auto input_h = GetTensorDimIndex(data_format, 'H');
    712   auto input_w = GetTensorDimIndex(data_format, 'W');
    713   strides[input_h] = row_stride;
    714   strides[input_w] = col_stride;
    715   dilations[input_h] = row_dilation;
    716   dilations[input_w] = col_dilation;
    717   TensorShape input_shape = in_backprop->shape();
    718 
    719   const TensorShape& filter_shape = filter.shape();
    720   ConvBackpropDimensions dims;
    721   OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensionsV2(
    722                           "Conv2DSlowBackpropInput", /*num_spatial_dims=*/2,
    723                           input_shape, filter_shape, out_backprop.shape(),
    724                           dilations, strides, padding, data_format, &dims));
    725 
    726   // TODO(yangzihao): The padding computations should be done in
    727   // GetWindowedOutputSize() functions.
    728   const int padding_rows =
    729       (padding == VALID)
    730           ? 0
    731           : std::max<int>(0, (dims.spatial_dims[0].output_size - 1) *
    732                                      dims.spatial_dims[0].stride +
    733                                  (dims.spatial_dims[0].filter_size - 1) *
    734                                      dims.spatial_dims[0].dilation +
    735                                  1 - dims.spatial_dims[0].input_size);
    736   const int padding_cols =
    737       (padding == VALID)
    738           ? 0
    739           : std::max<int>(0, (dims.spatial_dims[1].output_size - 1) *
    740                                      dims.spatial_dims[1].stride +
    741                                  (dims.spatial_dims[1].filter_size - 1) *
    742                                      dims.spatial_dims[1].dilation +
    743                                  1 - dims.spatial_dims[1].input_size);
    744 
    745   // TODO(keveman): cuDNN only supports equal padding on both sides, so only
    746   // calling it when that is true. Remove this check when (if?) cuDNN starts
    747   // supporting different padding.
    748   bool rows_odd = (padding_rows % 2 != 0);
    749   bool cols_odd = (padding_cols % 2 != 0);
    750 
    751   auto* stream = ctx->op_device_context()->stream();
    752   OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
    753 
    754   if (!use_cudnn) {
    755     ctx->SetStatus(errors::Unimplemented(
    756         "Conv2DBackpropInput for GPU is not currently supported "
    757         "without cudnn"));
    758     return;
    759   }
    760 
    761   if (dims.spatial_dims[0].filter_size == 1 &&
    762       dims.spatial_dims[1].filter_size == 1 &&
    763       dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
    764       data_format == FORMAT_NHWC) {
    765     // 1x1 filter, so call cublas directly.
    766     const uint64 m = dims.batch_size * dims.spatial_dims[0].input_size *
    767                      dims.spatial_dims[1].input_size;
    768     const uint64 k = dims.out_depth;
    769     const uint64 n = dims.in_depth;
    770 
    771     auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
    772                                 out_backprop.template flat<T>().size());
    773     auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
    774                                 filter.template flat<T>().size());
    775     auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
    776                                 in_backprop->template flat<T>().size());
    777 
    778     auto transpose = perftools::gputools::blas::Transpose::kTranspose;
    779     auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
    780 
    781     bool blas_launch_status =
    782         stream
    783             ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
    784                            a_ptr, k, 0.0f, &c_ptr, n)
    785             .ok();
    786     if (!blas_launch_status) {
    787       ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
    788                                       ", n=", n, ", k=", k));
    789     }
    790     return;
    791   } else if (dims.spatial_dims[0].filter_size ==
    792                  dims.spatial_dims[0].input_size &&
    793              dims.spatial_dims[1].filter_size ==
    794                  dims.spatial_dims[1].input_size &&
    795              padding == VALID && data_format == FORMAT_NHWC) {
    796     // The input data and filter have the same height/width, so call cublas
    797     // directly.
    798     const uint64 m = dims.batch_size;
    799     const uint64 k = dims.out_depth;
    800     const uint64 n = dims.spatial_dims[0].input_size *
    801                      dims.spatial_dims[1].input_size * dims.in_depth;
    802 
    803     auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
    804                                 out_backprop.template flat<T>().size());
    805     auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
    806                                 filter.template flat<T>().size());
    807     auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
    808                                 in_backprop->template flat<T>().size());
    809 
    810     auto transpose = perftools::gputools::blas::Transpose::kTranspose;
    811     auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
    812 
    813     bool blas_launch_status =
    814         stream
    815             ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
    816                            a_ptr, k, 0.0f, &c_ptr, n)
    817             .ok();
    818     if (!blas_launch_status) {
    819       ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
    820                                       ", n=", n, ", k=", k));
    821     }
    822     return;
    823   }
    824 
    825   TensorShape compatible_input_shape;
    826   if (rows_odd || cols_odd) {
    827     // If a padding dimension is odd, we have one more element on the right
    828     // side or the bottom side. This is unsupported in cudnn. Therefore,
    829     // we pad that extra element and make it compatible.
    830     compatible_input_shape = ShapeFromFormat(
    831         data_format, dims.batch_size,
    832         dims.spatial_dims[0].input_size + rows_odd,
    833         dims.spatial_dims[1].input_size + cols_odd, dims.in_depth);
    834   } else {
    835     compatible_input_shape = input_shape;
    836   }
    837 
    838   CHECK(padding_rows >= 0 && padding_cols >= 0)
    839       << "Negative row or col paddings: (" << padding_rows << ", "
    840       << padding_cols << ")";
    841   perftools::gputools::dnn::BatchDescriptor input_desc;
    842   input_desc.set_count(dims.batch_size)
    843       .set_height(GetTensorDim(compatible_input_shape, data_format, 'H'))
    844       .set_width(GetTensorDim(compatible_input_shape, data_format, 'W'))
    845       .set_feature_map_count(dims.in_depth)
    846       .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
    847   perftools::gputools::dnn::BatchDescriptor output_desc;
    848   output_desc.set_count(dims.batch_size)
    849       .set_height(dims.spatial_dims[0].output_size)
    850       .set_width(dims.spatial_dims[1].output_size)
    851       .set_feature_map_count(dims.out_depth)
    852       .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
    853   perftools::gputools::dnn::FilterDescriptor filter_desc;
    854   filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
    855       .set_input_filter_width(dims.spatial_dims[1].filter_size)
    856       .set_input_feature_map_count(dims.in_depth)
    857       .set_output_feature_map_count(dims.out_depth);
    858   perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
    859   conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
    860       .set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
    861       .set_vertical_filter_stride(dims.spatial_dims[0].stride)
    862       .set_horizontal_filter_stride(dims.spatial_dims[1].stride)
    863       .set_zero_padding_height(padding_rows / 2)
    864       .set_zero_padding_width(padding_cols / 2);
    865 
    866   // NOTE(keveman):
    867   // cuDNN only supports the following layouts :
    868   // Input  : B x D x R x C
    869   // Filter : OD x ID x R x C
    870   // Whereas, we have
    871   // Input  : B x R x C x D
    872   // Filter : R x C x ID x OD
    873   // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C)
    874   // The first TransformDepth performs
    875   // (B x R x C x D) => (B x D x R x C).
    876   // Since the tensor returned from cuDNN is B x D x R x C also,
    877   // the second TransformDepth performs
    878   // (B x D x R x C) => (B x R x C x D).
    879   Tensor transformed_filter;
    880   OP_REQUIRES_OK(
    881       ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
    882                               TensorShape({dims.out_depth, dims.in_depth,
    883                                            dims.spatial_dims[0].filter_size,
    884                                            dims.spatial_dims[1].filter_size}),
    885                               &transformed_filter));
    886 
    887   functor::TransformFilter<GPUDevice, T, int, 4>()(
    888       ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
    889       To32Bit(transformed_filter.tensor<T, 4>()));
    890 
    891   Tensor transformed_out_backprop;
    892   if (data_format == FORMAT_NHWC) {
    893     TensorShape nchw_shape = ShapeFromFormat(
    894         FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size,
    895         dims.spatial_dims[1].output_size, dims.out_depth);
    896     if (dims.out_depth > 1) {
    897       OP_REQUIRES_OK(ctx,
    898                      ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
    899                                         &transformed_out_backprop));
    900       functor::NHWCToNCHW<GPUDevice, T, 4>()(
    901           ctx->eigen_device<GPUDevice>(), out_backprop.tensor<T, 4>(),
    902           transformed_out_backprop.tensor<T, 4>());
    903     } else {
    904       // If depth <= 1, then just reshape.
    905       CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
    906     }
    907   } else {
    908     transformed_out_backprop = out_backprop;
    909   }
    910 
    911   Tensor pre_transformed_in_backprop;
    912   OP_REQUIRES_OK(
    913       ctx, ctx->allocate_temp(
    914                DataTypeToEnum<T>::value,
    915                ShapeFromFormat(
    916                    FORMAT_NCHW,
    917                    GetTensorDim(compatible_input_shape, data_format, 'N'),
    918                    GetTensorDim(compatible_input_shape, data_format, 'H'),
    919                    GetTensorDim(compatible_input_shape, data_format, 'W'),
    920                    GetTensorDim(compatible_input_shape, data_format, 'C')),
    921                &pre_transformed_in_backprop));
    922 
    923   auto out_backprop_ptr =
    924       AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
    925                      transformed_out_backprop.template flat<T>().size());
    926   auto filter_ptr =
    927       AsDeviceMemory(transformed_filter.template flat<T>().data(),
    928                      transformed_filter.template flat<T>().size());
    929   auto in_backprop_ptr =
    930       AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
    931                      pre_transformed_in_backprop.template flat<T>().size());
    932 
    933   static int64 ConvolveBackwardDataScratchSize = GetCudnnWorkspaceLimit(
    934       "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32  // 4GB by default
    935   );
    936   CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, ctx);
    937   int device_id = stream->parent()->device_ordinal();
    938   DataType dtype = out_backprop.dtype();
    939   ConvParameters conv_parameters = {
    940       dims.batch_size,                       // batch
    941       dims.in_depth,                         // in_depths
    942       {{input_desc.height(),                 // in_rows
    943         input_desc.width()}},                // in_cols
    944       dims.out_depth,                        // out_depths
    945       {{dims.spatial_dims[0].filter_size,    // filter_rows
    946         dims.spatial_dims[1].filter_size}},  // filter_cols
    947       {{dims.spatial_dims[0].dilation,       // dilation_rows
    948         dims.spatial_dims[1].dilation}},     // dilation_cols
    949       {{dims.spatial_dims[0].stride,         // stride_rows
    950         dims.spatial_dims[1].stride}},       // stride_cols
    951       {{padding_rows,                        // padding_rows
    952         padding_cols}},                      // padding_cols
    953       dtype,                                 // tensor data type
    954       device_id,                             // device_id
    955   };
    956   AlgorithmConfig algorithm_config;
    957   if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find(
    958                                 conv_parameters, &algorithm_config)) {
    959     std::vector<AlgorithmDesc> algorithms;
    960     CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
    961         conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
    962     ProfileResult best_result;
    963     ProfileResult best_result_no_scratch;
    964     for (auto profile_algorithm : algorithms) {
    965       // TODO(zhengxq): profile each algorithm multiple times to better
    966       // accuracy.
    967       CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
    968                                               ctx);
    969       ProfileResult profile_result;
    970       bool cudnn_launch_status =
    971           stream
    972               ->ThenConvolveBackwardDataWithAlgorithm(
    973                   filter_desc, filter_ptr, output_desc, out_backprop_ptr,
    974                   conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
    975                   AlgorithmConfig(profile_algorithm), &profile_result)
    976               .ok();
    977       if (cudnn_launch_status) {
    978         if (profile_result.is_valid()) {
    979           if (profile_result.elapsed_time_in_ms() <
    980               best_result.elapsed_time_in_ms()) {
    981             best_result = profile_result;
    982           }
    983           if (scratch_allocator.TotalByteSize() == 0 &&
    984               profile_result.elapsed_time_in_ms() <
    985                   best_result_no_scratch.elapsed_time_in_ms()) {
    986             best_result_no_scratch = profile_result;
    987           }
    988         }
    989       }
    990     }
    991     OP_REQUIRES(ctx,
    992                 best_result.is_valid() || best_result_no_scratch.is_valid(),
    993                 errors::NotFound("No algorithm worked!"));
    994     if (best_result.is_valid()) {
    995       algorithm_config.set_algorithm(best_result.algorithm());
    996     }
    997     if (best_result_no_scratch.is_valid()) {
    998       algorithm_config.set_algorithm_no_scratch(
    999           best_result_no_scratch.algorithm());
   1000     }
   1001     AutoTuneConvBwdData::GetInstance()->Insert(conv_parameters,
   1002                                                algorithm_config);
   1003   }
   1004   bool cudnn_launch_status =
   1005       stream
   1006           ->ThenConvolveBackwardDataWithAlgorithm(
   1007               filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
   1008               input_desc, &in_backprop_ptr, &scratch_allocator,
   1009               algorithm_config, nullptr)
   1010           .ok();
   1011 
   1012   if (!cudnn_launch_status) {
   1013     ctx->SetStatus(errors::Internal(
   1014         "cuDNN Backward Data function launch failure : input shape(",
   1015         input_shape.DebugString(), ") filter shape(",
   1016         filter_shape.DebugString(), ")"));
   1017     return;
   1018   }
   1019 
   1020   if (rows_odd || cols_odd) {
   1021     Tensor in_backprop_remove_padding;
   1022     OP_REQUIRES_OK(
   1023         ctx, ctx->allocate_temp(
   1024                  DataTypeToEnum<T>::value,
   1025                  ShapeFromFormat(FORMAT_NCHW,
   1026                                  GetTensorDim(input_shape, data_format, 'N'),
   1027                                  GetTensorDim(input_shape, data_format, 'H'),
   1028                                  GetTensorDim(input_shape, data_format, 'W'),
   1029                                  GetTensorDim(input_shape, data_format, 'C')),
   1030                  &in_backprop_remove_padding));
   1031 
   1032     // Remove the padding for odd rows or cols.
   1033     functor::PadInput<GPUDevice, T, int, 4>()(
   1034         ctx->template eigen_device<GPUDevice>(),
   1035         To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
   1036                     .tensor<T, 4>()),
   1037         {{0, 0}}, {{-rows_odd, -cols_odd}},
   1038         To32Bit(in_backprop_remove_padding.tensor<T, 4>()), FORMAT_NCHW);
   1039 
   1040     pre_transformed_in_backprop = in_backprop_remove_padding;
   1041   }
   1042 
   1043   if (data_format == FORMAT_NHWC) {
   1044     auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
   1045     functor::NCHWToNHWC<GPUDevice, T, 4>()(
   1046         ctx->eigen_device<GPUDevice>(),
   1047         toConstTensor(pre_transformed_in_backprop).template tensor<T, 4>(),
   1048         in_backprop->tensor<T, 4>());
   1049   } else {
   1050     *in_backprop = pre_transformed_in_backprop;
   1051   }
   1052 }
   1053 
   1054 // Forward declarations of the functor specializations for GPU.
   1055 namespace functor {
   1056 #define DECLARE_GPU_SPEC(T)                                              \
   1057   template <>                                                            \
   1058   void ShuffleAndReverse<GPUDevice, T, 4, int>::operator()(              \
   1059       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \
   1060       const Eigen::DSizes<int, 4>& order,                                \
   1061       const Eigen::array<bool, 4>& reverse_dims,                         \
   1062       typename TTypes<T, 4, int>::Tensor output);                        \
   1063   extern template struct ShuffleAndReverse<GPUDevice, T, 4, int>;        \
   1064   template <>                                                            \
   1065   void InflatePadAndShuffle<GPUDevice, T, 4, int>::operator()(           \
   1066       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \
   1067       const Eigen::DSizes<int, 4>& strides,                              \
   1068       const Eigen::array<Eigen::IndexPair<int>, 4>& pad_dims,            \
   1069       const Eigen::DSizes<int, 4>& order,                                \
   1070       typename TTypes<T, 4, int>::Tensor output);                        \
   1071   extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>;     \
   1072   template <>                                                            \
   1073   void TransformFilter<GPUDevice, T, int, 4>::operator()(                \
   1074       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,    \
   1075       typename TTypes<T, 4, int>::Tensor out);                           \
   1076   extern template struct TransformFilter<GPUDevice, T, int, 4>;          \
   1077   template <>                                                            \
   1078   void TransformDepth<GPUDevice, T, int>::operator()(                    \
   1079       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,    \
   1080       const Eigen::DSizes<int, 4>& shuffle,                              \
   1081       typename TTypes<T, 4, int>::Tensor out);                           \
   1082   extern template struct TransformDepth<GPUDevice, T, int>;              \
   1083   template <>                                                            \
   1084   void PadInput<GPUDevice, T, int, 4>::operator()(                       \
   1085       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,    \
   1086       const std::array<int, 2>& padding_left,                            \
   1087       const std::array<int, 2>& padding_right,                           \
   1088       typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
   1089   extern template struct PadInput<GPUDevice, T, int, 4>;
   1090 
   1091 DECLARE_GPU_SPEC(float);
   1092 DECLARE_GPU_SPEC(Eigen::half);
   1093 #undef DECLARE_GPU_SPEC
   1094 }  // namespace functor
   1095 
   1096 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
   1097                             .Device(DEVICE_GPU)
   1098                             .TypeConstraint<float>("T")
   1099                             .HostMemory("input_sizes"),
   1100                         Conv2DSlowBackpropInputOp<GPUDevice, float>);
   1101 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
   1102                             .Device(DEVICE_GPU)
   1103                             .TypeConstraint<Eigen::half>("T")
   1104                             .HostMemory("input_sizes"),
   1105                         Conv2DSlowBackpropInputOp<GPUDevice, Eigen::half>);
   1106 #endif  // GOOGLE_CUDA
   1107 
   1108 }  // namespace tensorflow
   1109