1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #if GOOGLE_CUDA 17 18 #define EIGEN_USE_GPU 19 20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 21 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/tensor_shape.h" 26 #include "tensorflow/core/kernels/bucketize_op.h" 27 #include "tensorflow/core/kernels/cuda_device_array.h" 28 #include "tensorflow/core/platform/logging.h" 29 #include "tensorflow/core/platform/types.h" 30 #include "tensorflow/core/util/cuda_kernel_helper.h" 31 32 namespace tensorflow { 33 34 typedef Eigen::GpuDevice GPUDevice; 35 36 template <typename T, bool useSharedMem> 37 __global__ void BucketizeCustomKernel( 38 const int32 size_in, const T* in, const int32 size_boundaries, 39 CudaDeviceArrayStruct<float> boundaries_array, int32* out) { 40 const float* boundaries = GetCudaDeviceArrayOnDevice(&boundaries_array); 41 42 extern __shared__ __align__(sizeof(float)) unsigned char shared_mem[]; 43 float* shared_mem_boundaries = reinterpret_cast<float*>(shared_mem); 44 45 if (useSharedMem) { 46 int32 lidx = threadIdx.y * blockDim.x + threadIdx.x; 47 int32 blockSize = blockDim.x * blockDim.y; 48 49 for (int32 i = lidx; i < size_boundaries; i += blockSize) { 50 shared_mem_boundaries[i] = boundaries[i]; 51 } 52 53 __syncthreads(); 54 55 boundaries = shared_mem_boundaries; 56 } 57 58 CUDA_1D_KERNEL_LOOP(i, size_in) { 59 T value = in[i]; 60 int32 bucket = 0; 61 int32 count = size_boundaries; 62 while (count > 0) { 63 int32 l = bucket; 64 int32 step = count / 2; 65 l += step; 66 if (!(value < static_cast<T>(boundaries[l]))) { 67 bucket = ++l; 68 count -= step + 1; 69 } else { 70 count = step; 71 } 72 } 73 out[i] = bucket; 74 } 75 } 76 77 namespace functor { 78 79 template <typename T> 80 struct BucketizeFunctor<GPUDevice, T> { 81 // PRECONDITION: boundaries_vector must be sorted. 82 static Status Compute(OpKernelContext* context, 83 const typename TTypes<T, 1>::ConstTensor& input, 84 const std::vector<float>& boundaries_vector, 85 typename TTypes<int32, 1>::Tensor& output) { 86 const GPUDevice& d = context->eigen_device<GPUDevice>(); 87 88 CudaDeviceArrayOnHost<float> boundaries_array(context, 89 boundaries_vector.size()); 90 TF_RETURN_IF_ERROR(boundaries_array.Init()); 91 for (int i = 0; i < boundaries_vector.size(); ++i) { 92 boundaries_array.Set(i, boundaries_vector[i]); 93 } 94 TF_RETURN_IF_ERROR(boundaries_array.Finalize()); 95 96 CudaLaunchConfig config = GetCudaLaunchConfig(input.size(), d); 97 int32 shared_mem_size = sizeof(float) * boundaries_vector.size(); 98 const int32 kMaxSharedMemBytes = 16384; 99 if (shared_mem_size < d.sharedMemPerBlock() && 100 shared_mem_size < kMaxSharedMemBytes) { 101 BucketizeCustomKernel<T, true> 102 <<<config.block_count, config.thread_per_block, shared_mem_size, 103 d.stream()>>>(input.size(), input.data(), boundaries_vector.size(), 104 boundaries_array.data(), output.data()); 105 } else { 106 BucketizeCustomKernel<T, false> 107 <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( 108 input.size(), input.data(), boundaries_vector.size(), 109 boundaries_array.data(), output.data()); 110 } 111 return Status::OK(); 112 } 113 }; 114 } // namespace functor 115 116 #define REGISTER_GPU_SPEC(type) \ 117 template struct functor::BucketizeFunctor<GPUDevice, type>; 118 119 REGISTER_GPU_SPEC(int32); 120 REGISTER_GPU_SPEC(int64); 121 REGISTER_GPU_SPEC(float); 122 REGISTER_GPU_SPEC(double); 123 #undef REGISTER_GPU_SPEC 124 125 } // namespace tensorflow 126 127 #endif // GOOGLE_CUDA 128