Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
     17 #define TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
     18 
     19 // Functor definitions for ScatterND ops, must be compilable by nvcc.
     20 
     21 #define EIGEN_USE_THREADS
     22 
     23 #include <atomic>
     24 
     25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     26 
     27 #include "tensorflow/core/framework/op_kernel.h"
     28 #include "tensorflow/core/framework/register_types.h"
     29 #include "tensorflow/core/framework/tensor.h"
     30 #include "tensorflow/core/framework/tensor_shape.h"
     31 #include "tensorflow/core/kernels/bounds_check.h"
     32 #include "tensorflow/core/kernels/fill_functor.h"
     33 #include "tensorflow/core/kernels/scatter_nd_op.h"
     34 #include "tensorflow/core/platform/mutex.h"
     35 #include "tensorflow/core/platform/types.h"
     36 #include "tensorflow/core/util/util.h"
     37 
     38 namespace tensorflow {
     39 
     40 typedef Eigen::ThreadPoolDevice CPUDevice;
     41 #ifdef TENSORFLOW_USE_SYCL
     42 typedef Eigen::SyclDevice SYCLDevice;
     43 #endif  // TENSORFLOW_USE_SYCL
     44 
     45 class OpKernelContext;
     46 
     47 // Specialization of UpdateExecutor to CPU
     48 namespace update_executor {
     49 
     50 template <typename Input, typename Update, typename Output,
     51           scatter_nd_op::UpdateOp OP>
     52 class UpdateExecutor {
     53  public:
     54   EIGEN_STRONG_INLINE static void Execute(Input value, Update update,
     55                                           Output output);
     56 };
     57 
     58 template <typename Input, typename Update, typename Output>
     59 class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::ASSIGN> {
     60  public:
     61   EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update,
     62                                           Output output) {
     63     output = update;
     64   }
     65 };
     66 
     67 template <typename Input, typename Update, typename Output>
     68 class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::ADD> {
     69  public:
     70   EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update,
     71                                           Output output) {
     72     output += update;
     73   }
     74 };
     75 
     76 template <typename Input, typename Update, typename Output>
     77 class UpdateExecutor<Input, Update, Output, scatter_nd_op::UpdateOp::SUB> {
     78  public:
     79   EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update,
     80                                           Output output) {
     81     output -= update;
     82   }
     83 };
     84 
     85 }  // namespace update_executor
     86 
     87 namespace functor {
     88 
     89 // Implementation of update functor for CPU.
     90 template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM>
     91 struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM> {
     92   Index operator()(
     93       const CPUDevice& d, const Index slice_size,
     94       const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
     95       typename TTypes<T, 2>::Tensor Tparams,
     96       typename TTypes<Index, 2>::ConstTensor Tindices,
     97       typename TTypes<T, 2>::ConstTensor Tupdates,
     98       typename TTypes<T, 2>::Tensor Toutput) {
     99     // error_loc is -1 if there's no out-of-bounds index,
    100     // otherwise it is the location of an OOB index in Tindices.
    101     Index error_loc = -1;
    102 
    103     const Eigen::DenseIndex batch_size = Tindices.dimension(0);
    104 
    105     Index batch_strides[IXDIM];
    106     for (int dim = IXDIM - 1; dim >= 0; --dim) {
    107       if (dim == IXDIM - 1) {
    108         batch_strides[dim] = 1;
    109       } else {
    110         batch_strides[dim] =
    111             batch_strides[dim + 1] * output_shape_prefix[dim + 1];
    112       }
    113     }
    114 
    115     for (Eigen::DenseIndex loc = 0; loc < batch_size; ++loc) {
    116       Index i = 0;
    117       bool out_of_bounds = false;
    118       for (int dim = 0; dim < IXDIM; ++dim) {
    119         const Index ix_d = internal::SubtleMustCopy(Tindices(loc, dim));
    120         out_of_bounds |= !FastBoundsCheck(ix_d, output_shape_prefix[dim]);
    121         i += ix_d * batch_strides[dim];
    122       }
    123       if (TF_PREDICT_FALSE(out_of_bounds)) {
    124         error_loc = loc;
    125         break;
    126       } else {
    127         auto input_chip = Toutput.template chip<0>(i);
    128         auto output_chip = input_chip.device(d);
    129         auto update_chip = Tupdates.template chip<0>(loc);
    130         update_executor::UpdateExecutor<
    131             decltype(input_chip), decltype(update_chip), decltype(output_chip),
    132             OP>::Execute(input_chip, update_chip, output_chip);
    133       }
    134     }
    135 
    136     return error_loc;
    137   }
    138 };
    139 
    140 #define REGISTER_SCATTER_ND_FULL(T, Index, op)                               \
    141   template Index                                                             \
    142   ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator()( \
    143       const CPUDevice& d, const Index slice_size,                            \
    144       const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM>              \
    145           output_shape_prefix,                                               \
    146       typename TTypes<T, 2>::Tensor Tparams,                                 \
    147       typename TTypes<Index, 2>::ConstTensor Tindices,                       \
    148       typename TTypes<T, 2>::ConstTensor Tupdates,                           \
    149       typename TTypes<T, 2>::Tensor Toutput)
    150 
    151 #define REGISTER_SCATTER_ND_INDEX(type, op)  \
    152   REGISTER_SCATTER_ND_FULL(type, int32, op); \
    153   REGISTER_SCATTER_ND_FULL(type, int64, op)
    154 
    155 #define REGISTER_SCATTER_ND_UPDATE(type) \
    156   REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ASSIGN);
    157 
    158 #define REGISTER_SCATTER_ND_MATH(type)                           \
    159   REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ADD); \
    160   REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::SUB);
    161 
    162 TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE);
    163 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH)
    164 
    165 #undef REGISTER_SCATTER_ND_MATH
    166 #undef REGISTER_SCATTER_ND_UPDATE
    167 #undef REGISTER_SCATTER_ND_INDEX
    168 #undef REGISTER_SCATTER_ND_FULL
    169 
    170 #ifdef TENSORFLOW_USE_SYCL
    171 // Implementation of update functor for SYCL.
    172 template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM>
    173 struct ScatterNdFunctor<SYCLDevice, T, Index, OP, IXDIM> {
    174   Index operator()(
    175       const SYCLDevice& d, const Index slice_size,
    176       const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
    177       typename TTypes<T, 2>::Tensor Tparams,
    178       typename TTypes<Index, 2>::ConstTensor Tindices,
    179       typename TTypes<T, 2>::ConstTensor Tupdates,
    180       typename TTypes<T, 2>::Tensor Toutput) {
    181     // error_loc is -1 if there's no out-of-bounds index,
    182     // otherwise it is the location of an OOB index in Tindices.
    183     Index error_loc = -1;
    184 
    185     const Eigen::DenseIndex batch_size = Tindices.dimension(0);
    186 
    187     Index batch_strides[IXDIM];
    188     for (int dim = IXDIM - 1; dim >= 0; --dim) {
    189       if (dim == IXDIM - 1) {
    190         batch_strides[dim] = 1;
    191       } else {
    192         batch_strides[dim] =
    193             batch_strides[dim + 1] * output_shape_prefix[dim + 1];
    194       }
    195     }
    196 
    197     for (Eigen::DenseIndex loc = 0; loc < batch_size; ++loc) {
    198       Index i = 0;
    199       bool out_of_bounds = false;
    200       for (int dim = 0; dim < IXDIM; ++dim) {
    201         const Index ix_d = internal::SubtleMustCopy(Tindices(loc, dim));
    202         out_of_bounds |= !FastBoundsCheck(ix_d, output_shape_prefix[dim]);
    203         i += ix_d * batch_strides[dim];
    204       }
    205       if (TF_PREDICT_FALSE(out_of_bounds)) {
    206         error_loc = loc;
    207         break;
    208       } else {
    209         auto input_chip = Toutput.template chip<0>(i);
    210         auto output_chip = input_chip.device(d);
    211         auto update_chip = Tupdates.template chip<0>(loc);
    212         update_executor::UpdateExecutor<
    213             decltype(input_chip), decltype(update_chip), decltype(output_chip),
    214             OP>::Execute(input_chip, update_chip, output_chip);
    215       }
    216     }
    217 
    218     return error_loc;
    219   }
    220 };
    221 
    222 #define REGISTER_SCATTER_ND_FULL_SYCL(T, Index, op)                           \
    223   template Index                                                              \
    224   ScatterNdFunctor<SYCLDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator()( \
    225       const SYCLDevice& d, const Index slice_size,                            \
    226       const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM>               \
    227           output_shape_prefix,                                                \
    228       typename TTypes<T, 2>::Tensor Tparams,                                  \
    229       typename TTypes<Index, 2>::ConstTensor Tindices,                        \
    230       typename TTypes<T, 2>::ConstTensor Tupdates,                            \
    231       typename TTypes<T, 2>::Tensor Toutput)
    232 
    233 #define REGISTER_SCATTER_ND_INDEX_SYCL(type, op)  \
    234   REGISTER_SCATTER_ND_FULL_SYCL(type, int32, op); \
    235   REGISTER_SCATTER_ND_FULL_SYCL(type, int64, op)
    236 
    237 #define REGISTER_SCATTER_ND_UPDATE_SYCL(type) \
    238   REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::ASSIGN);
    239 
    240 #define REGISTER_SCATTER_ND_MATH_SYCL(type)                           \
    241   REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::ADD); \
    242   REGISTER_SCATTER_ND_INDEX_SYCL(type, scatter_nd_op::UpdateOp::SUB);
    243 
    244 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL)
    245 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_MATH_SYCL)
    246 REGISTER_SCATTER_ND_UPDATE_SYCL(int32);
    247 REGISTER_SCATTER_ND_MATH_SYCL(int32);
    248 
    249 #undef REGISTER_SCATTER_ND_MATH_SYCL
    250 #undef REGISTER_SCATTER_ND_UPDATE_SYCL
    251 #undef REGISTER_SCATTER_ND_INDEX_SYCL
    252 #undef REGISTER_SCATTER_ND_FULL_SYCL
    253 
    254 #endif  // TENSORFLOW_USE_SYCL
    255 
    256 }  // namespace functor
    257 
    258 }  // namespace tensorflow
    259 
    260 #endif  // TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
    261