Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc.
     17 
     18 #ifndef TENSORFLOW_KERNELS_RELU_OP_H_
     19 #define TENSORFLOW_KERNELS_RELU_OP_H_
     20 
     21 #define EIGEN_USE_THREADS
     22 
     23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     24 #include "tensorflow/core/framework/numeric_op.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/register_types.h"
     27 #include "tensorflow/core/framework/tensor.h"
     28 #include "tensorflow/core/kernels/relu_op_functor.h"
     29 #include "tensorflow/core/lib/core/errors.h"
     30 
     31 namespace tensorflow {
     32 
     33 template <typename Device, typename T>
     34 class ReluOp : public UnaryElementWiseOp<T, ReluOp<Device, T>> {
     35  public:
     36   using UnaryElementWiseOp<T, ReluOp<Device, T>>::UnaryElementWiseOp;
     37 
     38   void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
     39     functor::Relu<Device, T> functor;
     40     functor(context->eigen_device<Device>(), input.flat<T>(),
     41             output->flat<T>());
     42   }
     43 };
     44 
     45 // Out of line check to save code space (we have this code once, rather
     46 // than once for every NDIMS * NumTypes * Num_different_relu_variants
     47 // functions.
     48 struct ReluHelpers {
     49   static void ValidateSameSizeHelper(OpKernelContext* context, const Tensor& g,
     50                                      const Tensor& a) {
     51     OP_REQUIRES(context, a.IsSameSize(g),
     52                 errors::InvalidArgument("g and a must be the same size"));
     53   }
     54   static bool ValidateSameSize(OpKernelContext* context, const Tensor& g,
     55                                const Tensor& a) {
     56     ValidateSameSizeHelper(context, g, a);
     57     return context->status().ok();
     58   }
     59 };
     60 
     61 template <typename Device, typename T>
     62 class ReluGradOp : public BinaryElementWiseOp<T, ReluGradOp<Device, T>> {
     63  public:
     64   using BinaryElementWiseOp<T, ReluGradOp<Device, T>>::BinaryElementWiseOp;
     65 
     66   void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
     67                          const Tensor& a, Tensor* output);
     68 
     69   // INPUTS:
     70   //   g (gradients): backpropagated gradients
     71   //   a (inputs): either the inputs that were passed to ReluOp(), or its
     72   //               outputs (using either one yields the same result here).
     73   // OUTPUT:
     74   //   gradients to backprop
     75   template <int NDIMS>
     76   void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
     77                Tensor* output) {
     78     OperateNoTemplate(context, g, a, output);
     79   }
     80 };
     81 
     82 template <typename Device, typename T>
     83 void ReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
     84                                               const Tensor& g, const Tensor& a,
     85                                               Tensor* output) {
     86   if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
     87   functor::ReluGrad<Device, T> functor;
     88   functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
     89           output->flat<T>());
     90 }
     91 
     92 template <typename Device, typename T>
     93 class Relu6Op : public UnaryElementWiseOp<T, Relu6Op<Device, T>> {
     94  public:
     95   using UnaryElementWiseOp<T, Relu6Op<Device, T>>::UnaryElementWiseOp;
     96 
     97   void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
     98     functor::Relu6<Device, T> functor;
     99     functor(context->eigen_device<Device>(), input.flat<T>(),
    100             output->flat<T>());
    101   }
    102 };
    103 
    104 template <typename Device, typename T>
    105 class Relu6GradOp : public BinaryElementWiseOp<T, Relu6GradOp<Device, T>> {
    106  public:
    107   using BinaryElementWiseOp<T, Relu6GradOp<Device, T>>::BinaryElementWiseOp;
    108 
    109   void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
    110                          const Tensor& a, Tensor* output);
    111 
    112   // INPUTS:
    113   //   g (gradients): backpropagated gradients
    114   //   a (inputs): inputs that were passed to Relu6Op()
    115   // OUTPUT:
    116   //   gradients to backprop
    117   template <int NDIMS>
    118   void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
    119                Tensor* output) {
    120     OperateNoTemplate(context, g, a, output);
    121   }
    122 };
    123 
    124 template <typename Device, typename T>
    125 void Relu6GradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
    126                                                const Tensor& g, const Tensor& a,
    127                                                Tensor* output) {
    128   if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
    129   functor::Relu6Grad<Device, T> functor;
    130   functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
    131           output->flat<T>());
    132 }
    133 
    134 template <typename Device, typename T>
    135 class EluOp : public UnaryElementWiseOp<T, EluOp<Device, T>> {
    136  public:
    137   using UnaryElementWiseOp<T, EluOp<Device, T>>::UnaryElementWiseOp;
    138 
    139   void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
    140     functor::Elu<Device, T> functor;
    141     functor(context->eigen_device<Device>(), input.flat<T>(),
    142             output->flat<T>());
    143   }
    144 };
    145 
    146 template <typename Device, typename T>
    147 class EluGradOp : public BinaryElementWiseOp<T, EluGradOp<Device, T>> {
    148  public:
    149   using BinaryElementWiseOp<T, EluGradOp<Device, T>>::BinaryElementWiseOp;
    150 
    151   void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
    152                          const Tensor& a, Tensor* output);
    153 
    154   // INPUTS:
    155   //   g (gradients): backpropagated gradients
    156   //   a (outputs): outputs of the EluOp()
    157   // OUTPUT:
    158   //   gradients to backprop
    159   template <int NDIMS>
    160   void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
    161                Tensor* output) {
    162     OperateNoTemplate(context, g, a, output);
    163   }
    164 };
    165 
    166 template <typename Device, typename T>
    167 void EluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
    168                                              const Tensor& g, const Tensor& a,
    169                                              Tensor* output) {
    170   if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
    171   functor::EluGrad<Device, T> functor;
    172   functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
    173           output->flat<T>());
    174 }
    175 
    176 template <typename Device, typename T>
    177 class SeluOp : public UnaryElementWiseOp<T, SeluOp<Device, T>> {
    178  public:
    179   using UnaryElementWiseOp<T, SeluOp<Device, T>>::UnaryElementWiseOp;
    180 
    181   void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
    182     functor::Selu<Device, T> functor;
    183     functor(context->eigen_device<Device>(), input.flat<T>(),
    184             output->flat<T>());
    185   }
    186 };
    187 
    188 template <typename Device, typename T>
    189 class SeluGradOp : public BinaryElementWiseOp<T, SeluGradOp<Device, T>> {
    190  public:
    191   using BinaryElementWiseOp<T, SeluGradOp<Device, T>>::BinaryElementWiseOp;
    192 
    193   void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
    194                          const Tensor& a, Tensor* output);
    195 
    196   // INPUTS:
    197   //   g (gradients): backpropagated gradients
    198   //   a (outputs): outputs of the SeluOp()
    199   // OUTPUT:
    200   //   gradients to backprop
    201   template <int NDIMS>
    202   void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
    203                Tensor* output) {
    204     OperateNoTemplate(context, g, a, output);
    205   }
    206 };
    207 
    208 template <typename Device, typename T>
    209 void SeluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
    210                                               const Tensor& g, const Tensor& a,
    211                                               Tensor* output) {
    212   if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
    213   functor::SeluGrad<Device, T> functor;
    214   functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
    215           output->flat<T>());
    216 }
    217 
    218 }  // namespace tensorflow
    219 
    220 #undef EIGEN_USE_THREADS
    221 
    222 #endif  // TENSORFLOW_KERNELS_RELU_OP_H_
    223