1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_KERNELS_CWISE_OPS_GRADIENTS_H_ 17 #define TENSORFLOW_KERNELS_CWISE_OPS_GRADIENTS_H_ 18 19 #define EIGEN_USE_THREADS 20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 21 #include "tensorflow/core/kernels/cwise_ops.h" 22 23 namespace Eigen { 24 namespace internal { 25 26 // Gradient for the tanh function 27 template <typename T> 28 struct scalar_tanh_gradient_op { 29 EIGEN_EMPTY_STRUCT_CTOR(scalar_tanh_gradient_op) 30 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T 31 operator()(const T& output, const T& output_gradient) const { 32 return output_gradient * (T(1) - output * output); 33 } 34 template <typename Packet> 35 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 36 packetOp(const Packet& output, const Packet& output_gradient) const { 37 return pmul(output_gradient, 38 psub(pset1<Packet>(T(1)), pmul(output, output))); 39 } 40 }; 41 template <typename T> 42 struct functor_traits<scalar_tanh_gradient_op<T>> { 43 enum { 44 Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost, 45 PacketAccess = packet_traits<T>::HasSub && packet_traits<T>::HasMul, 46 }; 47 }; 48 49 // Gradient for the sigmoid function 50 template <typename T> 51 struct scalar_sigmoid_gradient_op { 52 EIGEN_EMPTY_STRUCT_CTOR(scalar_sigmoid_gradient_op) 53 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T 54 operator()(const T& output, const T& output_gradient) const { 55 return output_gradient * output * (T(1) - output); 56 } 57 template <typename Packet> 58 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 59 packetOp(const Packet& output, const Packet& output_gradient) const { 60 return pmul(output_gradient, 61 pmul(output, psub(pset1<Packet>(T(1)), output))); 62 } 63 }; 64 template <typename T> 65 struct functor_traits<scalar_sigmoid_gradient_op<T>> { 66 enum { 67 Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost, 68 PacketAccess = packet_traits<T>::HasSub && packet_traits<T>::HasMul, 69 }; 70 }; 71 72 // Gradient for the inverse function 73 template <typename T> 74 struct scalar_inverse_gradient_op { 75 EIGEN_EMPTY_STRUCT_CTOR(scalar_inverse_gradient_op) 76 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T 77 operator()(const T& output, const T& output_gradient) const { 78 const T out_conj = numext::conj(output); 79 return -output_gradient * out_conj * out_conj; 80 } 81 template <typename Packet> 82 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 83 packetOp(const Packet& output, const Packet& output_gradient) const { 84 const Packet out_conj = pconj(output); 85 return pnegate(pmul(output_gradient, pmul(out_conj, out_conj))); 86 } 87 }; 88 template <typename T> 89 struct functor_traits<scalar_inverse_gradient_op<T>> { 90 enum { 91 Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost, 92 PacketAccess = packet_traits<T>::HasMul, 93 }; 94 }; 95 96 // Gradient for the sqrt function 97 template <typename T> 98 struct scalar_sqrt_gradient_op { 99 EIGEN_EMPTY_STRUCT_CTOR(scalar_sqrt_gradient_op) 100 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T 101 operator()(const T& output, const T& output_gradient) const { 102 const T out_conj = numext::conj(output); 103 return static_cast<T>(0.5) * output_gradient / out_conj; 104 } 105 template <typename Packet> 106 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 107 packetOp(const Packet& output, const Packet& output_gradient) const { 108 const Packet const_half = pset1<Packet>(static_cast<T>(0.5)); 109 const Packet out_conj = pconj(output); 110 return pdiv(pmul(const_half, output_gradient), out_conj); 111 } 112 }; 113 template <typename T> 114 struct functor_traits<scalar_sqrt_gradient_op<T>> { 115 enum { 116 PacketAccess = packet_traits<T>::HasMul & packet_traits<T>::HasDiv, 117 Cost = NumTraits<T>::MulCost + scalar_div_cost<T, PacketAccess>::value, 118 }; 119 }; 120 121 // Gradient for the rsqrt function 122 template <typename T> 123 struct scalar_rsqrt_gradient_op { 124 EIGEN_EMPTY_STRUCT_CTOR(scalar_rsqrt_gradient_op) 125 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T 126 operator()(const T& output, const T& output_gradient) const { 127 const T out_conj = numext::conj(output); 128 return static_cast<T>(-0.5) * (output_gradient * out_conj) * 129 (out_conj * out_conj); 130 } 131 template <typename Packet> 132 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 133 packetOp(const Packet& output, const Packet& output_gradient) const { 134 const Packet const_half = pset1<Packet>(static_cast<T>(-0.5)); 135 const Packet out_conj = pconj(output); 136 return pmul(const_half, pmul(pmul(output_gradient, out_conj), 137 pmul(out_conj, out_conj))); 138 } 139 }; 140 template <typename T> 141 struct functor_traits<scalar_rsqrt_gradient_op<T>> { 142 enum { 143 Cost = 4 * NumTraits<T>::MulCost, 144 PacketAccess = packet_traits<T>::HasMul, 145 }; 146 }; 147 148 } // end namespace internal 149 } // end namespace Eigen 150 151 namespace tensorflow { 152 153 namespace functor { 154 155 template <typename Device, typename Functor> 156 struct SimpleBinaryFunctor { 157 void operator()(const Device& d, typename Functor::tout_type out, 158 typename Functor::tin_type in0, 159 typename Functor::tin_type in1); 160 }; 161 162 // Partial specialization of BinaryFunctor for CPU devices 163 typedef Eigen::ThreadPoolDevice CPUDevice; 164 165 template <typename Functor> 166 struct SimpleBinaryFunctor<CPUDevice, Functor> { 167 void operator()(const CPUDevice& d, typename Functor::tout_type out, 168 typename Functor::tin_type in0, 169 typename Functor::tin_type in1) { 170 out.device(d) = in0.binaryExpr(in1, typename Functor::func()); 171 } 172 }; 173 174 #ifdef TENSORFLOW_USE_SYCL 175 // Partial specialization of BinaryFunctor for SYCL devices 176 typedef Eigen::SyclDevice SYCLDevice; 177 template <typename Functor> 178 struct SimpleBinaryFunctor<SYCLDevice, Functor> { 179 void operator()(const SYCLDevice& d, typename Functor::tout_type out, 180 typename Functor::tin_type in0, 181 typename Functor::tin_type in1) { 182 out.device(d) = in0.binaryExpr(in1, typename Functor::func()); 183 } 184 }; 185 186 #endif // TENSORFLOW_USE_SYCL 187 188 template <typename T> 189 struct tanh_grad : base<T, Eigen::internal::scalar_tanh_gradient_op<T>> {}; 190 191 template <typename T> 192 struct sigmoid_grad : base<T, Eigen::internal::scalar_sigmoid_gradient_op<T>> { 193 }; 194 195 template <typename T> 196 struct inverse_grad : base<T, Eigen::internal::scalar_inverse_gradient_op<T>> { 197 }; 198 199 template <typename T> 200 struct sqrt_grad : base<T, Eigen::internal::scalar_sqrt_gradient_op<T>> {}; 201 202 template <typename T> 203 struct rsqrt_grad : base<T, Eigen::internal::scalar_rsqrt_gradient_op<T>> {}; 204 205 } // end namespace functor 206 207 } // end namespace tensorflow 208 #endif // TENSORFLOW_KERNELS_CWISE_OPS_GRADIENTS_H_ 209