Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_AGGREGATE_OPS_H_
     17 #define TENSORFLOW_KERNELS_AGGREGATE_OPS_H_
     18 
     19 // Functor definitions for Aggregate ops, must be compilable by nvcc.
     20 
     21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     22 #include "tensorflow/core/framework/tensor_types.h"
     23 
     24 namespace tensorflow {
     25 namespace functor {
     26 
     27 template <typename Device, typename T>
     28 struct Add2Functor {
     29   void operator()(const Device& d, typename TTypes<T>::Flat out,
     30                   typename TTypes<T>::ConstFlat in1,
     31                   typename TTypes<T>::ConstFlat in2);
     32 };
     33 
     34 template <typename Device, typename T>
     35 struct Add2EigenImpl {
     36   static void Compute(const Device& d, typename TTypes<T>::Flat out,
     37                       typename TTypes<T>::ConstFlat in1,
     38                       typename TTypes<T>::ConstFlat in2) {
     39     out.device(d) = in1 + in2;
     40   }
     41 };
     42 
     43 template <typename Device, typename T>
     44 struct Add3Functor {
     45   void operator()(const Device& d, typename TTypes<T>::Flat out,
     46                   typename TTypes<T>::ConstFlat in1,
     47                   typename TTypes<T>::ConstFlat in2,
     48                   typename TTypes<T>::ConstFlat in3);
     49 };
     50 
     51 template <typename Device, typename T>
     52 struct Add3EigenImpl {
     53   static void Compute(const Device& d, typename TTypes<T>::Flat out,
     54                       typename TTypes<T>::ConstFlat in1,
     55                       typename TTypes<T>::ConstFlat in2,
     56                       typename TTypes<T>::ConstFlat in3) {
     57     out.device(d) = in1 + in2 + in3;
     58   }
     59 };
     60 
     61 template <typename Device, typename T>
     62 struct Add4Functor {
     63   void operator()(const Device& d, typename TTypes<T>::Flat out,
     64                   typename TTypes<T>::ConstFlat in1,
     65                   typename TTypes<T>::ConstFlat in2,
     66                   typename TTypes<T>::ConstFlat in3,
     67                   typename TTypes<T>::ConstFlat in4);
     68 };
     69 
     70 template <typename Device, typename T>
     71 struct Add4EigenImpl {
     72   static void Compute(const Device& d, typename TTypes<T>::Flat out,
     73                       typename TTypes<T>::ConstFlat in1,
     74                       typename TTypes<T>::ConstFlat in2,
     75                       typename TTypes<T>::ConstFlat in3,
     76                       typename TTypes<T>::ConstFlat in4) {
     77     out.device(d) = in1 + in2 + in3 + in4;
     78   }
     79 };
     80 
     81 template <typename Device, typename T>
     82 struct Add5Functor {
     83   void operator()(const Device& d, typename TTypes<T>::Flat out,
     84                   typename TTypes<T>::ConstFlat in1,
     85                   typename TTypes<T>::ConstFlat in2,
     86                   typename TTypes<T>::ConstFlat in3,
     87                   typename TTypes<T>::ConstFlat in4,
     88                   typename TTypes<T>::ConstFlat in5);
     89 };
     90 
     91 template <typename Device, typename T>
     92 struct Add5EigenImpl {
     93   static void Compute(const Device& d, typename TTypes<T>::Flat out,
     94                       typename TTypes<T>::ConstFlat in1,
     95                       typename TTypes<T>::ConstFlat in2,
     96                       typename TTypes<T>::ConstFlat in3,
     97                       typename TTypes<T>::ConstFlat in4,
     98                       typename TTypes<T>::ConstFlat in5) {
     99     out.device(d) = in1 + in2 + in3 + in4 + in5;
    100   }
    101 };
    102 
    103 template <typename Device, typename T>
    104 struct Add6Functor {
    105   void operator()(const Device& d, typename TTypes<T>::Flat out,
    106                   typename TTypes<T>::ConstFlat in1,
    107                   typename TTypes<T>::ConstFlat in2,
    108                   typename TTypes<T>::ConstFlat in3,
    109                   typename TTypes<T>::ConstFlat in4,
    110                   typename TTypes<T>::ConstFlat in5,
    111                   typename TTypes<T>::ConstFlat in6);
    112 };
    113 
    114 template <typename Device, typename T>
    115 struct Add6EigenImpl {
    116   static void Compute(const Device& d, typename TTypes<T>::Flat out,
    117                       typename TTypes<T>::ConstFlat in1,
    118                       typename TTypes<T>::ConstFlat in2,
    119                       typename TTypes<T>::ConstFlat in3,
    120                       typename TTypes<T>::ConstFlat in4,
    121                       typename TTypes<T>::ConstFlat in5,
    122                       typename TTypes<T>::ConstFlat in6) {
    123     out.device(d) = in1 + in2 + in3 + in4 + in5 + in6;
    124   }
    125 };
    126 
    127 template <typename Device, typename T>
    128 struct Add7Functor {
    129   void operator()(const Device& d, typename TTypes<T>::Flat out,
    130                   typename TTypes<T>::ConstFlat in1,
    131                   typename TTypes<T>::ConstFlat in2,
    132                   typename TTypes<T>::ConstFlat in3,
    133                   typename TTypes<T>::ConstFlat in4,
    134                   typename TTypes<T>::ConstFlat in5,
    135                   typename TTypes<T>::ConstFlat in6,
    136                   typename TTypes<T>::ConstFlat in7);
    137 };
    138 
    139 template <typename Device, typename T>
    140 struct Add7EigenImpl {
    141   static void Compute(const Device& d, typename TTypes<T>::Flat out,
    142                       typename TTypes<T>::ConstFlat in1,
    143                       typename TTypes<T>::ConstFlat in2,
    144                       typename TTypes<T>::ConstFlat in3,
    145                       typename TTypes<T>::ConstFlat in4,
    146                       typename TTypes<T>::ConstFlat in5,
    147                       typename TTypes<T>::ConstFlat in6,
    148                       typename TTypes<T>::ConstFlat in7) {
    149     out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7;
    150   }
    151 };
    152 
    153 template <typename Device, typename T>
    154 struct Add8Functor {
    155   void operator()(
    156       const Device& d, typename TTypes<T>::Flat out,
    157       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
    158       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
    159       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
    160       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8);
    161 };
    162 
    163 template <typename Device, typename T>
    164 struct Add8EigenImpl {
    165   static void Compute(
    166       const Device& d, typename TTypes<T>::Flat out,
    167       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
    168       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
    169       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
    170       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
    171     out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8;
    172   }
    173 };
    174 
    175 // Add8p is like Add8 except the underlying implementation should +=
    176 // rather than assign to the output.
    177 template <typename Device, typename T>
    178 struct Add8pFunctor {
    179   void operator()(
    180       const Device& d, typename TTypes<T>::Flat out,
    181       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
    182       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
    183       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
    184       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8);
    185 };
    186 
    187 template <typename Device, typename T>
    188 struct Add8pEigenImpl {
    189   static void Compute(
    190       const Device& d, typename TTypes<T>::Flat out,
    191       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
    192       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
    193       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
    194       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
    195     out.device(d) += in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8;
    196   }
    197 };
    198 
    199 template <typename Device, typename T>
    200 struct Add9Functor {
    201   void operator()(
    202       const Device& d, typename TTypes<T>::Flat out,
    203       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
    204       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
    205       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
    206       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
    207       typename TTypes<T>::ConstFlat in9);
    208 };
    209 
    210 template <typename Device, typename T>
    211 struct Add9EigenImpl {
    212   static void Compute(
    213       const Device& d, typename TTypes<T>::Flat out,
    214       typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
    215       typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
    216       typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
    217       typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
    218       typename TTypes<T>::ConstFlat in9) {
    219     out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8 + in9;
    220   }
    221 };
    222 
    223 }  // namespace functor
    224 }  // namespace tensorflow
    225 
    226 #endif  // TENSORFLOW_KERNELS_AGGREGATE_OPS_H_
    227