Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_CWISE_OPS_GRADIENTS_H_
     17 #define TENSORFLOW_KERNELS_CWISE_OPS_GRADIENTS_H_
     18 
     19 #define EIGEN_USE_THREADS
     20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     21 #include "tensorflow/core/kernels/cwise_ops.h"
     22 
     23 namespace Eigen {
     24 namespace internal {
     25 
     26 // Gradient for the tanh function
     27 template <typename T>
     28 struct scalar_tanh_gradient_op {
     29   EIGEN_EMPTY_STRUCT_CTOR(scalar_tanh_gradient_op)
     30   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
     31   operator()(const T& output, const T& output_gradient) const {
     32     return output_gradient * (T(1) - output * output);
     33   }
     34   template <typename Packet>
     35   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
     36   packetOp(const Packet& output, const Packet& output_gradient) const {
     37     return pmul(output_gradient,
     38                 psub(pset1<Packet>(T(1)), pmul(output, output)));
     39   }
     40 };
     41 template <typename T>
     42 struct functor_traits<scalar_tanh_gradient_op<T>> {
     43   enum {
     44     Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost,
     45     PacketAccess = packet_traits<T>::HasSub && packet_traits<T>::HasMul,
     46   };
     47 };
     48 
     49 // Gradient for the sigmoid function
     50 template <typename T>
     51 struct scalar_sigmoid_gradient_op {
     52   EIGEN_EMPTY_STRUCT_CTOR(scalar_sigmoid_gradient_op)
     53   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
     54   operator()(const T& output, const T& output_gradient) const {
     55     return output_gradient * output * (T(1) - output);
     56   }
     57   template <typename Packet>
     58   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
     59   packetOp(const Packet& output, const Packet& output_gradient) const {
     60     return pmul(output_gradient,
     61                 pmul(output, psub(pset1<Packet>(T(1)), output)));
     62   }
     63 };
     64 template <typename T>
     65 struct functor_traits<scalar_sigmoid_gradient_op<T>> {
     66   enum {
     67     Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost,
     68     PacketAccess = packet_traits<T>::HasSub && packet_traits<T>::HasMul,
     69   };
     70 };
     71 
     72 // Gradient for the inverse function
     73 template <typename T>
     74 struct scalar_inverse_gradient_op {
     75   EIGEN_EMPTY_STRUCT_CTOR(scalar_inverse_gradient_op)
     76   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
     77   operator()(const T& output, const T& output_gradient) const {
     78     const T out_conj = numext::conj(output);
     79     return -output_gradient * out_conj * out_conj;
     80   }
     81   template <typename Packet>
     82   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
     83   packetOp(const Packet& output, const Packet& output_gradient) const {
     84     const Packet out_conj = pconj(output);
     85     return pnegate(pmul(output_gradient, pmul(out_conj, out_conj)));
     86   }
     87 };
     88 template <typename T>
     89 struct functor_traits<scalar_inverse_gradient_op<T>> {
     90   enum {
     91     Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost,
     92     PacketAccess = packet_traits<T>::HasMul,
     93   };
     94 };
     95 
     96 // Gradient for the sqrt function
     97 template <typename T>
     98 struct scalar_sqrt_gradient_op {
     99   EIGEN_EMPTY_STRUCT_CTOR(scalar_sqrt_gradient_op)
    100   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
    101   operator()(const T& output, const T& output_gradient) const {
    102     const T out_conj = numext::conj(output);
    103     return static_cast<T>(0.5) * output_gradient / out_conj;
    104   }
    105   template <typename Packet>
    106   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
    107   packetOp(const Packet& output, const Packet& output_gradient) const {
    108     const Packet const_half = pset1<Packet>(static_cast<T>(0.5));
    109     const Packet out_conj = pconj(output);
    110     return pdiv(pmul(const_half, output_gradient), out_conj);
    111   }
    112 };
    113 template <typename T>
    114 struct functor_traits<scalar_sqrt_gradient_op<T>> {
    115   enum {
    116     PacketAccess = packet_traits<T>::HasMul & packet_traits<T>::HasDiv,
    117     Cost = NumTraits<T>::MulCost + scalar_div_cost<T, PacketAccess>::value,
    118   };
    119 };
    120 
    121 // Gradient for the rsqrt function
    122 template <typename T>
    123 struct scalar_rsqrt_gradient_op {
    124   EIGEN_EMPTY_STRUCT_CTOR(scalar_rsqrt_gradient_op)
    125   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
    126   operator()(const T& output, const T& output_gradient) const {
    127     const T out_conj = numext::conj(output);
    128     return static_cast<T>(-0.5) * (output_gradient * out_conj) *
    129            (out_conj * out_conj);
    130   }
    131   template <typename Packet>
    132   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
    133   packetOp(const Packet& output, const Packet& output_gradient) const {
    134     const Packet const_half = pset1<Packet>(static_cast<T>(-0.5));
    135     const Packet out_conj = pconj(output);
    136     return pmul(const_half, pmul(pmul(output_gradient, out_conj),
    137                                  pmul(out_conj, out_conj)));
    138   }
    139 };
    140 template <typename T>
    141 struct functor_traits<scalar_rsqrt_gradient_op<T>> {
    142   enum {
    143     Cost = 4 * NumTraits<T>::MulCost,
    144     PacketAccess = packet_traits<T>::HasMul,
    145   };
    146 };
    147 
    148 }  // end namespace internal
    149 }  // end namespace Eigen
    150 
    151 namespace tensorflow {
    152 
    153 namespace functor {
    154 
    155 template <typename Device, typename Functor>
    156 struct SimpleBinaryFunctor {
    157   void operator()(const Device& d, typename Functor::tout_type out,
    158                   typename Functor::tin_type in0,
    159                   typename Functor::tin_type in1);
    160 };
    161 
    162 // Partial specialization of BinaryFunctor for CPU devices
    163 typedef Eigen::ThreadPoolDevice CPUDevice;
    164 
    165 template <typename Functor>
    166 struct SimpleBinaryFunctor<CPUDevice, Functor> {
    167   void operator()(const CPUDevice& d, typename Functor::tout_type out,
    168                   typename Functor::tin_type in0,
    169                   typename Functor::tin_type in1) {
    170     out.device(d) = in0.binaryExpr(in1, typename Functor::func());
    171   }
    172 };
    173 
    174 #ifdef TENSORFLOW_USE_SYCL
    175 // Partial specialization of BinaryFunctor for SYCL devices
    176 typedef Eigen::SyclDevice SYCLDevice;
    177 template <typename Functor>
    178 struct SimpleBinaryFunctor<SYCLDevice, Functor> {
    179   void operator()(const SYCLDevice& d, typename Functor::tout_type out,
    180                   typename Functor::tin_type in0,
    181                   typename Functor::tin_type in1) {
    182     out.device(d) = in0.binaryExpr(in1, typename Functor::func());
    183   }
    184 };
    185 
    186 #endif  // TENSORFLOW_USE_SYCL
    187 
    188 template <typename T>
    189 struct tanh_grad : base<T, Eigen::internal::scalar_tanh_gradient_op<T>> {};
    190 
    191 template <typename T>
    192 struct sigmoid_grad : base<T, Eigen::internal::scalar_sigmoid_gradient_op<T>> {
    193 };
    194 
    195 template <typename T>
    196 struct inverse_grad : base<T, Eigen::internal::scalar_inverse_gradient_op<T>> {
    197 };
    198 
    199 template <typename T>
    200 struct sqrt_grad : base<T, Eigen::internal::scalar_sqrt_gradient_op<T>> {};
    201 
    202 template <typename T>
    203 struct rsqrt_grad : base<T, Eigen::internal::scalar_rsqrt_gradient_op<T>> {};
    204 
    205 }  // end namespace functor
    206 
    207 }  // end namespace tensorflow
    208 #endif  // TENSORFLOW_KERNELS_CWISE_OPS_GRADIENTS_H_
    209