Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc.
     17 
     18 #define USE_EIGEN_TENSOR
     19 #define EIGEN_USE_THREADS
     20 
     21 #include "tensorflow/core/kernels/conv_ops.h"
     22 #include <string.h>
     23 #include <map>
     24 #include <vector>
     25 #include "tensorflow/core/framework/numeric_op.h"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/register_types.h"
     28 #include "tensorflow/core/framework/tensor.h"
     29 #include "tensorflow/core/framework/tensor_shape.h"
     30 #include "tensorflow/core/framework/tensor_slice.h"
     31 #include "tensorflow/core/kernels/bounds_check.h"
     32 #include "tensorflow/core/kernels/conv_2d.h"
     33 #include "tensorflow/core/kernels/deep_conv2d.h"
     34 #include "tensorflow/core/kernels/ops_util.h"
     35 #ifdef TENSORFLOW_USE_LIBXSMM
     36 #include "tensorflow/core/kernels/xsmm_conv2d.h"
     37 #endif
     38 #include "tensorflow/core/lib/core/errors.h"
     39 #include "tensorflow/core/lib/gtl/array_slice.h"
     40 #include "tensorflow/core/lib/strings/numbers.h"
     41 #include "tensorflow/core/lib/strings/str_util.h"
     42 #include "tensorflow/core/platform/logging.h"
     43 #include "tensorflow/core/platform/macros.h"
     44 #include "tensorflow/core/util/padding.h"
     45 #include "tensorflow/core/util/tensor_format.h"
     46 #include "tensorflow/core/util/use_cudnn.h"
     47 
     48 #if GOOGLE_CUDA
     49 #include "tensorflow/core/kernels/conv_ops_gpu.h"
     50 #include "tensorflow/core/platform/stream_executor.h"
     51 #endif  // GOOGLE_CUDA
     52 
     53 namespace tensorflow {
     54 
     55 typedef Eigen::ThreadPoolDevice CPUDevice;
     56 typedef Eigen::GpuDevice GPUDevice;
     57 
     58 namespace {
     59 template <typename Device, typename T>
     60 struct LaunchGeneric {
     61   void operator()(OpKernelContext* ctx, const Tensor& input,
     62                   const Tensor& filter, int row_stride, int col_stride,
     63                   int row_dilation, int col_dilation, const Padding& padding,
     64                   Tensor* output, TensorFormat data_format) {
     65     CHECK(data_format == FORMAT_NHWC) << "Generic conv implementation only "
     66                                          "supports NHWC tensor format for now.";
     67     if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 &&
     68         col_stride == 1) {
     69       // For 1x1 kernel, the 2D convolution is reduced to matrix
     70       // multiplication.
     71       //
     72       // TODO(vrv): We should be able to call SpatialConvolution
     73       // and it will produce the same result, but doing so
     74       // led to NaNs during training.  Using matmul instead for now.
     75       int conv_width = 1;  // Width for the convolution step.
     76       for (int i = 0; i < 3; ++i) {
     77         conv_width *= output->dim_size(i);
     78       }
     79 
     80       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
     81       dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
     82       functor::MatMulConvFunctor<Device, T>()(
     83           ctx->eigen_device<Device>(),
     84           output->shaped<T, 2>({conv_width, filter.dim_size(3)}),
     85           input.shaped<T, 2>({conv_width, filter.dim_size(2)}),
     86           filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}),
     87           dim_pair);
     88     } else if (filter.dim_size(0) == input.dim_size(1) &&
     89                filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 &&
     90                col_dilation == 1 && padding == VALID) {
     91       // If the input data and filter have the same height/width,
     92       // the 2D convolution is reduced to matrix multiplication.
     93       const int k =  // Length of reduction dimension.
     94           filter.dim_size(0) * filter.dim_size(1) * filter.dim_size(2);
     95 
     96       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
     97       dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
     98       functor::MatMulConvFunctor<Device, T>()(
     99           ctx->eigen_device<Device>(),
    100           output->shaped<T, 2>({input.dim_size(0), filter.dim_size(3)}),
    101           input.shaped<T, 2>({input.dim_size(0), k}),
    102           filter.shaped<T, 2>({k, filter.dim_size(3)}), dim_pair);
    103     } else {
    104       functor::SpatialConvolution<Device, T>()(
    105           ctx->eigen_device<Device>(), output->tensor<T, 4>(),
    106           input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
    107           row_dilation, col_dilation, BrainPadding2EigenPadding(padding));
    108     }
    109   }
    110 };
    111 }  // namespace
    112 
    113 template <typename T>
    114 struct LaunchConv2DOp<CPUDevice, T> {
    115   void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
    116                   const Tensor& input, const Tensor& filter, int row_dilation,
    117                   int col_dilation, int row_stride, int col_stride,
    118                   const Padding& padding, Tensor* output,
    119                   TensorFormat data_format) {
    120     if (data_format != FORMAT_NHWC) {
    121       ctx->SetStatus(
    122           errors::Unimplemented("Generic conv implementation only supports "
    123                                 "NHWC tensor format for now."));
    124       return;
    125     }
    126     LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
    127                                   row_dilation, col_dilation, padding, output,
    128                                   data_format);
    129   }
    130 };
    131 
    132 template <typename Device, typename T>
    133 class LaunchDeepConvOp {
    134  public:
    135   static bool Run(OpKernelContext* ctx, const Tensor& input,
    136                   const Tensor& filter, int batch, int input_rows,
    137                   int input_cols, int in_depth, int filter_rows,
    138                   int filter_cols, int pad_rows, int pad_cols, int out_rows,
    139                   int /*out_cols*/, int /*out_depth*/, int /*dilation_rows*/,
    140                   int /*dilation_cols*/, int /*stride_rows*/,
    141                   int /*stride_cols*/, Tensor* /*output*/,
    142                   TensorFormat /*data_format*/) {
    143     return false;
    144   }
    145 };
    146 
    147 // Conditionally launches DeepConv operation based on convolution parameters.
    148 template <>
    149 class LaunchDeepConvOp<CPUDevice, float> {
    150  public:
    151   static bool Run(OpKernelContext* ctx, const Tensor& input,
    152                   const Tensor& filter, int batch, int input_rows,
    153                   int input_cols, int in_depth, int filter_rows,
    154                   int filter_cols, int pad_rows, int pad_cols, int out_rows,
    155                   int out_cols, int out_depth, int dilation_rows,
    156                   int dilation_cols, int stride_rows, int stride_cols,
    157                   Tensor* output, TensorFormat data_format) {
    158     if (data_format != FORMAT_NHWC || dilation_rows != 1 ||
    159         dilation_cols != 1 ||
    160         !CanUseDeepConv2D(stride_rows, stride_cols, filter_rows, filter_cols,
    161                           in_depth, out_depth, out_rows, out_cols)) {
    162       return false;
    163     }
    164 
    165     Conv2DArgs args;
    166     args.batch = batch;
    167     args.in_rows = input_rows;
    168     args.in_cols = input_cols;
    169     args.in_depth = in_depth;
    170     args.filter_rows = filter_rows;
    171     args.filter_cols = filter_cols;
    172     args.pad_rows = pad_rows;
    173     args.pad_cols = pad_cols;
    174     args.out_rows = out_rows;
    175     args.out_cols = out_cols;
    176     args.out_depth = out_depth;
    177 
    178     auto input_ptr = input.template flat<float>().data();
    179     auto filter_ptr = filter.template flat<float>().data();
    180     auto output_ptr = output->template flat<float>().data();
    181 
    182     functor::DeepConv2D<CPUDevice, float>()(ctx, args, input_ptr, filter_ptr,
    183                                             output_ptr);
    184     return true;
    185   }
    186 };
    187 
    188 #ifdef TENSORFLOW_USE_LIBXSMM
    189 template <typename Device, typename T>
    190 class LaunchXsmmConvOp {
    191  public:
    192   static bool Run(OpKernelContext* ctx, const Tensor& input,
    193                   const Tensor& filter, int batch, int input_rows,
    194                   int input_cols, int in_depth, int filter_rows,
    195                   int filter_cols, int pad_rows, int pad_cols, int out_rows,
    196                   int out_cols, int out_depth, int stride_rows, int stride_cols,
    197                   int dilation_rows, int dilation_cols, Tensor* output,
    198                   TensorFormat data_format) {
    199     return false;
    200   }
    201 };
    202 
    203 template <>
    204 class LaunchXsmmConvOp<CPUDevice, float> {
    205  public:
    206   static bool Run(OpKernelContext* ctx, const Tensor& input,
    207                   const Tensor& filter, int batch, int input_rows,
    208                   int input_cols, int in_depth, int filter_rows,
    209                   int filter_cols, int pad_rows, int pad_cols, int out_rows,
    210                   int out_cols, int out_depth, int dilation_rows,
    211                   int dilation_cols, int stride_rows, int stride_cols,
    212                   Tensor* output, TensorFormat data_format) {
    213     auto num_threads =
    214         ctx->device()->tensorflow_cpu_worker_threads()->num_threads;
    215     // See libxsmm_dnn.h for this struct definition.
    216     libxsmm_dnn_conv_desc desc;
    217     desc.N = batch;
    218     desc.C = in_depth;
    219     desc.H = input_rows;
    220     desc.W = input_cols;
    221     desc.K = out_depth;
    222     desc.R = filter_rows;
    223     desc.S = filter_cols;
    224     desc.u = stride_rows;
    225     desc.v = stride_cols;
    226     desc.pad_h = pad_rows;
    227     desc.pad_w = pad_cols;
    228     desc.pad_h_in = 0;
    229     desc.pad_w_in = 0;
    230     desc.pad_h_out = 0;
    231     desc.pad_w_out = 0;
    232     desc.threads = num_threads;
    233     desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
    234     desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
    235     desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;
    236     desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
    237     desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE;
    238     desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
    239 
    240     if (dilation_rows != 1 || dilation_cols != 1 ||
    241         !CanUseXsmmConv2D(desc, data_format)) {
    242       return false;
    243     }
    244 
    245     auto input_ptr = input.template flat<float>().data();
    246     auto filter_ptr = filter.template flat<float>().data();
    247     auto output_ptr = output->template flat<float>().data();
    248 
    249     bool success = functor::XsmmFwdConv2D<CPUDevice, float>()(
    250         ctx, desc, input_ptr, filter_ptr, output_ptr);
    251     return success;
    252   }
    253 };
    254 #endif
    255 
    256 template <typename Device, typename T>
    257 class Conv2DOp : public BinaryOp<T> {
    258  public:
    259   explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
    260     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
    261     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
    262     string data_format;
    263     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    264     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    265                 errors::InvalidArgument("Invalid data format"));
    266     OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
    267     use_cudnn_ &= CanUseCudnn();
    268     cudnn_use_autotune_ = CudnnUseAutotune();
    269     OP_REQUIRES(context, dilations_.size() == 4,
    270                 errors::InvalidArgument("Sliding window dilations field must "
    271                                         "specify 4 dimensions"));
    272     OP_REQUIRES(context, strides_.size() == 4,
    273                 errors::InvalidArgument("Sliding window strides field must "
    274                                         "specify 4 dimensions"));
    275     const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
    276     const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
    277     const int64 stride_h = GetTensorDim(strides_, data_format_, 'H');
    278     const int64 stride_w = GetTensorDim(strides_, data_format_, 'W');
    279     OP_REQUIRES(
    280         context, stride_n == 1 && stride_c == 1,
    281         errors::InvalidArgument("Current implementation does not yet support "
    282                                 "strides in the batch and depth dimensions."));
    283     OP_REQUIRES(context, stride_h > 0 && stride_w > 0,
    284                 errors::InvalidArgument(
    285                     "Row and column strides should be larger than 0."));
    286 
    287     const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N');
    288     const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C');
    289     const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H');
    290     const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W');
    291     OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
    292                 errors::InvalidArgument(
    293                     "Current implementation does not yet support "
    294                     "dilations in the batch and depth dimensions."));
    295     OP_REQUIRES(
    296         context, dilation_h > 0 && dilation_w > 0,
    297         errors::InvalidArgument("Dilated rates should be larger than 0."));
    298     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    299   }
    300 
    301   void Compute(OpKernelContext* context) override {
    302     // Input tensor is of the following dimensions:
    303     // [ batch, in_rows, in_cols, in_depth ]
    304 
    305     const Tensor& input = context->input(0);
    306 
    307     // Input filter is of the following dimensions:
    308     // [ filter_rows, filter_cols, in_depth, out_depth]
    309     const Tensor& filter = context->input(1);
    310 
    311     // For 2D convolution, there should be 4 dimensions.
    312     OP_REQUIRES(context, input.dims() == 4,
    313                 errors::InvalidArgument("input must be 4-dimensional",
    314                                         input.shape().DebugString()));
    315     OP_REQUIRES(context, filter.dims() == 4,
    316                 errors::InvalidArgument("filter must be 4-dimensional: ",
    317                                         filter.shape().DebugString()));
    318 
    319     for (int i = 0; i < 3; i++) {
    320       OP_REQUIRES(
    321           context,
    322           FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
    323           errors::InvalidArgument("filter too large"));
    324     }
    325 
    326     // The last dimension for input is in_depth. It must be the same as the
    327     // filter's in_depth.
    328     const int64 in_depth = GetTensorDim(input, data_format_, 'C');
    329     OP_REQUIRES(context, in_depth == filter.dim_size(2),
    330                 errors::InvalidArgument(
    331                     "input and filter must have the same depth: ", in_depth,
    332                     " vs ", filter.dim_size(2)));
    333 
    334     // The last dimension for filter is out_depth.
    335     const int out_depth = static_cast<int>(filter.dim_size(3));
    336 
    337     // The second dimension for input is rows/height.
    338     // The first dimension for filter is rows/height.
    339     const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H');
    340     OP_REQUIRES(
    341         context,
    342         FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
    343         errors::InvalidArgument("Input rows too large"));
    344     const int input_rows = static_cast<int>(input_rows_raw);
    345     const int filter_rows = static_cast<int>(filter.dim_size(0));
    346 
    347     // The third dimension for input is columns/width.
    348     // The second dimension for filter is columns/width.
    349     const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W');
    350     OP_REQUIRES(
    351         context,
    352         FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
    353         errors::InvalidArgument("Input cols too large"));
    354     const int input_cols = static_cast<int>(input_cols_raw);
    355     const int filter_cols = static_cast<int>(filter.dim_size(1));
    356 
    357     // The first dimension for input is batch.
    358     const int64 batch_raw = GetTensorDim(input, data_format_, 'N');
    359     OP_REQUIRES(context,
    360                 FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
    361                 errors::InvalidArgument("batch is too large"));
    362     const int batch = static_cast<int>(batch_raw);
    363 
    364     // For now we take the stride and dilation from the second and third
    365     // dimensions only (we do not support striding or dilation on the batch or
    366     // depth dimension).
    367     const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
    368     const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
    369 
    370     const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H');
    371     const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W');
    372 
    373     int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
    374     OP_REQUIRES_OK(context, GetWindowedOutputSizeV2(
    375                                 input_rows, filter_rows, dilation_rows,
    376                                 stride_rows, padding_, &out_rows, &pad_rows));
    377     OP_REQUIRES_OK(context, GetWindowedOutputSizeV2(
    378                                 input_cols, filter_cols, dilation_cols,
    379                                 stride_cols, padding_, &out_cols, &pad_cols));
    380     TensorShape out_shape =
    381         ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
    382 
    383     // Output tensor is of the following dimensions:
    384     // [ in_batch, out_rows, out_cols, out_depth ]
    385     Tensor* output = nullptr;
    386     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
    387 
    388     VLOG(2) << "Conv2D: in_depth = " << in_depth
    389             << ", input_cols = " << input_cols
    390             << ", filter_cols = " << filter_cols
    391             << ", input_rows = " << input_rows
    392             << ", filter_rows = " << filter_rows
    393             << ", stride_rows = " << stride_rows
    394             << ", stride_cols = " << stride_cols
    395             << ", dilation_rows = " << dilation_rows
    396             << ", dilation_cols = " << dilation_cols
    397             << ", out_depth = " << out_depth;
    398 
    399     // If there is nothing to compute, return.
    400     if (out_shape.num_elements() == 0) {
    401       return;
    402     }
    403 
    404 #ifdef TENSORFLOW_USE_LIBXSMM
    405     if (LaunchXsmmConvOp<Device, T>::Run(
    406             context, input, filter, batch, input_rows, input_cols, in_depth,
    407             filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols,
    408             out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols,
    409             output, data_format_)) {
    410       return;
    411     }
    412 #endif
    413 
    414     if (LaunchDeepConvOp<Device, T>::Run(
    415             context, input, filter, batch, input_rows, input_cols, in_depth,
    416             filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols,
    417             out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols,
    418             output, data_format_)) {
    419       return;
    420     }
    421 
    422     launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
    423               dilation_rows, dilation_cols, stride_rows, stride_cols, padding_,
    424               output, data_format_);
    425   }
    426 
    427  private:
    428   std::vector<int32> dilations_;
    429   std::vector<int32> strides_;
    430   bool use_cudnn_;
    431   Padding padding_;
    432   TensorFormat data_format_;
    433   LaunchConv2DOp<Device, T> launcher_;
    434   bool cudnn_use_autotune_;
    435 
    436   TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp);
    437 };
    438 
    439 #define REGISTER_CPU(T)                                         \
    440   REGISTER_KERNEL_BUILDER(                                      \
    441       Name("Conv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    442       Conv2DOp<CPUDevice, T>);
    443 
    444 // If we're using the alternative GEMM-based implementation of Conv2D for the
    445 // CPU implementation, don't register this EigenTensor-based version.
    446 #if !defined(USE_GEMM_FOR_CONV)
    447 TF_CALL_half(REGISTER_CPU);
    448 TF_CALL_float(REGISTER_CPU);
    449 #endif  // USE_GEMM_FOR_CONV
    450 
    451 // To be used inside depthwise_conv_op.cc.
    452 template class LaunchConv2DOp<CPUDevice, float>;
    453 
    454 #if GOOGLE_CUDA
    455 int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb,
    456                              int64 default_value_in_bytes) {
    457   const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
    458   if (workspace_limit_in_mb_str != nullptr &&
    459       strcmp(workspace_limit_in_mb_str, "") != 0) {
    460     int64 scratch_limit_in_mb = -1;
    461     if (strings::safe_strto64(workspace_limit_in_mb_str,
    462                               &scratch_limit_in_mb)) {
    463       return scratch_limit_in_mb * (1 << 20);
    464     } else {
    465       LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
    466                    << workspace_limit_in_mb_str;
    467     }
    468   }
    469   return default_value_in_bytes;
    470 }
    471 
    472 // A dummy type to group forward convolution autotune results together.
    473 struct ConvAutoTuneGroup {
    474   static string name() { return "Conv"; }
    475 };
    476 typedef AutoTuneSingleton<ConvAutoTuneGroup, ConvParameters,
    477                           perftools::gputools::dnn::AlgorithmConfig>
    478     AutoTuneConv;
    479 
    480 template <typename T>
    481 void LaunchConv2DOp<GPUDevice, T>::operator()(
    482     OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
    483     const Tensor& input_param, const Tensor& filter, int row_dilation,
    484     int col_dilation, int row_stride, int col_stride, const Padding& padding,
    485     Tensor* output, TensorFormat data_format) {
    486   using perftools::gputools::dnn::AlgorithmConfig;
    487   using perftools::gputools::dnn::AlgorithmDesc;
    488   using perftools::gputools::dnn::ProfileResult;
    489   auto* stream = ctx->op_device_context()->stream();
    490   OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
    491 
    492   if (!use_cudnn) {
    493     ctx->SetStatus(
    494         errors::Unimplemented("Conv2D for GPU is not currently supported "
    495                               "without cudnn"));
    496     return;
    497   }
    498 
    499   Tensor input = input_param;
    500 
    501   if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_dilation == 1 &&
    502       col_dilation == 1 && row_stride == 1 && col_stride == 1 &&
    503       data_format == FORMAT_NHWC) {
    504     // 1x1 filter, so call cublas directly.
    505     const uint64 m = input.dim_size(0) * input.dim_size(1) * input.dim_size(2);
    506     const uint64 k = filter.dim_size(2);
    507     const uint64 n = filter.dim_size(3);
    508 
    509     auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
    510                                 input.template flat<T>().size());
    511     auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
    512                                 filter.template flat<T>().size());
    513     auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
    514                                 output->template flat<T>().size());
    515 
    516     auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
    517     bool blas_launch_status =
    518         stream
    519             ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, n,
    520                            a_ptr, k, 0.0f, &c_ptr, n)
    521             .ok();
    522     if (!blas_launch_status) {
    523       ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
    524                                       ", n=", n, ", k=", k));
    525     }
    526     return;
    527   } else if (filter.dim_size(0) == input.dim_size(1) &&
    528              filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 &&
    529              col_dilation == 1 && padding == VALID &&
    530              data_format == FORMAT_NHWC) {
    531     // The input data and filter have the same height/width, so call cublas
    532     // directly.
    533     const uint64 m = input.dim_size(0);
    534     const uint64 k =
    535         filter.dim_size(0) * filter.dim_size(1) * filter.dim_size(2);
    536     const uint64 n = filter.dim_size(3);
    537 
    538     auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
    539                                 input.template flat<T>().size());
    540     auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
    541                                 filter.template flat<T>().size());
    542     auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
    543                                 output->template flat<T>().size());
    544 
    545     auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
    546     bool blas_launch_status =
    547         stream
    548             ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, n,
    549                            a_ptr, k, 0.0f, &c_ptr, n)
    550             .ok();
    551     if (!blas_launch_status) {
    552       ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
    553                                       ", n=", n, ", k=", k));
    554     }
    555     return;
    556   }
    557 
    558   int padding_rows = 0;
    559   int padding_cols = 0;
    560   const int64 in_batch = GetTensorDim(input, data_format, 'N');
    561   int64 in_rows = GetTensorDim(input, data_format, 'H');
    562   int64 in_cols = GetTensorDim(input, data_format, 'W');
    563   const int64 in_depths = GetTensorDim(input, data_format, 'C');
    564   const int64 out_batch = GetTensorDim(*output, data_format, 'N');
    565   const int64 out_rows = GetTensorDim(*output, data_format, 'H');
    566   const int64 out_cols = GetTensorDim(*output, data_format, 'W');
    567   const int64 out_depths = GetTensorDim(*output, data_format, 'C');
    568   const int64 patch_rows = filter.dim_size(0);
    569   const int64 patch_cols = filter.dim_size(1);
    570   if (padding == SAME) {
    571     // Total padding on rows and cols is
    572     // Pr = (R' - 1) * S + (Kr - 1) * Dr + 1 - R
    573     // Pc = (C' - 1) * S + (Kc - 1) * Dc + 1 - C
    574     // where (R', C') are output dimensions, (R, C) are input dimensions, S
    575     // is stride, (Dr, Dc) are dilations, (Kr, Kc) are filter dimensions.
    576     // We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top
    577     // and Pc - Pc/2 on the bottom.  When Pr or Pc is odd, this means
    578     // we pad more on the right and bottom than on the top and left.
    579     padding_rows =
    580         std::max<int>(0, (out_rows - 1) * row_stride +
    581                              (patch_rows - 1) * row_dilation + 1 - in_rows);
    582     padding_cols =
    583         std::max<int>(0, (out_cols - 1) * col_stride +
    584                              (patch_cols - 1) * col_dilation + 1 - in_cols);
    585     const bool rows_odd = (padding_rows % 2 != 0);
    586     const bool cols_odd = (padding_cols % 2 != 0);
    587     if (rows_odd || cols_odd) {
    588       Tensor transformed_input;
    589       int64 new_in_rows = in_rows + rows_odd;
    590       int64 new_in_cols = in_cols + cols_odd;
    591       OP_REQUIRES_OK(
    592           ctx,
    593           ctx->allocate_temp(DataTypeToEnum<T>::value,
    594                              ShapeFromFormat(data_format, in_batch, new_in_rows,
    595                                              new_in_cols, in_depths),
    596                              &transformed_input));
    597 
    598       functor::PadInput<GPUDevice, T, int, 4>()(
    599           ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 4>()),
    600           {{0, 0}}, {{rows_odd, cols_odd}},
    601           To32Bit(transformed_input.tensor<T, 4>()), data_format);
    602 
    603       input = transformed_input;
    604       in_rows = new_in_rows;
    605       in_cols = new_in_cols;
    606     }
    607   }
    608 
    609   if (data_format == FORMAT_NHWC) {
    610     // Convert the input tensor from NHWC to NCHW.
    611     TensorShape nchw_shape =
    612         ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths);
    613     if (in_depths > 1) {
    614       Tensor transformed_input;
    615       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
    616                                              nchw_shape, &transformed_input));
    617       functor::NHWCToNCHW<GPUDevice, T, 4>()(
    618           ctx->eigen_device<GPUDevice>(),
    619           const_cast<const Tensor&>(input).tensor<T, 4>(),
    620           transformed_input.tensor<T, 4>());
    621       input = transformed_input;
    622     } else {
    623       // If depth <= 1, then just reshape.
    624       CHECK(input.CopyFrom(input, nchw_shape));
    625     }
    626   }
    627 
    628   CHECK(padding_rows >= 0 && padding_cols >= 0)
    629       << "Negative row or col paddings: (" << padding_rows << ", "
    630       << padding_cols << ")";
    631   perftools::gputools::dnn::BatchDescriptor input_desc;
    632   input_desc.set_count(in_batch)
    633       .set_feature_map_count(in_depths)
    634       .set_height(in_rows)
    635       .set_width(in_cols)
    636       .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
    637   perftools::gputools::dnn::BatchDescriptor output_desc;
    638   output_desc.set_count(out_batch)
    639       .set_height(out_rows)
    640       .set_width(out_cols)
    641       .set_feature_map_count(out_depths)
    642       .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
    643   perftools::gputools::dnn::FilterDescriptor filter_desc;
    644   filter_desc.set_input_filter_height(filter.dim_size(0))
    645       .set_input_filter_width(filter.dim_size(1))
    646       .set_input_feature_map_count(filter.dim_size(2))
    647       .set_output_feature_map_count(filter.dim_size(3));
    648   perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
    649   conv_desc.set_vertical_dilation_rate(row_dilation)
    650       .set_horizontal_dilation_rate(col_dilation)
    651       .set_vertical_filter_stride(row_stride)
    652       .set_horizontal_filter_stride(col_stride)
    653       .set_zero_padding_height(padding_rows / 2)
    654       .set_zero_padding_width(padding_cols / 2);
    655 
    656   Tensor transformed_filter;
    657   OP_REQUIRES_OK(ctx, ctx->allocate_temp(
    658                           DataTypeToEnum<T>::value,
    659                           TensorShape({filter.dim_size(3), filter.dim_size(2),
    660                                        filter.dim_size(0), filter.dim_size(1)}),
    661                           &transformed_filter));
    662 
    663   functor::TransformFilter<GPUDevice, T, int, 4>()(
    664       ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
    665       To32Bit(transformed_filter.tensor<T, 4>()));
    666 
    667   Tensor transformed_output;
    668   OP_REQUIRES_OK(
    669       ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
    670                               ShapeFromFormat(FORMAT_NCHW, out_batch, out_rows,
    671                                               out_cols, out_depths),
    672                               &transformed_output));
    673 
    674   auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
    675                                   input.template flat<T>().size());
    676   auto filter_ptr =
    677       AsDeviceMemory(transformed_filter.template flat<T>().data(),
    678                      transformed_filter.template flat<T>().size());
    679   auto output_ptr =
    680       AsDeviceMemory(transformed_output.template flat<T>().data(),
    681                      transformed_output.template flat<T>().size());
    682 
    683   static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit(
    684       // default value is in bytes despite the name of the environment variable
    685       "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32  // 4GB
    686   );
    687 
    688   int device_id = stream->parent()->device_ordinal();
    689   DataType dtype = input.dtype();
    690   ConvParameters conv_parameters = {
    691       in_batch,          // batch
    692       in_depths,         // in_depths
    693       {{in_rows,         // in_rows
    694         in_cols}},       // in_cols
    695       out_depths,        // out_depths
    696       {{patch_rows,      // filter_rows
    697         patch_cols}},    // filter_cols
    698       {{row_dilation,    // dilation_rows
    699         col_dilation}},  // dilation_cols
    700       {{row_stride,      // stride_rows
    701         col_stride}},    // stride_cols
    702       {{padding_rows,    // padding_rows
    703         padding_cols}},  // padding_cols
    704       dtype,             // tensor datatype
    705       device_id,         // device_id
    706   };
    707   AlgorithmConfig algorithm_config;
    708   if (cudnn_use_autotune &&
    709       !AutoTuneConv::GetInstance()->Find(conv_parameters, &algorithm_config)) {
    710     std::vector<AlgorithmDesc> algorithms;
    711     CHECK(stream->parent()->GetConvolveAlgorithms(
    712         conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
    713     ProfileResult best_result;
    714     ProfileResult best_result_no_scratch;
    715     for (auto profile_algorithm : algorithms) {
    716       // TODO(zhengxq): profile each algorithm multiple times to better
    717       // accuracy.
    718       CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
    719       ProfileResult profile_result;
    720       bool cudnn_launch_status =
    721           stream
    722               ->ThenConvolveWithAlgorithm(
    723                   input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
    724                   output_desc, &output_ptr, &scratch_allocator,
    725                   AlgorithmConfig(profile_algorithm), &profile_result)
    726               .ok();
    727       if (cudnn_launch_status) {
    728         if (profile_result.is_valid()) {
    729           if (profile_result.elapsed_time_in_ms() <
    730               best_result.elapsed_time_in_ms()) {
    731             best_result = profile_result;
    732           }
    733           if (scratch_allocator.TotalByteSize() == 0 &&
    734               profile_result.elapsed_time_in_ms() <
    735                   best_result_no_scratch.elapsed_time_in_ms()) {
    736             best_result_no_scratch = profile_result;
    737           }
    738         }
    739       }
    740     }
    741     // TODO(yangzihao): refactor the profile result checking code into a common
    742     // utility function.
    743     OP_REQUIRES(ctx,
    744                 best_result.is_valid() || best_result_no_scratch.is_valid(),
    745                 errors::NotFound("No algorithm worked!"));
    746     if (best_result.is_valid()) {
    747       algorithm_config.set_algorithm(best_result.algorithm());
    748     }
    749     if (best_result_no_scratch.is_valid()) {
    750       algorithm_config.set_algorithm_no_scratch(
    751           best_result_no_scratch.algorithm());
    752     }
    753     AutoTuneConv::GetInstance()->Insert(conv_parameters, algorithm_config);
    754   }
    755 
    756   CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
    757   bool cudnn_launch_status =
    758       stream
    759           ->ThenConvolveWithAlgorithm(input_desc, input_ptr, filter_desc,
    760                                       filter_ptr, conv_desc, output_desc,
    761                                       &output_ptr, &scratch_allocator,
    762                                       algorithm_config, nullptr)
    763           .ok();
    764 
    765   if (!cudnn_launch_status) {
    766     ctx->SetStatus(errors::Internal(
    767         "cuDNN launch failure : input shape(", input.shape().DebugString(),
    768         ") filter shape(", filter.shape().DebugString(), ")"));
    769   }
    770 
    771   // Convert the output tensor back from NHWC to NCHW.
    772   if (data_format == FORMAT_NHWC) {
    773     functor::NCHWToNHWC<GPUDevice, T, 4>()(
    774         ctx->eigen_device<GPUDevice>(),
    775         const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
    776         output->tensor<T, 4>());
    777   } else {
    778     *output = transformed_output;
    779   }
    780 }
    781 
    782 // Forward declarations of the functor specializations for GPU.
    783 namespace functor {
    784 #define DECLARE_GPU_SPEC(T)                                                  \
    785   template <>                                                                \
    786   void SpatialConvolution<GPUDevice, T>::operator()(                         \
    787       const GPUDevice& d, typename TTypes<T, 4>::Tensor output,              \
    788       typename TTypes<T, 4>::ConstTensor input,                              \
    789       typename TTypes<T, 4>::ConstTensor filter, int row_stride,             \
    790       int col_stride, int row_dilation, int col_dilation,                    \
    791       const Eigen::PaddingType& padding);                                    \
    792   extern template struct SpatialConvolution<GPUDevice, T>;                   \
    793   template <>                                                                \
    794   void MatMulConvFunctor<GPUDevice, T>::operator()(                          \
    795       const GPUDevice& d, typename TTypes<T, 2>::Tensor out,                 \
    796       typename TTypes<T, 2>::ConstTensor in0,                                \
    797       typename TTypes<T, 2>::ConstTensor in1,                                \
    798       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair); \
    799   extern template struct MatMulConvFunctor<GPUDevice, T>;                    \
    800   template <>                                                                \
    801   void TransformFilter<GPUDevice, T, int, 4>::operator()(                    \
    802       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,        \
    803       typename TTypes<T, 4, int>::Tensor out);                               \
    804   extern template struct TransformFilter<GPUDevice, T, int, 4>;              \
    805   template <>                                                                \
    806   void PadInput<GPUDevice, T, int, 4>::operator()(                           \
    807       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,        \
    808       const std::array<int, 2>& padding_left,                                \
    809       const std::array<int, 2>& padding_right,                               \
    810       typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format);     \
    811   extern template struct PadInput<GPUDevice, T, int, 4>
    812 
    813 DECLARE_GPU_SPEC(float);
    814 DECLARE_GPU_SPEC(Eigen::half);
    815 #undef DECLARE_GPU_SPEC
    816 }  // namespace functor
    817 
    818 // Registration of the GPU implementations.
    819 REGISTER_KERNEL_BUILDER(
    820     Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
    821     Conv2DOp<GPUDevice, Eigen::half>);
    822 REGISTER_KERNEL_BUILDER(
    823     Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<float>("T"),
    824     Conv2DOp<GPUDevice, float>);
    825 
    826 // To be used inside depthwise_conv_op.cc.
    827 template class LaunchConv2DOp<GPUDevice, float>;
    828 
    829 #endif  // GOOGLE_CUDA
    830 
    831 }  // namespace tensorflow
    832