Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #if GOOGLE_CUDA
     17 
     18 #define EIGEN_USE_GPU
     19 
     20 #include <algorithm>
     21 #include <array>
     22 #include <limits>
     23 #include <utility>
     24 
     25 #include "cuda/include/cuda.h"
     26 #include "tensorflow/core/framework/register_types.h"
     27 #include "tensorflow/core/kernels/conv_2d.h"
     28 #include "tensorflow/core/lib/math/math_util.h"
     29 #include "tensorflow/core/util/cuda_kernel_helper.h"
     30 #include "tensorflow/core/util/tensor_format.h"
     31 
     32 namespace tensorflow {
     33 
     34 typedef Eigen::GpuDevice GPUDevice;
     35 
     36 namespace functor {
     37 namespace {
     38 template <typename T, bool conjugate>
     39 struct maybe_conj {
     40   __device__ static __inline__ T run(T x) {
     41     if (conjugate) {
     42       return Eigen::numext::conj(x);
     43     } else {
     44       return x;
     45     }
     46   }
     47 };
     48 
     49 // Partial specializations for Cuda types used to store complex numbers.
     50 template <bool conjugate>
     51 struct maybe_conj<float2, conjugate> {
     52   __device__ static __inline__ float2 run(float2 c) {
     53     if (conjugate) {
     54       float2 c_conj;
     55       c_conj.x = c.x;
     56       c_conj.y = -c.y;
     57       return c_conj;
     58     } else {
     59       return c;
     60     }
     61   }
     62 };
     63 
     64 template <bool conjugate>
     65 struct maybe_conj<double2, conjugate> {
     66   __device__ static __inline__ double2 run(double2 c) {
     67     if (conjugate) {
     68       double2 c_conj;
     69       c_conj.x = c.x;
     70       c_conj.y = -c.y;
     71       return c_conj;
     72     } else {
     73       return c;
     74     }
     75   }
     76 };
     77 
     78 }  // namespace
     79 
     80 // TODO(mjanusz): Move this to a shared util file.
     81 // A simple array that contains data that can be passed between CPU and GPU.
     82 template <typename T, int IndexCount, T DefaultValue>
     83 struct Array {
     84   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& operator[](int index) const {
     85     return data[index];
     86   }
     87   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T& operator[](int index) {
     88     return data[index];
     89   }
     90   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array() {
     91     for (int i = 0; i < IndexCount; i++) {
     92       data[i] = DefaultValue;
     93     }
     94   }
     95   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0) {
     96     data[0] = a0;
     97     for (int i = 1; i < IndexCount; i++) {
     98       data[i] = DefaultValue;
     99     }
    100   }
    101   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1) {
    102     data[0] = a0;
    103     data[1] = a1;
    104     for (int i = 2; i < IndexCount; i++) {
    105       data[i] = DefaultValue;
    106     }
    107   }
    108   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1, T a2) {
    109     data[0] = a0;
    110     data[1] = a1;
    111     data[2] = a2;
    112     for (int i = 3; i < IndexCount; i++) {
    113       data[i] = DefaultValue;
    114     }
    115   }
    116   EIGEN_STRONG_INLINE Array(const std::array<T, IndexCount>& array) {
    117     for (int i = 0; i < IndexCount; i++) {
    118       data[i] = array[i];
    119     }
    120   }
    121   T data[IndexCount];
    122 };
    123 
    124 // A dimension type with compile-time known size.
    125 template <int IndexCount>
    126 struct Dimension : Array<int, IndexCount, 1> {
    127   typedef Array<int, IndexCount, 1> Base;
    128   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension() : Base() {}
    129   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0) : Base(a0) {}
    130   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1)
    131       : Base(a0, a1) {}
    132   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1, int a2)
    133       : Base(a0, a1, a2) {}
    134   EIGEN_STRONG_INLINE Dimension(const std::array<int, IndexCount>& array)
    135       : Base(array) {}
    136 };
    137 
    138 // An index type with compile-time known size.
    139 template <int IndexCount>
    140 struct Index : Array<int, IndexCount, 0> {
    141   typedef Array<int, IndexCount, 0> Base;
    142   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index() : Base() {}
    143   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0) : Base(a0) {}
    144   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1) : Base(a0, a1) {}
    145   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1, int a2)
    146       : Base(a0, a1, a2) {}
    147 };
    148 
    149 // A helper function that converts a tensor index into a flat array index.
    150 template <int IndexCount>
    151 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int TensorIndexToFlat(
    152     const Index<IndexCount>& index, const Dimension<IndexCount>& dims) {
    153   int flat_index = index[0];
    154   for (int i = 1; i < IndexCount; i++) {
    155     flat_index = flat_index * dims[i] + index[i];
    156   }
    157   return flat_index;
    158 }
    159 
    160 // A helper function that converts a flat array index into a tensor index.
    161 template <int IndexCount>
    162 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index<IndexCount> FlatToTensorIndex(
    163     int index, const Dimension<IndexCount>& dims) {
    164   Index<IndexCount> tensor_index;
    165   for (int i = IndexCount - 1; i >= 0; i--) {
    166     int new_index = index / dims[i];
    167     tensor_index[i] = index - dims[i] * new_index;
    168     index = new_index;
    169   }
    170   return tensor_index;
    171 }
    172 
    173 // A Cuda custom kernel that swaps dimension-0 and dimension-2 of a 3D tensor.
    174 template <typename T, bool conjugate = false>
    175 __global__ void SwapDimension0And2InTensor3Simple(int nthreads, const T* input,
    176                                                   Dimension<3> input_dims,
    177                                                   T* output) {
    178   Dimension<3> output_dims;
    179   output_dims[0] = input_dims[2];
    180   output_dims[1] = input_dims[1];
    181   output_dims[2] = input_dims[0];
    182 
    183   CUDA_1D_KERNEL_LOOP(index, nthreads) {
    184     int output_index = index;
    185 
    186     Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
    187 
    188     Index<3> input_tensor_index;
    189     input_tensor_index[0] = output_tensor_index[2];
    190     input_tensor_index[1] = output_tensor_index[1];
    191     input_tensor_index[2] = output_tensor_index[0];
    192 
    193     int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
    194 
    195     output[output_index] =
    196         maybe_conj<T, conjugate>::run(ldg(input + input_index));
    197   }
    198 }
    199 
    200 // A Cuda custom kernel that swaps dimension-1 and dimension-2 of a 3D tensor.
    201 template <typename T, bool conjugate = false>
    202 __global__ void SwapDimension1And2InTensor3Simple(int nthreads, const T* input,
    203                                                   Dimension<3> input_dims,
    204                                                   T* output) {
    205   Dimension<3> output_dims;
    206   output_dims[0] = input_dims[0];
    207   output_dims[1] = input_dims[2];
    208   output_dims[2] = input_dims[1];
    209 
    210   CUDA_1D_KERNEL_LOOP(index, nthreads) {
    211     int output_index = index;
    212     Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
    213 
    214     Index<3> input_tensor_index;
    215     input_tensor_index[0] = output_tensor_index[0];
    216     input_tensor_index[1] = output_tensor_index[2];
    217     input_tensor_index[2] = output_tensor_index[1];
    218 
    219     int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
    220 
    221     output[output_index] =
    222         maybe_conj<T, conjugate>::run(ldg(input + input_index));
    223   }
    224 }
    225 
    226 // Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor,
    227 // where dimensions are zero-based: output[i][j][k] = input[i][k][j].
    228 //
    229 // Each thread block operates on a single tile, a rectangle of dimensions
    230 // TileSizeI x TileSizeJ.
    231 //
    232 // In general, for best performance, you should probably set TileSizeI,
    233 // TileSizeJ equal to the number of threads in a warp (32 in nvidia GPUs).
    234 // With a TileSizeI, TileSizeJ of 32, NumThreads of 128 or 256 seems to get
    235 // the best performance on K40 GPUs.
    236 template <typename T, int NumThreads, int TileSizeI, int TileSizeJ,
    237           bool conjugate = false>
    238 __global__ void SwapDimension1And2InTensor3UsingTiles(
    239     const T* __restrict__ input, Dimension<3> input_dims,
    240     T* __restrict__ output) {
    241   eigen_assert(blockDim.x == NumThreads);
    242   eigen_assert(blockDim.y == 1);
    243   eigen_assert(blockDim.z == 1);
    244   eigen_assert(gridDim.y == 1);
    245   eigen_assert(gridDim.z == 1);
    246 
    247   constexpr int ReadRowPerPass = NumThreads / TileSizeJ;
    248   constexpr int WriteRowPerPass = NumThreads / TileSizeI;
    249   // One extra line in the inner dimension to avoid share memory bank conflict.
    250   __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1];
    251 
    252   int x = threadIdx.x;
    253 
    254   Dimension<3> output_dims = {
    255       input_dims[0],
    256       input_dims[2],
    257       input_dims[1],
    258   };
    259 
    260   Dimension<3> input_dims_in_tiles = {
    261       input_dims[0],
    262       (input_dims[1] + TileSizeI - 1) / TileSizeI,
    263       (input_dims[2] + TileSizeJ - 1) / TileSizeJ,
    264   };
    265 
    266   Index<3> input_tile_index =
    267       FlatToTensorIndex(blockIdx.x, input_dims_in_tiles);
    268 
    269   Index<3> input_tile_origin = {
    270       input_tile_index[0],
    271       input_tile_index[1] * TileSizeI,
    272       input_tile_index[2] * TileSizeJ,
    273   };
    274 
    275   int input_origin_flat_index =
    276       TensorIndexToFlat(input_tile_origin, input_dims);
    277 
    278   bool full_tile = true;
    279   int tile_width = TileSizeJ;
    280 
    281   // Only the last row or column may not have the full size.
    282   if (input_tile_index[2] == input_dims_in_tiles[2] - 1) {
    283     tile_width = input_dims[2] - (input_dims_in_tiles[2] - 1) * TileSizeJ;
    284     full_tile &= false;
    285   }
    286 
    287   int tile_height = TileSizeI;
    288 
    289   if (input_tile_index[1] == input_dims_in_tiles[1] - 1) {
    290     tile_height = input_dims[1] - (input_dims_in_tiles[1] - 1) * TileSizeI;
    291     full_tile &= false;
    292   }
    293 
    294   // Calculate effective thread number. This ensures that we use the largest
    295   // number of threads available to form a regular thread block with no
    296   // trailing incomplete lines.
    297   constexpr int in_effective_thread_num = NumThreads / TileSizeJ * TileSizeJ;
    298 
    299   if (x < in_effective_thread_num) {
    300     // Orient the logical thread block with respect to the input array.
    301     // ie. align the contiguous dimension of thread blocks with the contiguous
    302     // dimension of the input array.
    303     int ti = x / TileSizeJ;
    304     int tj = x % TileSizeJ;
    305     int input_index = input_origin_flat_index + ti * input_dims[2] + tj;
    306     int input_increment = ReadRowPerPass * input_dims[2];
    307 
    308     if (full_tile) {
    309 #pragma unroll
    310       for (int i_loc = ti; i_loc < (TileSizeI); i_loc += ReadRowPerPass) {
    311         shared_memory_tile[i_loc][tj] =
    312             maybe_conj<T, conjugate>::run(input[input_index]);
    313         input_index += input_increment;
    314       }
    315     } else {
    316       if (tj < tile_width) {
    317         for (int i_loc = ti; i_loc < (tile_height); i_loc += ReadRowPerPass) {
    318           shared_memory_tile[i_loc][tj] =
    319               maybe_conj<T, conjugate>::run(input[input_index]);
    320           input_index += input_increment;
    321         }
    322       }
    323     }
    324   }
    325 
    326   __syncthreads();
    327 
    328   Index<3> output_tile_index = {
    329       input_tile_index[0],
    330       input_tile_index[2],
    331       input_tile_index[1],
    332   };
    333 
    334   Index<3> output_tile_origin = {
    335       output_tile_index[0],
    336       output_tile_index[1] * TileSizeJ,
    337       output_tile_index[2] * TileSizeI,
    338   };
    339 
    340   int output_origin_flat_index =
    341       TensorIndexToFlat(output_tile_origin, output_dims);
    342 
    343   constexpr int out_effective_thread_num = NumThreads / TileSizeI * TileSizeI;
    344 
    345   if (x < out_effective_thread_num) {
    346     // Re-orient the logical thread block with respect to the output array.
    347     // ie. align the contiguous dimension of thread blocks with contiguous
    348     // dimension of the output array.
    349     int ti = x / TileSizeI;
    350     int tj = x % TileSizeI;
    351     int output_index = output_origin_flat_index + ti * output_dims[2] + tj;
    352     int output_increment = WriteRowPerPass * output_dims[2];
    353 
    354     if (full_tile) {
    355 #pragma unroll
    356       for (int i_loc = ti; i_loc < (TileSizeJ); i_loc += WriteRowPerPass) {
    357         output[output_index] = shared_memory_tile[tj][i_loc];
    358         output_index += output_increment;
    359       }
    360     } else {
    361       if (tj < tile_height) {
    362         for (int i_loc = ti; i_loc < (tile_width); i_loc += WriteRowPerPass) {
    363           output[output_index] = shared_memory_tile[tj][i_loc];
    364           output_index += output_increment;
    365         }
    366       }
    367     }
    368   }
    369 }
    370 
    371 // A Cuda custom kernel that convert input to output, given proper padding on
    372 // the left and the top. The padded value is zero.
    373 template <typename T, int NDIMS>
    374 __global__ void PadInputCustomKernelNHWC(int nthreads, const T* input,
    375                                          Dimension<NDIMS> input_dims, T* output,
    376                                          Dimension<NDIMS> output_dims,
    377                                          Dimension<NDIMS - 2> padding_left) {
    378   CUDA_1D_KERNEL_LOOP(index, nthreads) {
    379     int output_index = index;
    380     Index<NDIMS> output_tensor_index =
    381         FlatToTensorIndex(output_index, output_dims);
    382 
    383     Index<NDIMS> input_tensor_index;
    384     input_tensor_index[0] = output_tensor_index[0];  // batch
    385     bool ok = true;
    386     for (int i = 1; i < NDIMS - 1; i++) {
    387       input_tensor_index[i] = output_tensor_index[i] - padding_left[i - 1];
    388       ok &=
    389           (input_tensor_index[i] >= 0 && input_tensor_index[i] < input_dims[i]);
    390     }
    391     input_tensor_index[NDIMS - 1] = output_tensor_index[NDIMS - 1];  // channels
    392 
    393     if (ok) {
    394       const int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
    395       output[output_index] = input[input_index];
    396     } else {
    397       output[output_index] = T(0);
    398     }
    399   }
    400 }
    401 
    402 template <typename T, int NDIMS>
    403 __global__ void PadInputCustomKernelNCHW(int nthreads, const T* input,
    404                                          Dimension<NDIMS> input_dims, T* output,
    405                                          Dimension<NDIMS> output_dims,
    406                                          Dimension<NDIMS - 2> padding_left) {
    407   CUDA_1D_KERNEL_LOOP(index, nthreads) {
    408     int output_index = index;
    409     Index<NDIMS> output_tensor_index =
    410         FlatToTensorIndex(output_index, output_dims);
    411 
    412     Index<NDIMS> input_tensor_index;
    413     input_tensor_index[0] = output_tensor_index[0];  // batch
    414     input_tensor_index[1] = output_tensor_index[1];  // channels
    415     bool ok = true;
    416     for (int i = 2; i < NDIMS; i++) {
    417       input_tensor_index[i] = output_tensor_index[i] - padding_left[i - 2];
    418       ok &=
    419           (input_tensor_index[i] >= 0 && input_tensor_index[i] < input_dims[i]);
    420     }
    421 
    422     if (ok) {
    423       const int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
    424       output[output_index] = input[input_index];
    425     } else {
    426       output[output_index] = T(0);
    427     }
    428   }
    429 }
    430 
    431 // A GPU helper function that converts TensorFlow filter format to Cudnn filter
    432 // format.
    433 template <typename T, int NDIMS>
    434 struct TransformFilter<GPUDevice, T, int, NDIMS> {
    435   typedef GPUDevice Device;
    436   void operator()(const Device& d,
    437                   typename TTypes<T, NDIMS, int>::ConstTensor in,
    438                   typename TTypes<T, NDIMS, int>::Tensor out) {
    439     Dimension<3> combined_dims;
    440     combined_dims[0] = in.dimension(0);  // spatial dimensions
    441     for (int i = 1; i < NDIMS - 2; i++) {
    442       combined_dims[0] *= in.dimension(i);
    443     }
    444     combined_dims[1] = in.dimension(NDIMS - 2);  // input filters
    445     combined_dims[2] = in.dimension(NDIMS - 1);  // output filters
    446     CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
    447     SwapDimension0And2InTensor3Simple<T>
    448         <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
    449             config.virtual_thread_count, in.data(), combined_dims, out.data());
    450   }
    451 };
    452 
    453 // Converts Cudnn filter format back to TensorFlow filter format.
    454 template <typename T, int NDIMS>
    455 struct ReverseTransformFilter<GPUDevice, T, NDIMS> {
    456   typedef GPUDevice Device;
    457   void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
    458                   typename TTypes<T, NDIMS>::Tensor out) {
    459     Dimension<3> combined_dims;
    460     combined_dims[0] = in.dimension(0);  // output filters
    461     combined_dims[1] = in.dimension(1);  // input filters
    462     combined_dims[2] = in.dimension(2);  // spatial dimensions
    463     for (int i = 3; i < NDIMS; ++i) {
    464       combined_dims[2] *= in.dimension(i);
    465     }
    466     CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
    467     SwapDimension0And2InTensor3Simple<T>
    468         <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
    469             config.virtual_thread_count, in.data(), combined_dims, out.data());
    470   }
    471 };
    472 
    473 // A GPU helper function that converts input tensor to a larger output tensor,
    474 // given proper padding values. The padded value is zero.
    475 template <typename T, int NDIMS>
    476 struct PadInput<GPUDevice, T, int, NDIMS> {
    477   typedef GPUDevice Device;
    478   void operator()(const Device& d,
    479                   typename TTypes<T, NDIMS, int>::ConstTensor in,
    480                   const std::array<int, NDIMS - 2>& padding_left,
    481                   const std::array<int, NDIMS - 2>& padding_right,
    482                   typename TTypes<T, NDIMS, int>::Tensor out,
    483                   TensorFormat format) {
    484     CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
    485     Dimension<NDIMS> input_dims;
    486     for (int i = 0; i < NDIMS; ++i) {
    487       input_dims[i] = in.dimension(i);
    488     }
    489     Dimension<NDIMS> output_dims;
    490     for (int i = 0; i < NDIMS; ++i) {
    491       output_dims[i] = out.dimension(i);
    492     }
    493 
    494     const Dimension<NDIMS - 2> padding_left_dim(padding_left);
    495 
    496     if (format == FORMAT_NHWC) {
    497       PadInputCustomKernelNHWC<T, NDIMS>
    498           <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
    499               config.virtual_thread_count, in.data(), input_dims, out.data(),
    500               output_dims, padding_left_dim);
    501     } else if (format == FORMAT_NCHW) {
    502       PadInputCustomKernelNCHW<T, NDIMS>
    503           <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
    504               config.virtual_thread_count, in.data(), input_dims, out.data(),
    505               output_dims, padding_left_dim);
    506     } else {
    507       LOG(FATAL) << "Invalid data format: " << format;
    508     }
    509   }
    510 };
    511 
    512 // We want std::equal_to and std::greater, but they're not constexpr until
    513 // C++14.
    514 struct EqualTo {
    515   constexpr bool operator()(int a, int b) const { return a == b; }
    516 };
    517 
    518 struct GreaterThan {
    519   constexpr bool operator()(int a, int b) const { return a > b; }
    520 };
    521 
    522 // For each data type, the tile size possibility frontier denotes the tile size
    523 // combinations that consume the most computational resources constrained by
    524 // - number of threads per SM limit,
    525 // - limit on size of the short dimension (<=15) due to the definition of
    526 //   narrow matrix,
    527 // - shared memory limit and
    528 // - some experimentally determined, type-specific constraint on the product of
    529 //   two side lengths to increase grid-level parallelism.
    530 //
    531 // A tile size combination lies on the frontier if and only if one or more
    532 // constraint mentioned above is hit. Tile size combinations lying outside this
    533 // frontier are either not possible, or are slower than the alternatives.
    534 //
    535 // It is instrumental to consider, for each data type, two subsets of the
    536 // corresponding frontier:
    537 // - long side frontier: the union of the biggest tile size combination for
    538 //   each legal long side len.
    539 // - non long side frontier: the frontier set minus the long side frontier.
    540 //
    541 // TileSizePossibilityFrontierCheck defines the frontier using only the long
    542 // side frontier tile size combinations (since one can easily extrapolate
    543 // the entire frontier from this subset). It serves as a utility function
    544 // to help us determine where a tile size combination of interest lies with
    545 // resepect to the frontier.
    546 template <typename Op>
    547 constexpr bool TileSizePossibilityFrontierCheck(int TileLongSide,
    548                                                 int TileShortSide,
    549                                                 int size_of_t, Op op) {
    550   // clang-format off
    551 
    552   return (size_of_t == 16 && ((TileLongSide == 32   && op(TileShortSide, 4))  ||
    553                              (TileLongSide == 64   && op(TileShortSide, 4))  ||
    554                              (TileLongSide == 128  && op(TileShortSide, 4))  ||
    555                              (TileLongSide == 256  && op(TileShortSide, 2)))) ||
    556           (size_of_t == 8 && ((TileLongSide == 32   && op(TileShortSide, 15)) ||
    557                              (TileLongSide == 64   && op(TileShortSide, 15)) ||
    558                              (TileLongSide == 128  && op(TileShortSide, 8))  ||
    559                              (TileLongSide == 256  && op(TileShortSide, 4))  ||
    560                              (TileLongSide == 512  && op(TileShortSide, 2)))) ||
    561           (size_of_t == 4 && ((TileLongSide == 32   && op(TileShortSide, 15)) ||
    562                              (TileLongSide == 64   && op(TileShortSide, 15)) ||
    563                              (TileLongSide == 128  && op(TileShortSide, 15)) ||
    564                              (TileLongSide == 256  && op(TileShortSide, 8))  ||
    565                              (TileLongSide == 512  && op(TileShortSide, 4))  ||
    566                              (TileLongSide == 1024 && op(TileShortSide, 2)))) ||
    567           (size_of_t == 2 && ((TileLongSide == 32   && op(TileShortSide, 15)) ||
    568                              (TileLongSide == 64   && op(TileShortSide, 15)) ||
    569                              (TileLongSide == 128  && op(TileShortSide, 15)) ||
    570                              (TileLongSide == 256  && op(TileShortSide, 8))  ||
    571                              (TileLongSide == 512  && op(TileShortSide, 4))  ||
    572                              (TileLongSide == 1024 && op(TileShortSide, 2)))) ||
    573           (size_of_t == 1 && ((TileLongSide == 32   && op(TileShortSide, 15)) ||
    574                              (TileLongSide == 64   && op(TileShortSide, 15)) ||
    575                              (TileLongSide == 128  && op(TileShortSide, 15)) ||
    576                              (TileLongSide == 256  && op(TileShortSide, 8))  ||
    577                              (TileLongSide == 512  && op(TileShortSide, 4))  ||
    578                              (TileLongSide == 1024 && op(TileShortSide, 2))));
    579 
    580   // clang-format on
    581 }
    582 
    583 constexpr bool TileSizeOnLongSideFrontier(int TileLongSide, int TileShortSide,
    584                                           int size_of_t) {
    585   return TileSizePossibilityFrontierCheck(TileLongSide, TileShortSide,
    586                                           size_of_t, EqualTo());
    587 }
    588 constexpr bool TileSizeOutsideFrontier(int TileLongSide, int TileShortSide,
    589                                        int size_of_t) {
    590   return TileSizePossibilityFrontierCheck(TileLongSide, TileShortSide,
    591                                           size_of_t, GreaterThan());
    592 }
    593 constexpr bool TileSizeOnNonLongSideFrontier(int TileLongSide,
    594                                              int TileShortSide, int size_of_t) {
    595   // For a tile size combination (longside, shortside), lying on the frontier
    596   // implies that (longside, shortside) is on or within the frontier but
    597   // (longside*2, shortside) or (longside, shortside+1) is not. With the above
    598   // critereon, we simply need to use !TileSizeOnLongSideFrontier to ensure that
    599   // it is not on the long side frontier.
    600   return !TileSizeOutsideFrontier(TileLongSide, TileShortSide, size_of_t) &&
    601          (TileSizeOutsideFrontier(TileLongSide * 2, TileShortSide, size_of_t) ||
    602           TileSizeOutsideFrontier(TileLongSide, TileShortSide + 1,
    603                                   size_of_t)) &&
    604          !TileSizeOnLongSideFrontier(TileLongSide, TileShortSide, size_of_t);
    605 }
    606 
    607 // Helper function to launch a batch narrow matirx transpose kernel.
    608 template <typename T, int TileLongSide, int TileShortSide>
    609 void LaunchBatchNarrowMatrixTransposeKernel(
    610     const GPUDevice& d, int tile_size_i, int tile_size_j, int total_tiles_count,
    611     const T* input, const Dimension<3>& input_dims, T* output) {
    612   constexpr int NumThreads = TileLongSide;
    613   if (tile_size_i <= TileLongSide && tile_size_j <= TileShortSide) {
    614     SwapDimension1And2InTensor3UsingTiles<T, NumThreads, TileLongSide,
    615                                           TileShortSide>
    616         <<<total_tiles_count, NumThreads, 0, d.stream()>>>(input, input_dims,
    617                                                            output);
    618   } else {
    619     SwapDimension1And2InTensor3UsingTiles<T, NumThreads, TileShortSide,
    620                                           TileLongSide>
    621         <<<total_tiles_count, NumThreads, 0, d.stream()>>>(input, input_dims,
    622                                                            output);
    623   }
    624 }
    625 
    626 // Recursive template function to search, in a trial-and-error manner, for the
    627 // minimum tile size configuration satisfying the requested tile side lengths.
    628 // An important invariant of this search procedure is that for an unsatisfied
    629 // request, we always try doubling the long side len first, and only after
    630 // the request is satisfied for the long side len do we begin incrementing
    631 // the short side len.
    632 //
    633 // We have three specializations of this search function depending on where the
    634 // current tile size combination lies with respect to the frontier.
    635 // - It lies within the frontier. If request is not satisfied, for the next tile
    636 // size combination, we first try doubling the long side len and if that does
    637 // not work, we then increment the short side len.
    638 // - It lies on the non long side frontier. If the request is not satisfied, we
    639 // can only increment the short side len.
    640 // - It lies on the long side frontier. We launch the kernel without checking if
    641 // the request is satisfied or not.
    642 template <typename T, int TileLongSide, int TileShortSide,
    643           typename dummy = void>
    644 struct BatchNarrowMatrixTransposeDispatcher {
    645   static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j,
    646                    int total_tiles_count, const T* input,
    647                    const Dimension<3>& input_dims, T* output) {
    648     static_assert(
    649         (TileLongSide & (TileLongSide - 1)) == 0,
    650         "The length of the longer side of the tile is always a power of 2.");
    651     bool request_satisfied =
    652         std::max(tile_size_i, tile_size_j) <= TileLongSide &&
    653         std::min(tile_size_i, tile_size_j) <= TileShortSide;
    654 
    655     if (request_satisfied) {
    656       LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide>(
    657           d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims,
    658           output);
    659       return;
    660     }
    661 
    662     // If the execution reaches here, then the kernel was not launched; we then
    663     // determine whether it is the long side or the short side that falls short
    664     // of the request and increase that parameter accordingly.
    665     const bool long_side_request_not_satisfied =
    666         std::max(tile_size_i, tile_size_j) > TileLongSide;
    667 
    668     if (long_side_request_not_satisfied) {
    669       BatchNarrowMatrixTransposeDispatcher<
    670           T, TileLongSide * 2, TileShortSide>::DoIt(d, tile_size_i, tile_size_j,
    671                                                     total_tiles_count, input,
    672                                                     input_dims, output);
    673     } else {
    674       BatchNarrowMatrixTransposeDispatcher<
    675           T, TileLongSide, TileShortSide + 1>::DoIt(d, tile_size_i, tile_size_j,
    676                                                     total_tiles_count, input,
    677                                                     input_dims, output);
    678     }
    679   }
    680 };
    681 
    682 template <typename T, int TileLongSide, int TileShortSide>
    683 struct BatchNarrowMatrixTransposeDispatcher<
    684     T, TileLongSide, TileShortSide,
    685     typename std::enable_if<TileSizeOnNonLongSideFrontier(
    686                                 TileLongSide, TileShortSide, sizeof(T)),
    687                             void>::type> {
    688   static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j,
    689                    int total_tiles_count, const T* input,
    690                    const Dimension<3>& input_dims, T* output) {
    691     static_assert(
    692         (TileLongSide & (TileLongSide - 1)) == 0,
    693         "The length of the longer side of the tile is always a power of 2.");
    694     bool request_satisfied =
    695         std::max(tile_size_i, tile_size_j) <= TileLongSide &&
    696         std::min(tile_size_i, tile_size_j) <= TileShortSide;
    697 
    698     if (request_satisfied) {
    699       LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide>(
    700           d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims,
    701           output);
    702       return;
    703     }
    704 
    705     // If the execution reaches here, then the kernel was not launched; since
    706     // we are on the non long side frontier, we increment the short dimension
    707     // and try again.
    708     BatchNarrowMatrixTransposeDispatcher<
    709         T, TileLongSide, TileShortSide + 1>::DoIt(d, tile_size_i, tile_size_j,
    710                                                   total_tiles_count, input,
    711                                                   input_dims, output);
    712   }
    713 };
    714 
    715 template <typename T, int TileLongSide, int TileShortSide>
    716 struct BatchNarrowMatrixTransposeDispatcher<
    717     T, TileLongSide, TileShortSide,
    718     typename std::enable_if<TileSizeOnLongSideFrontier(
    719                                 TileLongSide, TileShortSide, sizeof(T)),
    720                             void>::type> {
    721   static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j,
    722                    int total_tiles_count, const T* input,
    723                    const Dimension<3>& input_dims, T* output) {
    724     static_assert(
    725         (TileLongSide & (TileLongSide - 1)) == 0,
    726         "The length of the longer side of the tile is always a power of 2.");
    727 
    728     LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide>(
    729         d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims,
    730         output);
    731   }
    732 };
    733 
    734 // This function tries to recover, in a brute force way, the frontier defined in
    735 // TileSizePossibilityFrontierCheck as a vector of tile size combinations lying
    736 // on the long side frontier. This vector is sufficient to determine the entire
    737 // frontier.
    738 //
    739 // Note that if one changes the frontier definition in
    740 // TileSizePossibilityFrontierCheck and forgets to set the largest short
    741 // side len of the largest legal long side len to 2, this function will fail
    742 // and crash the program.
    743 template <int SizeOfT>
    744 const std::vector<std::pair<int, int>>& GetTileSizesFrontier() {
    745   static_assert(
    746       SizeOfT <= 16,
    747       "Currently, only data types of sizes 16 bytes or less are supported.");
    748   static_assert((SizeOfT & (SizeOfT - 1)) == 0,
    749                 "Data types must have sizes that are powers of 2.");
    750 
    751   // Expensive work to populate sizes, lazily run in a thread-safe
    752   // manner the first time GetTileSizesFrontier<N> is called.
    753   static auto* frontier = [] {
    754     auto* frontier = new std::vector<std::pair<int, int>>();
    755     const int kMaxLongSideLen = 1024;
    756     const int kMaxShortSideLen = 15;
    757     for (int long_side = 32; long_side <= kMaxLongSideLen; long_side *= 2) {
    758       for (int short_side = 2; short_side <= kMaxShortSideLen;
    759            short_side += 1) {
    760         if (TileSizeOnLongSideFrontier(long_side, short_side, SizeOfT)) {
    761           // The current combination lies on the frontier, thus we
    762           // add it to the frontier definition.
    763           frontier->push_back(std::make_pair(long_side, short_side));
    764 
    765           // The long side length is the largest one allowed iff its
    766           // corresponding short side length is 2.
    767           if (short_side == 2) return frontier;
    768 
    769           // We have exhausted all the possibilities in the frontier
    770           // with the given long side length.
    771           break;
    772         }
    773       }
    774     }
    775     LOG(FATAL)
    776         << "The corresponding short side length of the largest long side "
    777            "length has to be 2.";
    778   }();
    779   return *frontier;
    780 }
    781 
    782 // Helper structs to help determine which data type to use given the size of
    783 // the matrix data type. A transpose of elements of size N will use a kernel
    784 // which operates on an array of TransposeElemType<N>::type.
    785 template <int ElemBytes>
    786 struct TransposeElemType;
    787 template <>
    788 struct TransposeElemType<1> {
    789   using type = uint8;
    790 };
    791 template <>
    792 struct TransposeElemType<2> {
    793   using type = uint16;
    794 };
    795 template <>
    796 struct TransposeElemType<4> {
    797   using type = uint32;
    798 };
    799 template <>
    800 struct TransposeElemType<8> {
    801   using type = uint64;
    802 };
    803 template <>
    804 struct TransposeElemType<16> {
    805   using type = float4;
    806 };
    807 
    808 // A helper function to make RunSwapDimension1And2InTensor3 concise. This
    809 // helper function looks at the data type and input matrix sizes and decides
    810 // the thread numbers and tile sizes to use.
    811 template <typename T, bool conjugate = false>
    812 void SwapDimension1And2InTensor3WithNarrowMatrices(
    813     const GPUDevice& d, const T* input, const Dimension<3>& input_dims,
    814     T* output, const int kMinDimensionToUseTiles) {
    815   // Get available tile sizes here for the data type requested:
    816   const auto& tile_spec = GetTileSizesFrontier<sizeof(T)>();
    817 
    818   int tile_long_side_len = 0;
    819   int tile_short_side_len = 0;
    820   float lowest_cost = std::numeric_limits<float>::max();
    821   int data_long_side = std::max(input_dims[1], input_dims[2]);
    822 
    823   for (auto tile_size_pair : tile_spec) {
    824     int proposed_tile_long_side_len = tile_size_pair.first;
    825 
    826     // Number of threads that will not be doing anything useful when reading
    827     // the matrix because the thread block size is bigger than the data block
    828     // size.
    829     int num_wasted_threads =
    830         data_long_side - MathUtil::FloorOfRatio<int>(
    831                              data_long_side, proposed_tile_long_side_len) *
    832                              proposed_tile_long_side_len;
    833 
    834     int num_full_tiles = MathUtil::FloorOfRatio<int>(
    835         data_long_side, proposed_tile_long_side_len);
    836 
    837     float cost = 0;
    838 
    839     // However, if we can execute two or more full tiles, then we gladly
    840     // accept any number of wasted threads and ignore its cost.
    841     if (num_full_tiles <= 1) cost = num_wasted_threads;
    842 
    843     // Using less than or equal to here because given the same cost, we
    844     // would like to launch as many threads as possible.
    845     if (cost <= lowest_cost) {
    846       tile_long_side_len = proposed_tile_long_side_len;
    847       tile_short_side_len = tile_size_pair.second;
    848       lowest_cost = cost;
    849     }
    850   }
    851 
    852   // Request tile sizes such that the longer side of threadblock aligns with
    853   // the longer side of input data block to maximize read throughput.
    854   // The ideal tile shape is one where the length of the shorter side of the
    855   // tile is equal to the length of the shorter side of the input matrix.
    856   int requested_tile_size_i = input_dims[1] >= kMinDimensionToUseTiles
    857                                   ? tile_long_side_len
    858                                   : input_dims[1];
    859   int requested_tile_size_j = input_dims[1] >= kMinDimensionToUseTiles
    860                                   ? input_dims[2]
    861                                   : tile_long_side_len;
    862 
    863   // Truncate the shorter size requested according to the manual limit set in
    864   // tile_spec to make sure that we do not launch configurations violating
    865   // hardware limits.
    866   requested_tile_size_i =
    867       requested_tile_size_i == tile_long_side_len
    868           ? tile_long_side_len
    869           : std::min(requested_tile_size_i, tile_short_side_len);
    870   requested_tile_size_j =
    871       requested_tile_size_j == tile_long_side_len
    872           ? tile_long_side_len
    873           : std::min(requested_tile_size_j, tile_short_side_len);
    874 
    875   Dimension<3> input_dims_in_tiles = {
    876       input_dims[0],
    877       MathUtil::CeilOfRatio<int>(input_dims[1], requested_tile_size_i),
    878       MathUtil::CeilOfRatio<int>(input_dims[2], requested_tile_size_j),
    879   };
    880 
    881   int total_tiles_count =
    882       input_dims_in_tiles[0] * input_dims_in_tiles[1] * input_dims_in_tiles[2];
    883 
    884   using ElemType = typename TransposeElemType<sizeof(T)>::type;
    885   static_assert(alignof(T) >= alignof(ElemType), "Unexpected data alignment.");
    886   BatchNarrowMatrixTransposeDispatcher<ElemType, 32, 2>::DoIt(
    887       d, requested_tile_size_i, requested_tile_size_j, total_tiles_count,
    888       reinterpret_cast<const ElemType*>(input), input_dims,
    889       reinterpret_cast<ElemType*>(output));
    890 }
    891 
    892 // Launch the GPU kernel that would swap dimension-1 and dimension-2 in a
    893 // 3D tensor. It looks at the shape of the incoming data, and decides the best
    894 // strategy to launch.
    895 template <typename T, bool conjugate = false>
    896 void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
    897                                     const Dimension<3>& input_dims, T* output) {
    898   // If both dimensions are not trivial, use tiles for the actual swapping.
    899   // If one dimension is trivial, use SmallDim kernel for swapping.
    900   // Otherwise, the trivial swapping relying on the ldg cache is more efficient.
    901   static const int kMinDimensionToUseTiles = 16;
    902   static const int kMinDimensionToUseRectTiles = 96;
    903 
    904   bool large_matrix = input_dims[1] >= kMinDimensionToUseTiles &&
    905                       input_dims[2] >= kMinDimensionToUseTiles;
    906   bool narrow_matrix = input_dims[1] >= kMinDimensionToUseRectTiles ||
    907                        input_dims[2] >= kMinDimensionToUseRectTiles;
    908   if (large_matrix) {
    909     // We get best performance when kTileSize is the number of threads in a warp
    910     // (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256
    911     // threads.
    912     constexpr int kTileSize = 32;
    913     constexpr int kNumThreads = 256;
    914 
    915     Dimension<3> input_dims_in_tiles = {
    916         input_dims[0],
    917         MathUtil::CeilOfRatio<int>(input_dims[1], kTileSize),
    918         MathUtil::CeilOfRatio<int>(input_dims[2], kTileSize),
    919     };
    920 
    921     int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] *
    922                             input_dims_in_tiles[2];
    923     SwapDimension1And2InTensor3UsingTiles<T, kNumThreads, kTileSize, kTileSize,
    924                                           conjugate>
    925         <<<total_tiles_count, kNumThreads, 0, d.stream()>>>(input, input_dims,
    926                                                             output);
    927 
    928   } else if (narrow_matrix) {
    929     SwapDimension1And2InTensor3WithNarrowMatrices<T, conjugate>(
    930         d, input, input_dims, output, kMinDimensionToUseTiles);
    931   } else {
    932     int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
    933     CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d);
    934     SwapDimension1And2InTensor3Simple<T, conjugate>
    935         <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
    936             config.virtual_thread_count, input, input_dims, output);
    937   }
    938 }
    939 
    940 // A GPU helper functor that does general dimension 1 and 2 switch for 3D
    941 // tensor.
    942 template <typename T, bool conjugate>
    943 struct SwapDimension1And2InTensor3<GPUDevice, T, conjugate> {
    944   typedef GPUDevice Device;
    945   void operator()(const Device& d, const T* in,
    946                   const gtl::ArraySlice<int64>& combined_dims, T* out) {
    947     Dimension<3> input_dims = {static_cast<int>(combined_dims[0]),
    948                                static_cast<int>(combined_dims[1]),
    949                                static_cast<int>(combined_dims[2])};
    950     RunSwapDimension1And2InTensor3<T, conjugate>(d, in, input_dims, out);
    951   }
    952 };
    953 
    954 // A GPU helper functor that does general dimension 0 and 2 switch for 3D
    955 // tensor.
    956 template <typename T, bool conjugate>
    957 struct SwapDimension0And2InTensor3<GPUDevice, T, conjugate> {
    958   typedef GPUDevice Device;
    959   void operator()(const Device& d, const T* in,
    960                   const gtl::ArraySlice<int64>& combined_dims, T* out) {
    961     Dimension<3> input_dims = {static_cast<int>(combined_dims[0]),
    962                                static_cast<int>(combined_dims[1]),
    963                                static_cast<int>(combined_dims[2])};
    964     size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2];
    965     CudaLaunchConfig config = GetCudaLaunchConfig(total_size, d);
    966     SwapDimension0And2InTensor3Simple<T, conjugate>
    967         <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
    968             config.virtual_thread_count, in, input_dims, out);
    969   }
    970 };
    971 
    972 // A GPU helper functor that converts NHWC TensorFlow data format to
    973 // NCHW format that is accepted by Cudnn.
    974 template <typename T, int NDIMS>
    975 struct NHWCToNCHW<GPUDevice, T, NDIMS> {
    976   typedef GPUDevice Device;
    977   void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
    978                   typename TTypes<T, NDIMS>::Tensor out) {
    979     Dimension<3> combined_dims;
    980     combined_dims[0] = in.dimension(0);  // N (batch)
    981     combined_dims[1] = in.dimension(1);  // spatial dimensions (HW)
    982     for (int i = 2; i < NDIMS - 1; ++i) {
    983       combined_dims[1] *= in.dimension(i);
    984     }
    985     combined_dims[2] = in.dimension(NDIMS - 1);  // C (channels)
    986     RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data());
    987   }
    988 };
    989 
    990 // A GPU helper functor that converts NCHW Cudnn data format to NHWC TensorFlow
    991 // Format.
    992 template <typename T, int NDIMS>
    993 struct NCHWToNHWC<GPUDevice, T, NDIMS> {
    994   typedef GPUDevice Device;
    995   void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
    996                   typename TTypes<T, NDIMS>::Tensor out) {
    997     Dimension<3> combined_dims;
    998     combined_dims[0] = in.dimension(0);  // N (batch)
    999     combined_dims[1] = in.dimension(1);  // C (channel)
   1000     combined_dims[2] = in.dimension(2);  // spatial dimensions (HW)
   1001     for (int i = 3; i < NDIMS; ++i) {
   1002       combined_dims[2] *= in.dimension(i);
   1003     }
   1004     RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data());
   1005   }
   1006 };
   1007 
   1008 }  // namespace functor
   1009 
   1010 template struct functor::ShuffleAndReverse<GPUDevice, float, 4, int>;
   1011 template struct functor::ShuffleAndReverse<GPUDevice, Eigen::half, 4, int>;
   1012 
   1013 template struct functor::ShuffleAndReverse<GPUDevice, float, 4,
   1014                                            Eigen::DenseIndex>;
   1015 template struct functor::ShuffleAndReverse<GPUDevice, Eigen::half, 4,
   1016                                            Eigen::DenseIndex>;
   1017 
   1018 template struct functor::TransformDepth<GPUDevice, float, int>;
   1019 template struct functor::TransformDepth<GPUDevice, Eigen::half, int>;
   1020 
   1021 template struct functor::SwapDimension1And2InTensor3<GPUDevice, uint8>;
   1022 template struct functor::SwapDimension1And2InTensor3<GPUDevice, uint16>;
   1023 template struct functor::SwapDimension1And2InTensor3<GPUDevice, uint32>;
   1024 template struct functor::SwapDimension1And2InTensor3<GPUDevice, uint64>;
   1025 template struct functor::SwapDimension1And2InTensor3<GPUDevice, float4>;
   1026 template struct functor::SwapDimension1And2InTensor3<GPUDevice, float2,
   1027                                                      /*conjugate=*/true>;
   1028 template struct functor::SwapDimension1And2InTensor3<GPUDevice, double2,
   1029                                                      /*conjugate=*/true>;
   1030 
   1031 template struct functor::SwapDimension0And2InTensor3<GPUDevice, uint8>;
   1032 template struct functor::SwapDimension0And2InTensor3<GPUDevice, uint16>;
   1033 template struct functor::SwapDimension0And2InTensor3<GPUDevice, uint32>;
   1034 template struct functor::SwapDimension0And2InTensor3<GPUDevice, uint64>;
   1035 template struct functor::SwapDimension0And2InTensor3<GPUDevice, float4>;
   1036 template struct functor::SwapDimension0And2InTensor3<GPUDevice, float2,
   1037                                                      /*conjugate=*/true>;
   1038 template struct functor::SwapDimension0And2InTensor3<GPUDevice, double2,
   1039                                                      /*conjugate=*/true>;
   1040 
   1041 // For 2d ops.
   1042 template struct functor::TransformFilter<GPUDevice, float, int, 4>;
   1043 template struct functor::TransformFilter<GPUDevice, Eigen::half, int, 4>;
   1044 
   1045 template struct functor::ReverseTransformFilter<GPUDevice, float, 4>;
   1046 template struct functor::ReverseTransformFilter<GPUDevice, Eigen::half, 4>;
   1047 
   1048 template struct functor::NHWCToNCHW<GPUDevice, double, 4>;
   1049 template struct functor::NHWCToNCHW<GPUDevice, float, 4>;
   1050 template struct functor::NHWCToNCHW<GPUDevice, Eigen::half, 4>;
   1051 
   1052 template struct functor::NCHWToNHWC<GPUDevice, double, 4>;
   1053 template struct functor::NCHWToNHWC<GPUDevice, float, 4>;
   1054 template struct functor::NCHWToNHWC<GPUDevice, Eigen::half, 4>;
   1055 
   1056 template struct functor::PadInput<GPUDevice, int, int, 4>;
   1057 template struct functor::PadInput<GPUDevice, float, int, 4>;
   1058 template struct functor::PadInput<GPUDevice, Eigen::half, int, 4>;
   1059 
   1060 // For 3d ops.
   1061 template struct functor::TransformFilter<GPUDevice, float, int, 5>;
   1062 template struct functor::TransformFilter<GPUDevice, Eigen::half, int, 5>;
   1063 
   1064 template struct functor::ReverseTransformFilter<GPUDevice, float, 5>;
   1065 template struct functor::ReverseTransformFilter<GPUDevice, Eigen::half, 5>;
   1066 
   1067 template struct functor::NHWCToNCHW<GPUDevice, float, 5>;
   1068 template struct functor::NHWCToNCHW<GPUDevice, Eigen::half, 5>;
   1069 
   1070 template struct functor::NCHWToNHWC<GPUDevice, float, 5>;
   1071 template struct functor::NCHWToNHWC<GPUDevice, Eigen::half, 5>;
   1072 
   1073 template struct functor::PadInput<GPUDevice, float, int, 5>;
   1074 template struct functor::PadInput<GPUDevice, Eigen::half, int, 5>;
   1075 
   1076 }  // namespace tensorflow
   1077 
   1078 #endif  // GOOGLE_CUDA
   1079