Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define USE_EIGEN_TENSOR
     17 #define EIGEN_USE_THREADS
     18 
     19 #include "tensorflow/core/kernels/conv_3d.h"
     20 
     21 #include "tensorflow/core/framework/numeric_op.h"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/register_types.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/tensor_shape.h"
     26 #include "tensorflow/core/framework/tensor_slice.h"
     27 #include "tensorflow/core/kernels/conv_2d.h"
     28 #include "tensorflow/core/kernels/conv_grad_ops.h"
     29 #include "tensorflow/core/kernels/conv_ops_gpu.h"
     30 #include "tensorflow/core/kernels/ops_util.h"
     31 #include "tensorflow/core/lib/core/errors.h"
     32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     33 #include "tensorflow/core/util/padding.h"
     34 #include "tensorflow/core/util/tensor_format.h"
     35 #include "tensorflow/core/util/use_cudnn.h"
     36 #include "tensorflow/core/util/work_sharder.h"
     37 
     38 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
     39 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
     40 #endif
     41 
     42 #if GOOGLE_CUDA
     43 #include "tensorflow/core/platform/stream_executor.h"
     44 using stream_executor::dnn::DimIndex;
     45 #endif
     46 
     47 namespace {
     48 
     49 // TODO(ezhulenev): Split this file into conv_grad_filter_ops_3d.cc and
     50 // conv_grad_input_ops_3d.cc.
     51 
     52 // TODO(ezhulenev): Generalize Col2im and Im2col for 2-d and 3-d kernels.
     53 
     54 // "Depth" is already used for the channel dimension, so for the third spatial
     55 // dimension in this file we use "plane", although in NDHWC layout it's
     56 // indicated with a "D".
     57 
     58 // Returns in 'im_data' (assumed to be zero-initialized) image patch in storage
     59 // order (planes, height, width, depth), constructed from patches in 'col_data',
     60 // which is required to be in storage order (out_planes * out_height *
     61 // out_width, filter_planes, filter_height, filter_width, in_depth).
     62 //
     63 // Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
     64 template <typename T>
     65 void Col2im(const T* col_data, const int depth, const int planes,
     66             const int height, const int width, const int filter_p,
     67             const int filter_h, const int filter_w, const int pad_pt,
     68             const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
     69             const int pad_r, const int stride_p, const int stride_h,
     70             const int stride_w, T* im_data) {
     71   const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
     72   const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
     73   const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
     74   int p_pad = -pad_pt;
     75   for (int p = 0; p < planes_col; ++p) {
     76     int h_pad = -pad_t;
     77     for (int h = 0; h < height_col; ++h) {
     78       int w_pad = -pad_l;
     79       for (int w = 0; w < width_col; ++w) {
     80         T* im_patch_data =
     81             im_data + (p_pad * height * width + h_pad * width + w_pad) * depth;
     82         for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
     83           for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
     84             for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
     85               if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
     86                   iw < width) {
     87                 for (int i = 0; i < depth; ++i) {
     88                   im_patch_data[i] += col_data[i];
     89                 }
     90               }
     91               im_patch_data += depth;
     92               col_data += depth;
     93             }
     94             // Jump over remaining number of depth.
     95             im_patch_data += depth * (width - filter_w);
     96           }
     97           // Jump over remaining number of (depth * width).
     98           im_patch_data += (depth * width) * (height - filter_h);
     99         }
    100         w_pad += stride_w;
    101       }
    102       h_pad += stride_h;
    103     }
    104     p_pad += stride_p;
    105   }
    106 }
    107 
    108 // Returns in 'col_data', image patches in storage order (planes, height, width,
    109 // depth) extracted from image at 'input_data', which is required to be in
    110 // storage order (batch, planes, height, width, depth).
    111 //
    112 // Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
    113 template <typename T>
    114 void Im2col(const T* input_data, const int depth, const int planes,
    115             const int height, const int width, const int filter_p,
    116             const int filter_h, const int filter_w, const int pad_pt,
    117             const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
    118             const int pad_r, const int stride_p, const int stride_h,
    119             const int stride_w, T* col_data) {
    120   const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
    121   const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
    122   const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
    123 
    124   int p_pad = -pad_pt;
    125   for (int p = 0; p < planes_col; ++p) {
    126     int h_pad = -pad_t;
    127     for (int h = 0; h < height_col; ++h) {
    128       int w_pad = -pad_l;
    129       for (int w = 0; w < width_col; ++w) {
    130         for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
    131           for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
    132             for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
    133               if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
    134                   iw < width) {
    135                 memcpy(col_data,
    136                        input_data +
    137                            (ip * height * width + ih * width + iw) * depth,
    138                        sizeof(T) * depth);
    139               } else {
    140                 // This should be simply padded with zero.
    141                 memset(col_data, 0, sizeof(T) * depth);
    142               }
    143               col_data += depth;
    144             }
    145           }
    146         }
    147         w_pad += stride_w;
    148       }
    149       h_pad += stride_h;
    150     }
    151     p_pad += stride_p;
    152   }
    153 }
    154 
    155 }  // namespace
    156 
    157 namespace tensorflow {
    158 
    159 typedef Eigen::ThreadPoolDevice CPUDevice;
    160 typedef Eigen::GpuDevice GPUDevice;
    161 
    162 // Backprop for input that offloads computation to
    163 // Eigen::CuboidConvolutionBackwardInput.
    164 template <typename Device, class T>
    165 class Conv3DBackpropInputOp : public OpKernel {
    166  public:
    167   explicit Conv3DBackpropInputOp(OpKernelConstruction* context)
    168       : OpKernel(context),
    169         data_format_(FORMAT_NHWC),
    170         takes_shape_(type_string().find("V2") != std::string::npos) {
    171     // data_format is only available in V2.
    172     if (takes_shape_) {
    173       string data_format;
    174       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    175       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    176                   errors::InvalidArgument("Invalid data format"));
    177       OP_REQUIRES(
    178           context, data_format_ == FORMAT_NHWC,
    179           errors::InvalidArgument(
    180               "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU."));
    181     }
    182 
    183     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
    184     OP_REQUIRES(context, dilation_.size() == 5,
    185                 errors::InvalidArgument("Dilation rates field must "
    186                                         "specify 5 dimensions"));
    187     OP_REQUIRES(context,
    188                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
    189                  GetTensorDim(dilation_, data_format_, 'N') == 1),
    190                 errors::InvalidArgument(
    191                     "Current implementation does not yet support "
    192                     "dilation rates in the batch and depth dimensions."));
    193 
    194     // TODO(yangzihao): Add CPU version of dilated conv 3D.
    195     OP_REQUIRES(context,
    196                 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
    197                  GetTensorDim(dilation_, data_format_, '1') == 1 &&
    198                  GetTensorDim(dilation_, data_format_, '2') == 1),
    199                 errors::InvalidArgument(
    200                     "Current CPU implementation does not yet support "
    201                     "dilation rates larger than 1."));
    202 
    203     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    204     OP_REQUIRES(context, stride_.size() == 5,
    205                 errors::InvalidArgument("Sliding window strides field must "
    206                                         "specify 5 dimensions"));
    207     OP_REQUIRES(
    208         context,
    209         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
    210          GetTensorDim(stride_, data_format_, 'N') == 1),
    211         errors::InvalidArgument("Current implementation does not yet support "
    212                                 "strides in the batch and depth dimensions."));
    213     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    214   }
    215 
    216   void Compute(OpKernelContext* context) override {
    217     const Tensor& filter = context->input(1);
    218     const TensorShape& filter_shape = filter.shape();
    219 
    220     const Tensor& out_backprop = context->input(2);
    221     const TensorShape& out_backprop_shape = out_backprop.shape();
    222 
    223     TensorShape input_shape;
    224     if (takes_shape_) {
    225       const Tensor& input_sizes = context->input(0);
    226       // MakeShape is able to handle both DT_INT32 and DT_INT64 for input_sizes.
    227       OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape));
    228     } else {
    229       input_shape = context->input(0).shape();
    230     }
    231 
    232     ConvBackpropDimensions dims;
    233     OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
    234                                 "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
    235                                 input_shape, filter_shape, out_backprop_shape,
    236                                 stride_, padding_, data_format_, &dims));
    237 
    238     Tensor* in_backprop;
    239     OP_REQUIRES_OK(context,
    240                    context->allocate_output(0, input_shape, &in_backprop));
    241 
    242     functor::CuboidConvolutionBackwardInput<Device, T>()(
    243         context->eigen_device<Device>(),
    244         in_backprop->tensor<T, 5>(),                     // input_backward
    245         filter.tensor<T, 5>(),                           // filter
    246         out_backprop.tensor<T, 5>(),                     // output_backward
    247         static_cast<int>(dims.spatial_dims[0].stride),   // stride_planes
    248         static_cast<int>(dims.spatial_dims[1].stride),   // stride_rows
    249         static_cast<int>(dims.spatial_dims[2].stride));  // stride_cols
    250   }
    251 
    252  private:
    253   std::vector<int32> dilation_;
    254   std::vector<int32> stride_;
    255   Padding padding_;
    256   TensorFormat data_format_;
    257   bool takes_shape_;
    258 
    259   TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropInputOp);
    260 };
    261 
    262 // Custom backprop for input that explicitly does the work sharding and calls
    263 // Eigen only to multiply matrices.
    264 template <typename Device, class T>
    265 class Conv3DCustomBackpropInputOp : public OpKernel {
    266   // Limit the maximum size of allocated temporary buffer to
    267   // kMaxTempAllocationOverhead times the size of the input tensors (input,
    268   // filter, out_backprop). If the size of the temporary buffer exceeds this
    269   // limit, fallback on Eigen implementation.
    270   static constexpr int kMaxTempAllocationOverhead = 25;
    271 
    272  public:
    273   explicit Conv3DCustomBackpropInputOp(OpKernelConstruction* context)
    274       : OpKernel(context),
    275         data_format_(FORMAT_NHWC),
    276         takes_shape_(type_string().find("V2") != std::string::npos) {
    277     // data_format is only available in V2.
    278     if (takes_shape_) {
    279       string data_format;
    280       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    281       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    282                   errors::InvalidArgument("Invalid data format"));
    283       OP_REQUIRES(
    284           context, data_format_ == FORMAT_NHWC,
    285           errors::InvalidArgument(
    286               "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU."));
    287     }
    288 
    289     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
    290     OP_REQUIRES(context, dilation_.size() == 5,
    291                 errors::InvalidArgument("Dilation rates field must "
    292                                         "specify 5 dimensions"));
    293     OP_REQUIRES(context,
    294                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
    295                  GetTensorDim(dilation_, data_format_, 'N') == 1),
    296                 errors::InvalidArgument(
    297                     "Current implementation does not yet support "
    298                     "dilation rates in the batch and depth dimensions."));
    299 
    300     // TODO(yangzihao): Add CPU version of dilated conv 3D.
    301     OP_REQUIRES(context,
    302                 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
    303                  GetTensorDim(dilation_, data_format_, '1') == 1 &&
    304                  GetTensorDim(dilation_, data_format_, '2') == 1),
    305                 errors::InvalidArgument(
    306                     "Current CPU implementation does not yet support "
    307                     "dilation rates larger than 1."));
    308 
    309     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    310     OP_REQUIRES(context, stride_.size() == 5,
    311                 errors::InvalidArgument("Sliding window strides field must "
    312                                         "specify 5 dimensions"));
    313     OP_REQUIRES(
    314         context,
    315         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
    316          GetTensorDim(stride_, data_format_, 'N') == 1),
    317         errors::InvalidArgument("Current implementation does not yet support "
    318                                 "strides in the batch and depth dimensions."));
    319     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    320   }
    321 
    322   void Compute(OpKernelContext* context) override {
    323     const Tensor& filter = context->input(1);
    324     const TensorShape& filter_shape = filter.shape();
    325 
    326     const Tensor& out_backprop = context->input(2);
    327     const TensorShape& out_backprop_shape = out_backprop.shape();
    328 
    329     TensorShape input_shape;
    330     if (takes_shape_) {
    331       const Tensor& input_sizes = context->input(0);
    332       // MakeShape is able to handle both DT_INT32 and DT_INT64 for input_sizes.
    333       OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape));
    334     } else {
    335       input_shape = context->input(0).shape();
    336     }
    337 
    338     ConvBackpropDimensions dims;
    339     OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
    340                                 "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
    341                                 input_shape, filter_shape, out_backprop_shape,
    342                                 stride_, padding_, data_format_, &dims));
    343 
    344     Tensor* in_backprop;
    345     OP_REQUIRES_OK(context,
    346                    context->allocate_output(0, input_shape, &in_backprop));
    347 
    348     int64 top_pad_planes, bottom_pad_planes;
    349     int64 top_pad_rows, bottom_pad_rows;
    350     int64 left_pad_cols, right_pad_cols;
    351 
    352     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
    353                                 dims.spatial_dims[0].input_size,
    354                                 dims.spatial_dims[0].filter_size,
    355                                 dims.spatial_dims[0].stride, padding_,
    356                                 &dims.spatial_dims[0].output_size,
    357                                 &top_pad_planes, &bottom_pad_planes));
    358     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
    359                                 dims.spatial_dims[1].input_size,
    360                                 dims.spatial_dims[1].filter_size,
    361                                 dims.spatial_dims[1].stride, padding_,
    362                                 &dims.spatial_dims[1].output_size,
    363                                 &top_pad_rows, &bottom_pad_rows));
    364     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
    365                                 dims.spatial_dims[2].input_size,
    366                                 dims.spatial_dims[2].filter_size,
    367                                 dims.spatial_dims[2].stride, padding_,
    368                                 &dims.spatial_dims[2].output_size,
    369                                 &left_pad_cols, &right_pad_cols));
    370 
    371     // TODO(ezhulenev): Extract work size and shard estimation to shared
    372     // functions in conv_grad_ops, and update 2d convolution backprop.
    373 
    374     // The total dimension size of each kernel.
    375     const int64 filter_total_size =
    376         dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
    377         dims.spatial_dims[2].filter_size * dims.in_depth;
    378 
    379     // The output image size is the spatial size of the output.
    380     const int64 output_image_size = dims.spatial_dims[0].output_size *
    381                                     dims.spatial_dims[1].output_size *
    382                                     dims.spatial_dims[2].output_size;
    383 
    384     const auto cache_sizes = Eigen::internal::CacheSizes();
    385     const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
    386 
    387     // Use L3 cache size as target working set size.
    388     const size_t target_working_set_size = l3_cache_size / sizeof(T);
    389 
    390     // Calculate size of matrices involved in MatMul: C = A x B.
    391     const int64 size_A = output_image_size * dims.out_depth;
    392 
    393     const int64 size_B = filter_total_size * dims.out_depth;
    394 
    395     const int64 size_C = output_image_size * filter_total_size;
    396 
    397     const int64 work_unit_size = size_A + size_B + size_C;
    398 
    399     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
    400 
    401     // Use parallel tensor contractions if there is no batching.
    402     //
    403     // Compared to Conv2D code, this version is missing work size estimation. In
    404     // benchmarks I didn't find a case when it's beneficial to run parallel
    405     // contraction compared to sharding and matmuls.
    406     const bool use_parallel_contraction = dims.batch_size == 1;
    407 
    408     const size_t shard_size =
    409         use_parallel_contraction
    410             ? 1
    411             : (target_working_set_size + work_unit_size - 1) / work_unit_size;
    412 
    413     // Total number of elements in all the tensors used by this kernel.
    414     int64 total_tensor_elements = input_shape.num_elements() +
    415                                   filter_shape.num_elements() +
    416                                   out_backprop_shape.num_elements();
    417 
    418     // Shape of the temporary workspace buffer.
    419     TensorShape col_buffer_shape = {static_cast<int64>(shard_size),
    420                                     static_cast<int64>(output_image_size),
    421                                     static_cast<int64>(filter_total_size)};
    422     int64 col_buffer_elements = col_buffer_shape.num_elements();
    423 
    424     // If the temporary allocation overhead is too large, fallback on Eigen
    425     // implementation which requires much less memory.
    426     int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements;
    427     if (col_buffer_overhead > kMaxTempAllocationOverhead) {
    428       VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropInputOp: "
    429                  "col_buffer_overhead="
    430               << col_buffer_overhead;
    431 
    432       functor::CuboidConvolutionBackwardInput<Device, T>()(
    433           context->eigen_device<Device>(),
    434           in_backprop->tensor<T, 5>(),                     // input_backward
    435           filter.tensor<T, 5>(),                           // filter
    436           out_backprop.tensor<T, 5>(),                     // output_backward
    437           static_cast<int>(dims.spatial_dims[0].stride),   // stride_planes
    438           static_cast<int>(dims.spatial_dims[1].stride),   // stride_rows
    439           static_cast<int>(dims.spatial_dims[2].stride));  // stride_cols
    440 
    441       return;
    442     }
    443 
    444     Tensor col_buffer;
    445     OP_REQUIRES_OK(context,
    446                    context->allocate_temp(DataTypeToEnum<T>::value,
    447                                           col_buffer_shape, &col_buffer));
    448 
    449     // The input offset corresponding to a single input image.
    450     const int64 input_offset = dims.spatial_dims[0].input_size *
    451                                dims.spatial_dims[1].input_size *
    452                                dims.spatial_dims[2].input_size * dims.in_depth;
    453 
    454     // The output offset corresponding to a single output image.
    455     const int64 output_offset =
    456         dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
    457         dims.spatial_dims[2].output_size * dims.out_depth;
    458 
    459     const T* filter_data = filter.template flat<T>().data();
    460     T* col_buffer_data = col_buffer.template flat<T>().data();
    461     const T* out_backprop_data = out_backprop.template flat<T>().data();
    462 
    463     auto in_backprop_flat = in_backprop->template flat<T>();
    464     T* input_backprop_data = in_backprop_flat.data();
    465     in_backprop_flat.device(context->eigen_device<Device>()) =
    466         in_backprop_flat.constant(T(0));
    467 
    468     if (use_parallel_contraction) {
    469       typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
    470                                Eigen::Unaligned>
    471           TensorMap;
    472       typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
    473                                Eigen::Unaligned>
    474           ConstTensorMap;
    475 
    476       // Initialize contraction dims (we need to transpose 'B' below).
    477       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
    478       contract_dims[0].first = 1;
    479       contract_dims[0].second = 1;
    480 
    481       for (int image_id = 0; image_id < dims.batch_size; ++image_id) {
    482         // Compute gradient into col_buffer.
    483         TensorMap C(col_buffer_data, output_image_size, filter_total_size);
    484 
    485         ConstTensorMap A(out_backprop_data + output_offset * image_id,
    486                          output_image_size, dims.out_depth);
    487         ConstTensorMap B(filter_data, filter_total_size, dims.out_depth);
    488 
    489         C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims);
    490 
    491         Col2im<T>(col_buffer_data, dims.in_depth,
    492                   // Input spatial dimensions.
    493                   dims.spatial_dims[0].input_size,  // input planes
    494                   dims.spatial_dims[1].input_size,  // input rows
    495                   dims.spatial_dims[2].input_size,  // input cols
    496                   // Filter spatial dimensions.
    497                   dims.spatial_dims[0].filter_size,  // filter planes
    498                   dims.spatial_dims[1].filter_size,  // filter rows
    499                   dims.spatial_dims[2].filter_size,  // filter cols
    500                   // Spatial padding.
    501                   top_pad_planes, top_pad_rows, left_pad_cols,
    502                   bottom_pad_planes, bottom_pad_rows, right_pad_cols,
    503                   // Spatial striding.
    504                   dims.spatial_dims[0].stride,  // stride planes
    505                   dims.spatial_dims[1].stride,  // stride rows
    506                   dims.spatial_dims[2].stride,  // stride cols
    507                   input_backprop_data);
    508 
    509         input_backprop_data += input_offset;
    510       }
    511     } else {
    512       typedef Eigen::Map<
    513           Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
    514           MatrixMap;
    515       typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
    516                                              Eigen::RowMajor>>
    517           ConstMatrixMap;
    518 
    519       for (int image_id = 0; image_id < dims.batch_size;
    520            image_id += shard_size) {
    521         const int shard_limit =
    522             std::min(static_cast<int>(shard_size),
    523                      static_cast<int>(dims.batch_size) - image_id);
    524 
    525         auto shard = [&dims, &top_pad_planes, &top_pad_rows, &left_pad_cols,
    526                       &bottom_pad_planes, &bottom_pad_rows, &right_pad_cols,
    527                       &output_image_size, &filter_total_size,
    528                       &input_backprop_data, &col_buffer_data,
    529                       &out_backprop_data, &filter_data, &input_offset,
    530                       &output_offset, &size_C](int64 start, int64 limit) {
    531           for (int shard_id = start; shard_id < limit; ++shard_id) {
    532             T* im2col_buf = col_buffer_data + shard_id * size_C;
    533             T* input_data = input_backprop_data + shard_id * input_offset;
    534             const T* out_data = out_backprop_data + shard_id * output_offset;
    535 
    536             // Compute gradient into 'im2col_buf'.
    537             MatrixMap C(im2col_buf, output_image_size, filter_total_size);
    538 
    539             ConstMatrixMap A(out_data, output_image_size, dims.out_depth);
    540             ConstMatrixMap B(filter_data, filter_total_size, dims.out_depth);
    541 
    542             C.noalias() = A * B.transpose();
    543 
    544             Col2im<T>(im2col_buf, dims.in_depth,
    545                       // Input spatial dimensions.
    546                       dims.spatial_dims[0].input_size,  // input planes
    547                       dims.spatial_dims[1].input_size,  // input rows
    548                       dims.spatial_dims[2].input_size,  // input cols
    549                       // Filter spatial dimensions.
    550                       dims.spatial_dims[0].filter_size,  // filter planes
    551                       dims.spatial_dims[1].filter_size,  // filter rows
    552                       dims.spatial_dims[2].filter_size,  // filter cols
    553                       // Spatial padding.
    554                       top_pad_planes, top_pad_rows, left_pad_cols,
    555                       bottom_pad_planes, bottom_pad_rows, right_pad_cols,
    556                       // Spatial striding.
    557                       dims.spatial_dims[0].stride,  // stride planes
    558                       dims.spatial_dims[1].stride,  // stride rows
    559                       dims.spatial_dims[2].stride,  // stride cols
    560                       input_data);
    561           }
    562         };
    563         Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
    564               work_unit_size, shard);
    565 
    566         input_backprop_data += input_offset * shard_limit;
    567         out_backprop_data += output_offset * shard_limit;
    568       }
    569     }
    570   }
    571 
    572  private:
    573   std::vector<int32> dilation_;
    574   std::vector<int32> stride_;
    575   Padding padding_;
    576   TensorFormat data_format_;
    577   bool takes_shape_;
    578 
    579   TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropInputOp);
    580 };
    581 
    582 // Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
    583 // default Eigen implementation (at the cost of ~2x-8x peak memory usage).
    584 
    585 #define REGISTER_CPU_KERNEL(T)                                                 \
    586   REGISTER_KERNEL_BUILDER(                                                     \
    587       Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"),   \
    588       Conv3DCustomBackpropInputOp<CPUDevice, T>);                              \
    589   REGISTER_KERNEL_BUILDER(                                                     \
    590       Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    591       Conv3DCustomBackpropInputOp<CPUDevice, T>);                              \
    592   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput")                          \
    593                               .Device(DEVICE_CPU)                              \
    594                               .Label("custom")                                 \
    595                               .TypeConstraint<T>("T"),                         \
    596                           Conv3DCustomBackpropInputOp<CPUDevice, T>);          \
    597   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2")                        \
    598                               .Device(DEVICE_CPU)                              \
    599                               .Label("custom")                                 \
    600                               .TypeConstraint<T>("T"),                         \
    601                           Conv3DCustomBackpropInputOp<CPUDevice, T>);          \
    602   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput")                          \
    603                               .Device(DEVICE_CPU)                              \
    604                               .Label("eigen_tensor")                           \
    605                               .TypeConstraint<T>("T"),                         \
    606                           Conv3DBackpropInputOp<CPUDevice, T>);                \
    607   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2")                        \
    608                               .Device(DEVICE_CPU)                              \
    609                               .Label("eigen_tensor")                           \
    610                               .TypeConstraint<T>("T"),                         \
    611                           Conv3DBackpropInputOp<CPUDevice, T>);
    612 
    613 TF_CALL_half(REGISTER_CPU_KERNEL);
    614 TF_CALL_float(REGISTER_CPU_KERNEL);
    615 TF_CALL_double(REGISTER_CPU_KERNEL);
    616 #undef REGISTER_CPU_KERNEL
    617 
    618 // Backprop for filter that offloads computation to
    619 // Eigen::CuboidConvolutionBackwardFilter.
    620 template <typename Device, class T>
    621 class Conv3DBackpropFilterOp : public OpKernel {
    622  public:
    623   explicit Conv3DBackpropFilterOp(OpKernelConstruction* context)
    624       : OpKernel(context),
    625         data_format_(FORMAT_NHWC),
    626         takes_shape_(type_string().find("V2") != std::string::npos) {
    627     // data_format is only available in V2.
    628     if (takes_shape_) {
    629       string data_format;
    630       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    631       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    632                   errors::InvalidArgument("Invalid data format"));
    633       OP_REQUIRES(
    634           context, data_format_ == FORMAT_NHWC,
    635           errors::InvalidArgument(
    636               "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU."));
    637     }
    638 
    639     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
    640     OP_REQUIRES(context, dilation_.size() == 5,
    641                 errors::InvalidArgument("Dilation rates field must "
    642                                         "specify 5 dimensions"));
    643     OP_REQUIRES(context,
    644                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
    645                  GetTensorDim(dilation_, data_format_, 'N') == 1),
    646                 errors::InvalidArgument(
    647                     "Current implementation does not yet support "
    648                     "dilation rates in the batch and depth dimensions."));
    649 
    650     // TODO(yangzihao): Add CPU version of dilated conv 3D.
    651     OP_REQUIRES(context,
    652                 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
    653                  GetTensorDim(dilation_, data_format_, '1') == 1 &&
    654                  GetTensorDim(dilation_, data_format_, '2') == 1),
    655                 errors::InvalidArgument(
    656                     "Current CPU implementation does not yet support "
    657                     "dilation rates larger than 1."));
    658 
    659     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    660     OP_REQUIRES(context, stride_.size() == 5,
    661                 errors::InvalidArgument("Sliding window strides field must "
    662                                         "specify 5 dimensions"));
    663     OP_REQUIRES(
    664         context,
    665         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
    666          GetTensorDim(stride_, data_format_, 'N') == 1),
    667         errors::InvalidArgument("Current implementation does not yet support "
    668                                 "strides in the batch and depth dimensions."));
    669     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    670   }
    671 
    672   void Compute(OpKernelContext* context) override {
    673     const Tensor& input = context->input(0);
    674     const TensorShape& input_shape = input.shape();
    675 
    676     const Tensor& out_backprop = context->input(2);
    677     const TensorShape& out_backprop_shape = out_backprop.shape();
    678 
    679     TensorShape filter_shape;
    680     if (takes_shape_) {
    681       const Tensor& filter_sizes = context->input(1);
    682       OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
    683                                   filter_sizes.vec<int32>(), &filter_shape));
    684     } else {
    685       filter_shape = context->input(1).shape();
    686     }
    687 
    688     ConvBackpropDimensions dims;
    689     OP_REQUIRES_OK(context,
    690                    ConvBackpropComputeDimensions(
    691                        "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
    692                        input_shape, filter_shape, out_backprop_shape, stride_,
    693                        padding_, data_format_, &dims));
    694 
    695     Tensor* filter_backprop;
    696     OP_REQUIRES_OK(context,
    697                    context->allocate_output(0, filter_shape, &filter_backprop));
    698 
    699     if (input_shape.num_elements() == 0) {
    700       filter_backprop->template flat<T>().setZero();
    701       return;
    702     }
    703 
    704     functor::CuboidConvolutionBackwardFilter<Device, T>()(
    705         context->eigen_device<Device>(),
    706         filter_backprop->tensor<T, 5>(),                 // filter_backward
    707         input.tensor<T, 5>(),                            // input
    708         out_backprop.tensor<T, 5>(),                     // output_backward
    709         static_cast<int>(dims.spatial_dims[0].stride),   // stride_planes
    710         static_cast<int>(dims.spatial_dims[1].stride),   // stride_rows
    711         static_cast<int>(dims.spatial_dims[2].stride));  // stride_cols
    712   }
    713 
    714  private:
    715   std::vector<int32> dilation_;
    716   std::vector<int32> stride_;
    717   Padding padding_;
    718   TensorFormat data_format_;
    719   bool takes_shape_;
    720 
    721   TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropFilterOp);
    722 };
    723 
    724 // Custom backprop for filter that explicitly does the work sharding and calls
    725 // Eigen only to multiply matrices.
    726 template <typename Device, class T>
    727 class Conv3DCustomBackpropFilterOp : public OpKernel {
    728   // Limit the maximum size of allocated temporary buffer to
    729   // kMaxTempAllocationOverhead times the size of the input tensors (input,
    730   // filter, out_backprop). If the size of the temporary buffer exceeds this
    731   // limit, fallback on Eigen implementation.
    732   static constexpr int kMaxTempAllocationOverhead = 25;
    733 
    734  public:
    735   explicit Conv3DCustomBackpropFilterOp(OpKernelConstruction* context)
    736       : OpKernel(context),
    737         data_format_(FORMAT_NHWC),
    738         takes_shape_(type_string().find("V2") != std::string::npos) {
    739     // data_format is only available in V2.
    740     if (takes_shape_) {
    741       string data_format;
    742       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    743       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    744                   errors::InvalidArgument("Invalid data format"));
    745       OP_REQUIRES(
    746           context, data_format_ == FORMAT_NHWC,
    747           errors::InvalidArgument(
    748               "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU."));
    749     }
    750 
    751     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
    752     OP_REQUIRES(context, dilation_.size() == 5,
    753                 errors::InvalidArgument("Dilation rates field must "
    754                                         "specify 5 dimensions"));
    755     OP_REQUIRES(context,
    756                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
    757                  GetTensorDim(dilation_, data_format_, 'N') == 1),
    758                 errors::InvalidArgument(
    759                     "Current implementation does not yet support "
    760                     "dilation rates in the batch and depth dimensions."));
    761 
    762     // TODO(yangzihao): Add CPU version of dilated conv 3D.
    763     OP_REQUIRES(context,
    764                 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
    765                  GetTensorDim(dilation_, data_format_, '1') == 1 &&
    766                  GetTensorDim(dilation_, data_format_, '2') == 1),
    767                 errors::InvalidArgument(
    768                     "Current CPU implementation does not yet support "
    769                     "dilation rates larger than 1."));
    770 
    771     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    772     OP_REQUIRES(context, stride_.size() == 5,
    773                 errors::InvalidArgument("Sliding window strides field must "
    774                                         "specify 5 dimensions"));
    775     OP_REQUIRES(
    776         context,
    777         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
    778          GetTensorDim(stride_, data_format_, 'N') == 1),
    779         errors::InvalidArgument("Current implementation does not yet support "
    780                                 "strides in the batch and depth dimensions."));
    781     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    782   }
    783 
    784   void Compute(OpKernelContext* context) override {
    785     const Tensor& input = context->input(0);
    786     const TensorShape& input_shape = input.shape();
    787 
    788     const Tensor& out_backprop = context->input(2);
    789     const TensorShape& out_backprop_shape = out_backprop.shape();
    790 
    791     TensorShape filter_shape;
    792     if (takes_shape_) {
    793       const Tensor& filter_sizes = context->input(1);
    794       OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
    795                                   filter_sizes.vec<int32>(), &filter_shape));
    796     } else {
    797       filter_shape = context->input(1).shape();
    798     }
    799 
    800     ConvBackpropDimensions dims;
    801     OP_REQUIRES_OK(context,
    802                    ConvBackpropComputeDimensions(
    803                        "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
    804                        input_shape, filter_shape, out_backprop_shape, stride_,
    805                        padding_, data_format_, &dims));
    806 
    807     Tensor* filter_backprop;
    808     OP_REQUIRES_OK(context,
    809                    context->allocate_output(0, filter_shape, &filter_backprop));
    810 
    811     if (input_shape.num_elements() == 0) {
    812       filter_backprop->template flat<T>().setZero();
    813       return;
    814     }
    815 
    816     int64 top_pad_planes, bottom_pad_planes;
    817     int64 top_pad_rows, bottom_pad_rows;
    818     int64 left_pad_cols, right_pad_cols;
    819 
    820     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
    821                                 dims.spatial_dims[0].input_size,
    822                                 dims.spatial_dims[0].filter_size,
    823                                 dims.spatial_dims[0].stride, padding_,
    824                                 &dims.spatial_dims[0].output_size,
    825                                 &top_pad_planes, &bottom_pad_planes));
    826     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
    827                                 dims.spatial_dims[1].input_size,
    828                                 dims.spatial_dims[1].filter_size,
    829                                 dims.spatial_dims[1].stride, padding_,
    830                                 &dims.spatial_dims[1].output_size,
    831                                 &top_pad_rows, &bottom_pad_rows));
    832     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
    833                                 dims.spatial_dims[2].input_size,
    834                                 dims.spatial_dims[2].filter_size,
    835                                 dims.spatial_dims[2].stride, padding_,
    836                                 &dims.spatial_dims[2].output_size,
    837                                 &left_pad_cols, &right_pad_cols));
    838 
    839     // TODO(ezhulenev): Extract work size and shard estimation to shared
    840     // functions in conv_grad_ops, and update 2d convolution backprop.
    841 
    842     // The total dimension size of each kernel.
    843     const int64 filter_total_size =
    844         dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
    845         dims.spatial_dims[2].filter_size * dims.in_depth;
    846     // The output image size is the spatial size of the output.
    847     const int64 output_image_size = dims.spatial_dims[0].output_size *
    848                                     dims.spatial_dims[1].output_size *
    849                                     dims.spatial_dims[2].output_size;
    850 
    851     // Shard 'batch' images (volumes) into 'shard_size' groups of images
    852     // (volumes) to be fed into the parallel matmul. Calculate 'shard_size' by
    853     // dividing the L3 cache size ('target_working_set_size') by the matmul size
    854     // of an individual image ('work_unit_size').
    855 
    856     const auto cache_sizes = Eigen::internal::CacheSizes();
    857     const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
    858 
    859     // TODO(andydavis)
    860     // *) Consider reducing 'target_working_set_size' if L3 is shared by
    861     //    other concurrently running tensorflow ops.
    862     const size_t target_working_set_size = l3_cache_size / sizeof(T);
    863 
    864     const int64 size_A = output_image_size * filter_total_size;
    865 
    866     const int64 size_B = output_image_size * dims.out_depth;
    867 
    868     const int64 size_C = filter_total_size * dims.out_depth;
    869 
    870     const int64 work_unit_size = size_A + size_B + size_C;
    871 
    872     const size_t shard_size =
    873         (target_working_set_size + work_unit_size - 1) / work_unit_size;
    874 
    875     // Total number of elements in all the tensors used by this kernel.
    876     int64 total_tensor_elements = input_shape.num_elements() +
    877                                   filter_shape.num_elements() +
    878                                   out_backprop_shape.num_elements();
    879 
    880     // Shape of the temporary workspace buffer.
    881     TensorShape col_buffer_shape = {static_cast<int64>(shard_size),
    882                                     static_cast<int64>(output_image_size),
    883                                     static_cast<int64>(filter_total_size)};
    884     int64 col_buffer_elements = col_buffer_shape.num_elements();
    885 
    886     // If the temporary allocation overhead is too large, fallback on Eigen
    887     // implementation which requires much less memory.
    888     int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements;
    889     if (col_buffer_overhead > kMaxTempAllocationOverhead) {
    890       VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropFilterOp: "
    891                  "col_buffer_overhead="
    892               << col_buffer_overhead;
    893 
    894       functor::CuboidConvolutionBackwardFilter<Device, T>()(
    895           context->eigen_device<Device>(),
    896           filter_backprop->tensor<T, 5>(),                 // filter_backward
    897           input.tensor<T, 5>(),                            // input
    898           out_backprop.tensor<T, 5>(),                     // output_backward
    899           static_cast<int>(dims.spatial_dims[0].stride),   // stride_planes
    900           static_cast<int>(dims.spatial_dims[1].stride),   // stride_rows
    901           static_cast<int>(dims.spatial_dims[2].stride));  // stride_cols
    902 
    903       return;
    904     }
    905 
    906     Tensor col_buffer;
    907     OP_REQUIRES_OK(context,
    908                    context->allocate_temp(DataTypeToEnum<T>::value,
    909                                           col_buffer_shape, &col_buffer));
    910 
    911     // The input offset corresponding to a single input image.
    912     const int64 input_offset = dims.spatial_dims[0].input_size *
    913                                dims.spatial_dims[1].input_size *
    914                                dims.spatial_dims[2].input_size * dims.in_depth;
    915     // The output offset corresponding to a single output image.
    916     const int64 output_offset =
    917         dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
    918         dims.spatial_dims[2].output_size * dims.out_depth;
    919 
    920     const T* input_data = input.template flat<T>().data();
    921     T* col_buffer_data = col_buffer.template flat<T>().data();
    922     const T* out_backprop_data = out_backprop.template flat<T>().data();
    923     T* filter_backprop_data = filter_backprop->template flat<T>().data();
    924 
    925     typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
    926                              Eigen::Unaligned>
    927         TensorMap;
    928     typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
    929                              Eigen::Unaligned>
    930         ConstTensorMap;
    931 
    932     TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth);
    933     C.setZero();
    934 
    935     // Initialize contraction dims (we need to transpose 'A' below).
    936     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
    937     contract_dims[0].first = 0;
    938     contract_dims[0].second = 0;
    939 
    940     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
    941 
    942     for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) {
    943       const int shard_limit =
    944           std::min(static_cast<int>(shard_size),
    945                    static_cast<int>(dims.batch_size) - image_id);
    946 
    947       auto shard = [&input_data, &col_buffer_data, &dims, &top_pad_planes,
    948                     &top_pad_rows, &left_pad_cols, &bottom_pad_planes,
    949                     &bottom_pad_rows, &right_pad_cols, &input_offset,
    950                     &size_A](int64 start, int64 limit) {
    951         for (int shard_id = start; shard_id < limit; ++shard_id) {
    952           const T* input_data_shard = input_data + shard_id * input_offset;
    953           T* col_data_shard = col_buffer_data + shard_id * size_A;
    954 
    955           // When we compute the gradient with respect to the filters, we need
    956           // to do im2col to allow gemm-type computation.
    957           Im2col<T>(input_data_shard, dims.in_depth,
    958                     // Input spatial dimensions.
    959                     dims.spatial_dims[0].input_size,  // input planes
    960                     dims.spatial_dims[1].input_size,  // input rows
    961                     dims.spatial_dims[2].input_size,  // input cols
    962                     // Filter spatial dimensions.
    963                     dims.spatial_dims[0].filter_size,  // filter planes
    964                     dims.spatial_dims[1].filter_size,  // filter rows
    965                     dims.spatial_dims[2].filter_size,  // filter cols
    966                     // Spatial padding.
    967                     top_pad_planes, top_pad_rows, left_pad_cols,
    968                     bottom_pad_planes, bottom_pad_rows, right_pad_cols,
    969                     // Spatial striding.
    970                     dims.spatial_dims[0].stride,  // stride planes
    971                     dims.spatial_dims[1].stride,  // stride rows
    972                     dims.spatial_dims[2].stride,  // stride cols
    973                     col_data_shard);
    974         }
    975       };
    976       Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
    977             size_A, shard);
    978 
    979       ConstTensorMap A(col_buffer_data, output_image_size * shard_limit,
    980                        filter_total_size);
    981       ConstTensorMap B(out_backprop_data, output_image_size * shard_limit,
    982                        dims.out_depth);
    983 
    984       // Gradient with respect to filter.
    985       C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims);
    986 
    987       input_data += input_offset * shard_limit;
    988       out_backprop_data += output_offset * shard_limit;
    989     }
    990   }
    991 
    992  private:
    993   std::vector<int32> dilation_;
    994   std::vector<int32> stride_;
    995   Padding padding_;
    996   TensorFormat data_format_;
    997   bool takes_shape_;
    998 
    999   TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropFilterOp);
   1000 };
   1001 
   1002 // Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
   1003 // default Eigen implementation (at the cost of ~2x-8x peak memory usage).
   1004 
   1005 #define REGISTER_CPU_KERNEL(T)                                                \
   1006   REGISTER_KERNEL_BUILDER(                                                    \
   1007       Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
   1008       Conv3DCustomBackpropFilterOp<CPUDevice, T>);                            \
   1009   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
   1010                               .Device(DEVICE_CPU)                             \
   1011                               .TypeConstraint<T>("T"),                        \
   1012                           Conv3DCustomBackpropFilterOp<CPUDevice, T>);        \
   1013   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter")                        \
   1014                               .Device(DEVICE_CPU)                             \
   1015                               .Label("custom")                                \
   1016                               .TypeConstraint<T>("T"),                        \
   1017                           Conv3DCustomBackpropFilterOp<CPUDevice, T>);        \
   1018   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
   1019                               .Device(DEVICE_CPU)                             \
   1020                               .Label("custom")                                \
   1021                               .TypeConstraint<T>("T"),                        \
   1022                           Conv3DCustomBackpropFilterOp<CPUDevice, T>);        \
   1023   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter")                        \
   1024                               .Device(DEVICE_CPU)                             \
   1025                               .Label("eigen_tensor")                          \
   1026                               .TypeConstraint<T>("T"),                        \
   1027                           Conv3DBackpropFilterOp<CPUDevice, T>);              \
   1028   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
   1029                               .Device(DEVICE_CPU)                             \
   1030                               .Label("eigen_tensor")                          \
   1031                               .TypeConstraint<T>("T"),                        \
   1032                           Conv3DBackpropFilterOp<CPUDevice, T>);
   1033 
   1034 TF_CALL_float(REGISTER_CPU_KERNEL);
   1035 TF_CALL_double(REGISTER_CPU_KERNEL);
   1036 #undef REGISTER_CPU_KERNEL
   1037 
   1038 // WARNING: Eigen::half is not trivially copyable and can't be used in
   1039 // custom backprop filter kernel because of memcpy and memset in Im2col.
   1040 #define REGISTER_CPU_KERNEL(T)                                                \
   1041   REGISTER_KERNEL_BUILDER(                                                    \
   1042       Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
   1043       Conv3DBackpropFilterOp<CPUDevice, T>);                                  \
   1044   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
   1045                               .Device(DEVICE_CPU)                             \
   1046                               .TypeConstraint<T>("T"),                        \
   1047                           Conv3DBackpropFilterOp<CPUDevice, T>);
   1048 
   1049 TF_CALL_half(REGISTER_CPU_KERNEL);
   1050 #undef REGISTER_CPU_KERNEL
   1051 
   1052 // GPU definitions of both ops.
   1053 #if GOOGLE_CUDA
   1054 // Forward declarations of the functor specializations for GPU.
   1055 // This ensures that the custom implementation is used instead of the default
   1056 // Eigen one (which is used for CPU).
   1057 namespace functor {
   1058 #define DECLARE_GPU_SPEC(T)                                           \
   1059   template <>                                                         \
   1060   void TransformFilter<GPUDevice, T, int, 5>::operator()(             \
   1061       const GPUDevice& d, FilterTensorFormat dst_filter_format,       \
   1062       typename TTypes<T, 5, int>::ConstTensor in,                     \
   1063       typename TTypes<T, 5, int>::Tensor out);                        \
   1064   template <>                                                         \
   1065   void ReverseTransformFilter<GPUDevice, T, 5>::operator()(           \
   1066       const GPUDevice& d, typename TTypes<T, 5>::ConstTensor in,      \
   1067       typename TTypes<T, 5>::Tensor out);                             \
   1068   template <>                                                         \
   1069   void PadInput<GPUDevice, T, int, 5>::operator()(                    \
   1070       const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
   1071       const std::array<int, 3>& padding_left,                         \
   1072       const std::array<int, 3>& padding_right,                        \
   1073       typename TTypes<T, 5, int>::Tensor out, TensorFormat format);
   1074 
   1075 DECLARE_GPU_SPEC(Eigen::half);
   1076 DECLARE_GPU_SPEC(float);
   1077 DECLARE_GPU_SPEC(double);
   1078 #undef DECLARE_GPU_SPEC
   1079 }  // namespace functor
   1080 
   1081 // A dummy type to group backward data autotune results together.
   1082 struct Conv3dBackwardDataAutoTuneGroup {
   1083   static string name() { return "Conv3dBwdData"; }
   1084 };
   1085 typedef AutoTuneSingleton<Conv3dBackwardDataAutoTuneGroup, ConvParameters,
   1086                           se::dnn::AlgorithmConfig>
   1087 
   1088     AutoTuneConv3dBwdData;
   1089 template <typename T>
   1090 class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
   1091  public:
   1092   explicit Conv3DBackpropInputOp(OpKernelConstruction* context)
   1093       : OpKernel(context),
   1094         data_format_(FORMAT_NHWC),
   1095         takes_shape_(type_string().find("V2") != std::string::npos) {
   1096     // data_format is only available in V2.
   1097     if (takes_shape_) {
   1098       string data_format;
   1099       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
   1100       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
   1101                   errors::InvalidArgument("Invalid data format"));
   1102     }
   1103     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
   1104     OP_REQUIRES(context, dilation_.size() == 5,
   1105                 errors::InvalidArgument("Dilation rates field must "
   1106                                         "specify 5 dimensions"));
   1107     OP_REQUIRES(context,
   1108                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
   1109                  GetTensorDim(dilation_, data_format_, 'N') == 1),
   1110                 errors::InvalidArgument(
   1111                     "Current implementation does not yet support "
   1112                     "dilation rates in the batch and depth dimensions."));
   1113     OP_REQUIRES(
   1114         context,
   1115         (GetTensorDim(dilation_, data_format_, '0') > 0 &&
   1116          GetTensorDim(dilation_, data_format_, '1') > 0 &&
   1117          GetTensorDim(dilation_, data_format_, '2') > 0),
   1118         errors::InvalidArgument("Dilated rates should be larger than 0."));
   1119     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
   1120     OP_REQUIRES(context, stride_.size() == 5,
   1121                 errors::InvalidArgument("Sliding window strides field must "
   1122                                         "specify 5 dimensions"));
   1123     OP_REQUIRES(
   1124         context,
   1125         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
   1126          GetTensorDim(stride_, data_format_, 'N') == 1),
   1127         errors::InvalidArgument("Current implementation does not yet support "
   1128                                 "strides in the batch and depth dimensions."));
   1129     OP_REQUIRES(
   1130         context,
   1131         (GetTensorDim(stride_, data_format_, '0') > 0 &&
   1132          GetTensorDim(stride_, data_format_, '1') > 0 &&
   1133          GetTensorDim(stride_, data_format_, '2') > 0),
   1134         errors::InvalidArgument("Spatial strides should be larger than 0."));
   1135     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
   1136     cudnn_use_autotune_ = CudnnUseAutotune();
   1137   }
   1138   void Compute(OpKernelContext* context) override {
   1139     const Tensor& filter = context->input(1);
   1140     const TensorShape& filter_shape = filter.shape();
   1141 
   1142     const Tensor& out_backprop = context->input(2);
   1143     const TensorShape& out_backprop_shape = out_backprop.shape();
   1144 
   1145     TensorShape input_shape;
   1146     if (takes_shape_) {
   1147       const Tensor& input_sizes = context->input(0);
   1148       OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape));
   1149     } else {
   1150       input_shape = context->input(0).shape();
   1151     }
   1152 
   1153     ConvBackpropDimensions dims;
   1154     OP_REQUIRES_OK(context, ConvBackpropComputeDimensionsV2(
   1155                                 "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
   1156                                 input_shape, filter_shape, out_backprop_shape,
   1157                                 dilation_, stride_, padding_,
   1158                                 /*explicit_paddings=*/{}, data_format_, &dims));
   1159 
   1160     Tensor* in_backprop;
   1161     OP_REQUIRES_OK(context,
   1162                    context->allocate_output(0, input_shape, &in_backprop));
   1163 
   1164     auto* stream = context->op_device_context()->stream();
   1165     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
   1166 
   1167     if (dims.filter_size(0) == 1 && dims.filter_size(1) == 1 &&
   1168         dims.filter_size(2) == 1 && dims.dilation(0) == 1 &&
   1169         dims.dilation(1) == 1 && dims.dilation(2) == 1 && dims.stride(0) == 1 &&
   1170         dims.stride(1) == 1 && dims.stride(2) == 1 &&
   1171         data_format_ == FORMAT_NHWC) {
   1172       const uint64 m = dims.batch_size * dims.input_size(0) *
   1173                        dims.input_size(1) * dims.input_size(2);
   1174       const uint64 k = dims.out_depth;
   1175       const uint64 n = dims.in_depth;
   1176 
   1177       auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
   1178                                   out_backprop.template flat<T>().size());
   1179       auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
   1180                                   filter.template flat<T>().size());
   1181       auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
   1182                                   in_backprop->template flat<T>().size());
   1183 
   1184       auto transpose = se::blas::Transpose::kTranspose;
   1185       auto no_transpose = se::blas::Transpose::kNoTranspose;
   1186 
   1187       bool blas_launch_status =
   1188           stream
   1189               ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
   1190                              a_ptr, k, 0.0f, &c_ptr, n)
   1191               .ok();
   1192       if (!blas_launch_status) {
   1193         context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
   1194                                             ", n=", n, ", k=", k));
   1195       }
   1196       return;
   1197     } else if (dims.filter_size(0) == dims.input_size(0) &&
   1198                dims.filter_size(1) == dims.input_size(1) &&
   1199                dims.filter_size(2) == dims.input_size(2) &&
   1200                padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
   1201       const uint64 m = dims.batch_size;
   1202       const uint64 k = dims.out_depth;
   1203       const uint64 n = dims.input_size(0) * dims.input_size(1) *
   1204                        dims.input_size(2) * dims.in_depth;
   1205 
   1206       auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
   1207                                   out_backprop.template flat<T>().size());
   1208       auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
   1209                                   filter.template flat<T>().size());
   1210       auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
   1211                                   in_backprop->template flat<T>().size());
   1212 
   1213       auto transpose = se::blas::Transpose::kTranspose;
   1214       auto no_transpose = se::blas::Transpose::kNoTranspose;
   1215 
   1216       bool blas_launch_status =
   1217           stream
   1218               ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
   1219                              a_ptr, k, 0.0f, &c_ptr, n)
   1220               .ok();
   1221       if (!blas_launch_status) {
   1222         context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
   1223                                             ", n=", n, ", k=", k));
   1224       }
   1225       return;
   1226     }
   1227 
   1228     int padding_planes = dims.SpatialPadding(padding_, 0);
   1229     int padding_rows = dims.SpatialPadding(padding_, 1);
   1230     int padding_cols = dims.SpatialPadding(padding_, 2);
   1231     const bool planes_odd = (padding_planes % 2 != 0);
   1232     const bool rows_odd = (padding_rows % 2 != 0);
   1233     const bool cols_odd = (padding_cols % 2 != 0);
   1234 
   1235     TensorShape compatible_input_shape;
   1236     if (rows_odd || cols_odd || planes_odd) {
   1237       // cuDNN only supports the same amount of padding on both sides.
   1238       compatible_input_shape = {
   1239           dims.batch_size,
   1240           dims.in_depth,
   1241           dims.input_size(0) + planes_odd,
   1242           dims.input_size(1) + rows_odd,
   1243           dims.input_size(2) + cols_odd,
   1244       };
   1245     } else {
   1246       compatible_input_shape = {dims.batch_size, dims.in_depth,
   1247                                 dims.input_size(0), dims.input_size(1),
   1248                                 dims.input_size(2)};
   1249     }
   1250 
   1251     CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
   1252         << "Negative paddings: (" << padding_rows << ", " << padding_cols
   1253         << ", " << padding_planes << ")";
   1254     se::dnn::BatchDescriptor input_desc(3);
   1255     input_desc.set_count(dims.batch_size)
   1256         .set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
   1257         .set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3))
   1258         .set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2))
   1259         .set_feature_map_count(dims.in_depth)
   1260         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
   1261     se::dnn::BatchDescriptor output_desc(3);
   1262     output_desc.set_count(dims.batch_size)
   1263         .set_spatial_dim(DimIndex::X, dims.output_size(2))
   1264         .set_spatial_dim(DimIndex::Y, dims.output_size(1))
   1265         .set_spatial_dim(DimIndex::Z, dims.output_size(0))
   1266         .set_feature_map_count(dims.out_depth)
   1267         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
   1268     se::dnn::FilterDescriptor filter_desc(3);
   1269     filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
   1270         .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
   1271         .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
   1272         .set_input_feature_map_count(dims.in_depth)
   1273         .set_output_feature_map_count(dims.out_depth);
   1274     se::dnn::ConvolutionDescriptor conv_desc(3);
   1275     conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
   1276         .set_dilation_rate(DimIndex::Y, dims.dilation(1))
   1277         .set_dilation_rate(DimIndex::Z, dims.dilation(0))
   1278         .set_filter_stride(DimIndex::X, dims.stride(2))
   1279         .set_filter_stride(DimIndex::Y, dims.stride(1))
   1280         .set_filter_stride(DimIndex::Z, dims.stride(0))
   1281         .set_zero_padding(DimIndex::X, padding_cols / 2)
   1282         .set_zero_padding(DimIndex::Y, padding_rows / 2)
   1283         .set_zero_padding(DimIndex::Z, padding_planes / 2);
   1284 
   1285     // Shape: out, in, z, y, x.
   1286     Tensor transformed_filter;
   1287     OP_REQUIRES_OK(
   1288         context,
   1289         context->allocate_temp(
   1290             DataTypeToEnum<T>::value,
   1291             TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0),
   1292                          dims.filter_size(1), dims.filter_size(2)}),
   1293             &transformed_filter));
   1294     functor::TransformFilter<GPUDevice, T, int, 5>()(
   1295         context->eigen_device<GPUDevice>(), FORMAT_OIHW,
   1296         To32Bit(filter.tensor<T, 5>()),
   1297         To32Bit(transformed_filter.tensor<T, 5>()));
   1298 
   1299     // Shape: batch, filters, z, y, x.
   1300     Tensor transformed_out_backprop;
   1301     if (data_format_ == FORMAT_NHWC) {
   1302       TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
   1303                                 dims.output_size(0), dims.output_size(1),
   1304                                 dims.output_size(2)};
   1305       if (dims.out_depth > 1) {
   1306         OP_REQUIRES_OK(context, context->allocate_temp(
   1307                                     DataTypeToEnum<T>::value, nchw_shape,
   1308                                     &transformed_out_backprop));
   1309         functor::NHWCToNCHW<GPUDevice, T, 5>()(
   1310             context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
   1311             transformed_out_backprop.tensor<T, 5>());
   1312       } else {
   1313         CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
   1314       }
   1315     } else {
   1316       transformed_out_backprop = out_backprop;
   1317     }
   1318     // Shape: batch, filters, z, y, x.
   1319     Tensor pre_transformed_in_backprop;
   1320     OP_REQUIRES_OK(
   1321         context,
   1322         context->allocate_temp(DataTypeToEnum<T>::value, compatible_input_shape,
   1323                                &pre_transformed_in_backprop));
   1324 
   1325     auto out_backprop_ptr =
   1326         AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
   1327                        transformed_out_backprop.template flat<T>().size());
   1328     auto filter_ptr =
   1329         AsDeviceMemory(transformed_filter.template flat<T>().data(),
   1330                        transformed_filter.template flat<T>().size());
   1331     auto in_backprop_ptr =
   1332         AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
   1333                        pre_transformed_in_backprop.template flat<T>().size());
   1334 
   1335     static int64 ConvolveBackwardDataScratchSize = GetDnnWorkspaceLimit(
   1336         "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32);  // 4GB by default
   1337 
   1338     const int device_id = stream->parent()->device_ordinal();
   1339     DataType dtype = context->input(0).dtype();
   1340     const ConvParameters conv_parameters = {
   1341         dims.batch_size,
   1342         dims.in_depth,
   1343         {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
   1344         FORMAT_NCHW,
   1345         dims.out_depth,
   1346         {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
   1347         {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
   1348         {{dims.stride(0), dims.stride(1), dims.stride(2)}},
   1349         {{padding_planes, padding_rows, padding_cols}},
   1350         dtype,
   1351         device_id,
   1352     };
   1353 
   1354     using se::dnn::AlgorithmConfig;
   1355     using se::dnn::AlgorithmDesc;
   1356     using se::dnn::ProfileResult;
   1357     AlgorithmConfig algorithm_config;
   1358     if (cudnn_use_autotune_ && !AutoTuneConv3dBwdData::GetInstance()->Find(
   1359                                    conv_parameters, &algorithm_config)) {
   1360       std::vector<AlgorithmDesc> algorithms;
   1361       CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
   1362           conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
   1363               stream->parent()),
   1364           &algorithms));
   1365       ProfileResult best_result;
   1366       ProfileResult best_result_no_scratch;
   1367       for (auto profile_algorithm : algorithms) {
   1368         // TODO(zhengxq): profile each algorithm multiple times to better
   1369         // accuracy.
   1370         DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
   1371                                               context);
   1372         ProfileResult profile_result;
   1373         bool cudnn_launch_status =
   1374             stream
   1375                 ->ThenConvolveBackwardDataWithAlgorithm(
   1376                     filter_desc, filter_ptr, output_desc, out_backprop_ptr,
   1377                     conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
   1378                     AlgorithmConfig(profile_algorithm), &profile_result)
   1379                 .ok();
   1380         if (cudnn_launch_status) {
   1381           if (profile_result.is_valid()) {
   1382             if (profile_result.elapsed_time_in_ms() <
   1383                 best_result.elapsed_time_in_ms()) {
   1384               best_result = profile_result;
   1385             }
   1386             if (scratch_allocator.TotalByteSize() == 0 &&
   1387                 profile_result.elapsed_time_in_ms() <
   1388                     best_result_no_scratch.elapsed_time_in_ms()) {
   1389               best_result_no_scratch = profile_result;
   1390             }
   1391           }
   1392         }
   1393       }
   1394       OP_REQUIRES(context,
   1395                   best_result.is_valid() || best_result_no_scratch.is_valid(),
   1396                   errors::NotFound("No algorithm worked!"));
   1397       if (best_result.is_valid()) {
   1398         algorithm_config.set_algorithm(best_result.algorithm());
   1399       }
   1400       if (best_result_no_scratch.is_valid()) {
   1401         algorithm_config.set_algorithm_no_scratch(
   1402             best_result_no_scratch.algorithm());
   1403       }
   1404       AutoTuneConv3dBwdData::GetInstance()->Insert(conv_parameters,
   1405                                                    algorithm_config);
   1406     }
   1407     DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
   1408                                           context);
   1409     bool cudnn_launch_status =
   1410         stream
   1411             ->ThenConvolveBackwardDataWithAlgorithm(
   1412                 filter_desc, filter_ptr, output_desc, out_backprop_ptr,
   1413                 conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
   1414                 algorithm_config, nullptr)
   1415             .ok();
   1416 
   1417     if (!cudnn_launch_status) {
   1418       context->SetStatus(errors::Internal(
   1419           "cuDNN Backward Data function launch failure : input shape(",
   1420           input_shape.DebugString(), ") filter shape(",
   1421           filter_shape.DebugString(), ")"));
   1422     }
   1423 
   1424     if (rows_odd || cols_odd || planes_odd) {
   1425       Tensor in_backprop_remove_padding;
   1426       OP_REQUIRES_OK(context,
   1427                      context->allocate_temp(
   1428                          DataTypeToEnum<T>::value,
   1429                          {dims.batch_size, dims.in_depth, dims.input_size(0),
   1430                           dims.input_size(1), dims.input_size(2)},
   1431                          &in_backprop_remove_padding));
   1432 
   1433       // Remove the padding for odd spatial dimensions.
   1434       functor::PadInput<GPUDevice, T, int, 5>()(
   1435           context->eigen_device<GPUDevice>(),
   1436           To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
   1437                       .tensor<T, 5>()),
   1438           {{0, 0, 0}}, {{-planes_odd, -rows_odd, -cols_odd}},
   1439           To32Bit(in_backprop_remove_padding.tensor<T, 5>()), FORMAT_NCHW);
   1440 
   1441       pre_transformed_in_backprop = in_backprop_remove_padding;
   1442     }
   1443 
   1444     if (data_format_ == FORMAT_NHWC) {
   1445       auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
   1446       functor::NCHWToNHWC<GPUDevice, T, 5>()(
   1447           context->eigen_device<GPUDevice>(),
   1448           toConstTensor(pre_transformed_in_backprop).template tensor<T, 5>(),
   1449           in_backprop->tensor<T, 5>());
   1450     } else {
   1451       *in_backprop = pre_transformed_in_backprop;
   1452     }
   1453   }
   1454 
   1455  private:
   1456   std::vector<int32> dilation_;
   1457   std::vector<int32> stride_;
   1458   Padding padding_;
   1459   TensorFormat data_format_;
   1460   bool takes_shape_;
   1461   bool cudnn_use_autotune_;
   1462 };
   1463 
   1464 // A dummy type to group backward filter autotune results together.
   1465 struct Conv3dBackwardFilterAutoTuneGroup {
   1466   static string name() { return "Conv3dBwdFilter"; }
   1467 };
   1468 typedef AutoTuneSingleton<Conv3dBackwardFilterAutoTuneGroup, ConvParameters,
   1469                           se::dnn::AlgorithmConfig>
   1470     AutoTuneConv3dBwdFilter;
   1471 
   1472 template <typename T>
   1473 class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
   1474  public:
   1475   explicit Conv3DBackpropFilterOp(OpKernelConstruction* context)
   1476       : OpKernel(context),
   1477         data_format_(FORMAT_NHWC),
   1478         takes_shape_(type_string().find("V2") != std::string::npos) {
   1479     // data_format is only available in V2.
   1480     if (takes_shape_) {
   1481       string data_format;
   1482       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
   1483       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
   1484                   errors::InvalidArgument("Invalid data format"));
   1485     }
   1486     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
   1487     OP_REQUIRES(context, dilation_.size() == 5,
   1488                 errors::InvalidArgument("Dilation rates field must "
   1489                                         "specify 5 dimensions"));
   1490     OP_REQUIRES(context,
   1491                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
   1492                  GetTensorDim(dilation_, data_format_, 'N') == 1),
   1493                 errors::InvalidArgument(
   1494                     "Current implementation does not yet support "
   1495                     "dilation rates in the batch and depth dimensions."));
   1496     OP_REQUIRES(
   1497         context,
   1498         (GetTensorDim(dilation_, data_format_, '0') > 0 &&
   1499          GetTensorDim(dilation_, data_format_, '1') > 0 &&
   1500          GetTensorDim(dilation_, data_format_, '2') > 0),
   1501         errors::InvalidArgument("Dilated rates should be larger than 0."));
   1502     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
   1503     OP_REQUIRES(context, stride_.size() == 5,
   1504                 errors::InvalidArgument("Sliding window strides field must "
   1505                                         "specify 5 dimensions"));
   1506     OP_REQUIRES(
   1507         context,
   1508         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
   1509          GetTensorDim(stride_, data_format_, 'N') == 1),
   1510         errors::InvalidArgument("Current implementation does not yet support "
   1511                                 "strides in the batch and depth dimensions."));
   1512     OP_REQUIRES(
   1513         context,
   1514         (GetTensorDim(stride_, data_format_, '0') > 0 &&
   1515          GetTensorDim(stride_, data_format_, '1') > 0 &&
   1516          GetTensorDim(stride_, data_format_, '2') > 0),
   1517         errors::InvalidArgument("Spatial strides should be larger than 0."));
   1518     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
   1519     cudnn_use_autotune_ = CudnnUseAutotune();
   1520   }
   1521 
   1522   void Compute(OpKernelContext* context) override {
   1523     const Tensor& input = context->input(0);
   1524     const TensorShape& input_shape = input.shape();
   1525 
   1526     const Tensor& out_backprop = context->input(2);
   1527     const TensorShape& out_backprop_shape = out_backprop.shape();
   1528 
   1529     TensorShape filter_shape;
   1530     if (takes_shape_) {
   1531       const Tensor& filter_sizes = context->input(1);
   1532       OP_REQUIRES_OK(context, MakeShape(filter_sizes, &filter_shape));
   1533     } else {
   1534       filter_shape = context->input(1).shape();
   1535     }
   1536 
   1537     ConvBackpropDimensions dims;
   1538     OP_REQUIRES_OK(
   1539         context,
   1540         ConvBackpropComputeDimensionsV2(
   1541             "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3, input_shape,
   1542             filter_shape, out_backprop_shape, dilation_, stride_, padding_,
   1543             /*explicit_paddings=*/{}, data_format_, &dims));
   1544 
   1545     Tensor* filter_backprop;
   1546     OP_REQUIRES_OK(context,
   1547                    context->allocate_output(0, filter_shape, &filter_backprop));
   1548 
   1549     auto* stream = context->op_device_context()->stream();
   1550     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
   1551 
   1552     if (dims.filter_size(1) == 1 && dims.filter_size(2) == 1 &&
   1553         dims.filter_size(0) == 1 && dims.dilation(2) == 1 &&
   1554         dims.dilation(1) == 1 && dims.dilation(0) == 1 && dims.stride(2) == 1 &&
   1555         dims.stride(1) == 1 && dims.stride(0) == 1 &&
   1556         data_format_ == FORMAT_NHWC) {
   1557       const uint64 m = dims.in_depth;
   1558       const uint64 k = dims.batch_size * dims.input_size(1) *
   1559                        dims.input_size(2) * dims.input_size(0);
   1560       const uint64 n = dims.out_depth;
   1561 
   1562       // The shape of output backprop is
   1563       //   [batch, out_z, out_y, out_x, out_depth]
   1564       // From cublas's perspective, it is: n x k
   1565       auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
   1566                                   out_backprop.template flat<T>().size());
   1567 
   1568       // The shape of input is:
   1569       //   [batch, in_z, in_y, in_x, in_depth],
   1570       // From cublas's perspective, it is: m x k
   1571       auto b_ptr = AsDeviceMemory(input.template flat<T>().data(),
   1572                                   input.template flat<T>().size());
   1573 
   1574       // The shape of the filter backprop is:
   1575       //   [1, 1, 1, in_depth, out_depth]
   1576       // From cublas's perspective, it is: n x m
   1577       auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
   1578                                   filter_backprop->template flat<T>().size());
   1579 
   1580       bool blas_launch_status =
   1581           stream
   1582               ->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
   1583                              se::blas::Transpose::kTranspose, n, m, k, 1.0f,
   1584                              a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n)
   1585               .ok();
   1586       if (!blas_launch_status) {
   1587         context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
   1588                                             ", n=", n, ", k=", k));
   1589       }
   1590       return;
   1591     } else if (dims.filter_size(0) == dims.input_size(0) &&
   1592                dims.filter_size(1) == dims.input_size(1) &&
   1593                dims.filter_size(2) == dims.input_size(2) &&
   1594                padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
   1595       const uint64 m = dims.input_size(0) * dims.input_size(1) *
   1596                        dims.input_size(2) * dims.in_depth;
   1597       const uint64 k = dims.batch_size;
   1598       const uint64 n = dims.out_depth;
   1599 
   1600       auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
   1601                                   input.template flat<T>().size());
   1602       auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
   1603                                   out_backprop.template flat<T>().size());
   1604       auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
   1605                                   filter_backprop->template flat<T>().size());
   1606 
   1607       bool blas_launch_status =
   1608           stream
   1609               ->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
   1610                              se::blas::Transpose::kTranspose, n, m, k, 1.0f,
   1611                              b_ptr, n, a_ptr, m, 0.0f, &c_ptr, n)
   1612               .ok();
   1613       if (!blas_launch_status) {
   1614         context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
   1615                                             ", n=", n, ", k=", k));
   1616       }
   1617       return;
   1618     }
   1619 
   1620     int padding_planes = dims.SpatialPadding(padding_, 0);
   1621     int padding_rows = dims.SpatialPadding(padding_, 1);
   1622     int padding_cols = dims.SpatialPadding(padding_, 2);
   1623     const bool planes_odd = (padding_planes % 2 != 0);
   1624     const bool rows_odd = (padding_rows % 2 != 0);
   1625     const bool cols_odd = (padding_cols % 2 != 0);
   1626 
   1627     Tensor compatible_input;
   1628     if (rows_odd || cols_odd || planes_odd) {
   1629       OP_REQUIRES_OK(context,
   1630                      context->allocate_temp(
   1631                          DataTypeToEnum<T>::value,
   1632                          ShapeFromFormat(data_format_, dims.batch_size,
   1633                                          {{dims.input_size(0) + planes_odd,
   1634                                            dims.input_size(1) + rows_odd,
   1635                                            dims.input_size(2) + cols_odd}},
   1636                                          dims.in_depth),
   1637                          &compatible_input));
   1638       functor::PadInput<GPUDevice, T, int, 5>()(
   1639           context->template eigen_device<GPUDevice>(),
   1640           To32Bit(input.tensor<T, 5>()), {{0, 0, 0}},
   1641           {{planes_odd, rows_odd, cols_odd}},
   1642           To32Bit(compatible_input.tensor<T, 5>()), data_format_);
   1643     } else {
   1644       compatible_input = input;
   1645     }
   1646 
   1647     CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
   1648         << "Negative paddings: (" << padding_rows << ", " << padding_cols
   1649         << ", " << padding_planes << ")";
   1650     se::dnn::BatchDescriptor input_desc(3);
   1651     input_desc.set_count(dims.batch_size)
   1652         .set_spatial_dim(DimIndex::X,
   1653                          GetTensorDim(compatible_input, data_format_, '2'))
   1654         .set_spatial_dim(DimIndex::Y,
   1655                          GetTensorDim(compatible_input, data_format_, '1'))
   1656         .set_spatial_dim(DimIndex::Z,
   1657                          GetTensorDim(compatible_input, data_format_, '0'))
   1658         .set_feature_map_count(dims.in_depth)
   1659         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
   1660     se::dnn::BatchDescriptor output_desc(3);
   1661     output_desc.set_count(dims.batch_size)
   1662         .set_spatial_dim(DimIndex::X, dims.output_size(2))
   1663         .set_spatial_dim(DimIndex::Y, dims.output_size(1))
   1664         .set_spatial_dim(DimIndex::Z, dims.output_size(0))
   1665         .set_feature_map_count(dims.out_depth)
   1666         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
   1667     se::dnn::FilterDescriptor filter_desc(3);
   1668     filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
   1669         .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
   1670         .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
   1671         .set_input_feature_map_count(dims.in_depth)
   1672         .set_output_feature_map_count(dims.out_depth);
   1673     se::dnn::ConvolutionDescriptor conv_desc(3);
   1674     conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
   1675         .set_dilation_rate(DimIndex::Y, dims.dilation(1))
   1676         .set_dilation_rate(DimIndex::Z, dims.dilation(0))
   1677         .set_filter_stride(DimIndex::X, dims.stride(2))
   1678         .set_filter_stride(DimIndex::Y, dims.stride(1))
   1679         .set_filter_stride(DimIndex::Z, dims.stride(0))
   1680         .set_zero_padding(DimIndex::X, padding_cols / 2)
   1681         .set_zero_padding(DimIndex::Y, padding_rows / 2)
   1682         .set_zero_padding(DimIndex::Z, padding_planes / 2);
   1683 
   1684     Tensor pre_transformed_filter_backprop;
   1685     OP_REQUIRES_OK(
   1686         context,
   1687         context->allocate_temp(
   1688             DataTypeToEnum<T>::value,
   1689             TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0),
   1690                          dims.filter_size(1), dims.filter_size(2)}),
   1691             &pre_transformed_filter_backprop));
   1692 
   1693     Tensor transformed_out_backprop;
   1694     if (data_format_ == FORMAT_NHWC) {
   1695       TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
   1696                                 dims.output_size(0), dims.output_size(1),
   1697                                 dims.output_size(2)};
   1698       OP_REQUIRES_OK(
   1699           context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
   1700                                           &transformed_out_backprop));
   1701       if (dims.out_depth > 1) {
   1702         functor::NHWCToNCHW<GPUDevice, T, 5>()(
   1703             context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
   1704             transformed_out_backprop.tensor<T, 5>());
   1705       } else {
   1706         CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
   1707       }
   1708     } else {
   1709       transformed_out_backprop = out_backprop;
   1710     }
   1711     Tensor transformed_input;
   1712     if (data_format_ == FORMAT_NHWC) {
   1713       TensorShape nchw_shape = {
   1714           dims.batch_size, dims.in_depth, compatible_input.dim_size(1),
   1715           compatible_input.dim_size(2), compatible_input.dim_size(3)};
   1716       if (dims.in_depth > 1) {
   1717         OP_REQUIRES_OK(context,
   1718                        context->allocate_temp(DataTypeToEnum<T>::value,
   1719                                               nchw_shape, &transformed_input));
   1720         functor::NHWCToNCHW<GPUDevice, T, 5>()(
   1721             context->eigen_device<GPUDevice>(),
   1722             const_cast<const Tensor&>(compatible_input).tensor<T, 5>(),
   1723             transformed_input.tensor<T, 5>());
   1724       } else {
   1725         CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape));
   1726       }
   1727     } else {
   1728       transformed_input = compatible_input;
   1729     }
   1730 
   1731     auto out_backprop_ptr =
   1732         AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
   1733                        transformed_out_backprop.template flat<T>().size());
   1734     auto filter_backprop_ptr = AsDeviceMemory(
   1735         pre_transformed_filter_backprop.template flat<T>().data(),
   1736         pre_transformed_filter_backprop.template flat<T>().size());
   1737     auto input_ptr =
   1738         AsDeviceMemory(transformed_input.template flat<T>().data(),
   1739                        transformed_input.template flat<T>().size());
   1740 
   1741     static int64 ConvolveBackwardFilterScratchSize = GetDnnWorkspaceLimit(
   1742         "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32);  // 4GB by default
   1743 
   1744     const int device_id = stream->parent()->device_ordinal();
   1745     DataType dtype = input.dtype();
   1746     const ConvParameters conv_parameters = {
   1747         dims.batch_size,
   1748         dims.in_depth,
   1749         {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
   1750         FORMAT_NCHW,
   1751         dims.out_depth,
   1752         {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
   1753         {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
   1754         {{dims.stride(0), dims.stride(1), dims.stride(2)}},
   1755         {{padding_planes, padding_rows, padding_cols}},
   1756         dtype,
   1757         device_id,
   1758     };
   1759 
   1760     using se::dnn::AlgorithmConfig;
   1761     using se::dnn::AlgorithmDesc;
   1762     using se::dnn::ProfileResult;
   1763     AlgorithmConfig algorithm_config;
   1764     if (cudnn_use_autotune_ && !AutoTuneConv3dBwdFilter::GetInstance()->Find(
   1765                                    conv_parameters, &algorithm_config)) {
   1766       std::vector<AlgorithmDesc> algorithms;
   1767       CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
   1768           conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
   1769               stream->parent()),
   1770           &algorithms));
   1771       ProfileResult best_result;
   1772       ProfileResult best_result_no_scratch;
   1773       for (auto profile_algorithm : algorithms) {
   1774         // TODO(zhengxq): profile each algorithm multiple times to better
   1775         // accuracy.
   1776         DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
   1777                                               context);
   1778         ProfileResult profile_result;
   1779         bool cudnn_launch_status =
   1780             stream
   1781                 ->ThenConvolveBackwardFilterWithAlgorithm(
   1782                     input_desc, input_ptr, output_desc, out_backprop_ptr,
   1783                     conv_desc, filter_desc, &filter_backprop_ptr,
   1784                     &scratch_allocator, AlgorithmConfig(profile_algorithm),
   1785                     &profile_result)
   1786                 .ok();
   1787         if (cudnn_launch_status) {
   1788           if (profile_result.is_valid()) {
   1789             if (profile_result.elapsed_time_in_ms() <
   1790                 best_result.elapsed_time_in_ms()) {
   1791               best_result = profile_result;
   1792             }
   1793             if (scratch_allocator.TotalByteSize() == 0 &&
   1794                 profile_result.elapsed_time_in_ms() <
   1795                     best_result_no_scratch.elapsed_time_in_ms()) {
   1796               best_result_no_scratch = profile_result;
   1797             }
   1798           }
   1799         }
   1800       }
   1801       OP_REQUIRES(context,
   1802                   best_result.is_valid() || best_result_no_scratch.is_valid(),
   1803                   errors::NotFound("No algorithm worked!"));
   1804       if (best_result.is_valid()) {
   1805         algorithm_config.set_algorithm(best_result.algorithm());
   1806       }
   1807       if (best_result_no_scratch.is_valid()) {
   1808         algorithm_config.set_algorithm_no_scratch(
   1809             best_result_no_scratch.algorithm());
   1810       }
   1811       AutoTuneConv3dBwdFilter::GetInstance()->Insert(conv_parameters,
   1812                                                      algorithm_config);
   1813     }
   1814     DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
   1815                                           context);
   1816     bool cudnn_launch_status =
   1817         stream
   1818             ->ThenConvolveBackwardFilterWithAlgorithm(
   1819                 input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
   1820                 filter_desc, &filter_backprop_ptr, &scratch_allocator,
   1821                 algorithm_config, nullptr)
   1822             .ok();
   1823 
   1824     if (!cudnn_launch_status) {
   1825       context->SetStatus(errors::Internal(
   1826           "cuDNN Backward Filter function launch failure : input shape(",
   1827           input_shape.DebugString(), ") filter shape(",
   1828           filter_shape.DebugString(), ")"));
   1829     }
   1830 
   1831     auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
   1832     functor::ReverseTransformFilter<GPUDevice, T, 5>()(
   1833         context->eigen_device<GPUDevice>(),
   1834         toConstTensor(pre_transformed_filter_backprop).template tensor<T, 5>(),
   1835         filter_backprop->tensor<T, 5>());
   1836   }
   1837 
   1838  private:
   1839   std::vector<int32> dilation_;
   1840   std::vector<int32> stride_;
   1841   Padding padding_;
   1842   TensorFormat data_format_;
   1843   bool takes_shape_;
   1844   bool cudnn_use_autotune_;
   1845 };
   1846 
   1847 #define REGISTER_GPU_KERNEL(T)                                                \
   1848   REGISTER_KERNEL_BUILDER(                                                    \
   1849       Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint<T>("T"),  \
   1850       Conv3DBackpropInputOp<GPUDevice, T>);                                   \
   1851   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2")                       \
   1852                               .Device(DEVICE_GPU)                             \
   1853                               .TypeConstraint<T>("T")                         \
   1854                               .HostMemory("input_sizes"),                     \
   1855                           Conv3DBackpropInputOp<GPUDevice, T>);               \
   1856   REGISTER_KERNEL_BUILDER(                                                    \
   1857       Name("Conv3DBackpropFilter").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
   1858       Conv3DBackpropFilterOp<GPUDevice, T>);                                  \
   1859   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
   1860                               .Device(DEVICE_GPU)                             \
   1861                               .TypeConstraint<T>("T")                         \
   1862                               .HostMemory("filter_sizes"),                    \
   1863                           Conv3DBackpropFilterOp<GPUDevice, T>);
   1864 TF_CALL_half(REGISTER_GPU_KERNEL);
   1865 TF_CALL_float(REGISTER_GPU_KERNEL);
   1866 TF_CALL_double(REGISTER_GPU_KERNEL);
   1867 #undef REGISTER_GPU_KERNEL
   1868 
   1869 #endif  // GOOGLE_CUDA
   1870 
   1871 }  // namespace tensorflow
   1872