1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_KERNELS_SCATTER_FUNCTOR_H_ 17 #define TENSORFLOW_KERNELS_SCATTER_FUNCTOR_H_ 18 19 #include <type_traits> 20 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/kernels/bounds_check.h" 23 #include "tensorflow/core/platform/types.h" 24 25 namespace tensorflow { 26 27 class OpKernelContext; 28 typedef Eigen::ThreadPoolDevice CPUDevice; 29 typedef Eigen::GpuDevice GPUDevice; 30 #ifdef TENSORFLOW_USE_SYCL 31 typedef Eigen::SyclDevice SYCLDevice; 32 #endif // TENSORFLOW_USE_SYCL 33 34 namespace scatter_op { 35 36 enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV }; 37 38 namespace internal { 39 40 template <scatter_op::UpdateOp Op> 41 struct Assign {}; 42 template <> 43 struct Assign<scatter_op::UpdateOp::ASSIGN> { 44 template <typename Params, typename Update> 45 static void Run(Params p, Update u) { 46 p = u; 47 } 48 }; 49 template <> 50 struct Assign<scatter_op::UpdateOp::ADD> { 51 template <typename Params, typename Update> 52 static void Run(Params p, Update u) { 53 p += u; 54 } 55 }; 56 template <> 57 struct Assign<scatter_op::UpdateOp::SUB> { 58 template <typename Params, typename Update> 59 static void Run(Params p, Update u) { 60 p -= u; 61 } 62 }; 63 template <> 64 struct Assign<scatter_op::UpdateOp::MUL> { 65 template <typename Params, typename Update> 66 static void Run(Params p, Update u) { 67 p *= u; 68 } 69 }; 70 template <> 71 struct Assign<scatter_op::UpdateOp::DIV> { 72 template <typename Params, typename Update> 73 static void Run(Params p, Update u) { 74 p /= u; 75 } 76 }; 77 78 #ifdef TENSORFLOW_USE_SYCL 79 template <scatter_op::UpdateOp Op> 80 struct AssignSYCL {}; 81 template <> 82 struct AssignSYCL<scatter_op::UpdateOp::ASSIGN> { 83 template <typename Device, typename Params, typename Update> 84 static void Run(Device d, Params p, Update u) { 85 p.device(d) = u; 86 } 87 }; 88 89 template <> 90 struct AssignSYCL<scatter_op::UpdateOp::ADD> { 91 template <typename Device, typename Params, typename Update> 92 static void Run(Device d, Params p, Update u) { 93 p.device(d) += u; 94 } 95 }; 96 97 template <> 98 struct AssignSYCL<scatter_op::UpdateOp::SUB> { 99 template <typename Device, typename Params, typename Update> 100 static void Run(Device d, Params p, Update u) { 101 p.device(d) -= u; 102 } 103 }; 104 105 template <> 106 struct AssignSYCL<scatter_op::UpdateOp::MUL> { 107 template <typename Device, typename Params, typename Update> 108 static void Run(Device d, Params p, Update u) { 109 p.device(d) = p * u; 110 } 111 }; 112 113 template <> 114 struct AssignSYCL<scatter_op::UpdateOp::DIV> { 115 template <typename Device, typename Params, typename Update> 116 static void Run(Device d, Params p, Update u) { 117 p.device(d) = p / u; 118 } 119 }; 120 #endif // TENSORFLOW_USE_SYCL 121 122 } // namespace internal 123 } // namespace scatter_op 124 125 namespace functor { 126 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> 127 struct ScatterFunctor { 128 Index operator()(OpKernelContext* c, const Device& d, 129 typename TTypes<T>::Matrix params, 130 typename TTypes<T>::ConstMatrix updates, 131 typename TTypes<Index>::ConstFlat indices); 132 }; 133 134 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> 135 struct ScatterFunctorBase { 136 Index operator()(OpKernelContext* c, const Device& d, 137 typename TTypes<T>::Matrix params, 138 typename TTypes<T>::ConstMatrix updates, 139 typename TTypes<Index>::ConstFlat indices) { 140 // indices and params sizes were validated in DoCompute(). 141 const Index N = static_cast<Index>(indices.size()); 142 const Index limit = static_cast<Index>(params.dimension(0)); 143 for (Index i = 0; i < N; i++) { 144 // Grab the index and check its validity. An earlier version of the 145 // code checked it and then grabbed it from memory a second time, which 146 // was a security risk since it could have changed in between. 147 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 148 if (!FastBoundsCheck(index, limit)) return i; 149 // Copy last Ndim-1 dimensions of updates[i] to params[index] 150 scatter_op::internal::Assign<op>::Run(params.template chip<0>(index), 151 updates.template chip<0>(i)); 152 } 153 return -1; 154 } 155 }; 156 157 #ifdef TENSORFLOW_USE_SYCL 158 template <typename T, typename Index, scatter_op::UpdateOp op> 159 struct ScatterFunctorBase<SYCLDevice, T, Index, op> { 160 Index operator()(OpKernelContext* c, const SYCLDevice& d, 161 typename TTypes<T>::Matrix params, 162 typename TTypes<T>::ConstMatrix updates, 163 typename TTypes<Index>::ConstFlat indices) { 164 // indices and params sizes were validated in DoCompute(). 165 const Index N = static_cast<Index>(indices.size()); 166 const Index limit = static_cast<Index>(params.dimension(0)); 167 for (Index i = 0; i < N; i++) { 168 // Grab the index and check its validity. An earlier version of the 169 // code checked it and then grabbed it from memory a second time, which 170 // was a security risk since it could have changed in between. 171 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 172 if (!FastBoundsCheck(index, limit)) return i; 173 // Copy last Ndim-1 dimensions of updates[i] to params[index] 174 scatter_op::internal::AssignSYCL<op>::Run( 175 d, params.template chip<0>(index), updates.template chip<0>(i)); 176 } 177 return -1; 178 } 179 }; 180 #endif // TENSORFLOW_USE_SYCL 181 182 template <typename T, typename Index> 183 struct ScatterFunctorBase<CPUDevice, T, Index, scatter_op::UpdateOp::ASSIGN> { 184 Index operator()(OpKernelContext* c, const CPUDevice& d, 185 typename TTypes<T>::Matrix params, 186 typename TTypes<T>::ConstMatrix updates, 187 typename TTypes<Index>::ConstFlat indices) { 188 // indices and params sizes were validated in DoCompute(). 189 const Index N = static_cast<Index>(indices.size()); 190 const Index limit = static_cast<Index>(params.dimension(0)); 191 if (!std::is_same<T, string>::value) { 192 for (Index i = 0; i < N; i++) { 193 // Grab the index and check its validity. An earlier version of the 194 // code checked it and then grabbed it from memory a second time, which 195 // was a security risk since it could have changed in between. 196 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 197 if (!FastBoundsCheck(index, limit)) return i; 198 memmove(params.data() + index * params.dimension(1), 199 updates.data() + i * updates.dimension(1), 200 updates.dimension(1) * sizeof(T)); 201 } 202 } else { 203 for (Index i = 0; i < N; i++) { 204 // Grab the index and check its validity. An earlier version of the 205 // code checked it and then grabbed it from memory a second time, which 206 // was a security risk since it could have changed in between. 207 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 208 if (!FastBoundsCheck(index, limit)) return i; 209 // Copy last Ndim-1 dimensions of updates[i] to params[index] 210 scatter_op::internal::Assign<scatter_op::UpdateOp::ASSIGN>::Run( 211 params.template chip<0>(index), updates.template chip<0>(i)); 212 } 213 } 214 return -1; 215 } 216 }; 217 218 template <typename T, typename Index, scatter_op::UpdateOp op> 219 struct ScatterFunctor<CPUDevice, T, Index, op> 220 : ScatterFunctorBase<CPUDevice, T, Index, op> {}; 221 222 #ifdef TENSORFLOW_USE_SYCL 223 template <typename T, typename Index, scatter_op::UpdateOp op> 224 struct ScatterFunctorSYCL { 225 Index operator()(OpKernelContext* c, const SYCLDevice& d, 226 typename TTypes<T>::Matrix params, 227 typename TTypes<T>::ConstMatrix updates, 228 typename TTypes<Index>::Flat indices) { 229 // indices and params sizes were validated in DoCompute(). 230 const Index N = static_cast<Index>(indices.size()); 231 const Index limit = static_cast<Index>(params.dimension(0)); 232 for (Index i = 0; i < N; i++) { 233 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 234 if (!FastBoundsCheck(index, limit)) return i; 235 // Copy last Ndim-1 dimensions of updates[i] to params[index] 236 scatter_op::internal::AssignSYCL<op>::Run( 237 d, params.template chip<0>(index), updates.template chip<0>(i)); 238 } 239 return -1; 240 } 241 }; 242 #endif // TENSORFLOW_USE_SYCL 243 244 } // namespace functor 245 } // namespace tensorflow 246 247 #endif // TENSORFLOW_KERNELS_SCATTER_FUNCTOR_H_ 248