Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_AGGREGATE_OPS_CPU_H_
     17 #define TENSORFLOW_KERNELS_AGGREGATE_OPS_CPU_H_
     18 
     19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     20 #include "tensorflow/core/framework/tensor_types.h"
     21 
     22 #include "tensorflow/core/kernels/aggregate_ops.h"
     23 
     24 typedef Eigen::ThreadPoolDevice CPUDevice;
     25 
     26 #ifdef TENSORFLOW_USE_SYCL
     27 typedef Eigen::SyclDevice SYCLDevice;
     28 #endif  // TENSORFLOW_USE_SYCL
     29 
     30 namespace tensorflow {
     31 
     32 // Partial specializations for a CPUDevice, that uses the Eigen implementation
     33 // from AddNEigenImpl.
     34 namespace functor {
     35 template <typename T>
     36 struct Add2Functor<CPUDevice, T> {
     37   void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
     38                   typename TTypes<T>::ConstFlat in1,
     39                   typename TTypes<T>::ConstFlat in2) {
     40     Add2EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2);
     41   }
     42 };
     43 template <typename T>
     44 struct Add3Functor<CPUDevice, T> {
     45   void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
     46                   typename TTypes<T>::ConstFlat in1,
     47                   typename TTypes<T>::ConstFlat in2,
     48                   typename TTypes<T>::ConstFlat in3) {
     49     Add3EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3);
     50   }
     51 };
     52 template <typename T>
     53 struct Add4Functor<CPUDevice, T> {
     54   void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
     55                   typename TTypes<T>::ConstFlat in1,
     56                   typename TTypes<T>::ConstFlat in2,
     57                   typename TTypes<T>::ConstFlat in3,
     58                   typename TTypes<T>::ConstFlat in4) {
     59     Add4EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4);
     60   }
     61 };
     62 template <typename T>
     63 struct Add5Functor<CPUDevice, T> {
     64   void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
     65                   typename TTypes<T>::ConstFlat in1,
     66                   typename TTypes<T>::ConstFlat in2,
     67                   typename TTypes<T>::ConstFlat in3,
     68                   typename TTypes<T>::ConstFlat in4,
     69                   typename TTypes<T>::ConstFlat in5) {
     70     Add5EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5);
     71   }
     72 };
     73 template <typename T>
     74 struct Add6Functor<CPUDevice, T> {
     75   void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
     76                   typename TTypes<T>::ConstFlat in1,
     77                   typename TTypes<T>::ConstFlat in2,
     78                   typename TTypes<T>::ConstFlat in3,
     79                   typename TTypes<T>::ConstFlat in4,
     80                   typename TTypes<T>::ConstFlat in5,
     81                   typename TTypes<T>::ConstFlat in6) {
     82     Add6EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6);
     83   }
     84 };
     85 template <typename T>
     86 struct Add7Functor<CPUDevice, T> {
     87   void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
     88                   typename TTypes<T>::ConstFlat in1,
     89                   typename TTypes<T>::ConstFlat in2,
     90                   typename TTypes<T>::ConstFlat in3,
     91                   typename TTypes<T>::ConstFlat in4,
     92                   typename TTypes<T>::ConstFlat in5,
     93                   typename TTypes<T>::ConstFlat in6,
     94                   typename TTypes<T>::ConstFlat in7) {
     95     Add7EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
     96                                          in7);
     97   }
     98 };
     99 
    100 template <typename T>
    101 struct Add8Functor<CPUDevice, T> {
    102   void operator()(
    103       const CPUDevice& d, typename TTypes<T>::Flat out,
    104       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
    105       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
    106       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
    107       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
    108     Add8EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
    109                                          in7, in8);
    110   }
    111 };
    112 
    113 template <typename T>
    114 struct Add8pFunctor<CPUDevice, T> {
    115   void operator()(
    116       const CPUDevice& d, typename TTypes<T>::Flat out,
    117       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
    118       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
    119       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
    120       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
    121     Add8pEigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
    122                                           in7, in8);
    123   }
    124 };
    125 
    126 template <typename T>
    127 struct Add9Functor<CPUDevice, T> {
    128   void operator()(
    129       const CPUDevice& d, typename TTypes<T>::Flat out,
    130       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
    131       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
    132       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
    133       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
    134       typename TTypes<T>::ConstFlat in9) {
    135     Add9EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
    136                                          in7, in8, in9);
    137   }
    138 };
    139 
    140 #ifdef TENSORFLOW_USE_SYCL
    141 // Partial specializations for a SYCLDevice, that uses the Eigen implementation
    142 // from AddNEigenImpl.
    143 template <typename T>
    144 struct Add2Functor<SYCLDevice, T> {
    145   void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
    146                   typename TTypes<T>::ConstFlat in1,
    147                   typename TTypes<T>::ConstFlat in2) {
    148     Add2EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2);
    149   }
    150 };
    151 template <typename T>
    152 struct Add3Functor<SYCLDevice, T> {
    153   void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
    154                   typename TTypes<T>::ConstFlat in1,
    155                   typename TTypes<T>::ConstFlat in2,
    156                   typename TTypes<T>::ConstFlat in3) {
    157     Add3EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3);
    158   }
    159 };
    160 template <typename T>
    161 struct Add4Functor<SYCLDevice, T> {
    162   void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
    163                   typename TTypes<T>::ConstFlat in1,
    164                   typename TTypes<T>::ConstFlat in2,
    165                   typename TTypes<T>::ConstFlat in3,
    166                   typename TTypes<T>::ConstFlat in4) {
    167     Add4EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4);
    168   }
    169 };
    170 template <typename T>
    171 struct Add5Functor<SYCLDevice, T> {
    172   void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
    173                   typename TTypes<T>::ConstFlat in1,
    174                   typename TTypes<T>::ConstFlat in2,
    175                   typename TTypes<T>::ConstFlat in3,
    176                   typename TTypes<T>::ConstFlat in4,
    177                   typename TTypes<T>::ConstFlat in5) {
    178     Add5EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5);
    179   }
    180 };
    181 template <typename T>
    182 struct Add6Functor<SYCLDevice, T> {
    183   void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
    184                   typename TTypes<T>::ConstFlat in1,
    185                   typename TTypes<T>::ConstFlat in2,
    186                   typename TTypes<T>::ConstFlat in3,
    187                   typename TTypes<T>::ConstFlat in4,
    188                   typename TTypes<T>::ConstFlat in5,
    189                   typename TTypes<T>::ConstFlat in6) {
    190     Add6EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6);
    191   }
    192 };
    193 template <typename T>
    194 struct Add7Functor<SYCLDevice, T> {
    195   void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
    196                   typename TTypes<T>::ConstFlat in1,
    197                   typename TTypes<T>::ConstFlat in2,
    198                   typename TTypes<T>::ConstFlat in3,
    199                   typename TTypes<T>::ConstFlat in4,
    200                   typename TTypes<T>::ConstFlat in5,
    201                   typename TTypes<T>::ConstFlat in6,
    202                   typename TTypes<T>::ConstFlat in7) {
    203     Add7EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
    204                                           in7);
    205   }
    206 };
    207 
    208 template <typename T>
    209 struct Add8Functor<SYCLDevice, T> {
    210   void operator()(
    211       const SYCLDevice& d, typename TTypes<T>::Flat out,
    212       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
    213       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
    214       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
    215       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
    216     Add8EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
    217                                           in7, in8);
    218   }
    219 };
    220 
    221 template <typename T>
    222 struct Add8pFunctor<SYCLDevice, T> {
    223   void operator()(
    224       const SYCLDevice& d, typename TTypes<T>::Flat out,
    225       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
    226       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
    227       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
    228       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
    229     Add8pEigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
    230                                            in7, in8);
    231   }
    232 };
    233 
    234 template <typename T>
    235 struct Add9Functor<SYCLDevice, T> {
    236   void operator()(
    237       const SYCLDevice& d, typename TTypes<T>::Flat out,
    238       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
    239       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
    240       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
    241       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
    242       typename TTypes<T>::ConstFlat in9) {
    243     Add9EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
    244                                           in7, in8, in9);
    245   }
    246 };
    247 #endif  // TENSORFLOW_USE_SYCL
    248 
    249 }  // namespace functor
    250 
    251 }  // namespace tensorflow
    252 
    253 #endif  // TENSORFLOW_KERNELS_AGGREGATE_OPS_CPU_H_
    254