Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_RELU_OP_FUNCTOR_H_
     17 #define TENSORFLOW_KERNELS_RELU_OP_FUNCTOR_H_
     18 // Functor definition for ReluOp and ReluGradOp, must be compilable by nvcc.
     19 
     20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     21 #include "tensorflow/core/framework/tensor_types.h"
     22 
     23 namespace tensorflow {
     24 namespace functor {
     25 
     26 // Functor used by ReluOp to do the computations.
     27 template <typename Device, typename T>
     28 struct Relu {
     29   // Computes Relu activation.
     30   //
     31   // features: any shape.
     32   // activations: same shape as "features".
     33   void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
     34                   typename TTypes<T>::Tensor activations) {
     35     activations.device(d) = features.cwiseMax(static_cast<T>(0));
     36   }
     37 };
     38 
     39 // Functor used by ReluGradOp to do the computations.
     40 template <typename Device, typename T>
     41 struct ReluGrad {
     42   // Computes ReluGrad backprops.
     43   //
     44   // gradients: gradients backpropagated to the Relu op.
     45   // features: either the inputs that were passed to the Relu or, or its
     46   //           outputs (using either one yields the same result here).
     47   // backprops: gradients to backpropagate to the Relu inputs.
     48   void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
     49                   typename TTypes<T>::ConstTensor features,
     50                   typename TTypes<T>::Tensor backprops) {
     51     // NOTE: When the activation is exactly zero, we do not propagate the
     52     // associated gradient value. This allows the output of the Relu to be used,
     53     // as well as its input.
     54     backprops.device(d) =
     55         gradients * (features > static_cast<T>(0)).template cast<T>();
     56   }
     57 };
     58 
     59 // Functor used by Relu6Op to do the computations.
     60 template <typename Device, typename T>
     61 struct Relu6 {
     62   // Computes Relu6 activation.
     63   //
     64   // features: any shape.
     65   // activations: same shape as "features".
     66   void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
     67                   typename TTypes<T>::Tensor activations) {
     68     activations.device(d) =
     69         features.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(6));
     70   }
     71 };
     72 
     73 // Functor used by ReluGradOp to do the computations.
     74 template <typename Device, typename T>
     75 struct Relu6Grad {
     76   // Computes Relu6Grad backprops.
     77   //
     78   // gradients: gradients backpropagated to the Relu6 op.
     79   // features: inputs that where passed to the Relu6 op, or its outputs.
     80   // backprops: gradients to backpropagate to the Relu6 inputs.
     81   void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
     82                   typename TTypes<T>::ConstTensor features,
     83                   typename TTypes<T>::Tensor backprops) {
     84     // NOTE: When the activation is exactly zero or six, we
     85     // make sure not to propagate the associated gradient
     86     // value. This allows "features" to be either the input or the output of
     87     // the relu6.
     88     backprops.device(d) = gradients * ((features > static_cast<T>(0)) *
     89                                        (features < static_cast<T>(6)))
     90                                           .template cast<T>();
     91   }
     92 };
     93 
     94 // Functor used by EluOp to do the computations.
     95 template <typename Device, typename T>
     96 struct Elu {
     97   // Computes Elu activation.
     98   //
     99   // features: any shape.
    100   // activations: same shape as "features".
    101   void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
    102                   typename TTypes<T>::Tensor activations) {
    103     // features.constant(?)
    104     activations.device(d) =
    105         (features < static_cast<T>(0))
    106             .select(features.exp() - features.constant(static_cast<T>(1)),
    107                     features);
    108   }
    109 };
    110 
    111 // Functor used by EluGradOp to do the computations.
    112 template <typename Device, typename T>
    113 struct EluGrad {
    114   // Computes EluGrad backprops.
    115   //
    116   // gradients: gradients backpropagated to the Elu op.
    117   // activations: outputs of the Elu op.
    118   // backprops: gradients to backpropagate to the Elu inputs.
    119   void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
    120                   typename TTypes<T>::ConstTensor activations,
    121                   typename TTypes<T>::Tensor backprops) {
    122     backprops.device(d) =
    123         (activations < static_cast<T>(0))
    124             .select((activations + static_cast<T>(1)) * gradients, gradients);
    125   }
    126 };
    127 
    128 // Functor used by SeluOp to do the computations.
    129 template <typename Device, typename T>
    130 struct Selu {
    131   // Computes Selu activation.
    132   //
    133   // features: any shape.
    134   // activations: same shape as "features".
    135   void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
    136                   typename TTypes<T>::Tensor activations) {
    137     // features.constant(?)
    138     const auto scale = static_cast<T>(1.0507009873554804934193349852946);
    139     const auto scale_alpha = static_cast<T>(1.7580993408473768599402175208123);
    140     const auto one = static_cast<T>(1);
    141     const auto zero = static_cast<T>(0);
    142     activations.device(d) =
    143         (features < zero)
    144             .select(scale_alpha * (features.exp() - features.constant(one)),
    145                     scale * features);
    146   }
    147 };
    148 
    149 // Functor used by SeluGradOp to do the computations.
    150 template <typename Device, typename T>
    151 struct SeluGrad {
    152   // Computes SeluGrad backprops.
    153   //
    154   // gradients: gradients backpropagated to the Selu op.
    155   // activations: outputs of the Selu op.
    156   // backprops: gradients to backpropagate to the Selu inputs.
    157   void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
    158                   typename TTypes<T>::ConstTensor activations,
    159                   typename TTypes<T>::Tensor backprops) {
    160     const auto scale = static_cast<T>(1.0507009873554804934193349852946);
    161     const auto scale_alpha = static_cast<T>(1.7580993408473768599402175208123);
    162     backprops.device(d) =
    163         (activations < static_cast<T>(0))
    164             .select(gradients * (activations + scale_alpha), gradients * scale);
    165   }
    166 };
    167 
    168 }  // namespace functor
    169 }  // namespace tensorflow
    170 
    171 #endif  // TENSORFLOW_KERNELS_RELU_OP_FUNCTOR_H_
    172