1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/cc/ops/data_flow_ops.h" 17 #include "tensorflow/cc/ops/data_flow_ops_internal.h" 18 #include "tensorflow/cc/ops/standard_ops.h" 19 20 #include "tensorflow/cc/framework/grad_op_registry.h" 21 #include "tensorflow/cc/framework/gradients.h" 22 23 namespace tensorflow { 24 namespace ops { 25 namespace { 26 27 REGISTER_NO_GRADIENT_OP("Queue"); 28 REGISTER_NO_GRADIENT_OP("QueueEnqueue"); 29 REGISTER_NO_GRADIENT_OP("QueueEnqueueMany"); 30 REGISTER_NO_GRADIENT_OP("QueueDequeue"); 31 REGISTER_NO_GRADIENT_OP("QueueDequeueMany"); 32 REGISTER_NO_GRADIENT_OP("QueueDequeueUpTo"); 33 REGISTER_NO_GRADIENT_OP("QueueClose"); 34 REGISTER_NO_GRADIENT_OP("QueueSize"); 35 REGISTER_NO_GRADIENT_OP("Stack"); 36 REGISTER_NO_GRADIENT_OP("StackPush"); 37 REGISTER_NO_GRADIENT_OP("StackPop"); 38 REGISTER_NO_GRADIENT_OP("StackClose"); 39 REGISTER_NO_GRADIENT_OP("GetSessionHandle"); 40 REGISTER_NO_GRADIENT_OP("GetSessionHandleV2"); 41 REGISTER_NO_GRADIENT_OP("GetSessionTensor"); 42 REGISTER_NO_GRADIENT_OP("DeleteSessionTensor"); 43 44 Status DynamicPartitionGrad(const Scope& scope, const Operation& op, 45 const std::vector<Output>& grad_inputs, 46 std::vector<Output>* grad_outputs) { 47 // DynamicPartition only moves input values into various positions 48 // in the output, so the gradient operation only has to map incoming 49 // gradients into their input source locations. 50 // running example: 51 // data = [10, 20, 30, 40, 50] 52 // partitions = [0, 0, 1, 1, 0] 53 // num_partitions = 2 54 // dynamic_partition(data, partitions, num_partitions) = { 55 // [10, 20, 50], 56 // [30, 40] 57 // } 58 // grads = { 59 // [g1, g2, g3], 60 // [g4, g5] 61 // } 62 // The desired propagation of the gradients back to the data inputs is: 63 // [g1, g2, g4, g5, g3] 64 auto data = op.input(0); 65 auto partitions = op.input(1); 66 int32 num_partitions; 67 TF_RETURN_IF_ERROR( 68 GetNodeAttr(op.node()->attrs(), "num_partitions", &num_partitions)); 69 70 // Note: the shape of the partitions is a prefix of the data shape. 71 // shape(partitions) = [5] 72 auto partitions_shape = Shape(scope, partitions); 73 // We now create a partitions-shaped tensor with integers from 74 // [0..size(partitions)) This will be dynamic_partitioned with the 75 // input parameters, providing the destination index for a given 76 // source item. 77 // partitions_size = prod([5]) = 5 78 // reshape(range(partitions_size), [5]) = [0, 1, 2, 3, 4] 79 auto zero = Const(scope, 0); 80 auto one = Const(scope, 1); 81 auto original_indices = Reshape( 82 scope, Range(scope, zero, Prod(scope, partitions_shape, zero), one), 83 partitions_shape); 84 // dynamic_partition( 85 // [0, 1, 2, 3, 4], 86 // [0, 0, 1, 1, 0], 2) 87 // = { [0, 1, 4], 88 // [2, 3] } 89 auto partitioned_indices = 90 DynamicPartition(scope, original_indices, partitions, num_partitions); 91 92 // Invert these indices with dynamic_stitch to map the incoming 93 // gradients to their source inputs. 94 // dynamic_stitch( 95 // { [0, 1, 4], [2, 3] }, 96 // { [g1, g2, g3], [g4, g5] }) 97 // = [g1, g2, g4, g5, g3] 98 auto reconstructed = 99 DynamicStitch(scope, partitioned_indices.outputs, grad_inputs); 100 // reshape back into a data-shaped tensor to propagate gradients for the data 101 // input. 102 grad_outputs->push_back(Reshape(scope, reconstructed, Shape(scope, data))); 103 // Stop propagation along the partitions input 104 grad_outputs->push_back(NoGradient()); 105 return scope.status(); 106 } 107 REGISTER_GRADIENT_OP("DynamicPartition", DynamicPartitionGrad); 108 109 Status DynamicStitchGrad(const Scope& scope, const Operation& op, 110 const std::vector<Output>& grad_inputs, 111 std::vector<Output>* grad_outputs) { 112 // Running example: 113 // indices = {2, [1, 0]} 114 // data = {[d_1, d_2], [[d_3, d_4], [d_5, d_6]]} 115 // out = [[d_5, d_6], [d_3, d_4], [d_1, d_2]] 116 // grad = [[g_1, g_2], [g_3, g_4], [g_5, g_6]] 117 118 // indices and data are two equal-sized lists passed 119 // into DynamicStitch. 120 // num_values = 2 121 int32 num_values = op.num_inputs() / 2; 122 123 // Stop propagation along the indices list 124 for (int32 i = 0; i < num_values; i++) { 125 grad_outputs->push_back(NoGradient()); 126 } 127 128 // DynamicStitch shuffles its data to the output (using items in 129 // indices) so the gradient propagated to a given data input simply 130 // selects the gradient for its output position. 131 for (int32 i = 0; i < num_values; i++) { 132 // index has the destination positions for the i'th data 133 // element. We cast it into an int32 if necessary, so we can use 134 // it from a Gather op. 135 // i = 0: index = 2 136 // i = 1: index = [1, 0] 137 auto index = op.input(i); 138 if (index.type() != DT_INT32) { 139 index = Cast(scope, index, DT_INT32); 140 } 141 // Gather the index specified locations in the gradient and 142 // propagate it as the gradient for the i'th data item. 143 // i = 0: gather(grad, 2) = [g_5, g_6] 144 // i = 1: gather(grad, [1, 0]) = [[g_3, g_4], [g_1, g_2]] 145 grad_outputs->push_back(Gather(scope, grad_inputs[0], index)); 146 } 147 148 return scope.status(); 149 } 150 REGISTER_GRADIENT_OP("DynamicStitch", DynamicStitchGrad); 151 REGISTER_GRADIENT_OP("ParallelDynamicStitch", DynamicStitchGrad); 152 153 } // anonymous namespace 154 } // namespace ops 155 } // namespace tensorflow 156