1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/state_ops.cc. 17 18 #include "tensorflow/core/framework/op_kernel.h" 19 #include "tensorflow/core/framework/register_types.h" 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/kernels/scatter_functor.h" 22 #include "tensorflow/core/platform/mutex.h" 23 #include "tensorflow/core/platform/types.h" 24 #include "tensorflow/core/util/util.h" 25 26 #ifdef TENSORFLOW_USE_SYCL 27 #include "tensorflow/core/common_runtime/sycl/sycl_util.h" 28 #endif // TENSORFLOW_USE_SYCL 29 30 namespace tensorflow { 31 32 typedef Eigen::ThreadPoolDevice CPUDevice; 33 typedef Eigen::GpuDevice GPUDevice; 34 #ifdef TENSORFLOW_USE_SYCL 35 typedef Eigen::SyclDevice SYCLDevice; 36 #endif // TENSORFLOW_USE_SYCL 37 38 // Check whether updates.shape = indices.shape + params.shape[1:] 39 static bool ValidShapes(const Tensor& params, const Tensor& updates, 40 const Tensor& indices) { 41 if (updates.dims() != indices.dims() + params.dims() - 1) return false; 42 for (int d = 0; d < indices.dims(); d++) { 43 if (updates.dim_size(d) != indices.dim_size(d)) { 44 return false; 45 } 46 } 47 for (int d = 1; d < params.dims(); d++) { 48 if (params.dim_size(d) != updates.dim_size(d - 1 + indices.dims())) { 49 return false; 50 } 51 } 52 return true; 53 } 54 55 static void DoValidationChecking(OpKernelContext* c, const Tensor& params, 56 const Tensor& indices, const Tensor& updates) { 57 OP_REQUIRES(c, params.IsInitialized(), 58 errors::FailedPrecondition("Null ref for params")); 59 OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(params.shape()), 60 errors::InvalidArgument("params must be at least 1-D, got shape ", 61 params.shape().DebugString())); 62 OP_REQUIRES( 63 c, ValidShapes(params, updates, indices), 64 errors::InvalidArgument( 65 "Must have updates.shape = indices.shape + params.shape[1:], got ", 66 "updates.shape ", updates.shape().DebugString(), ", indices.shape ", 67 indices.shape().DebugString(), ", params.shape ", 68 params.shape().DebugString())); 69 } 70 71 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> 72 class ScatterUpdateOp : public OpKernel { 73 public: 74 // QUESTION: It'd be nice to support DT_INT16, DT_UINT8, 75 // etc. here. Should we have the framework do some sort of 76 // integer promotion automatically, or should that be something 77 // that users have to do explicitly with a conversion operator 78 // in the graph? 79 explicit ScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) { 80 OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_)); 81 } 82 83 void Compute(OpKernelContext* c) override { 84 if (use_exclusive_lock_) { 85 // Hold mutex while we apply updates 86 mutex_lock l(*c->input_ref_mutex(0)); 87 DoCompute(c); 88 } else { 89 DoCompute(c); 90 } 91 } 92 93 private: 94 bool use_exclusive_lock_; 95 96 void DoCompute(OpKernelContext* c) { 97 Tensor params = c->mutable_input(0, use_exclusive_lock_); 98 const Tensor& indices = c->input(1); 99 const Tensor& updates = c->input(2); 100 DoValidationChecking(c, params, indices, updates); 101 if (!c->status().ok()) return; 102 103 // Check that we have enough index space 104 const int64 N_big = indices.NumElements(); 105 OP_REQUIRES( 106 c, N_big <= std::numeric_limits<Index>::max(), 107 errors::InvalidArgument("indices has too many elements for ", 108 DataTypeString(DataTypeToEnum<Index>::v()), 109 " indexing: ", N_big, " > ", 110 std::numeric_limits<Index>::max())); 111 const Index N = static_cast<Index>(indices.NumElements()); 112 OP_REQUIRES( 113 c, params.dim_size(0) <= std::numeric_limits<Index>::max(), 114 errors::InvalidArgument("params.shape[0] too large for ", 115 DataTypeString(DataTypeToEnum<Index>::v()), 116 " indexing: ", params.dim_size(0), " > ", 117 std::numeric_limits<Index>::max())); 118 119 // We always return the input ref. 120 c->forward_ref_input_to_ref_output(0, 0); 121 122 if (N > 0) { 123 auto indices_flat = indices.flat<Index>(); 124 auto params_flat = params.flat_outer_dims<T>(); 125 auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N}); 126 127 functor::ScatterFunctor<Device, T, Index, op> functor; 128 const Index bad_i = functor(c, c->template eigen_device<Device>(), 129 params_flat, updates_flat, indices_flat); 130 OP_REQUIRES( 131 c, bad_i < 0, 132 errors::InvalidArgument( 133 "indices", SliceDebugString(indices.shape(), bad_i), " = ", 134 indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")")); 135 } 136 } 137 }; 138 139 #ifdef TENSORFLOW_USE_SYCL 140 template <typename T, typename Index, scatter_op::UpdateOp op> 141 class ScatterUpdateOp<SYCLDevice, T, Index, op> : public OpKernel { 142 public: 143 explicit ScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) { 144 OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_)); 145 } 146 147 void Compute(OpKernelContext* c) override { 148 if (use_exclusive_lock_) { 149 // Hold mutex while we apply updates 150 mutex_lock l(*c->input_ref_mutex(0)); 151 DoCompute(c); 152 } else { 153 DoCompute(c); 154 } 155 } 156 157 private: 158 bool use_exclusive_lock_; 159 160 void DoCompute(OpKernelContext* c) { 161 Tensor params = c->mutable_input(0, use_exclusive_lock_); 162 const Tensor& indices = c->input(1); 163 const Tensor& updates = c->input(2); 164 DoValidationChecking(c, params, indices, updates); 165 if (!c->status().ok()) return; 166 167 // Check that we have enough index space 168 const int64 N_big = indices.NumElements(); 169 OP_REQUIRES( 170 c, N_big <= std::numeric_limits<Index>::max(), 171 errors::InvalidArgument("indices has too many elements for ", 172 DataTypeString(DataTypeToEnum<Index>::v()), 173 " indexing: ", N_big, " > ", 174 std::numeric_limits<Index>::max())); 175 const Index N = static_cast<Index>(indices.NumElements()); 176 OP_REQUIRES( 177 c, params.dim_size(0) <= std::numeric_limits<Index>::max(), 178 errors::InvalidArgument("params.shape[0] too large for ", 179 DataTypeString(DataTypeToEnum<Index>::v()), 180 " indexing: ", params.dim_size(0), " > ", 181 std::numeric_limits<Index>::max())); 182 183 // We always return the input ref. 184 c->forward_ref_input_to_ref_output(0, 0); 185 186 if (N > 0) { 187 auto index_size = indices.NumElements() * sizeof(Index); 188 Tensor indices_host = Tensor(indices.dtype(), indices.shape()); 189 190 auto src_ptr = GetBase(&indices); 191 auto dst_ptr = GetBase(&indices_host); 192 193 c->eigen_sycl_device().memcpyDeviceToHost( 194 dst_ptr, static_cast<const Index*>(src_ptr), index_size); 195 196 auto indices_flat = indices_host.flat<Index>(); 197 auto params_flat = params.flat_outer_dims<T>(); 198 auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N}); 199 200 functor::ScatterFunctorSYCL<T, Index, op> functor; 201 const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(), 202 params_flat, updates_flat, indices_flat); 203 OP_REQUIRES( 204 c, bad_i < 0, 205 errors::InvalidArgument( 206 "indices", SliceDebugString(indices.shape(), bad_i), " = ", 207 indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")")); 208 } 209 } 210 }; 211 #endif // TENSORFLOW_USE_SYCL 212 213 #define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \ 214 REGISTER_KERNEL_BUILDER(Name(name) \ 215 .Device(DEVICE_##dev) \ 216 .TypeConstraint<type>("T") \ 217 .TypeConstraint<index_type>("Tindices"), \ 218 ScatterUpdateOp<dev##Device, type, index_type, op>) 219 220 #define REGISTER_SCATTER_KERNEL(type, dev, name, op) \ 221 REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \ 222 REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op); 223 224 #define REGISTER_SCATTER_ARITHEMTIC(type, dev) \ 225 REGISTER_SCATTER_KERNEL(type, dev, "ScatterAdd", scatter_op::UpdateOp::ADD); \ 226 REGISTER_SCATTER_KERNEL(type, dev, "ScatterDiv", scatter_op::UpdateOp::DIV); \ 227 REGISTER_SCATTER_KERNEL(type, dev, "ScatterMul", scatter_op::UpdateOp::MUL); \ 228 REGISTER_SCATTER_KERNEL(type, dev, "ScatterSub", scatter_op::UpdateOp::SUB); 229 230 #define REGISTER_SCATTER_UPDATE(type, dev) \ 231 REGISTER_SCATTER_KERNEL(type, dev, "ScatterUpdate", \ 232 scatter_op::UpdateOp::ASSIGN); 233 234 // Registers CPU kernels. 235 #define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \ 236 REGISTER_SCATTER_ARITHEMTIC(type, CPU); 237 238 #define REGISTER_SCATTER_UPDATE_CPU(type) REGISTER_SCATTER_UPDATE(type, CPU); 239 240 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU); 241 TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU); 242 243 // Registers GPU kernels. 244 #if GOOGLE_CUDA 245 #define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \ 246 REGISTER_SCATTER_ARITHEMTIC(type, GPU); 247 248 #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU); 249 250 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU); 251 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU); 252 253 #endif // GOOGLE_CUDA 254 255 // Registers GPU kernels. 256 #if TENSORFLOW_USE_SYCL 257 #define REGISTER_SCATTER_ARITHEMTIC_SYCL(type) \ 258 REGISTER_SCATTER_ARITHEMTIC(type, SYCL); 259 260 #define REGISTER_SCATTER_UPDATE_SYCL(type) REGISTER_SCATTER_UPDATE(type, SYCL); 261 262 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_SYCL); 263 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_SYCL); 264 265 #undef REGISTER_SCATTER_ARITHEMTIC_SYCL 266 #undef REGISTER_SCATTER_UPDATE_SYCL 267 #endif // TENSORFLOW_USE_SYCL 268 269 #undef REGISTER_SCATTER_ARITHEMTIC 270 #undef REGISTER_SCATTER_ARITHEMTIC_CPU 271 #undef REGISTER_SCATTER_ARITHEMTIC_GPU 272 #undef REGISTER_SCATTER_UPDATE 273 #undef REGISTER_SCATTER_UPDATE_CPU 274 #undef REGISTER_SCATTER_UPDATE_GPU 275 #undef REGISTER_SCATTER_KERNEL 276 #undef REGISTER_SCATTER_KERNEL_INDEX 277 278 } // namespace tensorflow 279