1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // The algorithm for dynamic partition has the following steps: 17 // 1. Let N be the size of partitions. We initialize a new vector indices_in 18 // with the values 0, 1, 2, ..., N-1. 19 // 2. We apply cub::DeviceRadixSort::SortPairs to the key - value pairs given 20 // by partitions and indices_in. This will result in two new vectors 21 // partitions_out and indices_out, with partitions_out sorted. 22 // 3. The first dimension of outputs[i] is equal to the number of i-values in 23 // partitions_out. We determine it in two steps: 24 // - apply cub::DeviceReduce::ReduceByKey to count how many times each value 25 // appears in partitions_out, 26 // - move the results to partition_count. This handles missing values 27 // (corresponding to empty parts). 28 // 4. Because partition_count is on the GPU, we bring it asynchronously to 29 // the CPU. Then we can allocate the output tensors. 30 // 5. Finally, we use indices_out and the gather functor to collect the output. 31 // This works, because for each interval of i-values, indices_out points 32 // to the slices which should form output[i]. 33 34 #if GOOGLE_CUDA 35 36 #define EIGEN_USE_GPU 37 38 #include "external/cub_archive/cub/device/device_radix_sort.cuh" 39 #include "external/cub_archive/cub/device/device_reduce.cuh" 40 #include "external/cub_archive/cub/iterator/constant_input_iterator.cuh" 41 #include "external/cub_archive/cub/thread/thread_operators.cuh" 42 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" 43 #include "tensorflow/core/framework/op_kernel.h" 44 #include "tensorflow/core/framework/register_types.h" 45 #include "tensorflow/core/framework/tensor.h" 46 #include "tensorflow/core/framework/types.h" 47 #include "tensorflow/core/kernels/bounds_check.h" 48 #include "tensorflow/core/kernels/fill_functor.h" 49 #include "tensorflow/core/kernels/gather_functor_gpu.cu.h" 50 #include "tensorflow/core/util/cuda_kernel_helper.h" 51 #include "tensorflow/core/util/transform_output_iterator.h" 52 53 namespace tensorflow { 54 55 typedef Eigen::GpuDevice GPUDevice; 56 57 namespace { 58 59 template <typename T> 60 __global__ void RangeInitKernel(const T start, const T delta, const int32 size, 61 T* out) { 62 CUDA_1D_KERNEL_LOOP(i, size) { out[i] = start + i * delta; } 63 } 64 65 __global__ void MoveValuesKernel(const int32* keys, const int32* values, 66 const int32* size, int32 out_size, 67 int32* out) { 68 int32 N = min(ldg(size), out_size); 69 CUDA_1D_KERNEL_LOOP(i, N) { 70 int32 key = ldg(keys + i); 71 int32 value = ldg(values + i); 72 if (FastBoundsCheck(key, out_size)) out[key] = value; 73 } 74 } 75 76 // Initialize out with range start, start + delta, start + 2 * delta, ... 77 // This is needed because tf.range has no GPU implementation. 78 template <typename T> 79 void RangeInit(const GPUDevice& d, const T start, const T delta, 80 const int32 size, typename TTypes<T>::Flat out) { 81 CudaLaunchConfig config = GetCudaLaunchConfig(size, d); 82 RangeInitKernel<T> 83 <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( 84 start, delta, size, out.data()); 85 } 86 87 // Given *num_runs pairs (key, value), this function moves the value 88 // corresponding to key i at position i in the array out. 89 void MoveValues(const GPUDevice& d, int32* keys, int32* values, int32* num_runs, 90 int32 out_size, int32* out) { 91 // Because num_runs is located on the GPU, we can not access it directly. 92 // So we launch the kernel with size = out_size. 93 // This is valid for correct inputs, because then out_size >= *num_runs. 94 // For wrong inputs, we may have out_size < *num_runs. In this case we will 95 // only handle the first out_size values. 96 CudaLaunchConfig config = GetCudaLaunchConfig(out_size, d); 97 MoveValuesKernel<<<config.block_count, config.thread_per_block, 0, 98 d.stream()>>>(keys, values, num_runs, out_size, out); 99 } 100 101 template <typename T> 102 void CallGatherKernel(const GPUDevice& d, const T* params, const int32* indices, 103 T* out, int64 gather_dim_size, int64 indices_size, 104 int64 slice_size, int64 out_size) { 105 CudaLaunchConfig config = GetCudaLaunchConfig(out_size, d); 106 GatherOpKernel<T, int32, true> 107 <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( 108 params, indices, out, gather_dim_size, indices_size, slice_size, 109 out_size); 110 } 111 112 struct IdentityOp { 113 __device__ int32 __forceinline__ operator()(const int32& a) const { 114 return a; 115 } 116 }; 117 118 // Define an output iterator that only allows assignment to 119 // positions between [base, base + limit). 120 class BoundedOutputIterator 121 : public TransformOutputIterator<int32, int32, IdentityOp> { 122 private: 123 int32 limit; 124 int32* base; 125 126 struct BoundedReference : Reference { 127 int32 limit; 128 int32* base; 129 // Constructor 130 __host__ __device__ __forceinline__ 131 BoundedReference(int32* ptr, int32* base, IdentityOp op, int32 limit) 132 : Reference(ptr, op), limit(limit), base(base) {} 133 134 // Assignment 135 __host__ __device__ __forceinline__ int32 operator=(int32 val) { 136 if (ptr - base < limit && ptr - base >= 0) *ptr = val; 137 return val; 138 } 139 }; 140 141 public: 142 typedef BoundedOutputIterator self_type; 143 typedef BoundedReference reference; 144 145 __host__ __device__ __forceinline__ BoundedOutputIterator(int32* ptr, 146 IdentityOp op, 147 int32 size) 148 : TransformOutputIterator(ptr, op), limit(size), base(ptr) {} 149 150 __host__ __device__ __forceinline__ 151 BoundedOutputIterator(int32* ptr, int32* base, IdentityOp op, int32 size) 152 : TransformOutputIterator(ptr, op), limit(size), base(base) {} 153 154 // Indirection 155 __host__ __device__ __forceinline__ reference operator*() const { 156 return BoundedReference(ptr, base, conversion_op, limit); 157 } 158 159 // Array subscript 160 __host__ __device__ __forceinline__ reference operator[](int32 n) const { 161 return BoundedReference(ptr + n, base, conversion_op, limit); 162 } 163 164 // Addition 165 __host__ __device__ __forceinline__ self_type operator+(int32 n) const { 166 self_type retval(ptr + n, base, conversion_op, limit); 167 return retval; 168 } 169 170 // Subtraction 171 __host__ __device__ __forceinline__ self_type operator-(int32 n) const { 172 self_type retval(ptr - n, base, conversion_op, limit); 173 return retval; 174 } 175 }; 176 177 } // namespace 178 179 // The current implementation has memory cost on GPU 180 // I + P + max(3N + R + P, O + N), where: 181 // I - the size of the input 182 // N - the size of the partitions tensor 183 // R - the temporary storage used by cub::RadixSort, about 2N 184 // P - the number of partitions 185 // O - the size of the output 186 // So roughly the cost is I + P + max(5N, O + N). 187 template <typename T> 188 class DynamicPartitionOpGPU : public AsyncOpKernel { 189 public: 190 explicit DynamicPartitionOpGPU(OpKernelConstruction* c) : AsyncOpKernel(c) { 191 OP_REQUIRES_OK(c, c->GetAttr("num_partitions", &num_partitions_)); 192 OP_REQUIRES(c, num_partitions_ >= 1, 193 errors::InvalidArgument("num_partitions must be at least 1")); 194 } 195 196 void AllocateTempSpace(OpKernelContext* c, int32 N, Tensor* indices_in, 197 Tensor* partitions_out, Tensor* indices_out, 198 DoneCallback done) { 199 int32 M = std::max(N, num_partitions_); 200 // indices_in will be made slightly larger to accommodate 201 // later computations. 202 OP_REQUIRES_OK_ASYNC( 203 c, c->allocate_temp(DT_INT32, TensorShape({M}), indices_in), done); 204 OP_REQUIRES_OK_ASYNC( 205 c, c->allocate_temp(DT_INT32, TensorShape({N}), partitions_out), done); 206 OP_REQUIRES_OK_ASYNC( 207 c, c->allocate_temp(DT_INT32, TensorShape({N}), indices_out), done); 208 } 209 210 void AllocateOutputs(OpKernelContext* c, const Tensor* data, 211 const Tensor* partitions, const Tensor* partition_count, 212 OpOutputList* Tout, DoneCallback done) { 213 auto e_part_count = partition_count->flat<int32>(); 214 // Allocate output tensors of the right size 215 OP_REQUIRES_OK_ASYNC(c, c->output_list("outputs", Tout), done); 216 for (int p = 0; p < num_partitions_; p++) { 217 TensorShape shape; 218 shape.AddDim(e_part_count(p)); 219 for (int i = partitions->dims(); i < data->dims(); i++) { 220 shape.AddDim(data->dim_size(i)); 221 } 222 Tensor* out; 223 OP_REQUIRES_OK_ASYNC(c, Tout->allocate(p, shape, &out), done); 224 } 225 } 226 227 void ComputeAsync(OpKernelContext* c, DoneCallback done) { 228 const Tensor& data = c->input(0); 229 const Tensor& partitions = c->input(1); 230 231 OP_REQUIRES_ASYNC( 232 c, TensorShapeUtils::StartsWith(data.shape(), partitions.shape()), 233 errors::InvalidArgument( 234 "data.shape must start with partitions.shape, ", 235 "got data.shape = ", data.shape().DebugString(), 236 ", partitions.shape = ", partitions.shape().DebugString()), 237 done); 238 239 Tensor partition_count; 240 241 // We must handle the case of empty partitions separately, 242 // because kernels don't work with 0-sized tensors. 243 if (partitions.NumElements() == 0) { 244 AllocatorAttributes alloc_attr; 245 alloc_attr.set_on_host(true); 246 OP_REQUIRES_OK_ASYNC( 247 c, 248 c->allocate_temp(DT_INT32, TensorShape({num_partitions_}), 249 &partition_count, alloc_attr), 250 done); 251 auto e_part_count = partition_count.flat<int32>(); 252 for (int i = 0; i < num_partitions_; i++) e_part_count(i) = 0; 253 OpOutputList outputs; 254 this->AllocateOutputs(c, &data, &partitions, &partition_count, &outputs, 255 done); 256 if (c->status().ok()) done(); 257 return; 258 } 259 260 // Prepare for counting. 261 OP_REQUIRES_OK_ASYNC( 262 c, 263 c->allocate_temp(DT_INT32, TensorShape({num_partitions_}), 264 &partition_count), 265 done); 266 Tensor indices_out; 267 // Count how many times each partition index occurs. 268 // Also sort the info in partitions and output it in indices_out, 269 // in preparation for the next step. 270 this->CountAndSortParts(c, &partitions, &partition_count, &indices_out, 271 done); 272 if (!c->status().ok()) return; 273 274 // In order to allocate the output tensor we have to move partition_count 275 // to CPU. 276 auto* stream = c->op_device_context()->stream(); 277 OP_REQUIRES_ASYNC(c, stream, errors::Internal("No GPU stream available."), 278 done); 279 Tensor cpu_tensor; 280 AllocatorAttributes alloc_attr; 281 alloc_attr.set_on_host(true); 282 alloc_attr.set_gpu_compatible(true); 283 OP_REQUIRES_OK_ASYNC( 284 c, 285 c->allocate_temp(partition_count.dtype(), partition_count.shape(), 286 &cpu_tensor, alloc_attr), 287 done); 288 perftools::gputools::DeviceMemoryBase wrapped( 289 partition_count.flat<int32>().data(), num_partitions_ * sizeof(int32)); 290 const bool status = 291 stream 292 ->ThenMemcpy(cpu_tensor.flat<int32>().data(), wrapped, 293 num_partitions_ * sizeof(int32)) 294 .ok(); 295 OP_REQUIRES_ASYNC( 296 c, status, 297 errors::Internal("Failed to launch copy from device to host."), done); 298 299 // Keep a reference to partition_count so that the buffer 300 // is not deallocated at the end of the function, before 301 // memcpy is completed. 302 TensorReference partition_ref(partition_count); 303 auto wrapped_callback = [this, c, &data, &partitions, indices_out, 304 partition_ref, cpu_tensor, done]() { 305 OpOutputList outputs; 306 this->AllocateOutputs(c, &data, &partitions, &cpu_tensor, &outputs, done); 307 if (!c->status().ok()) { 308 partition_ref.Unref(); 309 return; 310 } 311 int32 N = partitions.NumElements(); 312 int64 slice_size = data.NumElements() / N; 313 this->GatherSlices(c, &data, &indices_out, N, slice_size, outputs); 314 partition_ref.Unref(); 315 done(); 316 }; 317 318 c->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( 319 stream, wrapped_callback); 320 } 321 322 protected: 323 void RadixSort(OpKernelContext* c, const Tensor* partitions, 324 Tensor* indices_in, Tensor* partitions_out, 325 Tensor* indices_out, DoneCallback done) { 326 int32 N = partitions->NumElements(); 327 const GPUDevice& device = c->eigen_device<GPUDevice>(); 328 const cudaStream_t& cu_stream = GetCudaStream(c); 329 330 // Initialize the indices_in tensor using the Range GPU kernel. 331 RangeInit(device, 0, 1, N, indices_in->flat<int32>()); 332 // Obtain the pointers to inner buffers. 333 const int32* partitions_ptr = partitions->flat<int32>().data(); 334 int32* partitions_out_ptr = partitions_out->flat<int32>().data(); 335 int32* indices_in_ptr = indices_in->flat<int32>().data(); 336 int32* indices_out_ptr = indices_out->flat<int32>().data(); 337 // Determine temporary device storage requirements. 338 Tensor cub_temp_storage; 339 size_t temp_storage_bytes = 0; 340 cub::DeviceRadixSort::SortPairs( 341 NULL, temp_storage_bytes, partitions_ptr, partitions_out_ptr, 342 indices_in_ptr, indices_out_ptr, N, 0, sizeof(int32) * 8, cu_stream); 343 // Allocate temporary storage. 344 OP_REQUIRES_OK_ASYNC( 345 c, 346 c->allocate_temp(DT_INT8, 347 TensorShape({static_cast<int64>(temp_storage_bytes)}), 348 &cub_temp_storage), 349 done); 350 // Radix-sort the partition information. 351 cub::DeviceRadixSort::SortPairs( 352 cub_temp_storage.flat<int8>().data(), temp_storage_bytes, 353 partitions_ptr, partitions_out_ptr, indices_in_ptr, indices_out_ptr, N, 354 0, sizeof(int32) * 8, cu_stream); 355 } // At this point cub_temp_storage will be marked for deallocation. 356 357 void CountAndSortParts(OpKernelContext* c, const Tensor* partitions, 358 Tensor* partition_count, Tensor* indices_out, 359 DoneCallback done) { 360 const GPUDevice& device = c->eigen_device<GPUDevice>(); 361 const cudaStream_t& cu_stream = GetCudaStream(c); 362 int32 N = partitions->NumElements(); 363 Tensor indices_in; 364 Tensor partitions_out; 365 Tensor aggregates_out; 366 367 // Allocate memory for Radix-Sort. 368 this->AllocateTempSpace(c, N, &indices_in, &partitions_out, indices_out, 369 done); 370 if (!c->status().ok()) return; 371 this->RadixSort(c, partitions, &indices_in, &partitions_out, indices_out, 372 done); 373 if (!c->status().ok()) return; 374 // We will now apply a reduce operation to count how many times 375 // each index appears in partitions. 376 377 // Zero-out the partition_count tensor. 378 functor::SetZeroFunctor<GPUDevice, int32> zero_functor; 379 zero_functor(device, partition_count->flat<int32>()); 380 // Allocate memory for aggregates_out. 381 OP_REQUIRES_OK_ASYNC( 382 c, 383 c->allocate_temp(DT_INT32, TensorShape({num_partitions_}), 384 &aggregates_out), 385 done); 386 // Obtain the pointers to inner buffers. 387 int32* keys_in_ptr = partitions_out.flat<int32>().data(); 388 // Here we reuse the indices_in tensor for the unique keys output. 389 int32* unique_out_ptr = indices_in.flat<int32>().data(); 390 int32* aggregates_out_ptr = aggregates_out.flat<int32>().data(); 391 // We wrap the pointers in bounded output iterators to guard against 392 // wrong inputs (more than num_partitions distinct indices). 393 IdentityOp id_op; 394 BoundedOutputIterator unique_out_it(unique_out_ptr, id_op, num_partitions_); 395 BoundedOutputIterator aggregates_out_it(aggregates_out_ptr, id_op, 396 num_partitions_); 397 398 cub::ConstantInputIterator<int32> values_in(1); 399 cub::Sum reduction_op; 400 401 // Allocate space on GPU for the number of runs. This is required by CUB. 402 Tensor num_runs; 403 OP_REQUIRES_OK_ASYNC( 404 c, c->allocate_temp(DT_INT32, TensorShape({1}), &num_runs), done); 405 int32* num_runs_ptr = num_runs.flat<int32>().data(); 406 407 // Determine temporary device storage requirements 408 Tensor cub_temp_storage; 409 size_t temp_storage_bytes = 0; 410 cub::DeviceReduce::ReduceByKey(NULL, temp_storage_bytes, keys_in_ptr, 411 unique_out_it, values_in, aggregates_out_it, 412 num_runs_ptr, reduction_op, N, cu_stream); 413 // Allocate temporary storage. 414 OP_REQUIRES_OK_ASYNC( 415 c, 416 c->allocate_temp(DT_INT8, 417 TensorShape({static_cast<int64>(temp_storage_bytes)}), 418 &cub_temp_storage), 419 done); 420 // Run reduce-by-key. The effect is that we count how many times 421 // each index appears in partitions. The distinct indices are stored 422 // in unique_out, while the count is stored in aggregates_out. 423 // The total number of distinct indices is stored in num_runs. 424 cub::DeviceReduce::ReduceByKey(cub_temp_storage.flat<int8>().data(), 425 temp_storage_bytes, keys_in_ptr, 426 unique_out_it, values_in, aggregates_out_it, 427 num_runs_ptr, reduction_op, N, cu_stream); 428 // We are not done yet. unique_out only contains the indices that appeared 429 // at least once in partitions. We move each value from aggregates_out 430 // to the corresponding position in partition_count. This will handle 431 // possibly empty parts. 432 MoveValues(device, unique_out_ptr, aggregates_out_ptr, num_runs_ptr, 433 num_partitions_, partition_count->flat<int32>().data()); 434 } // At this point indices_in, partitions_out, aggregates_out 435 // and cub_temp_storage will be marked for deallocation. 436 437 void GatherSlices(OpKernelContext* c, const Tensor* data, 438 const Tensor* indices, int32 N, int64 slice_size, 439 OpOutputList& outs) { 440 const GPUDevice& device = c->eigen_device<GPUDevice>(); 441 const int32* ind_base = indices->flat<int32>().data(); 442 const T* data_base = data->flat<T>().data(); 443 444 for (int p = 0; p < num_partitions_; p++) { 445 int32 indices_size = outs[p]->dim_size(0); 446 int64 out_size = outs[p]->NumElements(); 447 T* out_base = outs[p]->flat<T>().data(); 448 if (out_size > 0) 449 CallGatherKernel<T>(device, data_base, ind_base, out_base, N, 450 indices_size, slice_size, out_size); 451 ind_base += indices_size; 452 } 453 } 454 455 int32 num_partitions_; 456 }; 457 458 #define REGISTER_DYNAMIC_PARTITION_GPU(T) \ 459 REGISTER_KERNEL_BUILDER( \ 460 Name("DynamicPartition").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 461 DynamicPartitionOpGPU<T>) 462 463 TF_CALL_GPU_NUMBER_TYPES(REGISTER_DYNAMIC_PARTITION_GPU); 464 TF_CALL_complex64(REGISTER_DYNAMIC_PARTITION_GPU); 465 TF_CALL_complex128(REGISTER_DYNAMIC_PARTITION_GPU); 466 #undef REGISTER_DYNAMIC_PARTITION_GPU 467 468 } // namespace tensorflow 469 470 #endif // GOOGLE_CUDA 471