Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #if GOOGLE_CUDA
     17 
     18 #define EIGEN_USE_GPU
     19 
     20 #include "tensorflow/core/framework/op_kernel.h"
     21 #include "tensorflow/core/framework/register_types.h"
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/framework/tensor_shape.h"
     24 #include "tensorflow/core/framework/tensor_types.h"
     25 #include "tensorflow/core/platform/types.h"
     26 
     27 #include "tensorflow/core/util/cuda_kernel_helper.h"
     28 
     29 #include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h"
     30 #include "tensorflow/core/kernels/reduction_ops_common.h"
     31 
     32 namespace tensorflow {
     33 
     34 namespace {
     35 
     36 template <typename T>
     37 __global__ void GenerateNormalizedProb(const T* logits, const T* sum_probs,
     38                                        const T* max_logits, T* output,
     39                                        const int num_rows, const int num_cols,
     40                                        const bool in_log_space) {
     41   const int tid = blockIdx.x * blockDim.x + threadIdx.x;
     42 
     43   const int row = tid / num_cols;
     44   const int col = tid % num_cols;
     45 
     46   if (row < num_rows && col < num_cols) {
     47     if (in_log_space)
     48       output[tid] =
     49           logits[tid] - ldg(max_logits + row) - log(ldg(sum_probs + row));
     50     else
     51       output[tid] =
     52           exp(logits[tid] - ldg(max_logits + row)) / ldg(sum_probs + row);
     53   }
     54 }
     55 
     56 template <typename T>
     57 struct SubtractAndExpFunctor {
     58   __host__ __device__ SubtractAndExpFunctor(const T* logits,
     59                                             const T* max_logits,
     60                                             const int num_cols)
     61       : logits_(logits), max_logits_(max_logits), num_cols_(num_cols) {}
     62 
     63   __host__ __device__ T operator()(const int gid) const {
     64     return exp(logits_[gid] - ldg(max_logits_ + gid / num_cols_));
     65   }
     66 
     67   const T* logits_;
     68   const T* max_logits_;
     69   const int num_cols_;
     70 };
     71 
     72 template <typename T, typename Op, typename InputIter>
     73 void DoRowReduction(OpKernelContext* context, T* output, InputIter input,
     74                     int rows, int cols) {
     75   typedef const Eigen::array<TTypes<float>::Tensor::Index, 1>& ReductionAxes;
     76   Constants<GPUDevice> constants;
     77 
     78   Op op;
     79 
     80   functor::ReduceImpl<T, Op, T*, InputIter, ReductionAxes>(
     81       context, output, input, 2, rows, cols, 1, 1, constants.kOne, op);
     82 }
     83 
     84 }  // namespace
     85 
     86 template <typename T>
     87 class SoftmaxOpGPU : public OpKernel {
     88  public:
     89   explicit SoftmaxOpGPU(OpKernelConstruction* context) : OpKernel(context) {
     90     log_ = StringPiece(type_string()).starts_with("Log");
     91   }
     92 
     93   void Compute(OpKernelContext* context) override {
     94     const Tensor& logits_in_ = context->input(0);
     95     auto logits_in = logits_in_.matrix<T>();
     96     const int rows = logits_in.dimension(0);
     97     const int cols = logits_in.dimension(1);
     98     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in_.shape()),
     99                 errors::InvalidArgument("logits must be 2-dimensional"));
    100     Tensor* softmax_out = nullptr;
    101     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
    102                                 {0}, 0, logits_in_.shape(), &softmax_out));
    103 
    104     const cudaStream_t& cu_stream = GetCudaStream(context);
    105     if (logits_in_.NumElements() > 0) {
    106       Tensor max_logits;
    107       Tensor sum_probs;
    108       OP_REQUIRES_OK(context,
    109                      context->allocate_temp(DataTypeToEnum<T>::value,
    110                                             softmax_out->shape(), &max_logits));
    111       OP_REQUIRES_OK(context,
    112                      context->allocate_temp(DataTypeToEnum<T>::value,
    113                                             softmax_out->shape(), &sum_probs));
    114 
    115       DoRowReduction<T, cub::Max, const T*>(
    116           context, const_cast<T*>(max_logits.flat<T>().data()),
    117           reinterpret_cast<const T*>(logits_in_.flat<T>().data()), rows, cols);
    118 
    119       const int numThreads = 128;
    120       const int numBlocks = Eigen::divup(rows * cols, numThreads);
    121 
    122       cub::CountingInputIterator<int> counting_iterator(0);
    123       typedef cub::TransformInputIterator<T, SubtractAndExpFunctor<T>,
    124                                           cub::CountingInputIterator<int>>
    125           InputIterType;
    126 
    127       InputIterType input_itr(
    128           counting_iterator,
    129           SubtractAndExpFunctor<T>(
    130               reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
    131               reinterpret_cast<const T*>(max_logits.flat<T>().data()), cols));
    132 
    133       DoRowReduction<T, cub::Sum, InputIterType>(
    134           context, const_cast<T*>(sum_probs.flat<T>().data()), input_itr, rows,
    135           cols);
    136 
    137       GenerateNormalizedProb<<<numBlocks, numThreads, 0, cu_stream>>>(
    138           reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
    139           reinterpret_cast<const T*>(sum_probs.flat<T>().data()),
    140           reinterpret_cast<const T*>(max_logits.flat<T>().data()),
    141           const_cast<T*>(softmax_out->flat<T>().data()), rows, cols, log_);
    142     }
    143   }
    144 
    145  private:
    146   bool log_;
    147 };
    148 
    149 REGISTER_KERNEL_BUILDER(
    150     Name("Softmax").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
    151     SoftmaxOpGPU<Eigen::half>);
    152 REGISTER_KERNEL_BUILDER(
    153     Name("Softmax").Device(DEVICE_GPU).TypeConstraint<float>("T"),
    154     SoftmaxOpGPU<float>);
    155 REGISTER_KERNEL_BUILDER(
    156     Name("Softmax").Device(DEVICE_GPU).TypeConstraint<double>("T"),
    157     SoftmaxOpGPU<double>);
    158 REGISTER_KERNEL_BUILDER(
    159     Name("LogSoftmax").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
    160     SoftmaxOpGPU<Eigen::half>);
    161 REGISTER_KERNEL_BUILDER(
    162     Name("LogSoftmax").Device(DEVICE_GPU).TypeConstraint<float>("T"),
    163     SoftmaxOpGPU<float>);
    164 
    165 }  // end namespace tensorflow
    166 
    167 #endif  // GOOGLE_CUDA
    168