1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <vector> 17 18 #include "tensorflow/cc/ops/array_ops_internal.h" 19 #include "tensorflow/cc/ops/standard_ops.h" 20 #include "tensorflow/core/lib/strings/strcat.h" 21 22 #include "tensorflow/cc/framework/grad_op_registry.h" 23 #include "tensorflow/cc/framework/gradients.h" 24 25 namespace tensorflow { 26 namespace ops { 27 namespace { 28 29 REGISTER_NO_GRADIENT_OP("Const"); 30 REGISTER_NO_GRADIENT_OP("StopGradient"); 31 REGISTER_NO_GRADIENT_OP("ConcatOffset"); 32 REGISTER_NO_GRADIENT_OP("EditDistance"); 33 REGISTER_NO_GRADIENT_OP("ZerosLike"); 34 REGISTER_NO_GRADIENT_OP("InvertPermutation"); 35 REGISTER_NO_GRADIENT_OP("Shape"); 36 REGISTER_NO_GRADIENT_OP("ShapeN"); 37 REGISTER_NO_GRADIENT_OP("Rank"); 38 REGISTER_NO_GRADIENT_OP("Size"); 39 REGISTER_NO_GRADIENT_OP("BroadcastGradientArgs"); 40 REGISTER_NO_GRADIENT_OP("OneHot"); 41 42 Status PackGrad(const Scope& scope, const Operation& op, 43 const std::vector<Output>& grad_inputs, 44 std::vector<Output>* grad_outputs) { 45 int N; 46 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "N", &N)); 47 int axis; 48 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis)); 49 50 grad_outputs->reserve(N); 51 auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis)); 52 for (const Output& o : grad_op.output) { 53 grad_outputs->emplace_back(o); 54 } 55 return scope.status(); 56 } 57 REGISTER_GRADIENT_OP("Pack", PackGrad); 58 59 Status UnpackGrad(const Scope& scope, const Operation& op, 60 const std::vector<Output>& grad_inputs, 61 std::vector<Output>* grad_outputs) { 62 int axis; 63 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis)); 64 grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis))); 65 return scope.status(); 66 } 67 REGISTER_GRADIENT_OP("Unpack", UnpackGrad); 68 69 Status IdentityGrad(const Scope& scope, const Operation& op, 70 const std::vector<Output>& grad_inputs, 71 std::vector<Output>* grad_outputs) { 72 grad_outputs->push_back(Identity(scope, grad_inputs[0])); 73 return scope.status(); 74 } 75 REGISTER_GRADIENT_OP("Identity", IdentityGrad); 76 77 Status RefIdentityGrad(const Scope& scope, const Operation& op, 78 const std::vector<Output>& grad_inputs, 79 std::vector<Output>* grad_outputs) { 80 grad_outputs->push_back(Identity(scope, grad_inputs[0])); 81 return scope.status(); 82 } 83 REGISTER_GRADIENT_OP("RefIdentity", RefIdentityGrad); 84 85 Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op, 86 const std::vector<Output>& grad_inputs, 87 std::vector<Output>* grad_outputs) { 88 grad_outputs->push_back(Identity(scope, grad_inputs[0])); 89 return scope.status(); 90 } 91 REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad); 92 93 Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op, 94 const std::vector<Output>& grad_inputs, 95 std::vector<Output>* grad_outputs) { 96 grad_outputs->push_back(Identity(scope, grad_inputs[0])); 97 grad_outputs->push_back(NoGradient()); 98 grad_outputs->push_back(NoGradient()); 99 return scope.status(); 100 } 101 REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad); 102 103 Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op, 104 const std::vector<Output>& grad_inputs, 105 std::vector<Output>* grad_outputs) { 106 grad_outputs->push_back(Identity(scope, grad_inputs[0])); 107 grad_outputs->push_back(NoGradient()); 108 grad_outputs->push_back(NoGradient()); 109 grad_outputs->push_back(NoGradient()); 110 return scope.status(); 111 } 112 REGISTER_GRADIENT_OP("QuantizeAndDequantizeV3", QuantizeAndDequantizeV3Grad); 113 114 Status SplitGrad(const Scope& scope, const Operation& op, 115 const std::vector<Output>& grad_inputs, 116 std::vector<Output>* grad_outputs) { 117 grad_outputs->push_back(NoGradient()); 118 grad_outputs->push_back(Concat(scope, grad_inputs, op.input(0))); 119 return scope.status(); 120 } 121 REGISTER_GRADIENT_OP("Split", SplitGrad); 122 123 Status DiagGrad(const Scope& scope, const Operation& op, 124 const std::vector<Output>& grad_inputs, 125 std::vector<Output>* grad_outputs) { 126 grad_outputs->push_back(DiagPart(scope, grad_inputs[0])); 127 return scope.status(); 128 } 129 REGISTER_GRADIENT_OP("Diag", DiagGrad); 130 131 Status DiagPartGrad(const Scope& scope, const Operation& op, 132 const std::vector<Output>& grad_inputs, 133 std::vector<Output>* grad_outputs) { 134 grad_outputs->push_back(Diag(scope, grad_inputs[0])); 135 return scope.status(); 136 } 137 REGISTER_GRADIENT_OP("DiagPart", DiagPartGrad); 138 139 Status MatrixDiagGrad(const Scope& scope, const Operation& op, 140 const std::vector<Output>& grad_inputs, 141 std::vector<Output>* grad_outputs) { 142 grad_outputs->push_back(MatrixDiagPart(scope, grad_inputs[0])); 143 return scope.status(); 144 } 145 REGISTER_GRADIENT_OP("MatrixDiag", MatrixDiagGrad); 146 147 Status MatrixBandPartGrad(const Scope& scope, const Operation& op, 148 const std::vector<Output>& grad_inputs, 149 std::vector<Output>* grad_outputs) { 150 auto num_lower = op.input(1); 151 auto num_upper = op.input(2); 152 grad_outputs->push_back( 153 MatrixBandPart(scope, grad_inputs[0], num_lower, num_upper)); 154 grad_outputs->push_back(NoGradient()); 155 grad_outputs->push_back(NoGradient()); 156 return scope.status(); 157 } 158 REGISTER_GRADIENT_OP("MatrixBandPart", MatrixBandPartGrad); 159 160 Status GatherNdGrad(const Scope& scope, const Operation& op, 161 const std::vector<Output>& grad_inputs, 162 std::vector<Output>* grad_outputs) { 163 auto ref = op.input(0); 164 auto ref_shape = Shape(scope, ref); 165 auto indices = op.input(1); 166 grad_outputs->push_back(ScatterNd(scope, indices, grad_inputs[0], ref_shape)); 167 grad_outputs->push_back(NoGradient()); 168 return scope.status(); 169 } 170 REGISTER_GRADIENT_OP("GatherNd", GatherNdGrad); 171 172 Status CheckNumericsGrad(const Scope& scope, const Operation& op, 173 const std::vector<Output>& grad_inputs, 174 std::vector<Output>* grad_outputs) { 175 string message; 176 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message)); 177 string err_msg = strings::StrCat( 178 "Not a number (NaN) or infinity (Inf) values detected in gradient. ", 179 message); 180 grad_outputs->push_back(CheckNumerics(scope, grad_inputs[0], err_msg)); 181 return scope.status(); 182 } 183 REGISTER_GRADIENT_OP("CheckNumerics", CheckNumericsGrad); 184 185 Status ReshapeGrad(const Scope& scope, const Operation& op, 186 const std::vector<Output>& grad_inputs, 187 std::vector<Output>* grad_outputs) { 188 auto input_shape = Shape(scope, op.input(0)); 189 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape)); 190 grad_outputs->push_back(NoGradient()); 191 return scope.status(); 192 } 193 REGISTER_GRADIENT_OP("Reshape", ReshapeGrad); 194 195 Status ExpandDimsGrad(const Scope& scope, const Operation& op, 196 const std::vector<Output>& grad_inputs, 197 std::vector<Output>* grad_outputs) { 198 auto input_shape = Shape(scope, op.input(0)); 199 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape)); 200 grad_outputs->push_back(NoGradient()); 201 return scope.status(); 202 } 203 REGISTER_GRADIENT_OP("ExpandDims", ExpandDimsGrad); 204 205 Status SqueezeGrad(const Scope& scope, const Operation& op, 206 const std::vector<Output>& grad_inputs, 207 std::vector<Output>* grad_outputs) { 208 auto input_shape = Shape(scope, op.input(0)); 209 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape)); 210 return scope.status(); 211 } 212 REGISTER_GRADIENT_OP("Squeeze", SqueezeGrad); 213 214 Status TransposeGrad(const Scope& scope, const Operation& op, 215 const std::vector<Output>& grad_inputs, 216 std::vector<Output>* grad_outputs) { 217 auto inverted_perm = InvertPermutation(scope, op.input(1)); 218 grad_outputs->push_back(Transpose(scope, grad_inputs[0], inverted_perm)); 219 grad_outputs->push_back(NoGradient()); 220 return scope.status(); 221 } 222 REGISTER_GRADIENT_OP("Transpose", TransposeGrad); 223 224 Status ReverseSequenceGrad(const Scope& scope, const Operation& op, 225 const std::vector<Output>& grad_inputs, 226 std::vector<Output>* grad_outputs) { 227 auto seq_lengths = op.input(1); 228 int batch_dim; 229 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim)); 230 int seq_dim; 231 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "seq_dim", &seq_dim)); 232 grad_outputs->push_back( 233 ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim, 234 ReverseSequence::BatchDim(batch_dim))); 235 grad_outputs->push_back(NoGradient()); 236 return scope.status(); 237 } 238 REGISTER_GRADIENT_OP("ReverseSequence", ReverseSequenceGrad); 239 240 Status ReverseGrad(const Scope& scope, const Operation& op, 241 const std::vector<Output>& grad_inputs, 242 std::vector<Output>* grad_outputs) { 243 auto reverse_dims = op.input(1); 244 grad_outputs->push_back(Reverse(scope, grad_inputs[0], reverse_dims)); 245 grad_outputs->push_back(NoGradient()); 246 return scope.status(); 247 } 248 REGISTER_GRADIENT_OP("ReverseV2", ReverseGrad); 249 250 Status ScatterNdGrad(const Scope& scope, const Operation& op, 251 const std::vector<Output>& grad_inputs, 252 std::vector<Output>* grad_outputs) { 253 auto indices = op.input(0); 254 grad_outputs->push_back(NoGradient()); 255 grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices)); 256 grad_outputs->push_back(NoGradient()); 257 return scope.status(); 258 } 259 REGISTER_GRADIENT_OP("ScatterNd", ScatterNdGrad); 260 261 Status ScatterNdNonAliasingAddGrad(const Scope& scope, const Operation& op, 262 const std::vector<Output>& grad_inputs, 263 std::vector<Output>* grad_outputs) { 264 auto indices = op.input(1); 265 grad_outputs->push_back(Identity(scope, grad_inputs[0])); 266 grad_outputs->push_back(NoGradient()); 267 grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices)); 268 return scope.status(); 269 } 270 REGISTER_GRADIENT_OP("ScatterNdNonAliasingAdd", ScatterNdNonAliasingAddGrad); 271 272 template <bool IsPadV2> 273 Status PadGrad(const Scope& scope, const Operation& op, 274 const std::vector<Output>& grad_inputs, 275 std::vector<Output>* grad_outputs) { 276 auto x = op.input(0); 277 auto a = op.input(1); // [Rank(x), 2] 278 // Takes a slice of a. The 1st column. [Rank(x), 1]. 279 auto size = Stack(scope, {Rank(scope, x), 1}); 280 auto pad_before = Slice(scope, a, {0, 0}, size); 281 // Make it a 1-D tensor. 282 auto begin = Reshape(scope, pad_before, {-1}); 283 grad_outputs->push_back(Slice(scope, grad_inputs[0], begin, Shape(scope, x))); 284 grad_outputs->push_back(NoGradient()); 285 // PadV2 adds a "constant_values" input. 286 if (IsPadV2) { 287 grad_outputs->push_back(NoGradient()); 288 } 289 return scope.status(); 290 } 291 REGISTER_GRADIENT_OP("Pad", PadGrad<false>); 292 REGISTER_GRADIENT_OP("PadV2", PadGrad<true>); 293 294 Status SpaceToBatchGrad(const Scope& scope, const Operation& op, 295 const std::vector<Output>& grad_inputs, 296 std::vector<Output>* grad_outputs) { 297 int block_size; 298 TF_RETURN_IF_ERROR( 299 GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); 300 grad_outputs->push_back( 301 BatchToSpace(scope, grad_inputs[0], op.input(1), block_size)); 302 grad_outputs->push_back(NoGradient()); 303 return scope.status(); 304 } 305 REGISTER_GRADIENT_OP("SpaceToBatch", SpaceToBatchGrad); 306 307 Status SpaceToBatchNDGrad(const Scope& scope, const Operation& op, 308 const std::vector<Output>& grad_inputs, 309 std::vector<Output>* grad_outputs) { 310 grad_outputs->push_back( 311 BatchToSpaceND(scope, grad_inputs[0], op.input(1), op.input(2))); 312 grad_outputs->push_back(NoGradient()); 313 grad_outputs->push_back(NoGradient()); 314 return scope.status(); 315 } 316 REGISTER_GRADIENT_OP("SpaceToBatchND", SpaceToBatchNDGrad); 317 318 Status BatchToSpaceGrad(const Scope& scope, const Operation& op, 319 const std::vector<Output>& grad_inputs, 320 std::vector<Output>* grad_outputs) { 321 int block_size; 322 TF_RETURN_IF_ERROR( 323 GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); 324 grad_outputs->push_back( 325 SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size)); 326 grad_outputs->push_back(NoGradient()); 327 return scope.status(); 328 } 329 REGISTER_GRADIENT_OP("BatchToSpace", BatchToSpaceGrad); 330 331 Status BatchToSpaceNDGrad(const Scope& scope, const Operation& op, 332 const std::vector<Output>& grad_inputs, 333 std::vector<Output>* grad_outputs) { 334 grad_outputs->push_back( 335 SpaceToBatchND(scope, grad_inputs[0], op.input(1), op.input(2))); 336 grad_outputs->push_back(NoGradient()); 337 grad_outputs->push_back(NoGradient()); 338 return scope.status(); 339 } 340 REGISTER_GRADIENT_OP("BatchToSpaceND", BatchToSpaceNDGrad); 341 342 Status SpaceToDepthGrad(const Scope& scope, const Operation& op, 343 const std::vector<Output>& grad_inputs, 344 std::vector<Output>* grad_outputs) { 345 int block_size; 346 TF_RETURN_IF_ERROR( 347 GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); 348 grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size)); 349 return scope.status(); 350 } 351 REGISTER_GRADIENT_OP("SpaceToDepth", SpaceToDepthGrad); 352 353 Status DepthToSpaceGrad(const Scope& scope, const Operation& op, 354 const std::vector<Output>& grad_inputs, 355 std::vector<Output>* grad_outputs) { 356 int block_size; 357 TF_RETURN_IF_ERROR( 358 GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); 359 grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size)); 360 return scope.status(); 361 } 362 REGISTER_GRADIENT_OP("DepthToSpace", DepthToSpaceGrad); 363 364 Status MirrorPadGrad(const Scope& scope, const Operation& op, 365 const std::vector<Output>& grad_inputs, 366 std::vector<Output>* grad_outputs) { 367 string mode; 368 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode)); 369 grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad( 370 scope, grad_inputs[0], op.input(1), mode)); 371 grad_outputs->push_back(NoGradient()); 372 return scope.status(); 373 } 374 REGISTER_GRADIENT_OP("MirrorPad", MirrorPadGrad); 375 376 // TODO(suharshs): b/34770860. This gradient was within 1e-3 but not 1e-4. 377 Status MirrorPadGradGrad(const Scope& scope, const Operation& op, 378 const std::vector<Output>& grad_inputs, 379 std::vector<Output>* grad_outputs) { 380 string mode; 381 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode)); 382 grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode)); 383 grad_outputs->push_back(NoGradient()); 384 return scope.status(); 385 } 386 REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad); 387 388 } // anonymous namespace 389 } // namespace ops 390 } // namespace tensorflow 391