Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
     17 #define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
     18 
     19 // This file requires the following include because it uses CudaAtomicMax:
     20 // #include "tensorflow/core/util/cuda_kernel_helper.h"
     21 
     22 // Unfortunately we can't add the #include, since it breaks compilation for
     23 // non-GPU targets. This only breaks in clang, because it's more strict for
     24 // template code and CudaAtomicMax is used in template context.
     25 
     26 // This file requires the following include because it uses CudaAtomicMax:
     27 // #include "tensorflow/core/util/cuda_kernel_helper.h"
     28 
     29 // Unfortunately we can't add the #include, since it breaks compilation for
     30 // non-GPU targets. This only breaks in clang, because it's more strict for
     31 // template code and CudaAtomicMax is used in template context.
     32 
     33 // This file requires the following include because it uses CudaAtomicMax:
     34 // #include "tensorflow/core/util/cuda_kernel_helper.h"
     35 
     36 // Unfortunately we can't add the #include, since it breaks compilation for
     37 // non-GPU targets. This only breaks in clang, because it's more strict for
     38 // template code and CudaAtomicMax is used in template context.
     39 
     40 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     41 #include "tensorflow/core/framework/tensor.h"
     42 #include "tensorflow/core/framework/tensor_shape.h"
     43 #include "tensorflow/core/framework/tensor_types.h"
     44 
     45 namespace tensorflow {
     46 
     47 class OpKernelContext;
     48 
     49 namespace functor {
     50 
     51 #ifdef GOOGLE_CUDA
     52 typedef Eigen::GpuDevice GPUDevice;
     53 // Functor for SegmentSumGPUOp.
     54 // output_rows: the number of output segments (unique segment ids in
     55 //                'segment_ids').
     56 // segment_ids_shape: shape of 'segment_ids' tensor.
     57 // segment_ids: unsorted map from input to output segment ids at which to
     58 //                perform segment sum operation.
     59 // data_size: size of input data tensor.
     60 // data: input data tensor.
     61 // output: output reshaped to {output_rows, output.size/output_rows}
     62 template <typename T, typename Index>
     63 struct SegmentSumFunctor {
     64   void operator()(OpKernelContext* ctx, const GPUDevice& d,
     65                   const Index output_rows, const TensorShape& segment_ids_shape,
     66                   typename TTypes<Index>::ConstFlat segment_ids,
     67                   const Index data_size, const T* data,
     68                   typename TTypes<T, 2>::Tensor output);
     69 };
     70 
     71 #endif
     72 
     73 template <typename Device, typename T, typename Index, typename InitialValueF,
     74           typename ReductionF>
     75 struct UnsortedSegmentFunctor {
     76   void operator()(OpKernelContext* ctx, const Index num_segments,
     77                   const TensorShape& segment_ids_shape,
     78                   typename TTypes<Index>::ConstFlat segment_ids,
     79                   const Index data_size, const T* data,
     80                   typename TTypes<T, 2>::Tensor output);
     81 };
     82 
     83 #ifdef GOOGLE_CUDA
     84 // reduction functors for the gpu
     85 template <typename T>
     86 struct SumOpGpu {
     87   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
     88                                                         const T& value) {
     89     CudaAtomicAdd(dest, value);
     90   }
     91 };
     92 
     93 template <typename T>
     94 struct ProdOpGpu {
     95   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
     96                                                         const T& value) {
     97     CudaAtomicMul(dest, value);
     98   }
     99 };
    100 
    101 template <typename T>
    102 struct MaxOpGpu {
    103   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
    104                                                         const T& value) {
    105     CudaAtomicMax(dest, value);
    106   }
    107 };
    108 
    109 template <typename T>
    110 struct MinOpGpu {
    111   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
    112                                                         const T& value) {
    113     CudaAtomicMin(dest, value);
    114   }
    115 };
    116 
    117 #endif  // GOOGLE_CUDA
    118 
    119 // initial value functors
    120 template <typename T>
    121 struct Zero {
    122   EIGEN_STRONG_INLINE T operator()() const { return T(0); }
    123 };
    124 
    125 template <typename T>
    126 struct One {
    127   EIGEN_STRONG_INLINE T operator()() const { return T(1); }
    128 };
    129 
    130 template <typename T>
    131 struct Lowest {
    132   EIGEN_STRONG_INLINE T operator()() const {
    133     return Eigen::NumTraits<T>::lowest();
    134   }
    135 };
    136 
    137 template <typename T>
    138 struct Highest {
    139   EIGEN_STRONG_INLINE T operator()() const {
    140     return Eigen::NumTraits<T>::highest();
    141   }
    142 };
    143 
    144 }  // namespace functor
    145 }  // namespace tensorflow
    146 
    147 #endif  // TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
    148