1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/grappler/optimizers/memory_optimizer.h" 17 18 #include <algorithm> 19 #include <queue> 20 #include <unordered_map> 21 #include <unordered_set> 22 #include <vector> 23 24 #include "tensorflow/core/framework/attr_value.pb.h" 25 #include "tensorflow/core/framework/node_def.pb.h" 26 #include "tensorflow/core/framework/op.h" 27 #include "tensorflow/core/framework/tensor_shape.pb.h" 28 #include "tensorflow/core/grappler/clusters/virtual_cluster.h" 29 #include "tensorflow/core/grappler/costs/graph_memory.h" 30 #include "tensorflow/core/grappler/costs/graph_properties.h" 31 #include "tensorflow/core/grappler/graph_view.h" 32 #include "tensorflow/core/grappler/grappler_item.h" 33 #include "tensorflow/core/grappler/op_types.h" 34 #include "tensorflow/core/grappler/optimizers/graph_rewriter.h" 35 #include "tensorflow/core/grappler/optimizers/static_schedule.h" 36 #include "tensorflow/core/grappler/utils.h" 37 #include "tensorflow/core/grappler/utils/topological_sort.h" 38 #include "tensorflow/core/grappler/utils/traversal.h" 39 #include "tensorflow/core/protobuf/rewriter_config.pb.h" 40 41 namespace tensorflow { 42 namespace grappler { 43 44 // Prefix added to nodes which are recomputed. 45 const char* kRecomputedNodePrefix = "Recomputed"; 46 const char* kRecomputeTriggerNodePrefix = "RecomputeTrigger"; 47 // Attribute which may be added to nodes to manually allow them to be 48 // recomputed. 49 const char* kRecomputeHint = "_recompute_hint"; 50 51 // Ops which we wouldn't mind recomputing to save memory. 52 // TODO(allenl): Replace this list with a cost model. 53 std::unordered_set<string> GetCheapToRecomputeOps() { 54 std::unordered_set<string> cheap_ops = { 55 "Add", "AddN", "BiasAdd", "Cast", "Fill", 56 "FloorDiv", "FloorMod", "FusedBatchNorm", "Mul", "Neg", 57 "RealDiv", "Reciprocal", "Relu", "Relu6", "Reshape", 58 "Rsqrt", "Sigmoid", "Sqrt", "Square", "SquaredDifference", 59 "Sub", "Tile", "Transpose"}; 60 return cheap_ops; 61 } 62 63 // Find recomputable ops which feed into target nodes. 64 std::unordered_set<const NodeDef*> FindCandidateRecomputeNodes( 65 const NodeMap& node_map, const GraphDef* graph, 66 const std::function<bool(const NodeDef&)>& is_candidate, 67 const std::function<bool(const NodeDef&)>& is_target) { 68 std::unordered_set<const NodeDef*> candidate_recompute_nodes; 69 for (const auto& node : graph->node()) { 70 if (!is_candidate(node)) { 71 continue; 72 } 73 bool has_target_output = false; 74 for (const NodeDef* output : node_map.GetOutputs(node.name())) { 75 // It only makes sense to recompute this if it feeds into a target 76 // node. We expand this to dependencies in GetOpGroupsToRecompute. 77 if (is_target(*output)) { 78 has_target_output = true; 79 break; 80 } 81 } 82 if (!has_target_output) { 83 continue; 84 } 85 bool has_target_input = false; 86 for (const string& input_name : node.input()) { 87 // Don't recompute nodes which depend on target nodes. 88 const NodeDef* input_node = node_map.GetNode(input_name); 89 if (is_target(*input_node)) { 90 has_target_input = true; 91 break; 92 } 93 } 94 if (has_target_input) { 95 continue; 96 } 97 candidate_recompute_nodes.insert(&node); 98 } 99 return candidate_recompute_nodes; 100 } 101 102 void connected_subgraph(const NodeMap& node_map, bool collect_inputs, 103 bool collect_outputs, 104 const std::function<bool(const NodeDef&)>& is_candidate, 105 std::unordered_set<const NodeDef*>* expanded_nodes) { 106 std::queue<const NodeDef*> to_visit; 107 for (const NodeDef* starting_node : *expanded_nodes) { 108 to_visit.push(starting_node); 109 } 110 expanded_nodes->clear(); 111 while (!to_visit.empty()) { 112 const NodeDef* current_node = to_visit.front(); 113 to_visit.pop(); 114 if (!expanded_nodes->insert(current_node).second) { 115 // We already visited this node 116 continue; 117 } 118 if (collect_inputs) { 119 // Add inputs and outputs to this subgraph if they are candidates 120 for (const string& input_name_raw : current_node->input()) { 121 const NodeDef* input_node = node_map.GetNode(input_name_raw); 122 if (expanded_nodes->count(input_node) == 0 && 123 is_candidate(*input_node)) { 124 to_visit.push(input_node); 125 } 126 } 127 } 128 if (collect_outputs) { 129 for (const NodeDef* output : node_map.GetOutputs(current_node->name())) { 130 if (expanded_nodes->count(output) == 0 && is_candidate(*output)) { 131 to_visit.push(output); 132 } 133 } 134 } 135 } 136 } 137 138 struct RecomputedSubGraph { 139 std::unordered_set<const NodeDef*> recomputed_source_nodes; 140 std::unordered_set<NodeDef*> target_nodes; 141 }; 142 143 // Find groups of ops to recompute together based on `should_recompute`. 144 std::vector<RecomputedSubGraph> GetOpGroupsToRecompute( 145 const GraphDef* graph, const NodeMap& node_map, 146 const std::function<bool(const NodeDef&)>& should_recompute, 147 const std::function<bool(const NodeDef&)>& is_target) { 148 std::unordered_set<const NodeDef*> visited_nodes; 149 std::vector<RecomputedSubGraph> subgraphs_to_recompute; 150 std::unordered_set<const NodeDef*> candidate_recompute_nodes = 151 FindCandidateRecomputeNodes(node_map, graph, should_recompute, is_target); 152 for (const NodeDef* recompute_node : candidate_recompute_nodes) { 153 if (visited_nodes.count(recompute_node) > 0) { 154 continue; 155 } 156 RecomputedSubGraph current_recomputation; 157 // Build out recomputation groups by expanding to inexpensive-to-recompute 158 // nodes which do not feed target nodes. The goal is to capture some 159 // intermediate activations within this graph. 160 std::unordered_set<const NodeDef*> unpruned_recompute_nodes; 161 unpruned_recompute_nodes.insert(recompute_node); 162 connected_subgraph(node_map, 163 true, // Collect inputs 164 true, // Collect outputs 165 should_recompute, &unpruned_recompute_nodes); 166 visited_nodes.insert(unpruned_recompute_nodes.begin(), 167 unpruned_recompute_nodes.end()); 168 for (const NodeDef* recompute_node : unpruned_recompute_nodes) { 169 bool inserted_feed = false; 170 for (NodeDef* output : node_map.GetOutputs(recompute_node->name())) { 171 if (is_target(*output)) { 172 current_recomputation.target_nodes.insert(output); 173 if (!inserted_feed) { 174 // Keep track of nodes which feed directly into a target node. These 175 // and nodes which feed into them will define the recomputed 176 // subgraph. 177 current_recomputation.recomputed_source_nodes.insert( 178 recompute_node); 179 inserted_feed = true; 180 } 181 } 182 } 183 } 184 // Recompute only nodes which eventually feed into a target node. 185 connected_subgraph(node_map, 186 true, // Collect inputs 187 false, // Collect outputs 188 [&unpruned_recompute_nodes](const NodeDef& node) { 189 return unpruned_recompute_nodes.count(&node) != 0; 190 }, 191 ¤t_recomputation.recomputed_source_nodes); 192 if (current_recomputation.target_nodes.empty()) { 193 continue; 194 } 195 subgraphs_to_recompute.push_back(current_recomputation); 196 } 197 return subgraphs_to_recompute; 198 } 199 200 // Computes the maximum topological numbers of (1) target node components 201 // (gradient nodes being fed by the recomputation), and (2) child recompute node 202 // components for each recomputed node. We will not attach any control 203 // dependencies to a recomputation unless they have component numbers greater 204 // than this value (to prevent cycles). 205 std::unordered_map<const NodeDef*, int> GetMaxDownstreamComponents( 206 const std::unordered_set<const NodeDef*>& recomputed_source_nodes, 207 const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map, 208 const std::unordered_map<const NodeDef*, int>& components) { 209 std::unordered_map<const NodeDef*, int> recomputed_node_components; 210 // Start by setting component numbers to the maximum among target nodes. 211 for (const NodeDef* original_recompute_node : recomputed_source_nodes) { 212 int max_target_component = -1; 213 for (NodeDef* output : 214 node_map.GetOutputs(original_recompute_node->name())) { 215 if (target_nodes.count(output) != 0) { 216 int current_target_component = components.find(output)->second; 217 if (current_target_component > max_target_component) { 218 max_target_component = current_target_component; 219 } 220 } 221 } 222 if (max_target_component > -1) { 223 recomputed_node_components[original_recompute_node] = 224 max_target_component; 225 } 226 } 227 // Sort recomputed nodes topologically (based on the original graph) so we can 228 // efficiently assign to each node the maximum of its recomputed child 229 // components and its own targets. 230 std::vector<const NodeDef*> recomputed_source_nodes_topological( 231 recomputed_source_nodes.begin(), recomputed_source_nodes.end()); 232 std::sort(recomputed_source_nodes_topological.begin(), 233 recomputed_source_nodes_topological.end(), 234 [&components](const NodeDef* first, const NodeDef* second) { 235 return components.find(first)->second < 236 components.find(second)->second; 237 }); 238 for (const NodeDef* original_recompute_node : 239 recomputed_source_nodes_topological) { 240 int max_component; 241 auto recomputed_component_iterator = 242 recomputed_node_components.find(original_recompute_node); 243 if (recomputed_component_iterator != recomputed_node_components.end()) { 244 max_component = recomputed_component_iterator->second; 245 } else { 246 max_component = -1; 247 } 248 for (NodeDef* output : 249 node_map.GetOutputs(original_recompute_node->name())) { 250 if (recomputed_source_nodes.count(output) == 0) { 251 continue; 252 } 253 auto child_component_iterator = recomputed_node_components.find(output); 254 CHECK(child_component_iterator != recomputed_node_components.end()); 255 int child_component = child_component_iterator->second; 256 if (child_component > max_component) { 257 max_component = child_component; 258 } 259 } 260 CHECK_GE(max_component, 0); 261 recomputed_node_components[original_recompute_node] = max_component; 262 } 263 return recomputed_node_components; 264 } 265 266 // Modifies `graph`, adding trigger nodes and returning a mapping from 267 // `recomputed_source_nodes` to trigger nodes which will not create loops in the 268 // graph (using the component numberings in `components` and 269 // `recomputed_node_max_feed_components`). The copied nodes (not the nodes in 270 // recomputed_source_nodes, which are the originals) eventually get these 271 // control dependencies. 272 std::unordered_map<const NodeDef*, const NodeDef*> 273 AddRecomputeControlDependencyNodes( 274 const std::unordered_set<const NodeDef*>& recomputed_source_nodes, 275 const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map, 276 const std::unordered_map<const NodeDef*, int>& components, 277 const std::unordered_map<const NodeDef*, int>& 278 recomputed_node_max_feed_components, 279 GraphDef* graph) { 280 // Sort recomputed nodes based on max downstream components. 281 std::vector<const NodeDef*> recomputed_source_nodes_topological( 282 recomputed_source_nodes.begin(), recomputed_source_nodes.end()); 283 std::sort(recomputed_source_nodes_topological.begin(), 284 recomputed_source_nodes_topological.end(), 285 [&recomputed_node_max_feed_components](const NodeDef* first, 286 const NodeDef* second) { 287 int first_component = 288 recomputed_node_max_feed_components.find(first)->second; 289 int second_component = 290 recomputed_node_max_feed_components.find(second)->second; 291 return first_component > second_component 292 // Ensure a consistent ordering. This is necessary because 293 // we're working not with node component numbers (which are 294 // unique) but with the maximum across nodes they feed into 295 // (very much not unique). 296 || (first_component == second_component && 297 first->name() > second->name()); 298 }); 299 // Create merged control dependency nodes by sorting target inputs 300 // topologically and zipper merging with the sorted recomputed nodes. 301 std::vector<const NodeDef*> target_inputs_topological; 302 for (const NodeDef* target_node : target_nodes) { 303 for (const string& target_input_name_raw : target_node->input()) { 304 const NodeDef* target_input = node_map.GetNode(target_input_name_raw); 305 // If this node has already had one of its inputs recomputed during this 306 // rewriting pass, we ignore that recomputed node here (it will not be in 307 // the NodeMap). 308 if (target_input == nullptr || 309 recomputed_source_nodes.count(target_input) != 0 || 310 components.find(target_node)->second == 311 components.find(target_input)->second) { 312 continue; 313 } 314 target_inputs_topological.push_back(target_input); 315 } 316 } 317 std::sort(target_inputs_topological.begin(), target_inputs_topological.end(), 318 [&components](const NodeDef* first, const NodeDef* second) { 319 return components.find(first)->second > 320 components.find(second)->second; 321 }); 322 auto target_input_iterator = target_inputs_topological.begin(); 323 NodeDef* current_trigger_node = nullptr; 324 std::unordered_map<const NodeDef*, const NodeDef*> triggers; 325 for (const NodeDef* original_recomputed_node : 326 recomputed_source_nodes_topological) { 327 NodeDef* new_trigger_node = graph->add_node(); 328 new_trigger_node->set_name(AddPrefixToNodeName( 329 original_recomputed_node->name(), kRecomputeTriggerNodePrefix)); 330 new_trigger_node->set_op("NoOp"); 331 new_trigger_node->set_device(original_recomputed_node->device()); 332 if (current_trigger_node != nullptr) { 333 *new_trigger_node->add_input() = 334 strings::StrCat("^", current_trigger_node->name()); 335 } 336 current_trigger_node = new_trigger_node; 337 triggers[original_recomputed_node] = current_trigger_node; 338 for (; 339 target_input_iterator != target_inputs_topological.end() && 340 components.find(*target_input_iterator)->second > 341 recomputed_node_max_feed_components.find(original_recomputed_node) 342 ->second; 343 ++target_input_iterator) { 344 *current_trigger_node->add_input() = 345 strings::StrCat("^", (*target_input_iterator)->name()); 346 VLOG(2) << " Recomputation trigger " << current_trigger_node->name() 347 << " depends on " << (*target_input_iterator)->name(); 348 } 349 } 350 return triggers; 351 } 352 353 string RecomputedOrOriginalNodeName( 354 const std::unordered_set<string>& recomputed_node_names, 355 const string& original_node_name) { 356 if (recomputed_node_names.find(original_node_name) == 357 recomputed_node_names.end()) { 358 return original_node_name; 359 } else { 360 return AddPrefixToNodeName(original_node_name, kRecomputedNodePrefix); 361 } 362 } 363 364 // Helper function to recompute a sub-graph (recomputed_source_nodes). Edges 365 // from recomputed_source_nodes to target_nodes are changed to start from the 366 // recomputed nodes. 367 void RecomputeSubgraph( 368 const std::unordered_set<const NodeDef*>& recomputed_source_nodes, 369 const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map, 370 const std::unordered_map<const NodeDef*, int>& components, 371 GraphDef* graph) { 372 std::unordered_set<string> recomputed_node_names; 373 VLOG(1) << "Recomputing a " << recomputed_source_nodes.size() 374 << " node subgraph"; 375 std::unordered_map<const NodeDef*, int> recomputed_node_components = 376 GetMaxDownstreamComponents(recomputed_source_nodes, target_nodes, 377 node_map, components); 378 for (const NodeDef* original_node : recomputed_source_nodes) { 379 VLOG(2) << " " << original_node->name(); 380 recomputed_node_names.insert(original_node->name()); 381 } 382 std::unordered_map<const NodeDef*, const NodeDef*> triggers = 383 AddRecomputeControlDependencyNodes(recomputed_source_nodes, target_nodes, 384 node_map, components, 385 recomputed_node_components, graph); 386 // Create the recomputed sub-graph 387 for (const NodeDef* original_node : recomputed_source_nodes) { 388 NodeDef* copied_node = graph->add_node(); 389 copied_node->set_name( 390 AddPrefixToNodeName(original_node->name(), kRecomputedNodePrefix)); 391 copied_node->set_op(original_node->op()); 392 *copied_node->mutable_attr() = original_node->attr(); 393 copied_node->set_device(original_node->device()); 394 for (const string& original_input_name : original_node->input()) { 395 // Set inputs which are internal to the copied subgraph to their copied 396 // versions. 397 *copied_node->add_input() = RecomputedOrOriginalNodeName( 398 recomputed_node_names, original_input_name); 399 } 400 // Each recomputed node gets a control dependency to prevent it from being 401 // recomputed immediately. 402 *copied_node->add_input() = 403 strings::StrCat("^", triggers[original_node]->name()); 404 } 405 // Set the inputs of nodes in the target subgraph to the recomputed nodes 406 // where applicable. 407 for (NodeDef* target_node : target_nodes) { 408 for (string& target_input_name : *target_node->mutable_input()) { 409 target_input_name = RecomputedOrOriginalNodeName(recomputed_node_names, 410 target_input_name); 411 } 412 } 413 } 414 415 void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level, 416 const string& recomputation_targets_name_prefix, 417 GraphDef* graph, const GrapplerItem& item) { 418 if (optimization_level != RewriterConfig::RECOMPUTATION_HEURISTICS && 419 optimization_level != RewriterConfig::HEURISTICS && 420 optimization_level != RewriterConfig::MANUAL) { 421 // Nothing to do 422 return; 423 } 424 // The topological numberings and NodeMap will be stale as soon as we start 425 // modifying the graph in RecomputeSubgraph. However, RecomputeSubgraph only 426 // looks up nodes which were in the original graph, and preserves the graph 427 // topology it's interested in. 428 // We don't use the results of this topological sort until later, but this 429 // call invalidates all NodeDef pointers, so it needs to be done before we 430 // start collecting those. 431 TF_CHECK_OK(TopologicalSort(graph)); 432 NodeMap node_map(graph); 433 std::vector<RecomputedSubGraph> recomputed_subgraphs; 434 // Do not recompute nodes which are fed, since the recomputed node would not 435 // take on the fed value (i.e. gradients would be incorrect). 436 std::unordered_set<string> feeds; 437 for (const auto& feed : item.feed) { 438 feeds.insert(NodeName(feed.first)); 439 } 440 std::function<bool(const NodeDef&)> is_target = 441 [&recomputation_targets_name_prefix](const NodeDef& node) { 442 // Nodes whose inputs we may want to recompute. Typically targets will 443 // be gradients (recomputation_targets_name_prefix="gradients/"), 444 // although the prefix is configurable since gradients may be created 445 // in a name scope. 446 // TODO(allenl): Use a static schedule 447 // (grappler::EstimateEarliestExecutionTimes) to recompute only nodes 448 // whose outputs will sit around for a while. 449 return node.name().find(recomputation_targets_name_prefix) == 0; 450 }; 451 452 if (optimization_level == RewriterConfig::RECOMPUTATION_HEURISTICS || 453 optimization_level == RewriterConfig::HEURISTICS) { 454 // TODO(allenl): Handle ResNet-like architectures better. Right now all of 455 // the cheap forward ops get grouped into a single subgraph which must 456 // execute before gradients start executing (unless layers are manually 457 // separated by identity ops). 458 std::unordered_set<string> cheap_to_recompute_ops = 459 GetCheapToRecomputeOps(); 460 recomputed_subgraphs = GetOpGroupsToRecompute( 461 graph, node_map, 462 [&cheap_to_recompute_ops, &feeds, &is_target](const NodeDef& node) { 463 return !is_target(node) && feeds.count(node.name()) == 0 && 464 (cheap_to_recompute_ops.count(node.op()) > 0 || 465 node.attr().count(kRecomputeHint) > 0); 466 }, 467 is_target); 468 } else if (optimization_level == RewriterConfig::MANUAL) { 469 recomputed_subgraphs = GetOpGroupsToRecompute( 470 graph, node_map, 471 [&feeds, &is_target](const NodeDef& node) { 472 return !is_target(node) && feeds.count(node.name()) == 0 && 473 node.attr().count(kRecomputeHint) > 0; 474 }, 475 is_target); 476 } 477 if (!recomputed_subgraphs.empty()) { 478 std::unordered_map<const NodeDef*, int> topological_numbering; 479 for (int node_number = 0; node_number < graph->node().size(); 480 ++node_number) { 481 topological_numbering[graph->mutable_node(node_number)] = 482 graph->node().size() - node_number - 1; 483 } 484 // Duplicate the indicated sub-graphs and set up control dependencies 485 for (const RecomputedSubGraph& subgraph : recomputed_subgraphs) { 486 RecomputeSubgraph(subgraph.recomputed_source_nodes, subgraph.target_nodes, 487 node_map, topological_numbering, graph); 488 } 489 } 490 } 491 492 bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { 493 // Look for AddN nodes (and equivalent) and record input names. 494 GraphView view(&item->graph); 495 496 std::unordered_map<string, std::unordered_set<NodeDef*>> addn_list; 497 for (NodeDef& node : *item->graph.mutable_node()) { 498 if (!IsAddN(node) && node.op() != "AccumulateNV2") { 499 continue; 500 } 501 // There is nothing to gain by optimizing nodes with 2 or fewer inputs. 502 if (view.NumFanins(node, false) <= 2) { 503 continue; 504 } 505 for (const auto& input : view.GetFanins(node, false)) { 506 if (input.node->device() == node.device()) { 507 string tensor_name = 508 strings::StrCat(input.node->name(), ":", input.port_id); 509 addn_list[tensor_name].insert(&node); 510 } 511 } 512 } 513 514 if (addn_list.empty()) { 515 return false; 516 } 517 518 GraphMemory memory(*item); 519 const std::unordered_map<string, DeviceProperties>& devices = 520 cluster->GetDevices(); 521 Status s = memory.InferStatically(devices); 522 if (!s.ok()) { 523 VLOG(1) << "Failed to infer memory usage: " << s.error_message(); 524 return false; 525 } 526 527 std::unordered_set<NodeDef*> addn_to_rewrite; 528 for (const auto& device : devices) { 529 const string& name = device.first; 530 const DeviceProperties& prop = device.second; 531 if (prop.memory_size() <= 0) { 532 VLOG(1) << "Available memory unknown for device " << name; 533 continue; 534 } 535 const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name); 536 537 if (mem_usage.used_memory <= prop.memory_size() * 0.8) { 538 continue; 539 } 540 541 for (const auto& live : mem_usage.live_tensors) { 542 string tensor_name = strings::StrCat(live.node, ":", live.output_id); 543 auto it = addn_list.find(tensor_name); 544 if (it != addn_list.end()) { 545 addn_to_rewrite.insert(it->second.begin(), it->second.end()); 546 } 547 } 548 } 549 550 if (addn_to_rewrite.empty()) { 551 return false; 552 } 553 GraphProperties properties(*item); 554 s = properties.InferStatically(false); 555 if (!s.ok()) { 556 VLOG(1) << "Failed to infer shapes: " << s.error_message(); 557 return false; 558 } 559 560 bool updated_graph = false; 561 // Rewrite the AddN. 562 for (NodeDef* node : addn_to_rewrite) { 563 if (!properties.HasOutputProperties(node->name())) { 564 VLOG(1) << "Missing properties for " << node->name(); 565 continue; 566 } 567 const TensorShapeProto& shape = 568 properties.GetOutputProperties(node->name())[0].shape(); 569 PartialTensorShape shp(shape); 570 if (!shp.IsFullyDefined()) { 571 VLOG(1) << "Shape not fully known for " << node->name(); 572 continue; 573 } 574 575 // Compute a topological ordering for the node fanin. 576 std::unordered_map<NodeDef*, int> topo_order; 577 ReverseDfs(view, {node}, nullptr, 578 [&topo_order](NodeDef* n) { 579 int topo_index = topo_order.size(); 580 topo_order[n] = topo_index; 581 }, 582 nullptr); 583 584 std::vector<int> input_topo_index; 585 586 for (int i = 0; i < node->input_size(); ++i) { 587 const string& input = node->input(i); 588 const string node_name = NodeName(input); 589 NodeDef* node = view.GetNode(node_name); 590 input_topo_index.push_back(topo_order.at(node)); 591 } 592 int min_input_topo_index = INT_MAX; 593 int min_input_id = -1; 594 for (int i = 0; i < node->input_size(); ++i) { 595 if (IsControlInput(node->input(i))) { 596 // control inputs are always last. 597 break; 598 } 599 const int current = input_topo_index[i]; 600 if (current < min_input_topo_index) { 601 min_input_topo_index = current; 602 min_input_id = i; 603 } 604 } 605 CHECK_LE(0, min_input_id); 606 std::vector<string> pre_ctrl_deps; 607 std::vector<string> post_ctrl_deps; 608 for (int i = node->input_size() - 1; i >= 0; --i) { 609 if (!IsControlInput(node->input(i))) { 610 // control inputs are always last. 611 break; 612 } 613 if (input_topo_index[i] < min_input_topo_index) { 614 // These control dependencies can be executed before the node. 615 pre_ctrl_deps.push_back(node->input(i)); 616 } else { 617 // These control dependencies should be executed after the node. 618 post_ctrl_deps.push_back(node->input(i)); 619 } 620 } 621 622 DataType dtype = node->attr().at("T").type(); 623 const string& device = node->device(); 624 625 // Create the temporary variable that will hold intermediate results 626 NodeDef* tmp_var = item->graph.add_node(); 627 tmp_var->set_name(strings::StrCat(node->name(), "/tmp_var")); 628 tmp_var->set_op("TemporaryVariable"); 629 tmp_var->set_device(device); 630 (*tmp_var->mutable_attr())["dtype"].set_type(dtype); 631 *(*tmp_var->mutable_attr())["shape"].mutable_shape() = shape; 632 (*tmp_var->mutable_attr())["var_name"].set_s(tmp_var->name()); 633 634 for (const string& ctrl_dep : pre_ctrl_deps) { 635 *tmp_var->add_input() = ctrl_dep; 636 } 637 *tmp_var->add_input() = 638 AsControlDependency(NodeName(node->input(min_input_id))); 639 640 // Initialize it to zero 641 NodeDef* zeros = item->graph.add_node(); 642 zeros->set_name(strings::StrCat(node->name(), "/tmp_var_zeros")); 643 zeros->set_op("ZerosLike"); 644 zeros->set_device(device); 645 (*zeros->mutable_attr())["T"].set_type(dtype); 646 *zeros->add_input() = node->input(min_input_id); 647 648 NodeDef* initialize = item->graph.add_node(); 649 initialize->set_name(strings::StrCat(node->name(), "/tmp_var_initializer")); 650 initialize->set_op("Assign"); 651 initialize->set_device(device); 652 (*initialize->mutable_attr())["T"].set_type(dtype); 653 (*initialize->mutable_attr())["use_locking"].set_b(false); 654 (*initialize->mutable_attr())["validate_shape"].set_b(false); 655 *initialize->add_input() = tmp_var->name(); 656 *initialize->add_input() = zeros->name(); 657 658 // Add the assignadd nodes 659 std::vector<NodeDef*> accumulates; 660 for (int i = 0; i < node->input_size(); ++i) { 661 const string& input = node->input(i); 662 if (!IsControlInput(input)) { 663 NodeDef* accumulate = item->graph.add_node(); 664 accumulate->set_name( 665 strings::StrCat(node->name(), "/tmp_var_accum_", i)); 666 accumulate->set_op("AssignAdd"); 667 accumulate->set_device(device); 668 (*accumulate->mutable_attr())["T"].set_type(dtype); 669 (*accumulate->mutable_attr())["use_locking"].set_b(true); 670 *accumulate->add_input() = initialize->name(); 671 *accumulate->add_input() = input; 672 accumulates.push_back(accumulate); 673 } 674 } 675 676 // Rewrite the AddN node as a DestroyTemporaryVariable ops 677 node->set_op("DestroyTemporaryVariable"); 678 node->clear_input(); 679 node->clear_attr(); 680 (*node->mutable_attr())["T"].set_type(dtype); 681 (*node->mutable_attr())["var_name"].set_s(tmp_var->name()); 682 *node->add_input() = initialize->name(); 683 for (const NodeDef* accum : accumulates) { 684 *node->add_input() = AsControlDependency(accum->name()); 685 } 686 for (const string& ctrl_dep : post_ctrl_deps) { 687 *node->add_input() = ctrl_dep; 688 } 689 690 updated_graph = true; 691 } 692 693 return updated_graph; 694 } 695 696 Status BuildSwapPair(NodeDef* node, int input_to_swap, 697 const std::unordered_map<string, const NodeDef*>& name_map, 698 GraphDef* graph, 699 std::pair<NodeDef*, NodeDef*>* swap_pair) { 700 const OpDef* op_def; 701 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node->op(), &op_def)); 702 DataType input_type; 703 TF_RETURN_IF_ERROR( 704 InputTypeForNode(*node, *op_def, input_to_swap, &input_type)); 705 if (IsRefType(input_type)) { 706 return errors::InvalidArgument("Can't swap input ", input_to_swap, 707 " of node ", node->name(), 708 " since it expects a reference"); 709 } 710 711 string tensor_to_swap = strings::StrCat(node->name(), "_", input_to_swap); 712 string swap_out_name = strings::StrCat("swap_out_", tensor_to_swap); 713 string swap_in_name = strings::StrCat("swap_in_", tensor_to_swap); 714 if (name_map.find(swap_out_name) != name_map.end() || 715 name_map.find(swap_in_name) != name_map.end()) { 716 return errors::InvalidArgument("Input ", input_to_swap, " of node ", 717 node->name(), " is already swapped"); 718 } 719 720 // Force the tensor to be copied to cpu. 721 NodeDef* swap_out_node = graph->add_node(); 722 swap_out_node->set_name(swap_out_name); 723 swap_out_node->set_op("Identity"); 724 swap_out_node->set_device("/device:CPU:0"); 725 726 // Force the tensor to be restored to the device. 727 NodeDef* swap_in_node = graph->add_node(); 728 swap_in_node->set_name(swap_in_name); 729 swap_in_node->set_op("Identity"); 730 *swap_in_node->add_input() = swap_out_node->name(); 731 732 // Colocate the swap_in_ node with the node itself. 733 swap_in_node->set_device(node->device()); 734 string coloc_group = strings::StrCat("loc@", tensor_to_swap); 735 (*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group); 736 (*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group); 737 738 (*swap_in_node->mutable_attr())["T"].set_type(input_type); 739 (*swap_out_node->mutable_attr())["T"].set_type(input_type); 740 *swap_pair = std::make_pair(swap_out_node, swap_in_node); 741 742 return Status::OK(); 743 } 744 745 static int64 EstimateSize(const OpInfo::TensorProperties& t) { 746 DataType dtype = t.dtype(); 747 int64 size = DataTypeSize(dtype); 748 TensorShapeProto shape = t.shape(); 749 if (shape.unknown_rank()) { 750 // Can't infer the size if the rank is unknown. It has to be at least a 751 // scalar though. 752 return size; 753 } 754 // If one of the dimensions is unknown statically, assume it's at least one. 755 for (int i = 0; i < shape.dim_size(); ++i) { 756 if (shape.dim(i).size() < 0) { 757 shape.mutable_dim(i)->set_size(1); 758 } 759 } 760 int64 num_elems = TensorShape(shape).num_elements(); 761 return num_elems * size; 762 } 763 764 struct SwapInfo { 765 std::vector<int> inputs_to_swap; 766 Costs::NanoSeconds time_to_swap = 0; 767 }; 768 769 static const NodeDef* FindSwapInTrigger( 770 const NodeDef* node, const SwapInfo& swap_info, 771 const std::unordered_map<string, const NodeDef*>& name_map, 772 const std::unordered_map<const NodeDef*, Costs::NanoSeconds>& 773 execution_times) { 774 // max_trigger_time stores the time before which the swap operation needs to 775 // be started in order to load the data back onto the accelerator without 776 // delaying the downstream computation. 777 Costs::NanoSeconds max_trigger_time(0); 778 std::set<string> possible_inputs; 779 for (int i = 0; i < node->input_size(); ++i) { 780 const string input_node_name = NodeName(node->input(i)); 781 auto it1 = name_map.find(input_node_name); 782 if (it1 == name_map.end()) { 783 return nullptr; 784 } 785 const NodeDef* input_node = it1->second; 786 787 auto it2 = execution_times.find(input_node); 788 if (it2 == execution_times.end()) { 789 return nullptr; 790 } 791 max_trigger_time = std::max(max_trigger_time, it2->second); 792 possible_inputs.insert(input_node_name); 793 } 794 795 for (const int i : swap_info.inputs_to_swap) { 796 const string input_node_name = NodeName(node->input(i)); 797 possible_inputs.erase(input_node_name); 798 } 799 if (possible_inputs.empty()) { 800 return nullptr; 801 } 802 803 max_trigger_time -= swap_info.time_to_swap; 804 805 std::map<Costs::NanoSeconds, const NodeDef*> candidates; 806 std::set<string> already_processed; 807 808 while (!possible_inputs.empty()) { 809 const string input_node_name = *possible_inputs.begin(); 810 possible_inputs.erase(possible_inputs.begin()); 811 already_processed.insert(input_node_name); 812 auto it1 = name_map.find(input_node_name); 813 if (it1 == name_map.end()) { 814 return nullptr; 815 } 816 const NodeDef* input_node = it1->second; 817 // Don't jump over frames, since adding a control dependency from one frame 818 // to the next isn't supported. Don't go through branches, since we don't 819 // know whether they'll be executed or not. 820 if (ModifiesFrameInfo(*input_node) || IsSwitch(*input_node) || 821 IsMerge(*input_node)) { 822 continue; 823 } 824 auto it2 = execution_times.find(input_node); 825 if (it2 == execution_times.end()) { 826 return nullptr; 827 } 828 if (it2->second < max_trigger_time) { 829 candidates[it2->second] = input_node; 830 } else { 831 for (const string& fanin : input_node->input()) { 832 string name = NodeName(fanin); 833 if (already_processed.find(name) == already_processed.end()) { 834 possible_inputs.insert(name); 835 } 836 } 837 } 838 } 839 840 // Select the candidate that will execute last, since we want to swap the data 841 // back at the last minute while still allowing enough time for data to be 842 // swapped back timely to feed the downstream nodes. 843 if (!candidates.empty()) { 844 return candidates.rbegin()->second; 845 } 846 return nullptr; 847 } 848 849 static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) { 850 const NodeDef& node = *output.node; 851 // There is no point in swapping out persistent tensors, since the tensor will 852 // continue to use memory. 853 if (IsPersistent(node)) { 854 return false; 855 } 856 857 const OpDef* op_def; 858 if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) { 859 return false; 860 } 861 DataType dtype; 862 if (!OutputTypeForNode(node, *op_def, output.port_id, &dtype).ok()) { 863 return false; 864 } 865 // References can only refer to persistent memory: therefore the node isn't 866 // swappable. 867 if (IsRefType(dtype)) { 868 return false; 869 } 870 871 if (output.node->op() == "Identity" || output.node->op() == "Reshape") { 872 // If placed on the same device, these nodes are just forwarding references 873 // to their input. Therefore they are swappable iff their fanin is swappable 874 // or it resides on a different device. 875 GraphView::InputPort input; 876 input.node = output.node; 877 input.port_id = 0; 878 GraphView::OutputPort fanin = graph.GetRegularFanin(input); 879 if (fanin.node->device() == node.device()) { 880 return IsSwappable(graph, fanin); 881 } 882 } 883 return true; 884 } 885 886 static NodeDef* FindSwapOutTrigger( 887 const NodeDef* node, int input_id, const GraphView& view, 888 const std::unordered_map<const NodeDef*, Costs::NanoSeconds>& 889 execution_times) { 890 // Find the output port that generated the tensor to swap. 891 GraphView::InputPort swap; 892 swap.node = const_cast<NodeDef*>(node); 893 swap.port_id = input_id; 894 GraphView::OutputPort generator = view.GetRegularFanin(swap); 895 if (!generator.node) { 896 return nullptr; 897 } 898 899 const std::unordered_set<GraphView::InputPort, GraphView::HashPort>& fanout = 900 view.GetFanout(generator); 901 NodeDef* trigger = nullptr; 902 Costs::NanoSeconds earliest_fanout(Costs::NanoSeconds::infinity()); 903 904 for (const auto& port : fanout) { 905 if (port.node == node) { 906 continue; 907 } 908 auto it = execution_times.find(port.node); 909 if (it != execution_times.end() && it->second < earliest_fanout) { 910 earliest_fanout = it->second; 911 trigger = port.node; 912 } 913 } 914 915 return trigger; 916 } 917 918 static bool IsSwappable(GraphView::InputPort input) { 919 const NodeDef& node = *input.node; 920 921 const OpDef* op_def; 922 if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) { 923 return false; 924 } 925 926 DataType dtype; 927 if (!InputTypeForNode(node, *op_def, input.port_id, &dtype).ok()) { 928 return false; 929 } 930 931 return !IsRefType(dtype); 932 } 933 934 struct MemInfo { 935 GraphView::OutputPort port; 936 int64 memory_used; 937 std::vector<GraphView::InputPort> uses_left; 938 double fitness; 939 940 bool operator<(const MemInfo& other) const { return fitness < other.fitness; } 941 }; 942 943 static bool IdentifySwappingCandidates( 944 Cluster* cluster, GrapplerItem* item, std::unordered_set<string>* skip_list, 945 std::unordered_map<NodeDef*, SwapInfo>* nodes_to_swap) { 946 GraphMemory memory(*item); 947 const std::unordered_map<string, DeviceProperties>& devices = 948 cluster->GetDevices(); 949 Status s = memory.InferStatically(devices); 950 if (!s.ok()) { 951 VLOG(1) << "Failed to infer memory usage: " << s.error_message(); 952 return false; 953 } 954 955 bool updated_graph = false; 956 for (const auto& device : devices) { 957 const string& name = device.first; 958 const DeviceProperties& prop = device.second; 959 if (prop.type() != "GPU") { 960 continue; 961 } 962 if (prop.memory_size() <= 0) { 963 VLOG(1) << "Peak memory usage unknown for device " << name; 964 continue; 965 } 966 const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name); 967 968 if (mem_usage.used_memory <= prop.memory_size()) { 969 continue; 970 } 971 int64 required_savings = mem_usage.used_memory - prop.memory_size(); 972 973 std::unordered_map<string, Costs::NanoSeconds> op_completion_times; 974 { 975 VirtualCluster vcluster(cluster->GetDevices()); 976 if (!vcluster.Provision().ok()) { 977 return false; 978 } 979 if (!vcluster.Initialize(*item).ok()) { 980 return false; 981 } 982 RunMetadata metadata; 983 Status s = vcluster.Run(item->graph, item->feed, item->fetch, &metadata); 984 if (!s.ok() && s.code() != error::RESOURCE_EXHAUSTED) { 985 return false; 986 } 987 988 for (const auto& dev_stats : metadata.step_stats().dev_stats()) { 989 for (const auto& node_stats : dev_stats.node_stats()) { 990 Costs::NanoSeconds exec_time = 991 Costs::NanoSeconds(1) + 992 Costs::MicroSeconds(node_stats.all_start_micros() + 993 node_stats.op_end_rel_micros()); 994 op_completion_times.emplace(node_stats.node_name(), exec_time); 995 } 996 } 997 } 998 999 Costs::Duration peak_time = -1; 1000 for (const auto& live_tensor : mem_usage.live_tensors) { 1001 if (live_tensor.allocation_time > peak_time) { 1002 peak_time = live_tensor.allocation_time; 1003 } 1004 } 1005 1006 std::vector<MemInfo> mem_state; 1007 1008 GraphView graph(&item->graph); 1009 for (const auto& live_tensor : mem_usage.live_tensors) { 1010 if (live_tensor.memory_used <= 1024) { 1011 // Don't bother with small tensors. 1012 continue; 1013 } 1014 if (live_tensor.deallocation_time - live_tensor.allocation_time <= 1015 Costs::Duration(1e6)) { 1016 // Not enough time to swap. 1017 VLOG(1) << "Not enough time to swap: skipping " << live_tensor.node; 1018 continue; 1019 } 1020 1021 if (skip_list->find(live_tensor.node) != skip_list->end()) { 1022 continue; 1023 } 1024 GraphView::OutputPort port = 1025 graph.GetOutputPort(live_tensor.node, live_tensor.output_id); 1026 if (!IsSwappable(graph, port)) { 1027 continue; 1028 } 1029 MemInfo mem_info; 1030 mem_info.port = port; 1031 mem_info.memory_used = live_tensor.memory_used; 1032 Costs::Duration allocation_time = live_tensor.allocation_time; 1033 Costs::Duration earliest_use(Costs::Duration::infinity()); 1034 bool valid = true; 1035 for (GraphView::InputPort input : graph.GetFanout(port)) { 1036 // Get execution time. 1037 auto it = op_completion_times.find(input.node->name()); 1038 if (it == op_completion_times.end()) { 1039 valid = false; 1040 break; 1041 } 1042 if (it->second <= peak_time) { 1043 continue; 1044 } 1045 1046 if (skip_list->find(input.node->name()) != skip_list->end()) { 1047 valid = false; 1048 break; 1049 } 1050 string input_name = 1051 strings::StrCat(input.node->name(), ":", input.port_id); 1052 if (skip_list->find(input_name) != skip_list->end()) { 1053 valid = false; 1054 break; 1055 } 1056 if (!IsSwappable(input)) { 1057 valid = false; 1058 break; 1059 } 1060 1061 // Set earliest use time that's after peak. 1062 mem_info.uses_left.emplace_back(input); 1063 earliest_use = std::min(earliest_use, it->second); 1064 } 1065 if (valid && !mem_info.uses_left.empty()) { 1066 // Compute the fitness: we need the tensor to be generated way away of 1067 // the time of peak memory usage (to ensure there is enough time to swap 1068 // it out). We also need to ensure it's used way after the peak time, to 1069 // ensure that swapping the tensor back in won't recreate the memory 1070 // bottleneck. Last but not least, we want the tensor to have as few 1071 // remaining uses as possible. 1072 mem_info.fitness = std::pow((earliest_use - peak_time).count(), 2); 1073 mem_info.fitness /= std::pow(mem_info.uses_left.size(), 2); 1074 mem_info.fitness += std::pow((allocation_time - peak_time).count(), 2); 1075 mem_info.fitness = -mem_info.fitness; 1076 mem_state.push_back(mem_info); 1077 } 1078 } 1079 1080 // Sort by fitness 1081 std::sort(mem_state.begin(), mem_state.end()); 1082 1083 for (const MemInfo& mem_info : mem_state) { 1084 for (const GraphView::InputPort fanout_to_swap : mem_info.uses_left) { 1085 VLOG(1) << "Will swap fanout " << fanout_to_swap.node->name() << ":" 1086 << fanout_to_swap.port_id << " of tensor " 1087 << mem_info.port.node->name() << ":" << mem_info.port.port_id 1088 << " of size " << mem_info.memory_used; 1089 1090 (*nodes_to_swap)[fanout_to_swap.node].inputs_to_swap.push_back( 1091 fanout_to_swap.port_id); 1092 } 1093 required_savings -= mem_info.memory_used; 1094 updated_graph = true; 1095 if (required_savings < 0) { 1096 break; 1097 } 1098 } 1099 } 1100 return updated_graph; 1101 } 1102 1103 bool SwappingPass(RewriterConfig::MemOptType optimization_level, 1104 Cluster* cluster, GrapplerItem* item, 1105 std::unordered_set<string>* skip_list) { 1106 std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap; 1107 if (optimization_level == RewriterConfig::SWAPPING_HEURISTICS || 1108 optimization_level == RewriterConfig::HEURISTICS) { 1109 // Use heuristics to figure out what needs to be swapped; 1110 IdentifySwappingCandidates(cluster, item, skip_list, &nodes_to_swap); 1111 } 1112 // Look for manual annotatations in the graph. 1113 for (auto& node : *item->graph.mutable_node()) { 1114 if (node.attr().count("_swap_to_host") != 0) { 1115 SwapInfo& swap_info = nodes_to_swap[&node]; 1116 const AttrValue& val = node.attr().at("_swap_to_host"); 1117 if (val.has_list()) { 1118 for (int64 input_id : val.list().i()) { 1119 swap_info.inputs_to_swap.push_back(input_id); 1120 } 1121 } else { 1122 int64 input_id = val.i(); 1123 swap_info.inputs_to_swap.push_back(input_id); 1124 } 1125 } 1126 } 1127 if (nodes_to_swap.empty()) { 1128 // Nothing to do. 1129 return false; 1130 } 1131 1132 // Estimate the size of the data to swap for each node. 1133 GraphProperties properties(*item); 1134 if (!properties.InferStatically(true).ok()) { 1135 return false; 1136 } 1137 for (auto& swap : nodes_to_swap) { 1138 const NodeDef* node = swap.first; 1139 const std::vector<OpInfo::TensorProperties>& props = 1140 properties.GetInputProperties(node->name()); 1141 SwapInfo& swap_info = swap.second; 1142 int64 bytes_to_swap = 0; 1143 for (int64 input_id : swap_info.inputs_to_swap) { 1144 const OpInfo::TensorProperties& t = props[input_id]; 1145 bytes_to_swap += EstimateSize(t); 1146 } 1147 // Let's assume we're going to swap over PCIe running at 16 GBps. 1148 swap_info.time_to_swap = bytes_to_swap / 16; 1149 } 1150 1151 std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times; 1152 if (!EstimateEarliestExecutionTimes(*item, cluster, &execution_times).ok()) { 1153 return false; 1154 } 1155 1156 std::unordered_map<string, const NodeDef*> name_map; 1157 for (const auto& node : item->graph.node()) { 1158 name_map[node.name()] = &node; 1159 } 1160 GraphView view(&item->graph); 1161 1162 bool updated_graph = false; 1163 1164 for (auto& swap : nodes_to_swap) { 1165 NodeDef* node = swap.first; 1166 const SwapInfo& swap_info = swap.second; 1167 if (skip_list->find(node->name()) != skip_list->end()) { 1168 continue; 1169 } 1170 1171 // Make sure the tensor isn't swapped back in right away: look for node that 1172 // will execute just before we need to swap the data back, and add a control 1173 // dependency from that node to the swap node. 1174 const NodeDef* in_trigger = 1175 FindSwapInTrigger(node, swap_info, name_map, execution_times); 1176 // If we failed, don't attempt to reprocess this node in a subsequent pass. 1177 if (!in_trigger) { 1178 skip_list->insert(node->name()); 1179 continue; 1180 } 1181 1182 // Swap all the tensors that are marked with the 'swap_to_host' attribute. 1183 for (int input_id : swap_info.inputs_to_swap) { 1184 string input_name = strings::StrCat(node->name(), ":", input_id); 1185 if (skip_list->find(input_name) != skip_list->end()) { 1186 continue; 1187 } else { 1188 // Don't attempt to reprocess this input in a subsequent pass. 1189 skip_list->insert(input_name); 1190 } 1191 1192 // Make sure the tensor is swapped out quickly: look for node that 1193 // will execute just after the tensor is generated and add a control 1194 // dependency from the swap out node to that node. 1195 NodeDef* out_trigger = 1196 FindSwapOutTrigger(node, input_id, view, execution_times); 1197 if (!out_trigger) { 1198 continue; 1199 } 1200 1201 std::pair<NodeDef*, NodeDef*> swap_nodes; 1202 if (!BuildSwapPair(node, input_id, name_map, &item->graph, &swap_nodes) 1203 .ok()) { 1204 continue; 1205 } 1206 *swap_nodes.first->add_input() = node->input(input_id); 1207 *node->mutable_input(input_id) = swap_nodes.second->name(); 1208 1209 // Add the control dependencies needed to delay the execution of the swap. 1210 out_trigger->add_input(strings::StrCat("^", swap_nodes.first->name())); 1211 swap_nodes.second->add_input(strings::StrCat("^", in_trigger->name())); 1212 1213 // Make sure we won't try to swap the swap nodes in subsequent passes. 1214 skip_list->insert(swap_nodes.first->name()); 1215 skip_list->insert(swap_nodes.second->name()); 1216 } 1217 } 1218 return updated_graph; 1219 } 1220 1221 Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, 1222 GraphDef* optimized_graph) { 1223 *optimized_graph = item.graph; 1224 1225 RecomputationRewritingPass(optimization_level_, 1226 recomputation_targets_name_prefix_, 1227 optimized_graph, item); 1228 1229 GrapplerItem optimized_item(item, std::move(*optimized_graph)); 1230 std::unordered_set<string> skip_list; 1231 // Bound the number of rewrite passes to avoid long processing times on graphs 1232 // that simply won't fit in memory. 1233 bool updated_graph = true; 1234 for (int i = 0; i < 25 && updated_graph; ++i) { 1235 updated_graph = false; 1236 if ((optimization_level_ == RewriterConfig::DEFAULT_MEM_OPT || 1237 optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS || 1238 optimization_level_ == RewriterConfig::HEURISTICS) && 1239 cluster != nullptr) { 1240 updated_graph |= SchedulingPass(cluster, &optimized_item); 1241 } 1242 1243 if ((optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS || 1244 optimization_level_ == RewriterConfig::HEURISTICS || 1245 optimization_level_ == RewriterConfig::MANUAL) && 1246 cluster != nullptr) { 1247 updated_graph |= SwappingPass(optimization_level_, cluster, 1248 &optimized_item, &skip_list); 1249 } 1250 } 1251 1252 optimized_graph->Swap(&optimized_item.graph); 1253 return Status::OK(); 1254 } 1255 1256 void MemoryOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item, 1257 const GraphDef& optimized_graph, double result) { 1258 // Nothing to do for MemoryOptimizer. 1259 } 1260 1261 } // end namespace grappler 1262 } // end namespace tensorflow 1263