Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #if GOOGLE_CUDA
     17 
     18 #define EIGEN_USE_GPU
     19 
     20 #include <cmath>
     21 #include <vector>
     22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     23 #include "external/cub_archive/cub/device/device_segmented_radix_sort.cuh"
     24 #include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
     25 #include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/register_types.h"
     28 #include "tensorflow/core/framework/tensor.h"
     29 #include "tensorflow/core/framework/tensor_shape.h"
     30 #include "tensorflow/core/kernels/topk_op.h"
     31 #include "tensorflow/core/lib/gtl/top_n.h"
     32 #include "tensorflow/core/platform/logging.h"
     33 #include "tensorflow/core/platform/types.h"
     34 #include "tensorflow/core/util/cuda_kernel_helper.h"
     35 
     36 // Required for sorting Eigen::half
     37 namespace cub {
     38 template <>
     39 struct NumericTraits<Eigen::half>
     40     : BaseTraits<FLOATING_POINT, true, false, unsigned short int, Eigen::half> {
     41 };
     42 }  // namespace cub
     43 
     44 namespace tensorflow {
     45 
     46 typedef Eigen::GpuDevice GPUDevice;
     47 
     48 namespace impl {
     49 
     50 enum class HeapType { kMinHeap, kMaxHeap };
     51 enum class PreferIndices { kLower, kHigher };
     52 
     53 template <typename T>
     54 struct Entry {
     55   int index;
     56   T value;
     57 
     58   // Test-only.
     59   static bool greater(const Entry<T>& a, const Entry<T>& b) {
     60     if (a.value == b.value) {
     61       return a.index < b.index;
     62     }
     63     return a.value > b.value;
     64   }
     65 };
     66 
     67 template <typename T>
     68 struct LinearData {
     69   typedef impl::Entry<T> Entry;
     70 
     71   __device__ Entry& operator[](std::size_t index) const { return data[index]; }
     72 
     73   __device__ int get_index(int i) const { return data[i].index; }
     74   __device__ T get_value(int i) const { return data[i].value; }
     75 
     76   Entry* const data;
     77 };
     78 
     79 template <typename T>
     80 struct IndirectLinearData {
     81   typedef impl::Entry<T> Entry;
     82 
     83   __device__ Entry& operator[](std::size_t index) const { return data[index]; }
     84 
     85   __device__ int get_index(int i) const {
     86     return backing_data[data[i].index].index;
     87   }
     88   __device__ T get_value(int i) const { return data[i].value; }
     89 
     90   Entry* const data;
     91   Entry* const backing_data;
     92 };
     93 
     94 #if GOOGLE_CUDA
     95 template <typename T>
     96 struct StridedData {
     97   typedef impl::Entry<T> Entry;
     98 
     99   __device__ Entry& operator[](std::size_t index) const {
    100     return data[index * blockDim.x + threadIdx.x];
    101   }
    102 
    103   __device__ int get_index(int i) const { return (*this)[i].index; }
    104   __device__ T get_value(int i) const { return (*this)[i].value; }
    105 
    106   Entry* const data;
    107 };
    108 #endif
    109 
    110 // A heap of Entry<T> that can either work as a min-heap or as a max-heap.
    111 template <HeapType heapType, PreferIndices preferIndices,
    112           template <typename> class Data, typename T>
    113 struct IndexedHeap {
    114   typedef typename Data<T>::Entry Entry;
    115   const Data<T> data;
    116 
    117   __device__ bool is_above(int left, int right) {
    118     T left_value = data.get_value(left);
    119     T right_value = data.get_value(right);
    120     if (left_value == right_value) {
    121       if (preferIndices == PreferIndices::kLower) {
    122         return data.get_index(left) < data.get_index(right);
    123       } else {
    124         return data.get_index(left) > data.get_index(right);
    125       }
    126     }
    127     if (heapType == HeapType::kMinHeap) {
    128       return left_value < right_value;
    129     } else {
    130       return left_value > right_value;
    131     }
    132   }
    133 
    134   __device__ void assign(int i, const Entry& entry) { data[i] = entry; }
    135 
    136   __device__ void push_up(int i) {
    137     int child = i;
    138     int parent;
    139     for (; child > 0; child = parent) {
    140       parent = (child - 1) / 2;
    141       if (!is_above(child, parent)) {
    142         // Heap property satisfied.
    143         break;
    144       }
    145       swap(child, parent);
    146     }
    147   }
    148 
    149   __device__ void swap(int a, int b) {
    150     auto tmp = data[b];
    151     data[b] = data[a];
    152     data[a] = tmp;
    153   }
    154 
    155   __device__ void push_root_down(int k) { push_down(0, k); }
    156 
    157   // MAX-HEAPIFY in Cormen
    158   __device__ void push_down(int node, int k) {
    159     while (true) {
    160       const int left = 2 * node + 1;
    161       const int right = left + 1;
    162       int smallest = node;
    163       if (left < k && is_above(left, smallest)) {
    164         smallest = left;
    165       }
    166       if (right < k && is_above(right, smallest)) {
    167         smallest = right;
    168       }
    169       if (smallest == node) {
    170         break;
    171       }
    172       swap(smallest, node);
    173       node = smallest;
    174     }
    175   }
    176 
    177   // BUILD-MAX-HEAPIFY in Cormen
    178   __device__ void build(int k) {
    179     for (int node = (k - 1) / 2; node >= 0; node--) {
    180       push_down(node, k);
    181     }
    182   }
    183 
    184   // HEAP-EXTRACT-MAX in Cormen
    185   __device__ void remove_root(int k) {
    186     data[0] = data[k - 1];
    187     push_root_down(k - 1);
    188   }
    189 
    190   // in-place HEAPSORT in Cormen
    191   // This method destroys the heap property.
    192   __device__ void sort(int k) {
    193     for (int slot = k - 1; slot > 0; slot--) {
    194       // This is like remove_root but we insert the element at the end.
    195       swap(slot, 0);
    196       // Heap is now an element smaller.
    197       push_root_down(/*k=*/slot);
    198     }
    199   }
    200 
    201   __device__ void replace_root(const Entry& entry, int k) {
    202     data[0] = entry;
    203     push_root_down(k);
    204   }
    205 
    206   __device__ const Entry& root() { return data[0]; }
    207 };
    208 
    209 template <HeapType heapType, PreferIndices preferIndices,
    210           template <typename> class Data, typename T>
    211 __device__ IndexedHeap<heapType, preferIndices, Data, T> make_indexed_heap(
    212     typename Data<T>::Entry* data) {
    213   return IndexedHeap<heapType, preferIndices, Data, T>{Data<T>{data}};
    214 }
    215 
    216 // heapTopK walks over [input, input+length) with `step_size` stride starting at
    217 // `start_index`.
    218 // It builds a top-`k` heap that is stored in `heap_entries` using `Accessor` to
    219 // access elements in `heap_entries`. If sorted=true, the elements will be
    220 // sorted at the end.
    221 template <typename T, template <typename> class Data = LinearData>
    222 __device__ void heapTopK(const T* __restrict__ input, int length, int k,
    223                          Entry<T>* __restrict__ heap_entries,
    224                          bool sorted = false, int start_index = 0,
    225                          int step_size = 1) {
    226   assert(k <= length);
    227 
    228   auto heap =
    229       make_indexed_heap<HeapType::kMinHeap, PreferIndices::kHigher, Data, T>(
    230           heap_entries);
    231 
    232   int heap_end_index = start_index + k * step_size;
    233   if (heap_end_index > length) {
    234     heap_end_index = length;
    235   }
    236   // Initialize the min-heap.
    237   for (int index = start_index, slot = 0; index < heap_end_index;
    238        index += step_size, slot++) {
    239     heap.assign(slot, {index, input[index]});
    240   }
    241 
    242   heap.build(k);
    243 
    244   // Now iterate over the remaining items.
    245   // If an item is smaller than the min element, it is not amongst the top k.
    246   // Otherwise, replace the min element with it and push upwards.
    247   for (int index = heap_end_index; index < length; index += step_size) {
    248     // We prefer elements with lower indices. This is given here.
    249     // Later elements automatically have higher indices, so can be discarded.
    250     if (input[index] > heap.root().value) {
    251       // This element should replace the min.
    252       heap.replace_root({index, input[index]}, k);
    253     }
    254   }
    255 
    256   // Sort if wanted.
    257   if (sorted) {
    258     heap.sort(k);
    259   }
    260 }
    261 
    262 // mergeShards performs a top-k merge on `num_shards` many sorted streams that
    263 // are sorted and stored in `entries` in a strided way:
    264 // |s_1 1st|s_2 1st|...s_{num_shards} 1st|s_1 2nd|s_2 2nd|...
    265 // The overall top k elements are written to `top_k_values` and their indices
    266 // to top_k_indices.
    267 // `top_k_heap` is used as temporary storage for the merge heap.
    268 template <typename T>
    269 __device__ void mergeShards(int num_shards, int k,
    270                             Entry<T>* __restrict__ entries,
    271                             Entry<T>* __restrict__ top_k_heap, T* top_k_values,
    272                             int* top_k_indices) {
    273   // If k < num_shards, we can use a min-heap with k elements to get the top k
    274   // of the sorted blocks.
    275   // If k > num_shards, we can initialize a min-heap with the top element from
    276   // each sorted block.
    277   const int heap_size = k < num_shards ? k : num_shards;
    278 
    279   // Min-heap part.
    280   {
    281     auto min_heap = IndexedHeap<HeapType::kMinHeap, PreferIndices::kHigher,
    282                                 IndirectLinearData, T>{
    283         IndirectLinearData<T>{top_k_heap, entries}};
    284     // Initialize the heap as a min-heap.
    285     for (int slot = 0; slot < heap_size; slot++) {
    286       min_heap.assign(slot, {slot, entries[slot].value});
    287     }
    288     min_heap.build(heap_size);
    289 
    290     // Now perform top k with the remaining shards (if num_shards > heap_size).
    291     for (int shard = heap_size; shard < num_shards; shard++) {
    292       const auto entry = entries[shard];
    293       const auto root = min_heap.root();
    294       if (entry.value < root.value) {
    295         continue;
    296       }
    297       if (entry.value == root.value &&
    298           entry.index > entries[root.index].index) {
    299         continue;
    300       }
    301       // This element should replace the min.
    302       min_heap.replace_root({shard, entry.value}, heap_size);
    303     }
    304   }
    305 
    306   // Max-part.
    307   {
    308     // Turn the min-heap into a max-heap in-place.
    309     auto max_heap = IndexedHeap<HeapType::kMaxHeap, PreferIndices::kLower,
    310                                 IndirectLinearData, T>{
    311         IndirectLinearData<T>{top_k_heap, entries}};
    312     // Heapify into a max heap.
    313     max_heap.build(heap_size);
    314 
    315     // Now extract the minimum k-1 times.
    316     // k is treated specially.
    317     const int last_k = k - 1;
    318     for (int rank = 0; rank < last_k; rank++) {
    319       const Entry<T>& max_element = max_heap.root();
    320       top_k_values[rank] = max_element.value;
    321       int shard_index = max_element.index;
    322       top_k_indices[rank] = entries[shard_index].index;
    323       int next_shard_index = shard_index + num_shards;
    324       // For rank < k-1, each top k heap still contains at least 1 element,
    325       // so we can draw a replacement.
    326       max_heap.replace_root({next_shard_index, entries[next_shard_index].value},
    327                             heap_size);
    328     }
    329 
    330     // rank == last_k.
    331     const Entry<T>& max_element = max_heap.root();
    332     top_k_values[last_k] = max_element.value;
    333     int shard_index = max_element.index;
    334     top_k_indices[last_k] = entries[shard_index].index;
    335   }
    336 }
    337 
    338 extern __shared__ char shared_memory[];
    339 
    340 template <typename T>
    341 __global__ void TopKKernel(const T* input, int length, int k, bool sorted,
    342                            T* output, int* indices) {
    343   const int batch_index = blockIdx.x;
    344   const T* batch_input = input + batch_index * length;
    345 
    346   const int thread_index = threadIdx.x;
    347   const int thread_count = blockDim.x;
    348 
    349   Entry<T>* shared_entries = (Entry<T>*)shared_memory;
    350 
    351   heapTopK<T, StridedData>(batch_input, length, k, shared_entries, true,
    352                            thread_index, thread_count);
    353 
    354   __syncthreads();
    355   if (thread_index == 0) {
    356     const int offset = batch_index * k;
    357     auto batch_output = output + offset;
    358     auto batch_indices = indices + offset;
    359     Entry<T>* top_k_heap = shared_entries + thread_count * k;
    360 
    361     // TODO(blackhc): Erich says: Performance can likely be improved
    362     // significantly by having the merge be done by multiple threads rather than
    363     // just one.  ModernGPU has some nice primitives that could help with this.
    364     mergeShards(thread_count, k, shared_entries, top_k_heap, batch_output,
    365                 batch_indices);
    366   }
    367 }
    368 
    369 template <typename T>
    370 cudaError LaunchTopKKernel(const cudaStream_t& stream, int num_shards,
    371                            const T* input, int batch_size, int length, int k,
    372                            bool sorted, T* output, int* indices) {
    373   // This code assumes that k is small enough that the computation
    374   // fits inside shared memory (hard coded to 48KB).  In practice this
    375   // means k <= 3072 for T=float/int32 and k <= 2048 for T=double/int64.
    376   // The calculation is:
    377   //   shared_memory_size / (2 * (sizeof(int) + sizeof(T))) < k.
    378 
    379   // Use as many shards as possible.
    380   if (num_shards <= 0) {
    381     constexpr auto shared_memory_size = 48 << 10;  // 48 KB
    382     const auto heap_size = k * sizeof(Entry<T>);
    383     // shared_memory_size = (num_shards + 1) * heap_size <=>
    384     num_shards = shared_memory_size / heap_size - 1;
    385     if (num_shards <= 0) {
    386       num_shards = 1;
    387     }
    388     auto shard_size = length / num_shards;
    389     auto min_shard_size = 2 * k;
    390     if (shard_size < min_shard_size) {
    391       num_shards = length / min_shard_size;
    392     }
    393     if (num_shards <= 0) {
    394       num_shards = 1;
    395     } else if (num_shards > 1024) {
    396       num_shards = 1024;
    397     }
    398   }
    399   // We are limited by the amount of shared memory we have per block.
    400   auto shared_memory_size = (num_shards + 1) * k * sizeof(Entry<T>);
    401 
    402   TopKKernel<<<batch_size, num_shards, shared_memory_size, stream>>>(
    403       input, length, k, sorted, output, indices);
    404   return cudaGetLastError();
    405 }
    406 
    407 struct SegmentOffsetCreator {
    408   EIGEN_DEVICE_FUNC
    409   SegmentOffsetCreator(int num_cols) : num_cols_(num_cols) {}
    410 
    411   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
    412     return idx * num_cols_;
    413   };
    414 
    415   int num_cols_;
    416 };
    417 
    418 struct ColumnIndexCreator {
    419   ColumnIndexCreator(int num_cols) : num_cols_(num_cols) {}
    420 
    421   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(
    422       const Eigen::array<int, 1>& ix) const {
    423     return ix[0] % num_cols_;
    424   }
    425 
    426   int num_cols_;
    427 };
    428 
    429 template <typename T>
    430 Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
    431                         int num_cols, int k,
    432                         typename TTypes<T, 2>::Tensor values,
    433                         TTypes<int, 2>::Tensor indices) {
    434   const GPUDevice& d = ctx->eigen_device<GPUDevice>();
    435   const cudaStream_t& cu_stream = GetCudaStream(ctx);
    436   size_t temp_storage_bytes = -1;
    437 
    438   // TODO(ebrevdo): Once cub supports iterators for ValueT replace that tensor
    439   // with an iterator that directly returns the correct value.
    440   Tensor input_indices;
    441   TF_RETURN_IF_ERROR(ctx->allocate_temp(
    442       DT_INT32, TensorShape({num_rows, num_cols}), &input_indices));
    443   auto input_indices_t = To32Bit(input_indices.flat<int32>());
    444   input_indices_t.device(d) =
    445       input_indices_t.generate(ColumnIndexCreator(num_cols));
    446 
    447   cub::CountingInputIterator<int> counting_iter(0);
    448   cub::TransformInputIterator<int, SegmentOffsetCreator,
    449                               cub::CountingInputIterator<int>>
    450       segment_offsets_t(counting_iter, SegmentOffsetCreator(num_cols));
    451 
    452   Tensor temp_values;
    453   Tensor temp_indices;
    454   T* sorted_values_ptr;
    455   int* sorted_indices_ptr;
    456   if (k == num_cols) {
    457     // Doing a full sort, no intermediate values needed.
    458     sorted_values_ptr = values.data();
    459     sorted_indices_ptr = indices.data();
    460   } else {
    461     // Need to create intermediate values for sorting.
    462     TF_RETURN_IF_ERROR(ctx->allocate_temp(
    463         DT_INT32, TensorShape({num_rows, num_cols}), &temp_indices));
    464     TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value,
    465                                           TensorShape({num_rows, num_cols}),
    466                                           &temp_values));
    467     sorted_indices_ptr = temp_indices.flat<int32>().data();
    468     sorted_values_ptr = temp_values.flat<T>().data();
    469   }
    470 
    471   auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
    472       /* d_temp_storage */ nullptr,
    473       /* temp_storage_bytes */ temp_storage_bytes,
    474       /* d_keys_in */ input,
    475       /* d_keys_out */ sorted_values_ptr,
    476       /* d_values_in */ input_indices_t.data(),
    477       /* d_values_out */ sorted_indices_ptr,
    478       /* num_items */ num_cols * num_rows,
    479       /* num_segments */ num_rows,
    480       /* d_begin_offsets */ segment_offsets_t,
    481       /* d_end_offsets */ segment_offsets_t + 1,
    482       /* begin_bit */ 0,
    483       /* end_bit */ sizeof(T) * 8,
    484       /* stream */ cu_stream);
    485   if (err != cudaSuccess) {
    486     return errors::Internal(
    487         "TopKOp: Could not launch "
    488         "cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate "
    489         "temp_storage_bytes, status: ",
    490         cudaGetErrorString(err));
    491   }
    492   Tensor temp_storage;
    493   TF_RETURN_IF_ERROR(ctx->allocate_temp(
    494       DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
    495       &temp_storage));
    496   err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
    497       /* d_temp_storage */ temp_storage.flat<int8>().data(),
    498       /* temp_storage_bytes */ temp_storage_bytes,
    499       /* d_keys_in */ input,
    500       /* d_keys_out */ sorted_values_ptr,
    501       /* d_values_in */ input_indices_t.data(),
    502       /* d_values_out */ sorted_indices_ptr,
    503       /* num_items */ num_cols * num_rows,
    504       /* num_segments */ num_rows,
    505       /* d_begin_offsets */ segment_offsets_t,
    506       /* d_end_offsets */ segment_offsets_t + 1,
    507       /* begin_bit */ 0,
    508       /* end_bit */ sizeof(T) * 8,
    509       /* stream */ cu_stream);
    510   if (err != cudaSuccess) {
    511     return errors::Internal(
    512         "TopKOp: Could not launch "
    513         "cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, "
    514         "temp_storage_bytes: ",
    515         temp_storage_bytes, ", status: ", cudaGetErrorString(err));
    516   }
    517   if (k < num_cols) {
    518     // Need to copy subsets of sorted_indices and sorted_outputs to
    519     // indices and outputs.
    520     const Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{0, 0};
    521     const Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{num_rows, k};
    522     To32Bit(indices).device(d) =
    523         To32Bit(temp_indices.matrix<int32>()).slice(slice_indices, slice_sizes);
    524     To32Bit(values).device(d) =
    525         To32Bit(temp_values.matrix<T>()).slice(slice_indices, slice_sizes);
    526   }
    527   return Status::OK();
    528 }
    529 
    530 }  // end namespace impl
    531 
    532 namespace functor {
    533 
    534 template <typename T>
    535 struct TopKFunctor<GPUDevice, T> {
    536   static EIGEN_ALWAYS_INLINE Status
    537   Compute(OpKernelContext* context, bool sorted, int k,
    538           const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows,
    539           const int64 num_cols, typename TTypes<T, 2>::Tensor values,
    540           typename TTypes<int, 2>::Tensor indices) {
    541     // For small k, use the heap implementation.  For larger k, use
    542     // the in-place cub sort.  For k == num_cols, always use the
    543     // in-place cub sort.  The thresholds for n and k were determined
    544     // empirically.
    545     if (num_cols <= 1000 || k == num_cols || k >= 100) {
    546       return impl::LaunchSortKernel(context, input.data(), num_rows, num_cols,
    547                                     k, values, indices);
    548     } else {
    549       const cudaStream_t& cu_stream = GetCudaStream(context);
    550       auto err = impl::LaunchTopKKernel(cu_stream, /* num_shards */ 0,
    551                                         input.data(), num_rows, num_cols, k,
    552                                         sorted, values.data(), indices.data());
    553       if (err != cudaSuccess) {
    554         return errors::Internal(
    555             "Could not launch TopKKernel: ", cudaGetErrorString(err), ".");
    556       } else {
    557         return Status::OK();
    558       }
    559     }
    560   }
    561 };
    562 
    563 }  // end namespace functor
    564 
    565 #define INSTANTIATE_TEMPLATE(type) \
    566   template struct functor::TopKFunctor<GPUDevice, type>;
    567 
    568 TF_CALL_GPU_NUMBER_TYPES(INSTANTIATE_TEMPLATE);
    569 TF_CALL_INTEGRAL_TYPES(INSTANTIATE_TEMPLATE);
    570 #undef INSTANTIATE_TEMPLATE
    571 
    572 }  // namespace tensorflow
    573 
    574 #endif  // GOOGLE_CUDA
    575