1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_KERNELS_AGGREGATE_OPS_CPU_H_ 17 #define TENSORFLOW_KERNELS_AGGREGATE_OPS_CPU_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 #include "tensorflow/core/framework/tensor_types.h" 21 22 #include "tensorflow/core/kernels/aggregate_ops.h" 23 24 typedef Eigen::ThreadPoolDevice CPUDevice; 25 26 #ifdef TENSORFLOW_USE_SYCL 27 typedef Eigen::SyclDevice SYCLDevice; 28 #endif // TENSORFLOW_USE_SYCL 29 30 namespace tensorflow { 31 32 // Partial specializations for a CPUDevice, that uses the Eigen implementation 33 // from AddNEigenImpl. 34 namespace functor { 35 template <typename T> 36 struct Add2Functor<CPUDevice, T> { 37 void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, 38 typename TTypes<T>::ConstFlat in1, 39 typename TTypes<T>::ConstFlat in2) { 40 Add2EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2); 41 } 42 }; 43 template <typename T> 44 struct Add3Functor<CPUDevice, T> { 45 void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, 46 typename TTypes<T>::ConstFlat in1, 47 typename TTypes<T>::ConstFlat in2, 48 typename TTypes<T>::ConstFlat in3) { 49 Add3EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3); 50 } 51 }; 52 template <typename T> 53 struct Add4Functor<CPUDevice, T> { 54 void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, 55 typename TTypes<T>::ConstFlat in1, 56 typename TTypes<T>::ConstFlat in2, 57 typename TTypes<T>::ConstFlat in3, 58 typename TTypes<T>::ConstFlat in4) { 59 Add4EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4); 60 } 61 }; 62 template <typename T> 63 struct Add5Functor<CPUDevice, T> { 64 void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, 65 typename TTypes<T>::ConstFlat in1, 66 typename TTypes<T>::ConstFlat in2, 67 typename TTypes<T>::ConstFlat in3, 68 typename TTypes<T>::ConstFlat in4, 69 typename TTypes<T>::ConstFlat in5) { 70 Add5EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5); 71 } 72 }; 73 template <typename T> 74 struct Add6Functor<CPUDevice, T> { 75 void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, 76 typename TTypes<T>::ConstFlat in1, 77 typename TTypes<T>::ConstFlat in2, 78 typename TTypes<T>::ConstFlat in3, 79 typename TTypes<T>::ConstFlat in4, 80 typename TTypes<T>::ConstFlat in5, 81 typename TTypes<T>::ConstFlat in6) { 82 Add6EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6); 83 } 84 }; 85 template <typename T> 86 struct Add7Functor<CPUDevice, T> { 87 void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, 88 typename TTypes<T>::ConstFlat in1, 89 typename TTypes<T>::ConstFlat in2, 90 typename TTypes<T>::ConstFlat in3, 91 typename TTypes<T>::ConstFlat in4, 92 typename TTypes<T>::ConstFlat in5, 93 typename TTypes<T>::ConstFlat in6, 94 typename TTypes<T>::ConstFlat in7) { 95 Add7EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, 96 in7); 97 } 98 }; 99 100 template <typename T> 101 struct Add8Functor<CPUDevice, T> { 102 void operator()( 103 const CPUDevice& d, typename TTypes<T>::Flat out, 104 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 105 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 106 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 107 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) { 108 Add8EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, 109 in7, in8); 110 } 111 }; 112 113 template <typename T> 114 struct Add8pFunctor<CPUDevice, T> { 115 void operator()( 116 const CPUDevice& d, typename TTypes<T>::Flat out, 117 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 118 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 119 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 120 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) { 121 Add8pEigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, 122 in7, in8); 123 } 124 }; 125 126 template <typename T> 127 struct Add9Functor<CPUDevice, T> { 128 void operator()( 129 const CPUDevice& d, typename TTypes<T>::Flat out, 130 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 131 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 132 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 133 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8, 134 typename TTypes<T>::ConstFlat in9) { 135 Add9EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, 136 in7, in8, in9); 137 } 138 }; 139 140 #ifdef TENSORFLOW_USE_SYCL 141 // Partial specializations for a SYCLDevice, that uses the Eigen implementation 142 // from AddNEigenImpl. 143 template <typename T> 144 struct Add2Functor<SYCLDevice, T> { 145 void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, 146 typename TTypes<T>::ConstFlat in1, 147 typename TTypes<T>::ConstFlat in2) { 148 Add2EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2); 149 } 150 }; 151 template <typename T> 152 struct Add3Functor<SYCLDevice, T> { 153 void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, 154 typename TTypes<T>::ConstFlat in1, 155 typename TTypes<T>::ConstFlat in2, 156 typename TTypes<T>::ConstFlat in3) { 157 Add3EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3); 158 } 159 }; 160 template <typename T> 161 struct Add4Functor<SYCLDevice, T> { 162 void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, 163 typename TTypes<T>::ConstFlat in1, 164 typename TTypes<T>::ConstFlat in2, 165 typename TTypes<T>::ConstFlat in3, 166 typename TTypes<T>::ConstFlat in4) { 167 Add4EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4); 168 } 169 }; 170 template <typename T> 171 struct Add5Functor<SYCLDevice, T> { 172 void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, 173 typename TTypes<T>::ConstFlat in1, 174 typename TTypes<T>::ConstFlat in2, 175 typename TTypes<T>::ConstFlat in3, 176 typename TTypes<T>::ConstFlat in4, 177 typename TTypes<T>::ConstFlat in5) { 178 Add5EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5); 179 } 180 }; 181 template <typename T> 182 struct Add6Functor<SYCLDevice, T> { 183 void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, 184 typename TTypes<T>::ConstFlat in1, 185 typename TTypes<T>::ConstFlat in2, 186 typename TTypes<T>::ConstFlat in3, 187 typename TTypes<T>::ConstFlat in4, 188 typename TTypes<T>::ConstFlat in5, 189 typename TTypes<T>::ConstFlat in6) { 190 Add6EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6); 191 } 192 }; 193 template <typename T> 194 struct Add7Functor<SYCLDevice, T> { 195 void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, 196 typename TTypes<T>::ConstFlat in1, 197 typename TTypes<T>::ConstFlat in2, 198 typename TTypes<T>::ConstFlat in3, 199 typename TTypes<T>::ConstFlat in4, 200 typename TTypes<T>::ConstFlat in5, 201 typename TTypes<T>::ConstFlat in6, 202 typename TTypes<T>::ConstFlat in7) { 203 Add7EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, 204 in7); 205 } 206 }; 207 208 template <typename T> 209 struct Add8Functor<SYCLDevice, T> { 210 void operator()( 211 const SYCLDevice& d, typename TTypes<T>::Flat out, 212 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 213 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 214 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 215 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) { 216 Add8EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, 217 in7, in8); 218 } 219 }; 220 221 template <typename T> 222 struct Add8pFunctor<SYCLDevice, T> { 223 void operator()( 224 const SYCLDevice& d, typename TTypes<T>::Flat out, 225 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 226 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 227 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 228 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) { 229 Add8pEigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, 230 in7, in8); 231 } 232 }; 233 234 template <typename T> 235 struct Add9Functor<SYCLDevice, T> { 236 void operator()( 237 const SYCLDevice& d, typename TTypes<T>::Flat out, 238 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 239 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 240 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 241 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8, 242 typename TTypes<T>::ConstFlat in9) { 243 Add9EigenImpl<SYCLDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, 244 in7, in8, in9); 245 } 246 }; 247 #endif // TENSORFLOW_USE_SYCL 248 249 } // namespace functor 250 251 } // namespace tensorflow 252 253 #endif // TENSORFLOW_KERNELS_AGGREGATE_OPS_CPU_H_ 254