Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #if GOOGLE_CUDA
     17 
     18 #define EIGEN_USE_GPU
     19 
     20 #include "tensorflow/core/kernels/aggregate_ops.h"
     21 
     22 #include "tensorflow/core/framework/register_types.h"
     23 #include "tensorflow/core/framework/tensor_types.h"
     24 #include "tensorflow/core/platform/types.h"
     25 
     26 namespace tensorflow {
     27 
     28 typedef Eigen::GpuDevice GPUDevice;
     29 
     30 // Partial specialization for a GPUDevice, that uses the Eigen implementation.
     31 namespace functor {
     32 template <typename T>
     33 struct Add2Functor<GPUDevice, T> {
     34   void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
     35                   typename TTypes<T>::ConstFlat in1,
     36                   typename TTypes<T>::ConstFlat in2) {
     37     Add2EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2);
     38   }
     39 };
     40 
     41 template <typename T>
     42 struct Add3Functor<GPUDevice, T> {
     43   void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
     44                   typename TTypes<T>::ConstFlat in1,
     45                   typename TTypes<T>::ConstFlat in2,
     46                   typename TTypes<T>::ConstFlat in3) {
     47     Add3EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3);
     48   }
     49 };
     50 
     51 template <typename T>
     52 struct Add4Functor<GPUDevice, T> {
     53   void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
     54                   typename TTypes<T>::ConstFlat in1,
     55                   typename TTypes<T>::ConstFlat in2,
     56                   typename TTypes<T>::ConstFlat in3,
     57                   typename TTypes<T>::ConstFlat in4) {
     58     Add4EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4);
     59   }
     60 };
     61 
     62 template <typename T>
     63 struct Add5Functor<GPUDevice, T> {
     64   void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
     65                   typename TTypes<T>::ConstFlat in1,
     66                   typename TTypes<T>::ConstFlat in2,
     67                   typename TTypes<T>::ConstFlat in3,
     68                   typename TTypes<T>::ConstFlat in4,
     69                   typename TTypes<T>::ConstFlat in5) {
     70     Add5EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5);
     71   }
     72 };
     73 
     74 template <typename T>
     75 struct Add6Functor<GPUDevice, T> {
     76   void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
     77                   typename TTypes<T>::ConstFlat in1,
     78                   typename TTypes<T>::ConstFlat in2,
     79                   typename TTypes<T>::ConstFlat in3,
     80                   typename TTypes<T>::ConstFlat in4,
     81                   typename TTypes<T>::ConstFlat in5,
     82                   typename TTypes<T>::ConstFlat in6) {
     83     Add6EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6);
     84   }
     85 };
     86 
     87 template <typename T>
     88 struct Add7Functor<GPUDevice, T> {
     89   void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
     90                   typename TTypes<T>::ConstFlat in1,
     91                   typename TTypes<T>::ConstFlat in2,
     92                   typename TTypes<T>::ConstFlat in3,
     93                   typename TTypes<T>::ConstFlat in4,
     94                   typename TTypes<T>::ConstFlat in5,
     95                   typename TTypes<T>::ConstFlat in6,
     96                   typename TTypes<T>::ConstFlat in7) {
     97     Add7EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
     98                                          in7);
     99   }
    100 };
    101 
    102 template <typename T>
    103 struct Add8Functor<GPUDevice, T> {
    104   void operator()(
    105       const GPUDevice& d, typename TTypes<T>::Flat out,
    106       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
    107       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
    108       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
    109       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
    110     Add8EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
    111                                          in7, in8);
    112   }
    113 };
    114 
    115 template <typename T>
    116 struct Add8pFunctor<GPUDevice, T> {
    117   void operator()(
    118       const GPUDevice& d, typename TTypes<T>::Flat out,
    119       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
    120       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
    121       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
    122       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
    123     Add8pEigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
    124                                           in7, in8);
    125   }
    126 };
    127 
    128 template <typename T>
    129 struct Add9Functor<GPUDevice, T> {
    130   void operator()(
    131       const GPUDevice& d, typename TTypes<T>::Flat out,
    132       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
    133       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
    134       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
    135       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
    136       typename TTypes<T>::ConstFlat in9) {
    137     Add9EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
    138                                          in7, in8, in9);
    139   }
    140 };
    141 
    142 }  // end namespace functor
    143 
    144 // Instantiate the GPU implementation for GPU number types.
    145 #define REGISTER_FUNCTORS(type)                           \
    146   template struct functor::Add2Functor<GPUDevice, type>;  \
    147   template struct functor::Add3Functor<GPUDevice, type>;  \
    148   template struct functor::Add4Functor<GPUDevice, type>;  \
    149   template struct functor::Add5Functor<GPUDevice, type>;  \
    150   template struct functor::Add6Functor<GPUDevice, type>;  \
    151   template struct functor::Add7Functor<GPUDevice, type>;  \
    152   template struct functor::Add8Functor<GPUDevice, type>;  \
    153   template struct functor::Add8pFunctor<GPUDevice, type>; \
    154   template struct functor::Add9Functor<GPUDevice, type>;
    155 
    156 TF_CALL_GPU_NUMBER_TYPES(REGISTER_FUNCTORS);
    157 TF_CALL_complex64(REGISTER_FUNCTORS);
    158 TF_CALL_complex128(REGISTER_FUNCTORS);
    159 
    160 #undef REGISTER_FUNCTORS
    161 
    162 }  // end namespace tensorflow
    163 
    164 #endif  // GOOGLE_CUDA
    165