Home | History | Annotate | Download | only in gradients
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/cc/ops/nn_ops.h"
     17 #include "tensorflow/cc/ops/nn_ops_internal.h"
     18 #include "tensorflow/cc/ops/standard_ops.h"
     19 
     20 #include "tensorflow/cc/framework/grad_op_registry.h"
     21 #include "tensorflow/cc/framework/gradients.h"
     22 
     23 namespace tensorflow {
     24 namespace ops {
     25 namespace {
     26 
     27 Status SoftmaxGrad(const Scope& scope, const Operation& op,
     28                    const std::vector<Output>& grad_inputs,
     29                    std::vector<Output>* grad_outputs) {
     30   // Softmax gradient function.
     31   // p = softmax(x) maps from [batch, n] to [batch, m]
     32   // dp/dx = [dp0/dx0   ... dp0/dxn-1  ]
     33   //         [  ...           ...      ]
     34   //         [dpm-1/dx0 ... dpm-1/dxn-1]
     35   // dL/dx = dp/dx * dL/dy
     36   //
     37   // Using alternative formula:
     38   // dL/dx = dL/dy * y - sum(dL/dy * y) * y
     39   //    = (dL/dy - sum(dL/dy * y)) * y
     40   auto y = op.output(0);
     41   auto dyy = Mul(scope, grad_inputs[0], y);
     42   auto sum = Reshape(scope, Sum(scope, dyy, {1}), {-1, 1});
     43   auto sub = Sub(scope, grad_inputs[0], sum);
     44   auto dx = Mul(scope, sub, y);
     45   grad_outputs->push_back(dx);
     46   return scope.status();
     47 }
     48 REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad);
     49 
     50 Status LogSoftmaxGrad(const Scope& scope, const Operation& op,
     51                    const std::vector<Output>& grad_inputs,
     52                    std::vector<Output>* grad_outputs) {
     53   auto softmax = Exp(scope, op.output(0));
     54   auto sum = Sum(scope, grad_inputs[0], {1}, Sum::KeepDims(true));
     55   auto mul = Mul(scope, sum, softmax);
     56   auto dx = Sub(scope, grad_inputs[0], mul);
     57   grad_outputs->push_back(dx);
     58   return scope.status();
     59 }
     60 REGISTER_GRADIENT_OP("LogSoftmax", LogSoftmaxGrad);
     61 
     62 Status ReluGradHelper(const Scope& scope, const Operation& op,
     63                       const std::vector<Output>& grad_inputs,
     64                       std::vector<Output>* grad_outputs) {
     65   auto dx = internal::ReluGrad(scope, grad_inputs[0], op.input(0));
     66   grad_outputs->push_back(dx);
     67   return scope.status();
     68 }
     69 REGISTER_GRADIENT_OP("Relu", ReluGradHelper);
     70 
     71 Status Relu6GradHelper(const Scope& scope, const Operation& op,
     72                        const std::vector<Output>& grad_inputs,
     73                        std::vector<Output>* grad_outputs) {
     74   auto dx = internal::Relu6Grad(scope, grad_inputs[0], op.input(0));
     75   grad_outputs->push_back(dx);
     76   return scope.status();
     77 }
     78 REGISTER_GRADIENT_OP("Relu6", Relu6GradHelper);
     79 
     80 Status EluGradHelper(const Scope& scope, const Operation& op,
     81                      const std::vector<Output>& grad_inputs,
     82                      std::vector<Output>* grad_outputs) {
     83   auto dx = internal::EluGrad(scope, grad_inputs[0], op.output(0));
     84   grad_outputs->push_back(dx);
     85   return scope.status();
     86 }
     87 REGISTER_GRADIENT_OP("Elu", EluGradHelper);
     88 
     89 Status SeluGradHelper(const Scope& scope, const Operation& op,
     90                       const std::vector<Output>& grad_inputs,
     91                       std::vector<Output>* grad_outputs) {
     92   auto dx = internal::SeluGrad(scope, grad_inputs[0], op.output(0));
     93   grad_outputs->push_back(dx);
     94   return scope.status();
     95 }
     96 REGISTER_GRADIENT_OP("Selu", SeluGradHelper);
     97 
     98 Status L2LossGrad(const Scope& scope, const Operation& op,
     99                   const std::vector<Output>& grad_inputs,
    100                   std::vector<Output>* grad_outputs) {
    101   grad_outputs->push_back(Mul(scope, op.input(0), grad_inputs[0]));
    102   return scope.status();
    103 }
    104 REGISTER_GRADIENT_OP("L2Loss", L2LossGrad);
    105 
    106 Status BiasAddGradHelper(const Scope& scope, const Operation& op,
    107                          const std::vector<Output>& grad_inputs,
    108                          std::vector<Output>* grad_outputs) {
    109   string data_format;
    110   BiasAddGrad::Attrs input_attrs;
    111   TF_RETURN_IF_ERROR(
    112       GetNodeAttr(op.output(0).node()->attrs(), "data_format", &data_format));
    113   input_attrs.DataFormat(data_format);
    114   auto dx_1 = BiasAddGrad(scope, grad_inputs[0], input_attrs);
    115   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
    116   grad_outputs->push_back(dx_1);
    117   return scope.status();
    118 }
    119 REGISTER_GRADIENT_OP("BiasAdd", BiasAddGradHelper);
    120 
    121 Status Conv2DGrad(const Scope& scope, const Operation& op,
    122                   const std::vector<Output>& grad_inputs,
    123                   std::vector<Output>* grad_outputs) {
    124   string data_format;
    125   string padding;
    126   std::vector<int32> strides;
    127   bool use_cudnn_on_gpu;
    128   auto attrs = op.output(0).node()->attrs();
    129   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
    130   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
    131   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides));
    132   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "use_cudnn_on_gpu", &use_cudnn_on_gpu));
    133   Conv2DBackpropInput::Attrs input_attrs;
    134   input_attrs.DataFormat(data_format);
    135   input_attrs.UseCudnnOnGpu(use_cudnn_on_gpu);
    136   auto dx_1 = Conv2DBackpropInput(scope, Shape(scope, op.input(0)),
    137                                   op.input(1), grad_inputs[0],
    138                                   strides, padding, input_attrs);
    139   grad_outputs->push_back(dx_1);
    140   Conv2DBackpropFilter::Attrs filter_attrs;
    141   filter_attrs.DataFormat(data_format);
    142   filter_attrs.UseCudnnOnGpu(use_cudnn_on_gpu);
    143   auto dx_2 = Conv2DBackpropFilter(scope, op.input(0),
    144                                    Shape(scope, op.input(1)), grad_inputs[0],
    145                                    strides, padding, filter_attrs);
    146   grad_outputs->push_back(dx_2);
    147   return scope.status();
    148 }
    149 REGISTER_GRADIENT_OP("Conv2D", Conv2DGrad);
    150 
    151 Status MaxPoolGradHelper(const Scope& scope, const Operation& op,
    152                          const std::vector<Output>& grad_inputs,
    153                          std::vector<Output>* grad_outputs) {
    154   string data_format;
    155   string padding;
    156   std::vector<int32> strides;
    157   std::vector<int32> ksize;
    158   auto attrs = op.output(0).node()->attrs();
    159   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
    160   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize));
    161   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
    162   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides));
    163   internal::MaxPoolGrad::Attrs grad_attrs;
    164   grad_attrs.DataFormat(data_format);
    165   auto dx = internal::MaxPoolGrad(scope, op.input(0),
    166                                   op.output(0),
    167                                   grad_inputs[0],
    168                                   ksize, strides,
    169                                   padding, grad_attrs);
    170   grad_outputs->push_back(dx);
    171   return scope.status();
    172 }
    173 REGISTER_GRADIENT_OP("MaxPool", MaxPoolGradHelper);
    174 
    175 Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op,
    176                            const std::vector<Output>& grad_inputs,
    177                            std::vector<Output>* grad_outputs) {
    178   string data_format;
    179   string padding;
    180   auto attrs = op.output(0).node()->attrs();
    181   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
    182   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
    183   MaxPoolGradV2::Attrs grad_attrs;
    184   grad_attrs.DataFormat(data_format);
    185   auto dx = MaxPoolGradV2(scope, op.input(0),
    186                           op.output(0),
    187                           grad_inputs[0],
    188                           op.input(1),
    189                           op.input(2),
    190                           padding,
    191                           grad_attrs);
    192   grad_outputs->push_back(dx);
    193   grad_outputs->push_back(NoGradient());
    194   grad_outputs->push_back(NoGradient());
    195   return scope.status();
    196 }
    197 REGISTER_GRADIENT_OP("MaxPoolV2", MaxPoolGradV2Helper);
    198 
    199 Status LRNGradHelper(const Scope& scope, const Operation& op,
    200                      const std::vector<Output>& grad_inputs,
    201                      std::vector<Output>* grad_outputs){
    202   internal::LRNGrad::Attrs grad_attrs;
    203 
    204   auto dx = internal::LRNGrad(scope, grad_inputs[0], op.input(0), op.output(0),
    205                               grad_attrs);
    206   grad_outputs->push_back(dx);
    207   return scope.status();
    208 }
    209 REGISTER_GRADIENT_OP("LRN", LRNGradHelper);
    210 
    211 }  // anonymous namespace
    212 }  // namespace ops
    213 }  // namespace tensorflow
    214