1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 18 #ifndef TENSORFLOW_KERNELS_RELU_OP_H_ 19 #define TENSORFLOW_KERNELS_RELU_OP_H_ 20 21 #define EIGEN_USE_THREADS 22 23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 24 #include "tensorflow/core/framework/numeric_op.h" 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/register_types.h" 27 #include "tensorflow/core/framework/tensor.h" 28 #include "tensorflow/core/kernels/relu_op_functor.h" 29 #include "tensorflow/core/lib/core/errors.h" 30 31 namespace tensorflow { 32 33 template <typename Device, typename T> 34 class ReluOp : public UnaryElementWiseOp<T, ReluOp<Device, T>> { 35 public: 36 using UnaryElementWiseOp<T, ReluOp<Device, T>>::UnaryElementWiseOp; 37 38 void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { 39 functor::Relu<Device, T> functor; 40 functor(context->eigen_device<Device>(), input.flat<T>(), 41 output->flat<T>()); 42 } 43 }; 44 45 // Out of line check to save code space (we have this code once, rather 46 // than once for every NDIMS * NumTypes * Num_different_relu_variants 47 // functions. 48 struct ReluHelpers { 49 static void ValidateSameSizeHelper(OpKernelContext* context, const Tensor& g, 50 const Tensor& a) { 51 OP_REQUIRES(context, a.IsSameSize(g), 52 errors::InvalidArgument("g and a must be the same size")); 53 } 54 static bool ValidateSameSize(OpKernelContext* context, const Tensor& g, 55 const Tensor& a) { 56 ValidateSameSizeHelper(context, g, a); 57 return context->status().ok(); 58 } 59 }; 60 61 template <typename Device, typename T> 62 class ReluGradOp : public BinaryElementWiseOp<T, ReluGradOp<Device, T>> { 63 public: 64 using BinaryElementWiseOp<T, ReluGradOp<Device, T>>::BinaryElementWiseOp; 65 66 void OperateNoTemplate(OpKernelContext* context, const Tensor& g, 67 const Tensor& a, Tensor* output); 68 69 // INPUTS: 70 // g (gradients): backpropagated gradients 71 // a (inputs): either the inputs that were passed to ReluOp(), or its 72 // outputs (using either one yields the same result here). 73 // OUTPUT: 74 // gradients to backprop 75 template <int NDIMS> 76 void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, 77 Tensor* output) { 78 OperateNoTemplate(context, g, a, output); 79 } 80 }; 81 82 template <typename Device, typename T> 83 void ReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context, 84 const Tensor& g, const Tensor& a, 85 Tensor* output) { 86 if (!ReluHelpers::ValidateSameSize(context, g, a)) return; 87 functor::ReluGrad<Device, T> functor; 88 functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), 89 output->flat<T>()); 90 } 91 92 template <typename Device, typename T> 93 class Relu6Op : public UnaryElementWiseOp<T, Relu6Op<Device, T>> { 94 public: 95 using UnaryElementWiseOp<T, Relu6Op<Device, T>>::UnaryElementWiseOp; 96 97 void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { 98 functor::Relu6<Device, T> functor; 99 functor(context->eigen_device<Device>(), input.flat<T>(), 100 output->flat<T>()); 101 } 102 }; 103 104 template <typename Device, typename T> 105 class Relu6GradOp : public BinaryElementWiseOp<T, Relu6GradOp<Device, T>> { 106 public: 107 using BinaryElementWiseOp<T, Relu6GradOp<Device, T>>::BinaryElementWiseOp; 108 109 void OperateNoTemplate(OpKernelContext* context, const Tensor& g, 110 const Tensor& a, Tensor* output); 111 112 // INPUTS: 113 // g (gradients): backpropagated gradients 114 // a (inputs): inputs that were passed to Relu6Op() 115 // OUTPUT: 116 // gradients to backprop 117 template <int NDIMS> 118 void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, 119 Tensor* output) { 120 OperateNoTemplate(context, g, a, output); 121 } 122 }; 123 124 template <typename Device, typename T> 125 void Relu6GradOp<Device, T>::OperateNoTemplate(OpKernelContext* context, 126 const Tensor& g, const Tensor& a, 127 Tensor* output) { 128 if (!ReluHelpers::ValidateSameSize(context, g, a)) return; 129 functor::Relu6Grad<Device, T> functor; 130 functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), 131 output->flat<T>()); 132 } 133 134 template <typename Device, typename T> 135 class EluOp : public UnaryElementWiseOp<T, EluOp<Device, T>> { 136 public: 137 using UnaryElementWiseOp<T, EluOp<Device, T>>::UnaryElementWiseOp; 138 139 void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { 140 functor::Elu<Device, T> functor; 141 functor(context->eigen_device<Device>(), input.flat<T>(), 142 output->flat<T>()); 143 } 144 }; 145 146 template <typename Device, typename T> 147 class EluGradOp : public BinaryElementWiseOp<T, EluGradOp<Device, T>> { 148 public: 149 using BinaryElementWiseOp<T, EluGradOp<Device, T>>::BinaryElementWiseOp; 150 151 void OperateNoTemplate(OpKernelContext* context, const Tensor& g, 152 const Tensor& a, Tensor* output); 153 154 // INPUTS: 155 // g (gradients): backpropagated gradients 156 // a (outputs): outputs of the EluOp() 157 // OUTPUT: 158 // gradients to backprop 159 template <int NDIMS> 160 void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, 161 Tensor* output) { 162 OperateNoTemplate(context, g, a, output); 163 } 164 }; 165 166 template <typename Device, typename T> 167 void EluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context, 168 const Tensor& g, const Tensor& a, 169 Tensor* output) { 170 if (!ReluHelpers::ValidateSameSize(context, g, a)) return; 171 functor::EluGrad<Device, T> functor; 172 functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), 173 output->flat<T>()); 174 } 175 176 template <typename Device, typename T> 177 class SeluOp : public UnaryElementWiseOp<T, SeluOp<Device, T>> { 178 public: 179 using UnaryElementWiseOp<T, SeluOp<Device, T>>::UnaryElementWiseOp; 180 181 void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { 182 functor::Selu<Device, T> functor; 183 functor(context->eigen_device<Device>(), input.flat<T>(), 184 output->flat<T>()); 185 } 186 }; 187 188 template <typename Device, typename T> 189 class SeluGradOp : public BinaryElementWiseOp<T, SeluGradOp<Device, T>> { 190 public: 191 using BinaryElementWiseOp<T, SeluGradOp<Device, T>>::BinaryElementWiseOp; 192 193 void OperateNoTemplate(OpKernelContext* context, const Tensor& g, 194 const Tensor& a, Tensor* output); 195 196 // INPUTS: 197 // g (gradients): backpropagated gradients 198 // a (outputs): outputs of the SeluOp() 199 // OUTPUT: 200 // gradients to backprop 201 template <int NDIMS> 202 void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, 203 Tensor* output) { 204 OperateNoTemplate(context, g, a, output); 205 } 206 }; 207 208 template <typename Device, typename T> 209 void SeluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context, 210 const Tensor& g, const Tensor& a, 211 Tensor* output) { 212 if (!ReluHelpers::ValidateSameSize(context, g, a)) return; 213 functor::SeluGrad<Device, T> functor; 214 functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), 215 output->flat<T>()); 216 } 217 218 } // namespace tensorflow 219 220 #undef EIGEN_USE_THREADS 221 222 #endif // TENSORFLOW_KERNELS_RELU_OP_H_ 223