1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_ 17 #define TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_ 18 19 // Functor definitions for ScatterND ops, must be compilable by nvcc. 20 21 #define EIGEN_USE_THREADS 22 23 #include <atomic> 24 25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 26 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/framework/register_types.h" 29 #include "tensorflow/core/framework/tensor.h" 30 #include "tensorflow/core/framework/tensor_shape.h" 31 #include "tensorflow/core/kernels/bounds_check.h" 32 #include "tensorflow/core/kernels/fill_functor.h" 33 #include "tensorflow/core/kernels/scatter_nd_op.h" 34 #include "tensorflow/core/platform/mutex.h" 35 #include "tensorflow/core/platform/types.h" 36 #include "tensorflow/core/util/util.h" 37 38 namespace tensorflow { 39 40 typedef Eigen::ThreadPoolDevice CPUDevice; 41 #ifdef TENSORFLOW_USE_SYCL 42 typedef Eigen::SyclDevice SYCLDevice; 43 #endif // TENSORFLOW_USE_SYCL 44 45 class OpKernelContext; 46 47 // Specialization of UpdateExecutor to CPU 48 namespace update_executor { 49 50 template <typename Input, typename Update, typename Output, 51 scatter_nd_op::UpdateOp OP> 52 class UpdateExecutor { 53 public: 54 EIGEN_STRONG_INLINE static void Execute(Input value, Update update, 55 Output output); 56 }; 57 58 template <typename Input, typename Update, typename Output> 59 class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::ASSIGN> { 60 public: 61 EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update, 62 Output output) { 63 output = update; 64 } 65 }; 66 67 template <typename Input, typename Update, typename Output> 68 class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::ADD> { 69 public: 70 EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update, 71 Output output) { 72 output += update; 73 } 74 }; 75 76 template <typename Input, typename Update, typename Output> 77 class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::SUB> { 78 public: 79 EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update, 80 Output output) { 81 output -= update; 82 } 83 }; 84 85 } // namespace update_executor 86 87 namespace functor { 88 89 // Implementation of update functor for CPU. 90 template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM> 91 struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM> { 92 Index operator()( 93 const CPUDevice& d, const Index slice_size, 94 const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, 95 typename TTypes<T, 2>::Tensor Tparams, 96 typename TTypes<Index, 2>::ConstTensor Tindices, 97 typename TTypes<T, 2>::ConstTensor Tupdates, 98 typename TTypes<T, 2>::Tensor Toutput) { 99 // error_loc is -1 if there's no out-of-bounds index, 100 // otherwise it is the location of an OOB index in Tindices. 101 Index error_loc = -1; 102 103 const Eigen::DenseIndex batch_size = Tindices.dimension(0); 104 105 Index batch_strides[IXDIM]; 106 for (int dim = IXDIM - 1; dim >= 0; --dim) { 107 if (dim == IXDIM - 1) { 108 batch_strides[dim] = 1; 109 } else { 110 batch_strides[dim] = 111 batch_strides[dim + 1] * output_shape_prefix[dim + 1]; 112 } 113 } 114 115 for (Eigen::DenseIndex loc = 0; loc < batch_size; ++loc) { 116 Index i = 0; 117 bool out_of_bounds = false; 118 for (int dim = 0; dim < IXDIM; ++dim) { 119 const Index ix_d = internal::SubtleMustCopy(Tindices(loc, dim)); 120 out_of_bounds |= !FastBoundsCheck(ix_d, output_shape_prefix[dim]); 121 i += ix_d * batch_strides[dim]; 122 } 123 if (TF_PREDICT_FALSE(out_of_bounds)) { 124 error_loc = loc; 125 break; 126 } else { 127 auto input_chip = Toutput.template chip<0>(i); 128 auto output_chip = input_chip.device(d); 129 auto update_chip = Tupdates.template chip<0>(loc); 130 update_executor::UpdateExecutor< 131 decltype(input_chip), decltype(update_chip), decltype(output_chip), 132 OP>::Execute(input_chip, update_chip, output_chip); 133 } 134 } 135 136 return error_loc; 137 } 138 }; 139 140 #define REGISTER_SCATTER_ND_FULL(T, Index, op) \ 141 template Index \ 142 ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator()( \ 143 const CPUDevice& d, const Index slice_size, \ 144 const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \ 145 output_shape_prefix, \ 146 typename TTypes<T, 2>::Tensor Tparams, \ 147 typename TTypes<Index, 2>::ConstTensor Tindices, \ 148 typename TTypes<T, 2>::ConstTensor Tupdates, \ 149 typename TTypes<T, 2>::Tensor Toutput) 150 151 #define REGISTER_SCATTER_ND_INDEX(type, op) \ 152 REGISTER_SCATTER_ND_FULL(type, int32, op); \ 153 REGISTER_SCATTER_ND_FULL(type, int64, op) 154 155 #define REGISTER_SCATTER_ND_UPDATE(type) \ 156 REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ASSIGN); 157 158 #define REGISTER_SCATTER_ND_MATH(type) \ 159 REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ADD); \ 160 REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::SUB); 161 162 TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE); 163 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH) 164 165 #undef REGISTER_SCATTER_ND_MATH 166 #undef REGISTER_SCATTER_ND_UPDATE 167 #undef REGISTER_SCATTER_ND_INDEX 168 #undef REGISTER_SCATTER_ND_FULL 169 170 #ifdef TENSORFLOW_USE_SYCL 171 // Implementation of update functor for SYCL. 172 template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM> 173 struct ScatterNdFunctor<SYCLDevice, T, Index, OP, IXDIM> { 174 Index operator()( 175 const SYCLDevice& d, const Index slice_size, 176 const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, 177 typename TTypes<T, 2>::Tensor Tparams, 178 typename TTypes<Index, 2>::ConstTensor Tindices, 179 typename TTypes<T, 2>::ConstTensor Tupdates, 180 typename TTypes<T, 2>::Tensor Toutput) { 181 // error_loc is -1 if there's no out-of-bounds index, 182 // otherwise it is the location of an OOB index in Tindices. 183 Index error_loc = -1; 184 185 const Eigen::DenseIndex batch_size = Tindices.dimension(0); 186 187 Index batch_strides[IXDIM]; 188 for (int dim = IXDIM - 1; dim >= 0; --dim) { 189 if (dim == IXDIM - 1) { 190 batch_strides[dim] = 1; 191 } else { 192 batch_strides[dim] = 193 batch_strides[dim + 1] * output_shape_prefix[dim + 1]; 194 } 195 } 196 197 for (Eigen::DenseIndex loc = 0; loc < batch_size; ++loc) { 198 Index i = 0; 199 bool out_of_bounds = false; 200 for (int dim = 0; dim < IXDIM; ++dim) { 201 const Index ix_d = internal::SubtleMustCopy(Tindices(loc, dim)); 202 out_of_bounds |= !FastBoundsCheck(ix_d, output_shape_prefix[dim]); 203 i += ix_d * batch_strides[dim]; 204 } 205 if (TF_PREDICT_FALSE(out_of_bounds)) { 206 error_loc = loc; 207 break; 208 } else { 209 auto input_chip = Toutput.template chip<0>(i); 210 auto output_chip = input_chip.device(d); 211 auto update_chip = Tupdates.template chip<0>(loc); 212 update_executor::UpdateExecutor< 213 decltype(input_chip), decltype(update_chip), decltype(output_chip), 214 OP>::Execute(input_chip, update_chip, output_chip); 215 } 216 } 217 218 return error_loc; 219 } 220 }; 221 222 #define REGISTER_SCATTER_ND_FULL_SYCL(T, Index, op) \ 223 template Index \ 224 ScatterNdFunctor<SYCLDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator()( \ 225 const SYCLDevice& d, const Index slice_size, \ 226 const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \ 227 output_shape_prefix, \ 228 typename TTypes<T, 2>::Tensor Tparams, \ 229 typename TTypes<Index, 2>::ConstTensor Tindices, \ 230 typename TTypes<T, 2>::ConstTensor Tupdates, \ 231 typename TTypes<T, 2>::Tensor Toutput) 232 233 #define REGISTER_SCATTER_ND_INDEX_SYCL(type, op) \ 234 REGISTER_SCATTER_ND_FULL_SYCL(type, int32, op); \ 235 REGISTER_SCATTER_ND_FULL_SYCL(type, int64, op) 236 237 #define REGISTER_SCATTER_ND_UPDATE_SYCL(type) \ 238 REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::ASSIGN); 239 240 #define REGISTER_SCATTER_ND_MATH_SYCL(type) \ 241 REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::ADD); \ 242 REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::SUB); 243 244 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL) 245 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_MATH_SYCL) 246 REGISTER_SCATTER_ND_UPDATE_SYCL(int32); 247 REGISTER_SCATTER_ND_MATH_SYCL(int32); 248 249 #undef REGISTER_SCATTER_ND_MATH_SYCL 250 #undef REGISTER_SCATTER_ND_UPDATE_SYCL 251 #undef REGISTER_SCATTER_ND_INDEX_SYCL 252 #undef REGISTER_SCATTER_ND_FULL_SYCL 253 254 #endif // TENSORFLOW_USE_SYCL 255 256 } // namespace functor 257 258 } // namespace tensorflow 259 260 #endif // TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_ 261