1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <deque> 17 #include <vector> 18 19 #include "tensorflow/cc/framework/grad_op_registry.h" 20 #include "tensorflow/cc/framework/gradients.h" 21 #include "tensorflow/cc/framework/while_gradients.h" 22 #include "tensorflow/cc/ops/standard_ops.h" 23 #include "tensorflow/core/framework/function.h" 24 #include "tensorflow/core/framework/node_def_util.h" 25 #include "tensorflow/core/framework/op.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/graph/algorithm.h" 28 #include "tensorflow/core/graph/graph_constructor.h" 29 #include "tensorflow/core/graph/while_context.h" 30 #include "tensorflow/core/lib/gtl/map_util.h" 31 #include "tensorflow/core/platform/macros.h" 32 33 namespace tensorflow { 34 namespace { 35 36 struct OutputHash { 37 uint64 operator()(const Output& x) const { 38 return x.hash(); 39 } 40 }; 41 42 struct OutputEq { 43 bool operator()(const Output& x, const Output& y) const { 44 return (x.node() == y.node()) && (x.index() == y.index()); 45 } 46 }; 47 48 class SymbolicGradientBuilder { 49 public: 50 SymbolicGradientBuilder(const Scope& scope, 51 const ops::GradOpRegistry* registry, 52 const std::vector<Output>& outputs, 53 const std::vector<Output>& inputs, 54 const std::vector<Output>& grad_inputs, 55 std::vector<Output>* grad_outputs); 56 57 Status AddGradients(); 58 59 static Output NoGradient() { return Output(nullptr, -1); } 60 61 private: 62 Status Initialize(); 63 64 // For each forward edge from `src` to `dst` in the initial/forward graph: 65 // propagates gradients `dst_grad` backwards along the edge from `src` 66 // to `dst` in the graph. This will add `dst_grad` to the list of pending 67 // gradients for the node associated with `src`. 68 Status BackpropAlongEdge(const Output& dst_grad, const Output& src); 69 70 // Adds a node to the graph (returned in `grad`) that sums the in-bound 71 // gradients to `src` (if there are more than one). 72 Status SumGradients(const Output& src, Output* grad); 73 74 // Returns true if `opname` is registered in `registry_` with no gradient 75 // function, false otherwise. 76 bool IsPrimitiveOpWithNoGrad(const string& opname); 77 78 // Call the gradient function for `op`, storing the result in `grad_outputs`. 79 Status CallGradFunction(const Operation& op, 80 const std::vector<Output>& grad_inputs, 81 std::vector<Output>* grad_outputs); 82 83 // Returns a list mapping whether each node in the graph is reachable 84 // from outputs_. Keyed by node id. 85 std::vector<bool> GetReachableNodes(); 86 87 // Creates the gradient subgraph for a while loop (or just stores 88 // `summed_grads` if not all incoming gradients are available yet). All exit 89 // nodes (which are the first nodes of a loop encountered in the backwards 90 // pass) are passed to this function rather than processed normally. 91 // `summed_grads` is the sum of `exit_node`s gradients. 92 Status ProcessWhileLoop(Node* exit_node, const Output& summed_grads); 93 94 // Gets the set of node ids at which to stop backprop. These are all elements 95 // of `outputs_` that do not get transitively consumed by other `outputs_`. 96 // Used to identify nodes at which to stop backprop. 97 std::unordered_set<int> GetStopBackpropNodes( 98 const std::vector<bool>& reachable_nodes, 99 std::unordered_set<int> output_nodes); 100 101 const Scope& scope_; 102 const ops::GradOpRegistry* registry_; 103 const std::vector<Output>& outputs_; 104 const std::vector<Output>& inputs_; 105 const std::vector<Output>& grad_inputs_; 106 std::vector<Output>* grad_outputs_; 107 108 // A vector of output endpoints which represents backpropagated gradients. 109 typedef std::vector<Output> BackproppedGradients; 110 111 // backprops_ is a map from a node output to its accumulated 112 // gradients. When a node output has accumulated all its 113 // gradients, we add a node which sums them up. 114 std::unordered_map<Output, BackproppedGradients, OutputHash, OutputEq> 115 backprops_; 116 117 // pending[i] is count-down counter for i-th node's expected 118 // backprops. When pending[i] becomes zero, we collected all 119 // backprop gradients for all outputs of the ith-node. 120 std::vector<int> pending_; 121 122 // `ready` keeps track of nodes that have been completely 123 // backpropped. Initially, for every output in `outputs_`, we add initial 124 // gradients from `grad_inputs_`. 125 std::deque<Node*> ready_; 126 127 // The set of node ids in `inputs_`. Used to identify nodes at backprop 128 // frontier. Maps from Output -> index into `grad_outputs_`. 129 std::unordered_map<Output, int, OutputHash, OutputEq> input_nodes_; 130 131 // For each while loop in the graph, collects the summed gradients for each of 132 // the loop's exit nodes. Note that unlike backprops_, this map contains the 133 // output of SumGradients(), not the input (i.e. each exit node may have 134 // multiple incoming gradients, but we only store the combined Output here). 135 std::map<WhileContext*, std::map<Node*, Output>> while_backprops_; 136 137 TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientBuilder); 138 }; 139 140 SymbolicGradientBuilder::SymbolicGradientBuilder( 141 const Scope& scope, const ops::GradOpRegistry* registry, 142 const std::vector<Output>& outputs, const std::vector<Output>& inputs, 143 const std::vector<Output>& grad_inputs, std::vector<Output>* grad_outputs) 144 : scope_(scope), 145 registry_(registry), 146 outputs_(outputs), 147 inputs_(inputs), 148 grad_inputs_(grad_inputs), 149 grad_outputs_(grad_outputs) {} 150 151 Status SymbolicGradientBuilder::BackpropAlongEdge(const Output& dst_grad, 152 const Output& src) { 153 if (src.node() == nullptr) { 154 return errors::Internal("Attempted to backprop along an invalid edge."); 155 } 156 auto iter = backprops_.find(src); 157 if (iter != backprops_.end()) { 158 auto* grads = &iter->second; 159 grads->push_back(dst_grad); 160 if (--pending_[src.node()->id()] == 0) { 161 ready_.push_back(src.node()); 162 } 163 } 164 return Status::OK(); 165 } 166 167 std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() { 168 std::vector<bool> reachable_nodes(scope_.graph()->num_node_ids(), false); 169 std::deque<Node*> queue; 170 std::vector<bool> visited(scope_.graph()->num_node_ids(), false); 171 for (const Output& out : outputs_) { 172 if (!reachable_nodes[out.node()->id()]) { 173 queue.push_back(out.node()); 174 reachable_nodes[out.node()->id()] = true; 175 } 176 } 177 178 while (!queue.empty()) { 179 Node* n = queue.front(); 180 queue.pop_front(); 181 for (const Edge* e : n->in_edges()) { 182 if (e->IsControlEdge()) continue; 183 if (visited[e->src()->id()]) continue; 184 queue.push_back(e->src()); 185 reachable_nodes[e->src()->id()] = true; 186 visited[e->src()->id()] = true; 187 } 188 } 189 return reachable_nodes; 190 } 191 192 std::unordered_set<int> SymbolicGradientBuilder::GetStopBackpropNodes( 193 const std::vector<bool>& reachable_nodes, 194 std::unordered_set<int> output_nodes) { 195 // Output nodes that get transitively consumed by other `outputs_` are stored 196 // in `internal_outputs`. 197 std::unordered_set<int> internal_outputs; 198 std::unordered_set<Node*> visited; 199 // Initialize `queue` for BFS traversal. Nodes in `queue` hold upcoming nodes 200 // along with the last Node in `output_` encountered along that path. If no 201 // `output_` node was encountered, pair.second will be nullptr. 202 std::deque<std::pair<Node*, Node*>> queue; 203 for (const Output& nout : inputs_) { 204 if (visited.find(nout.node()) == visited.end()) { 205 queue.push_back(std::make_pair(nout.node(), static_cast<Node*>(nullptr))); 206 visited.insert(nout.node()); 207 } 208 } 209 // BFS from nodes in 'inputs_' along out edges for the entire graph. Internal 210 // output nodes are recorded during the traversal. All nodes that are output 211 // nodes but not internal output nodes are considered the frontier of the 212 // output nodes, and thus our stop backprop nodes. 213 while (!queue.empty()) { 214 std::pair<Node*, Node*> p = queue.front(); 215 Node* n = p.first; 216 queue.pop_front(); 217 for (const Edge* e : n->out_edges()) { 218 // If a node is not reachable from outputs_, we can stop. 219 if (e->IsControlEdge() || !reachable_nodes[e->dst()->id()]) continue; 220 if (visited.find(e->dst()) != visited.end()) continue; 221 222 int node_id = e->dst()->id(); 223 Node* last_output_node = p.second; 224 if (output_nodes.find(node_id) != output_nodes.end()) { 225 // We reached an output node. 226 if (last_output_node != nullptr) { 227 // If we had already found an output node on this path so we mark 228 // it as an internal output. 229 internal_outputs.insert(last_output_node->id()); 230 } 231 // Mark this newly found output node to insert in the queue. 232 last_output_node = e->dst(); 233 } 234 queue.push_back(std::make_pair(e->dst(), last_output_node)); 235 visited.insert(e->dst()); 236 } 237 } 238 // Finally, we set stop_backprop_nodes to all output_nodes that aren't also 239 // internal_outputs. 240 std::unordered_set<int> stop_backprop_nodes; 241 for (int output_node : output_nodes) { 242 if (internal_outputs.find(output_node) == internal_outputs.end()) { 243 stop_backprop_nodes.insert(output_node); 244 } 245 } 246 return stop_backprop_nodes; 247 } 248 249 Status SymbolicGradientBuilder::Initialize() { 250 if (outputs_.size() != grad_inputs_.size()) { 251 return errors::InvalidArgument( 252 "Must specify a gradient input for each output."); 253 } 254 std::vector<bool> reachable_nodes = GetReachableNodes(); 255 for (const Output& input : inputs_) { 256 if (!reachable_nodes[input.node()->id()]) { 257 return errors::InvalidArgument( 258 "Cannot compute the partial derivative for node '", 259 input.node()->name(), 260 "' as it's unreachable from the output node(s)."); 261 } 262 } 263 grad_outputs_->clear(); 264 grad_outputs_->resize(inputs_.size()); 265 266 std::unordered_set<int> output_nodes; 267 output_nodes.reserve(outputs_.size()); 268 for (size_t i = 0; i < outputs_.size(); ++i) { 269 output_nodes.insert(outputs_[i].node()->id()); 270 } 271 272 std::unordered_set<int> stop_backprop_nodes = 273 GetStopBackpropNodes(reachable_nodes, output_nodes); 274 275 // Populate `input_nodes_` from Outputs in `inputs_`. 276 input_nodes_.reserve(inputs_.size()); 277 for (size_t i = 0; i < inputs_.size(); ++i) { 278 input_nodes_.insert({inputs_[i], i}); 279 } 280 281 // TODO(andydavis) Consider a more efficient data structure for `pending_` to 282 // handle computing gradients over small subgraphs from a very large graph. 283 pending_.resize(scope_.graph()->num_node_ids(), 0); 284 { 285 backprops_.clear(); 286 std::unordered_set<Node*> visited; 287 std::deque<Node*> queue; 288 for (const Output& nout : inputs_) { 289 if (visited.find(nout.node()) == visited.end()) { 290 queue.push_back(nout.node()); 291 visited.insert(nout.node()); 292 } 293 } 294 295 // Going forward to figure out which endpoints need backprop-ed. 296 // A node's endpoints need to be backprop-ed only if one of the 297 // arg node can reach the node via data edges. 298 while (!queue.empty()) { 299 Node* n = queue.front(); 300 queue.pop_front(); 301 for (int i = 0; i < n->num_outputs(); ++i) { 302 backprops_[{n, i}].clear(); 303 } 304 int num_expected_backprops = 0; 305 if (stop_backprop_nodes.find(n->id()) == stop_backprop_nodes.end()) { 306 // Internal node: continue BFS along connected outputs. 307 for (const Edge* e : n->out_edges()) { 308 // If a node is not reachable from outputs_, 309 // we don't expect it to receive a backpropagated gradient. 310 // It will not be counted in num_expected_backprops. 311 if (e->IsControlEdge() || !reachable_nodes[e->dst()->id()]) continue; 312 if (visited.find(e->dst()) == visited.end()) { 313 queue.push_back(e->dst()); 314 visited.insert(e->dst()); 315 } 316 ++num_expected_backprops; 317 } 318 } 319 if (output_nodes.find(n->id()) != output_nodes.end()) { 320 // Output node: update `num_expected_backprops` for each Output in 321 // `outputs_` that references `n`. 322 for (const Output& output : outputs_) { 323 if (output.node() == n) { 324 ++num_expected_backprops; 325 } 326 } 327 } 328 pending_[n->id()] = num_expected_backprops; 329 } 330 } 331 332 { 333 // Initialize backprop with `grad_inputs_`. 334 const size_t num_dy = grad_inputs_.size(); 335 for (size_t i = 0; i < num_dy; ++i) { 336 TF_RETURN_IF_ERROR(BackpropAlongEdge(grad_inputs_[i], outputs_[i])); 337 } 338 } 339 return Status::OK(); 340 } 341 342 Status SymbolicGradientBuilder::SumGradients(const Output& src, Output* grad) { 343 auto iter = backprops_.find(src); 344 if (iter == backprops_.end()) { 345 return errors::Internal( 346 "Unable to find backprop list for node.id ", src.node()->name()); 347 } 348 const auto& grads = iter->second; 349 // Filter any backproped 'NoGradient' Outputs from 'grads' (if needed). 350 // Return any valid backproped gradients that remain after filtering, 351 // or 'NoGradient' otherwise. 352 std::vector<Output> grads_to_keep; 353 for (const Output& o : grads) { 354 if (o == NoGradient()) continue; 355 grads_to_keep.push_back(o); 356 } 357 358 if (grads_to_keep.empty()) { 359 // Nothing propagated back. Return 'NoGradient'. 360 *grad = NoGradient(); 361 } else if (grads_to_keep.size() == 1) { 362 // Just one backprop edge. 363 *grad = grads_to_keep[0]; 364 } else { 365 // Otherwise, adds backprop-ed gradients. 366 // TODO(andydavis) Use a better accumulator here. 367 *grad = ops::AddN(scope_, grads_to_keep); 368 } 369 370 return Status::OK(); 371 } 372 373 bool SymbolicGradientBuilder::IsPrimitiveOpWithNoGrad(const string& opname) { 374 ops::GradFunc grad_fn; 375 Status s = registry_->Lookup(opname, &grad_fn); 376 return s.ok() && (grad_fn == nullptr); 377 } 378 379 Status SymbolicGradientBuilder::CallGradFunction( 380 const Operation& op, 381 const std::vector<Output>& grad_inputs, 382 std::vector<Output>* grad_outputs) { 383 ops::GradFunc grad_fn; 384 TF_RETURN_IF_ERROR(registry_->Lookup(op.node()->type_string(), &grad_fn)); 385 TF_RETURN_IF_ERROR(grad_fn(scope_, op, grad_inputs, grad_outputs)); 386 TF_RETURN_IF_ERROR(scope_.status()); 387 return Status::OK(); 388 } 389 390 Status SymbolicGradientBuilder::ProcessWhileLoop(Node* exit_node, 391 const Output& summed_grads) { 392 // TODO(skyewm): detect second-order gradient and return bad status 393 // TODO(skyewm): handle (or at least detect) nested while loops 394 395 // TODO(skyewm): handle NoGradient in while loop 396 if (summed_grads == NoGradient()) { 397 return errors::Unimplemented( 398 "Missing gradient into while loop not yet implemented"); 399 } 400 401 DCHECK(exit_node->IsExit()); 402 WhileContext* while_ctx = exit_node->while_ctx(); 403 DCHECK(while_ctx != nullptr); 404 405 // Record 'summed_grads' as the backprop input associated with 'exit_node' 406 std::map<Node*, Output>& backprops = while_backprops_[while_ctx]; 407 DCHECK(backprops.find(exit_node) == backprops.end()); 408 backprops[exit_node] = summed_grads; 409 410 // Wait until we have all exit nodes' backprops collected before processing 411 // the while loop. 412 // TODO(skyewm): what if not all the exit nodes are reachable? 413 if (backprops.size() < while_ctx->exit_nodes().size()) return Status::OK(); 414 415 // We've seen all the exit nodes for this loop and have collected all the 416 // backprops. Create the gradient graph for the while loop. 417 Scope while_scope = 418 scope_.NewSubScope(strings::StrCat(while_ctx->frame_name(), "_grad")); 419 std::vector<Output> dy; 420 for (Node* n : while_ctx->exit_nodes()) dy.push_back(backprops[n]); 421 std::vector<Output> dx; 422 TF_RETURN_IF_ERROR(AddWhileLoopGradient(while_ctx, while_scope, dy, &dx)); 423 424 // Backprop along the in edges to the while loop (i.e. the inputs to the enter 425 // nodes) 426 DCHECK_EQ(dx.size(), while_ctx->enter_nodes().size()); 427 for (int i = 0; i < dx.size(); ++i) { 428 Node* enter_node = while_ctx->enter_nodes()[i]; 429 for (const Edge* e : enter_node->in_edges()) { 430 if (e->IsControlEdge()) continue; 431 TF_RETURN_IF_ERROR(BackpropAlongEdge(dx[i], {e->src(), e->src_output()})); 432 } 433 } 434 return Status::OK(); 435 } 436 437 Status SymbolicGradientBuilder::AddGradients() { 438 // Initialize backprops. 439 TF_RETURN_IF_ERROR(Initialize()); 440 441 // Backward propagation. 442 std::vector<Output> dy; 443 while (!ready_.empty()) { 444 // n has collected all gradients. 445 Node* n = ready_.front(); 446 ready_.pop_front(); 447 448 // dy[i] is the sum of i-th output's backpropped gradients. 449 const int num_y = n->num_outputs(); 450 dy.clear(); 451 dy.resize(num_y, {nullptr, 0}); 452 std::vector<int> no_grad_dy_indices; 453 for (int i = 0; i < num_y; ++i) { 454 TF_RETURN_IF_ERROR(SumGradients({n, i}, &dy[i])); 455 if (dy[i] == NoGradient()) { 456 no_grad_dy_indices.push_back(i); 457 } 458 auto iter = input_nodes_.find({n, i}); 459 if (iter != input_nodes_.end()) { 460 // Return gradients for Output in 'grad_outputs_'. 461 (*grad_outputs_)[iter->second] = dy[i]; 462 } 463 } 464 465 // Stop backprop if none of the inputs to `n` are in `backprops_'. 466 bool stop_node = true; 467 for (const Edge* e : n->in_edges()) { 468 if (e->IsControlEdge()) continue; 469 if (backprops_.find({e->src(), e->src_output()}) != backprops_.end()) { 470 stop_node = false; 471 break; 472 } 473 } 474 475 if (stop_node) { 476 continue; 477 } 478 479 // Special case: if we find an exit node, process the associated while loop. 480 // Note that ProcessWhileLoop() calls BackpropAlongEdge() if necessary 481 // (which updates ready_), and we skip all the regular processing below 482 // after calling it. 483 if (n->IsExit()) { 484 DCHECK_EQ(dy.size(), 1); 485 TF_RETURN_IF_ERROR(ProcessWhileLoop(n, dy[0])); 486 continue; 487 } 488 // All loop-specific control flow ops should have been handled above 489 DCHECK(!n->IsEnter() && !n->IsNextIteration()) << n->DebugString(); 490 491 const size_t num_no_grad = no_grad_dy_indices.size(); 492 if (IsPrimitiveOpWithNoGrad(n->type_string()) || num_no_grad == num_y) { 493 // No grad defined for this op, or all outputs returned 'NoGradient': 494 // Backprop 'NoGradient' along the in edges. 495 for (const Edge* e : n->in_edges()) { 496 if (e->IsControlEdge()) continue; 497 TF_RETURN_IF_ERROR( 498 BackpropAlongEdge(NoGradient(), {e->src(), e->src_output()})); 499 } 500 continue; 501 } 502 503 if (num_no_grad > 0 && num_no_grad < num_y) { 504 // The outputs of 'n' returned a mixture of valid gradients and 505 // 'NoGradient'. Therefore, we need to add 'ZerosLike' nodes for each 506 // 'NoGradient' output before we call the gradient function for 'n'. 507 // TODO(andydavis) If static shapes are known, replace 'ZerosLike' with 508 // zero-filled Constant node of appropriate shape. 509 for (const int dy_index : no_grad_dy_indices) { 510 dy[dy_index] = ops::ZerosLike(scope_, Output(n, dy_index)); 511 } 512 } 513 514 // TODO(andydavis) Add option to encapsulate grad function in 515 // SymbolicGradientOp (as opposed to inlining into the graph). 516 std::vector<Output> dx; 517 TF_RETURN_IF_ERROR(CallGradFunction(Operation(n), dy, &dx)); 518 519 // Backprop along the in edges. 520 // TODO(andydavis) Find cleaner way to map each grad output returned by 521 // gradient function to the src node/output to which it should be 522 // backproped. Maybe grad functions can return a vector of Output pairs to 523 // make this association explicit. 524 size_t dx_index = 0; 525 for (const Edge* e : n->in_edges()) { 526 if (e->IsControlEdge()) continue; 527 if (dx_index == dx.size()) { 528 return errors::Internal( 529 "Invalid gradient output index: ", dx_index, " size: ", dx.size()); 530 } 531 TF_RETURN_IF_ERROR( 532 BackpropAlongEdge(dx[dx_index++], {e->src(), e->src_output()})); 533 } 534 } 535 536 // Check if any input nodes still have pending gradients and have not been 537 // processed yet. This happens if not all outputs of a node are in 'inputs_'. 538 std::unordered_map<Node*, int> requested_grads; 539 for (const Output& nout : inputs_) { 540 if (pending_[nout.node()->id()] > 0) { 541 DCHECK_GT(nout.node()->num_outputs(), 1); 542 int idx = input_nodes_[nout]; 543 DCHECK(((*grad_outputs_)[idx].node() == nullptr)); 544 TF_RETURN_IF_ERROR(SumGradients(nout, &(*grad_outputs_)[idx])); 545 ++requested_grads[nout.node()]; 546 } 547 } 548 for (const auto& p : requested_grads) { 549 int num_requested_inputs = p.first->num_outputs() - pending_[p.first->id()]; 550 CHECK_EQ(num_requested_inputs, p.second); 551 } 552 return Status::OK(); 553 } 554 555 } // namespace 556 557 Status AddSymbolicGradients(const Scope& scope, 558 const std::vector<Output>& outputs, 559 const std::vector<Output>& inputs, 560 const std::vector<Output>& grad_inputs, 561 std::vector<Output>* grad_outputs) { 562 SymbolicGradientBuilder builder(scope, ops::GradOpRegistry::Global(), outputs, 563 inputs, grad_inputs, grad_outputs); 564 return builder.AddGradients(); 565 } 566 567 Status AddSymbolicGradients(const Scope& scope, 568 const std::vector<Output>& outputs, 569 const std::vector<Output>& inputs, 570 std::vector<Output>* grad_outputs) { 571 std::vector<Output> grad_inputs; 572 grad_inputs.reserve(outputs.size()); 573 for (const Output& output : outputs) { 574 grad_inputs.emplace_back(ops::OnesLike(scope, output)); 575 } 576 return AddSymbolicGradients(scope, outputs, inputs, grad_inputs, 577 grad_outputs); 578 } 579 580 Output NoGradient() { return SymbolicGradientBuilder::NoGradient(); } 581 582 } // end namespace tensorflow 583