1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #if GOOGLE_CUDA 17 18 #define EIGEN_USE_GPU 19 20 #include <complex> 21 #include "tensorflow/core/framework/register_types.h" 22 #include "tensorflow/core/kernels/diag_op.h" 23 #include "tensorflow/core/util/cuda_kernel_helper.h" 24 25 namespace tensorflow { 26 namespace functor { 27 28 typedef Eigen::GpuDevice GPUDevice; 29 30 template <typename T> 31 __global__ void DiagCudaKernel(const int num_threads, const int64 size, 32 const T* in, T* out) { 33 CUDA_1D_KERNEL_LOOP(index, num_threads) { 34 // Fill the diagonal elements or set to zero in other place. 35 if (index % (1 + size) == 0) { 36 out[index] = in[index / (1 + size)]; 37 } else { 38 out[index] = T(0); 39 } 40 } 41 } 42 43 template <typename T> 44 struct DiagFunctor<GPUDevice, T> { 45 EIGEN_ALWAYS_INLINE Status operator()(OpKernelContext* context, 46 const int64 size, const T* in, T* out) { 47 // Empty tensor couldn't launch the kernel. 48 if (size == 0) { 49 return Status::OK(); 50 } 51 52 // CudaLaunchConfig uses an int for virtual_thread_count, 53 // so this may overflow for `size*size` in extreme cases, 54 // here is checking the multiplication overflow for integer. 55 if (size && (int(size * size) / size) != size) { 56 return errors::Internal("DiagOp got input size too large."); 57 } 58 int virtual_thread_count = int(size * size); 59 60 // Launch the GPU kernel. 61 const GPUDevice& device = context->eigen_device<GPUDevice>(); 62 CudaLaunchConfig diag_config = 63 GetCudaLaunchConfig(virtual_thread_count, device); 64 DiagCudaKernel<<<diag_config.block_count, diag_config.thread_per_block, 0, 65 device.stream()>>>(diag_config.virtual_thread_count, size, 66 in, out); 67 68 auto err = cudaGetLastError(); 69 if (err != cudaSuccess) { 70 return errors::Internal( 71 "Could not launch DiagOp kernel: ", cudaGetErrorString(err), "."); 72 } 73 return Status::OK(); 74 } 75 }; 76 77 template struct DiagFunctor<GPUDevice, double>; 78 template struct DiagFunctor<GPUDevice, float>; 79 template struct DiagFunctor<GPUDevice, int32>; 80 template struct DiagFunctor<GPUDevice, int64>; 81 template struct DiagFunctor<GPUDevice, complex64>; 82 template struct DiagFunctor<GPUDevice, complex128>; 83 84 template <typename T> 85 __global__ void DiagPartCudaKernel(const int num_threads, const int64 size, 86 const T* in, T* out) { 87 CUDA_1D_KERNEL_LOOP(index, num_threads) { 88 out[index] = in[(1 + size) * index]; 89 } 90 } 91 92 template <typename T> 93 struct DiagPartFunctor<GPUDevice, T> { 94 EIGEN_ALWAYS_INLINE Status operator()(OpKernelContext* context, 95 const int64 size, const T* in, T* out) { 96 // Empty tensor couldn't launch the kernel. 97 if (size == 0) { 98 return Status::OK(); 99 } 100 const GPUDevice& device = context->eigen_device<GPUDevice>(); 101 102 // Extract the diagonal elements. 103 CudaLaunchConfig diag_config = GetCudaLaunchConfig(size, device); 104 DiagPartCudaKernel<<<diag_config.block_count, diag_config.thread_per_block, 105 0, device.stream()>>>(diag_config.virtual_thread_count, 106 size, in, out); 107 108 auto err = cudaGetLastError(); 109 if (err != cudaSuccess) { 110 return errors::Internal( 111 "Could not launch DiagPartOp kernel: ", cudaGetErrorString(err), "."); 112 } 113 return Status::OK(); 114 } 115 }; 116 117 template struct DiagPartFunctor<GPUDevice, double>; 118 template struct DiagPartFunctor<GPUDevice, float>; 119 template struct DiagPartFunctor<GPUDevice, int32>; 120 template struct DiagPartFunctor<GPUDevice, int64>; 121 template struct DiagPartFunctor<GPUDevice, complex64>; 122 template struct DiagPartFunctor<GPUDevice, complex128>; 123 124 } // end namespace functor 125 } // end namespace tensorflow 126 127 #endif // GOOGLE_CUDA 128