Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // The algorithm for dynamic partition has the following steps:
     17 // 1. Let N be the size of partitions. We initialize a new vector indices_in
     18 //    with the values 0, 1, 2, ..., N-1.
     19 // 2. We apply cub::DeviceRadixSort::SortPairs to the key - value pairs given
     20 //    by partitions and indices_in. This will result in two new vectors
     21 //    partitions_out and indices_out, with partitions_out sorted.
     22 // 3. The first dimension of outputs[i] is equal to the number of i-values in
     23 //    partitions_out. We determine it in two steps:
     24 //    - apply cub::DeviceReduce::ReduceByKey to count how many times each value
     25 //      appears in partitions_out,
     26 //    - move the results to partition_count. This handles missing values
     27 //      (corresponding to empty parts).
     28 // 4. Because partition_count is on the GPU, we bring it asynchronously to
     29 //    the CPU. Then we can allocate the output tensors.
     30 // 5. Finally, we use indices_out and the gather functor to collect the output.
     31 //    This works, because for each interval of i-values, indices_out points
     32 //    to the slices which should form output[i].
     33 
     34 #if GOOGLE_CUDA
     35 
     36 #define EIGEN_USE_GPU
     37 
     38 #include "external/cub_archive/cub/device/device_radix_sort.cuh"
     39 #include "external/cub_archive/cub/device/device_reduce.cuh"
     40 #include "external/cub_archive/cub/iterator/constant_input_iterator.cuh"
     41 #include "external/cub_archive/cub/thread/thread_operators.cuh"
     42 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
     43 #include "tensorflow/core/framework/op_kernel.h"
     44 #include "tensorflow/core/framework/register_types.h"
     45 #include "tensorflow/core/framework/tensor.h"
     46 #include "tensorflow/core/framework/types.h"
     47 #include "tensorflow/core/kernels/bounds_check.h"
     48 #include "tensorflow/core/kernels/fill_functor.h"
     49 #include "tensorflow/core/kernels/gather_functor_gpu.cu.h"
     50 #include "tensorflow/core/util/cuda_kernel_helper.h"
     51 #include "tensorflow/core/util/transform_output_iterator.h"
     52 
     53 namespace tensorflow {
     54 
     55 typedef Eigen::GpuDevice GPUDevice;
     56 
     57 namespace {
     58 
     59 template <typename T>
     60 __global__ void RangeInitKernel(const T start, const T delta, const int32 size,
     61                                 T* out) {
     62   CUDA_1D_KERNEL_LOOP(i, size) { out[i] = start + i * delta; }
     63 }
     64 
     65 __global__ void MoveValuesKernel(const int32* keys, const int32* values,
     66                                  const int32* size, int32 out_size,
     67                                  int32* out) {
     68   int32 N = min(ldg(size), out_size);
     69   CUDA_1D_KERNEL_LOOP(i, N) {
     70     int32 key = ldg(keys + i);
     71     int32 value = ldg(values + i);
     72     if (FastBoundsCheck(key, out_size)) out[key] = value;
     73   }
     74 }
     75 
     76 // Initialize out with range start, start + delta, start + 2 * delta, ...
     77 // This is needed because tf.range has no GPU implementation.
     78 template <typename T>
     79 void RangeInit(const GPUDevice& d, const T start, const T delta,
     80                const int32 size, typename TTypes<T>::Flat out) {
     81   CudaLaunchConfig config = GetCudaLaunchConfig(size, d);
     82   RangeInitKernel<T>
     83       <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
     84           start, delta, size, out.data());
     85 }
     86 
     87 // Given *num_runs pairs (key, value), this function moves the value
     88 // corresponding to key i at position i in the array out.
     89 void MoveValues(const GPUDevice& d, int32* keys, int32* values, int32* num_runs,
     90                 int32 out_size, int32* out) {
     91   // Because num_runs is located on the GPU, we can not access it directly.
     92   // So we launch the kernel with size = out_size.
     93   // This is valid for correct inputs, because then out_size >= *num_runs.
     94   // For wrong inputs, we may have out_size < *num_runs. In this case we will
     95   // only handle the first out_size values.
     96   CudaLaunchConfig config = GetCudaLaunchConfig(out_size, d);
     97   MoveValuesKernel<<<config.block_count, config.thread_per_block, 0,
     98                      d.stream()>>>(keys, values, num_runs, out_size, out);
     99 }
    100 
    101 template <typename T>
    102 void CallGatherKernel(const GPUDevice& d, const T* params, const int32* indices,
    103                       T* out, int64 gather_dim_size, int64 indices_size,
    104                       int64 slice_size, int64 out_size) {
    105   CudaLaunchConfig config = GetCudaLaunchConfig(out_size, d);
    106   GatherOpKernel<T, int32, true>
    107       <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
    108           params, indices, out, gather_dim_size, indices_size, slice_size,
    109           out_size);
    110 }
    111 
    112 struct IdentityOp {
    113   __device__ int32 __forceinline__ operator()(const int32& a) const {
    114     return a;
    115   }
    116 };
    117 
    118 // Define an output iterator that only allows assignment to
    119 // positions between [base, base + limit).
    120 class BoundedOutputIterator
    121     : public TransformOutputIterator<int32, int32, IdentityOp> {
    122  private:
    123   int32 limit;
    124   int32* base;
    125 
    126   struct BoundedReference : Reference {
    127     int32 limit;
    128     int32* base;
    129     // Constructor
    130     __host__ __device__ __forceinline__
    131     BoundedReference(int32* ptr, int32* base, IdentityOp op, int32 limit)
    132         : Reference(ptr, op), limit(limit), base(base) {}
    133 
    134     // Assignment
    135     __host__ __device__ __forceinline__ int32 operator=(int32 val) {
    136       if (ptr - base < limit && ptr - base >= 0) *ptr = val;
    137       return val;
    138     }
    139   };
    140 
    141  public:
    142   typedef BoundedOutputIterator self_type;
    143   typedef BoundedReference reference;
    144 
    145   __host__ __device__ __forceinline__ BoundedOutputIterator(int32* ptr,
    146                                                             IdentityOp op,
    147                                                             int32 size)
    148       : TransformOutputIterator(ptr, op), limit(size), base(ptr) {}
    149 
    150   __host__ __device__ __forceinline__
    151   BoundedOutputIterator(int32* ptr, int32* base, IdentityOp op, int32 size)
    152       : TransformOutputIterator(ptr, op), limit(size), base(base) {}
    153 
    154   // Indirection
    155   __host__ __device__ __forceinline__ reference operator*() const {
    156     return BoundedReference(ptr, base, conversion_op, limit);
    157   }
    158 
    159   // Array subscript
    160   __host__ __device__ __forceinline__ reference operator[](int32 n) const {
    161     return BoundedReference(ptr + n, base, conversion_op, limit);
    162   }
    163 
    164   // Addition
    165   __host__ __device__ __forceinline__ self_type operator+(int32 n) const {
    166     self_type retval(ptr + n, base, conversion_op, limit);
    167     return retval;
    168   }
    169 
    170   // Subtraction
    171   __host__ __device__ __forceinline__ self_type operator-(int32 n) const {
    172     self_type retval(ptr - n, base, conversion_op, limit);
    173     return retval;
    174   }
    175 };
    176 
    177 }  // namespace
    178 
    179 // The current implementation has memory cost on GPU
    180 // I + P + max(3N + R + P, O + N), where:
    181 // I - the size of the input
    182 // N - the size of the partitions tensor
    183 // R - the temporary storage used by cub::RadixSort, about 2N
    184 // P - the number of partitions
    185 // O - the size of the output
    186 // So roughly the cost is I + P + max(5N, O + N).
    187 template <typename T>
    188 class DynamicPartitionOpGPU : public AsyncOpKernel {
    189  public:
    190   explicit DynamicPartitionOpGPU(OpKernelConstruction* c) : AsyncOpKernel(c) {
    191     OP_REQUIRES_OK(c, c->GetAttr("num_partitions", &num_partitions_));
    192     OP_REQUIRES(c, num_partitions_ >= 1,
    193                 errors::InvalidArgument("num_partitions must be at least 1"));
    194   }
    195 
    196   void AllocateTempSpace(OpKernelContext* c, int32 N, Tensor* indices_in,
    197                          Tensor* partitions_out, Tensor* indices_out,
    198                          DoneCallback done) {
    199     int32 M = std::max(N, num_partitions_);
    200     // indices_in will be made slightly larger to accommodate
    201     // later computations.
    202     OP_REQUIRES_OK_ASYNC(
    203         c, c->allocate_temp(DT_INT32, TensorShape({M}), indices_in), done);
    204     OP_REQUIRES_OK_ASYNC(
    205         c, c->allocate_temp(DT_INT32, TensorShape({N}), partitions_out), done);
    206     OP_REQUIRES_OK_ASYNC(
    207         c, c->allocate_temp(DT_INT32, TensorShape({N}), indices_out), done);
    208   }
    209 
    210   void AllocateOutputs(OpKernelContext* c, const Tensor* data,
    211                        const Tensor* partitions, const Tensor* partition_count,
    212                        OpOutputList* Tout, DoneCallback done) {
    213     auto e_part_count = partition_count->flat<int32>();
    214     // Allocate output tensors of the right size
    215     OP_REQUIRES_OK_ASYNC(c, c->output_list("outputs", Tout), done);
    216     for (int p = 0; p < num_partitions_; p++) {
    217       TensorShape shape;
    218       shape.AddDim(e_part_count(p));
    219       for (int i = partitions->dims(); i < data->dims(); i++) {
    220         shape.AddDim(data->dim_size(i));
    221       }
    222       Tensor* out;
    223       OP_REQUIRES_OK_ASYNC(c, Tout->allocate(p, shape, &out), done);
    224     }
    225   }
    226 
    227   void ComputeAsync(OpKernelContext* c, DoneCallback done) {
    228     const Tensor& data = c->input(0);
    229     const Tensor& partitions = c->input(1);
    230 
    231     OP_REQUIRES_ASYNC(
    232         c, TensorShapeUtils::StartsWith(data.shape(), partitions.shape()),
    233         errors::InvalidArgument(
    234             "data.shape must start with partitions.shape, ",
    235             "got data.shape = ", data.shape().DebugString(),
    236             ", partitions.shape = ", partitions.shape().DebugString()),
    237         done);
    238 
    239     Tensor partition_count;
    240 
    241     // We must handle the case of empty partitions separately,
    242     // because kernels don't work with 0-sized tensors.
    243     if (partitions.NumElements() == 0) {
    244       AllocatorAttributes alloc_attr;
    245       alloc_attr.set_on_host(true);
    246       OP_REQUIRES_OK_ASYNC(
    247           c,
    248           c->allocate_temp(DT_INT32, TensorShape({num_partitions_}),
    249                            &partition_count, alloc_attr),
    250           done);
    251       auto e_part_count = partition_count.flat<int32>();
    252       for (int i = 0; i < num_partitions_; i++) e_part_count(i) = 0;
    253       OpOutputList outputs;
    254       this->AllocateOutputs(c, &data, &partitions, &partition_count, &outputs,
    255                             done);
    256       if (c->status().ok()) done();
    257       return;
    258     }
    259 
    260     // Prepare for counting.
    261     OP_REQUIRES_OK_ASYNC(
    262         c,
    263         c->allocate_temp(DT_INT32, TensorShape({num_partitions_}),
    264                          &partition_count),
    265         done);
    266     Tensor indices_out;
    267     // Count how many times each partition index occurs.
    268     // Also sort the info in partitions and output it in indices_out,
    269     // in preparation for the next step.
    270     this->CountAndSortParts(c, &partitions, &partition_count, &indices_out,
    271                             done);
    272     if (!c->status().ok()) return;
    273 
    274     // In order to allocate the output tensor we have to move partition_count
    275     // to CPU.
    276     auto* stream = c->op_device_context()->stream();
    277     OP_REQUIRES_ASYNC(c, stream, errors::Internal("No GPU stream available."),
    278                       done);
    279     Tensor cpu_tensor;
    280     AllocatorAttributes alloc_attr;
    281     alloc_attr.set_on_host(true);
    282     alloc_attr.set_gpu_compatible(true);
    283     OP_REQUIRES_OK_ASYNC(
    284         c,
    285         c->allocate_temp(partition_count.dtype(), partition_count.shape(),
    286                          &cpu_tensor, alloc_attr),
    287         done);
    288     perftools::gputools::DeviceMemoryBase wrapped(
    289         partition_count.flat<int32>().data(), num_partitions_ * sizeof(int32));
    290     const bool status =
    291         stream
    292             ->ThenMemcpy(cpu_tensor.flat<int32>().data(), wrapped,
    293                          num_partitions_ * sizeof(int32))
    294             .ok();
    295     OP_REQUIRES_ASYNC(
    296         c, status,
    297         errors::Internal("Failed to launch copy from device to host."), done);
    298 
    299     // Keep a reference to partition_count so that the buffer
    300     // is not deallocated at the end of the function, before
    301     // memcpy is completed.
    302     TensorReference partition_ref(partition_count);
    303     auto wrapped_callback = [this, c, &data, &partitions, indices_out,
    304                              partition_ref, cpu_tensor, done]() {
    305       OpOutputList outputs;
    306       this->AllocateOutputs(c, &data, &partitions, &cpu_tensor, &outputs, done);
    307       if (!c->status().ok()) {
    308         partition_ref.Unref();
    309         return;
    310       }
    311       int32 N = partitions.NumElements();
    312       int64 slice_size = data.NumElements() / N;
    313       this->GatherSlices(c, &data, &indices_out, N, slice_size, outputs);
    314       partition_ref.Unref();
    315       done();
    316     };
    317 
    318     c->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
    319         stream, wrapped_callback);
    320   }
    321 
    322  protected:
    323   void RadixSort(OpKernelContext* c, const Tensor* partitions,
    324                  Tensor* indices_in, Tensor* partitions_out,
    325                  Tensor* indices_out, DoneCallback done) {
    326     int32 N = partitions->NumElements();
    327     const GPUDevice& device = c->eigen_device<GPUDevice>();
    328     const cudaStream_t& cu_stream = GetCudaStream(c);
    329 
    330     // Initialize the indices_in tensor using the Range GPU kernel.
    331     RangeInit(device, 0, 1, N, indices_in->flat<int32>());
    332     // Obtain the pointers to inner buffers.
    333     const int32* partitions_ptr = partitions->flat<int32>().data();
    334     int32* partitions_out_ptr = partitions_out->flat<int32>().data();
    335     int32* indices_in_ptr = indices_in->flat<int32>().data();
    336     int32* indices_out_ptr = indices_out->flat<int32>().data();
    337     // Determine temporary device storage requirements.
    338     Tensor cub_temp_storage;
    339     size_t temp_storage_bytes = 0;
    340     cub::DeviceRadixSort::SortPairs(
    341         NULL, temp_storage_bytes, partitions_ptr, partitions_out_ptr,
    342         indices_in_ptr, indices_out_ptr, N, 0, sizeof(int32) * 8, cu_stream);
    343     // Allocate temporary storage.
    344     OP_REQUIRES_OK_ASYNC(
    345         c,
    346         c->allocate_temp(DT_INT8,
    347                          TensorShape({static_cast<int64>(temp_storage_bytes)}),
    348                          &cub_temp_storage),
    349         done);
    350     // Radix-sort the partition information.
    351     cub::DeviceRadixSort::SortPairs(
    352         cub_temp_storage.flat<int8>().data(), temp_storage_bytes,
    353         partitions_ptr, partitions_out_ptr, indices_in_ptr, indices_out_ptr, N,
    354         0, sizeof(int32) * 8, cu_stream);
    355   }  // At this point cub_temp_storage will be marked for deallocation.
    356 
    357   void CountAndSortParts(OpKernelContext* c, const Tensor* partitions,
    358                          Tensor* partition_count, Tensor* indices_out,
    359                          DoneCallback done) {
    360     const GPUDevice& device = c->eigen_device<GPUDevice>();
    361     const cudaStream_t& cu_stream = GetCudaStream(c);
    362     int32 N = partitions->NumElements();
    363     Tensor indices_in;
    364     Tensor partitions_out;
    365     Tensor aggregates_out;
    366 
    367     // Allocate memory for Radix-Sort.
    368     this->AllocateTempSpace(c, N, &indices_in, &partitions_out, indices_out,
    369                             done);
    370     if (!c->status().ok()) return;
    371     this->RadixSort(c, partitions, &indices_in, &partitions_out, indices_out,
    372                     done);
    373     if (!c->status().ok()) return;
    374     // We will now apply a reduce operation to count how many times
    375     // each index appears in partitions.
    376 
    377     // Zero-out the partition_count tensor.
    378     functor::SetZeroFunctor<GPUDevice, int32> zero_functor;
    379     zero_functor(device, partition_count->flat<int32>());
    380     // Allocate memory for aggregates_out.
    381     OP_REQUIRES_OK_ASYNC(
    382         c,
    383         c->allocate_temp(DT_INT32, TensorShape({num_partitions_}),
    384                          &aggregates_out),
    385         done);
    386     // Obtain the pointers to inner buffers.
    387     int32* keys_in_ptr = partitions_out.flat<int32>().data();
    388     // Here we reuse the indices_in tensor for the unique keys output.
    389     int32* unique_out_ptr = indices_in.flat<int32>().data();
    390     int32* aggregates_out_ptr = aggregates_out.flat<int32>().data();
    391     // We wrap the pointers in bounded output iterators to guard against
    392     // wrong inputs (more than num_partitions distinct indices).
    393     IdentityOp id_op;
    394     BoundedOutputIterator unique_out_it(unique_out_ptr, id_op, num_partitions_);
    395     BoundedOutputIterator aggregates_out_it(aggregates_out_ptr, id_op,
    396                                             num_partitions_);
    397 
    398     cub::ConstantInputIterator<int32> values_in(1);
    399     cub::Sum reduction_op;
    400 
    401     // Allocate space on GPU for the number of runs. This is required by CUB.
    402     Tensor num_runs;
    403     OP_REQUIRES_OK_ASYNC(
    404         c, c->allocate_temp(DT_INT32, TensorShape({1}), &num_runs), done);
    405     int32* num_runs_ptr = num_runs.flat<int32>().data();
    406 
    407     // Determine temporary device storage requirements
    408     Tensor cub_temp_storage;
    409     size_t temp_storage_bytes = 0;
    410     cub::DeviceReduce::ReduceByKey(NULL, temp_storage_bytes, keys_in_ptr,
    411                                    unique_out_it, values_in, aggregates_out_it,
    412                                    num_runs_ptr, reduction_op, N, cu_stream);
    413     // Allocate temporary storage.
    414     OP_REQUIRES_OK_ASYNC(
    415         c,
    416         c->allocate_temp(DT_INT8,
    417                          TensorShape({static_cast<int64>(temp_storage_bytes)}),
    418                          &cub_temp_storage),
    419         done);
    420     // Run reduce-by-key. The effect is that we count how many times
    421     // each index appears in partitions. The distinct indices are stored
    422     // in unique_out, while the count is stored in aggregates_out.
    423     // The total number of distinct indices is stored in num_runs.
    424     cub::DeviceReduce::ReduceByKey(cub_temp_storage.flat<int8>().data(),
    425                                    temp_storage_bytes, keys_in_ptr,
    426                                    unique_out_it, values_in, aggregates_out_it,
    427                                    num_runs_ptr, reduction_op, N, cu_stream);
    428     // We are not done yet. unique_out only contains the indices that appeared
    429     // at least once in partitions. We move each value from aggregates_out
    430     // to the corresponding position in partition_count. This will handle
    431     // possibly empty parts.
    432     MoveValues(device, unique_out_ptr, aggregates_out_ptr, num_runs_ptr,
    433                num_partitions_, partition_count->flat<int32>().data());
    434   }  // At this point indices_in, partitions_out, aggregates_out
    435      // and cub_temp_storage will be marked for deallocation.
    436 
    437   void GatherSlices(OpKernelContext* c, const Tensor* data,
    438                     const Tensor* indices, int32 N, int64 slice_size,
    439                     OpOutputList& outs) {
    440     const GPUDevice& device = c->eigen_device<GPUDevice>();
    441     const int32* ind_base = indices->flat<int32>().data();
    442     const T* data_base = data->flat<T>().data();
    443 
    444     for (int p = 0; p < num_partitions_; p++) {
    445       int32 indices_size = outs[p]->dim_size(0);
    446       int64 out_size = outs[p]->NumElements();
    447       T* out_base = outs[p]->flat<T>().data();
    448       if (out_size > 0)
    449         CallGatherKernel<T>(device, data_base, ind_base, out_base, N,
    450                             indices_size, slice_size, out_size);
    451       ind_base += indices_size;
    452     }
    453   }
    454 
    455   int32 num_partitions_;
    456 };
    457 
    458 #define REGISTER_DYNAMIC_PARTITION_GPU(T)                                 \
    459   REGISTER_KERNEL_BUILDER(                                                \
    460       Name("DynamicPartition").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    461       DynamicPartitionOpGPU<T>)
    462 
    463 TF_CALL_GPU_NUMBER_TYPES(REGISTER_DYNAMIC_PARTITION_GPU);
    464 TF_CALL_complex64(REGISTER_DYNAMIC_PARTITION_GPU);
    465 TF_CALL_complex128(REGISTER_DYNAMIC_PARTITION_GPU);
    466 #undef REGISTER_DYNAMIC_PARTITION_GPU
    467 
    468 }  // namespace tensorflow
    469 
    470 #endif  // GOOGLE_CUDA
    471