Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #if GOOGLE_CUDA
     17 #define EIGEN_USE_GPU
     18 
     19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     20 #include "external/cub_archive/cub/util_ptx.cuh"
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 #include "tensorflow/core/kernels/depthwise_conv_op.h"
     23 #include "tensorflow/core/platform/types.h"
     24 #include "tensorflow/core/util/cuda_kernel_helper.h"
     25 #include "tensorflow/core/util/tensor_format.h"
     26 
     27 #if defined(_MSC_VER) && !defined(__clang__)
     28 #define UNROLL
     29 #define NOUNROLL
     30 #else
     31 #define UNROLL _Pragma("unroll")
     32 #define NOUNROLL _Pragma("nounroll")
     33 #endif
     34 
     35 namespace tensorflow {
     36 
     37 using Eigen::GpuDevice;
     38 
     39 // Returns whether depthwise convolution forward or backward input pass can be
     40 // performed using the faster ('Small') variant of the kernel.
     41 EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dGPUSmall(
     42     const DepthwiseArgs& args) {
     43   return args.depth_multiplier == 1 && args.stride == 1 && args.in_rows <= 32 &&
     44          args.in_cols <= 32 && args.in_rows == args.out_rows &&
     45          args.in_cols == args.out_cols && args.pad_rows >= 0 &&
     46          args.pad_rows < args.filter_rows && args.pad_cols >= 0 &&
     47          args.pad_cols < args.filter_cols &&
     48          args.filter_rows * args.filter_cols <=
     49              (args.in_rows + 1) / 2 * args.in_cols;
     50 }
     51 
     52 // Returns whether depthwise convolution backward filter pass can be performed
     53 // using the faster ('Small') variant of the kernel.
     54 EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(
     55     const DepthwiseArgs& args, const int block_height) {
     56   return args.depth_multiplier == 1 && args.stride == 1 && args.in_rows <= 32 &&
     57          args.in_cols <= 32 && args.in_rows == args.out_rows &&
     58          args.in_cols == args.out_cols && args.pad_rows >= 0 &&
     59          args.pad_rows < args.filter_rows && args.pad_cols >= 0 &&
     60          args.pad_cols < args.filter_cols && block_height <= args.in_rows &&
     61          args.filter_rows * args.filter_cols <= args.in_cols * block_height;
     62 }
     63 
     64 // The DepthwiseConv2dGPUKernels perform either forward or backprop input
     65 // convolution depending on a template argument of this enum.
     66 enum DepthwiseConv2dDirection { DIRECTION_FORWARD, DIRECTION_BACKWARD };
     67 
     68 // A Cuda kernel to compute the depthwise convolution forward pass
     69 // in NHWC format.
     70 template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
     71           int kKnownDepthMultiplier>
     72 __global__ void __launch_bounds__(1024, 2)
     73     DepthwiseConv2dGPUKernelNHWC(const DepthwiseArgs args, const T* input,
     74                                  const T* filter, T* output, int num_outputs) {
     75   const int in_height = args.in_rows;
     76   const int in_width = args.in_cols;
     77   const int in_depth = args.in_depth;
     78   const int filter_height =
     79       kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
     80   const int filter_width =
     81       kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
     82   const int depth_multiplier =
     83       kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
     84   const int stride = args.stride;
     85   const int pad_height = args.pad_rows;
     86   const int pad_width = args.pad_cols;
     87   const int out_height = args.out_rows;
     88   const int out_width = args.out_cols;
     89   const int out_depth = args.out_depth;
     90 
     91   CUDA_1D_KERNEL_LOOP(thread_id, num_outputs) {
     92     // Compute the indexes of this thread in the output.
     93     const int out_channel = thread_id % out_depth;
     94     const int out_col = (thread_id / out_depth) % out_width;
     95     const int out_row = (thread_id / out_depth / out_width) % out_height;
     96     const int batch = thread_id / out_depth / out_width / out_height;
     97     // Compute the input depth and the index of depth multiplier.
     98     const int in_channel = out_channel / depth_multiplier;
     99     const int multiplier = out_channel % depth_multiplier;
    100 
    101     // Decide if all input is valid, if yes, we can skip the boundary checks
    102     // for each input.
    103     const int input_row_start = out_row * stride - pad_height;
    104     const int input_col_start = out_col * stride - pad_width;
    105     const int input_row_end = input_row_start + filter_height;
    106     const int input_col_end = input_col_start + filter_width;
    107 
    108     T sum = static_cast<T>(0);
    109 
    110     const int input_offset_temp = in_height * batch;
    111     if (input_row_start >= 0 && input_col_start >= 0 &&
    112         input_row_end < in_height && input_col_end < in_width) {
    113       UNROLL for (int filter_row = 0; filter_row < filter_height;
    114                   ++filter_row) {
    115         const int in_row = input_row_start + filter_row;
    116         const int filter_offset_temp = filter_width * filter_row;
    117         UNROLL for (int filter_col = 0; filter_col < filter_width;
    118                     ++filter_col) {
    119           const int in_col = input_col_start + filter_col;
    120 
    121           const int input_offset =
    122               in_channel +
    123               in_depth * (in_col + in_width * (in_row + input_offset_temp));
    124           const int filter_offset =
    125               multiplier +
    126               depth_multiplier *
    127                   (in_channel + in_depth * (filter_col + filter_offset_temp));
    128           sum += ldg(input + input_offset) * ldg(filter + filter_offset);
    129         }
    130       }
    131     } else {
    132       UNROLL for (int filter_row = 0; filter_row < filter_height;
    133                   ++filter_row) {
    134         const int in_row = input_row_start + filter_row;
    135         const int filter_offset_temp = filter_width * filter_row;
    136         UNROLL for (int filter_col = 0; filter_col < filter_width;
    137                     ++filter_col) {
    138           const int in_col = input_col_start + filter_col;
    139           if (in_row >= 0 && in_row < in_height && in_col >= 0 &&
    140               in_col < in_width) {
    141             const int in_col = input_col_start + filter_col;
    142 
    143             const int input_offset =
    144                 in_channel +
    145                 in_depth * (in_col + in_width * (in_row + input_offset_temp));
    146             const int filter_offset =
    147                 multiplier +
    148                 depth_multiplier *
    149                     (in_channel + in_depth * (filter_col + filter_offset_temp));
    150             sum += ldg(input + input_offset) * ldg(filter + filter_offset);
    151           }
    152         }
    153       }
    154     }
    155     output[thread_id] = sum;
    156   }
    157 }
    158 
    159 // CUDA kernel to compute the depthwise convolution forward pass in NHWC format,
    160 // tailored for small images up to 32x32. Stride and depth multiplier must be 1.
    161 // Padding must be 'SAME', which allows to reuse the index computation. Only
    162 // use this kernel if CanLaunchDepthwiseConv2dGPUSmall(args) returns true.
    163 // Tiles of the input and filter tensors are loaded into shared memory before
    164 // performing the convolution. Each thread handles two elements per iteration,
    165 // one each in the lower and upper half of a tile.
    166 // Backprop input direction is the same as forward direction with the filter
    167 // rotated by 180.
    168 template <typename T, DepthwiseConv2dDirection kDirection,
    169           int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth,
    170           bool kKnownEvenHeight>
    171 __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall(
    172     const DepthwiseArgs args, const T* input, const T* filter, T* output) {
    173   assert(CanLaunchDepthwiseConv2dGPUSmall(args));
    174   // Holds block plus halo and filter data for blockDim.x depths.
    175   extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[];
    176   T* const shared_data = reinterpret_cast<T*>(shared_memory);
    177 
    178   const int num_batches = args.batch;
    179   const int in_height = args.in_rows;
    180   const int in_width = args.in_cols;
    181   const int in_depth = args.in_depth;
    182   const int filter_height =
    183       kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
    184   const int filter_width =
    185       kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
    186   const int pad_height = args.pad_rows;
    187   const int pad_width = args.pad_cols;
    188 
    189   assert(blockDim.x == kBlockDepth);
    190   assert(blockDim.y == args.in_cols);
    191   const int block_height = blockDim.z;
    192 
    193   // These values are the same for all threads and could
    194   // be precomputed on the CPU.
    195   const int block_size = block_height * in_width * kBlockDepth;
    196   const int in_row_size = in_width * in_depth;
    197   const int in_size = in_height * in_row_size;
    198   const int in_increment = (in_width - 1) * kBlockDepth;
    199   const int filter_pixels = filter_height * filter_width;
    200   const int tile_width = in_width + filter_width - 1;
    201   const int even_height = kKnownEvenHeight || (1 & ~in_height);
    202   const int tile_height = in_height + filter_height - even_height;
    203   const int tile_row_size = tile_width * kBlockDepth;
    204   const int tile_size = tile_height * tile_row_size;
    205   const int tile_offset = block_height * tile_row_size;
    206   const int pad_offset = pad_height * tile_width + pad_width;
    207   const int batch_blocks = (in_depth + kBlockDepth - 1) / kBlockDepth;
    208   const int in_blocks = batch_blocks * num_batches;
    209   const int tensor_offset =
    210       kKnownEvenHeight ? in_size / 2 : block_height * in_row_size;
    211 
    212   const int thread_depth = threadIdx.x;
    213   const int thread_col = threadIdx.y;
    214   const int thread_row = threadIdx.z;
    215 
    216   // Position in block.
    217   const int thread_pix = thread_row * in_width + thread_col;
    218   const int thread_idx = thread_pix * kBlockDepth + thread_depth;
    219 
    220   // Initialize tile, in particular the padding.
    221   for (int i = thread_idx; i < tile_size; i += block_size) {
    222     shared_data[i] = T(0);
    223   }
    224   __syncthreads();
    225 
    226   // Position in tensors.
    227   const int tensor_idx = thread_pix * in_depth + thread_depth;
    228 
    229   // Position in (padded) shared memory.
    230   const int data_pix = thread_row * tile_width + thread_col;
    231   const int data_idx = data_pix * kBlockDepth + thread_depth;
    232 
    233   // Position in shared memory, offset by pad_height / pad_width.
    234   const int tile_pix = data_pix + pad_offset;
    235   const int tile_idx = tile_pix * kBlockDepth + thread_depth;
    236 
    237   const int max_channel = in_depth - thread_depth;
    238   const int filter_write_offset =
    239       thread_pix < filter_pixels ? tile_size + thread_idx : 0;
    240   const int filter_read_offset =
    241       tile_size + thread_depth +
    242       (kDirection == DIRECTION_FORWARD ? 0 : filter_pixels * kBlockDepth);
    243   const bool skip_second =
    244       !kKnownEvenHeight && thread_row + (in_height & 1) == block_height;
    245 
    246   for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) {
    247     const int batch = b / batch_blocks;
    248     const int block = b - batch * batch_blocks;
    249 
    250     const int start_channel = block * kBlockDepth;
    251     const int filter_offset = tensor_idx + start_channel;
    252     const int inout_offset = batch * in_size + filter_offset;
    253     const bool channel_in_range = start_channel < max_channel;
    254 
    255     if (channel_in_range) {
    256       const T* const in_ptr = inout_offset + input;
    257       T* const tile_ptr = tile_idx + shared_data;
    258       tile_ptr[0] = ldg(in_ptr);
    259       if (!skip_second) {
    260         tile_ptr[tile_offset] = ldg(tensor_offset + in_ptr);
    261       }
    262 
    263       if (filter_write_offset != 0) {
    264         shared_data[filter_write_offset] = ldg(filter_offset + filter);
    265       }
    266     }
    267 
    268     // Note: the condition to reach this is uniform across the entire block.
    269     __syncthreads();
    270 
    271     if (channel_in_range) {
    272       T sum1 = static_cast<T>(0);
    273       T sum2 = static_cast<T>(0);
    274       int shared_offset = data_idx;
    275       const T* filter_ptr = filter_read_offset + shared_data;
    276       UNROLL for (int r = 0; r < filter_height; ++r) {
    277         UNROLL for (int c = 0; c < filter_width; ++c) {
    278           if (kDirection == DIRECTION_BACKWARD) {
    279             filter_ptr -= kBlockDepth;
    280           }
    281           const T filter_value = *filter_ptr;
    282           const T* const tile_ptr = shared_offset + shared_data;
    283           sum1 += filter_value * tile_ptr[0];
    284           sum2 += filter_value * tile_ptr[tile_offset];
    285           shared_offset += kBlockDepth;
    286           if (kDirection == DIRECTION_FORWARD) {
    287             filter_ptr += kBlockDepth;
    288           }
    289         }
    290         shared_offset += in_increment;
    291       }
    292       T* const out_ptr = inout_offset + output;
    293       out_ptr[0] = sum1;
    294       if (!skip_second) {
    295         out_ptr[tensor_offset] = sum2;
    296       }
    297     }
    298 
    299     // Note: the condition to reach this is uniform across the entire block.
    300     __syncthreads();
    301   }
    302 }
    303 
    304 // A Cuda kernel to compute the depthwise convolution forward pass
    305 // in NCHW format.
    306 template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
    307           int kKnownDepthMultiplier>
    308 __global__ void __launch_bounds__(1024, 2)
    309     DepthwiseConv2dGPUKernelNCHW(const DepthwiseArgs args, const T* input,
    310                                  const T* filter, T* output, int num_outputs) {
    311   const int in_height = args.in_rows;
    312   const int in_width = args.in_cols;
    313   const int in_depth = args.in_depth;
    314   const int filter_height =
    315       kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
    316   const int filter_width =
    317       kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
    318   const int depth_multiplier =
    319       kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
    320   const int stride = args.stride;
    321   const int pad_height = args.pad_rows;
    322   const int pad_width = args.pad_cols;
    323   const int out_height = args.out_rows;
    324   const int out_width = args.out_cols;
    325   const int out_depth = args.out_depth;
    326 
    327   CUDA_1D_KERNEL_LOOP(thread_id, num_outputs) {
    328     // Compute the indexes of this thread in the output.
    329     //
    330     // We want coalesced reads so we make sure that each warp reads
    331     // a contiguous chunk of memory.
    332     //
    333     // THIS IS PROBABLY WRONG, we are not doing coalesced reads
    334     // into the input, because of the depth multiplier division...
    335     const int out_col = thread_id % out_width;
    336     const int out_row = (thread_id / out_width) % out_height;
    337     const int out_channel = (thread_id / out_width / out_height) % out_depth;
    338     const int batch = thread_id / out_width / out_height / out_depth;
    339 
    340     // Compute the input depth and the index of depth multiplier
    341     // based off the output depth index that this thread is
    342     // computing n.
    343     const int in_channel = out_channel / depth_multiplier;
    344     const int multiplier = out_channel % depth_multiplier;
    345 
    346     // Data is stored in the following format (let's assume we
    347     // flatten the height and width into one contiguous dimension
    348     // called "P".
    349     //
    350     // B1C1P1 B1C1P2 ..... B1C2P1 B1C2P2 ....
    351     // B2C1P1 B2C1P2 ..... B2C2P1 B2C2P2 ....
    352     //
    353     // Each row contains in_depth * in_height * in_width values
    354     // for each sample in the batch.
    355     //
    356     // We can further flatten it into:
    357     //
    358     // B1C1P1 B1C1P2 .....
    359     // B1C2P1 B1C2P2 ....
    360     // B2C1P1 B2C1P2 .....
    361     // B2C2P1 B2C2P2 ....
    362     //
    363     // where each row is a contiguous array of all of the spatial
    364     // pixels for a given batch and input depth.  The following
    365     // loop unrolls across the filter dimensions for a given thread,
    366     // indexing into the filter value and the corresponding input
    367     // patch.
    368     //
    369     // We can compute the index into the patch once right here.
    370     const int input_offset_temp =
    371         (batch * in_depth + in_channel) * (in_height * in_width);
    372 
    373     // Finally, we can iterate over the spatial dimensions and perform the
    374     // convolution, writing into the output at the end.
    375     //
    376     // We perform an additional optimization, where we can determine
    377     // whether the patch fits within the image indices statically, and
    378     // avoid boundary checking within the loop.
    379     const int input_row_start = out_row * stride - pad_height;
    380     const int input_col_start = out_col * stride - pad_width;
    381     const int input_row_end = input_row_start + filter_height;
    382     const int input_col_end = input_col_start + filter_width;
    383 
    384     T sum = static_cast<T>(0);
    385     if (input_row_start >= 0 && input_col_start >= 0 &&
    386         input_row_end < in_height && input_col_end < in_width) {
    387       // Loop that doesn't need to check for boundary conditions.
    388       UNROLL for (int filter_row = 0; filter_row < filter_height;
    389                   ++filter_row) {
    390         const int in_row = input_row_start + filter_row;
    391         const int filter_offset_temp = filter_width * filter_row;
    392         UNROLL for (int filter_col = 0; filter_col < filter_width;
    393                     ++filter_col) {
    394           const int in_col = input_col_start + filter_col;
    395 
    396           const int input_offset =
    397               (input_offset_temp) + (in_row * in_width) + in_col;
    398           const int filter_offset =
    399               multiplier +
    400               depth_multiplier *
    401                   (in_channel + in_depth * (filter_col + filter_offset_temp));
    402           sum += ldg(input + input_offset) * ldg(filter + filter_offset);
    403         }
    404       }
    405     } else {
    406       // Loop that needs to check for boundary conditions.
    407       UNROLL for (int filter_row = 0; filter_row < filter_height;
    408                   ++filter_row) {
    409         const int in_row = input_row_start + filter_row;
    410         const int filter_offset_temp = filter_width * filter_row;
    411         UNROLL for (int filter_col = 0; filter_col < filter_width;
    412                     ++filter_col) {
    413           const int in_col = input_col_start + filter_col;
    414           // TODO(vrv): the in_row check can be done outside of this loop;
    415           // benchmark both methods to determine the better decision.
    416           if (in_row >= 0 && in_row < in_height && in_col >= 0 &&
    417               in_col < in_width) {
    418             const int in_col = input_col_start + filter_col;
    419 
    420             // input_offset_temp indexes into the start of memory
    421             // where the spatial data starts.
    422             const int input_offset =
    423                 (input_offset_temp) + (in_row * in_width) + in_col;
    424 
    425             const int filter_offset =
    426                 multiplier +
    427                 depth_multiplier *
    428                     (in_channel + in_depth * (filter_col + filter_offset_temp));
    429             sum += ldg(input + input_offset) * ldg(filter + filter_offset);
    430           }
    431         }
    432       }
    433     }
    434 
    435     output[thread_id] = sum;
    436   }
    437 }
    438 
    439 // CUDA kernel to compute the depthwise convolution forward pass in NCHW format,
    440 // tailored for small images up to 32x32. Stride and depth multiplier must be 1.
    441 // Padding must be 'SAME', which allows to reuse the index computation. Only
    442 // use this kernel if CanLaunchDepthwiseConv2dGPUSmall(args) returns true.
    443 // Tiles of the input and filter tensors are loaded into shared memory before
    444 // performing the convolution. Each thread handles two elements per iteration,
    445 // one each in the lower and upper half of a tile.
    446 // Backprop input direction is the same as forward direction with the filter
    447 // rotated by 180.
    448 template <typename T, DepthwiseConv2dDirection kDirection,
    449           int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth,
    450           bool kKnownEvenHeight>
    451 __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall(
    452     const DepthwiseArgs args, const T* input, const T* filter, T* output) {
    453   assert(CanLaunchDepthwiseConv2dGPUSmall(args));
    454   // Holds block plus halo and filter data for blockDim.z depths.
    455   extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[];
    456   T* const shared_data = reinterpret_cast<T*>(shared_memory);
    457 
    458   const int num_batches = args.batch;
    459   const int in_height = args.in_rows;
    460   const int in_width = args.in_cols;
    461   const int in_depth = args.in_depth;
    462   const int filter_height =
    463       kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
    464   const int filter_width =
    465       kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
    466   const int pad_height = args.pad_rows;
    467   const int pad_width = args.pad_cols;
    468 
    469   // Fixed blockDim.z, tailored for maximum grid size for images of size 16x16.
    470   assert(blockDim.x == args.in_cols);
    471   assert(blockDim.z == kBlockDepth);
    472   const int block_height = blockDim.y;
    473 
    474   // These values are the same for all threads and could
    475   // be precomputed on the CPU.
    476   const int block_pixels = in_width * block_height;
    477   const int block_size = block_pixels * kBlockDepth;
    478   const int in_pixels = in_width * in_height;
    479   const int in_increment = in_width - 1;
    480   const int filter_pixels = filter_height * filter_width;
    481   const int tile_width = in_width + filter_width - 1;
    482   const int even_height = kKnownEvenHeight || (1 & ~in_height);
    483   const int tile_height = in_height + filter_height - even_height;
    484   const int tile_pixels = tile_width * tile_height;
    485   const int tile_size = tile_pixels * kBlockDepth;
    486   const int tile_offset = block_height * tile_width;
    487   const int pad_offset = pad_height * tile_width + pad_width;
    488   const int in_total_depth = in_depth * num_batches;
    489   const int in_blocks = (in_total_depth + kBlockDepth - 1) / kBlockDepth;
    490 
    491   const int thread_col = threadIdx.x;
    492   const int thread_row = threadIdx.y;
    493   const int thread_depth = threadIdx.z;
    494 
    495   // Position in block.
    496   const int thread_pix = thread_row * in_width + thread_col;
    497   const int thread_idx = thread_depth * block_pixels + thread_pix;
    498 
    499   // Initialize tile, in particular the padding.
    500   for (int i = thread_idx; i < tile_size; i += block_size) {
    501     shared_data[i] = T(0);
    502   }
    503   __syncthreads();
    504 
    505   // Position in tensors.
    506   const int tensor_idx = thread_depth * in_pixels + thread_pix;
    507 
    508   // Position in (padded) shared memory.
    509   const int data_pix = thread_row * tile_width + thread_col;
    510   const int data_idx = thread_depth * tile_pixels + data_pix;
    511 
    512   // Position in shared memory, offset by pad_height / pad_width.
    513   const int tile_idx = data_idx + pad_offset;
    514 
    515   // Filter is always in HWCK format, irrespective of the input/output format.
    516   const int filter_pix = thread_idx / kBlockDepth;
    517   const int filter_channel = thread_idx % kBlockDepth;
    518   const int filter_idx = filter_pix * in_depth;
    519 
    520   const int max_channel = in_total_depth - thread_depth;
    521   const int filter_write_offset =
    522       filter_pix < filter_pixels ? tile_size + thread_idx : 0;
    523   const int filter_read_offset =
    524       tile_size + thread_depth +
    525       (kDirection == DIRECTION_FORWARD ? 0 : filter_pixels * kBlockDepth);
    526   const bool skip_second =
    527       !kKnownEvenHeight && thread_row + (in_height & 1) == block_height;
    528 
    529   for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) {
    530     const int channel = b * kBlockDepth;
    531 
    532     const int inout_offset = channel * in_pixels + tensor_idx;
    533     const bool channel_in_range = channel < max_channel;
    534 
    535     if (channel_in_range) {
    536       const T* const in_ptr = inout_offset + input;
    537       T* const tile_ptr = tile_idx + shared_data;
    538       tile_ptr[0] = ldg(in_ptr);
    539       if (!skip_second) {
    540         tile_ptr[tile_offset] = ldg(block_pixels + in_ptr);
    541       }
    542     }
    543 
    544     if (filter_write_offset != 0) {
    545       const int filter_offset =
    546           filter_idx + (channel + filter_channel) % in_depth;
    547       shared_data[filter_write_offset] = ldg(filter_offset + filter);
    548     }
    549 
    550     // Note: the condition to reach this is uniform across the entire block.
    551     __syncthreads();
    552 
    553     if (channel_in_range) {
    554       T sum1 = static_cast<T>(0);
    555       T sum2 = static_cast<T>(0);
    556       int shared_offset = data_idx;
    557       const T* filter_ptr = filter_read_offset + shared_data;
    558       UNROLL for (int r = 0; r < filter_height; ++r) {
    559         UNROLL for (int c = 0; c < filter_width; ++c) {
    560           if (kDirection == DIRECTION_BACKWARD) {
    561             filter_ptr -= kBlockDepth;
    562           }
    563           const T filter_value = *filter_ptr;
    564           const T* const tile_ptr = shared_offset + shared_data;
    565           sum1 += filter_value * tile_ptr[0];
    566           sum2 += filter_value * tile_ptr[tile_offset];
    567           ++shared_offset;
    568           if (kDirection == DIRECTION_FORWARD) {
    569             filter_ptr += kBlockDepth;
    570           }
    571         }
    572         shared_offset += in_increment;
    573       }
    574       T* const out_ptr = inout_offset + output;
    575       out_ptr[0] = sum1;
    576       if (!skip_second) {
    577         out_ptr[block_pixels] = sum2;
    578       }
    579     }
    580 
    581     // Note: the condition to reach this is uniform across the entire block.
    582     __syncthreads();
    583   }
    584 }
    585 
    586 template <typename T, DepthwiseConv2dDirection kDirection,
    587           int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth,
    588           bool kKnownEvenHeight>
    589 void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& device,
    590                                    const DepthwiseArgs& args, const T* input,
    591                                    const T* filter, T* output,
    592                                    TensorFormat data_format) {
    593   const int block_height = (args.in_rows + 1) / 2;
    594   dim3 block_dim;
    595   int block_count;
    596   void (*kernel)(const DepthwiseArgs, const T*, const T*, T*);
    597   switch (data_format) {
    598     case FORMAT_NHWC:
    599       block_dim = dim3(kBlockDepth, args.in_cols, block_height);
    600       block_count =
    601           args.batch * DivUp(args.out_depth, kBlockDepth) * kBlockDepth;
    602       kernel =
    603           DepthwiseConv2dGPUKernelNHWCSmall<T, kDirection, kKnownFilterWidth,
    604                                             kKnownFilterHeight, kBlockDepth,
    605                                             kKnownEvenHeight>;
    606       break;
    607     case FORMAT_NCHW:
    608       block_dim = dim3(args.in_cols, block_height, kBlockDepth);
    609       block_count =
    610           DivUp(args.batch * args.out_depth, kBlockDepth) * kBlockDepth;
    611       kernel =
    612           DepthwiseConv2dGPUKernelNCHWSmall<T, kDirection, kKnownFilterWidth,
    613                                             kKnownFilterHeight, kBlockDepth,
    614                                             kKnownEvenHeight>;
    615       break;
    616     case FORMAT_NCHW_VECT_C:
    617       LOG(ERROR) << "FORMAT_NCHW_VECT_C is not supported";
    618       return;
    619   }
    620   const int tile_width = args.in_cols + args.filter_cols - 1;
    621   const int tile_height = block_height * 2 + args.filter_rows - 1;
    622   const int tile_pixels = tile_height * tile_width;
    623   const int filter_pixels = args.filter_rows * args.filter_cols;
    624   const int shared_memory_size =
    625       kBlockDepth * (tile_pixels + filter_pixels) * sizeof(T);
    626   const int num_outputs = args.out_rows * args.out_cols * block_count;
    627   CudaLaunchConfig config = GetCudaLaunchConfigFixedBlockSize(
    628       num_outputs, device, kernel, shared_memory_size,
    629       block_dim.x * block_dim.y * block_dim.z);
    630   kernel<<<config.block_count, block_dim, shared_memory_size,
    631            device.stream()>>>(args, input, filter, output);
    632 }
    633 
    634 template <typename T, DepthwiseConv2dDirection kDirection,
    635           int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth>
    636 void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& device,
    637                                    const DepthwiseArgs& args, const T* input,
    638                                    const T* filter, T* output,
    639                                    TensorFormat data_format) {
    640   if (args.in_rows & 1) {
    641     LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth,
    642                                   kKnownFilterHeight, kBlockDepth, false>(
    643         device, args, input, filter, output, data_format);
    644   } else {
    645     LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth,
    646                                   kKnownFilterHeight, kBlockDepth, true>(
    647         device, args, input, filter, output, data_format);
    648   }
    649 }
    650 
    651 template <typename T, DepthwiseConv2dDirection kDirection,
    652           int kKnownFilterWidth, int kKnownFilterHeight>
    653 void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& device,
    654                                    const DepthwiseArgs& args, const T* input,
    655                                    const T* filter, T* output,
    656                                    TensorFormat data_format) {
    657   // Maximize (power of two) kBlockDepth while keeping a block within 1024
    658   // threads (2 pixels per thread).
    659   const int block_pixels = (args.in_rows + 1) / 2 * args.in_cols;
    660   if (block_pixels > 256) {
    661     LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth,
    662                                   kKnownFilterHeight, 2>(
    663         device, args, input, filter, output, data_format);
    664   } else if (block_pixels > 128) {
    665     LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth,
    666                                   kKnownFilterHeight, 4>(
    667         device, args, input, filter, output, data_format);
    668   } else {
    669     LaunchDepthwiseConv2dGPUSmall<T, kDirection, kKnownFilterWidth,
    670                                   kKnownFilterHeight, 8>(
    671         device, args, input, filter, output, data_format);
    672   }
    673 }
    674 
    675 template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
    676           int kKnownDepthMultiplier>
    677 void LaunchDepthwiseConv2dGPU(const GpuDevice& device,
    678                               const DepthwiseArgs& args, const T* input,
    679                               const T* filter, T* output,
    680                               TensorFormat data_format) {
    681   void (*kernel)(const DepthwiseArgs, const T*, const T*, T*, int);
    682   switch (data_format) {
    683     case FORMAT_NHWC:
    684       kernel =
    685           DepthwiseConv2dGPUKernelNHWC<T, kKnownFilterWidth, kKnownFilterHeight,
    686                                        kKnownDepthMultiplier>;
    687       break;
    688     case FORMAT_NCHW:
    689       kernel =
    690           DepthwiseConv2dGPUKernelNCHW<T, kKnownFilterWidth, kKnownFilterHeight,
    691                                        kKnownDepthMultiplier>;
    692       break;
    693     case FORMAT_NCHW_VECT_C:
    694       LOG(ERROR) << "FORMAT_NCHW_VECT_C is not supported";
    695       return;
    696   }
    697   const int num_outputs =
    698       args.batch * args.out_rows * args.out_cols * args.out_depth;
    699   CudaLaunchConfig config =
    700       GetCudaLaunchConfig(num_outputs, device, kernel, 0, 0);
    701   // The compile-time constant version runs faster with a single block.
    702   const int max_block_count = kKnownFilterWidth < 0 || kKnownFilterHeight < 0 ||
    703                                       kKnownDepthMultiplier < 0
    704                                   ? std::numeric_limits<int>::max()
    705                                   : device.getNumCudaMultiProcessors();
    706   kernel<<<std::min(max_block_count, config.block_count),
    707            config.thread_per_block, 0, device.stream()>>>(args, input, filter,
    708                                                           output, num_outputs);
    709 }
    710 
    711 template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
    712 void LaunchDepthwiseConv2dGPU(const GpuDevice& device,
    713                               const DepthwiseArgs& args, const T* input,
    714                               const T* filter, T* output,
    715                               TensorFormat data_format) {
    716   if (args.depth_multiplier == 1) {
    717     if (CanLaunchDepthwiseConv2dGPUSmall(args)) {
    718       LaunchDepthwiseConv2dGPUSmall<T, DIRECTION_FORWARD, kKnownFilterWidth,
    719                                     kKnownFilterHeight>(
    720           device, args, input, filter, output, data_format);
    721       return;
    722     }
    723 
    724     LaunchDepthwiseConv2dGPU<T, kKnownFilterWidth, kKnownFilterHeight, 1>(
    725         device, args, input, filter, output, data_format);
    726   } else {
    727     LaunchDepthwiseConv2dGPU<T, kKnownFilterWidth, kKnownFilterHeight, -1>(
    728         device, args, input, filter, output, data_format);
    729   }
    730 }
    731 
    732 // A simple launch pad to launch the Cuda kernel for depthwise convolution.
    733 template <typename T>
    734 void LaunchDepthwiseConvOp<GpuDevice, T>::operator()(OpKernelContext* ctx,
    735                                                      const DepthwiseArgs& args,
    736                                                      const T* input,
    737                                                      const T* filter, T* output,
    738                                                      TensorFormat data_format) {
    739   const GpuDevice& device = ctx->eigen_device<GpuDevice>();
    740   if (args.filter_rows == 3 && args.filter_cols == 3) {
    741     LaunchDepthwiseConv2dGPU<T, 3, 3>(device, args, input, filter, output,
    742                                       data_format);
    743   } else {
    744     LaunchDepthwiseConv2dGPU<T, -1, -1>(device, args, input, filter, output,
    745                                         data_format);
    746   }
    747   auto stream = ctx->op_device_context()->stream();
    748   OP_REQUIRES(ctx, stream->ok(),
    749               errors::Internal(
    750                   "Launch of gpu kernel for DepthwiseConv2dGPULaunch failed"));
    751 }
    752 
    753 template struct LaunchDepthwiseConvOp<GpuDevice, Eigen::half>;
    754 template struct LaunchDepthwiseConvOp<GpuDevice, float>;
    755 template struct LaunchDepthwiseConvOp<GpuDevice, double>;
    756 
    757 // A Cuda kernel to compute the depthwise convolution backprop w.r.t. input.
    758 template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
    759           int kKnownDepthMultiplier>
    760 __global__ void __launch_bounds__(640, 2)
    761     DepthwiseConv2dBackpropInputGPUKernelNHWC(const DepthwiseArgs args,
    762                                               const T* out_backprop,
    763                                               const T* filter, T* in_backprop,
    764                                               int num_in_backprop) {
    765   const int in_height = args.in_rows;
    766   const int in_width = args.in_cols;
    767   const int in_depth = args.in_depth;
    768   const int filter_height =
    769       kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
    770   const int filter_width =
    771       kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
    772   const int depth_multiplier =
    773       kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
    774   const int stride = args.stride;
    775   const int pad_height = args.pad_rows;
    776   const int pad_width = args.pad_cols;
    777   const int out_height = args.out_rows;
    778   const int out_width = args.out_cols;
    779   const int out_depth = args.out_depth;
    780 
    781   CUDA_1D_KERNEL_LOOP(thread_id, num_in_backprop) {
    782     // Compute the indexes of this thread in the output.
    783     const int in_channel = thread_id % in_depth;
    784     const int in_col = (thread_id / in_depth) % in_width;
    785     const int in_row = (thread_id / in_depth / in_width) % in_height;
    786     const int batch = thread_id / in_depth / in_width / in_height;
    787 
    788     T sum = static_cast<T>(0);
    789 
    790     const int out_row_start =
    791         tf_max<int>(0, (in_row - filter_height + pad_height + stride) / stride);
    792     const int out_row_end =
    793         tf_min(out_height - 1, (in_row + pad_height) / stride);
    794     const int out_col_start =
    795         tf_max(0, (in_col - filter_width + pad_width + stride) / stride);
    796     const int out_col_end =
    797         tf_min(out_width - 1, (in_col + pad_width) / stride);
    798 
    799     NOUNROLL for (int out_row = out_row_start; out_row <= out_row_end;
    800                   ++out_row) {
    801       const int filter_row = in_row + pad_height - out_row * stride;
    802       const int temp_out_backprop_offset =
    803           out_depth * out_width * (out_row + out_height * batch);
    804       const int temp_filter_offset = filter_width * filter_row;
    805       NOUNROLL for (int out_col = out_col_start; out_col <= out_col_end;
    806                     ++out_col) {
    807         const int filter_col = in_col + pad_width - out_col * stride;
    808         int filter_offset =
    809             depth_multiplier *
    810             (in_channel + in_depth * (filter_col + temp_filter_offset));
    811         const int out_backprop_offset =
    812             out_depth * out_col + temp_out_backprop_offset;
    813 #pragma unroll 6
    814         for (int i = 0; i < depth_multiplier; ++i) {
    815           sum += ldg(out_backprop + out_backprop_offset +
    816                      in_channel * depth_multiplier + i) *
    817                  ldg(filter + filter_offset + i);
    818         }
    819       }
    820     }
    821     const int in_backprop_offset =
    822         in_channel +
    823         in_depth * (in_col + in_width * (in_row + in_height * batch));
    824     in_backprop[in_backprop_offset] = sum;
    825   }
    826 }
    827 
    828 template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
    829           int kKnownDepthMultiplier>
    830 __global__ void __launch_bounds__(640, 2)
    831     DepthwiseConv2dBackpropInputGPUKernelNCHW(const DepthwiseArgs args,
    832                                               const T* out_backprop,
    833                                               const T* filter, T* in_backprop,
    834                                               int num_in_backprop) {
    835   const int in_height = args.in_rows;
    836   const int in_width = args.in_cols;
    837   const int in_depth = args.in_depth;
    838   const int filter_height =
    839       kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
    840   const int filter_width =
    841       kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
    842   const int depth_multiplier =
    843       kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
    844   const int stride = args.stride;
    845   const int pad_height = args.pad_rows;
    846   const int pad_width = args.pad_cols;
    847   const int out_height = args.out_rows;
    848   const int out_width = args.out_cols;
    849   const int out_depth = args.out_depth;
    850 
    851   // TODO(vrv): Consider assigning threads to output and using
    852   // atomics for accumulation, similar to the filter case.
    853   CUDA_1D_KERNEL_LOOP(thread_id, num_in_backprop) {
    854     // Compute the indexes of this thread in the input.
    855     const int in_col = thread_id % in_width;
    856     const int in_row = (thread_id / in_width) % in_height;
    857     const int in_channel = (thread_id / in_width / in_height) % in_depth;
    858     const int batch = thread_id / in_depth / in_width / in_height;
    859 
    860     T sum = static_cast<T>(0);
    861     const int out_channel_start = in_channel * depth_multiplier;
    862     const int out_channel_end = out_channel_start + depth_multiplier;
    863 
    864     const int out_row_start =
    865         tf_max<int>(0, (in_row - filter_height + pad_height + stride) / stride);
    866     const int out_row_end =
    867         tf_min(out_height - 1, (in_row + pad_height) / stride);
    868     const int out_col_start =
    869         tf_max(0, (in_col - filter_width + pad_width + stride) / stride);
    870     const int out_col_end =
    871         tf_min(out_width - 1, (in_col + pad_width) / stride);
    872 
    873     UNROLL for (int out_channel = out_channel_start;
    874                 out_channel < out_channel_end; ++out_channel) {
    875       UNROLL for (int out_row = out_row_start; out_row <= out_row_end;
    876                   ++out_row) {
    877         const int filter_row = in_row + pad_height - out_row * stride;
    878         const int filter_dm = out_channel - out_channel_start;
    879 
    880         const int temp_filter_offset = filter_width * filter_row;
    881         for (int out_col = out_col_start; out_col <= out_col_end; ++out_col) {
    882           const int filter_col = in_col + pad_width - out_col * stride;
    883           const int filter_offset =
    884               filter_dm +
    885               args.depth_multiplier *
    886                   (in_channel + in_depth * (filter_col + temp_filter_offset));
    887 
    888           const int out_backprop_offset =
    889               (batch * out_depth * out_height * out_width) +
    890               (out_channel * out_height * out_width) + (out_row * out_width) +
    891               (out_col);
    892 
    893           sum += ldg(out_backprop + out_backprop_offset) *
    894                  ldg(filter + filter_offset);
    895         }
    896       }
    897     }
    898     const int in_backprop_offset = (batch * in_height * in_width * in_depth) +
    899                                    (in_channel * in_height * in_width) +
    900                                    (in_row * in_width) + (in_col);
    901     in_backprop[in_backprop_offset] = sum;
    902   }
    903 }
    904 
    905 template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
    906           int kKnownDepthMultiplier>
    907 void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& device,
    908                                            const DepthwiseArgs& args,
    909                                            const T* out_backprop,
    910                                            const T* filter, T* in_backprop,
    911                                            TensorFormat data_format) {
    912   void (*kernel)(const DepthwiseArgs, const T*, const T*, T*, int);
    913   switch (data_format) {
    914     case FORMAT_NHWC:
    915       kernel = DepthwiseConv2dBackpropInputGPUKernelNHWC<
    916           T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>;
    917       break;
    918     case FORMAT_NCHW:
    919       kernel = DepthwiseConv2dBackpropInputGPUKernelNCHW<
    920           T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>;
    921       break;
    922     case FORMAT_NCHW_VECT_C:
    923       LOG(ERROR) << "FORMAT_NCHW_VECT_C is not supported";
    924       return;
    925   }
    926   const int num_in_backprop =
    927       args.batch * args.in_rows * args.in_cols * args.in_depth;
    928   CudaLaunchConfig config =
    929       GetCudaLaunchConfig(num_in_backprop, device, kernel, 0, 0);
    930   kernel<<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
    931       args, out_backprop, filter, in_backprop, num_in_backprop);
    932 }
    933 
    934 template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
    935 void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& device,
    936                                            const DepthwiseArgs& args,
    937                                            const T* out_backprop,
    938                                            const T* filter, T* in_backprop,
    939                                            TensorFormat data_format) {
    940   if (args.depth_multiplier == 1) {
    941     if (CanLaunchDepthwiseConv2dGPUSmall(args)) {
    942       LaunchDepthwiseConv2dGPUSmall<T, DIRECTION_BACKWARD, kKnownFilterWidth,
    943                                     kKnownFilterHeight>(
    944           device, args, out_backprop, filter, in_backprop, data_format);
    945       return;
    946     }
    947 
    948     LaunchDepthwiseConv2dBackpropInputGPU<T, kKnownFilterWidth,
    949                                           kKnownFilterHeight, 1>(
    950         device, args, out_backprop, filter, in_backprop, data_format);
    951   } else {
    952     LaunchDepthwiseConv2dBackpropInputGPU<T, kKnownFilterWidth,
    953                                           kKnownFilterHeight, -1>(
    954         device, args, out_backprop, filter, in_backprop, data_format);
    955   }
    956 }
    957 
    958 // A simple launch pad to launch the Cuda kernel for depthwise convolution.
    959 template <typename T>
    960 void LaunchDepthwiseConvBackpropInputOp<GpuDevice, T>::operator()(
    961     OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop,
    962     const T* filter, T* in_backprop, TensorFormat data_format) {
    963   const GpuDevice& device = ctx->eigen_device<GpuDevice>();
    964   if (args.filter_rows == 3 && args.filter_cols == 3) {
    965     LaunchDepthwiseConv2dBackpropInputGPU<T, 3, 3>(
    966         device, args, out_backprop, filter, in_backprop, data_format);
    967   } else {
    968     LaunchDepthwiseConv2dBackpropInputGPU<T, -1, -1>(
    969         device, args, out_backprop, filter, in_backprop, data_format);
    970   }
    971   auto stream = ctx->op_device_context()->stream();
    972   OP_REQUIRES(ctx, stream->ok(),
    973               errors::Internal("Launch of gpu kernel for "
    974                                "DepthwiseConv2dBackpropInp"
    975                                "utGPULaunch failed"));
    976 }
    977 
    978 template struct LaunchDepthwiseConvBackpropInputOp<GpuDevice, Eigen::half>;
    979 template struct LaunchDepthwiseConvBackpropInputOp<GpuDevice, float>;
    980 template struct LaunchDepthwiseConvBackpropInputOp<GpuDevice, double>;
    981 
    982 // A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
    983 template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
    984           int kKnownDepthMultiplier>
    985 __global__ void __launch_bounds__(640, 2)
    986     DepthwiseConv2dBackpropFilterGPUKernelNHWC(const DepthwiseArgs args,
    987                                                const T* out_backprop,
    988                                                const T* input,
    989                                                T* filter_backprop,
    990                                                int num_out_backprop) {
    991   const int in_height = args.in_rows;
    992   const int in_width = args.in_cols;
    993   const int in_depth = args.in_depth;
    994   const int filter_height =
    995       kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
    996   const int filter_width =
    997       kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
    998   const int depth_multiplier =
    999       kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
   1000   const int stride = args.stride;
   1001   const int pad_height = args.pad_rows;
   1002   const int pad_width = args.pad_cols;
   1003   const int out_height = args.out_rows;
   1004   const int out_width = args.out_cols;
   1005   const int out_depth = args.out_depth;
   1006 
   1007   CUDA_1D_KERNEL_LOOP(thread_id, num_out_backprop) {
   1008     // Compute the indexes of this thread in the output.
   1009     const int out_channel = thread_id % out_depth;
   1010     const int out_col = (thread_id / out_depth) % out_width;
   1011     const int out_row = (thread_id / out_depth / out_width) % out_height;
   1012     const int batch = thread_id / out_depth / out_width / out_height;
   1013     // Compute the input depth and the index of depth multiplier.
   1014     const int in_channel = out_channel / depth_multiplier;
   1015     const int dm = out_channel % depth_multiplier;
   1016 
   1017     // Decide if all input is valid, if yes, we can skip the boundary checks
   1018     // for each input.
   1019     const int in_row_start = out_row * stride - pad_height;
   1020     const int in_col_start = out_col * stride - pad_width;
   1021     const int in_row_end = in_row_start + filter_height;
   1022     const int in_col_end = in_col_start + filter_width;
   1023 
   1024     const int out_backprop_offset =
   1025         out_channel +
   1026         out_depth * (out_col + out_width * (out_row + out_height * batch));
   1027     const T out_bp = ldg(out_backprop + out_backprop_offset);
   1028     if (in_row_start >= 0 && in_col_start >= 0 && in_row_end < in_height &&
   1029         in_col_end < in_width) {
   1030       UNROLL for (int filter_row = 0; filter_row < filter_height;
   1031                   ++filter_row) {
   1032         const int in_row = in_row_start + filter_row;
   1033         // Avoid repeated computation.
   1034         const int input_offset_temp = in_width * (in_row + in_height * batch);
   1035         UNROLL for (int filter_col = 0; filter_col < filter_width;
   1036                     ++filter_col) {
   1037           const int in_col = in_col_start + filter_col;
   1038 
   1039           const int input_offset =
   1040               in_channel + in_depth * (in_col + input_offset_temp);
   1041           T partial_sum = ldg(input + input_offset) * out_bp;
   1042           T* addr =
   1043               filter_backprop +
   1044               (dm + depth_multiplier *
   1045                         (in_channel +
   1046                          in_depth * (filter_col + filter_width * filter_row)));
   1047           CudaAtomicAdd(addr, partial_sum);
   1048         }
   1049       }
   1050     } else {
   1051       UNROLL for (int filter_row = 0; filter_row < filter_height;
   1052                   ++filter_row) {
   1053         const int in_row = in_row_start + filter_row;
   1054         // Avoid repeated computation.
   1055         const int input_offset_temp = in_width * (in_row + in_height * batch);
   1056         UNROLL for (int filter_col = 0; filter_col < filter_width;
   1057                     ++filter_col) {
   1058           const int in_col = in_col_start + filter_col;
   1059           const int addr_temp = filter_width * filter_row;
   1060 
   1061           if (in_row >= 0 && in_row < in_height && in_col >= 0 &&
   1062               in_col < in_width) {
   1063             const int input_offset =
   1064                 in_channel + in_depth * (in_col + input_offset_temp);
   1065             T partial_sum = ldg(input + input_offset) * out_bp;
   1066             T* addr =
   1067                 filter_backprop +
   1068                 (dm + depth_multiplier *
   1069                           (in_channel + in_depth * (filter_col + addr_temp)));
   1070             // Potentially many threads can add to the same address so we have
   1071             // to use atomic add here.
   1072             // TODO(jmchen): If atomic add turns out to be slow, we can:
   1073             // 1. allocate multiple buffers for the gradients (one for each
   1074             // example in a batch, for example). This can reduce the
   1075             // contention on the destination; 2. Have each thread compute one
   1076             // gradient for an element in the filters. This should work well
   1077             // when the input depth is big and filter size is not too small.
   1078             CudaAtomicAdd(addr, partial_sum);
   1079           }
   1080         }
   1081       }
   1082     }
   1083   }
   1084 }
   1085 
   1086 // Device function to compute sub-warp sum reduction for a power-of-two group of
   1087 // neighboring threads.
   1088 template <int kWidth, typename T>
   1089 __device__ __forceinline__ T WarpSumReduce(T val) {
   1090   // support only power-of-two widths.
   1091   assert(__popc(kWidth) == 1);
   1092   int sub_warp = cub::LaneId() / kWidth;
   1093   int zeros = sub_warp * kWidth;
   1094   unsigned mask = ((1UL << kWidth) - 1) << zeros;
   1095   for (int delta = kWidth / 2; delta > 0; delta /= 2) {
   1096     val += CudaShuffleXorSync(mask, val, delta);
   1097   }
   1098   return val;
   1099 }
   1100 
   1101 // CUDA kernel to compute the depthwise convolution backward w.r.t. filter in
   1102 // NHWC format, tailored for small images up to 32x32. Stride and depth
   1103 // multiplier must be 1. Padding must be 'SAME'. Only use this kernel if
   1104 // CanLaunchDepthwiseConv2dGPUSmall(args) returns true.
   1105 // Tiles of the input tensor are loaded into shared memory before performing the
   1106 // convolution. Per iteration and filter element, each thread first performs
   1107 // a partial convolution for two elements, one each in the lower and upper half
   1108 // of a tile. The intermediate result of all pixels of a warp are then
   1109 // accumulated and written to shared memory. Finally, the values in shared
   1110 // memory are warp-accumulated (in chunks of kAccumPixels elements) and summed
   1111 // up in global memory using atomics.
   1112 // Requirements: threads per block must be multiple of 32 and <= launch_bounds,
   1113 // kAccumPixels * 64 >= args.in_rows * args.in_cols * kBlockDepth.
   1114 template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
   1115           int kBlockDepth, int kAccumPixels>
   1116 __global__
   1117 __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
   1118     const DepthwiseArgs args, const T* output, const T* input, T* filter) {
   1119   assert(CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, blockDim.z));
   1120   // Holds block plus halo and filter data for blockDim.x depths.
   1121   extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[];
   1122   T* const shared_data = reinterpret_cast<T*>(shared_memory);
   1123 
   1124   const int num_batches = args.batch;
   1125   const int in_height = args.in_rows;
   1126   const int in_width = blockDim.y;  // slower (see b/62280718): args.in_cols;
   1127   const int in_depth = args.in_depth;
   1128   const int filter_height =
   1129       kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
   1130   const int filter_width =
   1131       kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
   1132   const int pad_height = args.pad_rows;
   1133   const int pad_width = args.pad_cols;
   1134 
   1135   assert(blockDim.x == kBlockDepth);
   1136   assert(blockDim.y == args.in_cols);
   1137   const int block_height = blockDim.z;
   1138 
   1139   // These values are the same for all threads and could
   1140   // be precomputed on the CPU.
   1141   const int block_size = block_height * in_width * kBlockDepth;
   1142   assert((block_size & 31) == 0);
   1143   const int in_row_size = in_width * in_depth;
   1144   const int in_size = in_height * in_row_size;
   1145   const int in_increment = (in_width - 1) * kBlockDepth;
   1146   const int filter_pixels = filter_height * filter_width;
   1147   const int tile_width = in_width + filter_width - 1;
   1148   const int tile_height = 2 * block_height + filter_height - 1;
   1149   const int tile_row_size = tile_width * kBlockDepth;
   1150   const int tile_size = tile_height * tile_row_size;
   1151   const int tile_offset = block_height * tile_row_size;
   1152   const int pad_offset = pad_height * tile_width + pad_width;
   1153   const int batch_blocks = (in_depth + kBlockDepth - 1) / kBlockDepth;
   1154   const int in_blocks = batch_blocks * num_batches;
   1155   const int tensor_offset = block_height * in_row_size;
   1156   // The accumulator has a fixed number of pixels that can be reduced by one
   1157   // warp. Pixels beyond ceil(in_pixels * kBlockDepth / 64) are never written.
   1158   assert(kAccumPixels * 64 >= in_height * in_width * kBlockDepth);
   1159   const int accum_increment = kAccumPixels * kBlockDepth;
   1160   const int accum_size = filter_pixels * accum_increment;
   1161 
   1162   const int thread_depth = threadIdx.x;
   1163   const int thread_col = threadIdx.y;
   1164   const int thread_row = threadIdx.z;
   1165 
   1166   // Position in block.
   1167   const int thread_pix = thread_row * in_width + thread_col;
   1168   const int thread_idx = thread_pix * kBlockDepth + thread_depth;
   1169 
   1170   // Initialize tile, in particular the padding and accumulator.
   1171   for (int i = thread_idx; i < tile_size + accum_size; i += block_size) {
   1172     shared_data[i] = T(0);
   1173   }
   1174   __syncthreads();
   1175 
   1176   // Position in tensors.
   1177   const int tensor_idx = thread_pix * in_depth + thread_depth;
   1178 
   1179   // Position in (padded) shared memory.
   1180   const int data_pix = thread_row * tile_width + thread_col;
   1181   const int data_idx = data_pix * kBlockDepth + thread_depth;
   1182 
   1183   // Position in shared memory, offset by pad_height / pad_width.
   1184   const int tile_pix = data_pix + pad_offset;
   1185   const int tile_idx = tile_pix * kBlockDepth + thread_depth;
   1186 
   1187   // Position in accumulator (kBlockDepth per warp, depth major).
   1188   const int accum_pix = thread_pix / (32 / kBlockDepth);
   1189   const int accum_idx = thread_depth * kAccumPixels + accum_pix;
   1190 
   1191   const int max_channel = in_depth - thread_depth;
   1192   const int accum_offset = tile_size + accum_idx;
   1193   const bool skip_second = block_height + thread_row >= in_height;
   1194 
   1195   for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) {
   1196     const int batch = b / batch_blocks;
   1197     const int block = b - batch * batch_blocks;
   1198 
   1199     const int start_channel = block * kBlockDepth;
   1200     const int filter_offset = tensor_idx + start_channel;
   1201     const int inout_offset = batch * in_size + filter_offset;
   1202     const bool channel_in_range = start_channel < max_channel;
   1203 
   1204     if (channel_in_range) {
   1205       const T* const in_ptr = inout_offset + input;
   1206       T* const tile_ptr = tile_idx + shared_data;
   1207       tile_ptr[0] = ldg(in_ptr);
   1208       if (!skip_second) {
   1209         tile_ptr[tile_offset] = ldg(tensor_offset + in_ptr);
   1210       }
   1211     }
   1212 
   1213     // Note: the condition to reach this is uniform across the entire block.
   1214     __syncthreads();
   1215     unsigned active_threads = CudaBallotSync(kCudaWarpAll, channel_in_range);
   1216 
   1217     if (channel_in_range) {
   1218       const T* const out_ptr = inout_offset + output;
   1219       const T out1 = ldg(out_ptr);
   1220       const T out2 = skip_second ? T(0) : ldg(tensor_offset + out_ptr);
   1221       int shared_offset = data_idx;
   1222       T* accum_ptr = accum_offset + shared_data;
   1223       UNROLL for (int r = 0; r < filter_height; ++r) {
   1224         UNROLL for (int c = 0; c < filter_width; ++c) {
   1225           const T* const tile_ptr = shared_offset + shared_data;
   1226           T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
   1227           // Warp-accumulate pixels of the same depth and write to accumulator.
   1228           for (int delta = 16; delta >= kBlockDepth; delta /= 2) {
   1229             val += CudaShuffleXorSync(active_threads, val, delta);
   1230           }
   1231           if (!(thread_idx & 32 - kBlockDepth) /* lane_idx < kBlockDepth */) {
   1232             *accum_ptr = val;
   1233           }
   1234           shared_offset += kBlockDepth;
   1235           accum_ptr += accum_increment;
   1236         }
   1237         shared_offset += in_increment;
   1238       }
   1239     }
   1240 
   1241     // Note: the condition to reach this is uniform across the entire block.
   1242     __syncthreads();
   1243 
   1244     const T* const accum_data = tile_size + shared_data;
   1245     for (int i = thread_idx; i < accum_size; i += block_size) {
   1246       const int filter_idx = i / kAccumPixels;
   1247       const int filter_pix = filter_idx / kBlockDepth;
   1248       const int filter_channel = filter_idx % kBlockDepth + start_channel;
   1249       const int filter_offset = filter_pix * in_depth + filter_channel;
   1250       if (filter_channel < in_depth) {
   1251         T val = accum_data[i];
   1252         // Warp-accumulate the pixels of the same depth from the accumulator.
   1253         val = WarpSumReduce<kAccumPixels>(val);
   1254         if (!(thread_idx & kAccumPixels - 1)) {
   1255           CudaAtomicAdd(filter_offset + filter, val);
   1256         }
   1257       }
   1258     }
   1259   }
   1260 }
   1261 
   1262 // A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
   1263 template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
   1264           int kKnownDepthMultiplier>
   1265 __global__ void __launch_bounds__(640, 2)
   1266     DepthwiseConv2dBackpropFilterGPUKernelNCHW(const DepthwiseArgs args,
   1267                                                const T* out_backprop,
   1268                                                const T* input,
   1269                                                T* filter_backprop,
   1270                                                int num_out_backprop) {
   1271   const int in_height = args.in_rows;
   1272   const int in_width = args.in_cols;
   1273   const int in_depth = args.in_depth;
   1274   const int filter_height =
   1275       kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
   1276   const int filter_width =
   1277       kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
   1278   const int depth_multiplier =
   1279       kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
   1280   const int stride = args.stride;
   1281   const int pad_height = args.pad_rows;
   1282   const int pad_width = args.pad_cols;
   1283   const int out_height = args.out_rows;
   1284   const int out_width = args.out_cols;
   1285   const int out_depth = args.out_depth;
   1286 
   1287   CUDA_1D_KERNEL_LOOP(thread_id, num_out_backprop) {
   1288     // Compute the indexes of this thread in the output.
   1289     const int out_col = thread_id % out_width;
   1290     const int out_row = (thread_id / out_width) % out_height;
   1291     const int out_channel = (thread_id / out_width / out_height) % out_depth;
   1292 
   1293     const int batch = thread_id / out_depth / out_width / out_height;
   1294     // Compute the input depth and the index of depth multiplier.
   1295     const int in_channel = out_channel / depth_multiplier;
   1296     const int dm = out_channel % depth_multiplier;
   1297 
   1298     // Decide if all input is valid, if yes, we can skip the boundary checks
   1299     // for each input.
   1300     const int in_row_start = out_row * stride - pad_height;
   1301     const int in_col_start = out_col * stride - pad_width;
   1302     const int in_row_end = in_row_start + filter_height;
   1303     const int in_col_end = in_col_start + filter_width;
   1304 
   1305     const int out_backprop_offset =
   1306         (batch * out_depth * out_height * out_width) +
   1307         (out_channel * out_height * out_width) + (out_row * out_width) +
   1308         (out_col);
   1309 
   1310     const T out_bp = ldg(out_backprop + out_backprop_offset);
   1311     if (in_row_start >= 0 && in_col_start >= 0 && in_row_end < in_height &&
   1312         in_col_end < in_width) {
   1313       UNROLL for (int filter_row = 0; filter_row < filter_height;
   1314                   ++filter_row) {
   1315         const int in_row = in_row_start + filter_row;
   1316         // Avoid repeated computation.
   1317         const int input_offset_temp =
   1318             (batch * in_depth * in_height * in_width) +
   1319             (in_channel * in_height * in_width) + (in_row * in_width);
   1320 
   1321         UNROLL for (int filter_col = 0; filter_col < filter_width;
   1322                     ++filter_col) {
   1323           const int in_col = in_col_start + filter_col;
   1324           const int input_offset = input_offset_temp + in_col;
   1325           T partial_sum = ldg(input + input_offset) * out_bp;
   1326           T* addr =
   1327               filter_backprop +
   1328               (dm + depth_multiplier *
   1329                         (in_channel +
   1330                          in_depth * (filter_col + filter_width * filter_row)));
   1331           CudaAtomicAdd(addr, partial_sum);
   1332         }
   1333       }
   1334     } else {
   1335       UNROLL for (int filter_row = 0; filter_row < filter_height;
   1336                   ++filter_row) {
   1337         const int in_row = in_row_start + filter_row;
   1338         // Avoid repeated computation.
   1339         const int input_offset_temp =
   1340             (batch * in_depth * in_height * in_width) +
   1341             (in_channel * in_height * in_width) + (in_row * in_width);
   1342         UNROLL for (int filter_col = 0; filter_col < filter_width;
   1343                     ++filter_col) {
   1344           const int in_col = in_col_start + filter_col;
   1345           const int addr_temp = filter_width * filter_row;
   1346 
   1347           if (in_row >= 0 && in_row < in_height && in_col >= 0 &&
   1348               in_col < in_width) {
   1349             const int input_offset = input_offset_temp + in_col;
   1350             T partial_sum = ldg(input + input_offset) * out_bp;
   1351             T* addr =
   1352                 filter_backprop +
   1353                 (dm + depth_multiplier *
   1354                           (in_channel + in_depth * (filter_col + addr_temp)));
   1355             // Potentially many threads can add to the same address so we have
   1356             // to use atomic add here.
   1357             // TODO(jmchen): If atomic add turns out to be slow, we can:
   1358             // 1. allocate multiple buffers for the gradients (one for each
   1359             // example in a batch, for example). This can reduce the
   1360             // contention on the destination; 2. Have each thread compute one
   1361             // gradient for an element in the filters. This should work well
   1362             // when the input depth is big and filter size is not too small.
   1363             CudaAtomicAdd(addr, partial_sum);
   1364           }
   1365         }
   1366       }
   1367     }
   1368   }
   1369 }
   1370 
   1371 // CUDA kernel to compute the depthwise convolution backward w.r.t. filter in
   1372 // NCHW format, tailored for small images up to 32x32. Stride and depth
   1373 // multiplier must be 1. Padding must be 'SAME'. Only use this kernel if
   1374 // CanLaunchDepthwiseConv2dGPUSmall(args) returns true.
   1375 // Tiles of the input tensor are loaded into shared memory before performing the
   1376 // convolution. Per iteration and filter element, each thread first performs
   1377 // a partial convolution for two elements, one each in the lower and upper half
   1378 // of a tile. The intermediate result of all pixels of a warp are then
   1379 // accumulated and written to shared memory. Finally, the values in shared
   1380 // memory are warp-accumulated (in chunks of kAccumPixels elements) and summed
   1381 // up in global memory using atomics.
   1382 // Requirements: threads per block must be multiple of 32 and <= launch_bounds,
   1383 // kAccumPixels * 64 >= args.in_rows * args.in_cols * kBlockDepth.
   1384 template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
   1385           int kBlockDepth, int kAccumPixels>
   1386 __global__
   1387 __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
   1388     const DepthwiseArgs args, const T* output, const T* input, T* filter) {
   1389   assert(CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, blockDim.x));
   1390   // Holds block plus halo and filter data for blockDim.z depths.
   1391   extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[];
   1392   T* const shared_data = reinterpret_cast<T*>(shared_memory);
   1393 
   1394   const int num_batches = args.batch;
   1395   const int in_height = args.in_rows;
   1396   const int in_width = blockDim.x;  // slower (see b/62280718): args.in_cols;
   1397   const int in_depth = args.in_depth;
   1398   const int filter_height =
   1399       kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
   1400   const int filter_width =
   1401       kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
   1402   const int pad_height = args.pad_rows;
   1403   const int pad_width = args.pad_cols;
   1404 
   1405   assert(blockDim.x == args.in_cols);
   1406   assert(blockDim.z == kBlockDepth);
   1407   const int block_height = blockDim.y;
   1408 
   1409   // These values are the same for all threads and could
   1410   // be precomputed on the CPU.
   1411   const int block_pixels = in_width * block_height;
   1412   const int block_size = block_pixels * kBlockDepth;
   1413   assert((block_size & 31) == 0);
   1414   const int in_pixels = in_width * in_height;
   1415   const int in_increment = in_width - 1;
   1416   const int filter_pixels = filter_height * filter_width;
   1417   const int tile_width = in_width + filter_width - 1;
   1418   const int tile_height = 2 * block_height + filter_height - 1;
   1419   const int tile_pixels = tile_width * tile_height;
   1420   const int tile_size = tile_pixels * kBlockDepth;
   1421   const int tile_offset = block_height * tile_width;
   1422   const int pad_offset = pad_height * tile_width + pad_width;
   1423   const int in_total_depth = in_depth * num_batches;
   1424   const int in_blocks = (in_total_depth + kBlockDepth - 1) / kBlockDepth;
   1425   // The accumulator has a fixed number of pixels that can be reduced by one
   1426   // warp. Pixels beyond ceil(in_pixels * kBlockDepth / 64) are never written.
   1427   assert(kAccumPixels * 64 >= in_height * in_width * kBlockDepth);
   1428   const int accum_increment = kAccumPixels * kBlockDepth;
   1429   const int accum_size = filter_pixels * accum_increment;
   1430 
   1431   const int thread_col = threadIdx.x;
   1432   const int thread_row = threadIdx.y;
   1433   const int thread_depth = threadIdx.z;
   1434 
   1435   // Position in block.
   1436   const int thread_pix = thread_row * in_width + thread_col;
   1437   const int thread_idx = thread_depth * block_pixels + thread_pix;
   1438 
   1439   // Initialize tile, in particular the padding and accumulator.
   1440   for (int i = thread_idx; i < tile_size + accum_size; i += block_size) {
   1441     shared_data[i] = T(0);
   1442   }
   1443   __syncthreads();
   1444 
   1445   // Position in tensors.
   1446   const int tensor_idx = thread_depth * in_pixels + thread_pix;
   1447 
   1448   // Position in (padded) shared memory.
   1449   const int data_pix = thread_row * tile_width + thread_col;
   1450   const int data_idx = thread_depth * tile_pixels + data_pix;
   1451 
   1452   // Position in shared memory, offset by pad_height / pad_width.
   1453   const int tile_idx = data_idx + pad_offset;
   1454 
   1455   // Position in accumulator (kBlockDepth per warp, depth major).
   1456   const int accum_pix = thread_pix / (32 / kBlockDepth);
   1457   const int accum_idx = thread_depth * kAccumPixels + accum_pix;
   1458 
   1459   const int max_channel = in_total_depth - thread_depth;
   1460   const int accum_offset = tile_size + accum_idx;
   1461   const bool skip_second = block_height + thread_row >= in_height;
   1462 
   1463   for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) {
   1464     const int channel = b * kBlockDepth;
   1465 
   1466     const int inout_offset = channel * in_pixels + tensor_idx;
   1467     const bool channel_in_range = channel < max_channel;
   1468 
   1469     if (channel_in_range) {
   1470       const T* const in_ptr = inout_offset + input;
   1471       T* const tile_ptr = tile_idx + shared_data;
   1472       tile_ptr[0] = ldg(in_ptr);
   1473       if (!skip_second) {
   1474         tile_ptr[tile_offset] = ldg(block_pixels + in_ptr);
   1475       }
   1476     }
   1477 
   1478     // Note: the condition to reach this is uniform across the entire block.
   1479     __syncthreads();
   1480     unsigned active_threads = CudaBallotSync(kCudaWarpAll, channel_in_range);
   1481 
   1482     if (channel_in_range) {
   1483       const T* const out_ptr = inout_offset + output;
   1484       const T out1 = ldg(out_ptr);
   1485       const T out2 = skip_second ? T(0) : ldg(block_pixels + out_ptr);
   1486       int shared_offset = data_idx;
   1487       T* accum_ptr = accum_offset + shared_data;
   1488       UNROLL for (int r = 0; r < filter_height; ++r) {
   1489         UNROLL for (int c = 0; c < filter_width; ++c) {
   1490           const T* const tile_ptr = shared_offset + shared_data;
   1491           T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
   1492           // Warp-accumulate pixels of the same depth and write to accumulator.
   1493           for (int delta = 16 / kBlockDepth; delta > 0; delta /= 2) {
   1494             val += CudaShuffleXorSync(active_threads, val, delta);
   1495           }
   1496           if (!(thread_idx & 32 / kBlockDepth - 1)) {
   1497             *accum_ptr = val;  // kBlockDepth threads per warp.
   1498           }
   1499           ++shared_offset;
   1500           accum_ptr += accum_increment;
   1501         }
   1502         shared_offset += in_increment;
   1503       }
   1504     }
   1505 
   1506     // Note: the condition to reach this is uniform across the entire block.
   1507     __syncthreads();
   1508 
   1509     const T* const accum_data = tile_size + shared_data;
   1510     for (int i = thread_idx; i < accum_size; i += block_size) {
   1511       const int filter_idx = i / kAccumPixels;
   1512       const int filter_pix = filter_idx / kBlockDepth;
   1513       const int filter_channel =
   1514           (channel + filter_idx % kBlockDepth) % in_depth;
   1515       const int filter_offset = filter_pix * in_depth + filter_channel;
   1516       if (filter_channel < in_depth) {
   1517         T val = accum_data[i];
   1518         // Warp-accumulate pixels of the same depth from the accumulator.
   1519         val = WarpSumReduce<kAccumPixels>(val);
   1520         if (!(thread_idx & kAccumPixels - 1)) {
   1521           CudaAtomicAdd(filter_offset + filter, val);
   1522         }
   1523       }
   1524     }
   1525   }
   1526 }
   1527 
   1528 template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
   1529           int kBlockDepth, int kAccumPixels>
   1530 bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
   1531     const GpuDevice& device, const DepthwiseArgs& args, const int block_height,
   1532     const T* out_backprop, const T* input, T* filter_backprop,
   1533     TensorFormat data_format) {
   1534   const int tile_width = args.in_cols + args.filter_cols - 1;
   1535   const int tile_height = block_height * 2 + args.filter_rows - 1;
   1536   const int tile_pixels = tile_height * tile_width;
   1537   const int filter_pixels = args.filter_rows * args.filter_cols;
   1538   const int shared_memory_size =
   1539       kBlockDepth * (tile_pixels + filter_pixels * kAccumPixels) * sizeof(T);
   1540   if (shared_memory_size > device.sharedMemPerBlock()) {
   1541     return false;
   1542   }
   1543 
   1544   dim3 block_dim;
   1545   int block_count;
   1546   void (*kernel)(const DepthwiseArgs, const T*, const T*, T*);
   1547   switch (data_format) {
   1548     case FORMAT_NHWC:
   1549       block_dim = dim3(kBlockDepth, args.in_cols, block_height);
   1550       block_count =
   1551           args.batch * DivUp(args.out_depth, kBlockDepth) * kBlockDepth;
   1552       kernel = DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall<
   1553           T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels>;
   1554       break;
   1555     case FORMAT_NCHW:
   1556       block_dim = dim3(args.in_cols, block_height, kBlockDepth);
   1557       block_count =
   1558           DivUp(args.batch * args.out_depth, kBlockDepth) * kBlockDepth;
   1559       kernel = DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall<
   1560           T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels>;
   1561       break;
   1562     case FORMAT_NCHW_VECT_C:
   1563       LOG(ERROR) << "FORMAT_NCHW_VECT_C is not supported";
   1564       return false;
   1565   }
   1566   const int num_out_backprop = args.out_rows * args.out_cols * block_count;
   1567   CudaLaunchConfig config = GetCudaLaunchConfigFixedBlockSize(
   1568       num_out_backprop, device, kernel, shared_memory_size,
   1569       block_dim.x * block_dim.y * block_dim.z);
   1570   kernel<<<config.block_count, block_dim, shared_memory_size,
   1571            device.stream()>>>(args, out_backprop, input, filter_backprop);
   1572   return true;
   1573 }
   1574 
   1575 template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
   1576           int kBlockDepth>
   1577 bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
   1578     const GpuDevice& device, const DepthwiseArgs& args, const int block_height,
   1579     const T* out_backprop, const T* input, T* filter_backprop,
   1580     TensorFormat data_format) {
   1581   // Minimize (power of two) kAccumPixels, while satisfying
   1582   // kAccumPixels * 32 >= block_height * in_width * kBlockDepth.
   1583   const int block_pixels = block_height * args.in_cols * kBlockDepth;
   1584   if (block_pixels > 512) {
   1585     return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<
   1586         T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, 32>(
   1587         device, args, block_height, out_backprop, input, filter_backprop,
   1588         data_format);
   1589   } else if (block_pixels > 256) {
   1590     return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<
   1591         T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, 16>(
   1592         device, args, block_height, out_backprop, input, filter_backprop,
   1593         data_format);
   1594   } else {
   1595     return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<
   1596         T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, 8>(
   1597         device, args, block_height, out_backprop, input, filter_backprop,
   1598         data_format);
   1599   }
   1600 }
   1601 
   1602 template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
   1603 bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
   1604     const GpuDevice& device, const DepthwiseArgs& args, const T* out_backprop,
   1605     const T* input, T* filter_backprop, TensorFormat data_format) {
   1606   // Maximize (power of two) kBlockDepth while keeping a block within 1024
   1607   // threads (2 pixels per thread).
   1608   int block_depth = 8;
   1609   int block_height = (args.in_rows + 1) / 2;
   1610   int round_mask = 1;
   1611   for (; block_depth > 1; block_depth /= 2) {
   1612     // args.in_cols * block_height * kBlockDepth must be multiple of 32.
   1613     for (; block_height * args.in_cols * block_depth & 31;
   1614          round_mask = round_mask * 2 + 1) {
   1615       block_height = block_height + round_mask & ~round_mask;
   1616     }
   1617     int block_size = block_height * args.in_cols * block_depth;
   1618     if (block_size <= 1024) {
   1619       break;
   1620     }
   1621   }
   1622 
   1623   if (!CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, block_height)) {
   1624     return false;
   1625   }
   1626 
   1627   switch (block_depth) {
   1628     case 8:
   1629       return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<
   1630           T, kKnownFilterWidth, kKnownFilterHeight, 8>(
   1631           device, args, block_height, out_backprop, input, filter_backprop,
   1632           data_format);
   1633     case 4:
   1634       return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<
   1635           T, kKnownFilterWidth, kKnownFilterHeight, 4>(
   1636           device, args, block_height, out_backprop, input, filter_backprop,
   1637           data_format);
   1638     case 2:
   1639       return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<
   1640           T, kKnownFilterWidth, kKnownFilterHeight, 2>(
   1641           device, args, block_height, out_backprop, input, filter_backprop,
   1642           data_format);
   1643     default:
   1644       return false;
   1645   }
   1646 }
   1647 
   1648 template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
   1649           int kKnownDepthMultiplier>
   1650 void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& device,
   1651                                             const DepthwiseArgs& args,
   1652                                             const T* out_backprop,
   1653                                             const T* input, T* filter_backprop,
   1654                                             TensorFormat data_format) {
   1655   void (*kernel)(const DepthwiseArgs, const T*, const T*, T*, int);
   1656   switch (data_format) {
   1657     case FORMAT_NHWC:
   1658       kernel = DepthwiseConv2dBackpropFilterGPUKernelNHWC<
   1659           T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>;
   1660       break;
   1661     case FORMAT_NCHW:
   1662       kernel = DepthwiseConv2dBackpropFilterGPUKernelNCHW<
   1663           T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>;
   1664       break;
   1665     case FORMAT_NCHW_VECT_C:
   1666       LOG(ERROR) << "FORMAT_NCHW_VECT_C is not supported";
   1667       return;
   1668   }
   1669   const int num_out_backprop =
   1670       args.batch * args.out_rows * args.out_cols * args.out_depth;
   1671   CudaLaunchConfig config =
   1672       GetCudaLaunchConfig(num_out_backprop, device, kernel, 0, 0);
   1673   kernel<<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
   1674       args, out_backprop, input, filter_backprop, num_out_backprop);
   1675 }
   1676 
   1677 template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
   1678 void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& device,
   1679                                             const DepthwiseArgs& args,
   1680                                             const T* out_backprop,
   1681                                             const T* input, T* filter_backprop,
   1682                                             TensorFormat data_format) {
   1683   if (args.depth_multiplier == 1) {
   1684     if (TryLaunchDepthwiseConv2dBackpropFilterGPUSmall<T, kKnownFilterWidth,
   1685                                                        kKnownFilterHeight>(
   1686             device, args, out_backprop, input, filter_backprop, data_format)) {
   1687       return;
   1688     }
   1689 
   1690     LaunchDepthwiseConv2dBackpropFilterGPU<T, kKnownFilterWidth,
   1691                                            kKnownFilterHeight, 1>(
   1692         device, args, out_backprop, input, filter_backprop, data_format);
   1693   } else {
   1694     LaunchDepthwiseConv2dBackpropFilterGPU<T, kKnownFilterWidth,
   1695                                            kKnownFilterHeight, -1>(
   1696         device, args, out_backprop, input, filter_backprop, data_format);
   1697   }
   1698 }
   1699 
   1700 // A simple launch pad to launch the Cuda kernel for depthwise convolution.
   1701 template <typename T>
   1702 void LaunchDepthwiseConvBackpropFilterOp<GpuDevice, T>::operator()(
   1703     OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop,
   1704     const T* input, T* filter_backprop, TensorFormat data_format) {
   1705   const GpuDevice& device = ctx->eigen_device<GpuDevice>();
   1706   auto stream = ctx->op_device_context()->stream();
   1707 
   1708   // Initialize the results to 0.
   1709   int num_filter_backprop =
   1710       args.filter_rows * args.filter_cols * args.out_depth;
   1711   perftools::gputools::DeviceMemoryBase filter_bp_ptr(filter_backprop,
   1712                                                       num_filter_backprop);
   1713   stream->ThenMemset32(&filter_bp_ptr, 0, num_filter_backprop * sizeof(T));
   1714 
   1715   if (args.filter_rows == 3 && args.filter_cols == 3) {
   1716     LaunchDepthwiseConv2dBackpropFilterGPU<T, 3, 3>(
   1717         device, args, out_backprop, input, filter_backprop, data_format);
   1718   } else {
   1719     LaunchDepthwiseConv2dBackpropFilterGPU<T, -1, -1>(
   1720         device, args, out_backprop, input, filter_backprop, data_format);
   1721   }
   1722   OP_REQUIRES(ctx, stream->ok(),
   1723               errors::Internal("Launch of gpu kernel for "
   1724                                "DepthwiseConv2dBackpropFil"
   1725                                "terGPULaunch failed"));
   1726 }
   1727 
   1728 template struct LaunchDepthwiseConvBackpropFilterOp<GpuDevice, Eigen::half>;
   1729 template struct LaunchDepthwiseConvBackpropFilterOp<GpuDevice, float>;
   1730 template struct LaunchDepthwiseConvBackpropFilterOp<GpuDevice, double>;
   1731 }  // namespace tensorflow
   1732 #endif  // GOOGLE_CUDA
   1733