Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define USE_EIGEN_TENSOR
     17 #define EIGEN_USE_THREADS
     18 
     19 #include "tensorflow/core/kernels/conv_3d.h"
     20 
     21 #include "tensorflow/core/framework/numeric_op.h"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/register_types.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/tensor_shape.h"
     26 #include "tensorflow/core/framework/tensor_slice.h"
     27 #include "tensorflow/core/kernels/conv_2d.h"
     28 #include "tensorflow/core/kernels/conv_ops_gpu.h"
     29 #include "tensorflow/core/kernels/ops_util.h"
     30 #include "tensorflow/core/lib/core/errors.h"
     31 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     32 #include "tensorflow/core/util/padding.h"
     33 #include "tensorflow/core/util/tensor_format.h"
     34 #include "tensorflow/core/util/use_cudnn.h"
     35 
     36 #if GOOGLE_CUDA
     37 #include "tensorflow/core/platform/stream_executor.h"
     38 using perftools::gputools::dnn::DimIndex;
     39 #endif
     40 
     41 namespace tensorflow {
     42 
     43 typedef Eigen::ThreadPoolDevice CPUDevice;
     44 typedef Eigen::GpuDevice GPUDevice;
     45 
     46 // TODO(mjanusz): Get rid of the macro and return shapes directly.
     47 #define EXTRACT_AND_VERIFY_DIMENSIONS(label)                                   \
     48   const Tensor& out_backprop = context->input(2);                              \
     49   OP_REQUIRES(                                                                 \
     50       context, input_shape.dims() == 5,                                        \
     51       errors::InvalidArgument(label, ": input must be 5-dimensional"));        \
     52   OP_REQUIRES(                                                                 \
     53       context, filter_shape.dims() == 5,                                       \
     54       errors::InvalidArgument(label, ": filter must be 5-dimensional"));       \
     55   OP_REQUIRES(                                                                 \
     56       context, out_backprop.dims() == 5,                                       \
     57       errors::InvalidArgument(label, ": out_backprop must be 5-dimensional")); \
     58   const int64 batch = input_shape.dim_size(0);                                 \
     59   OP_REQUIRES(                                                                 \
     60       context, batch == out_backprop.dim_size(0),                              \
     61       errors::InvalidArgument(                                                 \
     62           label, ": input and out_backprop must have the same batch size"));   \
     63   const std::array<int64, 3> input_size = {                                    \
     64       {GetTensorDim(input_shape, data_format_, '0'),                           \
     65        GetTensorDim(input_shape, data_format_, '1'),                           \
     66        GetTensorDim(input_shape, data_format_, '2')}};                         \
     67   const int64 in_depth = GetTensorDim(input_shape, data_format_, 'C');         \
     68   const std::array<int64, 3> filter_size = {{filter_shape.dim_size(0),         \
     69                                              filter_shape.dim_size(1),         \
     70                                              filter_shape.dim_size(2)}};       \
     71   const int64 output_cols = GetTensorDim(out_backprop, data_format_, '2');     \
     72   const int64 output_rows = GetTensorDim(out_backprop, data_format_, '1');     \
     73   const int64 output_planes = GetTensorDim(out_backprop, data_format_, '0');   \
     74   OP_REQUIRES(context, in_depth == filter_shape.dim_size(3),                   \
     75               errors::InvalidArgument(                                         \
     76                   label, ": input and filter must have the same depth"));      \
     77   const int64 out_depth = filter_shape.dim_size(4);                            \
     78   OP_REQUIRES(                                                                 \
     79       context, out_depth == GetTensorDim(out_backprop, data_format_, 'C'),     \
     80       errors::InvalidArgument(                                                 \
     81           label, ": filter and out_backprop must have the same out_depth"));   \
     82   const std::array<int64, 3> strides = {                                       \
     83       {GetTensorDim(stride_, data_format_, '0'),                               \
     84        GetTensorDim(stride_, data_format_, '1'),                               \
     85        GetTensorDim(stride_, data_format_, '2')}};                             \
     86   std::array<int64, 3> out, padding;                                           \
     87   OP_REQUIRES_OK(context, Get3dOutputSize(input_size, filter_size, strides,    \
     88                                           padding_, &out, &padding));          \
     89   OP_REQUIRES(context, output_planes == out[0],                                \
     90               errors::InvalidArgument(                                         \
     91                   label,                                                       \
     92                   ": Number of planes of out_backprop doesn't match "          \
     93                   "computed:  actual = ",                                      \
     94                   output_planes, ", computed = ", out[0]));                    \
     95   OP_REQUIRES(                                                                 \
     96       context, output_rows == out[1],                                          \
     97       errors::InvalidArgument(                                                 \
     98           label, ": Number of rows of out_backprop doesn't match computed: ",  \
     99           "actual = ", output_rows, ", computed = ", out[1]));                 \
    100   OP_REQUIRES(                                                                 \
    101       context, output_cols == out[2],                                          \
    102       errors::InvalidArgument(                                                 \
    103           label, ": Number of cols of out_backprop doesn't match computed: ",  \
    104           "actual = ", output_cols, ", computed = ", out[2]));                 \
    105   const auto expanded_out_planes = (output_planes - 1) * strides[0] + 1;       \
    106   const auto expanded_out_rows = (output_rows - 1) * strides[1] + 1;           \
    107   const auto expanded_out_cols = (output_cols - 1) * strides[2] + 1;           \
    108   const auto padded_out_planes = input_size[0] + filter_size[0] - 1;           \
    109   const auto padded_out_rows = input_size[1] + filter_size[1] - 1;             \
    110   const auto padded_out_cols = input_size[2] + filter_size[2] - 1;             \
    111   const auto top_pad_planes = filter_size[0] - 1 - padding[0];                 \
    112   const auto top_pad_rows = filter_size[1] - 1 - padding[1];                   \
    113   const auto left_pad_cols = filter_size[2] - 1 - padding[2];                  \
    114   const auto bottom_pad_planes =                                               \
    115       padded_out_planes - expanded_out_planes - top_pad_planes;                \
    116   const auto bottom_pad_rows =                                                 \
    117       padded_out_rows - expanded_out_rows - top_pad_rows;                      \
    118   const auto right_pad_cols =                                                  \
    119       padded_out_cols - expanded_out_cols - left_pad_cols;                     \
    120   VLOG(2) << "Conv3d: " << label                                               \
    121           << ": expanded_out_planes = " << expanded_out_planes                 \
    122           << ": expanded_out_rows = " << expanded_out_rows                     \
    123           << ", expanded_out_cols = " << expanded_out_cols                     \
    124           << ", padded_out_planes = " << padded_out_planes                     \
    125           << ", padded_out_rows = " << padded_out_rows                         \
    126           << ", padded_out_cols = " << padded_out_cols                         \
    127           << ", top_pad_planes = " << top_pad_planes                           \
    128           << ", top_pad_rows = " << top_pad_rows                               \
    129           << ", left_pad_cols = " << left_pad_cols                             \
    130           << ", bottom_pad_planes = " << bottom_pad_planes                     \
    131           << ", bottom_pad_rows = " << bottom_pad_rows                         \
    132           << ", right_pad_cols = " << right_pad_cols
    133 
    134 // Backprop for input.
    135 template <typename Device, class T>
    136 class Conv3DBackpropInputOp : public OpKernel {
    137  public:
    138   explicit Conv3DBackpropInputOp(OpKernelConstruction* context)
    139       : OpKernel(context),
    140         data_format_(FORMAT_NHWC),
    141         takes_shape_(type_string().find("V2") != std::string::npos) {
    142     // data_format is only available in V2.
    143     if (takes_shape_) {
    144       string data_format;
    145       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    146       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    147                   errors::InvalidArgument("Invalid data format"));
    148       OP_REQUIRES(
    149           context, data_format_ == FORMAT_NHWC,
    150           errors::InvalidArgument(
    151               "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU."));
    152     }
    153 
    154     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    155     OP_REQUIRES(context, stride_.size() == 5,
    156                 errors::InvalidArgument("Sliding window strides field must "
    157                                         "specify 5 dimensions"));
    158     OP_REQUIRES(
    159         context,
    160         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
    161          GetTensorDim(stride_, data_format_, 'N') == 1),
    162         errors::InvalidArgument("Current implementation does not yet support "
    163                                 "strides in the batch and depth dimensions."));
    164     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    165   }
    166 
    167   void Compute(OpKernelContext* context) override {
    168     const Tensor& filter = context->input(1);
    169     const TensorShape& filter_shape = filter.shape();
    170     TensorShape input_shape;
    171     if (takes_shape_) {
    172       const Tensor& input_sizes = context->input(0);
    173       OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
    174                                   input_sizes.vec<int32>(), &input_shape));
    175     } else {
    176       input_shape = context->input(0).shape();
    177     }
    178     EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput");
    179     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{
    180         {0, 0},
    181         {top_pad_planes, bottom_pad_planes},
    182         {top_pad_rows, bottom_pad_rows},
    183         {left_pad_cols, right_pad_cols},
    184         {0, 0}};
    185     Tensor* in_backprop;
    186     OP_REQUIRES_OK(context,
    187                    context->allocate_output(0, input_shape, &in_backprop));
    188 
    189     // Fill out a padded out_backprop.
    190     TensorShape padded_out_shape({batch, padded_out_planes, padded_out_rows,
    191                                   padded_out_cols, out_depth});
    192     Tensor padded_output;
    193     OP_REQUIRES_OK(context,
    194                    context->allocate_temp(DataTypeToEnum<T>::v(),
    195                                           padded_out_shape, &padded_output));
    196     Eigen::DSizes<Eigen::DenseIndex, 5> no_op_shuffle{0, 1, 2, 3, 4};
    197     Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1],
    198                                                       strides[2], 1};
    199     functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()(
    200         context->eigen_device<Device>(), out_backprop.tensor<T, 5>(),
    201         eigen_strides, pad_dims, no_op_shuffle, padded_output.tensor<T, 5>());
    202     const Tensor& padded_output_cref = padded_output;
    203 
    204     // Fill a new "reverted" filter. We need to transpose the in_depth and
    205     // out_depth for the filter and reverse the planes, rows and cols.
    206     TensorShape r_filter_shape(
    207         {filter_size[0], filter_size[1], filter_size[2], out_depth, in_depth});
    208     Tensor r_filter;
    209     OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
    210                                                    r_filter_shape, &r_filter));
    211     Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{0, 1, 2, 4, 3};
    212     Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false};
    213     functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
    214         context->eigen_device<Device>(), filter.tensor<T, 5>(), filter_order,
    215         filter_rev_dims, r_filter.tensor<T, 5>());
    216     const Tensor& r_filter_cref = r_filter;
    217 
    218     // Now we can call conv_3d directly.
    219     functor::CuboidConvolution<Device, T>()(
    220         context->eigen_device<Device>(), in_backprop->tensor<T, 5>(),
    221         padded_output_cref.tensor<T, 5>(), r_filter_cref.tensor<T, 5>(), 1, 1,
    222         1, BrainPadding2EigenPadding(VALID));
    223   }
    224 
    225  private:
    226   std::vector<int32> stride_;
    227   Padding padding_;
    228   TensorFormat data_format_;
    229   bool takes_shape_;
    230 };
    231 
    232 #define REGISTER_CPU_KERNEL(T)                                                 \
    233   REGISTER_KERNEL_BUILDER(                                                     \
    234       Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"),   \
    235       Conv3DBackpropInputOp<CPUDevice, T>);                                    \
    236   REGISTER_KERNEL_BUILDER(                                                     \
    237       Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    238       Conv3DBackpropInputOp<CPUDevice, T>);
    239 TF_CALL_half(REGISTER_CPU_KERNEL);
    240 TF_CALL_float(REGISTER_CPU_KERNEL);
    241 TF_CALL_double(REGISTER_CPU_KERNEL);
    242 #undef REGISTER_CPU_KERNEL
    243 
    244 // Backprop for filter.
    245 template <typename Device, class T>
    246 class Conv3DBackpropFilterOp : public OpKernel {
    247  public:
    248   explicit Conv3DBackpropFilterOp(OpKernelConstruction* context)
    249       : OpKernel(context),
    250         data_format_(FORMAT_NHWC),
    251         takes_shape_(type_string().find("V2") != std::string::npos) {
    252     // data_format is only available in V2.
    253     if (takes_shape_) {
    254       string data_format;
    255       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    256       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    257                   errors::InvalidArgument("Invalid data format"));
    258       OP_REQUIRES(
    259           context, data_format_ == FORMAT_NHWC,
    260           errors::InvalidArgument(
    261               "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU."));
    262     }
    263 
    264     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    265     OP_REQUIRES(context, stride_.size() == 5,
    266                 errors::InvalidArgument("Sliding window strides field must "
    267                                         "specify 5 dimensions"));
    268     OP_REQUIRES(
    269         context,
    270         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
    271          GetTensorDim(stride_, data_format_, 'N') == 1),
    272         errors::InvalidArgument("Current implementation does not yet support "
    273                                 "strides in the batch and depth dimensions."));
    274     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    275   }
    276 
    277   void Compute(OpKernelContext* context) override {
    278     const Tensor& input = context->input(0);
    279     const TensorShape& input_shape = input.shape();
    280     TensorShape filter_shape;
    281 
    282     if (takes_shape_) {
    283       const Tensor& filter_sizes = context->input(1);
    284       OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
    285                                   filter_sizes.vec<int32>(), &filter_shape));
    286     } else {
    287       filter_shape = context->input(1).shape();
    288     }
    289 
    290     EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter");
    291     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{
    292         {0, 0},
    293         {top_pad_planes, bottom_pad_planes},
    294         {top_pad_rows, bottom_pad_rows},
    295         {left_pad_cols, right_pad_cols},
    296         {0, 0}};
    297     Tensor* filter_backprop;
    298     OP_REQUIRES_OK(context,
    299                    context->allocate_output(0, filter_shape, &filter_backprop));
    300 
    301     if (input_shape.num_elements() == 0) {
    302       filter_backprop->template flat<T>().setZero();
    303       return;
    304     }
    305 
    306     // For the backprop of the filter, we need to also transpose the
    307     // out_backprop.
    308     // The shape of backprop is
    309     //   [batch, out_z, out_y, out_x, out_depth]
    310     // And we need to change it to
    311     //   [out_depth, out_x, out_y, out_z, batch]
    312     Eigen::DSizes<Eigen::DenseIndex, 5> out_order{4, 1, 2, 3, 0};
    313     TensorShape padded_out_shape({out_depth, padded_out_planes, padded_out_rows,
    314                                   padded_out_cols, batch});
    315     Tensor padded_output;
    316     OP_REQUIRES_OK(context,
    317                    context->allocate_temp(DataTypeToEnum<T>::v(),
    318                                           padded_out_shape, &padded_output));
    319     Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1],
    320                                                       strides[2], 1};
    321     functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()(
    322         context->eigen_device<Device>(), out_backprop.tensor<T, 5>(),
    323         eigen_strides, pad_dims, out_order, padded_output.tensor<T, 5>());
    324     const Tensor& padded_output_cref = padded_output;
    325 
    326     // For the backprop of the filter, we need to transpose the input.
    327     // The shape of input is
    328     //   [batch, in_z, in_y, in_x, in_depth]
    329     // And we need to change it to
    330     //   [in_z, in_y, in_x, batch, in_depth]
    331     Eigen::DSizes<Eigen::DenseIndex, 5> in_order{1, 2, 3, 0, 4};
    332     TensorShape in_shuffle_shape(
    333         {input_size[0], input_size[1], input_size[2], batch, in_depth});
    334     Tensor in_shuffle;
    335     OP_REQUIRES_OK(context,
    336                    context->allocate_temp(DataTypeToEnum<T>::v(),
    337                                           in_shuffle_shape, &in_shuffle));
    338     // No need for reversing this time.
    339     Eigen::array<bool, 5> no_reverse{false, false, false, false, false};
    340     functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
    341         context->eigen_device<Device>(), input.tensor<T, 5>(), in_order,
    342         no_reverse, in_shuffle.tensor<T, 5>());
    343     const Tensor& in_shuffle_cref = in_shuffle;
    344 
    345     // The output of the conv_3d would be
    346     //   [out_depth, filter_size[2], filter_size[1], filter_size[0], in_depth]
    347     // and we need to shuffle it back to
    348     //   [filter_size[2], filter_size[1], filter_size[0], in_depth, out_depth];
    349     // And we need to reverse the filter backprops.
    350     // So we need to allocate (sigh) yet another piece of memory to hold the
    351     // output.
    352     TensorShape filter_shuffle_shape(
    353         {out_depth, filter_size[0], filter_size[1], filter_size[2], in_depth});
    354     Tensor filter_shuffle;
    355     OP_REQUIRES_OK(
    356         context, context->allocate_temp(DataTypeToEnum<T>::v(),
    357                                         filter_shuffle_shape, &filter_shuffle));
    358     functor::CuboidConvolution<Device, T>()(
    359         context->eigen_device<Device>(), filter_shuffle.tensor<T, 5>(),
    360         padded_output_cref.tensor<T, 5>(), in_shuffle_cref.tensor<T, 5>(), 1, 1,
    361         1, BrainPadding2EigenPadding(VALID));
    362 
    363     // Now copy the filter_backprop back to the destination.
    364     Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{1, 2, 3, 4, 0};
    365     Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false};
    366     const Tensor& filter_shuffle_cref = filter_shuffle;
    367     functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
    368         context->eigen_device<Device>(), filter_shuffle_cref.tensor<T, 5>(),
    369         filter_order, filter_rev_dims, filter_backprop->tensor<T, 5>());
    370   }
    371 
    372  private:
    373   std::vector<int32> stride_;
    374   Padding padding_;
    375   TensorFormat data_format_;
    376   bool takes_shape_;
    377 };
    378 
    379 #define REGISTER_CPU_KERNEL(T)                                                \
    380   REGISTER_KERNEL_BUILDER(                                                    \
    381       Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    382       Conv3DBackpropFilterOp<CPUDevice, T>);                                  \
    383   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
    384                               .Device(DEVICE_CPU)                             \
    385                               .TypeConstraint<T>("T"),                        \
    386                           Conv3DBackpropFilterOp<CPUDevice, T>);
    387 TF_CALL_half(REGISTER_CPU_KERNEL);
    388 TF_CALL_float(REGISTER_CPU_KERNEL);
    389 TF_CALL_double(REGISTER_CPU_KERNEL);
    390 #undef REGISTER_CPU_KERNEL
    391 
    392 // GPU definitions of both ops.
    393 #if GOOGLE_CUDA
    394 // Forward declarations of the functor specializations for GPU.
    395 // This ensures that the custom implementation is used instead of the default
    396 // Eigen one (which is used for CPU).
    397 namespace functor {
    398 #define DECLARE_GPU_SPEC(T)                                           \
    399   template <>                                                         \
    400   void TransformFilter<GPUDevice, T, int, 5>::operator()(             \
    401       const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
    402       typename TTypes<T, 5, int>::Tensor out);                        \
    403   template <>                                                         \
    404   void ReverseTransformFilter<GPUDevice, T, 5>::operator()(           \
    405       const GPUDevice& d, typename TTypes<T, 5>::ConstTensor in,      \
    406       typename TTypes<T, 5>::Tensor out);                             \
    407   template <>                                                         \
    408   void PadInput<GPUDevice, T, int, 5>::operator()(                    \
    409       const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
    410       const std::array<int, 3>& padding_left,                         \
    411       const std::array<int, 3>& padding_right,                        \
    412       typename TTypes<T, 5, int>::Tensor out, TensorFormat format);
    413 
    414 DECLARE_GPU_SPEC(Eigen::half);
    415 DECLARE_GPU_SPEC(float);
    416 #undef DECLARE_GPU_SPEC
    417 }  // namespace functor
    418 
    419 // A dummy type to group backward data autotune results together.
    420 struct Conv3dBackwardDataAutoTuneGroup {
    421   static string name() { return "Conv3dBwdData"; }
    422 };
    423 typedef AutoTuneSingleton<Conv3dBackwardDataAutoTuneGroup, ConvParameters,
    424                           perftools::gputools::dnn::AlgorithmConfig>
    425 
    426     AutoTuneConv3dBwdData;
    427 template <typename T>
    428 class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
    429  public:
    430   explicit Conv3DBackpropInputOp(OpKernelConstruction* context)
    431       : OpKernel(context),
    432         data_format_(FORMAT_NHWC),
    433         takes_shape_(type_string().find("V2") != std::string::npos) {
    434     // data_format is only available in V2.
    435     if (takes_shape_) {
    436       string data_format;
    437       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    438       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    439                   errors::InvalidArgument("Invalid data format"));
    440     }
    441     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    442     OP_REQUIRES(context, stride_.size() == 5,
    443                 errors::InvalidArgument("Sliding window strides field must "
    444                                         "specify 5 dimensions"));
    445     OP_REQUIRES(
    446         context,
    447         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
    448          GetTensorDim(stride_, data_format_, 'N') == 1),
    449         errors::InvalidArgument("Current implementation does not yet support "
    450                                 "strides in the batch and depth dimensions."));
    451     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    452     cudnn_use_autotune_ = CudnnUseAutotune();
    453   }
    454   void Compute(OpKernelContext* context) override {
    455     const Tensor& filter = context->input(1);
    456     const TensorShape& filter_shape = filter.shape();
    457     TensorShape input_shape;
    458     if (takes_shape_) {
    459       const Tensor& input_sizes = context->input(0);
    460       OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
    461                                   input_sizes.vec<int32>(), &input_shape));
    462     } else {
    463       input_shape = context->input(0).shape();
    464     }
    465     EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput");
    466     Tensor* in_backprop;
    467     OP_REQUIRES_OK(context,
    468                    context->allocate_output(0, input_shape, &in_backprop));
    469 
    470     auto* stream = context->op_device_context()->stream();
    471     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
    472 
    473     if (filter_size[0] == 1 && filter_size[1] == 1 && filter_size[2] == 1 &&
    474         stride_[0] == 1 && stride_[1] == 1 && stride_[2] == 1 &&
    475         data_format_ == FORMAT_NHWC) {
    476       const uint64 m = batch * input_size[0] * input_size[1] * input_size[2];
    477       const uint64 k = out_depth;
    478       const uint64 n = in_depth;
    479 
    480       auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
    481                                   out_backprop.template flat<T>().size());
    482       auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
    483                                   filter.template flat<T>().size());
    484       auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
    485                                   in_backprop->template flat<T>().size());
    486 
    487       auto transpose = perftools::gputools::blas::Transpose::kTranspose;
    488       auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
    489 
    490       bool blas_launch_status =
    491           stream
    492               ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
    493                              a_ptr, k, 0.0f, &c_ptr, n)
    494               .ok();
    495       if (!blas_launch_status) {
    496         context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
    497                                             ", n=", n, ", k=", k));
    498       }
    499       return;
    500     } else if (filter_size[0] == input_size[0] &&
    501                filter_size[1] == input_size[1] &&
    502                filter_size[2] == input_size[2] && padding_ == Padding::VALID &&
    503                data_format_ == FORMAT_NHWC) {
    504       const uint64 m = batch;
    505       const uint64 k = out_depth;
    506       const uint64 n = input_size[0] * input_size[1] * input_size[2] * in_depth;
    507 
    508       auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
    509                                   out_backprop.template flat<T>().size());
    510       auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
    511                                   filter.template flat<T>().size());
    512       auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
    513                                   in_backprop->template flat<T>().size());
    514 
    515       auto transpose = perftools::gputools::blas::Transpose::kTranspose;
    516       auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
    517 
    518       bool blas_launch_status =
    519           stream
    520               ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
    521                              a_ptr, k, 0.0f, &c_ptr, n)
    522               .ok();
    523       if (!blas_launch_status) {
    524         context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
    525                                             ", n=", n, ", k=", k));
    526       }
    527       return;
    528     }
    529 
    530     int padding_rows = 0, padding_cols = 0, padding_planes = 0;
    531 
    532     if (padding_ == Padding::SAME) {
    533       padding_planes = std::max<int>(
    534           0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]);
    535       padding_cols = std::max<int>(
    536           0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
    537       padding_rows = std::max<int>(
    538           0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
    539     }
    540     const bool rows_odd = (padding_rows % 2 != 0);
    541     const bool cols_odd = (padding_cols % 2 != 0);
    542     const bool planes_odd = (padding_planes % 2 != 0);
    543 
    544     TensorShape compatible_input_shape;
    545     if (rows_odd || cols_odd || planes_odd) {
    546       // cuDNN only supports the same amount of padding on both sides.
    547       compatible_input_shape = {
    548           batch,
    549           in_depth,
    550           input_size[0] + planes_odd,
    551           input_size[1] + rows_odd,
    552           input_size[2] + cols_odd,
    553       };
    554     } else {
    555       compatible_input_shape = {batch, in_depth, input_size[0], input_size[1],
    556                                 input_size[2]};
    557     }
    558 
    559     CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
    560         << "Negative paddings: (" << padding_rows << ", " << padding_cols
    561         << ", " << padding_planes << ")";
    562     perftools::gputools::dnn::BatchDescriptor input_desc(3);
    563     input_desc.set_count(batch)
    564         .set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
    565         .set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3))
    566         .set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2))
    567         .set_feature_map_count(in_depth)
    568         .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
    569     perftools::gputools::dnn::BatchDescriptor output_desc(3);
    570     output_desc.set_count(batch)
    571         .set_spatial_dim(DimIndex::X, output_cols)
    572         .set_spatial_dim(DimIndex::Y, output_rows)
    573         .set_spatial_dim(DimIndex::Z, output_planes)
    574         .set_feature_map_count(out_depth)
    575         .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
    576     perftools::gputools::dnn::FilterDescriptor filter_desc(3);
    577     filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
    578         .set_spatial_dim(DimIndex::Y, filter_size[1])
    579         .set_spatial_dim(DimIndex::Z, filter_size[0])
    580         .set_input_feature_map_count(in_depth)
    581         .set_output_feature_map_count(out_depth);
    582     perftools::gputools::dnn::ConvolutionDescriptor conv_desc(3);
    583     conv_desc.set_filter_stride(DimIndex::X, strides[2])
    584         .set_filter_stride(DimIndex::Y, strides[1])
    585         .set_filter_stride(DimIndex::Z, strides[0])
    586         .set_zero_padding(DimIndex::X, padding_cols / 2)
    587         .set_zero_padding(DimIndex::Y, padding_rows / 2)
    588         .set_zero_padding(DimIndex::Z, padding_planes / 2);
    589 
    590     // Shape: out, in, z, y, x.
    591     Tensor transformed_filter;
    592     OP_REQUIRES_OK(
    593         context,
    594         context->allocate_temp(DataTypeToEnum<T>::value,
    595                                TensorShape({out_depth, in_depth, filter_size[0],
    596                                             filter_size[1], filter_size[2]}),
    597                                &transformed_filter));
    598     functor::TransformFilter<GPUDevice, T, int, 5>()(
    599         context->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()),
    600         To32Bit(transformed_filter.tensor<T, 5>()));
    601 
    602     // Shape: batch, filters, z, y, x.
    603     Tensor transformed_out_backprop;
    604     if (data_format_ == FORMAT_NHWC) {
    605       TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows,
    606                                 output_cols};
    607       if (out_depth > 1) {
    608         OP_REQUIRES_OK(context, context->allocate_temp(
    609                                     DataTypeToEnum<T>::value, nchw_shape,
    610                                     &transformed_out_backprop));
    611         functor::NHWCToNCHW<GPUDevice, T, 5>()(
    612             context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
    613             transformed_out_backprop.tensor<T, 5>());
    614       } else {
    615         CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
    616       }
    617     } else {
    618       transformed_out_backprop = out_backprop;
    619     }
    620     // Shape: batch, filters, z, y, x.
    621     Tensor pre_transformed_in_backprop;
    622     OP_REQUIRES_OK(
    623         context,
    624         context->allocate_temp(DataTypeToEnum<T>::value, compatible_input_shape,
    625                                &pre_transformed_in_backprop));
    626 
    627     auto out_backprop_ptr =
    628         AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
    629                        transformed_out_backprop.template flat<T>().size());
    630     auto filter_ptr =
    631         AsDeviceMemory(transformed_filter.template flat<T>().data(),
    632                        transformed_filter.template flat<T>().size());
    633     auto in_backprop_ptr =
    634         AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
    635                        pre_transformed_in_backprop.template flat<T>().size());
    636 
    637     static int64 ConvolveBackwardDataScratchSize = GetCudnnWorkspaceLimit(
    638         "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32);  // 4GB by default
    639 
    640     const int device_id = stream->parent()->device_ordinal();
    641     DataType dtype = context->input(0).dtype();
    642     const ConvParameters conv_parameters = {
    643         batch,
    644         in_depth,
    645         {{input_size[0], input_size[1], input_size[2]}},
    646         out_depth,
    647         {{filter_size[0], filter_size[1], filter_size[2]}},
    648         // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
    649         // conv is supported.
    650         /*dilation=*/{{1, 1, 1}},
    651         {{strides[0], strides[1], strides[2]}},
    652         {{padding_planes, padding_rows, padding_cols}},
    653         dtype,
    654         device_id,
    655     };
    656 
    657     using perftools::gputools::dnn::AlgorithmConfig;
    658     using perftools::gputools::dnn::AlgorithmDesc;
    659     using perftools::gputools::dnn::ProfileResult;
    660     AlgorithmConfig algorithm_config;
    661     if (cudnn_use_autotune_ && !AutoTuneConv3dBwdData::GetInstance()->Find(
    662                                    conv_parameters, &algorithm_config)) {
    663       std::vector<AlgorithmDesc> algorithms;
    664       CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
    665           conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
    666       ProfileResult best_result;
    667       ProfileResult best_result_no_scratch;
    668       for (auto profile_algorithm : algorithms) {
    669         // TODO(zhengxq): profile each algorithm multiple times to better
    670         // accuracy.
    671         CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
    672                                                 context);
    673         ProfileResult profile_result;
    674         bool cudnn_launch_status =
    675             stream
    676                 ->ThenConvolveBackwardDataWithAlgorithm(
    677                     filter_desc, filter_ptr, output_desc, out_backprop_ptr,
    678                     conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
    679                     AlgorithmConfig(profile_algorithm), &profile_result)
    680                 .ok();
    681         if (cudnn_launch_status) {
    682           if (profile_result.is_valid()) {
    683             if (profile_result.elapsed_time_in_ms() <
    684                 best_result.elapsed_time_in_ms()) {
    685               best_result = profile_result;
    686             }
    687             if (scratch_allocator.TotalByteSize() == 0 &&
    688                 profile_result.elapsed_time_in_ms() <
    689                     best_result_no_scratch.elapsed_time_in_ms()) {
    690               best_result_no_scratch = profile_result;
    691             }
    692           }
    693         }
    694       }
    695       OP_REQUIRES(context,
    696                   best_result.is_valid() || best_result_no_scratch.is_valid(),
    697                   errors::NotFound("No algorithm worked!"));
    698       if (best_result.is_valid()) {
    699         algorithm_config.set_algorithm(best_result.algorithm());
    700       }
    701       if (best_result_no_scratch.is_valid()) {
    702         algorithm_config.set_algorithm_no_scratch(
    703             best_result_no_scratch.algorithm());
    704       }
    705       AutoTuneConv3dBwdData::GetInstance()->Insert(conv_parameters,
    706                                                    algorithm_config);
    707     }
    708     CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
    709                                             context);
    710     bool cudnn_launch_status =
    711         stream
    712             ->ThenConvolveBackwardDataWithAlgorithm(
    713                 filter_desc, filter_ptr, output_desc, out_backprop_ptr,
    714                 conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
    715                 algorithm_config, nullptr)
    716             .ok();
    717 
    718     if (!cudnn_launch_status) {
    719       context->SetStatus(errors::Internal(
    720           "cuDNN Backward Data function launch failure : input shape(",
    721           input_shape.DebugString(), ") filter shape(",
    722           filter_shape.DebugString(), ")"));
    723     }
    724 
    725     if (rows_odd || cols_odd || planes_odd) {
    726       Tensor in_backprop_remove_padding;
    727       OP_REQUIRES_OK(context,
    728                      context->allocate_temp(DataTypeToEnum<T>::value,
    729                                             {batch, in_depth, input_size[0],
    730                                              input_size[1], input_size[2]},
    731                                             &in_backprop_remove_padding));
    732 
    733       // Remove the padding for odd spatial dimensions.
    734       functor::PadInput<GPUDevice, T, int, 5>()(
    735           context->eigen_device<GPUDevice>(),
    736           To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
    737                       .tensor<T, 5>()),
    738           {{0, 0, 0}}, {{-planes_odd, -rows_odd, -cols_odd}},
    739           To32Bit(in_backprop_remove_padding.tensor<T, 5>()), FORMAT_NCHW);
    740 
    741       pre_transformed_in_backprop = in_backprop_remove_padding;
    742     }
    743 
    744     if (data_format_ == FORMAT_NHWC) {
    745       auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
    746       functor::NCHWToNHWC<GPUDevice, T, 5>()(
    747           context->eigen_device<GPUDevice>(),
    748           toConstTensor(pre_transformed_in_backprop).template tensor<T, 5>(),
    749           in_backprop->tensor<T, 5>());
    750     } else {
    751       *in_backprop = pre_transformed_in_backprop;
    752     }
    753   }
    754 
    755  private:
    756   std::vector<int32> stride_;
    757   Padding padding_;
    758   TensorFormat data_format_;
    759   bool takes_shape_;
    760   bool cudnn_use_autotune_;
    761 };
    762 
    763 // A dummy type to group backward filter autotune results together.
    764 struct Conv3dBackwardFilterAutoTuneGroup {
    765   static string name() { return "Conv3dBwdFilter"; }
    766 };
    767 typedef AutoTuneSingleton<Conv3dBackwardFilterAutoTuneGroup, ConvParameters,
    768                           perftools::gputools::dnn::AlgorithmConfig>
    769     AutoTuneConv3dBwdFilter;
    770 
    771 template <typename T>
    772 class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
    773  public:
    774   explicit Conv3DBackpropFilterOp(OpKernelConstruction* context)
    775       : OpKernel(context),
    776         data_format_(FORMAT_NHWC),
    777         takes_shape_(type_string().find("V2") != std::string::npos) {
    778     // data_format is only available in V2.
    779     if (takes_shape_) {
    780       string data_format;
    781       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    782       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    783                   errors::InvalidArgument("Invalid data format"));
    784     }
    785     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    786     OP_REQUIRES(context, stride_.size() == 5,
    787                 errors::InvalidArgument("Sliding window strides field must "
    788                                         "specify 5 dimensions"));
    789     OP_REQUIRES(
    790         context,
    791         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
    792          GetTensorDim(stride_, data_format_, 'N') == 1),
    793         errors::InvalidArgument("Current implementation does not yet support "
    794                                 "strides in the batch and depth dimensions."));
    795     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    796     cudnn_use_autotune_ = CudnnUseAutotune();
    797   }
    798 
    799   void Compute(OpKernelContext* context) override {
    800     const Tensor& input = context->input(0);
    801     const TensorShape& input_shape = input.shape();
    802     TensorShape filter_shape;
    803     if (takes_shape_) {
    804       const Tensor& filter_sizes = context->input(1);
    805       OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
    806                                   filter_sizes.vec<int32>(), &filter_shape));
    807     } else {
    808       filter_shape = context->input(1).shape();
    809     }
    810 
    811     EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter");
    812 
    813     Tensor* filter_backprop;
    814     OP_REQUIRES_OK(context,
    815                    context->allocate_output(0, filter_shape, &filter_backprop));
    816 
    817     auto* stream = context->op_device_context()->stream();
    818     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
    819 
    820     if (filter_size[1] == 1 && filter_size[2] == 1 && filter_size[0] == 1 &&
    821         strides[2] == 1 && strides[1] == 1 && strides[0] == 1 &&
    822         data_format_ == FORMAT_NHWC) {
    823       const uint64 m = in_depth;
    824       const uint64 k = batch * input_size[1] * input_size[2] * input_size[0];
    825       const uint64 n = out_depth;
    826 
    827       // The shape of output backprop is
    828       //   [batch, out_z, out_y, out_x, out_depth]
    829       // From cublas's perspective, it is: n x k
    830       auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
    831                                   out_backprop.template flat<T>().size());
    832 
    833       // The shape of input is:
    834       //   [batch, in_z, in_y, in_x, in_depth],
    835       // From cublas's perspective, it is: m x k
    836       auto b_ptr = AsDeviceMemory(input.template flat<T>().data(),
    837                                   input.template flat<T>().size());
    838 
    839       // The shape of the filter backprop is:
    840       //   [1, 1, 1, in_depth, out_depth]
    841       // From cublas's perspective, it is: n x m
    842       auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
    843                                   filter_backprop->template flat<T>().size());
    844 
    845       bool blas_launch_status =
    846           stream
    847               ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose,
    848                              perftools::gputools::blas::Transpose::kTranspose,
    849                              n, m, k, 1.0f, a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n)
    850               .ok();
    851       if (!blas_launch_status) {
    852         context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
    853                                             ", n=", n, ", k=", k));
    854       }
    855       return;
    856     } else if (filter_size[0] == input_size[0] &&
    857                filter_size[1] == input_size[1] &&
    858                filter_size[2] == input_size[2] && padding_ == Padding::VALID &&
    859                data_format_ == FORMAT_NHWC) {
    860       const uint64 m = input_size[0] * input_size[1] * input_size[2] * in_depth;
    861       const uint64 k = batch;
    862       const uint64 n = out_depth;
    863 
    864       auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
    865                                   input.template flat<T>().size());
    866       auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
    867                                   out_backprop.template flat<T>().size());
    868       auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
    869                                   filter_backprop->template flat<T>().size());
    870 
    871       bool blas_launch_status =
    872           stream
    873               ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose,
    874                              perftools::gputools::blas::Transpose::kTranspose,
    875                              n, m, k, 1.0f, b_ptr, n, a_ptr, m, 0.0f, &c_ptr, n)
    876               .ok();
    877       if (!blas_launch_status) {
    878         context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
    879                                             ", n=", n, ", k=", k));
    880       }
    881       return;
    882     }
    883 
    884     int padding_rows = 0, padding_cols = 0, padding_planes = 0;
    885 
    886     if (padding_ == Padding::SAME) {
    887       padding_planes = std::max<int>(
    888           0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]);
    889       padding_cols = std::max<int>(
    890           0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
    891       padding_rows = std::max<int>(
    892           0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
    893     }
    894     bool rows_odd = (padding_rows % 2 != 0);
    895     bool cols_odd = (padding_cols % 2 != 0);
    896     bool planes_odd = (padding_planes % 2 != 0);
    897 
    898     Tensor compatible_input;
    899     if (rows_odd || cols_odd || planes_odd) {
    900       OP_REQUIRES_OK(context, context->allocate_temp(
    901                                   DataTypeToEnum<T>::value,
    902                                   ShapeFromFormat(data_format_, batch,
    903                                                   {{input_size[0] + planes_odd,
    904                                                     input_size[1] + rows_odd,
    905                                                     input_size[2] + cols_odd}},
    906                                                   in_depth),
    907                                   &compatible_input));
    908       functor::PadInput<GPUDevice, T, int, 5>()(
    909           context->template eigen_device<GPUDevice>(),
    910           To32Bit(input.tensor<T, 5>()), {{0, 0, 0}},
    911           {{planes_odd, rows_odd, cols_odd}},
    912           To32Bit(compatible_input.tensor<T, 5>()), data_format_);
    913     } else {
    914       compatible_input = input;
    915     }
    916 
    917     CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
    918         << "Negative paddings: (" << padding_rows << ", " << padding_cols
    919         << ", " << padding_planes << ")";
    920     perftools::gputools::dnn::BatchDescriptor input_desc(3);
    921     input_desc.set_count(batch)
    922         .set_spatial_dim(DimIndex::X,
    923                          GetTensorDim(compatible_input, data_format_, '2'))
    924         .set_spatial_dim(DimIndex::Y,
    925                          GetTensorDim(compatible_input, data_format_, '1'))
    926         .set_spatial_dim(DimIndex::Z,
    927                          GetTensorDim(compatible_input, data_format_, '0'))
    928         .set_feature_map_count(in_depth)
    929         .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
    930     perftools::gputools::dnn::BatchDescriptor output_desc(3);
    931     output_desc.set_count(batch)
    932         .set_spatial_dim(DimIndex::X, output_cols)
    933         .set_spatial_dim(DimIndex::Y, output_rows)
    934         .set_spatial_dim(DimIndex::Z, output_planes)
    935         .set_feature_map_count(out_depth)
    936         .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
    937     perftools::gputools::dnn::FilterDescriptor filter_desc(3);
    938     filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
    939         .set_spatial_dim(DimIndex::Y, filter_size[1])
    940         .set_spatial_dim(DimIndex::Z, filter_size[0])
    941         .set_input_feature_map_count(in_depth)
    942         .set_output_feature_map_count(out_depth);
    943     perftools::gputools::dnn::ConvolutionDescriptor conv_desc(3);
    944     conv_desc.set_filter_stride(DimIndex::X, strides[2])
    945         .set_filter_stride(DimIndex::Y, strides[1])
    946         .set_filter_stride(DimIndex::Z, strides[0])
    947         .set_zero_padding(DimIndex::X, padding_cols / 2)
    948         .set_zero_padding(DimIndex::Y, padding_rows / 2)
    949         .set_zero_padding(DimIndex::Z, padding_planes / 2);
    950 
    951     Tensor pre_transformed_filter_backprop;
    952     OP_REQUIRES_OK(
    953         context,
    954         context->allocate_temp(DataTypeToEnum<T>::value,
    955                                TensorShape({out_depth, in_depth, filter_size[0],
    956                                             filter_size[1], filter_size[2]}),
    957                                &pre_transformed_filter_backprop));
    958 
    959     Tensor transformed_out_backprop;
    960     if (data_format_ == FORMAT_NHWC) {
    961       TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows,
    962                                 output_cols};
    963       OP_REQUIRES_OK(
    964           context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
    965                                           &transformed_out_backprop));
    966       if (out_depth > 1) {
    967         functor::NHWCToNCHW<GPUDevice, T, 5>()(
    968             context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
    969             transformed_out_backprop.tensor<T, 5>());
    970       } else {
    971         CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
    972       }
    973     } else {
    974       transformed_out_backprop = out_backprop;
    975     }
    976     Tensor transformed_input;
    977     if (data_format_ == FORMAT_NHWC) {
    978       TensorShape nchw_shape = {batch, in_depth, compatible_input.dim_size(1),
    979                                 compatible_input.dim_size(2),
    980                                 compatible_input.dim_size(3)};
    981       if (in_depth > 1) {
    982         OP_REQUIRES_OK(context,
    983                        context->allocate_temp(DataTypeToEnum<T>::value,
    984                                               nchw_shape, &transformed_input));
    985         functor::NHWCToNCHW<GPUDevice, T, 5>()(
    986             context->eigen_device<GPUDevice>(),
    987             const_cast<const Tensor&>(compatible_input).tensor<T, 5>(),
    988             transformed_input.tensor<T, 5>());
    989       } else {
    990         CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape));
    991       }
    992     } else {
    993       transformed_input = compatible_input;
    994     }
    995 
    996     auto out_backprop_ptr =
    997         AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
    998                        transformed_out_backprop.template flat<T>().size());
    999     auto filter_backprop_ptr = AsDeviceMemory(
   1000         pre_transformed_filter_backprop.template flat<T>().data(),
   1001         pre_transformed_filter_backprop.template flat<T>().size());
   1002     auto input_ptr =
   1003         AsDeviceMemory(transformed_input.template flat<T>().data(),
   1004                        transformed_input.template flat<T>().size());
   1005 
   1006     static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit(
   1007         "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32);  // 4GB by default
   1008 
   1009     const int device_id = stream->parent()->device_ordinal();
   1010     DataType dtype = input.dtype();
   1011     const ConvParameters conv_parameters = {
   1012         batch,
   1013         in_depth,
   1014         {{input_size[0], input_size[1], input_size[2]}},
   1015         out_depth,
   1016         {{filter_size[0], filter_size[1], filter_size[2]}},
   1017         {{1, 1, 1}},
   1018         {{strides[0], strides[1], strides[2]}},
   1019         {{padding_planes, padding_rows, padding_cols}},
   1020         dtype,
   1021         device_id,
   1022     };
   1023 
   1024     using perftools::gputools::dnn::AlgorithmConfig;
   1025     using perftools::gputools::dnn::AlgorithmDesc;
   1026     using perftools::gputools::dnn::ProfileResult;
   1027     AlgorithmConfig algorithm_config;
   1028     if (cudnn_use_autotune_ && !AutoTuneConv3dBwdFilter::GetInstance()->Find(
   1029                                    conv_parameters, &algorithm_config)) {
   1030       std::vector<AlgorithmDesc> algorithms;
   1031       CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
   1032           conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
   1033       ProfileResult best_result;
   1034       ProfileResult best_result_no_scratch;
   1035       for (auto profile_algorithm : algorithms) {
   1036         // TODO(zhengxq): profile each algorithm multiple times to better
   1037         // accuracy.
   1038         CudnnScratchAllocator scratch_allocator(
   1039             ConvolveBackwardFilterScratchSize, context);
   1040         ProfileResult profile_result;
   1041         bool cudnn_launch_status =
   1042             stream
   1043                 ->ThenConvolveBackwardFilterWithAlgorithm(
   1044                     input_desc, input_ptr, output_desc, out_backprop_ptr,
   1045                     conv_desc, filter_desc, &filter_backprop_ptr,
   1046                     &scratch_allocator, AlgorithmConfig(profile_algorithm),
   1047                     &profile_result)
   1048                 .ok();
   1049         if (cudnn_launch_status) {
   1050           if (profile_result.is_valid()) {
   1051             if (profile_result.elapsed_time_in_ms() <
   1052                 best_result.elapsed_time_in_ms()) {
   1053               best_result = profile_result;
   1054             }
   1055             if (scratch_allocator.TotalByteSize() == 0 &&
   1056                 profile_result.elapsed_time_in_ms() <
   1057                     best_result_no_scratch.elapsed_time_in_ms()) {
   1058               best_result_no_scratch = profile_result;
   1059             }
   1060           }
   1061         }
   1062       }
   1063       OP_REQUIRES(context,
   1064                   best_result.is_valid() || best_result_no_scratch.is_valid(),
   1065                   errors::NotFound("No algorithm worked!"));
   1066       if (best_result.is_valid()) {
   1067         algorithm_config.set_algorithm(best_result.algorithm());
   1068       }
   1069       if (best_result_no_scratch.is_valid()) {
   1070         algorithm_config.set_algorithm_no_scratch(
   1071             best_result_no_scratch.algorithm());
   1072       }
   1073       AutoTuneConv3dBwdFilter::GetInstance()->Insert(conv_parameters,
   1074                                                      algorithm_config);
   1075     }
   1076     CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
   1077                                             context);
   1078     bool cudnn_launch_status =
   1079         stream
   1080             ->ThenConvolveBackwardFilterWithAlgorithm(
   1081                 input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
   1082                 filter_desc, &filter_backprop_ptr, &scratch_allocator,
   1083                 algorithm_config, nullptr)
   1084             .ok();
   1085 
   1086     if (!cudnn_launch_status) {
   1087       context->SetStatus(errors::Internal(
   1088           "cuDNN Backward Filter function launch failure : input shape(",
   1089           input_shape.DebugString(), ") filter shape(",
   1090           filter_shape.DebugString(), ")"));
   1091     }
   1092 
   1093     auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
   1094     functor::ReverseTransformFilter<GPUDevice, T, 5>()(
   1095         context->eigen_device<GPUDevice>(),
   1096         toConstTensor(pre_transformed_filter_backprop).template tensor<T, 5>(),
   1097         filter_backprop->tensor<T, 5>());
   1098   }
   1099 
   1100  private:
   1101   std::vector<int32> stride_;
   1102   Padding padding_;
   1103   TensorFormat data_format_;
   1104   bool takes_shape_;
   1105   bool cudnn_use_autotune_;
   1106 };
   1107 
   1108 #define REGISTER_GPU_KERNEL(T)                                                \
   1109   REGISTER_KERNEL_BUILDER(                                                    \
   1110       Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint<T>("T"),  \
   1111       Conv3DBackpropInputOp<GPUDevice, T>);                                   \
   1112   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2")                       \
   1113                               .Device(DEVICE_GPU)                             \
   1114                               .TypeConstraint<T>("T")                         \
   1115                               .HostMemory("input_sizes"),                     \
   1116                           Conv3DBackpropInputOp<GPUDevice, T>);               \
   1117   REGISTER_KERNEL_BUILDER(                                                    \
   1118       Name("Conv3DBackpropFilter").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
   1119       Conv3DBackpropFilterOp<GPUDevice, T>);                                  \
   1120   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
   1121                               .Device(DEVICE_GPU)                             \
   1122                               .TypeConstraint<T>("T")                         \
   1123                               .HostMemory("filter_sizes"),                    \
   1124                           Conv3DBackpropFilterOp<GPUDevice, T>);
   1125 TF_CALL_half(REGISTER_GPU_KERNEL);
   1126 TF_CALL_float(REGISTER_GPU_KERNEL);
   1127 #undef REGISTER_GPU_KERNEL
   1128 
   1129 #endif  // GOOGLE_CUDA
   1130 
   1131 }  // namespace tensorflow
   1132