Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #if GOOGLE_CUDA
     17 
     18 #define EIGEN_USE_GPU
     19 
     20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     21 #include "tensorflow/core/kernels/ops_util.h"
     22 #include "tensorflow/core/kernels/transpose_functor.h"
     23 #include "tensorflow/core/util/cuda_kernel_helper.h"
     24 
     25 // TODO(yangzihao): Remove the dependency of conv_2d.h once we move all
     26 // GPU util functions and transpose kernels into separate files.
     27 #include "tensorflow/core/kernels/conv_2d.h"
     28 
     29 typedef Eigen::GpuDevice GPUDevice;
     30 
     31 namespace tensorflow {
     32 namespace internal {
     33 
     34 template <typename T, bool conjugate>
     35 __global__ void TransposeKernel(int nthreads, const T* src, const int32* buf,
     36                                 const int32 ndims, T* dst) {
     37   const int32* in_strides = buf;
     38   const int32* out_strides = buf + ndims;
     39   const int32* perm = buf + ndims * 2;
     40   CUDA_1D_KERNEL_LOOP(o_idx, nthreads) {
     41     int32 i_idx = 0;
     42     int32 t = o_idx;
     43     for (int32 i = 0; i < ndims; ++i) {
     44       const int32 ratio = t / out_strides[i];
     45       t -= ratio * out_strides[i];
     46       i_idx += ratio * in_strides[perm[i]];
     47     }
     48     if (conjugate) {
     49       dst[o_idx] = Eigen::numext::conj(ldg(src + i_idx));
     50     } else {
     51       dst[o_idx] = ldg(src + i_idx);
     52     }
     53   }
     54 }
     55 
     56 template <typename T, bool conjugate>
     57 void TransposeSimple(const GPUDevice& d, const Tensor& in,
     58                      const gtl::ArraySlice<int32> perm, Tensor* out) {
     59   // Ensures we can use 32-bit index.
     60   const int64 nelem = in.NumElements();
     61   CHECK_LT(nelem, kint32max) << "Tensor too large to transpose on GPU";
     62   // Pack strides and permutation into one buffer.
     63   const int32 ndims = in.dims();
     64   gtl::InlinedVector<int32, 24> host_buf(ndims * 3);
     65   gtl::InlinedVector<int32, 8> in_strides = ComputeStride<int32>(in.shape());
     66   gtl::InlinedVector<int32, 8> out_strides = ComputeStride<int32>(out->shape());
     67   // Dimension permutation.
     68   for (int i = 0; i < ndims; ++i) {
     69     host_buf[i] = in_strides[i];
     70     host_buf[ndims + i] = out_strides[i];
     71     host_buf[ndims * 2 + i] = perm[i];
     72   }
     73   // Copies the input strides, output strides and permutation to the device.
     74   auto num_bytes = sizeof(int64) * host_buf.size();
     75   auto dev_buf = d.allocate(num_bytes);
     76   // NOTE: host_buf is not allocated by CudaHostAllocator, and
     77   // therefore we are doing a sync copy effectively.
     78   d.memcpyHostToDevice(dev_buf, host_buf.data(), num_bytes);
     79   // Launch kernel to q[...] = p[...].
     80   const T* p = reinterpret_cast<const T*>(in.tensor_data().data());
     81   T* q = reinterpret_cast<T*>(const_cast<char*>((out->tensor_data().data())));
     82   CudaLaunchConfig cfg = GetCudaLaunchConfig(nelem, d);
     83   TransposeKernel<T, conjugate>
     84       <<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>(
     85           cfg.virtual_thread_count, p, reinterpret_cast<const int32*>(dev_buf),
     86           ndims, q);
     87   // Safe to deallocate immediately after the kernel launch.
     88   d.deallocate(dev_buf);
     89 }
     90 
     91 // TransposeUsingTile tries to reduce the dimension of the input tensor to 3 and
     92 // then call special kernels to swap either dimension 1 and dimension 2 or
     93 // dimension 0 and dimension 2. It returns true if the operation is success,
     94 // false otherwise.
     95 template <typename T, bool conjugate = false>
     96 struct TransposeUsingTile {
     97   static bool run(const Eigen::GpuDevice& d, const Tensor& in,
     98                   const gtl::ArraySlice<int32> perm, Tensor* out) {
     99     // First try to reduce the dimensions of the input tensor.
    100     TransposePermsVec new_perm;
    101     TransposeDimsVec new_dims;
    102     ReduceTransposeDimensions(in.shape(), perm, &new_perm, &new_dims);
    103 
    104     // Only use special GPU kernel when dimension is 2 or 3.
    105     int dims = new_dims.size();
    106     if (dims < 2 || dims > 3) return false;
    107     auto in_data = reinterpret_cast<const T*>(in.tensor_data().data());
    108     auto out_data =
    109         reinterpret_cast<T*>(const_cast<char*>(out->tensor_data().data()));
    110     switch (dims) {
    111       case 2:
    112         if (new_perm[0] == 1 && new_perm[1] == 0) {
    113           // Add the first dimension size as 1.
    114           new_dims.insert(new_dims.begin(), 1);
    115           tensorflow::functor::SwapDimension1And2InTensor3<GPUDevice, T,
    116                                                            conjugate>()(
    117               d, in_data, new_dims, out_data);
    118           return true;
    119         }
    120         break;
    121       case 3:
    122         if (new_perm == TransposePermsVec({0, 2, 1})) {
    123           tensorflow::functor::SwapDimension1And2InTensor3<GPUDevice, T,
    124                                                            conjugate>()(
    125               d, in_data, new_dims, out_data);
    126           return true;
    127         } else if (new_perm == TransposePermsVec({2, 1, 0})) {
    128           tensorflow::functor::SwapDimension0And2InTensor3<GPUDevice, T,
    129                                                            conjugate>()(
    130               d, in_data, new_dims, out_data);
    131           return true;
    132         } else {
    133           // do not handle other 3D permutations
    134           return false;
    135         }
    136         break;
    137       default:
    138         return false;
    139     }
    140     return false;
    141   }
    142 };
    143 
    144 template <bool conjugate>
    145 struct TransposeUsingTile<complex64, conjugate> {
    146   static bool run(const Eigen::GpuDevice& d, const Tensor& in,
    147                   const gtl::ArraySlice<int32> perm, Tensor* out) {
    148     if (!conjugate) {
    149       return TransposeUsingTile<uint64>::run(d, in, perm, out);
    150     } else {
    151       return TransposeUsingTile<float2, true>::run(d, in, perm, out);
    152     }
    153   }
    154 };
    155 
    156 template <bool conjugate>
    157 struct TransposeUsingTile<complex128, conjugate> {
    158   static bool run(const Eigen::GpuDevice& d, const Tensor& in,
    159                   const gtl::ArraySlice<int32> perm, Tensor* out) {
    160     if (!conjugate) {
    161       return TransposeUsingTile<float4>::run(d, in, perm, out);
    162     } else {
    163       return TransposeUsingTile<double2, true>::run(d, in, perm, out);
    164     }
    165   }
    166 };
    167 
    168 }  // namespace internal
    169 
    170 // Transpose kernel specialized for GPU Device.
    171 template <typename T, bool conjugate>
    172 struct Transpose<GPUDevice, T, conjugate> {
    173   static void run(const GPUDevice& d, const Tensor& in,
    174                   const gtl::ArraySlice<int32> perm, Tensor* out) {
    175     switch (in.dims()) {
    176       case 2:
    177         if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
    178                                                              out)) {
    179           internal::TransposeUsingEigen<GPUDevice, T, 2>(d, in, perm, conjugate,
    180                                                          out);
    181         }
    182         break;
    183       case 3:
    184         if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
    185                                                              out)) {
    186           internal::TransposeUsingEigen<GPUDevice, T, 3>(d, in, perm, conjugate,
    187                                                          out);
    188         }
    189         break;
    190       case 4:
    191         if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
    192                                                              out)) {
    193           internal::TransposeUsingEigen<GPUDevice, T, 4>(d, in, perm, conjugate,
    194                                                          out);
    195         }
    196         break;
    197       case 5:
    198         if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
    199                                                              out)) {
    200           internal::TransposeUsingEigen<GPUDevice, T, 5>(d, in, perm, conjugate,
    201                                                          out);
    202         }
    203         break;
    204       case 6:
    205         if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
    206                                                              out)) {
    207           internal::TransposeUsingEigen<GPUDevice, T, 6>(d, in, perm, conjugate,
    208                                                          out);
    209         }
    210         break;
    211       case 7:
    212         if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
    213                                                              out)) {
    214           internal::TransposeUsingEigen<GPUDevice, T, 7>(d, in, perm, conjugate,
    215                                                          out);
    216         }
    217         break;
    218       case 8:
    219         if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
    220                                                              out)) {
    221           internal::TransposeUsingEigen<GPUDevice, T, 8>(d, in, perm, conjugate,
    222                                                          out);
    223         }
    224         break;
    225       default:
    226         internal::TransposeSimple<T, conjugate>(d, in, perm, out);
    227         break;
    228     }
    229   }
    230 };
    231 
    232 template <bool conjugate>
    233 struct Transpose<GPUDevice, string, conjugate> {
    234   static void run(const GPUDevice& d, const Tensor& in,
    235                   const gtl::ArraySlice<int32> perm, Tensor* out) {
    236     LOG(FATAL) << "Transpose of DT_STRING tensor not supported on GPU.";
    237   }
    238 };
    239 
    240 // Explicit instantiation.
    241 template struct Transpose<GPUDevice, string, false>;
    242 
    243 template <>
    244 Status DoTranspose(const GPUDevice& device, const Tensor& in,
    245                    const gtl::ArraySlice<int32> perm, Tensor* out) {
    246   return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/false, out);
    247 }
    248 template <>
    249 Status DoConjugateTranspose(const GPUDevice& device, const Tensor& in,
    250                             const gtl::ArraySlice<int32> perm, Tensor* out) {
    251   return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/true, out);
    252 }
    253 template <>
    254 Status DoMatrixTranspose(const GPUDevice& device, const Tensor& in,
    255                          Tensor* out) {
    256   return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/false, out);
    257 }
    258 template <>
    259 Status DoConjugateMatrixTranspose(const GPUDevice& device, const Tensor& in,
    260                                   Tensor* out) {
    261   return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/true, out);
    262 }
    263 
    264 }  // namespace tensorflow
    265 #endif  // GOOGLE_CUDA
    266