Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc.
     17 
     18 #define USE_EIGEN_TENSOR
     19 #define EIGEN_USE_THREADS
     20 
     21 #include "tensorflow/core/kernels/conv_grad_ops.h"
     22 
     23 #include <algorithm>
     24 #include <vector>
     25 
     26 #include "tensorflow/core/framework/numeric_op.h"
     27 #include "tensorflow/core/framework/op_kernel.h"
     28 #include "tensorflow/core/framework/register_types.h"
     29 #include "tensorflow/core/framework/tensor.h"
     30 #include "tensorflow/core/framework/tensor_shape.h"
     31 #include "tensorflow/core/framework/tensor_slice.h"
     32 #include "tensorflow/core/kernels/conv_2d.h"
     33 #include "tensorflow/core/kernels/fill_functor.h"
     34 #ifdef TENSORFLOW_USE_LIBXSMM
     35 #include "tensorflow/core/kernels/xsmm_conv2d.h"
     36 #endif
     37 #include "tensorflow/core/kernels/ops_util.h"
     38 #include "tensorflow/core/lib/core/errors.h"
     39 #include "tensorflow/core/lib/gtl/array_slice.h"
     40 #include "tensorflow/core/platform/logging.h"
     41 #include "tensorflow/core/platform/macros.h"
     42 #include "tensorflow/core/util/padding.h"
     43 #include "tensorflow/core/util/tensor_format.h"
     44 #include "tensorflow/core/util/use_cudnn.h"
     45 #include "tensorflow/core/util/work_sharder.h"
     46 
     47 #if GOOGLE_CUDA
     48 #include "tensorflow/core/kernels/conv_ops_gpu.h"
     49 #include "tensorflow/core/platform/stream_executor.h"
     50 #endif  // GOOGLE_CUDA
     51 
     52 namespace {
     53 
     54 // Returns in 'col_data', image patches in storage order (height, width, depth)
     55 // extracted from image at 'input_data', which is required to be in storage
     56 // order (batch, height, width, depth).
     57 // Implementation written by Yangqing Jia (jiayq).
     58 template <typename T>
     59 void Im2col(const T* input_data, const int depth, const int height,
     60             const int width, const int filter_h, const int filter_w,
     61             const int pad_t, const int pad_l, const int pad_b, const int pad_r,
     62             const int stride_h, const int stride_w, T* col_data) {
     63   int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
     64   int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
     65 
     66   int h_pad = -pad_t;
     67   for (int h = 0; h < height_col; ++h) {
     68     int w_pad = -pad_l;
     69     for (int w = 0; w < width_col; ++w) {
     70       for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
     71         for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
     72           if (ih >= 0 && ih < height && iw >= 0 && iw < width) {
     73             memcpy(col_data, input_data + (ih * width + iw) * depth,
     74                    sizeof(T) * depth);
     75           } else {
     76             // This should be simply padded with zero.
     77             memset(col_data, 0, sizeof(T) * depth);
     78           }
     79           col_data += depth;
     80         }
     81       }
     82       w_pad += stride_w;
     83     }
     84     h_pad += stride_h;
     85   }
     86 }
     87 
     88 }  // namespace
     89 
     90 namespace tensorflow {
     91 
     92 typedef Eigen::ThreadPoolDevice CPUDevice;
     93 typedef Eigen::GpuDevice GPUDevice;
     94 
     95 template <typename T>
     96 struct LaunchConv2DBackpropFilterOp<CPUDevice, T> {
     97   void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
     98                   const Tensor& out_backprop, const Tensor& input,
     99                   int row_stride, int col_stride, const Padding& padding,
    100                   Tensor* filter_backprop, TensorFormat data_format) {
    101     const CPUDevice& d = ctx->eigen_device<CPUDevice>();
    102     functor::SpatialConvolutionBackwardFilter<CPUDevice, T>()(
    103         d, filter_backprop->tensor<T, 4>(), input.tensor<T, 4>(),
    104         out_backprop.tensor<T, 4>(), row_stride, col_stride,
    105         /*row_dilation=*/1, /*col_dilation=*/1);
    106   }
    107 };
    108 
    109 #ifdef TENSORFLOW_USE_LIBXSMM
    110 template <typename Device, class T>
    111 struct LaunchXsmmBackwardFilter {
    112   bool operator()(OpKernelContext* context, const Device& d,
    113                   typename TTypes<T, 4>::ConstTensor input_backward,
    114                   typename TTypes<T, 4>::Tensor kernel,
    115                   typename TTypes<T, 4>::ConstTensor output_backward,
    116                   int input_rows, int input_cols, int row_stride,
    117                   int col_stride, int pad_h, int pad_w,
    118                   TensorFormat data_format) const {
    119     return false;
    120   }
    121 };
    122 
    123 template <>
    124 struct LaunchXsmmBackwardFilter<CPUDevice, float> {
    125   bool operator()(OpKernelContext* context, const CPUDevice& d,
    126                   typename TTypes<float, 4>::ConstTensor input,
    127                   typename TTypes<float, 4>::Tensor filter,
    128                   typename TTypes<float, 4>::ConstTensor output, int input_rows,
    129                   int input_cols, int row_stride, int col_stride, int pad_h,
    130                   int pad_w, TensorFormat data_format) const {
    131     auto batch = input.dimension(0);
    132     auto in_depth = input.dimension(3);
    133     auto out_depth = output.dimension(3);
    134     auto filter_rows = filter.dimension(0);
    135     auto filter_cols = filter.dimension(1);
    136 
    137     auto num_threads =
    138         context->device()->tensorflow_cpu_worker_threads()->num_threads;
    139     // See libxsmm_dnn.h for this struct definition.
    140     libxsmm_dnn_conv_desc desc;
    141     desc.N = batch;
    142     desc.C = in_depth;
    143     desc.H = input_rows;
    144     desc.W = input_cols;
    145     desc.K = out_depth;
    146     desc.R = filter_rows;
    147     desc.S = filter_cols;
    148     desc.u = row_stride;
    149     desc.v = col_stride;
    150     desc.pad_h = pad_h;
    151     desc.pad_w = pad_w;
    152     desc.pad_h_in = 0;  // pad_rows;  // ignored by libxsmm for now.
    153     desc.pad_w_in = 0;  // pad_cols;  // ignored by libxsmm for now.
    154     desc.pad_h_out = 0;
    155     desc.pad_w_out = 0;
    156     desc.threads = num_threads;
    157     desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
    158     desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
    159     desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
    160     desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
    161     desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
    162     desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
    163 
    164     if (!CanUseXsmmConv2D(desc, data_format)) {
    165       return false;
    166     }
    167 
    168     auto input_ptr = input.data();
    169     auto filter_ptr = filter.data();
    170     auto output_ptr = output.data();
    171     bool success = functor::XsmmBkwFilterConv2D<CPUDevice, float>()(
    172         context, desc, input_ptr, filter_ptr, output_ptr);
    173     return success;
    174   }
    175 };
    176 #endif
    177 
    178 template <typename Device, class T>
    179 class Conv2DFastBackpropFilterOp : public OpKernel {
    180  public:
    181   explicit Conv2DFastBackpropFilterOp(OpKernelConstruction* context)
    182       : OpKernel(context) {
    183     string data_format;
    184     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    185     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    186                 errors::InvalidArgument("Invalid data format"));
    187     OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
    188                 errors::InvalidArgument(
    189                     "Conv2DFastBackpropFilterOp only supports NHWC."));
    190     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
    191     OP_REQUIRES(context, strides_.size() == 4,
    192                 errors::InvalidArgument("Sliding window strides field must "
    193                                         "specify 4 dimensions"));
    194     OP_REQUIRES(
    195         context, (strides_[0] == 1 && strides_[3] == 1),
    196         errors::InvalidArgument("Current implementation does not yet support "
    197                                 "strides in the batch and depth dimensions."));
    198     OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0,
    199                 errors::InvalidArgument(
    200                     "Row and column strides should be larger than 0."));
    201     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    202     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
    203     OP_REQUIRES(context, dilations_.size() == 4,
    204                 errors::InvalidArgument("Sliding window dilations field must "
    205                                         "specify 4 dimensions"));
    206     OP_REQUIRES(context, (dilations_[0] == 1 && dilations_[3] == 1),
    207                 errors::InvalidArgument(
    208                     "Current implementation does not yet support "
    209                     "dilations in the batch and depth dimensions."));
    210     // TODO(yangzihao): Add a CPU implementation for dilated convolution.
    211     OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
    212                 errors::InvalidArgument(
    213                     "Current Eigen and libxsmm implementations do not "
    214                     "yet support dilation rates larger than 1."));
    215   }
    216 
    217   void Compute(OpKernelContext* context) override {
    218     const Tensor& input = context->input(0);
    219     const Tensor& filter_sizes = context->input(1);
    220     const Tensor& out_backprop = context->input(2);
    221     OP_REQUIRES(
    222         context, TensorShapeUtils::IsVector(filter_sizes.shape()),
    223         errors::InvalidArgument(
    224             "Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ",
    225             filter_sizes.dims()));
    226     TensorShape filter_shape;
    227     OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
    228                                 filter_sizes.vec<int32>(), &filter_shape));
    229 
    230     ConvBackpropDimensions dims;
    231     OP_REQUIRES_OK(
    232         context,
    233         ConvBackpropComputeDimensions(
    234             type_string(), /*num_spatial_dims=*/2, input.shape(), filter_shape,
    235             out_backprop.shape(), strides_, padding_, data_format_, &dims));
    236 
    237     Tensor* filter_backprop = nullptr;
    238     OP_REQUIRES_OK(context,
    239                    context->allocate_output(0, filter_shape, &filter_backprop));
    240 
    241     // If there is nothing to compute, return.
    242     if (filter_shape.num_elements() == 0) {
    243       return;
    244     }
    245 
    246 #if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD
    247     int64 pad_top, pad_bottom;
    248     int64 pad_left, pad_right;
    249     OP_REQUIRES_OK(
    250         context,
    251         GetWindowedOutputSizeVerbose(
    252             dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
    253             dims.spatial_dims[0].stride, padding_,
    254             &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom));
    255     OP_REQUIRES_OK(
    256         context,
    257         GetWindowedOutputSizeVerbose(
    258             dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
    259             dims.spatial_dims[1].stride, padding_,
    260             &dims.spatial_dims[1].output_size, &pad_left, &pad_right));
    261 
    262     if (pad_left == pad_right && pad_top == pad_bottom) {
    263       if (LaunchXsmmBackwardFilter<Device, T>()(
    264               context, context->eigen_device<Device>(), input.tensor<T, 4>(),
    265               filter_backprop->tensor<T, 4>(), out_backprop.tensor<T, 4>(),
    266               dims.spatial_dims[0].input_size, dims.spatial_dims[1].input_size,
    267               static_cast<int>(dims.spatial_dims[0].stride),
    268               static_cast<int>(dims.spatial_dims[1].stride),
    269               static_cast<int>(pad_top), static_cast<int>(pad_left),
    270               data_format_)) {
    271         return;
    272       }
    273     }
    274 #endif
    275 
    276     LaunchConv2DBackpropFilterOp<Device, T>()(
    277         context, false, false, out_backprop, input, dims.spatial_dims[0].stride,
    278         dims.spatial_dims[1].stride, padding_, filter_backprop, data_format_);
    279   }
    280 
    281  private:
    282   std::vector<int32> dilations_;
    283   std::vector<int32> strides_;
    284   Padding padding_;
    285   TensorFormat data_format_;
    286 
    287   TF_DISALLOW_COPY_AND_ASSIGN(Conv2DFastBackpropFilterOp);
    288 };
    289 
    290 // Based on implementation written by Yangqing Jia (jiayq).
    291 template <typename Device, class T>
    292 class Conv2DCustomBackpropFilterOp : public OpKernel {
    293  public:
    294   explicit Conv2DCustomBackpropFilterOp(OpKernelConstruction* context)
    295       : OpKernel(context) {
    296     string data_format;
    297     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    298     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    299                 errors::InvalidArgument("Invalid data format"));
    300     OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
    301                 errors::InvalidArgument(
    302                     "Conv2DCustomBackpropFilterOp only supports NHWC."));
    303     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
    304     OP_REQUIRES(context, strides_.size() == 4,
    305                 errors::InvalidArgument("Sliding window strides field must "
    306                                         "specify 4 dimensions"));
    307     OP_REQUIRES(
    308         context, (strides_[0] == 1 && strides_[3] == 1),
    309         errors::InvalidArgument("Current implementation does not yet support "
    310                                 "strides in the batch and depth dimensions."));
    311     OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0,
    312                 errors::InvalidArgument(
    313                     "Row and column strides should be larger than 0."));
    314     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    315     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
    316     OP_REQUIRES(context, dilations_.size() == 4,
    317                 errors::InvalidArgument("Sliding window dilations field must "
    318                                         "specify 4 dimensions"));
    319     OP_REQUIRES(context, (dilations_[0] == 1 && dilations_[3] == 1),
    320                 errors::InvalidArgument(
    321                     "Current implementation does not yet support "
    322                     "dilations in the batch and depth dimensions."));
    323     // TODO(yangzihao): Add a CPU implementation for dilated convolution.
    324     OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
    325                 errors::InvalidArgument(
    326                     "Current libxsmm and customized CPU implementations do "
    327                     "not yet support dilation rates larger than 1."));
    328   }
    329 
    330   void Compute(OpKernelContext* context) override {
    331     const Tensor& input = context->input(0);
    332     const Tensor& filter_sizes = context->input(1);
    333     const Tensor& out_backprop = context->input(2);
    334     OP_REQUIRES(
    335         context, TensorShapeUtils::IsVector(filter_sizes.shape()),
    336         errors::InvalidArgument(
    337             "Conv2DCustomBackpropFilter: filter_sizes input must be 1-dim, "
    338             "not ",
    339             filter_sizes.dims()));
    340     TensorShape filter_shape;
    341     OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
    342                                 filter_sizes.vec<int32>(), &filter_shape));
    343 
    344     ConvBackpropDimensions dims;
    345     OP_REQUIRES_OK(context,
    346                    ConvBackpropComputeDimensions(
    347                        "Conv2DCustomBackpropFilter", /*num_spatial_dims=*/2,
    348                        input.shape(), filter_shape, out_backprop.shape(),
    349                        strides_, padding_, data_format_, &dims));
    350 
    351     Tensor* filter_backprop;
    352     OP_REQUIRES_OK(context,
    353                    context->allocate_output(0, filter_shape, &filter_backprop));
    354 
    355     // If there is nothing to compute, return.
    356     if (filter_shape.num_elements() == 0) {
    357       return;
    358     }
    359 
    360     int64 pad_top, pad_bottom;
    361     int64 pad_left, pad_right;
    362     OP_REQUIRES_OK(
    363         context,
    364         GetWindowedOutputSizeVerbose(
    365             dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
    366             dims.spatial_dims[0].stride, padding_,
    367             &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom));
    368     OP_REQUIRES_OK(
    369         context,
    370         GetWindowedOutputSizeVerbose(
    371             dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
    372             dims.spatial_dims[1].stride, padding_,
    373             &dims.spatial_dims[1].output_size, &pad_left, &pad_right));
    374 #if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD
    375     if (pad_left == pad_right && pad_top == pad_bottom) {
    376       if (LaunchXsmmBackwardFilter<Device, T>()(
    377               context, context->eigen_device<Device>(), input.tensor<T, 4>(),
    378               filter_backprop->tensor<T, 4>(), out_backprop.tensor<T, 4>(),
    379               dims.spatial_dims[0].input_size, dims.spatial_dims[1].input_size,
    380               static_cast<int>(dims.spatial_dims[0].stride),
    381               static_cast<int>(dims.spatial_dims[1].stride),
    382               static_cast<int>(pad_top), static_cast<int>(pad_left),
    383               data_format_)) {
    384         return;
    385       }
    386     }
    387 #endif
    388 
    389     // The total dimension size of each kernel.
    390     const int filter_total_size = dims.spatial_dims[0].filter_size *
    391                                   dims.spatial_dims[1].filter_size *
    392                                   dims.in_depth;
    393     // The output image size is the spatial size of the output.
    394     const int output_image_size =
    395         dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size;
    396 
    397     // Shard 'batch' images into 'shard_size' groups of images to be fed
    398     // into the parallel matmul. Calculate 'shard_size' by dividing the L3 cache
    399     // size ('target_working_set_size') by the matmul size of an individual
    400     // image ('work_unit_size').
    401 
    402     // TODO(andydavis)
    403     // *) Get L3 cache size from device at runtime (30MB is from ivybridge).
    404     // *) Consider reducing 'target_working_set_size' if L3 is shared by
    405     //    other concurrently running tensorflow ops.
    406     const size_t target_working_set_size = (30LL << 20) / sizeof(T);
    407 
    408     const size_t size_A = output_image_size * filter_total_size;
    409 
    410     const size_t size_B = output_image_size * dims.out_depth;
    411 
    412     const size_t size_C = filter_total_size * dims.out_depth;
    413 
    414     const size_t work_unit_size = size_A + size_B + size_C;
    415 
    416     const size_t shard_size =
    417         (target_working_set_size + work_unit_size - 1) / work_unit_size;
    418 
    419     Tensor col_buffer;
    420     OP_REQUIRES_OK(context,
    421                    context->allocate_temp(
    422                        DataTypeToEnum<T>::value,
    423                        TensorShape({static_cast<int64>(shard_size),
    424                                     static_cast<int64>(output_image_size),
    425                                     static_cast<int64>(filter_total_size)}),
    426                        &col_buffer));
    427 
    428     // The input offset corresponding to a single input image.
    429     const int input_offset = dims.spatial_dims[0].input_size *
    430                              dims.spatial_dims[1].input_size * dims.in_depth;
    431     // The output offset corresponding to a single output image.
    432     const int output_offset = dims.spatial_dims[0].output_size *
    433                               dims.spatial_dims[1].output_size * dims.out_depth;
    434 
    435     const T* input_data = input.template flat<T>().data();
    436     T* col_buffer_data = col_buffer.template flat<T>().data();
    437     const T* out_backprop_data = out_backprop.template flat<T>().data();
    438     T* filter_backprop_data = filter_backprop->template flat<T>().data();
    439 
    440     typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
    441                              Eigen::Unaligned>
    442         TensorMap;
    443     typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
    444                              Eigen::Unaligned>
    445         ConstTensorMap;
    446 
    447     TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth);
    448     C.setZero();
    449 
    450     // Initialize contraction dims (we need to transpose 'A' below).
    451     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
    452     contract_dims[0].first = 0;
    453     contract_dims[0].second = 0;
    454 
    455     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
    456 
    457     for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) {
    458       const int shard_limit =
    459           std::min(static_cast<int>(shard_size),
    460                    static_cast<int>(dims.batch_size) - image_id);
    461 
    462       auto shard = [&input_data, &col_buffer_data, &dims, &pad_top, &pad_left,
    463                     &pad_bottom, &pad_right, &input_offset,
    464                     &size_A](int64 start, int64 limit) {
    465         for (int shard_id = start; shard_id < limit; ++shard_id) {
    466           const T* input_data_shard = input_data + shard_id * input_offset;
    467           T* col_data_shard = col_buffer_data + shard_id * size_A;
    468 
    469           // When we compute the gradient with respect to the filters, we need
    470           // to do im2col to allow gemm-type computation.
    471           Im2col<T>(
    472               input_data_shard, dims.in_depth, dims.spatial_dims[0].input_size,
    473               dims.spatial_dims[1].input_size, dims.spatial_dims[0].filter_size,
    474               dims.spatial_dims[1].filter_size, pad_top, pad_left, pad_bottom,
    475               pad_right, dims.spatial_dims[0].stride,
    476               dims.spatial_dims[1].stride, col_data_shard);
    477         }
    478       };
    479       Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
    480             size_A, shard);
    481 
    482       ConstTensorMap A(col_buffer_data, output_image_size * shard_limit,
    483                        filter_total_size);
    484       ConstTensorMap B(out_backprop_data, output_image_size * shard_limit,
    485                        dims.out_depth);
    486 
    487       // Gradient with respect to filter.
    488       C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims);
    489 
    490       input_data += input_offset * shard_limit;
    491       out_backprop_data += output_offset * shard_limit;
    492     }
    493   }
    494 
    495  private:
    496   std::vector<int32> dilations_;
    497   std::vector<int32> strides_;
    498   Padding padding_;
    499   TensorFormat data_format_;
    500 
    501   TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropFilterOp);
    502 };
    503 
    504 #define REGISTER_CPU_KERNELS(T)                                               \
    505   REGISTER_KERNEL_BUILDER(                                                    \
    506       Name("Conv2DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    507       Conv2DCustomBackpropFilterOp<CPUDevice, T>);                            \
    508   REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter")                        \
    509                               .Device(DEVICE_CPU)                             \
    510                               .Label("custom")                                \
    511                               .TypeConstraint<T>("T"),                        \
    512                           Conv2DCustomBackpropFilterOp<CPUDevice, T>);        \
    513   REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter")                        \
    514                               .Device(DEVICE_CPU)                             \
    515                               .Label("eigen_tensor")                          \
    516                               .TypeConstraint<T>("T"),                        \
    517                           Conv2DFastBackpropFilterOp<CPUDevice, T>);
    518 
    519 TF_CALL_half(REGISTER_CPU_KERNELS);
    520 TF_CALL_float(REGISTER_CPU_KERNELS);
    521 #undef REGISTER_CPU_KERNELS
    522 
    523 // GPU definitions.
    524 #if GOOGLE_CUDA
    525 // The slow version (but compiles for GPU)
    526 
    527 // A dummy type to group forward backward filter autotune results together.
    528 struct ConvBackwardFilterAutoTuneGroup {
    529   static string name() { return "ConvBwdFilter"; }
    530 };
    531 typedef AutoTuneSingleton<ConvBackwardFilterAutoTuneGroup, ConvParameters,
    532                           perftools::gputools::dnn::AlgorithmConfig>
    533     AutoTuneConvBwdFilter;
    534 
    535 // Backprop for filter.
    536 template <typename Device, class T>
    537 class Conv2DSlowBackpropFilterOp : public OpKernel {
    538  public:
    539   explicit Conv2DSlowBackpropFilterOp(OpKernelConstruction* context)
    540       : OpKernel(context) {
    541     string data_format;
    542     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    543     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    544                 errors::InvalidArgument("Invalid data format"));
    545     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
    546     int stride_n = GetTensorDim(strides_, data_format_, 'N');
    547     int stride_c = GetTensorDim(strides_, data_format_, 'C');
    548     int stride_h = GetTensorDim(strides_, data_format_, 'H');
    549     int stride_w = GetTensorDim(strides_, data_format_, 'W');
    550     OP_REQUIRES(
    551         context, (stride_n == 1 && stride_c == 1),
    552         errors::InvalidArgument("Current implementation does not yet support "
    553                                 "strides in the batch and depth dimensions."));
    554     OP_REQUIRES(context, stride_h > 0 && stride_w > 0,
    555                 errors::InvalidArgument(
    556                     "Row and column strides should be larger than 0."));
    557     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
    558     OP_REQUIRES(context, dilations_.size() == 4,
    559                 errors::InvalidArgument("Sliding window dilations field must "
    560                                         "specify 4 dimensions"));
    561     int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
    562     int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
    563     int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
    564     int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
    565     OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
    566                 errors::InvalidArgument(
    567                     "Current implementation does not yet support "
    568                     "dilations in the batch and depth dimensions."));
    569     OP_REQUIRES(
    570         context, dilation_h > 0 && dilation_w > 0,
    571         errors::InvalidArgument("Dilated rates should be larger than 0."));
    572     OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
    573     use_cudnn_ &= CanUseCudnn();
    574     cudnn_use_autotune_ = CudnnUseAutotune();
    575     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    576   }
    577 
    578   void Compute(OpKernelContext* context) override {
    579     const Tensor& input = context->input(0);
    580     const Tensor& filter_sizes = context->input(1);
    581     const Tensor& out_backprop = context->input(2);
    582     OP_REQUIRES(
    583         context, TensorShapeUtils::IsVector(filter_sizes.shape()),
    584         errors::InvalidArgument(
    585             "Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ",
    586             filter_sizes.dims()));
    587     TensorShape filter_shape;
    588     OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
    589                                 filter_sizes.vec<int32>(), &filter_shape));
    590 
    591     Tensor* filter_backprop = nullptr;
    592     OP_REQUIRES_OK(context,
    593                    context->allocate_output(0, filter_shape, &filter_backprop));
    594 
    595     // If there is nothing to compute, return.
    596     if (filter_shape.num_elements() == 0) {
    597       return;
    598     }
    599     // If input is empty, set gradients to zero.
    600     if (input.shape().num_elements() == 0) {
    601       functor::SetZeroFunctor<Device, T> f;
    602       f(context->eigen_device<Device>(), filter_backprop->flat<T>());
    603       return;
    604     }
    605 
    606     // For now we take the stride from the second and third dimensions only (we
    607     // do not support striding on the batch or depth dimension).
    608     const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
    609     const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
    610     const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H');
    611     const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W');
    612 
    613     launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, input,
    614               dilation_rows, dilation_cols, stride_rows, stride_cols, padding_,
    615               filter_backprop, data_format_);
    616   }
    617 
    618  private:
    619   std::vector<int32> dilations_;
    620   std::vector<int32> strides_;
    621   Padding padding_;
    622   bool use_cudnn_;
    623   TensorFormat data_format_;
    624   LaunchConv2DBackpropFilterOp<Device, T> launcher_;
    625   bool cudnn_use_autotune_;
    626 
    627   TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropFilterOp);
    628 };
    629 
    630 template <typename T>
    631 void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
    632     OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
    633     const Tensor& out_backprop, const Tensor& input, int row_dilation,
    634     int col_dilation, int row_stride, int col_stride, const Padding& padding,
    635     Tensor* filter_backprop, TensorFormat data_format) {
    636   using perftools::gputools::dnn::AlgorithmConfig;
    637   using perftools::gputools::dnn::AlgorithmDesc;
    638   using perftools::gputools::dnn::ProfileResult;
    639 
    640   std::vector<int32> dilations(4, 1);
    641   dilations[GetTensorDimIndex(data_format, 'H')] = row_dilation;
    642   dilations[GetTensorDimIndex(data_format, 'W')] = col_dilation;
    643 
    644   std::vector<int32> strides(4, 1);
    645   strides[GetTensorDimIndex(data_format, 'H')] = row_stride;
    646   strides[GetTensorDimIndex(data_format, 'W')] = col_stride;
    647   TensorShape filter_shape = filter_backprop->shape();
    648 
    649   ConvBackpropDimensions dims;
    650   OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensionsV2(
    651                           "Conv2DSlowBackpropFilter", /*num_spatial_dims=*/2,
    652                           input.shape(), filter_shape, out_backprop.shape(),
    653                           dilations, strides, padding, data_format, &dims));
    654 
    655   // TODO(yangzihao): The padding computations should be done in
    656   // GetWindowedOutputSize() functions.
    657   const int padding_rows =
    658       (padding == VALID)
    659           ? 0
    660           : std::max<int>(0, (dims.spatial_dims[0].output_size - 1) *
    661                                      dims.spatial_dims[0].stride +
    662                                  (dims.spatial_dims[0].filter_size - 1) *
    663                                      dims.spatial_dims[0].dilation +
    664                                  1 - dims.spatial_dims[0].input_size);
    665   const int padding_cols =
    666       (padding == VALID)
    667           ? 0
    668           : std::max<int>(0, (dims.spatial_dims[1].output_size - 1) *
    669                                      dims.spatial_dims[1].stride +
    670                                  (dims.spatial_dims[1].filter_size - 1) *
    671                                      dims.spatial_dims[1].dilation +
    672                                  1 - dims.spatial_dims[1].input_size);
    673 
    674   // TODO(zhengxq): cuDNN only supports equal padding on both sides, so only
    675   // calling it when that is true. Remove this check when (if?) cuDNN starts
    676   // supporting different padding.
    677   bool rows_odd = (padding_rows % 2 != 0);
    678   bool cols_odd = (padding_cols % 2 != 0);
    679 
    680   auto* stream = ctx->op_device_context()->stream();
    681   OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
    682 
    683   if (!use_cudnn) {
    684     ctx->SetStatus(errors::Unimplemented(
    685         "Conv2DBackprop for GPU is not currently supported "
    686         "without cudnn"));
    687     return;
    688   }
    689 
    690   bool cudnn_disable_conv_1x1_optimization_ = CudnnDisableConv1x1Optimization();
    691   if (!cudnn_disable_conv_1x1_optimization_ &&
    692       dims.spatial_dims[0].filter_size == 1 &&
    693       dims.spatial_dims[1].filter_size == 1 &&
    694       dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
    695       data_format == FORMAT_NHWC) {
    696     const uint64 m = dims.in_depth;
    697     const uint64 k = dims.batch_size * dims.spatial_dims[0].input_size *
    698                      dims.spatial_dims[1].input_size;
    699     const uint64 n = dims.out_depth;
    700 
    701     // The shape of output backprop is
    702     //   [batch, out_rows, out_cols, out_depth]
    703     //   From cublas's perspective, it is: n x k
    704     auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
    705                                 out_backprop.template flat<T>().size());
    706 
    707     // The shape of input is
    708     //   [batch, in_rows, in_cols, in_depth],
    709     //   From cublas's perspective, it is: m x k
    710     auto b_ptr = AsDeviceMemory(input.template flat<T>().data(),
    711                                 input.template flat<T>().size());
    712 
    713     // the shape of the filter backprop from the conv_2d should be
    714     //   [1, 1, in_depth, out_depth]
    715     //   From cublas's perspective, it is: n x m
    716     auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
    717                                 filter_backprop->template flat<T>().size());
    718 
    719     bool blas_launch_status =
    720         stream
    721             ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose,
    722                            perftools::gputools::blas::Transpose::kTranspose, n,
    723                            m, k, 1.0f, a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n)
    724             .ok();
    725     if (!blas_launch_status) {
    726       ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
    727                                       ", n=", n, ", k=", k));
    728     }
    729     return;
    730   } else if (dims.spatial_dims[0].filter_size ==
    731                  dims.spatial_dims[0].input_size &&
    732              dims.spatial_dims[1].filter_size ==
    733                  dims.spatial_dims[1].input_size &&
    734              padding == VALID && data_format == FORMAT_NHWC) {
    735     // The input data and filter have the same height/width, so call cublas
    736     // directly.
    737     const uint64 m = dims.spatial_dims[0].input_size *
    738                      dims.spatial_dims[1].input_size * dims.in_depth;
    739     const uint64 k = dims.batch_size;
    740     const uint64 n = dims.out_depth;
    741 
    742     auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
    743                                 input.template flat<T>().size());
    744     auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
    745                                 out_backprop.template flat<T>().size());
    746     auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
    747                                 filter_backprop->template flat<T>().size());
    748 
    749     bool blas_launch_status =
    750         stream
    751             ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose,
    752                            perftools::gputools::blas::Transpose::kTranspose, n,
    753                            m, k, 1.0f, b_ptr, n, a_ptr, m, 0.0f, &c_ptr, n)
    754             .ok();
    755     if (!blas_launch_status) {
    756       ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
    757                                       ", n=", n, ", k=", k));
    758     }
    759     return;
    760   }
    761 
    762   Tensor compatible_input;
    763   if (rows_odd || cols_odd) {
    764     // If a padding dimension is odd, we have one more element on the right
    765     // side or the bottom side. This is unsupported in cudnn. Therefore,
    766     // we pad that extra element and make it compatible.
    767     OP_REQUIRES_OK(
    768         ctx, ctx->allocate_temp(
    769                  DataTypeToEnum<T>::value,
    770                  ShapeFromFormat(data_format, dims.batch_size,
    771                                  dims.spatial_dims[0].input_size + rows_odd,
    772                                  dims.spatial_dims[1].input_size + cols_odd,
    773                                  dims.in_depth),
    774                  &compatible_input));
    775 
    776     functor::PadInput<GPUDevice, T, int, 4>()(
    777         ctx->template eigen_device<GPUDevice>(), To32Bit(input.tensor<T, 4>()),
    778         {{0, 0}}, {{rows_odd, cols_odd}},
    779         To32Bit(compatible_input.tensor<T, 4>()), data_format);
    780   } else {
    781     compatible_input = input;
    782   }
    783 
    784   CHECK(padding_rows >= 0 && padding_cols >= 0)
    785       << "Negative row or col paddings: (" << padding_rows << ", "
    786       << padding_cols << ")";
    787   perftools::gputools::dnn::BatchDescriptor input_desc;
    788   input_desc.set_count(dims.batch_size)
    789       .set_height(GetTensorDim(compatible_input, data_format, 'H'))
    790       .set_width(GetTensorDim(compatible_input, data_format, 'W'))
    791       .set_feature_map_count(dims.in_depth)
    792       .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
    793   perftools::gputools::dnn::BatchDescriptor output_desc;
    794   output_desc.set_count(dims.batch_size)
    795       .set_height(dims.spatial_dims[0].output_size)
    796       .set_width(dims.spatial_dims[1].output_size)
    797       .set_feature_map_count(dims.out_depth)
    798       .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
    799   perftools::gputools::dnn::FilterDescriptor filter_desc;
    800   filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
    801       .set_input_filter_width(dims.spatial_dims[1].filter_size)
    802       .set_input_feature_map_count(dims.in_depth)
    803       .set_output_feature_map_count(dims.out_depth);
    804   perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
    805   conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
    806       .set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
    807       .set_vertical_filter_stride(dims.spatial_dims[0].stride)
    808       .set_horizontal_filter_stride(dims.spatial_dims[1].stride)
    809       .set_zero_padding_height(padding_rows / 2)
    810       .set_zero_padding_width(padding_cols / 2);
    811 
    812   // NOTE(zhengxq):
    813   // cuDNN only supports the following layouts :
    814   // Input  : B x D x R x C
    815   // Filter : OD x ID x R x C
    816   // Whereas, we have
    817   // Input  : B x R x C x D
    818   // Filter : R x C x ID x OD
    819   // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C)
    820   // The first TransformDepth performs
    821   // (B x R x C x D) => (B x D x R x C).
    822   // Since the tensor returned from cuDNN is B x D x R x C also,
    823   // the second TransformDepth performs
    824   // (B x D x R x C) => (B x R x C x D).
    825 
    826   Tensor pre_transformed_filter_backprop;
    827   OP_REQUIRES_OK(
    828       ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
    829                               TensorShape({dims.out_depth, dims.in_depth,
    830                                            dims.spatial_dims[0].filter_size,
    831                                            dims.spatial_dims[1].filter_size}),
    832                               &pre_transformed_filter_backprop));
    833 
    834   Tensor transformed_out_backprop;
    835   if (data_format == FORMAT_NHWC) {
    836     TensorShape nchw_shape = ShapeFromFormat(
    837         FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size,
    838         dims.spatial_dims[1].output_size, dims.out_depth);
    839     if (dims.out_depth > 1) {
    840       OP_REQUIRES_OK(ctx,
    841                      ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
    842                                         &transformed_out_backprop));
    843       functor::NHWCToNCHW<GPUDevice, T, 4>()(
    844           ctx->eigen_device<GPUDevice>(), out_backprop.tensor<T, 4>(),
    845           transformed_out_backprop.tensor<T, 4>());
    846     } else {
    847       // If depth <= 1, just reshape.
    848       CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
    849     }
    850   } else {
    851     transformed_out_backprop = out_backprop;
    852   }
    853 
    854   Tensor transformed_input;
    855   if (data_format == FORMAT_NHWC) {
    856     TensorShape nchw_shape = ShapeFromFormat(
    857         FORMAT_NCHW, GetTensorDim(compatible_input, data_format, 'N'),
    858         GetTensorDim(compatible_input, data_format, 'H'),
    859         GetTensorDim(compatible_input, data_format, 'W'),
    860         GetTensorDim(compatible_input, data_format, 'C'));
    861     if (nchw_shape.dim_size(1) > 1) {
    862       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
    863                                              nchw_shape, &transformed_input));
    864       functor::NHWCToNCHW<GPUDevice, T, 4>()(
    865           ctx->eigen_device<GPUDevice>(),
    866           const_cast<const Tensor&>(compatible_input).tensor<T, 4>(),
    867           transformed_input.tensor<T, 4>());
    868     } else {
    869       // If depth <= 1, just reshape.
    870       CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape));
    871     }
    872   } else {
    873     transformed_input = compatible_input;
    874   }
    875 
    876   auto out_backprop_ptr =
    877       AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
    878                      transformed_out_backprop.template flat<T>().size());
    879   auto filter_backprop_ptr =
    880       AsDeviceMemory(pre_transformed_filter_backprop.template flat<T>().data(),
    881                      pre_transformed_filter_backprop.template flat<T>().size());
    882   auto input_ptr = AsDeviceMemory(transformed_input.template flat<T>().data(),
    883                                   transformed_input.template flat<T>().size());
    884 
    885   static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit(
    886       "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32  // 4GB by default
    887   );
    888   int device_id = stream->parent()->device_ordinal();
    889   DataType dtype = input.dtype();
    890   ConvParameters conv_parameters = {
    891       dims.batch_size,                       // batch
    892       dims.in_depth,                         // in_depths
    893       {{input_desc.height(),                 // in_rows
    894         input_desc.width()}},                // in_cols
    895       dims.out_depth,                        // out_depths
    896       {{dims.spatial_dims[0].filter_size,    // filter_rows
    897         dims.spatial_dims[1].filter_size}},  // filter_cols
    898       {{dims.spatial_dims[0].dilation,       // dilation_rows
    899         dims.spatial_dims[1].dilation}},     // dilation_cols
    900       {{dims.spatial_dims[0].stride,         // stride_rows
    901         dims.spatial_dims[1].stride}},       // stride_cols
    902       {{padding_rows,                        // padding_rows
    903         padding_cols}},                      // padding_cols
    904       dtype,                                 // tensor datatype
    905       device_id,                             // device_id
    906   };
    907   AlgorithmConfig algorithm_config;
    908   if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find(
    909                                 conv_parameters, &algorithm_config)) {
    910     std::vector<AlgorithmDesc> algorithms;
    911     CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
    912         conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
    913     ProfileResult best_result;
    914     ProfileResult best_result_no_scratch;
    915     for (auto profile_algorithm : algorithms) {
    916       // TODO(zhengxq): profile each algorithm multiple times to better
    917       // accuracy.
    918       CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
    919                                               ctx);
    920       ProfileResult profile_result;
    921       bool cudnn_launch_status =
    922           stream
    923               ->ThenConvolveBackwardFilterWithAlgorithm(
    924                   input_desc, input_ptr, output_desc, out_backprop_ptr,
    925                   conv_desc, filter_desc, &filter_backprop_ptr,
    926                   &scratch_allocator, AlgorithmConfig(profile_algorithm),
    927                   &profile_result)
    928               .ok();
    929       if (cudnn_launch_status) {
    930         if (profile_result.is_valid()) {
    931           if (profile_result.elapsed_time_in_ms() <
    932               best_result.elapsed_time_in_ms()) {
    933             best_result = profile_result;
    934           }
    935           if (scratch_allocator.TotalByteSize() == 0 &&
    936               profile_result.elapsed_time_in_ms() <
    937                   best_result_no_scratch.elapsed_time_in_ms()) {
    938             best_result_no_scratch = profile_result;
    939           }
    940         }
    941       }
    942     }
    943     OP_REQUIRES(ctx,
    944                 best_result.is_valid() || best_result_no_scratch.is_valid(),
    945                 errors::NotFound("No algorithm worked!"));
    946     if (best_result.is_valid()) {
    947       algorithm_config.set_algorithm(best_result.algorithm());
    948     }
    949     if (best_result_no_scratch.is_valid()) {
    950       algorithm_config.set_algorithm_no_scratch(
    951           best_result_no_scratch.algorithm());
    952     }
    953     AutoTuneConvBwdFilter::GetInstance()->Insert(conv_parameters,
    954                                                  algorithm_config);
    955   }
    956   CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
    957                                           ctx);
    958   bool cudnn_launch_status =
    959       stream
    960           ->ThenConvolveBackwardFilterWithAlgorithm(
    961               input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
    962               filter_desc, &filter_backprop_ptr, &scratch_allocator,
    963               algorithm_config, nullptr)
    964           .ok();
    965 
    966   if (!cudnn_launch_status) {
    967     ctx->SetStatus(errors::Internal(
    968         "cuDNN Backward Filter function launch failure : input shape(",
    969         input.shape().DebugString(), ") filter shape(",
    970         filter_shape.DebugString(), ")"));
    971     return;
    972   }
    973 
    974   auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
    975   functor::ReverseTransformFilter<GPUDevice, T, 4>()(
    976       ctx->eigen_device<GPUDevice>(),
    977       toConstTensor(pre_transformed_filter_backprop).template tensor<T, 4>(),
    978       filter_backprop->tensor<T, 4>());
    979 }
    980 
    981 // Forward declarations of the functor specializations for GPU.
    982 namespace functor {
    983 #define DECLARE_GPU_SPEC(T)                                              \
    984   template <>                                                            \
    985   void ShuffleAndReverse<GPUDevice, T, 4, int>::operator()(              \
    986       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \
    987       const Eigen::DSizes<int, 4>& order,                                \
    988       const Eigen::array<bool, 4>& reverse_dims,                         \
    989       typename TTypes<T, 4, int>::Tensor output);                        \
    990   extern template struct ShuffleAndReverse<GPUDevice, T, 4, int>;        \
    991   template <>                                                            \
    992   void InflatePadAndShuffle<GPUDevice, T, 4, int>::operator()(           \
    993       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \
    994       const Eigen::DSizes<int, 4>& strides,                              \
    995       const Eigen::array<Eigen::IndexPair<int>, 4>& pad_dims,            \
    996       const Eigen::DSizes<int, 4>& order,                                \
    997       typename TTypes<T, 4, int>::Tensor output);                        \
    998   extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>;     \
    999   template <>                                                            \
   1000   void TransformFilter<GPUDevice, T, int, 4>::operator()(                \
   1001       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,    \
   1002       typename TTypes<T, 4, int>::Tensor out);                           \
   1003   extern template struct TransformFilter<GPUDevice, T, int, 4>;          \
   1004   template <>                                                            \
   1005   void TransformDepth<GPUDevice, T, int>::operator()(                    \
   1006       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,    \
   1007       const Eigen::DSizes<int, 4>& shuffle,                              \
   1008       typename TTypes<T, 4, int>::Tensor out);                           \
   1009   extern template struct TransformDepth<GPUDevice, T, int>;              \
   1010   template <>                                                            \
   1011   void PadInput<GPUDevice, T, int, 4>::operator()(                       \
   1012       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,    \
   1013       const std::array<int, 2>& padding_left,                            \
   1014       const std::array<int, 2>& padding_right,                           \
   1015       typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
   1016   extern template struct PadInput<GPUDevice, T, int, 4>;
   1017 
   1018 DECLARE_GPU_SPEC(float);
   1019 DECLARE_GPU_SPEC(Eigen::half);
   1020 #undef DECLARE_GPU_SPEC
   1021 }  // namespace functor
   1022 
   1023 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter")
   1024                             .Device(DEVICE_GPU)
   1025                             .TypeConstraint<float>("T")
   1026                             .HostMemory("filter_sizes"),
   1027                         Conv2DSlowBackpropFilterOp<GPUDevice, float>);
   1028 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter")
   1029                             .Device(DEVICE_GPU)
   1030                             .TypeConstraint<Eigen::half>("T")
   1031                             .HostMemory("filter_sizes"),
   1032                         Conv2DSlowBackpropFilterOp<GPUDevice, Eigen::half>);
   1033 #endif  // GOOGLE_CUDA
   1034 
   1035 }  // namespace tensorflow
   1036