Home | History | Annotate | Download | only in gradients
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/cc/ops/data_flow_ops.h"
     17 #include "tensorflow/cc/ops/data_flow_ops_internal.h"
     18 #include "tensorflow/cc/ops/standard_ops.h"
     19 
     20 #include "tensorflow/cc/framework/grad_op_registry.h"
     21 #include "tensorflow/cc/framework/gradients.h"
     22 
     23 namespace tensorflow {
     24 namespace ops {
     25 namespace {
     26 
     27 REGISTER_NO_GRADIENT_OP("Queue");
     28 REGISTER_NO_GRADIENT_OP("QueueEnqueue");
     29 REGISTER_NO_GRADIENT_OP("QueueEnqueueMany");
     30 REGISTER_NO_GRADIENT_OP("QueueDequeue");
     31 REGISTER_NO_GRADIENT_OP("QueueDequeueMany");
     32 REGISTER_NO_GRADIENT_OP("QueueDequeueUpTo");
     33 REGISTER_NO_GRADIENT_OP("QueueClose");
     34 REGISTER_NO_GRADIENT_OP("QueueSize");
     35 REGISTER_NO_GRADIENT_OP("Stack");
     36 REGISTER_NO_GRADIENT_OP("StackPush");
     37 REGISTER_NO_GRADIENT_OP("StackPop");
     38 REGISTER_NO_GRADIENT_OP("StackClose");
     39 REGISTER_NO_GRADIENT_OP("GetSessionHandle");
     40 REGISTER_NO_GRADIENT_OP("GetSessionHandleV2");
     41 REGISTER_NO_GRADIENT_OP("GetSessionTensor");
     42 REGISTER_NO_GRADIENT_OP("DeleteSessionTensor");
     43 
     44 Status DynamicPartitionGrad(const Scope& scope, const Operation& op,
     45                             const std::vector<Output>& grad_inputs,
     46                             std::vector<Output>* grad_outputs) {
     47   // DynamicPartition only moves input values into various positions
     48   // in the output, so the gradient operation only has to map incoming
     49   // gradients into their input source locations.
     50   // running example:
     51   // data = [10, 20, 30, 40, 50]
     52   // partitions = [0, 0, 1, 1, 0]
     53   // num_partitions = 2
     54   // dynamic_partition(data, partitions, num_partitions) = {
     55   //   [10, 20, 50],
     56   //   [30, 40]
     57   // }
     58   // grads = {
     59   //   [g1, g2, g3],
     60   //   [g4, g5]
     61   // }
     62   // The desired propagation of the gradients back to the data inputs is:
     63   // [g1, g2, g4, g5, g3]
     64   auto data = op.input(0);
     65   auto partitions = op.input(1);
     66   int32 num_partitions;
     67   TF_RETURN_IF_ERROR(
     68       GetNodeAttr(op.node()->attrs(), "num_partitions", &num_partitions));
     69 
     70   // Note: the shape of the partitions is a prefix of the data shape.
     71   // shape(partitions) = [5]
     72   auto partitions_shape = Shape(scope, partitions);
     73   // We now create a partitions-shaped tensor with integers from
     74   // [0..size(partitions)) This will be dynamic_partitioned with the
     75   // input parameters, providing the destination index for a given
     76   // source item.
     77   // partitions_size = prod([5]) = 5
     78   // reshape(range(partitions_size), [5]) = [0, 1, 2, 3, 4]
     79   auto zero = Const(scope, 0);
     80   auto one = Const(scope, 1);
     81   auto original_indices = Reshape(
     82       scope, Range(scope, zero, Prod(scope, partitions_shape, zero), one),
     83       partitions_shape);
     84   // dynamic_partition(
     85   //   [0, 1, 2, 3, 4],
     86   //   [0, 0, 1, 1, 0], 2)
     87   //  = { [0, 1, 4],
     88   //      [2, 3] }
     89   auto partitioned_indices =
     90       DynamicPartition(scope, original_indices, partitions, num_partitions);
     91 
     92   // Invert these indices with dynamic_stitch to map the incoming
     93   // gradients to their source inputs.
     94   // dynamic_stitch(
     95   //   { [0, 1, 4], [2, 3] },
     96   //   { [g1, g2, g3], [g4, g5] })
     97   // = [g1, g2, g4, g5, g3]
     98   auto reconstructed =
     99       DynamicStitch(scope, partitioned_indices.outputs, grad_inputs);
    100   // reshape back into a data-shaped tensor to propagate gradients for the data
    101   // input.
    102   grad_outputs->push_back(Reshape(scope, reconstructed, Shape(scope, data)));
    103   // Stop propagation along the partitions input
    104   grad_outputs->push_back(NoGradient());
    105   return scope.status();
    106 }
    107 REGISTER_GRADIENT_OP("DynamicPartition", DynamicPartitionGrad);
    108 
    109 Status DynamicStitchGrad(const Scope& scope, const Operation& op,
    110                          const std::vector<Output>& grad_inputs,
    111                          std::vector<Output>* grad_outputs) {
    112   // Running example:
    113   // indices = {2, [1, 0]}
    114   // data = {[d_1, d_2], [[d_3, d_4], [d_5, d_6]]}
    115   // out = [[d_5, d_6], [d_3, d_4], [d_1, d_2]]
    116   // grad = [[g_1, g_2], [g_3, g_4], [g_5, g_6]]
    117 
    118   // indices and data are two equal-sized lists passed
    119   // into DynamicStitch.
    120   // num_values = 2
    121   int32 num_values = op.num_inputs() / 2;
    122 
    123   // Stop propagation along the indices list
    124   for (int32 i = 0; i < num_values; i++) {
    125     grad_outputs->push_back(NoGradient());
    126   }
    127 
    128   // DynamicStitch shuffles its data to the output (using items in
    129   // indices) so the gradient propagated to a given data input simply
    130   // selects the gradient for its output position.
    131   for (int32 i = 0; i < num_values; i++) {
    132     // index has the destination positions for the i'th data
    133     // element. We cast it into an int32 if necessary, so we can use
    134     // it from a Gather op.
    135     // i = 0: index = 2
    136     // i = 1: index = [1, 0]
    137     auto index = op.input(i);
    138     if (index.type() != DT_INT32) {
    139       index = Cast(scope, index, DT_INT32);
    140     }
    141     // Gather the index specified locations in the gradient and
    142     // propagate it as the gradient for the i'th data item.
    143     // i = 0: gather(grad, 2) = [g_5, g_6]
    144     // i = 1: gather(grad, [1, 0]) = [[g_3, g_4], [g_1, g_2]]
    145     grad_outputs->push_back(Gather(scope, grad_inputs[0], index));
    146   }
    147 
    148   return scope.status();
    149 }
    150 REGISTER_GRADIENT_OP("DynamicStitch", DynamicStitchGrad);
    151 REGISTER_GRADIENT_OP("ParallelDynamicStitch", DynamicStitchGrad);
    152 
    153 }  // anonymous namespace
    154 }  // namespace ops
    155 }  // namespace tensorflow
    156