Home | History | Annotate | Download | only in gradients
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <vector>
     17 
     18 #include "tensorflow/cc/ops/array_ops_internal.h"
     19 #include "tensorflow/cc/ops/standard_ops.h"
     20 #include "tensorflow/core/lib/strings/strcat.h"
     21 
     22 #include "tensorflow/cc/framework/grad_op_registry.h"
     23 #include "tensorflow/cc/framework/gradients.h"
     24 
     25 namespace tensorflow {
     26 namespace ops {
     27 namespace {
     28 
     29 REGISTER_NO_GRADIENT_OP("Const");
     30 REGISTER_NO_GRADIENT_OP("StopGradient");
     31 REGISTER_NO_GRADIENT_OP("ConcatOffset");
     32 REGISTER_NO_GRADIENT_OP("EditDistance");
     33 REGISTER_NO_GRADIENT_OP("ZerosLike");
     34 REGISTER_NO_GRADIENT_OP("InvertPermutation");
     35 REGISTER_NO_GRADIENT_OP("Shape");
     36 REGISTER_NO_GRADIENT_OP("ShapeN");
     37 REGISTER_NO_GRADIENT_OP("Rank");
     38 REGISTER_NO_GRADIENT_OP("Size");
     39 REGISTER_NO_GRADIENT_OP("BroadcastGradientArgs");
     40 REGISTER_NO_GRADIENT_OP("OneHot");
     41 
     42 Status PackGrad(const Scope& scope, const Operation& op,
     43                 const std::vector<Output>& grad_inputs,
     44                 std::vector<Output>* grad_outputs) {
     45   int N;
     46   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "N", &N));
     47   int axis;
     48   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
     49 
     50   grad_outputs->reserve(N);
     51   auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis));
     52   for (const Output& o : grad_op.output) {
     53     grad_outputs->emplace_back(o);
     54   }
     55   return scope.status();
     56 }
     57 REGISTER_GRADIENT_OP("Pack", PackGrad);
     58 
     59 Status UnpackGrad(const Scope& scope, const Operation& op,
     60                   const std::vector<Output>& grad_inputs,
     61                   std::vector<Output>* grad_outputs) {
     62   int axis;
     63   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
     64   grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis)));
     65   return scope.status();
     66 }
     67 REGISTER_GRADIENT_OP("Unpack", UnpackGrad);
     68 
     69 Status IdentityGrad(const Scope& scope, const Operation& op,
     70                     const std::vector<Output>& grad_inputs,
     71                     std::vector<Output>* grad_outputs) {
     72   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
     73   return scope.status();
     74 }
     75 REGISTER_GRADIENT_OP("Identity", IdentityGrad);
     76 
     77 Status RefIdentityGrad(const Scope& scope, const Operation& op,
     78                        const std::vector<Output>& grad_inputs,
     79                        std::vector<Output>* grad_outputs) {
     80   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
     81   return scope.status();
     82 }
     83 REGISTER_GRADIENT_OP("RefIdentity", RefIdentityGrad);
     84 
     85 Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
     86                                  const std::vector<Output>& grad_inputs,
     87                                  std::vector<Output>* grad_outputs) {
     88   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
     89   return scope.status();
     90 }
     91 REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
     92 
     93 Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op,
     94                                    const std::vector<Output>& grad_inputs,
     95                                    std::vector<Output>* grad_outputs) {
     96   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
     97   grad_outputs->push_back(NoGradient());
     98   grad_outputs->push_back(NoGradient());
     99   return scope.status();
    100 }
    101 REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad);
    102 
    103 Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
    104                                    const std::vector<Output>& grad_inputs,
    105                                    std::vector<Output>* grad_outputs) {
    106   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
    107   grad_outputs->push_back(NoGradient());
    108   grad_outputs->push_back(NoGradient());
    109   grad_outputs->push_back(NoGradient());
    110   return scope.status();
    111 }
    112 REGISTER_GRADIENT_OP("QuantizeAndDequantizeV3", QuantizeAndDequantizeV3Grad);
    113 
    114 Status SplitGrad(const Scope& scope, const Operation& op,
    115                  const std::vector<Output>& grad_inputs,
    116                  std::vector<Output>* grad_outputs) {
    117   grad_outputs->push_back(NoGradient());
    118   grad_outputs->push_back(Concat(scope, grad_inputs, op.input(0)));
    119   return scope.status();
    120 }
    121 REGISTER_GRADIENT_OP("Split", SplitGrad);
    122 
    123 Status DiagGrad(const Scope& scope, const Operation& op,
    124                 const std::vector<Output>& grad_inputs,
    125                 std::vector<Output>* grad_outputs) {
    126   grad_outputs->push_back(DiagPart(scope, grad_inputs[0]));
    127   return scope.status();
    128 }
    129 REGISTER_GRADIENT_OP("Diag", DiagGrad);
    130 
    131 Status DiagPartGrad(const Scope& scope, const Operation& op,
    132                     const std::vector<Output>& grad_inputs,
    133                     std::vector<Output>* grad_outputs) {
    134   grad_outputs->push_back(Diag(scope, grad_inputs[0]));
    135   return scope.status();
    136 }
    137 REGISTER_GRADIENT_OP("DiagPart", DiagPartGrad);
    138 
    139 Status MatrixDiagGrad(const Scope& scope, const Operation& op,
    140                       const std::vector<Output>& grad_inputs,
    141                       std::vector<Output>* grad_outputs) {
    142   grad_outputs->push_back(MatrixDiagPart(scope, grad_inputs[0]));
    143   return scope.status();
    144 }
    145 REGISTER_GRADIENT_OP("MatrixDiag", MatrixDiagGrad);
    146 
    147 Status MatrixBandPartGrad(const Scope& scope, const Operation& op,
    148                           const std::vector<Output>& grad_inputs,
    149                           std::vector<Output>* grad_outputs) {
    150   auto num_lower = op.input(1);
    151   auto num_upper = op.input(2);
    152   grad_outputs->push_back(
    153       MatrixBandPart(scope, grad_inputs[0], num_lower, num_upper));
    154   grad_outputs->push_back(NoGradient());
    155   grad_outputs->push_back(NoGradient());
    156   return scope.status();
    157 }
    158 REGISTER_GRADIENT_OP("MatrixBandPart", MatrixBandPartGrad);
    159 
    160 Status GatherNdGrad(const Scope& scope, const Operation& op,
    161                     const std::vector<Output>& grad_inputs,
    162                     std::vector<Output>* grad_outputs) {
    163   auto ref = op.input(0);
    164   auto ref_shape = Shape(scope, ref);
    165   auto indices = op.input(1);
    166   grad_outputs->push_back(ScatterNd(scope, indices, grad_inputs[0], ref_shape));
    167   grad_outputs->push_back(NoGradient());
    168   return scope.status();
    169 }
    170 REGISTER_GRADIENT_OP("GatherNd", GatherNdGrad);
    171 
    172 Status CheckNumericsGrad(const Scope& scope, const Operation& op,
    173                          const std::vector<Output>& grad_inputs,
    174                          std::vector<Output>* grad_outputs) {
    175   string message;
    176   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message));
    177   string err_msg = strings::StrCat(
    178       "Not a number (NaN) or infinity (Inf) values detected in gradient. ",
    179       message);
    180   grad_outputs->push_back(CheckNumerics(scope, grad_inputs[0], err_msg));
    181   return scope.status();
    182 }
    183 REGISTER_GRADIENT_OP("CheckNumerics", CheckNumericsGrad);
    184 
    185 Status ReshapeGrad(const Scope& scope, const Operation& op,
    186                    const std::vector<Output>& grad_inputs,
    187                    std::vector<Output>* grad_outputs) {
    188   auto input_shape = Shape(scope, op.input(0));
    189   grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
    190   grad_outputs->push_back(NoGradient());
    191   return scope.status();
    192 }
    193 REGISTER_GRADIENT_OP("Reshape", ReshapeGrad);
    194 
    195 Status ExpandDimsGrad(const Scope& scope, const Operation& op,
    196                       const std::vector<Output>& grad_inputs,
    197                       std::vector<Output>* grad_outputs) {
    198   auto input_shape = Shape(scope, op.input(0));
    199   grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
    200   grad_outputs->push_back(NoGradient());
    201   return scope.status();
    202 }
    203 REGISTER_GRADIENT_OP("ExpandDims", ExpandDimsGrad);
    204 
    205 Status SqueezeGrad(const Scope& scope, const Operation& op,
    206                    const std::vector<Output>& grad_inputs,
    207                    std::vector<Output>* grad_outputs) {
    208   auto input_shape = Shape(scope, op.input(0));
    209   grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
    210   return scope.status();
    211 }
    212 REGISTER_GRADIENT_OP("Squeeze", SqueezeGrad);
    213 
    214 Status TransposeGrad(const Scope& scope, const Operation& op,
    215                      const std::vector<Output>& grad_inputs,
    216                      std::vector<Output>* grad_outputs) {
    217   auto inverted_perm = InvertPermutation(scope, op.input(1));
    218   grad_outputs->push_back(Transpose(scope, grad_inputs[0], inverted_perm));
    219   grad_outputs->push_back(NoGradient());
    220   return scope.status();
    221 }
    222 REGISTER_GRADIENT_OP("Transpose", TransposeGrad);
    223 
    224 Status ReverseSequenceGrad(const Scope& scope, const Operation& op,
    225                            const std::vector<Output>& grad_inputs,
    226                            std::vector<Output>* grad_outputs) {
    227   auto seq_lengths = op.input(1);
    228   int batch_dim;
    229   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim));
    230   int seq_dim;
    231   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "seq_dim", &seq_dim));
    232   grad_outputs->push_back(
    233       ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim,
    234                       ReverseSequence::BatchDim(batch_dim)));
    235   grad_outputs->push_back(NoGradient());
    236   return scope.status();
    237 }
    238 REGISTER_GRADIENT_OP("ReverseSequence", ReverseSequenceGrad);
    239 
    240 Status ReverseGrad(const Scope& scope, const Operation& op,
    241                    const std::vector<Output>& grad_inputs,
    242                    std::vector<Output>* grad_outputs) {
    243   auto reverse_dims = op.input(1);
    244   grad_outputs->push_back(Reverse(scope, grad_inputs[0], reverse_dims));
    245   grad_outputs->push_back(NoGradient());
    246   return scope.status();
    247 }
    248 REGISTER_GRADIENT_OP("ReverseV2", ReverseGrad);
    249 
    250 Status ScatterNdGrad(const Scope& scope, const Operation& op,
    251                      const std::vector<Output>& grad_inputs,
    252                      std::vector<Output>* grad_outputs) {
    253   auto indices = op.input(0);
    254   grad_outputs->push_back(NoGradient());
    255   grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
    256   grad_outputs->push_back(NoGradient());
    257   return scope.status();
    258 }
    259 REGISTER_GRADIENT_OP("ScatterNd", ScatterNdGrad);
    260 
    261 Status ScatterNdNonAliasingAddGrad(const Scope& scope, const Operation& op,
    262                                    const std::vector<Output>& grad_inputs,
    263                                    std::vector<Output>* grad_outputs) {
    264   auto indices = op.input(1);
    265   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
    266   grad_outputs->push_back(NoGradient());
    267   grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
    268   return scope.status();
    269 }
    270 REGISTER_GRADIENT_OP("ScatterNdNonAliasingAdd", ScatterNdNonAliasingAddGrad);
    271 
    272 template <bool IsPadV2>
    273 Status PadGrad(const Scope& scope, const Operation& op,
    274                const std::vector<Output>& grad_inputs,
    275                std::vector<Output>* grad_outputs) {
    276   auto x = op.input(0);
    277   auto a = op.input(1);  // [Rank(x), 2]
    278   // Takes a slice of a. The 1st column. [Rank(x), 1].
    279   auto size = Stack(scope, {Rank(scope, x), 1});
    280   auto pad_before = Slice(scope, a, {0, 0}, size);
    281   // Make it a 1-D tensor.
    282   auto begin = Reshape(scope, pad_before, {-1});
    283   grad_outputs->push_back(Slice(scope, grad_inputs[0], begin, Shape(scope, x)));
    284   grad_outputs->push_back(NoGradient());
    285   // PadV2 adds a "constant_values" input.
    286   if (IsPadV2) {
    287     grad_outputs->push_back(NoGradient());
    288   }
    289   return scope.status();
    290 }
    291 REGISTER_GRADIENT_OP("Pad", PadGrad<false>);
    292 REGISTER_GRADIENT_OP("PadV2", PadGrad<true>);
    293 
    294 Status SpaceToBatchGrad(const Scope& scope, const Operation& op,
    295                         const std::vector<Output>& grad_inputs,
    296                         std::vector<Output>* grad_outputs) {
    297   int block_size;
    298   TF_RETURN_IF_ERROR(
    299       GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
    300   grad_outputs->push_back(
    301       BatchToSpace(scope, grad_inputs[0], op.input(1), block_size));
    302   grad_outputs->push_back(NoGradient());
    303   return scope.status();
    304 }
    305 REGISTER_GRADIENT_OP("SpaceToBatch", SpaceToBatchGrad);
    306 
    307 Status SpaceToBatchNDGrad(const Scope& scope, const Operation& op,
    308                           const std::vector<Output>& grad_inputs,
    309                           std::vector<Output>* grad_outputs) {
    310   grad_outputs->push_back(
    311       BatchToSpaceND(scope, grad_inputs[0], op.input(1), op.input(2)));
    312   grad_outputs->push_back(NoGradient());
    313   grad_outputs->push_back(NoGradient());
    314   return scope.status();
    315 }
    316 REGISTER_GRADIENT_OP("SpaceToBatchND", SpaceToBatchNDGrad);
    317 
    318 Status BatchToSpaceGrad(const Scope& scope, const Operation& op,
    319                         const std::vector<Output>& grad_inputs,
    320                         std::vector<Output>* grad_outputs) {
    321   int block_size;
    322   TF_RETURN_IF_ERROR(
    323       GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
    324   grad_outputs->push_back(
    325       SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size));
    326   grad_outputs->push_back(NoGradient());
    327   return scope.status();
    328 }
    329 REGISTER_GRADIENT_OP("BatchToSpace", BatchToSpaceGrad);
    330 
    331 Status BatchToSpaceNDGrad(const Scope& scope, const Operation& op,
    332                           const std::vector<Output>& grad_inputs,
    333                           std::vector<Output>* grad_outputs) {
    334   grad_outputs->push_back(
    335       SpaceToBatchND(scope, grad_inputs[0], op.input(1), op.input(2)));
    336   grad_outputs->push_back(NoGradient());
    337   grad_outputs->push_back(NoGradient());
    338   return scope.status();
    339 }
    340 REGISTER_GRADIENT_OP("BatchToSpaceND", BatchToSpaceNDGrad);
    341 
    342 Status SpaceToDepthGrad(const Scope& scope, const Operation& op,
    343                         const std::vector<Output>& grad_inputs,
    344                         std::vector<Output>* grad_outputs) {
    345   int block_size;
    346   TF_RETURN_IF_ERROR(
    347       GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
    348   grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size));
    349   return scope.status();
    350 }
    351 REGISTER_GRADIENT_OP("SpaceToDepth", SpaceToDepthGrad);
    352 
    353 Status DepthToSpaceGrad(const Scope& scope, const Operation& op,
    354                         const std::vector<Output>& grad_inputs,
    355                         std::vector<Output>* grad_outputs) {
    356   int block_size;
    357   TF_RETURN_IF_ERROR(
    358       GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
    359   grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size));
    360   return scope.status();
    361 }
    362 REGISTER_GRADIENT_OP("DepthToSpace", DepthToSpaceGrad);
    363 
    364 Status MirrorPadGrad(const Scope& scope, const Operation& op,
    365                      const std::vector<Output>& grad_inputs,
    366                      std::vector<Output>* grad_outputs) {
    367   string mode;
    368   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
    369   grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad(
    370       scope, grad_inputs[0], op.input(1), mode));
    371   grad_outputs->push_back(NoGradient());
    372   return scope.status();
    373 }
    374 REGISTER_GRADIENT_OP("MirrorPad", MirrorPadGrad);
    375 
    376 // TODO(suharshs): b/34770860. This gradient was within 1e-3 but not 1e-4.
    377 Status MirrorPadGradGrad(const Scope& scope, const Operation& op,
    378                          const std::vector<Output>& grad_inputs,
    379                          std::vector<Output>* grad_outputs) {
    380   string mode;
    381   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
    382   grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode));
    383   grad_outputs->push_back(NoGradient());
    384   return scope.status();
    385 }
    386 REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad);
    387 
    388 }  // anonymous namespace
    389 }  // namespace ops
    390 }  // namespace tensorflow
    391