Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/state_ops.cc.
     17 
     18 #include "tensorflow/core/framework/op_kernel.h"
     19 #include "tensorflow/core/framework/register_types.h"
     20 #include "tensorflow/core/framework/tensor.h"
     21 #include "tensorflow/core/kernels/scatter_functor.h"
     22 #include "tensorflow/core/platform/mutex.h"
     23 #include "tensorflow/core/platform/types.h"
     24 #include "tensorflow/core/util/util.h"
     25 
     26 #ifdef TENSORFLOW_USE_SYCL
     27 #include "tensorflow/core/common_runtime/sycl/sycl_util.h"
     28 #endif  // TENSORFLOW_USE_SYCL
     29 
     30 namespace tensorflow {
     31 
     32 typedef Eigen::ThreadPoolDevice CPUDevice;
     33 typedef Eigen::GpuDevice GPUDevice;
     34 #ifdef TENSORFLOW_USE_SYCL
     35 typedef Eigen::SyclDevice SYCLDevice;
     36 #endif  // TENSORFLOW_USE_SYCL
     37 
     38 // Check whether updates.shape = indices.shape + params.shape[1:]
     39 static bool ValidShapes(const Tensor& params, const Tensor& updates,
     40                         const Tensor& indices) {
     41   if (updates.dims() != indices.dims() + params.dims() - 1) return false;
     42   for (int d = 0; d < indices.dims(); d++) {
     43     if (updates.dim_size(d) != indices.dim_size(d)) {
     44       return false;
     45     }
     46   }
     47   for (int d = 1; d < params.dims(); d++) {
     48     if (params.dim_size(d) != updates.dim_size(d - 1 + indices.dims())) {
     49       return false;
     50     }
     51   }
     52   return true;
     53 }
     54 
     55 static void DoValidationChecking(OpKernelContext* c, const Tensor& params,
     56                                  const Tensor& indices, const Tensor& updates) {
     57   OP_REQUIRES(c, params.IsInitialized(),
     58               errors::FailedPrecondition("Null ref for params"));
     59   OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
     60               errors::InvalidArgument("params must be at least 1-D, got shape ",
     61                                       params.shape().DebugString()));
     62   OP_REQUIRES(
     63       c, ValidShapes(params, updates, indices),
     64       errors::InvalidArgument(
     65           "Must have updates.shape = indices.shape + params.shape[1:], got ",
     66           "updates.shape ", updates.shape().DebugString(), ", indices.shape ",
     67           indices.shape().DebugString(), ", params.shape ",
     68           params.shape().DebugString()));
     69 }
     70 
     71 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
     72 class ScatterUpdateOp : public OpKernel {
     73  public:
     74   //   QUESTION: It'd be nice to support DT_INT16, DT_UINT8,
     75   //   etc. here.  Should we have the framework do some sort of
     76   //   integer promotion automatically, or should that be something
     77   //   that users have to do explicitly with a conversion operator
     78   //   in the graph?
     79   explicit ScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
     80     OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
     81   }
     82 
     83   void Compute(OpKernelContext* c) override {
     84     if (use_exclusive_lock_) {
     85       // Hold mutex while we apply updates
     86       mutex_lock l(*c->input_ref_mutex(0));
     87       DoCompute(c);
     88     } else {
     89       DoCompute(c);
     90     }
     91   }
     92 
     93  private:
     94   bool use_exclusive_lock_;
     95 
     96   void DoCompute(OpKernelContext* c) {
     97     Tensor params = c->mutable_input(0, use_exclusive_lock_);
     98     const Tensor& indices = c->input(1);
     99     const Tensor& updates = c->input(2);
    100     DoValidationChecking(c, params, indices, updates);
    101     if (!c->status().ok()) return;
    102 
    103     // Check that we have enough index space
    104     const int64 N_big = indices.NumElements();
    105     OP_REQUIRES(
    106         c, N_big <= std::numeric_limits<Index>::max(),
    107         errors::InvalidArgument("indices has too many elements for ",
    108                                 DataTypeString(DataTypeToEnum<Index>::v()),
    109                                 " indexing: ", N_big, " > ",
    110                                 std::numeric_limits<Index>::max()));
    111     const Index N = static_cast<Index>(indices.NumElements());
    112     OP_REQUIRES(
    113         c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
    114         errors::InvalidArgument("params.shape[0] too large for ",
    115                                 DataTypeString(DataTypeToEnum<Index>::v()),
    116                                 " indexing: ", params.dim_size(0), " > ",
    117                                 std::numeric_limits<Index>::max()));
    118 
    119     // We always return the input ref.
    120     c->forward_ref_input_to_ref_output(0, 0);
    121 
    122     if (N > 0) {
    123       auto indices_flat = indices.flat<Index>();
    124       auto params_flat = params.flat_outer_dims<T>();
    125       auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N});
    126 
    127       functor::ScatterFunctor<Device, T, Index, op> functor;
    128       const Index bad_i = functor(c, c->template eigen_device<Device>(),
    129                                   params_flat, updates_flat, indices_flat);
    130       OP_REQUIRES(
    131           c, bad_i < 0,
    132           errors::InvalidArgument(
    133               "indices", SliceDebugString(indices.shape(), bad_i), " = ",
    134               indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
    135     }
    136   }
    137 };
    138 
    139 #ifdef TENSORFLOW_USE_SYCL
    140 template <typename T, typename Index, scatter_op::UpdateOp op>
    141 class ScatterUpdateOp<SYCLDevice, T, Index, op> : public OpKernel {
    142  public:
    143   explicit ScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
    144     OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
    145   }
    146 
    147   void Compute(OpKernelContext* c) override {
    148     if (use_exclusive_lock_) {
    149       // Hold mutex while we apply updates
    150       mutex_lock l(*c->input_ref_mutex(0));
    151       DoCompute(c);
    152     } else {
    153       DoCompute(c);
    154     }
    155   }
    156 
    157  private:
    158   bool use_exclusive_lock_;
    159 
    160   void DoCompute(OpKernelContext* c) {
    161     Tensor params = c->mutable_input(0, use_exclusive_lock_);
    162     const Tensor& indices = c->input(1);
    163     const Tensor& updates = c->input(2);
    164     DoValidationChecking(c, params, indices, updates);
    165     if (!c->status().ok()) return;
    166 
    167     // Check that we have enough index space
    168     const int64 N_big = indices.NumElements();
    169     OP_REQUIRES(
    170         c, N_big <= std::numeric_limits<Index>::max(),
    171         errors::InvalidArgument("indices has too many elements for ",
    172                                 DataTypeString(DataTypeToEnum<Index>::v()),
    173                                 " indexing: ", N_big, " > ",
    174                                 std::numeric_limits<Index>::max()));
    175     const Index N = static_cast<Index>(indices.NumElements());
    176     OP_REQUIRES(
    177         c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
    178         errors::InvalidArgument("params.shape[0] too large for ",
    179                                 DataTypeString(DataTypeToEnum<Index>::v()),
    180                                 " indexing: ", params.dim_size(0), " > ",
    181                                 std::numeric_limits<Index>::max()));
    182 
    183     // We always return the input ref.
    184     c->forward_ref_input_to_ref_output(0, 0);
    185 
    186     if (N > 0) {
    187       auto index_size = indices.NumElements() * sizeof(Index);
    188       Tensor indices_host = Tensor(indices.dtype(), indices.shape());
    189 
    190       auto src_ptr = GetBase(&indices);
    191       auto dst_ptr = GetBase(&indices_host);
    192 
    193       c->eigen_sycl_device().memcpyDeviceToHost(
    194           dst_ptr, static_cast<const Index*>(src_ptr), index_size);
    195 
    196       auto indices_flat = indices_host.flat<Index>();
    197       auto params_flat = params.flat_outer_dims<T>();
    198       auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N});
    199 
    200       functor::ScatterFunctorSYCL<T, Index, op> functor;
    201       const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(),
    202                                   params_flat, updates_flat, indices_flat);
    203       OP_REQUIRES(
    204           c, bad_i < 0,
    205           errors::InvalidArgument(
    206               "indices", SliceDebugString(indices.shape(), bad_i), " = ",
    207               indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
    208     }
    209   }
    210 };
    211 #endif  // TENSORFLOW_USE_SYCL
    212 
    213 #define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \
    214   REGISTER_KERNEL_BUILDER(Name(name)                                   \
    215                               .Device(DEVICE_##dev)                    \
    216                               .TypeConstraint<type>("T")               \
    217                               .TypeConstraint<index_type>("Tindices"), \
    218                           ScatterUpdateOp<dev##Device, type, index_type, op>)
    219 
    220 #define REGISTER_SCATTER_KERNEL(type, dev, name, op)         \
    221   REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
    222   REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
    223 
    224 #define REGISTER_SCATTER_ARITHEMTIC(type, dev)                                 \
    225   REGISTER_SCATTER_KERNEL(type, dev, "ScatterAdd", scatter_op::UpdateOp::ADD); \
    226   REGISTER_SCATTER_KERNEL(type, dev, "ScatterDiv", scatter_op::UpdateOp::DIV); \
    227   REGISTER_SCATTER_KERNEL(type, dev, "ScatterMul", scatter_op::UpdateOp::MUL); \
    228   REGISTER_SCATTER_KERNEL(type, dev, "ScatterSub", scatter_op::UpdateOp::SUB);
    229 
    230 #define REGISTER_SCATTER_UPDATE(type, dev)            \
    231   REGISTER_SCATTER_KERNEL(type, dev, "ScatterUpdate", \
    232                           scatter_op::UpdateOp::ASSIGN);
    233 
    234 // Registers CPU kernels.
    235 #define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \
    236   REGISTER_SCATTER_ARITHEMTIC(type, CPU);
    237 
    238 #define REGISTER_SCATTER_UPDATE_CPU(type) REGISTER_SCATTER_UPDATE(type, CPU);
    239 
    240 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU);
    241 TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU);
    242 
    243 // Registers GPU kernels.
    244 #if GOOGLE_CUDA
    245 #define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \
    246   REGISTER_SCATTER_ARITHEMTIC(type, GPU);
    247 
    248 #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
    249 
    250 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU);
    251 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU);
    252 
    253 #endif  // GOOGLE_CUDA
    254 
    255 // Registers GPU kernels.
    256 #if TENSORFLOW_USE_SYCL
    257 #define REGISTER_SCATTER_ARITHEMTIC_SYCL(type) \
    258   REGISTER_SCATTER_ARITHEMTIC(type, SYCL);
    259 
    260 #define REGISTER_SCATTER_UPDATE_SYCL(type) REGISTER_SCATTER_UPDATE(type, SYCL);
    261 
    262 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_SYCL);
    263 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_SYCL);
    264 
    265 #undef REGISTER_SCATTER_ARITHEMTIC_SYCL
    266 #undef REGISTER_SCATTER_UPDATE_SYCL
    267 #endif  // TENSORFLOW_USE_SYCL
    268 
    269 #undef REGISTER_SCATTER_ARITHEMTIC
    270 #undef REGISTER_SCATTER_ARITHEMTIC_CPU
    271 #undef REGISTER_SCATTER_ARITHEMTIC_GPU
    272 #undef REGISTER_SCATTER_UPDATE
    273 #undef REGISTER_SCATTER_UPDATE_CPU
    274 #undef REGISTER_SCATTER_UPDATE_GPU
    275 #undef REGISTER_SCATTER_KERNEL
    276 #undef REGISTER_SCATTER_KERNEL_INDEX
    277 
    278 }  // namespace tensorflow
    279