1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_SCAN_OPS_GPU_H_ 17 #define TENSORFLOW_CORE_KERNELS_SCAN_OPS_GPU_H_ 18 19 #if GOOGLE_CUDA 20 21 #define EIGEN_USE_GPU 22 23 #if CUDA_VERSION >= 9000 24 #define CUB_USE_COOPERATIVE_GROUPS 25 #endif // CUDA_VERSION >= 9000 26 27 #include "third_party/cub/block/block_load.cuh" 28 #include "third_party/cub/block/block_scan.cuh" 29 #include "third_party/cub/block/block_store.cuh" 30 #include "third_party/cub/iterator/counting_input_iterator.cuh" 31 #include "third_party/cub/iterator/transform_input_iterator.cuh" 32 #include "cuda/include/cuComplex.h" 33 #include "tensorflow/core/framework/numeric_types.h" 34 #include "tensorflow/core/framework/register_types.h" 35 #include "tensorflow/core/util/cuda_launch_config.h" 36 #include "tensorflow/core/util/permutation_input_iterator.h" 37 #include "tensorflow/core/util/permutation_output_iterator.h" 38 39 #include "tensorflow/core/kernels/scan_ops.h" 40 41 namespace tensorflow { 42 43 typedef Eigen::GpuDevice GPUDevice; 44 typedef Eigen::Index Index; 45 46 namespace functor { 47 48 // Map a contiguous range to the actual memory locations depending on which 49 // axis the scan is taking place over and whether or not reversed. 50 struct MapIndexToLocation { 51 __host__ __device__ MapIndexToLocation(int dimx, int dimy, int dimz, 52 bool reverse = false) 53 : dimx_(dimx), dimy_(dimy), dimz_(dimz), reverse_(reverse) {} 54 55 __host__ __device__ int operator()(int id) const { 56 if (dimx_ == 1) { 57 int row = id % dimy_; 58 int col = id / dimy_; 59 60 if (reverse_) return (dimy_ - row - 1) * dimz_ + col; 61 62 return row * dimz_ + col; 63 } else if (dimz_ == 1) { 64 if (reverse_) { 65 int row = id / dimy_; 66 int col = id % dimy_; 67 return row * dimy_ + (dimy_ - col - 1); 68 } 69 return id; 70 } else { 71 int col = id % dimy_; 72 int tmp = id / dimy_; 73 74 int row1 = id / (dimy_ * dimz_); 75 int col1 = tmp % dimz_; 76 77 if (reverse_) 78 return row1 * dimy_ * dimz_ + (dimy_ - col - 1) * dimz_ + col1; 79 80 return row1 * dimy_ * dimz_ + col * dimz_ + col1; 81 } 82 } 83 84 int dimx_; 85 int dimy_; 86 int dimz_; 87 bool reverse_; 88 }; 89 90 template <typename T, typename Op> 91 struct BlockPrefixCallbackOp { 92 // Running prefix 93 T running_total_; 94 Op op_; 95 96 __device__ BlockPrefixCallbackOp(T running_total, Op op) 97 : running_total_(running_total), op_(op) {} 98 99 // Callback operator to be entered by the first warp of threads in the block. 100 // tid 0 is responsible for returning a value for seeding the block-wide scan. 101 __device__ T operator()(T block_aggregate) { 102 T old_prefix = running_total_; 103 running_total_ = op_(old_prefix, block_aggregate); 104 return old_prefix; 105 } 106 }; 107 108 template <typename T> 109 struct Sum { 110 __host__ __device__ T operator()(const T& a, const T& b) const { 111 return a + b; 112 } 113 }; 114 115 template <typename T> 116 struct Prod { 117 __host__ __device__ T operator()(const T& a, const T& b) const { 118 return a * b; 119 } 120 }; 121 122 template <typename T, typename Op> 123 struct IsSum { 124 constexpr static bool value = 125 (std::is_same<Op, Sum<T>>::value || 126 std::is_same<Op, Eigen::internal::SumReducer<T>>::value); 127 }; 128 129 template <typename T, typename Op> 130 struct IsProd { 131 constexpr static bool value = 132 (std::is_same<Op, Prod<T>>::value || 133 std::is_same<Op, Eigen::internal::ProdReducer<T>>::value); 134 }; 135 136 template <typename T, typename Op> 137 struct IdentityValue { 138 static_assert(IsSum<T, Op>::value || IsProd<T, Op>::value, 139 "IdentityValue not yet defined for this type."); 140 141 template <typename U = T, typename OpCopy = Op> 142 __host__ __device__ U operator()( 143 typename std::enable_if<IsSum<U, OpCopy>::value, U>::type t = U(0)) { 144 return t; 145 } 146 147 template <typename U = T, typename OpCopy = Op> 148 __host__ __device__ U operator()( 149 typename std::enable_if<IsProd<U, OpCopy>::value, U>::type t = U(1)) { 150 return t; 151 } 152 }; 153 154 // Each block is mapped to one sequence. A contiguous range is mapped to the 155 // appropriate locations in memory by the permutation iterators. This is 156 // ideal for 1-D and row based scans. Column scans would be better if they 157 // did a block load and then locally transposed. CUB's device wide scan is not 158 // used in the large 1D case, even though it would be more efficient, because 159 // it is not deterministic. 160 template <typename T, typename Op, int BlockDim = 128, int ItemsPerThread = 4> 161 __global__ void scan_kernel(const T* in, T* out, int dimx, int dimy, int dimz, 162 bool exclusive, bool reverse, Op op) { 163 typedef cub::BlockLoad<T, BlockDim, ItemsPerThread, cub::BLOCK_LOAD_TRANSPOSE> 164 BlockLoad; 165 typedef cub::BlockStore<T, BlockDim, ItemsPerThread, 166 cub::BLOCK_STORE_TRANSPOSE> 167 BlockStore; 168 typedef cub::BlockScan<T, BlockDim> BlockScan; 169 170 // Allocate aliased shared memory for BlockLoad, BlockStore, and BlockScan 171 __shared__ union { 172 typename BlockLoad::TempStorage load; 173 typename BlockScan::TempStorage scan; 174 typename BlockStore::TempStorage store; 175 } temp_storage; 176 177 int problem_length = dimy; 178 179 // Initialize running total 180 BlockPrefixCallbackOp<T, Op> prefix_op(IdentityValue<T, Op>()(), op); 181 182 MapIndexToLocation map_op(dimx, dimy, dimz, reverse); 183 int block_start = problem_length * blockIdx.x; 184 // Have the block iterate over segments of items 185 for (int block_offset = block_start; 186 block_offset < block_start + problem_length; 187 block_offset += BlockDim * ItemsPerThread) { 188 int valid_items = min(BlockDim * ItemsPerThread, 189 problem_length - (block_offset % problem_length)); 190 191 // first construct a counting iterator that has the desired start point 192 typedef cub::TransformInputIterator<int, MapIndexToLocation, 193 cub::CountingInputIterator<int>> 194 MapIterType; 195 196 cub::CountingInputIterator<int> counting_iter(block_offset); 197 198 // Next map the iterator to the actual locations in memory 199 MapIterType map_iter(counting_iter, map_op); 200 201 PermutationInputIterator<T, const T*, MapIterType> permutein_iter(in, 202 map_iter); 203 PermutationOutputIterator<T, T*, MapIterType> permuteout_iter(out, 204 map_iter); 205 206 // Load a segment of consecutive items that are blocked across threads 207 T thread_data[ItemsPerThread]; 208 BlockLoad(temp_storage.load).Load(permutein_iter, thread_data, valid_items); 209 __syncthreads(); 210 211 // Collectively compute the block-wide scan 212 if (exclusive) { 213 BlockScan(temp_storage.scan) 214 .ExclusiveScan(thread_data, thread_data, op, prefix_op); 215 } else { 216 BlockScan(temp_storage.scan) 217 .InclusiveScan(thread_data, thread_data, op, prefix_op); 218 } 219 __syncthreads(); 220 221 // Store scanned items to output segment 222 BlockStore(temp_storage.store) 223 .Store(permuteout_iter, thread_data, valid_items); 224 __syncthreads(); 225 } 226 } 227 228 template <typename T, typename Op> 229 void LaunchScan(const GPUDevice& d, typename TTypes<T, 3>::ConstTensor in, 230 typename TTypes<T, 3>::Tensor out, Op op, const bool reverse, 231 const bool exclusive) { 232 const int items_per_thread = 4; 233 234 int dimx = in.dimension(0); 235 int dimy = in.dimension(1); 236 int dimz = in.dimension(2); 237 int num_blocks = dimx * dimz; 238 239 int ideal_block_size = dimy / items_per_thread; 240 241 // There seems to be a bug when the type is not float and block_size 1024. 242 // Launch on the smallest power of 2 block size that we can. 243 if (ideal_block_size >= 1024 && std::is_same<T, float>::value) { 244 const int block_size = 1024; 245 TF_CHECK_OK( 246 CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>, 247 num_blocks, block_size, 0, d.stream(), in.data(), 248 out.data(), dimx, dimy, dimz, exclusive, reverse, op)); 249 } else if (ideal_block_size >= 512) { 250 const int block_size = 512; 251 TF_CHECK_OK( 252 CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>, 253 num_blocks, block_size, 0, d.stream(), in.data(), 254 out.data(), dimx, dimy, dimz, exclusive, reverse, op)); 255 } else if (ideal_block_size >= 256) { 256 const int block_size = 256; 257 TF_CHECK_OK( 258 CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>, 259 num_blocks, block_size, 0, d.stream(), in.data(), 260 out.data(), dimx, dimy, dimz, exclusive, reverse, op)); 261 } else if (ideal_block_size >= 128) { 262 const int block_size = 128; 263 TF_CHECK_OK( 264 CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>, 265 num_blocks, block_size, 0, d.stream(), in.data(), 266 out.data(), dimx, dimy, dimz, exclusive, reverse, op)); 267 } else if (ideal_block_size >= 64) { 268 const int block_size = 64; 269 TF_CHECK_OK( 270 CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>, 271 num_blocks, block_size, 0, d.stream(), in.data(), 272 out.data(), dimx, dimy, dimz, exclusive, reverse, op)); 273 } else { 274 const int block_size = 32; 275 TF_CHECK_OK( 276 CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>, 277 num_blocks, block_size, 0, d.stream(), in.data(), 278 out.data(), dimx, dimy, dimz, exclusive, reverse, op)); 279 } 280 } 281 282 template <typename T> 283 struct Scan<GPUDevice, Eigen::internal::SumReducer<T>, T> { 284 void operator()(const GPUDevice& d, typename TTypes<T, 3>::ConstTensor in, 285 typename TTypes<T, 3>::Tensor out, 286 const Eigen::internal::SumReducer<T>& reducer, 287 const bool reverse, const bool exclusive) { 288 LaunchScan<T, Sum<T>>(d, in, out, Sum<T>(), reverse, exclusive); 289 } 290 }; 291 292 template <typename T> 293 struct Scan<GPUDevice, Eigen::internal::ProdReducer<T>, T> { 294 void operator()(const GPUDevice& d, typename TTypes<T, 3>::ConstTensor in, 295 typename TTypes<T, 3>::Tensor out, 296 const Eigen::internal::ProdReducer<T>& reducer, 297 const bool reverse, const bool exclusive) { 298 LaunchScan<T, Prod<T>>(d, in, out, Prod<T>(), reverse, exclusive); 299 } 300 }; 301 302 } // namespace functor 303 } // end namespace tensorflow 304 305 #endif // GOOGLE_CUDA 306 307 #endif // TENSORFLOW_CORE_KERNELS_SCAN_OPS_GPU_H_ 308