Home | History | Annotate | Download | only in optimizers
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
     17 
     18 #include <algorithm>
     19 #include <queue>
     20 #include <unordered_map>
     21 #include <unordered_set>
     22 #include <vector>
     23 
     24 #include "tensorflow/core/framework/attr_value.pb.h"
     25 #include "tensorflow/core/framework/node_def.pb.h"
     26 #include "tensorflow/core/framework/op.h"
     27 #include "tensorflow/core/framework/tensor_shape.pb.h"
     28 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
     29 #include "tensorflow/core/grappler/costs/graph_memory.h"
     30 #include "tensorflow/core/grappler/costs/graph_properties.h"
     31 #include "tensorflow/core/grappler/graph_view.h"
     32 #include "tensorflow/core/grappler/grappler_item.h"
     33 #include "tensorflow/core/grappler/op_types.h"
     34 #include "tensorflow/core/grappler/optimizers/graph_rewriter.h"
     35 #include "tensorflow/core/grappler/optimizers/static_schedule.h"
     36 #include "tensorflow/core/grappler/utils.h"
     37 #include "tensorflow/core/grappler/utils/topological_sort.h"
     38 #include "tensorflow/core/grappler/utils/traversal.h"
     39 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
     40 
     41 namespace tensorflow {
     42 namespace grappler {
     43 
     44 // Prefix added to nodes which are recomputed.
     45 const char* kRecomputedNodePrefix = "Recomputed";
     46 const char* kRecomputeTriggerNodePrefix = "RecomputeTrigger";
     47 // Attribute which may be added to nodes to manually allow them to be
     48 // recomputed.
     49 const char* kRecomputeHint = "_recompute_hint";
     50 
     51 // Ops which we wouldn't mind recomputing to save memory.
     52 // TODO(allenl): Replace this list with a cost model.
     53 std::unordered_set<string> GetCheapToRecomputeOps() {
     54   std::unordered_set<string> cheap_ops = {
     55       "Add",      "AddN",       "BiasAdd",        "Cast",   "Fill",
     56       "FloorDiv", "FloorMod",   "FusedBatchNorm", "Mul",    "Neg",
     57       "RealDiv",  "Reciprocal", "Relu",           "Relu6",  "Reshape",
     58       "Rsqrt",    "Sigmoid",    "Sqrt",           "Square", "SquaredDifference",
     59       "Sub",      "Tile",       "Transpose"};
     60   return cheap_ops;
     61 }
     62 
     63 // Find recomputable ops which feed into target nodes.
     64 std::unordered_set<const NodeDef*> FindCandidateRecomputeNodes(
     65     const NodeMap& node_map, const GraphDef* graph,
     66     const std::function<bool(const NodeDef&)>& is_candidate,
     67     const std::function<bool(const NodeDef&)>& is_target) {
     68   std::unordered_set<const NodeDef*> candidate_recompute_nodes;
     69   for (const auto& node : graph->node()) {
     70     if (!is_candidate(node)) {
     71       continue;
     72     }
     73     bool has_target_output = false;
     74     for (const NodeDef* output : node_map.GetOutputs(node.name())) {
     75       // It only makes sense to recompute this if it feeds into a target
     76       // node. We expand this to dependencies in GetOpGroupsToRecompute.
     77       if (is_target(*output)) {
     78         has_target_output = true;
     79         break;
     80       }
     81     }
     82     if (!has_target_output) {
     83       continue;
     84     }
     85     bool has_target_input = false;
     86     for (const string& input_name : node.input()) {
     87       // Don't recompute nodes which depend on target nodes.
     88       const NodeDef* input_node = node_map.GetNode(input_name);
     89       if (is_target(*input_node)) {
     90         has_target_input = true;
     91         break;
     92       }
     93     }
     94     if (has_target_input) {
     95       continue;
     96     }
     97     candidate_recompute_nodes.insert(&node);
     98   }
     99   return candidate_recompute_nodes;
    100 }
    101 
    102 void connected_subgraph(const NodeMap& node_map, bool collect_inputs,
    103                         bool collect_outputs,
    104                         const std::function<bool(const NodeDef&)>& is_candidate,
    105                         std::unordered_set<const NodeDef*>* expanded_nodes) {
    106   std::queue<const NodeDef*> to_visit;
    107   for (const NodeDef* starting_node : *expanded_nodes) {
    108     to_visit.push(starting_node);
    109   }
    110   expanded_nodes->clear();
    111   while (!to_visit.empty()) {
    112     const NodeDef* current_node = to_visit.front();
    113     to_visit.pop();
    114     if (!expanded_nodes->insert(current_node).second) {
    115       // We already visited this node
    116       continue;
    117     }
    118     if (collect_inputs) {
    119       // Add inputs and outputs to this subgraph if they are candidates
    120       for (const string& input_name_raw : current_node->input()) {
    121         const NodeDef* input_node = node_map.GetNode(input_name_raw);
    122         if (expanded_nodes->count(input_node) == 0 &&
    123             is_candidate(*input_node)) {
    124           to_visit.push(input_node);
    125         }
    126       }
    127     }
    128     if (collect_outputs) {
    129       for (const NodeDef* output : node_map.GetOutputs(current_node->name())) {
    130         if (expanded_nodes->count(output) == 0 && is_candidate(*output)) {
    131           to_visit.push(output);
    132         }
    133       }
    134     }
    135   }
    136 }
    137 
    138 struct RecomputedSubGraph {
    139   std::unordered_set<const NodeDef*> recomputed_source_nodes;
    140   std::unordered_set<NodeDef*> target_nodes;
    141 };
    142 
    143 // Find groups of ops to recompute together based on `should_recompute`.
    144 std::vector<RecomputedSubGraph> GetOpGroupsToRecompute(
    145     const GraphDef* graph, const NodeMap& node_map,
    146     const std::function<bool(const NodeDef&)>& should_recompute,
    147     const std::function<bool(const NodeDef&)>& is_target) {
    148   std::unordered_set<const NodeDef*> visited_nodes;
    149   std::vector<RecomputedSubGraph> subgraphs_to_recompute;
    150   std::unordered_set<const NodeDef*> candidate_recompute_nodes =
    151       FindCandidateRecomputeNodes(node_map, graph, should_recompute, is_target);
    152   for (const NodeDef* recompute_node : candidate_recompute_nodes) {
    153     if (visited_nodes.count(recompute_node) > 0) {
    154       continue;
    155     }
    156     RecomputedSubGraph current_recomputation;
    157     // Build out recomputation groups by expanding to inexpensive-to-recompute
    158     // nodes which do not feed target nodes. The goal is to capture some
    159     // intermediate activations within this graph.
    160     std::unordered_set<const NodeDef*> unpruned_recompute_nodes;
    161     unpruned_recompute_nodes.insert(recompute_node);
    162     connected_subgraph(node_map,
    163                        true,  // Collect inputs
    164                        true,  // Collect outputs
    165                        should_recompute, &unpruned_recompute_nodes);
    166     visited_nodes.insert(unpruned_recompute_nodes.begin(),
    167                          unpruned_recompute_nodes.end());
    168     for (const NodeDef* recompute_node : unpruned_recompute_nodes) {
    169       bool inserted_feed = false;
    170       for (NodeDef* output : node_map.GetOutputs(recompute_node->name())) {
    171         if (is_target(*output)) {
    172           current_recomputation.target_nodes.insert(output);
    173           if (!inserted_feed) {
    174             // Keep track of nodes which feed directly into a target node. These
    175             // and nodes which feed into them will define the recomputed
    176             // subgraph.
    177             current_recomputation.recomputed_source_nodes.insert(
    178                 recompute_node);
    179             inserted_feed = true;
    180           }
    181         }
    182       }
    183     }
    184     // Recompute only nodes which eventually feed into a target node.
    185     connected_subgraph(node_map,
    186                        true,   // Collect inputs
    187                        false,  // Collect outputs
    188                        [&unpruned_recompute_nodes](const NodeDef& node) {
    189                          return unpruned_recompute_nodes.count(&node) != 0;
    190                        },
    191                        &current_recomputation.recomputed_source_nodes);
    192     if (current_recomputation.target_nodes.empty()) {
    193       continue;
    194     }
    195     subgraphs_to_recompute.push_back(current_recomputation);
    196   }
    197   return subgraphs_to_recompute;
    198 }
    199 
    200 // Computes the maximum topological numbers of (1) target node components
    201 // (gradient nodes being fed by the recomputation), and (2) child recompute node
    202 // components for each recomputed node. We will not attach any control
    203 // dependencies to a recomputation unless they have component numbers greater
    204 // than this value (to prevent cycles).
    205 std::unordered_map<const NodeDef*, int> GetMaxDownstreamComponents(
    206     const std::unordered_set<const NodeDef*>& recomputed_source_nodes,
    207     const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map,
    208     const std::unordered_map<const NodeDef*, int>& components) {
    209   std::unordered_map<const NodeDef*, int> recomputed_node_components;
    210   // Start by setting component numbers to the maximum among target nodes.
    211   for (const NodeDef* original_recompute_node : recomputed_source_nodes) {
    212     int max_target_component = -1;
    213     for (NodeDef* output :
    214          node_map.GetOutputs(original_recompute_node->name())) {
    215       if (target_nodes.count(output) != 0) {
    216         int current_target_component = components.find(output)->second;
    217         if (current_target_component > max_target_component) {
    218           max_target_component = current_target_component;
    219         }
    220       }
    221     }
    222     if (max_target_component > -1) {
    223       recomputed_node_components[original_recompute_node] =
    224           max_target_component;
    225     }
    226   }
    227   // Sort recomputed nodes topologically (based on the original graph) so we can
    228   // efficiently assign to each node the maximum of its recomputed child
    229   // components and its own targets.
    230   std::vector<const NodeDef*> recomputed_source_nodes_topological(
    231       recomputed_source_nodes.begin(), recomputed_source_nodes.end());
    232   std::sort(recomputed_source_nodes_topological.begin(),
    233             recomputed_source_nodes_topological.end(),
    234             [&components](const NodeDef* first, const NodeDef* second) {
    235               return components.find(first)->second <
    236                      components.find(second)->second;
    237             });
    238   for (const NodeDef* original_recompute_node :
    239        recomputed_source_nodes_topological) {
    240     int max_component;
    241     auto recomputed_component_iterator =
    242         recomputed_node_components.find(original_recompute_node);
    243     if (recomputed_component_iterator != recomputed_node_components.end()) {
    244       max_component = recomputed_component_iterator->second;
    245     } else {
    246       max_component = -1;
    247     }
    248     for (NodeDef* output :
    249          node_map.GetOutputs(original_recompute_node->name())) {
    250       if (recomputed_source_nodes.count(output) == 0) {
    251         continue;
    252       }
    253       auto child_component_iterator = recomputed_node_components.find(output);
    254       CHECK(child_component_iterator != recomputed_node_components.end());
    255       int child_component = child_component_iterator->second;
    256       if (child_component > max_component) {
    257         max_component = child_component;
    258       }
    259     }
    260     CHECK_GE(max_component, 0);
    261     recomputed_node_components[original_recompute_node] = max_component;
    262   }
    263   return recomputed_node_components;
    264 }
    265 
    266 // Modifies `graph`, adding trigger nodes and returning a mapping from
    267 // `recomputed_source_nodes` to trigger nodes which will not create loops in the
    268 // graph (using the component numberings in `components` and
    269 // `recomputed_node_max_feed_components`). The copied nodes (not the nodes in
    270 // recomputed_source_nodes, which are the originals) eventually get these
    271 // control dependencies.
    272 std::unordered_map<const NodeDef*, const NodeDef*>
    273 AddRecomputeControlDependencyNodes(
    274     const std::unordered_set<const NodeDef*>& recomputed_source_nodes,
    275     const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map,
    276     const std::unordered_map<const NodeDef*, int>& components,
    277     const std::unordered_map<const NodeDef*, int>&
    278         recomputed_node_max_feed_components,
    279     GraphDef* graph) {
    280   // Sort recomputed nodes based on max downstream components.
    281   std::vector<const NodeDef*> recomputed_source_nodes_topological(
    282       recomputed_source_nodes.begin(), recomputed_source_nodes.end());
    283   std::sort(recomputed_source_nodes_topological.begin(),
    284             recomputed_source_nodes_topological.end(),
    285             [&recomputed_node_max_feed_components](const NodeDef* first,
    286                                                    const NodeDef* second) {
    287               int first_component =
    288                   recomputed_node_max_feed_components.find(first)->second;
    289               int second_component =
    290                   recomputed_node_max_feed_components.find(second)->second;
    291               return first_component > second_component
    292                      // Ensure a consistent ordering. This is necessary because
    293                      // we're working not with node component numbers (which are
    294                      // unique) but with the maximum across nodes they feed into
    295                      // (very much not unique).
    296                      || (first_component == second_component &&
    297                          first->name() > second->name());
    298             });
    299   // Create merged control dependency nodes by sorting target inputs
    300   // topologically and zipper merging with the sorted recomputed nodes.
    301   std::vector<const NodeDef*> target_inputs_topological;
    302   for (const NodeDef* target_node : target_nodes) {
    303     for (const string& target_input_name_raw : target_node->input()) {
    304       const NodeDef* target_input = node_map.GetNode(target_input_name_raw);
    305       // If this node has already had one of its inputs recomputed during this
    306       // rewriting pass, we ignore that recomputed node here (it will not be in
    307       // the NodeMap).
    308       if (target_input == nullptr ||
    309           recomputed_source_nodes.count(target_input) != 0 ||
    310           components.find(target_node)->second ==
    311               components.find(target_input)->second) {
    312         continue;
    313       }
    314       target_inputs_topological.push_back(target_input);
    315     }
    316   }
    317   std::sort(target_inputs_topological.begin(), target_inputs_topological.end(),
    318             [&components](const NodeDef* first, const NodeDef* second) {
    319               return components.find(first)->second >
    320                      components.find(second)->second;
    321             });
    322   auto target_input_iterator = target_inputs_topological.begin();
    323   NodeDef* current_trigger_node = nullptr;
    324   std::unordered_map<const NodeDef*, const NodeDef*> triggers;
    325   for (const NodeDef* original_recomputed_node :
    326        recomputed_source_nodes_topological) {
    327     NodeDef* new_trigger_node = graph->add_node();
    328     new_trigger_node->set_name(AddPrefixToNodeName(
    329         original_recomputed_node->name(), kRecomputeTriggerNodePrefix));
    330     new_trigger_node->set_op("NoOp");
    331     new_trigger_node->set_device(original_recomputed_node->device());
    332     if (current_trigger_node != nullptr) {
    333       *new_trigger_node->add_input() =
    334           strings::StrCat("^", current_trigger_node->name());
    335     }
    336     current_trigger_node = new_trigger_node;
    337     triggers[original_recomputed_node] = current_trigger_node;
    338     for (;
    339          target_input_iterator != target_inputs_topological.end() &&
    340          components.find(*target_input_iterator)->second >
    341              recomputed_node_max_feed_components.find(original_recomputed_node)
    342                  ->second;
    343          ++target_input_iterator) {
    344       *current_trigger_node->add_input() =
    345           strings::StrCat("^", (*target_input_iterator)->name());
    346       VLOG(2) << "  Recomputation trigger " << current_trigger_node->name()
    347               << " depends on " << (*target_input_iterator)->name();
    348     }
    349   }
    350   return triggers;
    351 }
    352 
    353 string RecomputedOrOriginalNodeName(
    354     const std::unordered_set<string>& recomputed_node_names,
    355     const string& original_node_name) {
    356   if (recomputed_node_names.find(original_node_name) ==
    357       recomputed_node_names.end()) {
    358     return original_node_name;
    359   } else {
    360     return AddPrefixToNodeName(original_node_name, kRecomputedNodePrefix);
    361   }
    362 }
    363 
    364 // Helper function to recompute a sub-graph (recomputed_source_nodes). Edges
    365 // from recomputed_source_nodes to target_nodes are changed to start from the
    366 // recomputed nodes.
    367 void RecomputeSubgraph(
    368     const std::unordered_set<const NodeDef*>& recomputed_source_nodes,
    369     const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map,
    370     const std::unordered_map<const NodeDef*, int>& components,
    371     GraphDef* graph) {
    372   std::unordered_set<string> recomputed_node_names;
    373   VLOG(1) << "Recomputing a " << recomputed_source_nodes.size()
    374           << " node subgraph";
    375   std::unordered_map<const NodeDef*, int> recomputed_node_components =
    376       GetMaxDownstreamComponents(recomputed_source_nodes, target_nodes,
    377                                  node_map, components);
    378   for (const NodeDef* original_node : recomputed_source_nodes) {
    379     VLOG(2) << "  " << original_node->name();
    380     recomputed_node_names.insert(original_node->name());
    381   }
    382   std::unordered_map<const NodeDef*, const NodeDef*> triggers =
    383       AddRecomputeControlDependencyNodes(recomputed_source_nodes, target_nodes,
    384                                          node_map, components,
    385                                          recomputed_node_components, graph);
    386   // Create the recomputed sub-graph
    387   for (const NodeDef* original_node : recomputed_source_nodes) {
    388     NodeDef* copied_node = graph->add_node();
    389     copied_node->set_name(
    390         AddPrefixToNodeName(original_node->name(), kRecomputedNodePrefix));
    391     copied_node->set_op(original_node->op());
    392     *copied_node->mutable_attr() = original_node->attr();
    393     copied_node->set_device(original_node->device());
    394     for (const string& original_input_name : original_node->input()) {
    395       // Set inputs which are internal to the copied subgraph to their copied
    396       // versions.
    397       *copied_node->add_input() = RecomputedOrOriginalNodeName(
    398           recomputed_node_names, original_input_name);
    399     }
    400     // Each recomputed node gets a control dependency to prevent it from being
    401     // recomputed immediately.
    402     *copied_node->add_input() =
    403         strings::StrCat("^", triggers[original_node]->name());
    404   }
    405   // Set the inputs of nodes in the target subgraph to the recomputed nodes
    406   // where applicable.
    407   for (NodeDef* target_node : target_nodes) {
    408     for (string& target_input_name : *target_node->mutable_input()) {
    409       target_input_name = RecomputedOrOriginalNodeName(recomputed_node_names,
    410                                                        target_input_name);
    411     }
    412   }
    413 }
    414 
    415 void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
    416                                 const string& recomputation_targets_name_prefix,
    417                                 GraphDef* graph, const GrapplerItem& item) {
    418   if (optimization_level != RewriterConfig::RECOMPUTATION_HEURISTICS &&
    419       optimization_level != RewriterConfig::HEURISTICS &&
    420       optimization_level != RewriterConfig::MANUAL) {
    421     // Nothing to do
    422     return;
    423   }
    424   // The topological numberings and NodeMap will be stale as soon as we start
    425   // modifying the graph in RecomputeSubgraph. However, RecomputeSubgraph only
    426   // looks up nodes which were in the original graph, and preserves the graph
    427   // topology it's interested in.
    428   // We don't use the results of this topological sort until later, but this
    429   // call invalidates all NodeDef pointers, so it needs to be done before we
    430   // start collecting those.
    431   TF_CHECK_OK(TopologicalSort(graph));
    432   NodeMap node_map(graph);
    433   std::vector<RecomputedSubGraph> recomputed_subgraphs;
    434   // Do not recompute nodes which are fed, since the recomputed node would not
    435   // take on the fed value (i.e. gradients would be incorrect).
    436   std::unordered_set<string> feeds;
    437   for (const auto& feed : item.feed) {
    438     feeds.insert(NodeName(feed.first));
    439   }
    440   std::function<bool(const NodeDef&)> is_target =
    441       [&recomputation_targets_name_prefix](const NodeDef& node) {
    442         // Nodes whose inputs we may want to recompute. Typically targets will
    443         // be gradients (recomputation_targets_name_prefix="gradients/"),
    444         // although the prefix is configurable since gradients may be created
    445         // in a name scope.
    446         // TODO(allenl): Use a static schedule
    447         // (grappler::EstimateEarliestExecutionTimes) to recompute only nodes
    448         // whose outputs will sit around for a while.
    449         return node.name().find(recomputation_targets_name_prefix) == 0;
    450       };
    451 
    452   if (optimization_level == RewriterConfig::RECOMPUTATION_HEURISTICS ||
    453       optimization_level == RewriterConfig::HEURISTICS) {
    454     // TODO(allenl): Handle ResNet-like architectures better. Right now all of
    455     // the cheap forward ops get grouped into a single subgraph which must
    456     // execute before gradients start executing (unless layers are manually
    457     // separated by identity ops).
    458     std::unordered_set<string> cheap_to_recompute_ops =
    459         GetCheapToRecomputeOps();
    460     recomputed_subgraphs = GetOpGroupsToRecompute(
    461         graph, node_map,
    462         [&cheap_to_recompute_ops, &feeds, &is_target](const NodeDef& node) {
    463           return !is_target(node) && feeds.count(node.name()) == 0 &&
    464                  (cheap_to_recompute_ops.count(node.op()) > 0 ||
    465                   node.attr().count(kRecomputeHint) > 0);
    466         },
    467         is_target);
    468   } else if (optimization_level == RewriterConfig::MANUAL) {
    469     recomputed_subgraphs = GetOpGroupsToRecompute(
    470         graph, node_map,
    471         [&feeds, &is_target](const NodeDef& node) {
    472           return !is_target(node) && feeds.count(node.name()) == 0 &&
    473                  node.attr().count(kRecomputeHint) > 0;
    474         },
    475         is_target);
    476   }
    477   if (!recomputed_subgraphs.empty()) {
    478     std::unordered_map<const NodeDef*, int> topological_numbering;
    479     for (int node_number = 0; node_number < graph->node().size();
    480          ++node_number) {
    481       topological_numbering[graph->mutable_node(node_number)] =
    482           graph->node().size() - node_number - 1;
    483     }
    484     // Duplicate the indicated sub-graphs and set up control dependencies
    485     for (const RecomputedSubGraph& subgraph : recomputed_subgraphs) {
    486       RecomputeSubgraph(subgraph.recomputed_source_nodes, subgraph.target_nodes,
    487                         node_map, topological_numbering, graph);
    488     }
    489   }
    490 }
    491 
    492 bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
    493   // Look for AddN nodes (and equivalent) and record input names.
    494   GraphView view(&item->graph);
    495 
    496   std::unordered_map<string, std::unordered_set<NodeDef*>> addn_list;
    497   for (NodeDef& node : *item->graph.mutable_node()) {
    498     if (!IsAddN(node) && node.op() != "AccumulateNV2") {
    499       continue;
    500     }
    501     // There is nothing to gain by optimizing nodes with 2 or fewer inputs.
    502     if (view.NumFanins(node, false) <= 2) {
    503       continue;
    504     }
    505     for (const auto& input : view.GetFanins(node, false)) {
    506       if (input.node->device() == node.device()) {
    507         string tensor_name =
    508             strings::StrCat(input.node->name(), ":", input.port_id);
    509         addn_list[tensor_name].insert(&node);
    510       }
    511     }
    512   }
    513 
    514   if (addn_list.empty()) {
    515     return false;
    516   }
    517 
    518   GraphMemory memory(*item);
    519   const std::unordered_map<string, DeviceProperties>& devices =
    520       cluster->GetDevices();
    521   Status s = memory.InferStatically(devices);
    522   if (!s.ok()) {
    523     VLOG(1) << "Failed to infer memory usage: " << s.error_message();
    524     return false;
    525   }
    526 
    527   std::unordered_set<NodeDef*> addn_to_rewrite;
    528   for (const auto& device : devices) {
    529     const string& name = device.first;
    530     const DeviceProperties& prop = device.second;
    531     if (prop.memory_size() <= 0) {
    532       VLOG(1) << "Available memory unknown for device " << name;
    533       continue;
    534     }
    535     const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name);
    536 
    537     if (mem_usage.used_memory <= prop.memory_size() * 0.8) {
    538       continue;
    539     }
    540 
    541     for (const auto& live : mem_usage.live_tensors) {
    542       string tensor_name = strings::StrCat(live.node, ":", live.output_id);
    543       auto it = addn_list.find(tensor_name);
    544       if (it != addn_list.end()) {
    545         addn_to_rewrite.insert(it->second.begin(), it->second.end());
    546       }
    547     }
    548   }
    549 
    550   if (addn_to_rewrite.empty()) {
    551     return false;
    552   }
    553   GraphProperties properties(*item);
    554   s = properties.InferStatically(false);
    555   if (!s.ok()) {
    556     VLOG(1) << "Failed to infer shapes: " << s.error_message();
    557     return false;
    558   }
    559 
    560   bool updated_graph = false;
    561   // Rewrite the AddN.
    562   for (NodeDef* node : addn_to_rewrite) {
    563     if (!properties.HasOutputProperties(node->name())) {
    564       VLOG(1) << "Missing properties for " << node->name();
    565       continue;
    566     }
    567     const TensorShapeProto& shape =
    568         properties.GetOutputProperties(node->name())[0].shape();
    569     PartialTensorShape shp(shape);
    570     if (!shp.IsFullyDefined()) {
    571       VLOG(1) << "Shape not fully known for " << node->name();
    572       continue;
    573     }
    574 
    575     // Compute a topological ordering for the node fanin.
    576     std::unordered_map<NodeDef*, int> topo_order;
    577     ReverseDfs(view, {node}, nullptr,
    578                [&topo_order](NodeDef* n) {
    579                  int topo_index = topo_order.size();
    580                  topo_order[n] = topo_index;
    581                },
    582                nullptr);
    583 
    584     std::vector<int> input_topo_index;
    585 
    586     for (int i = 0; i < node->input_size(); ++i) {
    587       const string& input = node->input(i);
    588       const string node_name = NodeName(input);
    589       NodeDef* node = view.GetNode(node_name);
    590       input_topo_index.push_back(topo_order.at(node));
    591     }
    592     int min_input_topo_index = INT_MAX;
    593     int min_input_id = -1;
    594     for (int i = 0; i < node->input_size(); ++i) {
    595       if (IsControlInput(node->input(i))) {
    596         // control inputs are always last.
    597         break;
    598       }
    599       const int current = input_topo_index[i];
    600       if (current < min_input_topo_index) {
    601         min_input_topo_index = current;
    602         min_input_id = i;
    603       }
    604     }
    605     CHECK_LE(0, min_input_id);
    606     std::vector<string> pre_ctrl_deps;
    607     std::vector<string> post_ctrl_deps;
    608     for (int i = node->input_size() - 1; i >= 0; --i) {
    609       if (!IsControlInput(node->input(i))) {
    610         // control inputs are always last.
    611         break;
    612       }
    613       if (input_topo_index[i] < min_input_topo_index) {
    614         // These control dependencies can be executed before the node.
    615         pre_ctrl_deps.push_back(node->input(i));
    616       } else {
    617         // These control dependencies should be executed after the node.
    618         post_ctrl_deps.push_back(node->input(i));
    619       }
    620     }
    621 
    622     DataType dtype = node->attr().at("T").type();
    623     const string& device = node->device();
    624 
    625     // Create the temporary variable that will hold intermediate results
    626     NodeDef* tmp_var = item->graph.add_node();
    627     tmp_var->set_name(strings::StrCat(node->name(), "/tmp_var"));
    628     tmp_var->set_op("TemporaryVariable");
    629     tmp_var->set_device(device);
    630     (*tmp_var->mutable_attr())["dtype"].set_type(dtype);
    631     *(*tmp_var->mutable_attr())["shape"].mutable_shape() = shape;
    632     (*tmp_var->mutable_attr())["var_name"].set_s(tmp_var->name());
    633 
    634     for (const string& ctrl_dep : pre_ctrl_deps) {
    635       *tmp_var->add_input() = ctrl_dep;
    636     }
    637     *tmp_var->add_input() =
    638         AsControlDependency(NodeName(node->input(min_input_id)));
    639 
    640     // Initialize it to zero
    641     NodeDef* zeros = item->graph.add_node();
    642     zeros->set_name(strings::StrCat(node->name(), "/tmp_var_zeros"));
    643     zeros->set_op("ZerosLike");
    644     zeros->set_device(device);
    645     (*zeros->mutable_attr())["T"].set_type(dtype);
    646     *zeros->add_input() = node->input(min_input_id);
    647 
    648     NodeDef* initialize = item->graph.add_node();
    649     initialize->set_name(strings::StrCat(node->name(), "/tmp_var_initializer"));
    650     initialize->set_op("Assign");
    651     initialize->set_device(device);
    652     (*initialize->mutable_attr())["T"].set_type(dtype);
    653     (*initialize->mutable_attr())["use_locking"].set_b(false);
    654     (*initialize->mutable_attr())["validate_shape"].set_b(false);
    655     *initialize->add_input() = tmp_var->name();
    656     *initialize->add_input() = zeros->name();
    657 
    658     // Add the assignadd nodes
    659     std::vector<NodeDef*> accumulates;
    660     for (int i = 0; i < node->input_size(); ++i) {
    661       const string& input = node->input(i);
    662       if (!IsControlInput(input)) {
    663         NodeDef* accumulate = item->graph.add_node();
    664         accumulate->set_name(
    665             strings::StrCat(node->name(), "/tmp_var_accum_", i));
    666         accumulate->set_op("AssignAdd");
    667         accumulate->set_device(device);
    668         (*accumulate->mutable_attr())["T"].set_type(dtype);
    669         (*accumulate->mutable_attr())["use_locking"].set_b(true);
    670         *accumulate->add_input() = initialize->name();
    671         *accumulate->add_input() = input;
    672         accumulates.push_back(accumulate);
    673       }
    674     }
    675 
    676     // Rewrite the AddN node as a DestroyTemporaryVariable ops
    677     node->set_op("DestroyTemporaryVariable");
    678     node->clear_input();
    679     node->clear_attr();
    680     (*node->mutable_attr())["T"].set_type(dtype);
    681     (*node->mutable_attr())["var_name"].set_s(tmp_var->name());
    682     *node->add_input() = initialize->name();
    683     for (const NodeDef* accum : accumulates) {
    684       *node->add_input() = AsControlDependency(accum->name());
    685     }
    686     for (const string& ctrl_dep : post_ctrl_deps) {
    687       *node->add_input() = ctrl_dep;
    688     }
    689 
    690     updated_graph = true;
    691   }
    692 
    693   return updated_graph;
    694 }
    695 
    696 Status BuildSwapPair(NodeDef* node, int input_to_swap,
    697                      const std::unordered_map<string, const NodeDef*>& name_map,
    698                      GraphDef* graph,
    699                      std::pair<NodeDef*, NodeDef*>* swap_pair) {
    700   const OpDef* op_def;
    701   TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node->op(), &op_def));
    702   DataType input_type;
    703   TF_RETURN_IF_ERROR(
    704       InputTypeForNode(*node, *op_def, input_to_swap, &input_type));
    705   if (IsRefType(input_type)) {
    706     return errors::InvalidArgument("Can't swap input ", input_to_swap,
    707                                    " of node ", node->name(),
    708                                    " since it expects a reference");
    709   }
    710 
    711   string tensor_to_swap = strings::StrCat(node->name(), "_", input_to_swap);
    712   string swap_out_name = strings::StrCat("swap_out_", tensor_to_swap);
    713   string swap_in_name = strings::StrCat("swap_in_", tensor_to_swap);
    714   if (name_map.find(swap_out_name) != name_map.end() ||
    715       name_map.find(swap_in_name) != name_map.end()) {
    716     return errors::InvalidArgument("Input ", input_to_swap, " of node ",
    717                                    node->name(), " is already swapped");
    718   }
    719 
    720   // Force the tensor to be copied to cpu.
    721   NodeDef* swap_out_node = graph->add_node();
    722   swap_out_node->set_name(swap_out_name);
    723   swap_out_node->set_op("Identity");
    724   swap_out_node->set_device("/device:CPU:0");
    725 
    726   // Force the tensor to be restored to the device.
    727   NodeDef* swap_in_node = graph->add_node();
    728   swap_in_node->set_name(swap_in_name);
    729   swap_in_node->set_op("Identity");
    730   *swap_in_node->add_input() = swap_out_node->name();
    731 
    732   // Colocate the swap_in_ node with the node itself.
    733   swap_in_node->set_device(node->device());
    734   string coloc_group = strings::StrCat("loc@", tensor_to_swap);
    735   (*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
    736   (*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
    737 
    738   (*swap_in_node->mutable_attr())["T"].set_type(input_type);
    739   (*swap_out_node->mutable_attr())["T"].set_type(input_type);
    740   *swap_pair = std::make_pair(swap_out_node, swap_in_node);
    741 
    742   return Status::OK();
    743 }
    744 
    745 static int64 EstimateSize(const OpInfo::TensorProperties& t) {
    746   DataType dtype = t.dtype();
    747   int64 size = DataTypeSize(dtype);
    748   TensorShapeProto shape = t.shape();
    749   if (shape.unknown_rank()) {
    750     // Can't infer the size if the rank is unknown. It has to be at least a
    751     // scalar though.
    752     return size;
    753   }
    754   // If one of the dimensions is unknown statically, assume it's at least one.
    755   for (int i = 0; i < shape.dim_size(); ++i) {
    756     if (shape.dim(i).size() < 0) {
    757       shape.mutable_dim(i)->set_size(1);
    758     }
    759   }
    760   int64 num_elems = TensorShape(shape).num_elements();
    761   return num_elems * size;
    762 }
    763 
    764 struct SwapInfo {
    765   std::vector<int> inputs_to_swap;
    766   Costs::NanoSeconds time_to_swap = 0;
    767 };
    768 
    769 static const NodeDef* FindSwapInTrigger(
    770     const NodeDef* node, const SwapInfo& swap_info,
    771     const std::unordered_map<string, const NodeDef*>& name_map,
    772     const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
    773         execution_times) {
    774   // max_trigger_time stores the time before which the swap operation needs to
    775   // be started in order to load the data back onto the accelerator without
    776   // delaying the downstream computation.
    777   Costs::NanoSeconds max_trigger_time(0);
    778   std::set<string> possible_inputs;
    779   for (int i = 0; i < node->input_size(); ++i) {
    780     const string input_node_name = NodeName(node->input(i));
    781     auto it1 = name_map.find(input_node_name);
    782     if (it1 == name_map.end()) {
    783       return nullptr;
    784     }
    785     const NodeDef* input_node = it1->second;
    786 
    787     auto it2 = execution_times.find(input_node);
    788     if (it2 == execution_times.end()) {
    789       return nullptr;
    790     }
    791     max_trigger_time = std::max(max_trigger_time, it2->second);
    792     possible_inputs.insert(input_node_name);
    793   }
    794 
    795   for (const int i : swap_info.inputs_to_swap) {
    796     const string input_node_name = NodeName(node->input(i));
    797     possible_inputs.erase(input_node_name);
    798   }
    799   if (possible_inputs.empty()) {
    800     return nullptr;
    801   }
    802 
    803   max_trigger_time -= swap_info.time_to_swap;
    804 
    805   std::map<Costs::NanoSeconds, const NodeDef*> candidates;
    806   std::set<string> already_processed;
    807 
    808   while (!possible_inputs.empty()) {
    809     const string input_node_name = *possible_inputs.begin();
    810     possible_inputs.erase(possible_inputs.begin());
    811     already_processed.insert(input_node_name);
    812     auto it1 = name_map.find(input_node_name);
    813     if (it1 == name_map.end()) {
    814       return nullptr;
    815     }
    816     const NodeDef* input_node = it1->second;
    817     // Don't jump over frames, since adding a control dependency from one frame
    818     // to the next isn't supported. Don't go through branches, since we don't
    819     // know whether they'll be executed or not.
    820     if (ModifiesFrameInfo(*input_node) || IsSwitch(*input_node) ||
    821         IsMerge(*input_node)) {
    822       continue;
    823     }
    824     auto it2 = execution_times.find(input_node);
    825     if (it2 == execution_times.end()) {
    826       return nullptr;
    827     }
    828     if (it2->second < max_trigger_time) {
    829       candidates[it2->second] = input_node;
    830     } else {
    831       for (const string& fanin : input_node->input()) {
    832         string name = NodeName(fanin);
    833         if (already_processed.find(name) == already_processed.end()) {
    834           possible_inputs.insert(name);
    835         }
    836       }
    837     }
    838   }
    839 
    840   // Select the candidate that will execute last, since we want to swap the data
    841   // back at the last minute while still allowing enough time for data to be
    842   // swapped back timely to feed the downstream nodes.
    843   if (!candidates.empty()) {
    844     return candidates.rbegin()->second;
    845   }
    846   return nullptr;
    847 }
    848 
    849 static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) {
    850   const NodeDef& node = *output.node;
    851   // There is no point in swapping out persistent tensors, since the tensor will
    852   // continue to use memory.
    853   if (IsPersistent(node)) {
    854     return false;
    855   }
    856 
    857   const OpDef* op_def;
    858   if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
    859     return false;
    860   }
    861   DataType dtype;
    862   if (!OutputTypeForNode(node, *op_def, output.port_id, &dtype).ok()) {
    863     return false;
    864   }
    865   // References can only refer to persistent memory: therefore the node isn't
    866   // swappable.
    867   if (IsRefType(dtype)) {
    868     return false;
    869   }
    870 
    871   if (output.node->op() == "Identity" || output.node->op() == "Reshape") {
    872     // If placed on the same device, these nodes are just forwarding references
    873     // to their input. Therefore they are swappable iff their fanin is swappable
    874     // or it resides on a different device.
    875     GraphView::InputPort input;
    876     input.node = output.node;
    877     input.port_id = 0;
    878     GraphView::OutputPort fanin = graph.GetRegularFanin(input);
    879     if (fanin.node->device() == node.device()) {
    880       return IsSwappable(graph, fanin);
    881     }
    882   }
    883   return true;
    884 }
    885 
    886 static NodeDef* FindSwapOutTrigger(
    887     const NodeDef* node, int input_id, const GraphView& view,
    888     const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
    889         execution_times) {
    890   // Find the output port that generated the tensor to swap.
    891   GraphView::InputPort swap;
    892   swap.node = const_cast<NodeDef*>(node);
    893   swap.port_id = input_id;
    894   GraphView::OutputPort generator = view.GetRegularFanin(swap);
    895   if (!generator.node) {
    896     return nullptr;
    897   }
    898 
    899   const std::unordered_set<GraphView::InputPort, GraphView::HashPort>& fanout =
    900       view.GetFanout(generator);
    901   NodeDef* trigger = nullptr;
    902   Costs::NanoSeconds earliest_fanout(Costs::NanoSeconds::infinity());
    903 
    904   for (const auto& port : fanout) {
    905     if (port.node == node) {
    906       continue;
    907     }
    908     auto it = execution_times.find(port.node);
    909     if (it != execution_times.end() && it->second < earliest_fanout) {
    910       earliest_fanout = it->second;
    911       trigger = port.node;
    912     }
    913   }
    914 
    915   return trigger;
    916 }
    917 
    918 static bool IsSwappable(GraphView::InputPort input) {
    919   const NodeDef& node = *input.node;
    920 
    921   const OpDef* op_def;
    922   if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
    923     return false;
    924   }
    925 
    926   DataType dtype;
    927   if (!InputTypeForNode(node, *op_def, input.port_id, &dtype).ok()) {
    928     return false;
    929   }
    930 
    931   return !IsRefType(dtype);
    932 }
    933 
    934 struct MemInfo {
    935   GraphView::OutputPort port;
    936   int64 memory_used;
    937   std::vector<GraphView::InputPort> uses_left;
    938   double fitness;
    939 
    940   bool operator<(const MemInfo& other) const { return fitness < other.fitness; }
    941 };
    942 
    943 static bool IdentifySwappingCandidates(
    944     Cluster* cluster, GrapplerItem* item, std::unordered_set<string>* skip_list,
    945     std::unordered_map<NodeDef*, SwapInfo>* nodes_to_swap) {
    946   GraphMemory memory(*item);
    947   const std::unordered_map<string, DeviceProperties>& devices =
    948       cluster->GetDevices();
    949   Status s = memory.InferStatically(devices);
    950   if (!s.ok()) {
    951     VLOG(1) << "Failed to infer memory usage: " << s.error_message();
    952     return false;
    953   }
    954 
    955   bool updated_graph = false;
    956   for (const auto& device : devices) {
    957     const string& name = device.first;
    958     const DeviceProperties& prop = device.second;
    959     if (prop.type() != "GPU") {
    960       continue;
    961     }
    962     if (prop.memory_size() <= 0) {
    963       VLOG(1) << "Peak memory usage unknown for device " << name;
    964       continue;
    965     }
    966     const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name);
    967 
    968     if (mem_usage.used_memory <= prop.memory_size()) {
    969       continue;
    970     }
    971     int64 required_savings = mem_usage.used_memory - prop.memory_size();
    972 
    973     std::unordered_map<string, Costs::NanoSeconds> op_completion_times;
    974     {
    975       VirtualCluster vcluster(cluster->GetDevices());
    976       if (!vcluster.Provision().ok()) {
    977         return false;
    978       }
    979       if (!vcluster.Initialize(*item).ok()) {
    980         return false;
    981       }
    982       RunMetadata metadata;
    983       Status s = vcluster.Run(item->graph, item->feed, item->fetch, &metadata);
    984       if (!s.ok() && s.code() != error::RESOURCE_EXHAUSTED) {
    985         return false;
    986       }
    987 
    988       for (const auto& dev_stats : metadata.step_stats().dev_stats()) {
    989         for (const auto& node_stats : dev_stats.node_stats()) {
    990           Costs::NanoSeconds exec_time =
    991               Costs::NanoSeconds(1) +
    992               Costs::MicroSeconds(node_stats.all_start_micros() +
    993                                   node_stats.op_end_rel_micros());
    994           op_completion_times.emplace(node_stats.node_name(), exec_time);
    995         }
    996       }
    997     }
    998 
    999     Costs::Duration peak_time = -1;
   1000     for (const auto& live_tensor : mem_usage.live_tensors) {
   1001       if (live_tensor.allocation_time > peak_time) {
   1002         peak_time = live_tensor.allocation_time;
   1003       }
   1004     }
   1005 
   1006     std::vector<MemInfo> mem_state;
   1007 
   1008     GraphView graph(&item->graph);
   1009     for (const auto& live_tensor : mem_usage.live_tensors) {
   1010       if (live_tensor.memory_used <= 1024) {
   1011         // Don't bother with small tensors.
   1012         continue;
   1013       }
   1014       if (live_tensor.deallocation_time - live_tensor.allocation_time <=
   1015           Costs::Duration(1e6)) {
   1016         // Not enough time to swap.
   1017         VLOG(1) << "Not enough time to swap: skipping " << live_tensor.node;
   1018         continue;
   1019       }
   1020 
   1021       if (skip_list->find(live_tensor.node) != skip_list->end()) {
   1022         continue;
   1023       }
   1024       GraphView::OutputPort port =
   1025           graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
   1026       if (!IsSwappable(graph, port)) {
   1027         continue;
   1028       }
   1029       MemInfo mem_info;
   1030       mem_info.port = port;
   1031       mem_info.memory_used = live_tensor.memory_used;
   1032       Costs::Duration allocation_time = live_tensor.allocation_time;
   1033       Costs::Duration earliest_use(Costs::Duration::infinity());
   1034       bool valid = true;
   1035       for (GraphView::InputPort input : graph.GetFanout(port)) {
   1036         // Get execution time.
   1037         auto it = op_completion_times.find(input.node->name());
   1038         if (it == op_completion_times.end()) {
   1039           valid = false;
   1040           break;
   1041         }
   1042         if (it->second <= peak_time) {
   1043           continue;
   1044         }
   1045 
   1046         if (skip_list->find(input.node->name()) != skip_list->end()) {
   1047           valid = false;
   1048           break;
   1049         }
   1050         string input_name =
   1051             strings::StrCat(input.node->name(), ":", input.port_id);
   1052         if (skip_list->find(input_name) != skip_list->end()) {
   1053           valid = false;
   1054           break;
   1055         }
   1056         if (!IsSwappable(input)) {
   1057           valid = false;
   1058           break;
   1059         }
   1060 
   1061         // Set earliest use time that's after peak.
   1062         mem_info.uses_left.emplace_back(input);
   1063         earliest_use = std::min(earliest_use, it->second);
   1064       }
   1065       if (valid && !mem_info.uses_left.empty()) {
   1066         // Compute the fitness: we need the tensor to be generated way away of
   1067         // the time of peak memory usage (to ensure there is enough time to swap
   1068         // it out). We also need to ensure it's used way after the peak time, to
   1069         // ensure that swapping the tensor back in won't recreate the memory
   1070         // bottleneck. Last but not least, we want the tensor to have as few
   1071         // remaining uses as possible.
   1072         mem_info.fitness = std::pow((earliest_use - peak_time).count(), 2);
   1073         mem_info.fitness /= std::pow(mem_info.uses_left.size(), 2);
   1074         mem_info.fitness += std::pow((allocation_time - peak_time).count(), 2);
   1075         mem_info.fitness = -mem_info.fitness;
   1076         mem_state.push_back(mem_info);
   1077       }
   1078     }
   1079 
   1080     // Sort by fitness
   1081     std::sort(mem_state.begin(), mem_state.end());
   1082 
   1083     for (const MemInfo& mem_info : mem_state) {
   1084       for (const GraphView::InputPort fanout_to_swap : mem_info.uses_left) {
   1085         VLOG(1) << "Will swap fanout " << fanout_to_swap.node->name() << ":"
   1086                 << fanout_to_swap.port_id << " of tensor "
   1087                 << mem_info.port.node->name() << ":" << mem_info.port.port_id
   1088                 << " of size " << mem_info.memory_used;
   1089 
   1090         (*nodes_to_swap)[fanout_to_swap.node].inputs_to_swap.push_back(
   1091             fanout_to_swap.port_id);
   1092       }
   1093       required_savings -= mem_info.memory_used;
   1094       updated_graph = true;
   1095       if (required_savings < 0) {
   1096         break;
   1097       }
   1098     }
   1099   }
   1100   return updated_graph;
   1101 }
   1102 
   1103 bool SwappingPass(RewriterConfig::MemOptType optimization_level,
   1104                   Cluster* cluster, GrapplerItem* item,
   1105                   std::unordered_set<string>* skip_list) {
   1106   std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap;
   1107   if (optimization_level == RewriterConfig::SWAPPING_HEURISTICS ||
   1108       optimization_level == RewriterConfig::HEURISTICS) {
   1109     // Use heuristics to figure out what needs to be swapped;
   1110     IdentifySwappingCandidates(cluster, item, skip_list, &nodes_to_swap);
   1111   }
   1112   // Look for manual annotatations in the graph.
   1113   for (auto& node : *item->graph.mutable_node()) {
   1114     if (node.attr().count("_swap_to_host") != 0) {
   1115       SwapInfo& swap_info = nodes_to_swap[&node];
   1116       const AttrValue& val = node.attr().at("_swap_to_host");
   1117       if (val.has_list()) {
   1118         for (int64 input_id : val.list().i()) {
   1119           swap_info.inputs_to_swap.push_back(input_id);
   1120         }
   1121       } else {
   1122         int64 input_id = val.i();
   1123         swap_info.inputs_to_swap.push_back(input_id);
   1124       }
   1125     }
   1126   }
   1127   if (nodes_to_swap.empty()) {
   1128     // Nothing to do.
   1129     return false;
   1130   }
   1131 
   1132   // Estimate the size of the data to swap for each node.
   1133   GraphProperties properties(*item);
   1134   if (!properties.InferStatically(true).ok()) {
   1135     return false;
   1136   }
   1137   for (auto& swap : nodes_to_swap) {
   1138     const NodeDef* node = swap.first;
   1139     const std::vector<OpInfo::TensorProperties>& props =
   1140         properties.GetInputProperties(node->name());
   1141     SwapInfo& swap_info = swap.second;
   1142     int64 bytes_to_swap = 0;
   1143     for (int64 input_id : swap_info.inputs_to_swap) {
   1144       const OpInfo::TensorProperties& t = props[input_id];
   1145       bytes_to_swap += EstimateSize(t);
   1146     }
   1147     // Let's assume we're going to swap over PCIe running at 16 GBps.
   1148     swap_info.time_to_swap = bytes_to_swap / 16;
   1149   }
   1150 
   1151   std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times;
   1152   if (!EstimateEarliestExecutionTimes(*item, cluster, &execution_times).ok()) {
   1153     return false;
   1154   }
   1155 
   1156   std::unordered_map<string, const NodeDef*> name_map;
   1157   for (const auto& node : item->graph.node()) {
   1158     name_map[node.name()] = &node;
   1159   }
   1160   GraphView view(&item->graph);
   1161 
   1162   bool updated_graph = false;
   1163 
   1164   for (auto& swap : nodes_to_swap) {
   1165     NodeDef* node = swap.first;
   1166     const SwapInfo& swap_info = swap.second;
   1167     if (skip_list->find(node->name()) != skip_list->end()) {
   1168       continue;
   1169     }
   1170 
   1171     // Make sure the tensor isn't swapped back in right away: look for node that
   1172     // will execute just before we need to swap the data back, and add a control
   1173     // dependency from that node to the swap node.
   1174     const NodeDef* in_trigger =
   1175         FindSwapInTrigger(node, swap_info, name_map, execution_times);
   1176     // If we failed, don't attempt to reprocess this node in a subsequent pass.
   1177     if (!in_trigger) {
   1178       skip_list->insert(node->name());
   1179       continue;
   1180     }
   1181 
   1182     // Swap all the tensors that are marked with the 'swap_to_host' attribute.
   1183     for (int input_id : swap_info.inputs_to_swap) {
   1184       string input_name = strings::StrCat(node->name(), ":", input_id);
   1185       if (skip_list->find(input_name) != skip_list->end()) {
   1186         continue;
   1187       } else {
   1188         // Don't attempt to reprocess this input in a subsequent pass.
   1189         skip_list->insert(input_name);
   1190       }
   1191 
   1192       // Make sure the tensor is swapped out quickly: look for node that
   1193       // will execute just after the tensor is generated and add a control
   1194       // dependency from the swap out node to that node.
   1195       NodeDef* out_trigger =
   1196           FindSwapOutTrigger(node, input_id, view, execution_times);
   1197       if (!out_trigger) {
   1198         continue;
   1199       }
   1200 
   1201       std::pair<NodeDef*, NodeDef*> swap_nodes;
   1202       if (!BuildSwapPair(node, input_id, name_map, &item->graph, &swap_nodes)
   1203                .ok()) {
   1204         continue;
   1205       }
   1206       *swap_nodes.first->add_input() = node->input(input_id);
   1207       *node->mutable_input(input_id) = swap_nodes.second->name();
   1208 
   1209       // Add the control dependencies needed to delay the execution of the swap.
   1210       out_trigger->add_input(strings::StrCat("^", swap_nodes.first->name()));
   1211       swap_nodes.second->add_input(strings::StrCat("^", in_trigger->name()));
   1212 
   1213       // Make sure we won't try to swap the swap nodes in subsequent passes.
   1214       skip_list->insert(swap_nodes.first->name());
   1215       skip_list->insert(swap_nodes.second->name());
   1216     }
   1217   }
   1218   return updated_graph;
   1219 }
   1220 
   1221 Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
   1222                                  GraphDef* optimized_graph) {
   1223   *optimized_graph = item.graph;
   1224 
   1225   RecomputationRewritingPass(optimization_level_,
   1226                              recomputation_targets_name_prefix_,
   1227                              optimized_graph, item);
   1228 
   1229   GrapplerItem optimized_item(item, std::move(*optimized_graph));
   1230   std::unordered_set<string> skip_list;
   1231   // Bound the number of rewrite passes to avoid long processing times on graphs
   1232   // that simply won't fit in memory.
   1233   bool updated_graph = true;
   1234   for (int i = 0; i < 25 && updated_graph; ++i) {
   1235     updated_graph = false;
   1236     if ((optimization_level_ == RewriterConfig::DEFAULT_MEM_OPT ||
   1237          optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS ||
   1238          optimization_level_ == RewriterConfig::HEURISTICS) &&
   1239         cluster != nullptr) {
   1240       updated_graph |= SchedulingPass(cluster, &optimized_item);
   1241     }
   1242 
   1243     if ((optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS ||
   1244          optimization_level_ == RewriterConfig::HEURISTICS ||
   1245          optimization_level_ == RewriterConfig::MANUAL) &&
   1246         cluster != nullptr) {
   1247       updated_graph |= SwappingPass(optimization_level_, cluster,
   1248                                     &optimized_item, &skip_list);
   1249     }
   1250   }
   1251 
   1252   optimized_graph->Swap(&optimized_item.graph);
   1253   return Status::OK();
   1254 }
   1255 
   1256 void MemoryOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
   1257                                const GraphDef& optimized_graph, double result) {
   1258   // Nothing to do for MemoryOptimizer.
   1259 }
   1260 
   1261 }  // end namespace grappler
   1262 }  // end namespace tensorflow
   1263