1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/kernels/fill_functor.h" 17 18 #define EIGEN_USE_THREADS 19 20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 21 #include "tensorflow/core/framework/register_types.h" 22 #include "tensorflow/core/framework/tensor_types.h" 23 #include "tensorflow/core/framework/types.h" 24 #include "tensorflow/core/framework/variant_encode_decode.h" 25 26 namespace tensorflow { 27 namespace functor { 28 29 template <typename T> 30 void SetZeroFunctor<Eigen::ThreadPoolDevice, T>::operator()( 31 const Eigen::ThreadPoolDevice& d, typename TTypes<T>::Flat out) { 32 out.device(d) = out.constant(T(0)); 33 } 34 35 void SetZeroFunctor<Eigen::ThreadPoolDevice, string>::operator()( 36 const Eigen::ThreadPoolDevice& d, typename TTypes<string>::Flat out) { 37 out.device(d) = out.constant(string()); 38 } 39 40 // Explicit instantiations. 41 #define DEFINE_SETZERO_CPU(T) \ 42 template struct SetZeroFunctor<Eigen::ThreadPoolDevice, T>; 43 DEFINE_SETZERO_CPU(bool); 44 DEFINE_SETZERO_CPU(Eigen::half); 45 DEFINE_SETZERO_CPU(bfloat16); 46 DEFINE_SETZERO_CPU(float); 47 DEFINE_SETZERO_CPU(double); 48 DEFINE_SETZERO_CPU(uint8); 49 DEFINE_SETZERO_CPU(int8); 50 DEFINE_SETZERO_CPU(uint16); 51 DEFINE_SETZERO_CPU(int16); 52 DEFINE_SETZERO_CPU(int32); 53 DEFINE_SETZERO_CPU(int64); 54 DEFINE_SETZERO_CPU(complex64); 55 DEFINE_SETZERO_CPU(complex128); 56 DEFINE_SETZERO_CPU(Variant); 57 #undef DEFINE_SETZERO_CPU 58 59 #ifdef TENSORFLOW_USE_SYCL 60 template <typename T> 61 void SetZeroFunctor<Eigen::SyclDevice, T>::operator()( 62 const Eigen::SyclDevice& d, typename TTypes<T>::Flat out) { 63 To32Bit(out).device(d) = To32Bit(out).constant(T(0)); 64 } 65 66 #define DEFINE_SETZERO_SYCL(T) \ 67 template struct SetZeroFunctor<Eigen::SyclDevice, T>; 68 DEFINE_SETZERO_SYCL(bool); 69 DEFINE_SETZERO_SYCL(float); 70 DEFINE_SETZERO_SYCL(double); 71 DEFINE_SETZERO_SYCL(uint8); 72 DEFINE_SETZERO_SYCL(int8); 73 DEFINE_SETZERO_SYCL(uint16); 74 DEFINE_SETZERO_SYCL(int16); 75 DEFINE_SETZERO_SYCL(int32); 76 DEFINE_SETZERO_SYCL(int64); 77 #undef DEFINE_SETZERO_SYCL 78 #endif // TENSORFLOW_USE_SYCL 79 80 template <typename T> 81 void SetOneFunctor<Eigen::ThreadPoolDevice, T>::operator()( 82 const Eigen::ThreadPoolDevice& d, typename TTypes<T>::Flat out) { 83 out.device(d) = out.constant(T(1)); 84 } 85 86 // Explicit instantiations. 87 #define DEFINE_SETONE_CPU(T) \ 88 template struct SetOneFunctor<Eigen::ThreadPoolDevice, T>; 89 DEFINE_SETONE_CPU(bool); 90 DEFINE_SETONE_CPU(Eigen::half); 91 DEFINE_SETONE_CPU(bfloat16); 92 DEFINE_SETONE_CPU(float); 93 DEFINE_SETONE_CPU(double); 94 DEFINE_SETONE_CPU(uint8); 95 DEFINE_SETONE_CPU(int8); 96 DEFINE_SETONE_CPU(uint16); 97 DEFINE_SETONE_CPU(int16); 98 DEFINE_SETONE_CPU(int32); 99 DEFINE_SETONE_CPU(int64); 100 DEFINE_SETONE_CPU(complex64); 101 DEFINE_SETONE_CPU(complex128); 102 #undef DEFINE_SETONE_CPU 103 104 #ifdef TENSORFLOW_USE_SYCL 105 template <typename T> 106 void SetOneFunctor<Eigen::SyclDevice, T>::operator()( 107 const Eigen::SyclDevice& d, typename TTypes<T>::Flat out) { 108 out.device(d) = out.constant(T(1)); 109 } 110 111 #define DEFINE_SETONE_SYCL(T) \ 112 template struct SetOneFunctor<Eigen::SyclDevice, T>; 113 DEFINE_SETONE_SYCL(float); 114 DEFINE_SETONE_SYCL(bool); 115 DEFINE_SETONE_SYCL(double); 116 #undef DEFINE_SETONE_SYCL 117 #endif // TENSORFLOW_USE_SYCL 118 119 template <typename T> 120 struct FillFunctor<Eigen::ThreadPoolDevice, T> { 121 void operator()(const Eigen::ThreadPoolDevice& d, 122 typename TTypes<T>::Flat out, 123 typename TTypes<T>::ConstScalar in) { 124 out.device(d) = out.constant(in()); 125 } 126 }; 127 128 // Explicit instantiations. 129 #define DEFINE_FILL_CPU(T) \ 130 template struct FillFunctor<Eigen::ThreadPoolDevice, T>; 131 132 TF_CALL_ALL_TYPES(DEFINE_FILL_CPU); 133 DEFINE_FILL_CPU(quint8); 134 DEFINE_FILL_CPU(quint16); 135 #undef DEFINE_FILL_CPU 136 137 #ifdef TENSORFLOW_USE_SYCL 138 template <typename T> 139 struct FillFunctor<Eigen::SyclDevice, T> { 140 void operator()(const Eigen::SyclDevice& d, typename TTypes<T>::Flat out, 141 typename TTypes<T>::ConstScalar in) { 142 #if !defined(EIGEN_HAS_INDEX_LIST) 143 Eigen::array<int, 1> rank1{1}; 144 #else 145 Eigen::IndexList<Eigen::type2index<1> > rank1; 146 #endif 147 const int size = out.dimension(0); 148 Eigen::array<int, 1> broadcast_dims{size}; 149 150 To32Bit(out).device(d) = in.reshape(rank1).broadcast(broadcast_dims); 151 } 152 }; 153 154 #define DEFINE_FILL_SYCL(T) template struct FillFunctor<Eigen::SyclDevice, T>; 155 DEFINE_FILL_SYCL(float); 156 DEFINE_FILL_SYCL(double); 157 TF_CALL_INTEGRAL_TYPES(DEFINE_FILL_SYCL) 158 #undef DEFINE_FILL_SYCL 159 #endif // TENSORFLOW_USE_SYCL 160 161 } // namespace functor 162 } // namespace tensorflow 163