Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Specialization of SpaceToBatchFunctor for a GPUDevice.
     17 
     18 #if GOOGLE_CUDA
     19 
     20 #define EIGEN_USE_GPU
     21 
     22 #include "tensorflow/core/kernels/spacetobatch_functor.h"
     23 
     24 #include "tensorflow/core/framework/register_types.h"
     25 #include "tensorflow/core/util/cuda_kernel_helper.h"
     26 
     27 namespace tensorflow {
     28 
     29 typedef Eigen::GpuDevice GPUDevice;
     30 
     31 // Shape and padding parameters for space-to-batch and batch-to-space conversion
     32 // GPU kernel.
     33 template <int NUM_BLOCK_DIMS>
     34 struct S2BParameters {
     35   int32 space_tensor_batch;
     36   int32 batch_tensor_shape[NUM_BLOCK_DIMS + 2];
     37   int32 space_tensor_spatial_shape[NUM_BLOCK_DIMS];
     38   int32 pad_start[NUM_BLOCK_DIMS];
     39   int32 block_shape[NUM_BLOCK_DIMS];
     40 };
     41 
     42 // GPU kernel for space-to-batch (if B2S = false) and batch-to-space conversion
     43 // (if B2S = true).
     44 //
     45 // To simplify template implementation given lack of constexpr if, both the
     46 // input and output pointers are non-const.
     47 template <typename T, int NUM_BLOCK_DIMS, bool B2S>
     48 __global__ void S2B(const int32 nthreads, T* space_tensor_ptr,
     49                     S2BParameters<NUM_BLOCK_DIMS> args, T* batch_tensor_ptr) {
     50   CUDA_1D_KERNEL_LOOP(batch_tensor_idx, nthreads) {
     51     int32 remaining_batch_tensor_idx = batch_tensor_idx;
     52 
     53     int32 batch_tensor_pos[NUM_BLOCK_DIMS + 2];
     54 
     55     for (int dim = NUM_BLOCK_DIMS + 1; dim >= 1; --dim) {
     56       batch_tensor_pos[dim] =
     57           remaining_batch_tensor_idx % args.batch_tensor_shape[dim];
     58       remaining_batch_tensor_idx /= args.batch_tensor_shape[dim];
     59     }
     60     batch_tensor_pos[0] = remaining_batch_tensor_idx;
     61 
     62     int32 remaining_block_idx = batch_tensor_pos[0] / args.space_tensor_batch;
     63     int32 space_tensor_idx = batch_tensor_pos[NUM_BLOCK_DIMS + 1];
     64     int32 space_tensor_stride = args.batch_tensor_shape[NUM_BLOCK_DIMS + 1];
     65     const int32 space_tensor_batch_pos =
     66         batch_tensor_pos[0] % args.space_tensor_batch;
     67     for (int block_dim = NUM_BLOCK_DIMS - 1; block_dim >= 0; --block_dim) {
     68       int32 offset = remaining_block_idx;
     69       if (block_dim > 0) {
     70         offset %= args.block_shape[block_dim];
     71       }
     72       int32 space_tensor_pos =
     73           batch_tensor_pos[block_dim + 1] * args.block_shape[block_dim] +
     74           offset - args.pad_start[block_dim];
     75       if (space_tensor_pos < 0 ||
     76           space_tensor_pos >= args.space_tensor_spatial_shape[block_dim]) {
     77         if (B2S == false) {
     78           // In the space-to-batch case, write zero padding.
     79           batch_tensor_ptr[batch_tensor_idx] = static_cast<T>(0);
     80         }
     81         break;
     82       }
     83       space_tensor_idx += space_tensor_stride * space_tensor_pos;
     84       space_tensor_stride *= args.space_tensor_spatial_shape[block_dim];
     85       if (block_dim == 0) {
     86         space_tensor_idx += space_tensor_stride * space_tensor_batch_pos;
     87         if (B2S == false) {
     88           batch_tensor_ptr[batch_tensor_idx] =
     89               ldg(space_tensor_ptr + space_tensor_idx);
     90         } else {
     91           space_tensor_ptr[space_tensor_idx] =
     92               ldg(batch_tensor_ptr + batch_tensor_idx);
     93         }
     94       }
     95       remaining_block_idx /= args.block_shape[block_dim];
     96     }
     97   }
     98 }
     99 
    100 namespace functor {
    101 template <typename T, int NUM_BLOCK_DIMS, bool B2S>
    102 struct SpaceToBatchFunctor<GPUDevice, T, NUM_BLOCK_DIMS, B2S> {
    103   using SpaceT = typename std::conditional<B2S, T, const T>::type;
    104   using BatchT = typename std::conditional<B2S, const T, T>::type;
    105   Status operator()(
    106       const GPUDevice& d,
    107       typename TTypes<SpaceT, NUM_BLOCK_DIMS + 2>::Tensor space_tensor,
    108       const int64 block_shape[NUM_BLOCK_DIMS],
    109       const int64 paddings[NUM_BLOCK_DIMS * 2],
    110       typename TTypes<BatchT, NUM_BLOCK_DIMS + 2>::Tensor batch_tensor) {
    111     // Kernel execution fails if number of elements is zero.
    112     if (batch_tensor.size() == 0) {
    113       return Status::OK();
    114     }
    115     S2BParameters<NUM_BLOCK_DIMS> args;
    116     args.space_tensor_batch = space_tensor.dimension(0);
    117     for (int block_dim = 0; block_dim < NUM_BLOCK_DIMS; ++block_dim) {
    118       if (block_shape[block_dim] > std::numeric_limits<int32>::max()) {
    119         return errors::InvalidArgument("block_shape value exceeds 2^32-1");
    120       }
    121       args.block_shape[block_dim] = block_shape[block_dim];
    122       if (space_tensor.dimension(block_dim + 1) >
    123           std::numeric_limits<int32>::max()) {
    124         return errors::InvalidArgument("space_tensor dimension exceeds 2^32-1");
    125       }
    126       args.space_tensor_spatial_shape[block_dim] =
    127           space_tensor.dimension(block_dim + 1);
    128       if (paddings[block_dim * 2] > std::numeric_limits<int32>::max()) {
    129         return errors::InvalidArgument("paddings/crops value exceeds 2^32-1");
    130       }
    131       args.pad_start[block_dim] = paddings[block_dim * 2];
    132     }
    133     int64 total_count = 1;
    134     for (int dim = 0; dim < NUM_BLOCK_DIMS + 2; ++dim) {
    135       args.batch_tensor_shape[dim] = batch_tensor.dimension(dim);
    136       total_count *= args.batch_tensor_shape[dim];
    137     }
    138     if (total_count > std::numeric_limits<int32>::max()) {
    139       return errors::InvalidArgument(
    140           "number of batch_tensor elements exceeds 2^32-1");
    141     }
    142     CudaLaunchConfig config =
    143         GetCudaLaunchConfig(static_cast<int32>(total_count), d);
    144     S2B<T, NUM_BLOCK_DIMS, B2S>
    145         <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
    146             config.virtual_thread_count, const_cast<T*>(space_tensor.data()),
    147             args, const_cast<T*>(batch_tensor.data()));
    148     return Status::OK();
    149   }
    150 };
    151 
    152 // Instantiate.
    153 #define INSTANTIATE(NUM_BLOCK_DIMS, T)                                      \
    154   template struct SpaceToBatchFunctor<GPUDevice, T, NUM_BLOCK_DIMS, false>; \
    155   template struct SpaceToBatchFunctor<GPUDevice, T, NUM_BLOCK_DIMS, true>;  \
    156   /**/
    157 
    158 #define INSTANTIATE_FOR_T(T) \
    159   TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(INSTANTIATE, T)
    160 
    161 TF_CALL_GPU_NUMBER_TYPES(INSTANTIATE_FOR_T)
    162 
    163 #undef INSTANTIATE_FOR_T
    164 #undef INSTANTIATE
    165 
    166 }  // end namespace functor
    167 }  // end namespace tensorflow
    168 
    169 #endif  // GOOGLE_CUDA
    170