1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #if GOOGLE_CUDA 17 18 #define EIGEN_USE_GPU 19 20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 21 #include "external/cub_archive/cub/device/device_reduce.cuh" 22 #include "external/cub_archive/cub/device/device_segmented_reduce.cuh" 23 #include "external/cub_archive/cub/iterator/counting_input_iterator.cuh" 24 #include "external/cub_archive/cub/iterator/transform_input_iterator.cuh" 25 #include "external/cub_archive/cub/warp/warp_reduce.cuh" 26 #include "cuda/include/cuComplex.h" 27 #include "tensorflow/core/kernels/reduction_ops.h" 28 #include "tensorflow/core/lib/core/bits.h" 29 #include "tensorflow/core/util/cuda_kernel_helper.h" 30 #include "tensorflow/core/util/permutation_input_iterator.h" 31 #include "tensorflow/core/util/transform_output_iterator.h" 32 33 #include <sstream> 34 35 namespace tensorflow { 36 namespace functor { 37 38 typedef Eigen::GpuDevice GPUDevice; 39 40 template <typename T> 41 struct Sum { 42 __host__ __device__ T operator()(const T& a, const T& b) const { 43 return a + b; 44 } 45 }; 46 47 // needed to work around a compiler bug in nvcc - it doesn't seem to like 48 // the overloaded addition op for std::complex 49 template <> 50 struct Sum<std::complex<float>> { 51 __host__ __device__ std::complex<float> operator()( 52 const std::complex<float>& a, const std::complex<float>& b) const { 53 auto result = cuCaddf(make_cuComplex(a.real(), a.imag()), 54 make_cuComplex(b.real(), b.imag())); 55 return std::complex<float>(result.x, result.y); 56 } 57 }; 58 59 template <> 60 struct Sum<std::complex<double>> { 61 __host__ __device__ std::complex<double> operator()( 62 const std::complex<double>& a, const std::complex<double>& b) const { 63 auto result = cuCadd(make_cuDoubleComplex(a.real(), a.imag()), 64 make_cuDoubleComplex(b.real(), b.imag())); 65 return std::complex<double>(result.x, result.y); 66 } 67 }; 68 69 template <typename T> 70 struct Prod { 71 __host__ __device__ T operator()(const T& a, const T& b) const { 72 return a * b; 73 } 74 }; 75 76 // needed to work around a compiler bug in nvcc - it doesn't seem to like 77 // the overloaded multiply op for std::complex 78 template <> 79 struct Prod<std::complex<float>> { 80 __host__ __device__ std::complex<float> operator()( 81 const std::complex<float>& a, const std::complex<float>& b) const { 82 auto result = cuCmulf(make_cuComplex(a.real(), a.imag()), 83 make_cuComplex(b.real(), b.imag())); 84 return std::complex<float>(result.x, result.y); 85 } 86 }; 87 88 template <> 89 struct Prod<std::complex<double>> { 90 __host__ __device__ std::complex<double> operator()( 91 const std::complex<double>& a, const std::complex<double>& b) const { 92 auto result = cuCmul(make_cuDoubleComplex(a.real(), a.imag()), 93 make_cuDoubleComplex(b.real(), b.imag())); 94 return std::complex<double>(result.x, result.y); 95 } 96 }; 97 98 template <typename T, typename outT = T> 99 struct DividesBy { 100 T divisor; 101 102 __host__ __device__ explicit DividesBy(T divisor) : divisor(divisor) {} 103 104 __host__ __device__ outT operator()(const T& x) const { return x / divisor; } 105 }; 106 107 // needed to work around a compiler bug in nvcc - it doesn't seem to like 108 // the overloaded ops for std::complex 109 template <> 110 struct DividesBy<std::complex<float>> { 111 cuFloatComplex divisor; 112 113 __host__ __device__ explicit DividesBy(std::complex<float> divisor) 114 : divisor(make_cuComplex(divisor.real(), divisor.imag())) {} 115 116 // implements 117 __host__ __device__ std::complex<float> operator()( 118 const std::complex<float>& x) const { 119 auto result = cuCdivf(make_cuComplex(x.real(), x.imag()), divisor); 120 return std::complex<float>(result.x, result.y); 121 } 122 }; 123 124 template <> 125 struct DividesBy<std::complex<double>> { 126 cuDoubleComplex divisor; 127 128 __host__ __device__ explicit DividesBy(std::complex<double> divisor) 129 : divisor(make_cuDoubleComplex(divisor.real(), divisor.imag())) {} 130 131 // implements 132 __host__ __device__ std::complex<double> operator()( 133 const std::complex<double>& x) const { 134 auto result = cuCdiv(make_cuDoubleComplex(x.real(), x.imag()), divisor); 135 return std::complex<double>(result.x, result.y); 136 } 137 }; 138 139 template <> 140 struct DividesBy<float, Eigen::half> { 141 float divisor; 142 143 __host__ __device__ explicit DividesBy(float divisor) : divisor(divisor) {} 144 145 __host__ __device__ Eigen::half operator()(const float& x) const { 146 return Eigen::half(x / divisor); 147 } 148 }; 149 150 struct HalfToFloat { 151 __host__ __device__ float operator()(const Eigen::half& x) const { 152 return Eigen::half_impl::half_to_float(x); 153 } 154 }; 155 156 struct FloatToHalf { 157 __host__ __device__ Eigen::half operator()(const float& x) const { 158 return Eigen::half_impl::float_to_half_rtne(x); 159 } 160 }; 161 162 struct And { 163 __host__ __device__ bool operator()(const bool& a, const bool& b) const { 164 return a && b; 165 } 166 }; 167 168 struct Or { 169 __host__ __device__ bool operator()(const bool& a, const bool& b) const { 170 return a || b; 171 } 172 }; 173 174 // each block does a grid strided loop and reduces its values locally 175 // the case of one block is used for low latency small reductions to scalars 176 template <typename T, typename outT, int num_threads, typename Op> 177 __global__ void BlockReduceKernel( 178 T in, outT out, int num_elems, Op op, 179 typename std::iterator_traits<T>::value_type initVal) { 180 const int bid = blockIdx.x; 181 const int tid = threadIdx.x; 182 183 const int gid = bid * blockDim.x + tid; 184 const int stride = blockDim.x * gridDim.x; 185 186 typedef typename std::iterator_traits<T>::value_type value_type; 187 188 value_type sum = initVal; 189 if (gid < num_elems) { 190 sum = in[gid]; 191 for (int pos = gid + stride; pos < num_elems; pos += stride) { 192 sum = op(sum, in[pos]); 193 } 194 } 195 196 typedef cub::BlockReduce<value_type, num_threads> BlockReduce; 197 198 __shared__ typename BlockReduce::TempStorage temp_storage; 199 200 // only include input values in the reduction 201 // 202 // elements: ----------------- 203 // grid: |====|====|====|====|====| 204 const int num_elements_to_reduce = 205 max(min(num_elems - bid * blockDim.x, num_threads), 0); 206 207 sum = BlockReduce(temp_storage).Reduce(sum, op, num_elements_to_reduce); 208 209 if (tid == 0) out[bid] = sum; 210 } 211 212 // maps a warp to each row 213 template <typename T, typename outT, typename Op> 214 __global__ void RowReduceKernel( 215 T in, outT out, int num_rows, int num_cols, Op op, 216 typename std::iterator_traits<T>::value_type initVal) { 217 typedef typename std::iterator_traits<T>::value_type value_type; 218 const int row = (blockIdx.x * blockDim.x + threadIdx.x) / 32; 219 const int lane = threadIdx.x % 32; 220 221 if (num_cols == 1) { 222 int gid = threadIdx.x + blockIdx.x * blockDim.x; 223 if (gid < num_rows) out[gid] = in[gid]; 224 return; 225 } 226 227 value_type sum = initVal; 228 int col = lane; 229 230 if (row < num_rows && col < num_cols) { 231 sum = in[row * num_cols + col]; 232 col += 32; 233 for (; col < num_cols; col += 32) { 234 sum = op(sum, in[row * num_cols + col]); 235 } 236 } 237 238 typedef cub::WarpReduce<value_type> WarpReduce; 239 240 __shared__ typename WarpReduce::TempStorage temp_storage; 241 242 sum = WarpReduce(temp_storage).Reduce(sum, op, min(num_cols, 32)); 243 244 if (row < num_rows && lane == 0) out[row] = sum; 245 } 246 247 // Works only if there are <= 16 columns 248 // each warps sums over multiple rows at once 249 template <typename T, typename outT, typename Op> 250 __global__ void ColumnReduceMax16ColumnsKernel( 251 T in, outT out, int num_rows, int num_cols, Op op, 252 typename std::iterator_traits<T>::value_type initVal) { 253 typedef typename std::iterator_traits<T>::value_type value_type; 254 int rows_per_warp = 32 / num_cols; 255 256 const int lane = threadIdx.x % 32; 257 const int lane_row = lane / num_cols; 258 259 const int start_row_warp = 260 rows_per_warp * (blockIdx.y * blockDim.y + threadIdx.y); 261 const int start_row_lane = start_row_warp + lane_row; 262 int row = start_row_lane; 263 int col = lane % num_cols; 264 265 value_type sum = initVal; 266 if (row * num_cols + col < num_rows * num_cols) 267 sum = in[row * num_cols + col]; 268 269 // 1D array necessary due to bug in CUDA 9 compiler. 270 // TODO(nluehr) revert to 2D array when compiler is ready. 271 __shared__ value_type partial_sums[32 * 33]; 272 273 row += rows_per_warp * gridDim.y * blockDim.y; 274 for (; row < num_rows; row += rows_per_warp * gridDim.y * blockDim.y) { 275 int global_pos = row * num_cols + col; 276 if (global_pos < (num_rows * num_cols)) 277 sum = op(sum, in[row * num_cols + col]); 278 } 279 280 const int rows_in_this_warp = min(rows_per_warp, num_rows - start_row_warp); 281 // not the most efficient way to do this sum 282 for (int i = 1; i < rows_in_this_warp; ++i) { 283 value_type tmp = 284 cub::ShuffleIndex(sum, threadIdx.x + i * num_cols, 32, 0xffffffff); 285 if (lane < num_cols) sum = op(sum, tmp); 286 } 287 288 if (lane < num_cols) partial_sums[lane * 33 + threadIdx.y] = sum; 289 290 __syncthreads(); 291 292 if (threadIdx.y == 0 && threadIdx.x < num_cols) { 293 value_type s = partial_sums[threadIdx.x * 33]; 294 295 if (blockDim.y > 1) { 296 for (int row = 1; row < blockDim.y; ++row) { 297 s = op(s, partial_sums[threadIdx.x * 33 + row]); 298 } 299 } 300 301 out[col * gridDim.y + blockIdx.y] = s; 302 } 303 } 304 305 // Maps each block to a column range 32 wide 306 template <typename T, typename outT, typename Op> 307 __global__ void ColumnReduceKernel( 308 T in, outT out, int num_rows, int num_cols, Op op, 309 typename std::iterator_traits<T>::value_type initVal) { 310 typedef typename std::iterator_traits<T>::value_type value_type; 311 int row = blockIdx.y * blockDim.y + threadIdx.y; 312 int col = blockIdx.x * 32 + threadIdx.x; 313 314 value_type sum = initVal; 315 if (row < num_rows && col < num_cols) sum = in[row * num_cols + col]; 316 317 // 1D array necessary due to bug in CUDA 9 compiler. 318 // TODO(nluehr) revert to 2D array when compiler is ready. 319 __shared__ value_type partial_sums[32 * 33]; 320 321 row += gridDim.y * blockDim.y; 322 323 if (col < num_cols) { 324 for (; row < num_rows; row += gridDim.y * blockDim.y) { 325 sum = op(sum, in[row * num_cols + col]); 326 } 327 } 328 329 partial_sums[threadIdx.x * 33 + threadIdx.y] = sum; 330 331 __syncthreads(); 332 333 if (threadIdx.y == 0 && col < num_cols) { 334 value_type s = partial_sums[threadIdx.x * 33]; 335 336 // only include input values in the reduction 337 // elem block_rows 338 // - = 339 // - = 340 // # # block boundary 341 // - = 342 // - = 343 // # # block boundary 344 // - = 345 // = 346 const int numRowsThisBlock = 347 min(blockDim.y, num_rows - blockIdx.y * blockDim.y); 348 349 for (int row = 1; row < numRowsThisBlock; ++row) { 350 s = op(s, partial_sums[threadIdx.x * 33 + row]); 351 } 352 353 out[col * gridDim.y + blockIdx.y] = s; 354 } 355 } 356 357 // does multiple warp size segmented reductions in parallel 358 // segments cannot cross warp boundaries (mainly used for reducing the segments 359 // that come from the Max16Columns column reduction kernel) 360 template <typename T, typename outT, typename Op> 361 __global__ void CleanupSegments( 362 T partial_sums, outT out, int num_rows, int num_cols, int segment_size, 363 Op op, typename std::iterator_traits<T>::value_type initVal) { 364 typedef typename std::iterator_traits<T>::value_type value_type; 365 const int tid = threadIdx.x + blockIdx.x * blockDim.x; 366 367 value_type val = initVal; 368 if (tid < segment_size * num_cols) val = partial_sums[tid]; 369 370 typedef cub::WarpReduce<value_type> WarpReduce; 371 372 __shared__ typename WarpReduce::TempStorage temp_storage; 373 374 const bool head_flag = (threadIdx.x % segment_size) == 0; 375 value_type sum = 376 WarpReduce(temp_storage).HeadSegmentedReduce(val, head_flag, op); 377 378 if (head_flag && tid < segment_size * num_cols) { 379 out[tid / segment_size] = sum; 380 } 381 } 382 383 // assigns one thread to a column 384 template <typename T, typename outT, typename Op> 385 __global__ void ColumnReduceSimpleKernel(T in, outT out, int num_planes, 386 int num_rows, int num_cols, Op op) { 387 typedef typename std::iterator_traits<T>::value_type value_type; 388 const int gid = threadIdx.x + blockIdx.x * blockDim.x; 389 const int elems_per_plane = num_rows * num_cols; 390 391 const int plane = gid / num_cols; 392 const int col = gid % num_cols; 393 394 if (plane >= num_planes) return; 395 396 if (num_rows == 1) { 397 out[plane * elems_per_plane + col] = in[plane * elems_per_plane + col]; 398 return; 399 } 400 401 value_type sum = op(in[plane * elems_per_plane + col], 402 in[plane * elems_per_plane + num_cols + col]); 403 for (int row = 2; row < num_rows; ++row) { 404 sum = op(sum, in[plane * elems_per_plane + row * num_cols + col]); 405 } 406 407 out[plane * num_cols + col] = sum; 408 } 409 410 struct RowOffset { 411 __host__ __device__ explicit RowOffset(const int& cols) : cols_(cols) {} 412 413 __host__ __device__ int operator()(const int& x) const { return cols_ * x; } 414 415 int cols_; 416 }; 417 418 struct GatherOp { 419 __host__ __device__ GatherOp(const int& extent_x, const int& extent_y, 420 const int& extent_z, bool kOne) 421 : extent_x_(extent_x), 422 extent_y_(extent_y), 423 extent_z_(extent_z), 424 kOne_(kOne) { 425 if (kOne_) 426 group_size_ = extent_y_; 427 else 428 group_size_ = extent_x_ * extent_z_; 429 } 430 431 __host__ __device__ int operator()(const int& ind) const { 432 const int group = kOne_ ? ind / group_size_ : ind % group_size_; 433 const int offset = kOne_ ? ind % group_size_ : ind / group_size_; 434 435 const int x = group / extent_z_; 436 const int z = group % extent_z_; 437 438 return x * extent_y_ * extent_z_ + z + offset * extent_z_; 439 } 440 441 int extent_x_; 442 int extent_y_; 443 int extent_z_; 444 bool kOne_; 445 int group_size_; 446 }; 447 448 template <typename T, typename Op, typename OUT_T, typename IN_T> 449 void LaunchScalarReduction(OpKernelContext* ctx, OUT_T out, IN_T in, 450 int in_size, Op op, T init, 451 const cudaStream_t& cu_stream) { 452 // handle situations where low latency is important better than CUB 453 if (in_size <= 4096) { 454 const int num_blocks = 1; 455 const int num_threads = 256; 456 BlockReduceKernel<IN_T, OUT_T, num_threads> 457 <<<num_blocks, num_threads, 0, cu_stream>>>(in, out, in_size, op, init); 458 return; 459 } else if (in_size <= 1 << 19) { 460 const int num_threads = 256; 461 const int num_blocks = std::min(32, Eigen::divup(in_size, num_threads)); 462 // it seems like tailoring this to the GPU 463 // would be more effective, but all attempts 464 // at making this a multiple of the number of 465 // multiprocessors have lead to lower perf 466 // in general 467 // TODO(eriche) investigate this more 468 469 Tensor temp_storage; 470 OP_REQUIRES_OK( 471 ctx, 472 ctx->allocate_temp( 473 DT_INT8, TensorShape({static_cast<int64>(num_blocks * sizeof(T))}), 474 &temp_storage)); 475 476 BlockReduceKernel<IN_T, T*, num_threads> 477 <<<num_blocks, num_threads, 0, cu_stream>>>( 478 in, (T*)temp_storage.flat<int8_t>().data(), in_size, op, init); 479 480 // take care that we only reduce blocks that had some valid elements in them 481 // TODO(eriche): CUB currently has a bug in HeadSegmentedReduce that 482 // requires it to be used with a full warp. Can reduce 32 -> num_blocks 483 // when this is fixed. 484 CleanupSegments<<<1, 32, 0, cu_stream>>>( 485 (T*)temp_storage.flat<int8_t>().data(), out, 1, 1, num_blocks, op, 486 init); 487 return; 488 } 489 std::size_t temp_storage_bytes = 0; 490 491 Tensor temp_storage; 492 // written as a loop because it reduces clutter 493 // first pass allocates memory, second launches kernel(s) 494 for (int i = 0; i < 2; ++i) { 495 auto success = cub::DeviceReduce::Reduce( 496 i == 0 ? nullptr : temp_storage.flat<int8_t>().data(), 497 temp_storage_bytes, in, out, in_size, op, init, cu_stream); 498 499 OP_REQUIRES( 500 ctx, success == 0, 501 errors::Internal("CUB reduce error", cudaGetErrorString(success))); 502 503 if (i == 0) 504 OP_REQUIRES_OK( 505 ctx, 506 ctx->allocate_temp( 507 DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}), 508 &temp_storage)); 509 } 510 } 511 512 template <typename T, typename Op, typename OUT_T, typename IN_T> 513 void LaunchRowReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int num_rows, 514 int num_cols, Op op, T init, 515 const cudaStream_t& cu_stream) { 516 if (num_cols < 1024) { 517 const int threads_per_block = 128; 518 const int warps_per_block = threads_per_block / 32; 519 int num_blocks = (num_rows + warps_per_block - 1) / warps_per_block; 520 521 RowReduceKernel<<<num_blocks, threads_per_block, 0, cu_stream>>>( 522 in, out, num_rows, num_cols, op, init); 523 return; 524 } 525 526 // setup segment offsets with counting and transform iterator 527 RowOffset row_offset_op(num_cols); 528 cub::CountingInputIterator<int> counting_iter(0); 529 cub::TransformInputIterator<int, RowOffset, cub::CountingInputIterator<int>> 530 transform_iter(counting_iter, row_offset_op); 531 532 std::size_t temp_storage_bytes = 0; 533 Tensor temp_storage; 534 for (int i = 0; i < 2; ++i) { 535 auto success = cub::DeviceSegmentedReduce::Reduce( 536 i == 0 ? nullptr : temp_storage.flat<int8_t>().data(), 537 temp_storage_bytes, in, out, num_rows, transform_iter, 538 transform_iter + 1, op, init, cu_stream); 539 540 OP_REQUIRES(ctx, success == 0, 541 errors::Internal("CUB segmented reduce error", 542 cudaGetErrorString(success))); 543 544 if (i == 0) 545 OP_REQUIRES_OK( 546 ctx, 547 ctx->allocate_temp( 548 DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}), 549 &temp_storage)); 550 } 551 } 552 553 template <typename T, typename Op, typename OUT_T, typename IN_T> 554 void LaunchColumnReduction_LTE16Cols(OpKernelContext* ctx, OUT_T out, IN_T in, 555 int extent_x, int extent_y, Op op, T init, 556 const cudaStream_t& cu_stream) { 557 int rows_per_warp = 32 / extent_y; 558 dim3 block_dim(32, std::min(Eigen::divup(extent_x, rows_per_warp), 32), 1); 559 dim3 grid_dim(1, 560 Eigen::divup(static_cast<unsigned int>(extent_x), 561 rows_per_warp * block_dim.y), 562 1); 563 564 grid_dim.y = std::min((int)grid_dim.y, 32); 565 566 if (grid_dim.y > 2 && grid_dim.y < 32) { 567 int log2 = Log2Floor(grid_dim.y); 568 grid_dim.y = 1 << log2; 569 } 570 571 if (grid_dim.y == 1) { 572 ColumnReduceMax16ColumnsKernel<<<grid_dim, block_dim, 0, cu_stream>>>( 573 in, out, extent_x, extent_y, op, init); 574 } else { 575 Tensor temp_storage; 576 OP_REQUIRES_OK(ctx, 577 ctx->allocate_temp(DT_INT8, 578 TensorShape({static_cast<int64>( 579 sizeof(T) * extent_y * grid_dim.y)}), 580 &temp_storage)); 581 ColumnReduceMax16ColumnsKernel<<<grid_dim, block_dim, 0, cu_stream>>>( 582 in, (T*)temp_storage.flat<int8_t>().data(), extent_x, extent_y, op, 583 init); 584 585 dim3 new_grid_dim((grid_dim.y * extent_y + 31) / 32, 1, 1); 586 dim3 num_threads(128, 1, 1); 587 CleanupSegments<<<new_grid_dim, num_threads, 0, cu_stream>>>( 588 (T*)temp_storage.flat<int8_t>().data(), out, extent_x, extent_y, 589 grid_dim.y, op, init); 590 } 591 } 592 593 template <typename T, typename Op, typename OUT_T, typename IN_T> 594 void LaunchColumnReduction_LTE4096Cols(OpKernelContext* ctx, OUT_T out, IN_T in, 595 int extent_x, int extent_y, Op op, 596 T init, const cudaStream_t& cu_stream) { 597 dim3 block_dim(32, std::min(extent_x, 32), 1); 598 dim3 grid_dim((extent_y + 31) / 32, 1, 1); 599 600 if (grid_dim.x < 16) grid_dim.y = std::min((extent_x + 31) / 32, 32); 601 602 if (grid_dim.y > 2 && grid_dim.y < 32) { 603 int log2 = Log2Floor(grid_dim.y); 604 grid_dim.y = 1 << log2; 605 } 606 607 if (grid_dim.y == 1) { 608 ColumnReduceKernel<<<grid_dim, block_dim, 0, cu_stream>>>( 609 in, out, extent_x, extent_y, op, init); 610 } else { 611 Tensor temp_storage; 612 OP_REQUIRES_OK(ctx, 613 ctx->allocate_temp(DT_INT8, 614 TensorShape({static_cast<int64>( 615 sizeof(T) * extent_y * grid_dim.y)}), 616 &temp_storage)); 617 618 ColumnReduceKernel<<<grid_dim, block_dim, 0, cu_stream>>>( 619 in, (T*)temp_storage.flat<int8_t>().data(), extent_x, extent_y, op, 620 init); 621 622 dim3 new_grid_dim((grid_dim.y * extent_y + 31) / 32, 1, 1); 623 dim3 num_threads(128, 1, 1); 624 CleanupSegments<<<new_grid_dim, block_dim, 0, cu_stream>>>( 625 (T*)temp_storage.flat<int8_t>().data(), out, extent_x, extent_y, 626 grid_dim.y, op, init); 627 } 628 } 629 630 template <typename T, typename Op, typename OUT_T, typename IN_T> 631 void LaunchColumnReduction(OpKernelContext* ctx, OUT_T out, IN_T in, 632 int extent_x, int extent_y, Op op, T init, 633 const cudaStream_t& cu_stream) { 634 if (extent_y <= 16) { 635 LaunchColumnReduction_LTE16Cols(ctx, out, in, extent_x, extent_y, op, init, 636 cu_stream); 637 } else if (extent_y <= 4096) { 638 LaunchColumnReduction_LTE4096Cols(ctx, out, in, extent_x, extent_y, op, 639 init, cu_stream); 640 } else { 641 int threads_per_block = 128; 642 int num_blocks = Eigen::divup(extent_y, threads_per_block); 643 644 ColumnReduceSimpleKernel<<<num_blocks, threads_per_block, 0, cu_stream>>>( 645 in, out, 1, extent_x, extent_y, op); 646 } 647 } 648 649 template <typename T, typename Op, typename OUT_T, typename IN_T> 650 void Launch3DYReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int extent_x, 651 int extent_y, int extent_z, Op op, T init, 652 const cudaStream_t& cu_stream) { 653 int threads_per_block = 128; 654 int num_blocks = 655 (extent_x * extent_z + threads_per_block - 1) / threads_per_block; 656 657 // TODO(eriche): this won't be very good in the case of small x 658 // small z and large y. 659 ColumnReduceSimpleKernel<<<num_blocks, threads_per_block, 0, cu_stream>>>( 660 in, out, extent_x, extent_y, extent_z, op); 661 } 662 663 template <typename T, typename Op, typename OUT_T, typename IN_T> 664 void Launch3DXZReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int extent_x, 665 int extent_y, int extent_z, Op op, T init, 666 const cudaStream_t& cu_stream) { 667 // setup segment offsets with counting and transform iterator 668 RowOffset row_offset_op(extent_x * extent_z); 669 cub::CountingInputIterator<int> counting_iter(0); 670 cub::TransformInputIterator<int, RowOffset, cub::CountingInputIterator<int>> 671 transform_iter(counting_iter, row_offset_op); 672 673 GatherOp gather_op(extent_x, extent_y, extent_z, false); 674 typedef cub::TransformInputIterator<int, GatherOp, 675 cub::CountingInputIterator<int>> 676 gatherIterType; 677 gatherIterType gather_iter(counting_iter, gather_op); 678 679 PermutationInputIterator<T, IN_T, gatherIterType> permute_iter(in, 680 gather_iter); 681 682 std::size_t temp_storage_bytes = 0; 683 Tensor temp_storage; 684 685 for (int i = 0; i < 2; ++i) { 686 auto success = cub::DeviceSegmentedReduce::Reduce( 687 i == 0 ? nullptr : temp_storage.flat<int8_t>().data(), 688 temp_storage_bytes, permute_iter, out, extent_y, transform_iter, 689 transform_iter + 1, op, init, cu_stream); 690 691 OP_REQUIRES(ctx, success == 0, 692 errors::Internal("CUB segmented reduce error", 693 cudaGetErrorString(success))); 694 695 if (i == 0) 696 OP_REQUIRES_OK( 697 ctx, 698 ctx->allocate_temp( 699 DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}), 700 &temp_storage)); 701 } 702 } 703 704 namespace reduction_op_helper { 705 706 template <typename T, typename Op> 707 struct IsSum { 708 constexpr static bool value = 709 (std::is_same<Op, cub::Sum>::value || 710 std::is_same<Op, Eigen::internal::SumReducer<T>>::value || 711 std::is_same<Op, Sum<T>>::value); 712 }; 713 714 template <typename T, typename Op> 715 struct IsMax { 716 constexpr static bool value = 717 (std::is_same<Op, cub::Max>::value || 718 std::is_same<Op, Eigen::internal::MaxReducer<T>>::value); 719 }; 720 721 template <typename T, typename Op> 722 struct IsMin { 723 constexpr static bool value = 724 (std::is_same<Op, cub::Min>::value || 725 std::is_same<Op, Eigen::internal::MinReducer<T>>::value); 726 }; 727 728 template <typename T, typename Op> 729 struct IsProd { 730 constexpr static bool value = 731 (std::is_same<Op, Prod<T>>::value || 732 std::is_same<Op, Eigen::internal::ProdReducer<T>>::value); 733 }; 734 735 template <typename T, typename Op> 736 struct IdentityValue { 737 static_assert(IsSum<T, Op>::value || IsMax<T, Op>::value || 738 IsMin<T, Op>::value || IsProd<T, Op>::value || 739 std::is_same<Op, And>::value || std::is_same<Op, Or>::value, 740 "IdentityValue not yet defined for this type"); 741 742 template <typename U = T, typename OpCopy = Op> 743 U operator()( 744 typename std::enable_if<IsSum<U, OpCopy>::value, U>::type t = U(0)) { 745 return t; 746 } 747 748 template <typename U = T, typename OpCopy = Op> 749 U operator()(typename std::enable_if<IsMax<U, OpCopy>::value, U>::type t = 750 Eigen::NumTraits<U>::lowest()) { 751 return t; 752 } 753 754 template <typename U = T, typename OpCopy = Op> 755 U operator()(typename std::enable_if<IsMin<U, OpCopy>::value, U>::type t = 756 Eigen::NumTraits<U>::highest()) { 757 return t; 758 } 759 760 template <typename U = T, typename OpCopy = Op> 761 U operator()( 762 typename std::enable_if<IsProd<U, OpCopy>::value, U>::type t = U(1)) { 763 return t; 764 } 765 766 template <typename U = T, typename OpCopy = Op> 767 U operator()(typename std::enable_if<std::is_same<OpCopy, And>::value, 768 bool>::type t = true) { 769 return t; 770 } 771 772 template <typename U = T, typename OpCopy = Op> 773 U operator()(typename std::enable_if<std::is_same<OpCopy, Or>::value, 774 bool>::type t = false) { 775 return t; 776 } 777 }; 778 779 } // namespace reduction_op_helper 780 781 template <typename T, typename Op, typename OUT_T, typename IN_T, 782 typename ReductionAxes> 783 void ReduceImpl(OpKernelContext* ctx, OUT_T out, IN_T in, int in_rank, 784 int in_dim0, int in_dim1, int in_dim2, int out_rank, 785 const ReductionAxes& reduction_axes, Op op) { 786 T init = reduction_op_helper::IdentityValue<T, Op>()(); 787 const cudaStream_t& cu_stream = GetCudaStream(ctx); 788 if (out_rank == 0) { 789 const int in_size = in_dim0 * in_dim1 * in_dim2; 790 LaunchScalarReduction(ctx, out, in, in_size, op, init, cu_stream); 791 } else if (in_rank == 2 && out_rank == 1 && 792 reduction_axes[0] == 1) { // row reduction 793 LaunchRowReduction(ctx, out, in, in_dim0, in_dim1, op, init, cu_stream); 794 } else if (in_rank == 2 && out_rank == 1 && 795 reduction_axes[0] == 0) { // column reduction 796 LaunchColumnReduction(ctx, out, in, in_dim0, in_dim1, op, init, cu_stream); 797 } else if (in_rank == 3 && out_rank == 2 && reduction_axes[0] == 1) { 798 Launch3DYReduction(ctx, out, in, in_dim0, in_dim1, in_dim2, op, init, 799 cu_stream); 800 } else if (in_rank == 3 && out_rank == 1 && reduction_axes[0] == 0 && 801 reduction_axes[1] == 2) { 802 Launch3DXZReduction(ctx, out, in, in_dim0, in_dim1, in_dim2, op, init, 803 cu_stream); 804 } else { 805 std::stringstream ss; 806 ss << "Invalid reduction requested: in_rank, out_rank, axes " << in_rank 807 << " " << out_rank; 808 if (out_rank == 1) ss << " " << reduction_axes[0]; 809 if (out_rank == 2) ss << " " << reduction_axes[1]; 810 LOG(FATAL) << ss.str(); 811 } 812 } 813 814 template <typename Reducer> 815 struct ReduceFunctor<GPUDevice, Reducer> { 816 template <typename OUT_T, typename IN_T, typename ReductionAxes> 817 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, 818 const ReductionAxes& reduction_axes, 819 const Reducer& reducer); 820 }; 821 822 template <typename T> 823 struct ReduceFunctor<GPUDevice, Eigen::internal::SumReducer<T>> { 824 template <typename OUT_T, typename IN_T, typename ReductionAxes> 825 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, 826 const ReductionAxes& reduction_axes, 827 const Eigen::internal::SumReducer<T>& reducer) { 828 ReduceImpl<T, Sum<T>, T*, T*, ReductionAxes>( 829 ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0), 830 in.rank() >= 2 ? in.dimension(1) : 1, 831 in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, 832 Sum<T>()); 833 } 834 835 template <typename OUT_T> 836 static void FillIdentity(const GPUDevice& d, OUT_T out, 837 const Eigen::internal::SumReducer<T>& reducer) { 838 FillIdentityEigenImpl(d, To32Bit(out), reducer); 839 } 840 }; 841 842 template <typename T> 843 struct ReduceFunctor<GPUDevice, Eigen::internal::MeanReducer<T>> { 844 template <typename OUT_T, typename IN_T, typename ReductionAxes> 845 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, 846 const ReductionAxes& reduction_axes, 847 const Eigen::internal::MeanReducer<T>& reducer) { 848 int divisor = 1; 849 if (out.rank() == 0) 850 divisor = in.size(); 851 else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 0) 852 divisor = in.dimension(0); 853 else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 1) 854 divisor = in.dimension(1); 855 else if (out.rank() == 1 && in.rank() == 3 && reduction_axes[0] == 0 && 856 reduction_axes[1] == 2) 857 divisor = in.dimension(0) * in.dimension(2); 858 else if (out.rank() == 2 && in.rank() == 3 && reduction_axes[0] == 1) 859 divisor = in.dimension(1); 860 861 DividesBy<T> div_op(static_cast<T>(divisor)); 862 TransformOutputIterator<T, T, DividesBy<T>> itr((T*)out.data(), div_op); 863 ReduceImpl<T, Sum<T>, TransformOutputIterator<T, T, DividesBy<T>>, T*, 864 ReductionAxes>(ctx, itr, (T*)in.data(), in.rank(), 865 in.dimension(0), 866 in.rank() >= 2 ? in.dimension(1) : 1, 867 in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), 868 reduction_axes, Sum<T>()); 869 } 870 871 template <typename OUT_T> 872 static void FillIdentity(const GPUDevice& d, OUT_T out, 873 const Eigen::internal::MeanReducer<T>& reducer) { 874 FillIdentityEigenImpl(d, To32Bit(out), reducer); 875 } 876 }; 877 878 template <> 879 struct ReduceFunctor<GPUDevice, Eigen::internal::MeanReducer<Eigen::half>> { 880 template <typename OUT_T, typename IN_T, typename ReductionAxes> 881 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, 882 const ReductionAxes& reduction_axes, 883 const Eigen::internal::MeanReducer<Eigen::half>& reducer) { 884 float divisor = 1.f; 885 if (out.rank() == 0) 886 divisor = in.size(); 887 else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 0) 888 divisor = in.dimension(0); 889 else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 1) 890 divisor = in.dimension(1); 891 else if (out.rank() == 1 && in.rank() == 3 && reduction_axes[0] == 0 && 892 reduction_axes[1] == 2) 893 divisor = in.dimension(0) * in.dimension(2); 894 else if (out.rank() == 2 && in.rank() == 3 && reduction_axes[0] == 1) 895 divisor = in.dimension(1); 896 DividesBy<float, Eigen::half> div_op(divisor); 897 898 typedef cub::TransformInputIterator<float, HalfToFloat, Eigen::half*> 899 inputIterType; 900 inputIterType input_itr((Eigen::half*)in.data(), HalfToFloat()); 901 902 typedef TransformOutputIterator<Eigen::half, float, 903 DividesBy<float, Eigen::half>> 904 outputIterType; 905 outputIterType itr((Eigen::half*)out.data(), div_op); 906 907 ReduceImpl<float, cub::Sum, outputIterType, inputIterType, ReductionAxes>( 908 ctx, itr, input_itr, in.rank(), in.dimension(0), 909 in.rank() >= 2 ? in.dimension(1) : 1, 910 in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, 911 cub::Sum()); 912 } 913 914 template <typename OUT_T> 915 static void FillIdentity( 916 const GPUDevice& d, OUT_T out, 917 const Eigen::internal::MeanReducer<Eigen::half>& reducer) { 918 FillIdentityEigenImpl(d, To32Bit(out), reducer); 919 } 920 }; 921 922 template <typename T> 923 struct ReduceFunctor<GPUDevice, Eigen::internal::MaxReducer<T>> { 924 template <typename OUT_T, typename IN_T, typename ReductionAxes> 925 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, 926 const ReductionAxes& reduction_axes, 927 const Eigen::internal::MaxReducer<T>& reducer) { 928 ReduceImpl<T, cub::Max, T*, T*, ReductionAxes>( 929 ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0), 930 in.rank() >= 2 ? in.dimension(1) : 1, 931 in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, 932 cub::Max()); 933 } 934 935 template <typename OUT_T> 936 static void FillIdentity(const GPUDevice& d, OUT_T out, 937 const Eigen::internal::MaxReducer<T>& reducer) { 938 FillIdentityEigenImpl(d, To32Bit(out), reducer); 939 } 940 }; 941 942 template <typename T> 943 struct ReduceFunctor<GPUDevice, Eigen::internal::MinReducer<T>> { 944 template <typename OUT_T, typename IN_T, typename ReductionAxes> 945 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, 946 const ReductionAxes& reduction_axes, 947 const Eigen::internal::MinReducer<T>& reducer) { 948 ReduceImpl<T, cub::Min, T*, T*, ReductionAxes>( 949 ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0), 950 in.rank() >= 2 ? in.dimension(1) : 1, 951 in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, 952 cub::Min()); 953 } 954 955 template <typename OUT_T> 956 static void FillIdentity(const GPUDevice& d, OUT_T out, 957 const Eigen::internal::MinReducer<T>& reducer) { 958 FillIdentityEigenImpl(d, To32Bit(out), reducer); 959 } 960 }; 961 962 template <typename T> 963 struct ReduceFunctor<GPUDevice, Eigen::internal::ProdReducer<T>> { 964 template <typename OUT_T, typename IN_T, typename ReductionAxes> 965 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, 966 const ReductionAxes& reduction_axes, 967 const Eigen::internal::ProdReducer<T>& reducer) { 968 ReduceImpl<T, Prod<T>, T*, T*, ReductionAxes>( 969 ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0), 970 in.rank() >= 2 ? in.dimension(1) : 1, 971 in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, 972 Prod<T>()); 973 } 974 975 template <typename OUT_T> 976 static void FillIdentity(const GPUDevice& d, OUT_T out, 977 const Eigen::internal::ProdReducer<T>& reducer) { 978 FillIdentityEigenImpl(d, To32Bit(out), reducer); 979 } 980 }; 981 982 template <> 983 struct ReduceFunctor<GPUDevice, Eigen::internal::AndReducer> { 984 template <typename OUT_T, typename IN_T, typename ReductionAxes> 985 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, 986 const ReductionAxes& reduction_axes, 987 const Eigen::internal::AndReducer& reducer) { 988 ReduceImpl<bool, And, bool*, bool*, ReductionAxes>( 989 ctx, (bool*)out.data(), (bool*)in.data(), in.rank(), in.dimension(0), 990 in.rank() >= 2 ? in.dimension(1) : 1, 991 in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, 992 And()); 993 } 994 995 template <typename OUT_T> 996 static void FillIdentity(const GPUDevice& d, OUT_T out, 997 const Eigen::internal::AndReducer& reducer) { 998 FillIdentityEigenImpl(d, To32Bit(out), reducer); 999 } 1000 }; 1001 1002 template <> 1003 struct ReduceFunctor<GPUDevice, Eigen::internal::OrReducer> { 1004 template <typename OUT_T, typename IN_T, typename ReductionAxes> 1005 static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, 1006 const ReductionAxes& reduction_axes, 1007 const Eigen::internal::OrReducer& reducer) { 1008 ReduceImpl<bool, Or, bool*, bool*, ReductionAxes>( 1009 ctx, (bool*)out.data(), (bool*)in.data(), in.rank(), in.dimension(0), 1010 in.rank() >= 2 ? in.dimension(1) : 1, 1011 in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, Or()); 1012 } 1013 1014 template <typename OUT_T> 1015 static void FillIdentity(const GPUDevice& d, OUT_T out, 1016 const Eigen::internal::OrReducer& reducer) { 1017 FillIdentityEigenImpl(d, To32Bit(out), reducer); 1018 } 1019 }; 1020 1021 } // namespace functor 1022 } // namespace tensorflow 1023 1024 #endif 1025