Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/fill_functor.h"
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     21 #include "tensorflow/core/framework/register_types.h"
     22 #include "tensorflow/core/framework/tensor_types.h"
     23 #include "tensorflow/core/framework/types.h"
     24 #include "tensorflow/core/framework/variant_encode_decode.h"
     25 
     26 namespace tensorflow {
     27 namespace functor {
     28 
     29 template <typename T>
     30 void SetZeroFunctor<Eigen::ThreadPoolDevice, T>::operator()(
     31     const Eigen::ThreadPoolDevice& d, typename TTypes<T>::Flat out) {
     32   out.device(d) = out.constant(T(0));
     33 }
     34 
     35 void SetZeroFunctor<Eigen::ThreadPoolDevice, string>::operator()(
     36     const Eigen::ThreadPoolDevice& d, typename TTypes<string>::Flat out) {
     37   out.device(d) = out.constant(string());
     38 }
     39 
     40 // Explicit instantiations.
     41 #define DEFINE_SETZERO_CPU(T) \
     42   template struct SetZeroFunctor<Eigen::ThreadPoolDevice, T>;
     43 DEFINE_SETZERO_CPU(bool);
     44 DEFINE_SETZERO_CPU(Eigen::half);
     45 DEFINE_SETZERO_CPU(bfloat16);
     46 DEFINE_SETZERO_CPU(float);
     47 DEFINE_SETZERO_CPU(double);
     48 DEFINE_SETZERO_CPU(uint8);
     49 DEFINE_SETZERO_CPU(int8);
     50 DEFINE_SETZERO_CPU(uint16);
     51 DEFINE_SETZERO_CPU(int16);
     52 DEFINE_SETZERO_CPU(int32);
     53 DEFINE_SETZERO_CPU(int64);
     54 DEFINE_SETZERO_CPU(complex64);
     55 DEFINE_SETZERO_CPU(complex128);
     56 DEFINE_SETZERO_CPU(Variant);
     57 #undef DEFINE_SETZERO_CPU
     58 
     59 #ifdef TENSORFLOW_USE_SYCL
     60 template <typename T>
     61 void SetZeroFunctor<Eigen::SyclDevice, T>::operator()(
     62     const Eigen::SyclDevice& d, typename TTypes<T>::Flat out) {
     63   To32Bit(out).device(d) = To32Bit(out).constant(T(0));
     64 }
     65 
     66 #define DEFINE_SETZERO_SYCL(T) \
     67   template struct SetZeroFunctor<Eigen::SyclDevice, T>;
     68 DEFINE_SETZERO_SYCL(bool);
     69 DEFINE_SETZERO_SYCL(float);
     70 DEFINE_SETZERO_SYCL(double);
     71 DEFINE_SETZERO_SYCL(uint8);
     72 DEFINE_SETZERO_SYCL(int8);
     73 DEFINE_SETZERO_SYCL(uint16);
     74 DEFINE_SETZERO_SYCL(int16);
     75 DEFINE_SETZERO_SYCL(int32);
     76 DEFINE_SETZERO_SYCL(int64);
     77 #undef DEFINE_SETZERO_SYCL
     78 #endif  // TENSORFLOW_USE_SYCL
     79 
     80 template <typename T>
     81 void SetOneFunctor<Eigen::ThreadPoolDevice, T>::operator()(
     82     const Eigen::ThreadPoolDevice& d, typename TTypes<T>::Flat out) {
     83   out.device(d) = out.constant(T(1));
     84 }
     85 
     86 // Explicit instantiations.
     87 #define DEFINE_SETONE_CPU(T) \
     88   template struct SetOneFunctor<Eigen::ThreadPoolDevice, T>;
     89 DEFINE_SETONE_CPU(bool);
     90 DEFINE_SETONE_CPU(Eigen::half);
     91 DEFINE_SETONE_CPU(bfloat16);
     92 DEFINE_SETONE_CPU(float);
     93 DEFINE_SETONE_CPU(double);
     94 DEFINE_SETONE_CPU(uint8);
     95 DEFINE_SETONE_CPU(int8);
     96 DEFINE_SETONE_CPU(uint16);
     97 DEFINE_SETONE_CPU(int16);
     98 DEFINE_SETONE_CPU(int32);
     99 DEFINE_SETONE_CPU(int64);
    100 DEFINE_SETONE_CPU(complex64);
    101 DEFINE_SETONE_CPU(complex128);
    102 #undef DEFINE_SETONE_CPU
    103 
    104 #ifdef TENSORFLOW_USE_SYCL
    105 template <typename T>
    106 void SetOneFunctor<Eigen::SyclDevice, T>::operator()(
    107     const Eigen::SyclDevice& d, typename TTypes<T>::Flat out) {
    108   out.device(d) = out.constant(T(1));
    109 }
    110 
    111 #define DEFINE_SETONE_SYCL(T) \
    112   template struct SetOneFunctor<Eigen::SyclDevice, T>;
    113 DEFINE_SETONE_SYCL(float);
    114 DEFINE_SETONE_SYCL(bool);
    115 DEFINE_SETONE_SYCL(double);
    116 #undef DEFINE_SETONE_SYCL
    117 #endif  // TENSORFLOW_USE_SYCL
    118 
    119 template <typename T>
    120 struct FillFunctor<Eigen::ThreadPoolDevice, T> {
    121   void operator()(const Eigen::ThreadPoolDevice& d,
    122                   typename TTypes<T>::Flat out,
    123                   typename TTypes<T>::ConstScalar in) {
    124     out.device(d) = out.constant(in());
    125   }
    126 };
    127 
    128 // Explicit instantiations.
    129 #define DEFINE_FILL_CPU(T) \
    130   template struct FillFunctor<Eigen::ThreadPoolDevice, T>;
    131 
    132 TF_CALL_ALL_TYPES(DEFINE_FILL_CPU);
    133 DEFINE_FILL_CPU(quint8);
    134 DEFINE_FILL_CPU(quint16);
    135 #undef DEFINE_FILL_CPU
    136 
    137 #ifdef TENSORFLOW_USE_SYCL
    138 template <typename T>
    139 struct FillFunctor<Eigen::SyclDevice, T> {
    140   void operator()(const Eigen::SyclDevice& d, typename TTypes<T>::Flat out,
    141                   typename TTypes<T>::ConstScalar in) {
    142 #if !defined(EIGEN_HAS_INDEX_LIST)
    143     Eigen::array<int, 1> rank1{1};
    144 #else
    145     Eigen::IndexList<Eigen::type2index<1> > rank1;
    146 #endif
    147     const int size = out.dimension(0);
    148     Eigen::array<int, 1> broadcast_dims{size};
    149 
    150     To32Bit(out).device(d) = in.reshape(rank1).broadcast(broadcast_dims);
    151   }
    152 };
    153 
    154 #define DEFINE_FILL_SYCL(T) template struct FillFunctor<Eigen::SyclDevice, T>;
    155 DEFINE_FILL_SYCL(float);
    156 DEFINE_FILL_SYCL(double);
    157 TF_CALL_INTEGRAL_TYPES(DEFINE_FILL_SYCL)
    158 #undef DEFINE_FILL_SYCL
    159 #endif  // TENSORFLOW_USE_SYCL
    160 
    161 }  // namespace functor
    162 }  // namespace tensorflow
    163