Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_SCATTER_FUNCTOR_H_
     17 #define TENSORFLOW_KERNELS_SCATTER_FUNCTOR_H_
     18 
     19 #include <type_traits>
     20 
     21 #include "tensorflow/core/framework/tensor.h"
     22 #include "tensorflow/core/kernels/bounds_check.h"
     23 #include "tensorflow/core/platform/types.h"
     24 
     25 namespace tensorflow {
     26 
     27 class OpKernelContext;
     28 typedef Eigen::ThreadPoolDevice CPUDevice;
     29 typedef Eigen::GpuDevice GPUDevice;
     30 #ifdef TENSORFLOW_USE_SYCL
     31 typedef Eigen::SyclDevice SYCLDevice;
     32 #endif  // TENSORFLOW_USE_SYCL
     33 
     34 namespace scatter_op {
     35 
     36 enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV };
     37 
     38 namespace internal {
     39 
     40 template <scatter_op::UpdateOp Op>
     41 struct Assign {};
     42 template <>
     43 struct Assign<scatter_op::UpdateOp::ASSIGN> {
     44   template <typename Params, typename Update>
     45   static void Run(Params p, Update u) {
     46     p = u;
     47   }
     48 };
     49 template <>
     50 struct Assign<scatter_op::UpdateOp::ADD> {
     51   template <typename Params, typename Update>
     52   static void Run(Params p, Update u) {
     53     p += u;
     54   }
     55 };
     56 template <>
     57 struct Assign<scatter_op::UpdateOp::SUB> {
     58   template <typename Params, typename Update>
     59   static void Run(Params p, Update u) {
     60     p -= u;
     61   }
     62 };
     63 template <>
     64 struct Assign<scatter_op::UpdateOp::MUL> {
     65   template <typename Params, typename Update>
     66   static void Run(Params p, Update u) {
     67     p *= u;
     68   }
     69 };
     70 template <>
     71 struct Assign<scatter_op::UpdateOp::DIV> {
     72   template <typename Params, typename Update>
     73   static void Run(Params p, Update u) {
     74     p /= u;
     75   }
     76 };
     77 
     78 #ifdef TENSORFLOW_USE_SYCL
     79 template <scatter_op::UpdateOp Op>
     80 struct AssignSYCL {};
     81 template <>
     82 struct AssignSYCL<scatter_op::UpdateOp::ASSIGN> {
     83   template <typename Device, typename Params, typename Update>
     84   static void Run(Device d, Params p, Update u) {
     85     p.device(d) = u;
     86   }
     87 };
     88 
     89 template <>
     90 struct AssignSYCL<scatter_op::UpdateOp::ADD> {
     91   template <typename Device, typename Params, typename Update>
     92   static void Run(Device d, Params p, Update u) {
     93     p.device(d) += u;
     94   }
     95 };
     96 
     97 template <>
     98 struct AssignSYCL<scatter_op::UpdateOp::SUB> {
     99   template <typename Device, typename Params, typename Update>
    100   static void Run(Device d, Params p, Update u) {
    101     p.device(d) -= u;
    102   }
    103 };
    104 
    105 template <>
    106 struct AssignSYCL<scatter_op::UpdateOp::MUL> {
    107   template <typename Device, typename Params, typename Update>
    108   static void Run(Device d, Params p, Update u) {
    109     p.device(d) = p * u;
    110   }
    111 };
    112 
    113 template <>
    114 struct AssignSYCL<scatter_op::UpdateOp::DIV> {
    115   template <typename Device, typename Params, typename Update>
    116   static void Run(Device d, Params p, Update u) {
    117     p.device(d) = p / u;
    118   }
    119 };
    120 #endif  // TENSORFLOW_USE_SYCL
    121 
    122 }  // namespace internal
    123 }  // namespace scatter_op
    124 
    125 namespace functor {
    126 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
    127 struct ScatterFunctor {
    128   Index operator()(OpKernelContext* c, const Device& d,
    129                    typename TTypes<T>::Matrix params,
    130                    typename TTypes<T>::ConstMatrix updates,
    131                    typename TTypes<Index>::ConstFlat indices);
    132 };
    133 
    134 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
    135 struct ScatterFunctorBase {
    136   Index operator()(OpKernelContext* c, const Device& d,
    137                    typename TTypes<T>::Matrix params,
    138                    typename TTypes<T>::ConstMatrix updates,
    139                    typename TTypes<Index>::ConstFlat indices) {
    140     // indices and params sizes were validated in DoCompute().
    141     const Index N = static_cast<Index>(indices.size());
    142     const Index limit = static_cast<Index>(params.dimension(0));
    143     for (Index i = 0; i < N; i++) {
    144       // Grab the index and check its validity.  An earlier version of the
    145       // code checked it and then grabbed it from memory a second time, which
    146       // was a security risk since it could have changed in between.
    147       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
    148       if (!FastBoundsCheck(index, limit)) return i;
    149       // Copy last Ndim-1 dimensions of updates[i] to params[index]
    150       scatter_op::internal::Assign<op>::Run(params.template chip<0>(index),
    151                                             updates.template chip<0>(i));
    152     }
    153     return -1;
    154   }
    155 };
    156 
    157 #ifdef TENSORFLOW_USE_SYCL
    158 template <typename T, typename Index, scatter_op::UpdateOp op>
    159 struct ScatterFunctorBase<SYCLDevice, T, Index, op> {
    160   Index operator()(OpKernelContext* c, const SYCLDevice& d,
    161                    typename TTypes<T>::Matrix params,
    162                    typename TTypes<T>::ConstMatrix updates,
    163                    typename TTypes<Index>::ConstFlat indices) {
    164     // indices and params sizes were validated in DoCompute().
    165     const Index N = static_cast<Index>(indices.size());
    166     const Index limit = static_cast<Index>(params.dimension(0));
    167     for (Index i = 0; i < N; i++) {
    168       // Grab the index and check its validity.  An earlier version of the
    169       // code checked it and then grabbed it from memory a second time, which
    170       // was a security risk since it could have changed in between.
    171       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
    172       if (!FastBoundsCheck(index, limit)) return i;
    173       // Copy last Ndim-1 dimensions of updates[i] to params[index]
    174       scatter_op::internal::AssignSYCL<op>::Run(
    175           d, params.template chip<0>(index), updates.template chip<0>(i));
    176     }
    177     return -1;
    178   }
    179 };
    180 #endif  // TENSORFLOW_USE_SYCL
    181 
    182 template <typename T, typename Index>
    183 struct ScatterFunctorBase<CPUDevice, T, Index, scatter_op::UpdateOp::ASSIGN> {
    184   Index operator()(OpKernelContext* c, const CPUDevice& d,
    185                    typename TTypes<T>::Matrix params,
    186                    typename TTypes<T>::ConstMatrix updates,
    187                    typename TTypes<Index>::ConstFlat indices) {
    188     // indices and params sizes were validated in DoCompute().
    189     const Index N = static_cast<Index>(indices.size());
    190     const Index limit = static_cast<Index>(params.dimension(0));
    191     if (!std::is_same<T, string>::value) {
    192       for (Index i = 0; i < N; i++) {
    193         // Grab the index and check its validity.  An earlier version of the
    194         // code checked it and then grabbed it from memory a second time, which
    195         // was a security risk since it could have changed in between.
    196         const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
    197         if (!FastBoundsCheck(index, limit)) return i;
    198         memmove(params.data() + index * params.dimension(1),
    199                 updates.data() + i * updates.dimension(1),
    200                 updates.dimension(1) * sizeof(T));
    201       }
    202     } else {
    203       for (Index i = 0; i < N; i++) {
    204         // Grab the index and check its validity.  An earlier version of the
    205         // code checked it and then grabbed it from memory a second time, which
    206         // was a security risk since it could have changed in between.
    207         const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
    208         if (!FastBoundsCheck(index, limit)) return i;
    209         // Copy last Ndim-1 dimensions of updates[i] to params[index]
    210         scatter_op::internal::Assign<scatter_op::UpdateOp::ASSIGN>::Run(
    211             params.template chip<0>(index), updates.template chip<0>(i));
    212       }
    213     }
    214     return -1;
    215   }
    216 };
    217 
    218 template <typename T, typename Index, scatter_op::UpdateOp op>
    219 struct ScatterFunctor<CPUDevice, T, Index, op>
    220     : ScatterFunctorBase<CPUDevice, T, Index, op> {};
    221 
    222 #ifdef TENSORFLOW_USE_SYCL
    223 template <typename T, typename Index, scatter_op::UpdateOp op>
    224 struct ScatterFunctorSYCL {
    225   Index operator()(OpKernelContext* c, const SYCLDevice& d,
    226                    typename TTypes<T>::Matrix params,
    227                    typename TTypes<T>::ConstMatrix updates,
    228                    typename TTypes<Index>::Flat indices) {
    229     // indices and params sizes were validated in DoCompute().
    230     const Index N = static_cast<Index>(indices.size());
    231     const Index limit = static_cast<Index>(params.dimension(0));
    232     for (Index i = 0; i < N; i++) {
    233       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
    234       if (!FastBoundsCheck(index, limit)) return i;
    235       // Copy last Ndim-1 dimensions of updates[i] to params[index]
    236       scatter_op::internal::AssignSYCL<op>::Run(
    237           d, params.template chip<0>(index), updates.template chip<0>(i));
    238     }
    239     return -1;
    240   }
    241 };
    242 #endif  // TENSORFLOW_USE_SYCL
    243 
    244 }  // namespace functor
    245 }  // namespace tensorflow
    246 
    247 #endif  // TENSORFLOW_KERNELS_SCATTER_FUNCTOR_H_
    248