1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #if GOOGLE_CUDA 17 18 #define EIGEN_USE_GPU 19 20 #include "tensorflow/core/kernels/aggregate_ops.h" 21 22 #include "tensorflow/core/framework/register_types.h" 23 #include "tensorflow/core/framework/tensor_types.h" 24 #include "tensorflow/core/platform/types.h" 25 26 namespace tensorflow { 27 28 typedef Eigen::GpuDevice GPUDevice; 29 30 // Partial specialization for a GPUDevice, that uses the Eigen implementation. 31 namespace functor { 32 template <typename T> 33 struct Add2Functor<GPUDevice, T> { 34 void operator()(const GPUDevice& d, typename TTypes<T>::Flat out, 35 typename TTypes<T>::ConstFlat in1, 36 typename TTypes<T>::ConstFlat in2) { 37 Add2EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2); 38 } 39 }; 40 41 template <typename T> 42 struct Add3Functor<GPUDevice, T> { 43 void operator()(const GPUDevice& d, typename TTypes<T>::Flat out, 44 typename TTypes<T>::ConstFlat in1, 45 typename TTypes<T>::ConstFlat in2, 46 typename TTypes<T>::ConstFlat in3) { 47 Add3EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3); 48 } 49 }; 50 51 template <typename T> 52 struct Add4Functor<GPUDevice, T> { 53 void operator()(const GPUDevice& d, typename TTypes<T>::Flat out, 54 typename TTypes<T>::ConstFlat in1, 55 typename TTypes<T>::ConstFlat in2, 56 typename TTypes<T>::ConstFlat in3, 57 typename TTypes<T>::ConstFlat in4) { 58 Add4EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4); 59 } 60 }; 61 62 template <typename T> 63 struct Add5Functor<GPUDevice, T> { 64 void operator()(const GPUDevice& d, typename TTypes<T>::Flat out, 65 typename TTypes<T>::ConstFlat in1, 66 typename TTypes<T>::ConstFlat in2, 67 typename TTypes<T>::ConstFlat in3, 68 typename TTypes<T>::ConstFlat in4, 69 typename TTypes<T>::ConstFlat in5) { 70 Add5EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5); 71 } 72 }; 73 74 template <typename T> 75 struct Add6Functor<GPUDevice, T> { 76 void operator()(const GPUDevice& d, typename TTypes<T>::Flat out, 77 typename TTypes<T>::ConstFlat in1, 78 typename TTypes<T>::ConstFlat in2, 79 typename TTypes<T>::ConstFlat in3, 80 typename TTypes<T>::ConstFlat in4, 81 typename TTypes<T>::ConstFlat in5, 82 typename TTypes<T>::ConstFlat in6) { 83 Add6EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6); 84 } 85 }; 86 87 template <typename T> 88 struct Add7Functor<GPUDevice, T> { 89 void operator()(const GPUDevice& d, typename TTypes<T>::Flat out, 90 typename TTypes<T>::ConstFlat in1, 91 typename TTypes<T>::ConstFlat in2, 92 typename TTypes<T>::ConstFlat in3, 93 typename TTypes<T>::ConstFlat in4, 94 typename TTypes<T>::ConstFlat in5, 95 typename TTypes<T>::ConstFlat in6, 96 typename TTypes<T>::ConstFlat in7) { 97 Add7EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, 98 in7); 99 } 100 }; 101 102 template <typename T> 103 struct Add8Functor<GPUDevice, T> { 104 void operator()( 105 const GPUDevice& d, typename TTypes<T>::Flat out, 106 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 107 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 108 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 109 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) { 110 Add8EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, 111 in7, in8); 112 } 113 }; 114 115 template <typename T> 116 struct Add8pFunctor<GPUDevice, T> { 117 void operator()( 118 const GPUDevice& d, typename TTypes<T>::Flat out, 119 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 120 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 121 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 122 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) { 123 Add8pEigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, 124 in7, in8); 125 } 126 }; 127 128 template <typename T> 129 struct Add9Functor<GPUDevice, T> { 130 void operator()( 131 const GPUDevice& d, typename TTypes<T>::Flat out, 132 typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2, 133 typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4, 134 typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6, 135 typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8, 136 typename TTypes<T>::ConstFlat in9) { 137 Add9EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6, 138 in7, in8, in9); 139 } 140 }; 141 142 } // end namespace functor 143 144 // Instantiate the GPU implementation for GPU number types. 145 #define REGISTER_FUNCTORS(type) \ 146 template struct functor::Add2Functor<GPUDevice, type>; \ 147 template struct functor::Add3Functor<GPUDevice, type>; \ 148 template struct functor::Add4Functor<GPUDevice, type>; \ 149 template struct functor::Add5Functor<GPUDevice, type>; \ 150 template struct functor::Add6Functor<GPUDevice, type>; \ 151 template struct functor::Add7Functor<GPUDevice, type>; \ 152 template struct functor::Add8Functor<GPUDevice, type>; \ 153 template struct functor::Add8pFunctor<GPUDevice, type>; \ 154 template struct functor::Add9Functor<GPUDevice, type>; 155 156 TF_CALL_GPU_NUMBER_TYPES(REGISTER_FUNCTORS); 157 TF_CALL_complex64(REGISTER_FUNCTORS); 158 TF_CALL_complex128(REGISTER_FUNCTORS); 159 160 #undef REGISTER_FUNCTORS 161 162 } // end namespace tensorflow 163 164 #endif // GOOGLE_CUDA 165