1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #if GOOGLE_CUDA 17 18 #define EIGEN_USE_GPU 19 20 #include "tensorflow/core/framework/op_kernel.h" 21 #include "tensorflow/core/framework/register_types.h" 22 #include "tensorflow/core/framework/tensor.h" 23 #include "tensorflow/core/framework/tensor_shape.h" 24 #include "tensorflow/core/framework/tensor_types.h" 25 #include "tensorflow/core/platform/types.h" 26 27 #include "tensorflow/core/util/cuda_kernel_helper.h" 28 29 #include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h" 30 #include "tensorflow/core/kernels/reduction_ops_common.h" 31 32 namespace tensorflow { 33 34 namespace { 35 36 template <typename T> 37 __global__ void GenerateNormalizedProb(const T* logits, const T* sum_probs, 38 const T* max_logits, T* output, 39 const int num_rows, const int num_cols, 40 const bool in_log_space) { 41 const int tid = blockIdx.x * blockDim.x + threadIdx.x; 42 43 const int row = tid / num_cols; 44 const int col = tid % num_cols; 45 46 if (row < num_rows && col < num_cols) { 47 if (in_log_space) 48 output[tid] = 49 logits[tid] - ldg(max_logits + row) - log(ldg(sum_probs + row)); 50 else 51 output[tid] = 52 exp(logits[tid] - ldg(max_logits + row)) / ldg(sum_probs + row); 53 } 54 } 55 56 template <typename T> 57 struct SubtractAndExpFunctor { 58 __host__ __device__ SubtractAndExpFunctor(const T* logits, 59 const T* max_logits, 60 const int num_cols) 61 : logits_(logits), max_logits_(max_logits), num_cols_(num_cols) {} 62 63 __host__ __device__ T operator()(const int gid) const { 64 return exp(logits_[gid] - ldg(max_logits_ + gid / num_cols_)); 65 } 66 67 const T* logits_; 68 const T* max_logits_; 69 const int num_cols_; 70 }; 71 72 template <typename T, typename Op, typename InputIter> 73 void DoRowReduction(OpKernelContext* context, T* output, InputIter input, 74 int rows, int cols) { 75 typedef const Eigen::array<TTypes<float>::Tensor::Index, 1>& ReductionAxes; 76 Constants<GPUDevice> constants; 77 78 Op op; 79 80 functor::ReduceImpl<T, Op, T*, InputIter, ReductionAxes>( 81 context, output, input, 2, rows, cols, 1, 1, constants.kOne, op); 82 } 83 84 } // namespace 85 86 template <typename T> 87 class SoftmaxOpGPU : public OpKernel { 88 public: 89 explicit SoftmaxOpGPU(OpKernelConstruction* context) : OpKernel(context) { 90 log_ = StringPiece(type_string()).starts_with("Log"); 91 } 92 93 void Compute(OpKernelContext* context) override { 94 const Tensor& logits_in_ = context->input(0); 95 auto logits_in = logits_in_.matrix<T>(); 96 const int rows = logits_in.dimension(0); 97 const int cols = logits_in.dimension(1); 98 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in_.shape()), 99 errors::InvalidArgument("logits must be 2-dimensional")); 100 Tensor* softmax_out = nullptr; 101 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( 102 {0}, 0, logits_in_.shape(), &softmax_out)); 103 104 const cudaStream_t& cu_stream = GetCudaStream(context); 105 if (logits_in_.NumElements() > 0) { 106 Tensor max_logits; 107 Tensor sum_probs; 108 OP_REQUIRES_OK(context, 109 context->allocate_temp(DataTypeToEnum<T>::value, 110 softmax_out->shape(), &max_logits)); 111 OP_REQUIRES_OK(context, 112 context->allocate_temp(DataTypeToEnum<T>::value, 113 softmax_out->shape(), &sum_probs)); 114 115 DoRowReduction<T, cub::Max, const T*>( 116 context, const_cast<T*>(max_logits.flat<T>().data()), 117 reinterpret_cast<const T*>(logits_in_.flat<T>().data()), rows, cols); 118 119 const int numThreads = 128; 120 const int numBlocks = Eigen::divup(rows * cols, numThreads); 121 122 cub::CountingInputIterator<int> counting_iterator(0); 123 typedef cub::TransformInputIterator<T, SubtractAndExpFunctor<T>, 124 cub::CountingInputIterator<int>> 125 InputIterType; 126 127 InputIterType input_itr( 128 counting_iterator, 129 SubtractAndExpFunctor<T>( 130 reinterpret_cast<const T*>(logits_in_.flat<T>().data()), 131 reinterpret_cast<const T*>(max_logits.flat<T>().data()), cols)); 132 133 DoRowReduction<T, cub::Sum, InputIterType>( 134 context, const_cast<T*>(sum_probs.flat<T>().data()), input_itr, rows, 135 cols); 136 137 GenerateNormalizedProb<<<numBlocks, numThreads, 0, cu_stream>>>( 138 reinterpret_cast<const T*>(logits_in_.flat<T>().data()), 139 reinterpret_cast<const T*>(sum_probs.flat<T>().data()), 140 reinterpret_cast<const T*>(max_logits.flat<T>().data()), 141 const_cast<T*>(softmax_out->flat<T>().data()), rows, cols, log_); 142 } 143 } 144 145 private: 146 bool log_; 147 }; 148 149 REGISTER_KERNEL_BUILDER( 150 Name("Softmax").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), 151 SoftmaxOpGPU<Eigen::half>); 152 REGISTER_KERNEL_BUILDER( 153 Name("Softmax").Device(DEVICE_GPU).TypeConstraint<float>("T"), 154 SoftmaxOpGPU<float>); 155 REGISTER_KERNEL_BUILDER( 156 Name("Softmax").Device(DEVICE_GPU).TypeConstraint<double>("T"), 157 SoftmaxOpGPU<double>); 158 REGISTER_KERNEL_BUILDER( 159 Name("LogSoftmax").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), 160 SoftmaxOpGPU<Eigen::half>); 161 REGISTER_KERNEL_BUILDER( 162 Name("LogSoftmax").Device(DEVICE_GPU).TypeConstraint<float>("T"), 163 SoftmaxOpGPU<float>); 164 165 } // end namespace tensorflow 166 167 #endif // GOOGLE_CUDA 168