Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #if GOOGLE_CUDA
     17 
     18 #define EIGEN_USE_GPU
     19 
     20 #include "tensorflow/core/framework/register_types.h"
     21 #include "tensorflow/core/framework/tensor_types.h"
     22 #include "tensorflow/core/kernels/fill_functor.h"
     23 #include "tensorflow/core/platform/types.h"
     24 
     25 namespace Eigen {
     26 namespace internal {
     27 
     28 template <typename T>
     29 struct scalar_const_op {
     30   typedef typename packet_traits<T>::type Packet;
     31 
     32   const T* val;
     33 
     34   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
     35   scalar_const_op(const scalar_const_op& x)
     36       : val(x.val) {}
     37 
     38   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_const_op(const T* v) : val(v) {}
     39 
     40   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()() const {
     41     return *val;
     42   }
     43 
     44   template <typename PacketType = Packet>
     45   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const PacketType packetOp() const {
     46     return internal::pset1<PacketType>(*val);
     47   }
     48 };
     49 
     50 template <typename T>
     51 struct functor_traits<scalar_const_op<T> > {
     52   enum {
     53     Cost = 1,
     54     PacketAccess = packet_traits<T>::Vectorizable,
     55     IsRepeatable = true
     56   };
     57 };
     58 
     59 }  // end namespace internal
     60 }  // end namespace Eigen
     61 
     62 namespace tensorflow {
     63 
     64 namespace functor {
     65 
     66 typedef Eigen::GpuDevice GPUDevice;
     67 
     68 // Partial specialization FillFunctor<Device=GPUDevice, T>
     69 template <typename T>
     70 struct FillFunctor<GPUDevice, T> {
     71   void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
     72                   typename TTypes<T>::ConstScalar in) {
     73     Eigen::internal::scalar_const_op<T> f(in.data());
     74     To32Bit(out).device(d) = To32Bit(out).nullaryExpr(f);
     75   }
     76 };
     77 
     78 #define DEFINE_FILL_GPU(T) template struct FillFunctor<GPUDevice, T>;
     79 TF_CALL_REAL_NUMBER_TYPES(DEFINE_FILL_GPU);
     80 TF_CALL_bool(DEFINE_FILL_GPU);
     81 #undef DEFINE_FILL_GPU
     82 
     83 // Partial specialization of FillFunctor<Device=GPUDevice, T>.
     84 template <typename T>
     85 struct SetZeroFunctor<GPUDevice, T> {
     86   void operator()(const GPUDevice& d, typename TTypes<T>::Flat out) {
     87     To32Bit(out).device(d) = To32Bit(out).constant(T(0));
     88   }
     89 };
     90 
     91 #define DEFINE_SETZERO_GPU(T) template struct SetZeroFunctor<GPUDevice, T>;
     92 TF_CALL_NUMBER_TYPES(DEFINE_SETZERO_GPU);
     93 TF_CALL_bool(DEFINE_SETZERO_GPU);
     94 #undef DEFINE_SETZERO_GPU
     95 
     96 // Partial specialization of FillFunctor<Device=GPUDevice, T>.
     97 template <typename T>
     98 struct SetOneFunctor<GPUDevice, T> {
     99   void operator()(const GPUDevice& d, typename TTypes<T>::Flat out) {
    100     To32Bit(out).device(d) = To32Bit(out).constant(T(1));
    101   }
    102 };
    103 
    104 #define DEFINE_SETONE_GPU(T) template struct SetOneFunctor<GPUDevice, T>;
    105 TF_CALL_NUMBER_TYPES(DEFINE_SETONE_GPU);
    106 TF_CALL_bool(DEFINE_SETONE_GPU);
    107 #undef DEFINE_SETONE_GPU
    108 
    109 }  // end namespace functor
    110 }  // end namespace tensorflow
    111 
    112 #endif  // GOOGLE_CUDA
    113