1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" 17 18 #include <unordered_map> 19 #include <unordered_set> 20 21 #include "tensorflow/core/framework/node_def.pb.h" 22 #include "tensorflow/core/framework/op.h" 23 #include "tensorflow/core/grappler/costs/graph_properties.h" 24 #include "tensorflow/core/grappler/grappler_item.h" 25 #include "tensorflow/core/grappler/op_types.h" 26 #include "tensorflow/core/grappler/optimizers/constant_folding.h" 27 #include "tensorflow/core/grappler/utils/topological_sort.h" 28 #include "tensorflow/core/lib/core/errors.h" 29 #include "tensorflow/core/lib/core/stringpiece.h" 30 #include "tensorflow/core/lib/gtl/inlined_vector.h" 31 #include "tensorflow/core/lib/strings/strcat.h" 32 #include "tensorflow/core/util/device_name_utils.h" 33 34 namespace tensorflow { 35 namespace grappler { 36 37 namespace { 38 39 bool RemoveInput(NodeDef* node, const string& input, NodeMap* node_map) { 40 bool removed_input = false; 41 int pos = 0; 42 while (pos < node->input_size()) { 43 if (node->input(pos) == input) { 44 node->mutable_input()->SwapElements(pos, node->input_size() - 1); 45 node->mutable_input()->RemoveLast(); 46 node_map->RemoveOutput(NodeName(input), node->name()); 47 removed_input = true; 48 } else { 49 ++pos; 50 } 51 } 52 return removed_input; 53 } 54 55 void DeleteNodes(const std::set<int>& nodes_to_delete, GraphDef* graph) { 56 int last = graph->node_size() - 1; 57 for (auto it = nodes_to_delete.rbegin(); it != nodes_to_delete.rend(); ++it) { 58 const int index = *it; 59 graph->mutable_node()->SwapElements(index, last); 60 last--; 61 } 62 graph->mutable_node()->DeleteSubrange(last + 1, nodes_to_delete.size()); 63 } 64 65 } // namespace 66 67 bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) { 68 if (!IsIdentity(node)) { 69 return true; 70 } 71 72 if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) { 73 return false; 74 } 75 if (!fetch_nodes_known_) { 76 // The output values of this node may be needed. 77 return false; 78 } 79 const NodeDef* input = node_map_->GetNode(NodeName(node.input(0))); 80 CHECK(input != nullptr) << "node = " << node.name() 81 << " input = " << node.input(0); 82 // Don't remove Identity nodes corresponding to Variable reads or following 83 // Recv. 84 if (IsVariable(*input) || IsRecv(*input)) { 85 return false; 86 } else if (IsSwitch(*input)) { 87 // Don't turn Identity nodes following Switch into NoOp or remove them 88 // if it requires anchoring a control dependencies the Switch node, which 89 // is not valid. 90 if (StringPiece(node.name()).starts_with(kConstantFoldingCtrl)) { 91 // TODO(rmlarsen): Try to remove this artificial contraint. 92 return false; 93 } 94 } 95 for (auto consumer : node_map_->GetOutputs(node.name())) { 96 if (node.input_size() > 1 && IsMerge(*consumer)) { 97 return false; 98 } 99 if (IsSwitch(*input)) { 100 for (const string& consumer_input : consumer->input()) { 101 if (consumer_input == AsControlDependency(node.name())) { 102 return false; 103 } 104 } 105 } 106 } 107 return true; 108 } 109 110 bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) { 111 if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) { 112 return false; 113 } 114 if (!fetch_nodes_known_ || NumNonControlOutputs(node, *node_map_) > 0) { 115 // The output values of this node may be needed. 116 return false; 117 } 118 if (IsMerge(node) || IsSwitch(node)) { 119 return false; 120 } 121 if (ModifiesFrameInfo(node)) { 122 return false; 123 } 124 if (!IsFreeOfSideEffect(node)) { 125 return false; 126 } 127 if (node.op() == "ControlTrigger") { 128 return false; 129 } 130 if (node.op().rfind("Submodel", 0) == 0) { 131 return false; 132 } 133 const OpDef* op_def = nullptr; 134 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); 135 if (!status.ok() || op_def->output_arg_size() == 0) { 136 return false; 137 } 138 139 if (!SafeToRemoveIdentity(node)) { 140 return false; 141 } 142 143 const std::unordered_set<string> do_not_rewrite_ops{ 144 "Assert", "CheckNumerics", "_Retval", 145 "_Arg", "_ParallelConcatUpdate", "_TPUExecute", 146 "_TPUCompile"}; 147 return do_not_rewrite_ops.find(node.op()) == do_not_rewrite_ops.end(); 148 } 149 150 void DependencyOptimizer::OptimizeNode(int node_idx, 151 SetVector<int>* nodes_to_simplify, 152 std::set<int>* nodes_to_delete) { 153 NodeDef* node = optimized_graph_->mutable_node(node_idx); 154 const bool is_noop = IsNoOp(*node); 155 const bool is_identity = IsIdentity(*node); 156 const string node_name = node->name(); 157 // Constant nodes with no input control dependency are always executed early, 158 // so we can prune all their output control dependencies. 159 if (IsConstant(*node) && node->input_size() == 0) { 160 const std::set<NodeDef*> output_nodes = node_map_->GetOutputs(node_name); 161 for (NodeDef* fanout : output_nodes) { 162 bool optimize_fanout = false; 163 bool data_connection = false; 164 for (int i = fanout->input_size() - 1; i >= 0; --i) { 165 int pos; 166 string input_name = ParseNodeName(fanout->input(i), &pos); 167 if (input_name == node_name) { 168 if (pos < 0) { 169 fanout->mutable_input()->SwapElements(i, fanout->input_size() - 1); 170 fanout->mutable_input()->RemoveLast(); 171 optimize_fanout = true; 172 } else { 173 data_connection = true; 174 } 175 } 176 } 177 if (optimize_fanout) { 178 nodes_to_simplify->PushBack(node_to_idx_[fanout]); 179 if (!data_connection) { 180 node_map_->RemoveOutput(node_name, fanout->name()); 181 } 182 } 183 } 184 if (node_map_->GetOutputs(node_name).empty() && fetch_nodes_known_ && 185 nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) { 186 // Mark the node for deletion. 187 nodes_to_delete->insert(node_to_idx_[node]); 188 } 189 return; 190 } 191 192 // Change ops that only have control dependencies as outputs to NoOps. 193 if (!is_noop && SafeToConvertToNoOp(*node)) { 194 VLOG(1) << "***** Replacing " << node_name << " (" << node->op() 195 << ") with NoOp."; 196 // The outputs of this node are not consumed. Replace its inputs with 197 // control dependencies and replace the op itself with the NoOp op. 198 std::unordered_set<string> ctrl_inputs; 199 int pos = 0; 200 while (pos < node->input_size()) { 201 const string old_input = node->input(pos); 202 if (IsControlInput(old_input)) { 203 if (!ctrl_inputs.insert(old_input).second) { 204 // We found a duplicate control input. Remove it. 205 node->mutable_input()->SwapElements(pos, node->input_size() - 1); 206 node->mutable_input()->RemoveLast(); 207 } else { 208 ++pos; 209 } 210 continue; 211 } 212 const string ctrl_input = ConstantFolding::AddControlDependency( 213 old_input, optimized_graph_, node_map_.get()); 214 if (ctrl_inputs.insert(ctrl_input).second) { 215 node->set_input(pos, ctrl_input); 216 node_map_->UpdateInput(node_name, old_input, ctrl_input); 217 const NodeDef* old_input_node = node_map_->GetNode(old_input); 218 nodes_to_simplify->PushBack(node_to_idx_[old_input_node]); 219 } 220 ++pos; 221 } 222 node->set_op("NoOp"); 223 node->clear_attr(); 224 nodes_to_simplify->PushBack(node_to_idx_[node]); 225 return; 226 } 227 228 // Remove NoOp nodes if the product of their fan-in and fan-out is less than 229 // or equal to the sum of the fan-in and fan-out. The non-trivial rewrites 230 // take the following form: 231 // 232 // Case a) 233 // x --^> +------+ x --^> +---+ 234 // y --^> | NoOp | --^> a ==> y --^> | a | 235 // ... | | ... | | 236 // z --^> +------+ z --^> +---+ 237 // 238 // Case b) 239 // +------+ --^> a +---+ --^> a 240 // x --^> | NoOp | --^> b ==> | x | --^> b 241 // | | ... | | ... 242 // +------+ --^> c +---+ --^> c 243 // Case c) 244 // +------+ x ---^> a 245 // x --^> | NoOp | --^> a ==> \/ 246 // y --^> | | --^> b /\ 247 // +------+ y ---^> b 248 // 249 // We only apply this optimization if we don't increase the number of control 250 // edges across device boundaries, e.g. in cases a) and b) if NoOp and 251 // a and x, respectively, are on the same device. Control edges across device 252 // boundaries require inter-device communication (Send/Recv pairs to be 253 // inserted in the graph), which is very costly. 254 // 255 // We also remove identity nodes, subject to the same constraints on number of 256 // resulting control edges and device boundary crossings: 257 // 258 // Case a) 259 // +----------+ ---> a +---+ ---> a 260 // x --> | Identity | --^> b ==> | x | --^> b 261 // | | ... | | ... 262 // +----------+ --^> c +---+ --^> c 263 // 264 // Case b) 265 // x ---> +----------+ ---> a x ---> +---+ 266 // y --^> | Identity | ==> y --^> | a | 267 // ... | | ... | | 268 // z --^> +----------+ z --^> +---+ 269 // 270 // Case c) 271 // +----------+ x ---> +---+ 272 // x ---> | Identity | ---> a ==> \--^> | a | 273 // y --^> | | --^> b /\ +---+ 274 // +----------+ y --^> b 275 276 if (is_noop || is_identity) { 277 const auto& output_node_set = node_map_->GetOutputs(node_name); 278 const std::vector<NodeDef*> output_nodes(output_node_set.begin(), 279 output_node_set.end()); 280 const int num_outputs = output_nodes.size(); 281 const int num_inputs = node->input_size(); 282 283 if (num_inputs * num_outputs > num_inputs + num_outputs) { 284 return; 285 } 286 std::vector<NodeDef*> input_nodes; 287 for (int i = 0; i < num_inputs; ++i) { 288 NodeDef* input_node = node_map_->GetNode(node->input(i)); 289 CHECK_NE(input_node, nullptr); 290 input_nodes.push_back(input_node); 291 } 292 293 // Make sure that we don't increase the number of edges that cross 294 // device boundaries. 295 if ((num_inputs == 1 && num_outputs > 1 && 296 input_nodes[0]->device() != node->device()) || 297 (num_inputs > 1 && num_outputs == 1 && 298 output_nodes[0]->device() != node->device())) { 299 return; 300 } 301 if (num_inputs == 2 && num_outputs == 2) { 302 const string& noop_dev = node->device(); 303 const string& in0_dev = input_nodes[0]->device(); 304 const string& in1_dev = input_nodes[1]->device(); 305 const string& out0_dev = output_nodes[0]->device(); 306 const string& out1_dev = output_nodes[1]->device(); 307 const int num_cross_before = static_cast<int>(in0_dev != noop_dev) + 308 static_cast<int>(in1_dev != noop_dev) + 309 static_cast<int>(out0_dev != noop_dev) + 310 static_cast<int>(out1_dev != noop_dev); 311 const int num_cross_after = static_cast<int>(in0_dev != out0_dev) + 312 static_cast<int>(in0_dev != out1_dev) + 313 static_cast<int>(in1_dev != out0_dev) + 314 static_cast<int>(in1_dev != out1_dev); 315 if (num_cross_after > num_cross_before) { 316 return; 317 } 318 // To avoid potentially removing Identity nodes following _Recv nodes, 319 // we require that no device crossings occur in that case. 320 // TODO(rmlarsen): See if we can relax this condition. 321 if (is_identity && (num_cross_after > 0 || num_cross_before > 0)) { 322 return; 323 } 324 } 325 if (is_identity && !SafeToRemoveIdentity(*node)) { 326 return; 327 } 328 329 VLOG(1) << "***** Rerouting input around\n" << node->DebugString(); 330 // Now remove the node and re-wire its inputs to its outputs. 331 for (auto consumer : output_nodes) { 332 bool updated_consumer = false; 333 VLOG(1) << "consumer before:\n" << consumer->DebugString(); 334 for (int i = 0; i < num_inputs; ++i) { 335 const NodeDef* input = input_nodes[i]; 336 // Forward dependency from input to consumer if it doesn't already 337 // depend on it. 338 if (is_identity && i == 0) { 339 // Replace regular input from Identity node. 340 bool found_input = false; 341 string new_input; 342 const string& input_to_forward = node->input(0); 343 CHECK(!IsControlInput(input_to_forward)); 344 for (int j = 0; j < consumer->input_size(); ++j) { 345 const string& old_input = consumer->input(j); 346 if (old_input == node_name) { 347 new_input = input_to_forward; 348 node_map_->UpdateInput(consumer->name(), old_input, new_input); 349 consumer->set_input(j, new_input); 350 found_input = true; 351 } else if (old_input == AsControlDependency(NodeName(node_name))) { 352 new_input = AsControlDependency(NodeName(input_to_forward)); 353 node_map_->UpdateInput(consumer->name(), old_input, new_input); 354 consumer->set_input(j, new_input); 355 found_input = true; 356 } 357 } 358 CHECK(found_input); 359 updated_consumer = true; 360 } else { 361 // Forward dependency from input to consumer if it doesn't already 362 // depend on it. 363 if (node_map_->GetOutputs(input->name()).count(consumer) == 0) { 364 consumer->add_input(AsControlDependency(input->name())); 365 node_map_->AddOutput(input->name(), consumer->name()); 366 nodes_to_simplify->PushBack(node_to_idx_[input]); 367 updated_consumer = true; 368 } 369 } 370 } 371 // Remove dependency on node from consumer. 372 updated_consumer |= RemoveInput(consumer, AsControlDependency(node_name), 373 node_map_.get()); 374 if (updated_consumer) { 375 nodes_to_simplify->PushBack(node_to_idx_[consumer]); 376 } 377 VLOG(1) << "consumer after:\n" << consumer->DebugString(); 378 } 379 node_map_->RemoveOutputs(node_name); 380 if (fetch_nodes_known_ && 381 nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) { 382 // Mark the node for deletion. 383 nodes_to_delete->insert(node_idx); 384 385 // Disconnect the node from its inputs to enable further optimizations. 386 node_map_->RemoveInputs(node_name); 387 node->clear_input(); 388 } 389 } 390 } 391 392 void DependencyOptimizer::CleanControlInputs() { 393 for (int i = 0; i < optimized_graph_->node_size(); ++i) { 394 DedupControlInputs(optimized_graph_->mutable_node(i)); 395 } 396 } 397 398 Status DependencyOptimizer::OptimizeDependencies() { 399 SetVector<int> nodes_to_simplify; 400 std::set<int> nodes_to_delete; 401 for (int i = 0; i < optimized_graph_->node_size(); ++i) { 402 const NodeDef& node = optimized_graph_->node(i); 403 if (IsNoOp(node) || IsIdentity(node) || IsConstant(node) || 404 SafeToConvertToNoOp(node)) { 405 nodes_to_simplify.PushBack(i); 406 } 407 } 408 while (!nodes_to_simplify.Empty()) { 409 int node_to_simplify = nodes_to_simplify.PopBack(); 410 // Discard nodes that were marked for deletion already. 411 while (nodes_to_delete.find(node_to_simplify) != nodes_to_delete.end()) { 412 node_to_simplify = nodes_to_simplify.PopBack(); 413 } 414 OptimizeNode(node_to_simplify, &nodes_to_simplify, &nodes_to_delete); 415 } 416 417 if (fetch_nodes_known_) { 418 VLOG(1) << "Deleted " << nodes_to_delete.size() << " out of " 419 << optimized_graph_->node_size() << " nodes."; 420 DeleteNodes(nodes_to_delete, optimized_graph_); 421 node_map_.reset(new NodeMap(optimized_graph_)); 422 BuildNodeToIdx(); 423 } 424 return Status::OK(); 425 } 426 427 Status DependencyOptimizer::TransitiveReduction() { 428 // PRECONDITION: optimized_graph_ must be sorted topologically. 429 const int num_nodes = optimized_graph_->node_size(); 430 // Set up a compressed version of the graph to save a constant factor in the 431 // expensive algorithm below. Also cache the set of control outputs and the 432 // highest index of a target of any control output from each node. 433 int num_controls = 0; 434 std::vector<gtl::InlinedVector<int, 4>> inputs(num_nodes); 435 std::vector<gtl::InlinedVector<std::pair<int, int>, 2>> control_outputs( 436 num_nodes); 437 for (int node_idx = 0; node_idx < num_nodes; ++node_idx) { 438 const NodeDef& node = optimized_graph_->node(node_idx); 439 if (ModifiesFrameInfo(node) || !HasOpDef(node)) { 440 // Ignore function nodes and nodes that modify frame info. 441 continue; 442 } 443 for (int input_slot = 0; input_slot < node.input_size(); ++input_slot) { 444 const string& input = node.input(input_slot); 445 const NodeDef* input_node = node_map_->GetNode(input); 446 if (ModifiesFrameInfo(*input_node) || IsMerge(*input_node)) { 447 // Ignore edges from nodes that modify frame info and from Merge nodes, 448 // because we cannot know which of it's input paths executes. 449 continue; 450 } 451 const int input_node_idx = node_to_idx_[input_node]; 452 inputs[node_idx].push_back(input_node_idx); 453 if (IsControlInput(input)) { 454 ++num_controls; 455 control_outputs[input_node_idx].emplace_back(node_idx, input_slot); 456 } 457 } 458 } 459 460 // Run the longest path in DAG algorithm for each source node that has control 461 // outputs. If, for any target node of a control output, there exists a path 462 // of length > 1, we can drop that control dependency. 463 int num_controls_removed = 0; 464 std::vector<int> longest_distance(num_nodes); 465 // Map from target_index -> set of (input_slot, source_index), representing 466 // the control edges to remove. We sort them in reverse order by input slot, 467 // such that when we swap them out so we don't clobber the 468 // node(target).input() repeated field. 469 typedef std::pair<int, int> InputSlotAndSource; 470 std::unordered_map< 471 int, std::set<InputSlotAndSource, std::greater<InputSlotAndSource>>> 472 control_edges_to_remove; 473 for (int source = 0; source < num_nodes; ++source) { 474 int highest_control_target = -1; 475 for (const auto& control_output : control_outputs[source]) { 476 if (control_output.first > highest_control_target) { 477 highest_control_target = control_output.first; 478 } 479 } 480 if (highest_control_target <= source) { 481 continue; 482 } 483 std::fill(longest_distance.begin() + source, 484 longest_distance.begin() + highest_control_target + 1, 0); 485 for (int target = source + 1; target <= highest_control_target; ++target) { 486 for (int input : inputs[target]) { 487 // If the input node is before source in the topo order, no path 488 // source -> input -> target can exits and we can skip it. 489 // Also only extend a path from the source itself or from nodes that 490 // have a path from source, indicated by longest_distance[input] > 0. 491 if (input == source || 492 (input > source && longest_distance[input] > 0)) { 493 // If source -> input -> target is longer than the longest 494 // path so far from source -> target, update the longest_distance. 495 int candidate_longest_distance = longest_distance[input] + 1; 496 if (candidate_longest_distance > longest_distance[target]) { 497 longest_distance[target] = candidate_longest_distance; 498 } 499 } 500 } 501 } 502 503 // If the longest path from source to target of a control dependency is 504 // longer than 1, there exists an alternate path, and we can eliminate the 505 // redundant direct control dependency. 506 for (const auto& control_output : control_outputs[source]) { 507 const int target = control_output.first; 508 if (longest_distance[target] > 1) { 509 const int input_slot = control_output.second; 510 control_edges_to_remove[target].emplace(input_slot, source); 511 // VLOG(1) << "Removing edge from:\n" 512 // << optimized_graph_->node(source).DebugString() << 513 // "\n\nto:\n\n" 514 // << optimized_graph_->node(target).DebugString(); 515 } 516 } 517 } 518 519 for (const auto& it : control_edges_to_remove) { 520 const int target = it.first; 521 NodeDef* target_node = optimized_graph_->mutable_node(target); 522 for (const InputSlotAndSource& slot_and_source : it.second) { 523 const int input_slot = slot_and_source.first; 524 const int source = slot_and_source.second; 525 const NodeDef& source_node = optimized_graph_->node(source); 526 CHECK_LT(input_slot, target_node->input_size()); 527 target_node->mutable_input()->SwapElements(input_slot, 528 target_node->input_size() - 1); 529 node_map_->RemoveOutput(source_node.name(), target_node->name()); 530 target_node->mutable_input()->RemoveLast(); 531 ++num_controls_removed; 532 } 533 } 534 VLOG(1) << "Removed " << num_controls_removed << " out of " << num_controls 535 << " control dependencies"; 536 return Status::OK(); 537 } 538 539 void DependencyOptimizer::BuildNodeToIdx() { 540 // Set up &node -> index map. 541 node_to_idx_.clear(); 542 for (int i = 0; i < optimized_graph_->node_size(); ++i) { 543 const NodeDef& node = optimized_graph_->node(i); 544 node_to_idx_[&node] = i; 545 } 546 } 547 548 Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, 549 GraphDef* optimized_graph) { 550 optimized_graph_ = optimized_graph; 551 *optimized_graph_ = item.graph; 552 nodes_to_preserve_ = item.NodesToPreserve(); 553 fetch_nodes_known_ = !item.fetch.empty(); 554 CleanControlInputs(); 555 556 const int num_iterations = 2; 557 for (int iteration = 0; iteration < num_iterations; ++iteration) { 558 Status topo_sort_status; 559 // Perform topological sort to prepare the graph for transitive reduction. 560 topo_sort_status = TopologicalSort(optimized_graph_); 561 // Set up index-based graph datastructures to speed up analysis steps below. 562 node_map_.reset(new NodeMap(optimized_graph_)); 563 BuildNodeToIdx(); 564 565 if (topo_sort_status.ok()) { 566 // Remove redundant control dependencies. 567 TF_RETURN_IF_ERROR(TransitiveReduction()); 568 } else { 569 LOG(ERROR) << topo_sort_status.error_message(); 570 } 571 // Turn nodes with only control outputs into NoOps, prune NoOp and Identity 572 // nodes. 573 TF_RETURN_IF_ERROR(OptimizeDependencies()); 574 575 // Dedup control inputs. 576 CleanControlInputs(); 577 } 578 579 return Status::OK(); 580 } 581 582 void DependencyOptimizer::Feedback(Cluster* /*cluster*/, 583 const GrapplerItem& /*item*/, 584 const GraphDef& /*optimized_graph*/, 585 double /*result*/) { 586 // Nothing to do for DependencyOptimizer. 587 } 588 589 } // end namespace grappler 590 } // end namespace tensorflow 591