1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #if GOOGLE_CUDA 17 18 #define EIGEN_USE_GPU 19 20 #include <cmath> 21 #include <vector> 22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 23 #include "external/cub_archive/cub/device/device_segmented_radix_sort.cuh" 24 #include "external/cub_archive/cub/iterator/counting_input_iterator.cuh" 25 #include "external/cub_archive/cub/iterator/transform_input_iterator.cuh" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/register_types.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/framework/tensor_shape.h" 30 #include "tensorflow/core/kernels/topk_op.h" 31 #include "tensorflow/core/lib/gtl/top_n.h" 32 #include "tensorflow/core/platform/logging.h" 33 #include "tensorflow/core/platform/types.h" 34 #include "tensorflow/core/util/cuda_kernel_helper.h" 35 36 // Required for sorting Eigen::half 37 namespace cub { 38 template <> 39 struct NumericTraits<Eigen::half> 40 : BaseTraits<FLOATING_POINT, true, false, unsigned short int, Eigen::half> { 41 }; 42 } // namespace cub 43 44 namespace tensorflow { 45 46 typedef Eigen::GpuDevice GPUDevice; 47 48 namespace impl { 49 50 enum class HeapType { kMinHeap, kMaxHeap }; 51 enum class PreferIndices { kLower, kHigher }; 52 53 template <typename T> 54 struct Entry { 55 int index; 56 T value; 57 58 // Test-only. 59 static bool greater(const Entry<T>& a, const Entry<T>& b) { 60 if (a.value == b.value) { 61 return a.index < b.index; 62 } 63 return a.value > b.value; 64 } 65 }; 66 67 template <typename T> 68 struct LinearData { 69 typedef impl::Entry<T> Entry; 70 71 __device__ Entry& operator[](std::size_t index) const { return data[index]; } 72 73 __device__ int get_index(int i) const { return data[i].index; } 74 __device__ T get_value(int i) const { return data[i].value; } 75 76 Entry* const data; 77 }; 78 79 template <typename T> 80 struct IndirectLinearData { 81 typedef impl::Entry<T> Entry; 82 83 __device__ Entry& operator[](std::size_t index) const { return data[index]; } 84 85 __device__ int get_index(int i) const { 86 return backing_data[data[i].index].index; 87 } 88 __device__ T get_value(int i) const { return data[i].value; } 89 90 Entry* const data; 91 Entry* const backing_data; 92 }; 93 94 #if GOOGLE_CUDA 95 template <typename T> 96 struct StridedData { 97 typedef impl::Entry<T> Entry; 98 99 __device__ Entry& operator[](std::size_t index) const { 100 return data[index * blockDim.x + threadIdx.x]; 101 } 102 103 __device__ int get_index(int i) const { return (*this)[i].index; } 104 __device__ T get_value(int i) const { return (*this)[i].value; } 105 106 Entry* const data; 107 }; 108 #endif 109 110 // A heap of Entry<T> that can either work as a min-heap or as a max-heap. 111 template <HeapType heapType, PreferIndices preferIndices, 112 template <typename> class Data, typename T> 113 struct IndexedHeap { 114 typedef typename Data<T>::Entry Entry; 115 const Data<T> data; 116 117 __device__ bool is_above(int left, int right) { 118 T left_value = data.get_value(left); 119 T right_value = data.get_value(right); 120 if (left_value == right_value) { 121 if (preferIndices == PreferIndices::kLower) { 122 return data.get_index(left) < data.get_index(right); 123 } else { 124 return data.get_index(left) > data.get_index(right); 125 } 126 } 127 if (heapType == HeapType::kMinHeap) { 128 return left_value < right_value; 129 } else { 130 return left_value > right_value; 131 } 132 } 133 134 __device__ void assign(int i, const Entry& entry) { data[i] = entry; } 135 136 __device__ void push_up(int i) { 137 int child = i; 138 int parent; 139 for (; child > 0; child = parent) { 140 parent = (child - 1) / 2; 141 if (!is_above(child, parent)) { 142 // Heap property satisfied. 143 break; 144 } 145 swap(child, parent); 146 } 147 } 148 149 __device__ void swap(int a, int b) { 150 auto tmp = data[b]; 151 data[b] = data[a]; 152 data[a] = tmp; 153 } 154 155 __device__ void push_root_down(int k) { push_down(0, k); } 156 157 // MAX-HEAPIFY in Cormen 158 __device__ void push_down(int node, int k) { 159 while (true) { 160 const int left = 2 * node + 1; 161 const int right = left + 1; 162 int smallest = node; 163 if (left < k && is_above(left, smallest)) { 164 smallest = left; 165 } 166 if (right < k && is_above(right, smallest)) { 167 smallest = right; 168 } 169 if (smallest == node) { 170 break; 171 } 172 swap(smallest, node); 173 node = smallest; 174 } 175 } 176 177 // BUILD-MAX-HEAPIFY in Cormen 178 __device__ void build(int k) { 179 for (int node = (k - 1) / 2; node >= 0; node--) { 180 push_down(node, k); 181 } 182 } 183 184 // HEAP-EXTRACT-MAX in Cormen 185 __device__ void remove_root(int k) { 186 data[0] = data[k - 1]; 187 push_root_down(k - 1); 188 } 189 190 // in-place HEAPSORT in Cormen 191 // This method destroys the heap property. 192 __device__ void sort(int k) { 193 for (int slot = k - 1; slot > 0; slot--) { 194 // This is like remove_root but we insert the element at the end. 195 swap(slot, 0); 196 // Heap is now an element smaller. 197 push_root_down(/*k=*/slot); 198 } 199 } 200 201 __device__ void replace_root(const Entry& entry, int k) { 202 data[0] = entry; 203 push_root_down(k); 204 } 205 206 __device__ const Entry& root() { return data[0]; } 207 }; 208 209 template <HeapType heapType, PreferIndices preferIndices, 210 template <typename> class Data, typename T> 211 __device__ IndexedHeap<heapType, preferIndices, Data, T> make_indexed_heap( 212 typename Data<T>::Entry* data) { 213 return IndexedHeap<heapType, preferIndices, Data, T>{Data<T>{data}}; 214 } 215 216 // heapTopK walks over [input, input+length) with `step_size` stride starting at 217 // `start_index`. 218 // It builds a top-`k` heap that is stored in `heap_entries` using `Accessor` to 219 // access elements in `heap_entries`. If sorted=true, the elements will be 220 // sorted at the end. 221 template <typename T, template <typename> class Data = LinearData> 222 __device__ void heapTopK(const T* __restrict__ input, int length, int k, 223 Entry<T>* __restrict__ heap_entries, 224 bool sorted = false, int start_index = 0, 225 int step_size = 1) { 226 assert(k <= length); 227 228 auto heap = 229 make_indexed_heap<HeapType::kMinHeap, PreferIndices::kHigher, Data, T>( 230 heap_entries); 231 232 int heap_end_index = start_index + k * step_size; 233 if (heap_end_index > length) { 234 heap_end_index = length; 235 } 236 // Initialize the min-heap. 237 for (int index = start_index, slot = 0; index < heap_end_index; 238 index += step_size, slot++) { 239 heap.assign(slot, {index, input[index]}); 240 } 241 242 heap.build(k); 243 244 // Now iterate over the remaining items. 245 // If an item is smaller than the min element, it is not amongst the top k. 246 // Otherwise, replace the min element with it and push upwards. 247 for (int index = heap_end_index; index < length; index += step_size) { 248 // We prefer elements with lower indices. This is given here. 249 // Later elements automatically have higher indices, so can be discarded. 250 if (input[index] > heap.root().value) { 251 // This element should replace the min. 252 heap.replace_root({index, input[index]}, k); 253 } 254 } 255 256 // Sort if wanted. 257 if (sorted) { 258 heap.sort(k); 259 } 260 } 261 262 // mergeShards performs a top-k merge on `num_shards` many sorted streams that 263 // are sorted and stored in `entries` in a strided way: 264 // |s_1 1st|s_2 1st|...s_{num_shards} 1st|s_1 2nd|s_2 2nd|... 265 // The overall top k elements are written to `top_k_values` and their indices 266 // to top_k_indices. 267 // `top_k_heap` is used as temporary storage for the merge heap. 268 template <typename T> 269 __device__ void mergeShards(int num_shards, int k, 270 Entry<T>* __restrict__ entries, 271 Entry<T>* __restrict__ top_k_heap, T* top_k_values, 272 int* top_k_indices) { 273 // If k < num_shards, we can use a min-heap with k elements to get the top k 274 // of the sorted blocks. 275 // If k > num_shards, we can initialize a min-heap with the top element from 276 // each sorted block. 277 const int heap_size = k < num_shards ? k : num_shards; 278 279 // Min-heap part. 280 { 281 auto min_heap = IndexedHeap<HeapType::kMinHeap, PreferIndices::kHigher, 282 IndirectLinearData, T>{ 283 IndirectLinearData<T>{top_k_heap, entries}}; 284 // Initialize the heap as a min-heap. 285 for (int slot = 0; slot < heap_size; slot++) { 286 min_heap.assign(slot, {slot, entries[slot].value}); 287 } 288 min_heap.build(heap_size); 289 290 // Now perform top k with the remaining shards (if num_shards > heap_size). 291 for (int shard = heap_size; shard < num_shards; shard++) { 292 const auto entry = entries[shard]; 293 const auto root = min_heap.root(); 294 if (entry.value < root.value) { 295 continue; 296 } 297 if (entry.value == root.value && 298 entry.index > entries[root.index].index) { 299 continue; 300 } 301 // This element should replace the min. 302 min_heap.replace_root({shard, entry.value}, heap_size); 303 } 304 } 305 306 // Max-part. 307 { 308 // Turn the min-heap into a max-heap in-place. 309 auto max_heap = IndexedHeap<HeapType::kMaxHeap, PreferIndices::kLower, 310 IndirectLinearData, T>{ 311 IndirectLinearData<T>{top_k_heap, entries}}; 312 // Heapify into a max heap. 313 max_heap.build(heap_size); 314 315 // Now extract the minimum k-1 times. 316 // k is treated specially. 317 const int last_k = k - 1; 318 for (int rank = 0; rank < last_k; rank++) { 319 const Entry<T>& max_element = max_heap.root(); 320 top_k_values[rank] = max_element.value; 321 int shard_index = max_element.index; 322 top_k_indices[rank] = entries[shard_index].index; 323 int next_shard_index = shard_index + num_shards; 324 // For rank < k-1, each top k heap still contains at least 1 element, 325 // so we can draw a replacement. 326 max_heap.replace_root({next_shard_index, entries[next_shard_index].value}, 327 heap_size); 328 } 329 330 // rank == last_k. 331 const Entry<T>& max_element = max_heap.root(); 332 top_k_values[last_k] = max_element.value; 333 int shard_index = max_element.index; 334 top_k_indices[last_k] = entries[shard_index].index; 335 } 336 } 337 338 extern __shared__ char shared_memory[]; 339 340 template <typename T> 341 __global__ void TopKKernel(const T* input, int length, int k, bool sorted, 342 T* output, int* indices) { 343 const int batch_index = blockIdx.x; 344 const T* batch_input = input + batch_index * length; 345 346 const int thread_index = threadIdx.x; 347 const int thread_count = blockDim.x; 348 349 Entry<T>* shared_entries = (Entry<T>*)shared_memory; 350 351 heapTopK<T, StridedData>(batch_input, length, k, shared_entries, true, 352 thread_index, thread_count); 353 354 __syncthreads(); 355 if (thread_index == 0) { 356 const int offset = batch_index * k; 357 auto batch_output = output + offset; 358 auto batch_indices = indices + offset; 359 Entry<T>* top_k_heap = shared_entries + thread_count * k; 360 361 // TODO(blackhc): Erich says: Performance can likely be improved 362 // significantly by having the merge be done by multiple threads rather than 363 // just one. ModernGPU has some nice primitives that could help with this. 364 mergeShards(thread_count, k, shared_entries, top_k_heap, batch_output, 365 batch_indices); 366 } 367 } 368 369 template <typename T> 370 cudaError LaunchTopKKernel(const cudaStream_t& stream, int num_shards, 371 const T* input, int batch_size, int length, int k, 372 bool sorted, T* output, int* indices) { 373 // This code assumes that k is small enough that the computation 374 // fits inside shared memory (hard coded to 48KB). In practice this 375 // means k <= 3072 for T=float/int32 and k <= 2048 for T=double/int64. 376 // The calculation is: 377 // shared_memory_size / (2 * (sizeof(int) + sizeof(T))) < k. 378 379 // Use as many shards as possible. 380 if (num_shards <= 0) { 381 constexpr auto shared_memory_size = 48 << 10; // 48 KB 382 const auto heap_size = k * sizeof(Entry<T>); 383 // shared_memory_size = (num_shards + 1) * heap_size <=> 384 num_shards = shared_memory_size / heap_size - 1; 385 if (num_shards <= 0) { 386 num_shards = 1; 387 } 388 auto shard_size = length / num_shards; 389 auto min_shard_size = 2 * k; 390 if (shard_size < min_shard_size) { 391 num_shards = length / min_shard_size; 392 } 393 if (num_shards <= 0) { 394 num_shards = 1; 395 } else if (num_shards > 1024) { 396 num_shards = 1024; 397 } 398 } 399 // We are limited by the amount of shared memory we have per block. 400 auto shared_memory_size = (num_shards + 1) * k * sizeof(Entry<T>); 401 402 TopKKernel<<<batch_size, num_shards, shared_memory_size, stream>>>( 403 input, length, k, sorted, output, indices); 404 return cudaGetLastError(); 405 } 406 407 struct SegmentOffsetCreator { 408 EIGEN_DEVICE_FUNC 409 SegmentOffsetCreator(int num_cols) : num_cols_(num_cols) {} 410 411 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const { 412 return idx * num_cols_; 413 }; 414 415 int num_cols_; 416 }; 417 418 struct ColumnIndexCreator { 419 ColumnIndexCreator(int num_cols) : num_cols_(num_cols) {} 420 421 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()( 422 const Eigen::array<int, 1>& ix) const { 423 return ix[0] % num_cols_; 424 } 425 426 int num_cols_; 427 }; 428 429 template <typename T> 430 Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows, 431 int num_cols, int k, 432 typename TTypes<T, 2>::Tensor values, 433 TTypes<int, 2>::Tensor indices) { 434 const GPUDevice& d = ctx->eigen_device<GPUDevice>(); 435 const cudaStream_t& cu_stream = GetCudaStream(ctx); 436 size_t temp_storage_bytes = -1; 437 438 // TODO(ebrevdo): Once cub supports iterators for ValueT replace that tensor 439 // with an iterator that directly returns the correct value. 440 Tensor input_indices; 441 TF_RETURN_IF_ERROR(ctx->allocate_temp( 442 DT_INT32, TensorShape({num_rows, num_cols}), &input_indices)); 443 auto input_indices_t = To32Bit(input_indices.flat<int32>()); 444 input_indices_t.device(d) = 445 input_indices_t.generate(ColumnIndexCreator(num_cols)); 446 447 cub::CountingInputIterator<int> counting_iter(0); 448 cub::TransformInputIterator<int, SegmentOffsetCreator, 449 cub::CountingInputIterator<int>> 450 segment_offsets_t(counting_iter, SegmentOffsetCreator(num_cols)); 451 452 Tensor temp_values; 453 Tensor temp_indices; 454 T* sorted_values_ptr; 455 int* sorted_indices_ptr; 456 if (k == num_cols) { 457 // Doing a full sort, no intermediate values needed. 458 sorted_values_ptr = values.data(); 459 sorted_indices_ptr = indices.data(); 460 } else { 461 // Need to create intermediate values for sorting. 462 TF_RETURN_IF_ERROR(ctx->allocate_temp( 463 DT_INT32, TensorShape({num_rows, num_cols}), &temp_indices)); 464 TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value, 465 TensorShape({num_rows, num_cols}), 466 &temp_values)); 467 sorted_indices_ptr = temp_indices.flat<int32>().data(); 468 sorted_values_ptr = temp_values.flat<T>().data(); 469 } 470 471 auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending( 472 /* d_temp_storage */ nullptr, 473 /* temp_storage_bytes */ temp_storage_bytes, 474 /* d_keys_in */ input, 475 /* d_keys_out */ sorted_values_ptr, 476 /* d_values_in */ input_indices_t.data(), 477 /* d_values_out */ sorted_indices_ptr, 478 /* num_items */ num_cols * num_rows, 479 /* num_segments */ num_rows, 480 /* d_begin_offsets */ segment_offsets_t, 481 /* d_end_offsets */ segment_offsets_t + 1, 482 /* begin_bit */ 0, 483 /* end_bit */ sizeof(T) * 8, 484 /* stream */ cu_stream); 485 if (err != cudaSuccess) { 486 return errors::Internal( 487 "TopKOp: Could not launch " 488 "cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate " 489 "temp_storage_bytes, status: ", 490 cudaGetErrorString(err)); 491 } 492 Tensor temp_storage; 493 TF_RETURN_IF_ERROR(ctx->allocate_temp( 494 DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}), 495 &temp_storage)); 496 err = cub::DeviceSegmentedRadixSort::SortPairsDescending( 497 /* d_temp_storage */ temp_storage.flat<int8>().data(), 498 /* temp_storage_bytes */ temp_storage_bytes, 499 /* d_keys_in */ input, 500 /* d_keys_out */ sorted_values_ptr, 501 /* d_values_in */ input_indices_t.data(), 502 /* d_values_out */ sorted_indices_ptr, 503 /* num_items */ num_cols * num_rows, 504 /* num_segments */ num_rows, 505 /* d_begin_offsets */ segment_offsets_t, 506 /* d_end_offsets */ segment_offsets_t + 1, 507 /* begin_bit */ 0, 508 /* end_bit */ sizeof(T) * 8, 509 /* stream */ cu_stream); 510 if (err != cudaSuccess) { 511 return errors::Internal( 512 "TopKOp: Could not launch " 513 "cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, " 514 "temp_storage_bytes: ", 515 temp_storage_bytes, ", status: ", cudaGetErrorString(err)); 516 } 517 if (k < num_cols) { 518 // Need to copy subsets of sorted_indices and sorted_outputs to 519 // indices and outputs. 520 const Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{0, 0}; 521 const Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{num_rows, k}; 522 To32Bit(indices).device(d) = 523 To32Bit(temp_indices.matrix<int32>()).slice(slice_indices, slice_sizes); 524 To32Bit(values).device(d) = 525 To32Bit(temp_values.matrix<T>()).slice(slice_indices, slice_sizes); 526 } 527 return Status::OK(); 528 } 529 530 } // end namespace impl 531 532 namespace functor { 533 534 template <typename T> 535 struct TopKFunctor<GPUDevice, T> { 536 static EIGEN_ALWAYS_INLINE Status 537 Compute(OpKernelContext* context, bool sorted, int k, 538 const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows, 539 const int64 num_cols, typename TTypes<T, 2>::Tensor values, 540 typename TTypes<int, 2>::Tensor indices) { 541 // For small k, use the heap implementation. For larger k, use 542 // the in-place cub sort. For k == num_cols, always use the 543 // in-place cub sort. The thresholds for n and k were determined 544 // empirically. 545 if (num_cols <= 1000 || k == num_cols || k >= 100) { 546 return impl::LaunchSortKernel(context, input.data(), num_rows, num_cols, 547 k, values, indices); 548 } else { 549 const cudaStream_t& cu_stream = GetCudaStream(context); 550 auto err = impl::LaunchTopKKernel(cu_stream, /* num_shards */ 0, 551 input.data(), num_rows, num_cols, k, 552 sorted, values.data(), indices.data()); 553 if (err != cudaSuccess) { 554 return errors::Internal( 555 "Could not launch TopKKernel: ", cudaGetErrorString(err), "."); 556 } else { 557 return Status::OK(); 558 } 559 } 560 } 561 }; 562 563 } // end namespace functor 564 565 #define INSTANTIATE_TEMPLATE(type) \ 566 template struct functor::TopKFunctor<GPUDevice, type>; 567 568 TF_CALL_GPU_NUMBER_TYPES(INSTANTIATE_TEMPLATE); 569 TF_CALL_INTEGRAL_TYPES(INSTANTIATE_TEMPLATE); 570 #undef INSTANTIATE_TEMPLATE 571 572 } // namespace tensorflow 573 574 #endif // GOOGLE_CUDA 575