Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/copy_insertion.h"
     17 
     18 #include "absl/container/flat_hash_map.h"
     19 #include "absl/container/flat_hash_set.h"
     20 #include "absl/strings/str_cat.h"
     21 #include "absl/strings/str_join.h"
     22 #include "tensorflow/compiler/xla/service/dump.h"
     23 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
     24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     25 #include "tensorflow/compiler/xla/service/hlo_dce.h"
     26 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
     27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     28 #include "tensorflow/compiler/xla/service/hlo_module.h"
     29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     30 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
     31 #include "tensorflow/compiler/xla/service/logical_buffer.h"
     32 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
     33 #include "tensorflow/compiler/xla/status_macros.h"
     34 #include "tensorflow/compiler/xla/statusor.h"
     35 #include "tensorflow/compiler/xla/types.h"
     36 #include "tensorflow/compiler/xla/util.h"
     37 #include "tensorflow/core/platform/logging.h"
     38 
     39 namespace xla {
     40 namespace {
     41 
     42 using absl::StrAppend;
     43 
     44 bool IsReadonlyEntryParameterValue(const HloValue& value) {
     45   const HloComputation* computation = value.defining_instruction()->parent();
     46   return value.defining_instruction()->opcode() == HloOpcode::kParameter &&
     47          computation == computation->parent()->entry_computation() &&
     48          !computation->parent()->input_output_alias_config().ParameterHasAlias(
     49              value.defining_instruction()->parameter_number(), value.index());
     50 }
     51 
     52 bool IsConstantValue(const HloValue& value) {
     53   return value.defining_instruction()->opcode() == HloOpcode::kConstant;
     54 }
     55 
     56 bool ValueIsReadOnly(const HloValue& value) {
     57   return IsConstantValue(value) || IsReadonlyEntryParameterValue(value);
     58 }
     59 
     60 // Data structure describing the action which should be taken on parts of a
     61 // computation buffers, with respect to the adding of special case copies.
     62 struct SpecialCaseCopyPolicy {
     63   // Insert a copy if the same buffer is found at multiple indices within the
     64   // output tuple.
     65   bool copy_root_replicated_buffers = false;
     66   // If true, insert a copy if a buffer coming from a constant or a parameter
     67   // is found within the output tuple.
     68   bool copy_parameters_and_constants = false;
     69 };
     70 
     71 SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node,
     72                                                HloModule* module,
     73                                                HloComputation* computation) {
     74   SpecialCaseCopyPolicy policy;
     75   if (computation == module->entry_computation()) {
     76     policy.copy_parameters_and_constants = true;
     77     policy.copy_root_replicated_buffers = true;
     78   }
     79   return policy;
     80 }
     81 
     82 bool ShouldCopyRootValue(const HloValue& value,
     83                          const SpecialCaseCopyPolicy& policy) {
     84   if (policy.copy_parameters_and_constants) {
     85     return ValueIsReadOnly(value);
     86   }
     87   return false;
     88 }
     89 
     90 // Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in
     91 // 'indices_to_copy'. Add control edges from the respective kCopy instructions
     92 // in deep copy of 'from' to the respective kCopy instruction in the deep copy
     93 // of 'to'.
     94 //
     95 // Requirements: 'from' and 'to' must have compatible shapes.
     96 //
     97 // For example, suppose 'from' and 'to' are two-element tuples where index 0 is
     98 // the only index to copy. Prior to deep-copying we have:
     99 //
    100 //
    101 //      'from'
    102 //         |
    103 //        ...
    104 //         |
    105 //       'to'
    106 //
    107 // DeepCopyAndAddControlEdges produces:
    108 //
    109 //       'from'
    110 //        /   \
    111 //      GTE   GTE
    112 //       |     |
    113 //     Copy    |
    114 //    /   \   /
    115 //   |    Tuple
    116 //   |      |
    117 //  ctrl   ...
    118 //  edge    |
    119 //   |      |
    120 //   |    'to'
    121 //   |    /   \
    122 //   |  GTE   GTE
    123 //    \  |     |
    124 //     Copy    |
    125 //        \   /
    126 //        Tuple
    127 //
    128 StatusOr<std::pair<HloInstruction*, HloInstruction*>>
    129 DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to,
    130                            const ShapeTree<bool>& indices_to_copy) {
    131   DCHECK(ShapeUtil::Compatible(from->shape(), to->shape()));
    132   // to/from_copy_tree hold the kCopy instruction produces by the deep
    133   // copies. Elements which are not copied (indices_to_copy.element(index) ==
    134   // false) have nullptr at that index.
    135   ShapeTree<HloInstruction*> from_copy_tree(from->shape(),
    136                                             /*init_value=*/nullptr);
    137   TF_ASSIGN_OR_RETURN(HloInstruction * from_deep_copy,
    138                       from->parent()->DeepCopyInstruction(
    139                           from, &indices_to_copy, &from_copy_tree));
    140 
    141   ShapeTree<HloInstruction*> to_copy_tree(to->shape(), /*init_value=*/nullptr);
    142   TF_ASSIGN_OR_RETURN(
    143       HloInstruction * to_deep_copy,
    144       to->parent()->DeepCopyInstruction(to, &indices_to_copy, &to_copy_tree));
    145 
    146   // Add control edges between the respective kCopy instructions.
    147   for (const auto& pair : from_copy_tree) {
    148     const ShapeIndex& index = pair.first;
    149     HloInstruction* from_copy = pair.second;
    150     HloInstruction* to_copy = to_copy_tree.element(index);
    151     if (from_copy == nullptr) {
    152       TF_RET_CHECK(to_copy == nullptr);
    153       continue;
    154     }
    155     TF_RET_CHECK(to_copy != nullptr);
    156     TF_RETURN_IF_ERROR(from_copy->AddControlDependencyTo(to_copy));
    157   }
    158 
    159   return std::make_pair(from_deep_copy, to_deep_copy);
    160 }
    161 
    162 // Compute the indices of the loop state which need copies in order to avoid
    163 // live range interference. Generally, an element in the loop state does not
    164 // need to be copied if the element is passed through transparently through the
    165 // body.
    166 //
    167 // Returns whether any indices need to be copied.
    168 bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow,
    169                            const HloInstruction* xla_while,
    170                            ShapeTree<bool>* indices_to_copy) {
    171   DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(), xla_while->shape()));
    172 
    173   bool any_copies = false;
    174   const HloInstruction* init = xla_while->operand(0);
    175   for (auto& pair : *indices_to_copy) {
    176     const ShapeIndex& index = pair.first;
    177     bool& should_copy = pair.second;
    178     // If there is any ambiguity, then loop state must be copied.
    179     if (dataflow.GetValueSet(init, index).values().size() > 1 ||
    180         dataflow.GetValueSet(xla_while, index).values().size() > 1) {
    181       should_copy = true;
    182     } else {
    183       // If the output of the while instruction is not the same as the init
    184       // value of the while, then this element is not passed through the body
    185       // transparently and must be copied.
    186       should_copy = dataflow.GetUniqueValueAt(xla_while, index) !=
    187                     dataflow.GetUniqueValueAt(init, index);
    188     }
    189     any_copies |= should_copy;
    190   }
    191   return any_copies;
    192 }
    193 
    194 // Add kCopy instructions around the given kWhile instruction to eliminate any
    195 // possible live range interference of HLO values assuming a dependency-based
    196 // ordering (HloDependencyOrdering). Copies are added conservatively. There
    197 // likely are copies which are not strictly necessary, but they are removed
    198 // later in the pass via RemoveUnnecessaryCopies.
    199 //
    200 //
    201 // Elements (each ShapeIndex) in the loop state are considered independently.  A
    202 // copy is added to each element of the loop state which is modified in the
    203 // while body. For each such element, a total of three kCopy instructions are
    204 // added at following locations:
    205 //
    206 //   (1) The init value is copied before the kWhile instruction. Before:
    207 //
    208 //           (Init)
    209 //             |
    210 //           kWhile
    211 //             |
    212 //            ...
    213 //
    214 //       After:
    215 //
    216 //           (Init)
    217 //             |
    218 //           kCopy
    219 //             |
    220 //           kWhile
    221 //             |
    222 //            ...
    223 //
    224 //       This copy is necessary in case the init value is simultaneously live
    225 //       with the kWhile.
    226 //
    227 //   (2) Copies are added to the parameter and root of the while body
    228 //       computation. Before:
    229 //
    230 //           kParameter
    231 //               |
    232 //              ...
    233 //               |
    234 //           (body root)
    235 //
    236 //       After:
    237 //
    238 //           kParameter
    239 //               |
    240 //             kCopy ----------+
    241 //               |             |
    242 //              ...           ctrl
    243 //               |            edge
    244 //           (body root)       |
    245 //               |             |
    246 //             kCopy <---------+
    247 //
    248 //       The root kCopy becomes the new root of the computation. Both copies are
    249 //       necessary to any potential interference between the parameter value and
    250 //       the root value. The control edge prevents potential interference
    251 //       between the copies themselves.
    252 //
    253 // If the loop state is a tuple then the above kCopy instructions are a deep
    254 // copy constructed of kCopy, KGetTupleElement, and kTuple instruction as
    255 // constructed by HloInstruction::DeepCopyInstruction.
    256 Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis,
    257                          HloInstruction* xla_while) {
    258   VLOG(2) << "Adding copies for kWhile instruction " << xla_while->name();
    259   TF_RET_CHECK(xla_while->opcode() == HloOpcode::kWhile);
    260 
    261   ShapeTree<bool> indices_to_copy(xla_while->shape());
    262   if (!IndicesToCopyForWhile(alias_analysis.dataflow_analysis(), xla_while,
    263                              &indices_to_copy)) {
    264     VLOG(2) << "No copies necessary for kWhile instruction "
    265             << xla_while->name();
    266     return Status::OK();
    267   }
    268 
    269   VLOG(2) << "Adding copies for " << xla_while->name() << " at indices:";
    270   for (auto& pair : indices_to_copy) {
    271     if (pair.second) {
    272       VLOG(2) << "  " << pair.first;
    273     }
    274   }
    275 
    276   // Deep copy init.
    277   HloInstruction* while_init = xla_while->mutable_operand(0);
    278   TF_ASSIGN_OR_RETURN(
    279       HloInstruction * while_init_copy,
    280       xla_while->parent()->DeepCopyInstruction(while_init, &indices_to_copy));
    281   TF_RETURN_IF_ERROR(while_init->ReplaceUseWith(xla_while, while_init_copy));
    282 
    283   // Deep copy the parameter and the root. Extend a control edge from the copy
    284   // of the parameter value to the corresponding copy value of the root.
    285   HloComputation* body = xla_while->while_body();
    286   HloInstruction* param = body->parameter_instruction(0);
    287   HloInstruction* root = body->root_instruction();
    288 
    289   // If param is the root then all indices should have been passed through the
    290   // while body and we should have returned early above.
    291   TF_RET_CHECK(param != root);
    292 
    293   // Copy users before making a deep copy of the parameter as the deep copy
    294   // will create new users of the parameter (eg, the GTE instructions of the
    295   // deep copy).
    296   std::vector<HloInstruction*> param_users = param->users();
    297 
    298   ShapeIndex current_index;
    299   TF_ASSIGN_OR_RETURN(auto pair,
    300                       DeepCopyAndAddControlEdges(param, root, indices_to_copy));
    301 
    302   HloInstruction* param_copy = pair.first;
    303   HloInstruction* root_copy = pair.second;
    304 
    305   for (HloInstruction* user : param_users) {
    306     TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, param_copy));
    307   }
    308 
    309   body->set_root_instruction(root_copy);
    310 
    311   return Status::OK();
    312 }
    313 
    314 // We add copies for all the indices of the true and false computation roots, in
    315 // order to resolve interference. We later rely on RemoveUnnecessaryCopies to
    316 // drop the unnecessary ones.
    317 Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
    318                                HloInstruction* conditional) {
    319   VLOG(2) << "Adding copies for kConditional instruction "
    320           << conditional->name();
    321   TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional);
    322 
    323   for (HloComputation* computation : conditional->branch_computations()) {
    324     HloInstruction* root = computation->root_instruction();
    325     std::vector<HloInstruction*> users = root->users();
    326     TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
    327                         computation->DeepCopyInstruction(root));
    328     for (HloInstruction* user : users) {
    329       TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy));
    330     }
    331     computation->set_root_instruction(deep_copy);
    332   }
    333   return Status::OK();
    334 }
    335 
    336 // Conservatively adds copies before root instruction of entry computation and
    337 // each aliased parameter to resolve interference of aliased input and output
    338 // buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary
    339 // ones.
    340 Status AddCopiesForAliasedInputOutputs(HloModule* module) {
    341   HloComputation* entry = module->entry_computation();
    342   HloInstruction* root = entry->root_instruction();
    343 
    344   ShapeTree<bool> output_indices_to_copy(root->shape());
    345   std::vector<absl::optional<ShapeTree<HloInstruction*>>> copied_parameters(
    346       entry->num_parameters());
    347   bool has_alias = false;
    348   for (auto* param : entry->parameter_instructions()) {
    349     bool param_has_alias = false;
    350     ShapeTree<bool> param_indices_to_copy(param->shape());
    351 
    352     module->input_output_alias_config().ForEachAlias(
    353         [&](const ShapeIndex& output_index,
    354             const HloInputOutputAliasConfig::Alias& alias) {
    355           if (alias.parameter_number == param->parameter_number()) {
    356             param_has_alias = true;
    357             *(param_indices_to_copy.mutable_element(alias.parameter_index)) =
    358                 true;
    359             *(output_indices_to_copy.mutable_element(output_index)) = true;
    360           }
    361         });
    362 
    363     if (!param_has_alias) {
    364       continue;
    365     }
    366 
    367     TF_RET_CHECK(param->parameter_number() < entry->num_parameters());
    368     TF_RET_CHECK(!copied_parameters[param->parameter_number()]);
    369 
    370     has_alias = true;
    371     // Store a snapshot of users before DeepCopyInstruction, as
    372     // DeepCopyInstruction introduces new users of the instruction.
    373     std::vector<HloInstruction*> users = param->users();
    374     ShapeTree<HloInstruction*> param_copy_tree(param->shape(),
    375                                                /*init_value=*/nullptr);
    376     TF_ASSIGN_OR_RETURN(HloInstruction * copied,
    377                         entry->DeepCopyInstruction(
    378                             param, &param_indices_to_copy, &param_copy_tree));
    379     for (HloInstruction* user : users) {
    380       TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, copied));
    381     }
    382 
    383     copied_parameters[param->parameter_number()] = param_copy_tree;
    384   }
    385 
    386   if (!has_alias) {
    387     return Status::OK();
    388   }
    389 
    390   // Add copies before root instruction.
    391   ShapeTree<HloInstruction*> output_copy_tree(root->shape(),
    392                                               /*init_value=*/nullptr);
    393 
    394   TF_ASSIGN_OR_RETURN(HloInstruction * root_copied,
    395                       root->parent()->DeepCopyInstruction(
    396                           root, &output_indices_to_copy, &output_copy_tree));
    397 
    398   // Add control dependencies between the input/output copies.
    399   TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus(
    400       [&](const ShapeIndex& output_index,
    401           const HloInputOutputAliasConfig::Alias& alias) -> Status {
    402         if (!copied_parameters[alias.parameter_number]) {
    403           return Status::OK();
    404         }
    405         HloInstruction* from =
    406             copied_parameters[alias.parameter_number]->element(
    407                 alias.parameter_index);
    408         HloInstruction* to = output_copy_tree.element(output_index);
    409 
    410         TF_RET_CHECK(from != nullptr);
    411         TF_RET_CHECK(to != nullptr);
    412         TF_RETURN_IF_ERROR(from->AddControlDependencyTo(to));
    413         return Status::OK();
    414       }));
    415 
    416   entry->set_root_instruction(root_copied);
    417 
    418   return Status::OK();
    419 }
    420 
    421 // Removes any control dependencies to or from the given instruction.
    422 Status StripControlDependenciesFrom(HloInstruction* instruction) {
    423   while (!instruction->control_successors().empty()) {
    424     TF_RETURN_IF_ERROR(instruction->RemoveControlDependencyTo(
    425         instruction->control_successors().front()));
    426   }
    427 
    428   while (!instruction->control_predecessors().empty()) {
    429     TF_RETURN_IF_ERROR(
    430         instruction->control_predecessors().front()->RemoveControlDependencyTo(
    431             instruction));
    432   }
    433 
    434   return Status::OK();
    435 }
    436 
    437 // Class which tracks the HLO values within each HLO buffer in the module
    438 // during copy removal.
    439 //
    440 // The values are held in a linked list where there is one list for each
    441 // buffer. Removing a copy instruction merges together the values in the
    442 // source buffer of the copy to the destination buffer of the copy. This class
    443 // tracks these value lists as copies are removed from the graph (and value
    444 // lists are merged).
    445 //
    446 // The CopyRemover object is initialized to match the state of
    447 // HloAliasAnalysis. However, as copies are removed this state diverges. The
    448 // values-to-buffer mapping is maintained outside of HloAliasAnalysis because
    449 // a fully updatable alias analysis is very slow.
    450 class CopyRemover {
    451  public:
    452   // The values held in a single HLO buffer are represented using a linked
    453   // list. An element type in this list is ValueNode.
    454   //
    455   // This linked list is hand-rolled to enable efficient splicing of lists
    456   // using only references to list elements without knowing which lists are
    457   // being spliced. std::list requires a reference to the list object to
    458   // splice.
    459   struct ValueNode {
    460     explicit ValueNode(const HloValue* v) : value(v) {}
    461 
    462     const HloValue* value;
    463 
    464     // The uses are maintained outside of HloValue::uses() because
    465     // HloValue::uses() is not updatable (a fully updatable dataflow analysis
    466     // is slow).
    467     std::vector<const HloUse*> uses;
    468 
    469     // next/prev elements in the linked list. The list is circularly linked so
    470     // these values are never null for elements in the list.
    471     ValueNode* prev = nullptr;
    472     ValueNode* next = nullptr;
    473   };
    474 
    475   CopyRemover(const HloModule& module, const HloAliasAnalysis& alias_analysis,
    476               const HloOrdering& ordering)
    477       : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) {
    478     // Construct a list for each HLO buffer in the alias analysis. Maintain a
    479     // map from HloValue to the respective list element representing that
    480     // value. The map is used to construct the copy info map below.
    481     absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
    482     for (const HloBuffer& buffer : alias_analysis.buffers()) {
    483       // Verify values contained in the buffer are strictly ordered. This
    484       // should always be the case after adding copies to eliminate
    485       // interference. Specifically, the addition of the control flow edges
    486       // between copies added around aliased operations (kWhile) guarantees
    487       // this strict order.
    488       for (const HloValue* value_a : buffer.values()) {
    489         if (value_a->shape().IsToken()) {
    490           // Token values have no representation and cannot interfere.
    491           continue;
    492         }
    493         for (const HloValue* value_b : buffer.values()) {
    494           if (value_a != value_b) {
    495             DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b,
    496                                                      dataflow_) ||
    497                    ordering_.LiveRangeStrictlyBefore(*value_b, *value_a,
    498                                                      dataflow_))
    499                 << value_a->ToShortString() << " and "
    500                 << value_b->ToShortString() << " are not ordered";
    501           }
    502         }
    503       }
    504 
    505       std::vector<const HloValue*> values = buffer.values();
    506       absl::c_sort(values, [this](const HloValue* a, const HloValue* b) {
    507         return ordering_.IsDefinedBefore(*a, *b);
    508       });
    509 
    510       // Create a list containing all of the values in the buffer.
    511       AddValueList(values, &value_to_node);
    512     }
    513 
    514     // Create copy_map_ which contains the source and destination values
    515     // of all copies.
    516     CreateCopyMap(module, value_to_node);
    517 
    518     XLA_VLOG_LINES(3, ToString());
    519     TF_DCHECK_OK(Verify());
    520   }
    521 
    522   // Add a list containing the given values to CopyRemover. This
    523   // represents the values contained in a single buffer. For each value in
    524   // 'values' an entry is created in value_to_node which indicates the
    525   // respective ValueNode representing that value.
    526   void AddValueList(
    527       absl::Span<const HloValue* const> values,
    528       absl::flat_hash_map<const HloValue*, ValueNode*>* value_to_node) {
    529     ValueNode* tail = nullptr;
    530     ValueNode* head = nullptr;
    531     for (const HloValue* value : values) {
    532       auto new_node = new ValueNode(value);
    533       (*value_to_node)[value] = new_node;
    534 
    535       // Copy the HLO values's uses into the ValueNode for the value. These
    536       // uses in ValueNode are updated as copies are removed.
    537       new_node->uses.reserve(value->uses().size());
    538       for (const HloUse& use : value->uses()) {
    539         new_node->uses.push_back(&use);
    540       }
    541 
    542       // Connect the new node into the linked list.
    543       if (tail == nullptr) {
    544         head = new_node;
    545       } else {
    546         tail->next = new_node;
    547         new_node->prev = tail;
    548       }
    549       tail = new_node;
    550     }
    551 
    552     // The linked list is circular so connect the head and tail.
    553     tail->next = head;
    554     head->prev = tail;
    555     value_lists_.insert(head);
    556   }
    557 
    558   // This method also fills in copy_map_ which indicates which nodes
    559   // in the value lists corresponding to the source and destination values of
    560   // kCopy instructions. value_to_node should map each HloValue to its
    561   // respective ValueNode.
    562   void CreateCopyMap(
    563       const HloModule& module,
    564       const absl::flat_hash_map<const HloValue*, ValueNode*>& value_to_node) {
    565     for (HloComputation* computation : module.computations()) {
    566       for (HloInstruction* instruction : computation->instructions()) {
    567         // Add copies with unambiguous source values to the map. Copies with
    568         // ambiguous sources are not removable.
    569         if (instruction->opcode() == HloOpcode::kCopy) {
    570           const HloValueSet& src_value_set =
    571               dataflow_.GetValueSet(instruction->operand(0));
    572           if (src_value_set.values().size() == 1) {
    573             CopyNodes& copy_node = copy_map_[instruction];
    574             copy_node.dest =
    575                 value_to_node.at(&dataflow_.GetUniqueValueAt(instruction));
    576             copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue());
    577           }
    578         }
    579       }
    580     }
    581   }
    582 
    583   ~CopyRemover() {
    584     for (const ValueNode* head : value_lists_) {
    585       const ValueNode* p = head;
    586       do {
    587         const ValueNode* tmp = p->next;
    588         delete p;
    589         p = tmp;
    590       } while (p != head);
    591     }
    592   }
    593 
    594   // Verify invariants within the linked lists.
    595   Status Verify() const {
    596     for (const ValueNode* head : value_lists_) {
    597       const ValueNode* p = head;
    598       do {
    599         // Verify links between elements are consistent.
    600         TF_RET_CHECK(p->prev->next == p);
    601         TF_RET_CHECK(p->next->prev == p);
    602 
    603         const HloInstruction* def = p->value->defining_instruction();
    604         if (def->opcode() == HloOpcode::kCopy && ContainsKey(copy_map_, def)) {
    605           TF_RET_CHECK(copy_map_.at(def).dest == p);
    606         }
    607         for (const HloUse* use : p->uses) {
    608           if (use->instruction->opcode() == HloOpcode::kCopy &&
    609               ContainsKey(copy_map_, use->instruction)) {
    610             TF_RET_CHECK(copy_map_.at(use->instruction).src == p);
    611           }
    612         }
    613 
    614         p = p->next;
    615       } while (p != head);
    616     }
    617     return Status::OK();
    618   }
    619 
    620   // Try to elide the given copy. Elision of a copy is possible only if no
    621   // live range interference is introduced by the copy's elimination. If
    622   // elision is possible, then the internal state (value lists) are updated,
    623   // and true is returned. Returns false otherwise.
    624   bool TryElideCopy(const HloInstruction* copy) {
    625     VLOG(2) << "Trying to remove " << copy->name();
    626 
    627     if (!ContainsKey(copy_map_, copy)) {
    628       VLOG(2) << copy->name() << " is not removable";
    629       return false;
    630     }
    631     if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) {
    632       VLOG(2) << copy->name() << " is not removable (shape mismatch)";
    633       return false;
    634     }
    635     const CopyNodes& copy_node = copy_map_.at(copy);
    636     ValueNode* src = copy_node.src;
    637     ValueNode* dest = copy_node.dest;
    638     DCHECK(src != nullptr);
    639     DCHECK(dest != nullptr);
    640 
    641     auto is_live_range_before = [this](const ValueNode& a, const ValueNode& b) {
    642       VLOG(3) << "Checking live range of " << *a.value << " WRT " << *b.value;
    643       if (LiveRangeBefore(a, b)) {
    644         VLOG(2) << "  Live range of " << a.value->ToShortString()
    645                 << " is before " << b.value->ToShortString();
    646         return true;
    647       } else {
    648         VLOG(2) << "  Live range of " << a.value->ToShortString()
    649                 << " is not before " << b.value->ToShortString();
    650         return false;
    651       }
    652     };
    653 
    654     VLOG(3) << copy->name() << " copies value " << src->value->ToShortString();
    655     VLOG(3) << "Source buffer values: " << ValueListToString(src);
    656     VLOG(3) << "Dest buffer values: " << ValueListToString(dest);
    657 
    658     // A kCopy instruction copies an HLO value from a source buffer and
    659     // defines an HLO value in a destination buffer. Most generally, the
    660     // source and destination buffers may each hold more than one value at
    661     // different points in the computation so we define the following:
    662     //
    663     //   Values in source buffer:      {s_0, ..., s_n}
    664     //   Values in destination buffer: {d_0, ..., d_m}
    665     //
    666     // A kCopy instruction between these buffers copies a value s_x in the
    667     // source buffer and defines a value d_y in the destination buffer. The
    668     // elision of a copy merges the source and destination buffers together,
    669     // so the list of values for the source and destination buffers are
    670     // merged.
    671     //
    672     // We handle two different cases for copy elision:
    673     //
    674     //  (1) the kCopy defines the first value in the destination buffer (d_0).
    675     //
    676     //  (2) the kCopy copies the last value in the source buffer (s_n).
    677     //
    678     // For the remaining case where the kCopy copies a not-last value from the
    679     // source buffer to a not-first value of the destination buffer, the kCopy
    680     // instruction cannot be removed. This case is generated, for example, if
    681     // the kCopy copies a while body parameter of the loop state at one tuple
    682     // index to a different tuple index in the while body root. Removal of the
    683     // copy necessarily results in live range interference of values in the
    684     // loop state at the two different tuple indices.
    685     //
    686     //  We can only perform copy elision if the resulting merged values have
    687     //  totally ordered live ranges; otherwise the merged buffer would have
    688     //  live range interference.
    689     if (src->next == dest) {
    690       // In the process of eliding copies, its possible for a copy to have the
    691       // same source and destination buffer. In this case, the copy can be
    692       // safely removed.
    693       VLOG(2) << copy->name() << " source and destination buffers are same.";
    694     } else if (IsHead(*dest)) {
    695       // The copy copies an arbitrary value in the source buffer (call it s_x)
    696       // and defines d_0, the first value in the destination buffer. After
    697       // merging, the values in the combined buffer must be strictly ordered
    698       // as follows** to elide the copy:
    699       //
    700       // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n}
    701       //
    702       // Removing the copy eliminates d_0, and uses of d_0 become uses of
    703       // s_x. In the above ordering, the live range of d_m must be ordered
    704       // before the live range of s_{x+1} and the definition and all uses of
    705       // s_x must be ordered before the definition of d_1. These conditions
    706       // are checked below prior to elision.
    707       //
    708       // ** Technically it might be possible to have a non-interfering
    709       //    non-trivial interleaving of the values of the source and
    710       //    destination buffers in the resulting order. However, this case is
    711       //    slow and complicated to check and likely not worth it. So instead
    712       //    we simply check for the case where *all* values of the destination
    713       //    buffer (d_1 through d_m) are spliced into the point where the copy
    714       //    used to be.
    715       VLOG(2) << copy->name() << " defines the first value in its buffer";
    716       ValueNode* next_dest = Next(*dest);
    717       if (next_dest != nullptr) {
    718         // Live range of 'from' value (s_x) must be before 'next_dest' (d_1);
    719         if (!is_live_range_before(*src, *next_dest)) {
    720           return false;
    721         }
    722       }
    723       ValueNode* next_src = Next(*src);
    724 
    725       if (next_src != nullptr) {
    726         // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}.
    727         ValueNode* last_dest = dest->prev;
    728         DCHECK(IsTail(*last_dest));
    729         if (!is_live_range_before(*last_dest, *next_src)) {
    730           return false;
    731         }
    732       }
    733 
    734       // Splice in destination buffer values list right after 'src'.
    735       SpliceAfter(dest, src);
    736     } else if (IsTail(*src)) {
    737       // The copy copies the last value in the source buffer, s_n, and defines
    738       // an arbitrary value in the destination buffer, d_y.  After
    739       // merging, the values in the combined buffer must be strictly ordered
    740       // as follows** to elide the copy:
    741       //
    742       // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m}
    743       //
    744       // Removing the copy eliminates d_y, and uses of d_y become uses of
    745       // s_n. To enforce the above order, the live range of d_{y-1} must be
    746       // before the live range of s_0, and the live range of s_n must be
    747       // before the live range of d_{y+1}.
    748       //
    749       // ** See comment above in the code handling Case (1).
    750       VLOG(2) << copy->name() << " copies the last value ("
    751               << src->value->ToShortString() << ") in its buffer";
    752 
    753       ValueNode* prev_dest = Prev(*dest);
    754       // nullptr condition handled above in the first 'if' case.
    755       DCHECK(prev_dest != nullptr);
    756       ValueNode* first_src = src->next;
    757       DCHECK(IsHead(*first_src));
    758       if (!is_live_range_before(*prev_dest, *first_src)) {
    759         // Live range of value d_{y-1} is not before s_0.
    760         return false;
    761       }
    762       ValueNode* next_dest = Next(*dest);
    763       if (next_dest != nullptr) {
    764         if (!is_live_range_before(*src, *next_dest)) {
    765           // Live range of value s_n is not before d_{y+1}.
    766           return false;
    767         }
    768       }
    769 
    770       // Splice source buffer values list right after 'prev_dest'.
    771       SpliceAfter(first_src, prev_dest);
    772     } else {
    773       VLOG(2) << copy->name()
    774               << " copies value in middle of source buffer to value in middle "
    775                  "of destination buffer";
    776       return false;
    777     }
    778 
    779     RemoveCopyValue(dest);
    780 
    781     XLA_VLOG_LINES(4, ToString());
    782     TF_DCHECK_OK(Verify());
    783 
    784     return true;
    785   }
    786 
    787   // Delete the given ValueNode associated with a elided kCopy
    788   // instruction. This should be called after splicing the value lists of the
    789   // source and destination buffers together.
    790   void RemoveCopyValue(ValueNode* copy_value_node) {
    791     CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(),
    792              HloOpcode::kCopy);
    793     ValueNode* operand_node = copy_value_node->prev;
    794     CHECK(operand_node != copy_value_node);
    795 
    796     VLOG(2) << "Removing copy " << operand_node->value->ToShortString()
    797             << " => " << copy_value_node->value->ToShortString();
    798 
    799     // Splice out the copy value node.
    800     operand_node->next = copy_value_node->next;
    801     copy_value_node->next->prev = operand_node;
    802 
    803     // Patch up uses. Remove use of copy from operand_node uses.
    804     auto it = absl::c_find_if(operand_node->uses, [copy_value_node](
    805                                                       const HloUse* use) {
    806       return use->instruction == copy_value_node->value->defining_instruction();
    807     });
    808     CHECK(it != operand_node->uses.end());
    809     operand_node->uses.erase(it);
    810 
    811     // If the elided copy has any uses which are themselves kCopy instructions
    812     // then patch up the copy info to reflect the that this kCopy instruction
    813     // has a different operand (the operand of the elided copy).
    814     for (const HloUse* copy_use : copy_value_node->uses) {
    815       operand_node->uses.push_back(copy_use);
    816       if (copy_use->instruction->opcode() == HloOpcode::kCopy &&
    817           ContainsKey(copy_map_, copy_use->instruction)) {
    818         copy_map_.at(copy_use->instruction).src = operand_node;
    819       }
    820     }
    821 
    822     // Delete the copy info and the value node.
    823     copy_map_.erase(copy_value_node->value->defining_instruction());
    824     delete copy_value_node;
    825   }
    826 
    827   // Returns true if the live range of given value 'a' is before the live
    828   // range of 'b'.
    829   //
    830   // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not
    831   // updated as copies are removed.
    832   bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) {
    833     if (a.uses.empty()) {
    834       VLOG(2) << "Empty uses for " << *a.value;
    835       return ordering_.IsDefinedBefore(*a.value, *b.value);
    836     }
    837     for (const HloUse* use : a.uses) {
    838       VLOG(2) << "Checking use " << *use << " against " << *b.value;
    839       if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) {
    840         VLOG(2) << "Use " << *use << " is NOT before " << *b.value;
    841         return false;
    842       }
    843       VLOG(2) << "Use " << *use << " is before " << *b.value;
    844     }
    845     return true;
    846   }
    847 
    848   // Returns whether 'node' is the last node in its list.
    849   bool IsTail(const ValueNode& node) const {
    850     return ContainsKey(value_lists_, node.next);
    851   }
    852 
    853   // Returns whether 'node' is the first node in its list.
    854   bool IsHead(const ValueNode& node) const {
    855     return ContainsKey(value_lists_, &node);
    856   }
    857 
    858   // Returns the next node in the list after 'node'. If 'node' is the
    859   // tail, then nullptr is returned.
    860   ValueNode* Next(const ValueNode& node) const {
    861     if (IsTail(node)) {
    862       return nullptr;
    863     } else {
    864       return node.next;
    865     }
    866   }
    867 
    868   // Returns the previous node in the list before 'node'. If 'node'
    869   // is the head, then nullptr is returned.
    870   ValueNode* Prev(const ValueNode& node) const {
    871     if (IsHead(node)) {
    872       return nullptr;
    873     } else {
    874       return node.prev;
    875     }
    876   }
    877 
    878   // Splices the entire linked list with 'head' as its head right after the
    879   // node 'insert_after' in another linked list.
    880   void SpliceAfter(ValueNode* head, ValueNode* insert_after) {
    881     DCHECK(IsHead(*head));
    882     value_lists_.erase(head);
    883 
    884     ValueNode* tail = head->prev;
    885     tail->next = insert_after->next;
    886     insert_after->next->prev = tail;
    887 
    888     insert_after->next = head;
    889     head->prev = insert_after;
    890   }
    891 
    892   string ValueListToString(const ValueNode* element) {
    893     const ValueNode* head = element;
    894     while (!IsHead(*head)) {
    895       head = Prev(*head);
    896     }
    897     std::vector<const HloValue*> values;
    898     for (const ValueNode* p = head; p != nullptr; p = Next(*p)) {
    899       values.push_back(p->value);
    900     }
    901     return absl::StrCat("{",
    902                         absl::StrJoin(values, ", ",
    903                                       [](string* s, const HloValue* value) {
    904                                         StrAppend(s, value->ToShortString());
    905                                       }),
    906                         "}");
    907   }
    908 
    909   string ToString() const {
    910     string out = absl::StrCat("CopyRemover:\n");
    911     StrAppend(&out, "  Def-use chains in each buffer:\n");
    912     for (const ValueNode* head : value_lists_) {
    913       StrAppend(&out, "    Buffer defined by ", head->value->ToShortString(),
    914                 ":\n");
    915       const ValueNode* p = head;
    916       do {
    917         StrAppend(&out, "      ", p->value->ToShortString(), ", uses: ",
    918                   absl::StrJoin(p->uses, "; ",
    919                                 [](string* s, const HloUse* use) {
    920                                   StrAppend(s, use->ToString());
    921                                 }),
    922                   "\n");
    923 
    924         p = p->next;
    925       } while (p != head);
    926     }
    927     StrAppend(&out, "  Potentially removable copies:\n");
    928     for (const auto& pair : copy_map_) {
    929       const HloInstruction* copy = pair.first;
    930       const CopyNodes& copy_info = pair.second;
    931 
    932       StrAppend(&out, "    ", copy->name(), " : ",
    933                 copy_info.src->value->ToShortString(), " => ",
    934                 copy_info.dest->value->ToShortString(), "\n");
    935     }
    936     return out;
    937   }
    938 
    939  private:
    940   const HloDataflowAnalysis& dataflow_;
    941   const HloOrdering& ordering_;
    942 
    943   // The heads of all the value lists. Each value list represents the HLO
    944   // values contained in a particular HLO buffer. The values in the list are
    945   // in dependency order.
    946   absl::flat_hash_set<const ValueNode*> value_lists_;
    947 
    948   // Copy removal requires fast access to the value list elements
    949   // corresponding to the source and destination values of the kCopy
    950   // instruction. This data structure holds pointers to these elements for
    951   // each kCopy instruction in the graph.
    952   struct CopyNodes {
    953     // The source and destinations values of the kCopy instruction.
    954     ValueNode* src = nullptr;
    955     ValueNode* dest = nullptr;
    956   };
    957   absl::flat_hash_map<const HloInstruction*, CopyNodes> copy_map_;
    958 };
    959 
    960 }  // namespace
    961 
    962 // Add kCopy instructions to the given module to guarantee there is no
    963 // live-range interference. Generally interference can only occur around kWhile
    964 // instructions which have update-in-place semantics.
    965 Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
    966   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
    967                       HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
    968 
    969   for (HloComputation* computation : module->computations()) {
    970     for (HloInstruction* instruction : computation->instructions()) {
    971       if (instruction->opcode() == HloOpcode::kWhile) {
    972         TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction));
    973       } else if (instruction->opcode() == HloOpcode::kConditional) {
    974         TF_RETURN_IF_ERROR(
    975             AddCopiesForConditional(*alias_analysis, instruction));
    976       }
    977     }
    978   }
    979 
    980   TF_RETURN_IF_ERROR(AddCopiesForAliasedInputOutputs(module));
    981   return Status::OK();
    982 }
    983 
    984 Status CopyInsertion::AddSpecialCaseCopies(HloModule* module) {
    985   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
    986   return AddSpecialCaseCopies(*call_graph, module);
    987 }
    988 
    989 Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
    990                                            HloModule* module) {
    991   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
    992                       HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
    993 
    994   // Identify which shape indices of which instructions need to be copied. Store
    995   // these results in 'instructions_to_copy'.
    996   HloInstructionMap<ShapeTree<bool>> instructions_to_copy;
    997   auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction,
    998                                                    const ShapeIndex& index) {
    999     auto it = instructions_to_copy.find(instruction);
   1000     if (it == instructions_to_copy.end()) {
   1001       auto it_added = instructions_to_copy.emplace(
   1002           std::piecewise_construct, std::forward_as_tuple(instruction),
   1003           std::forward_as_tuple(instruction->shape(), /*init_value=*/false));
   1004       it = it_added.first;
   1005     }
   1006     *it->second.mutable_element(index) = true;
   1007   };
   1008 
   1009   // Iterate through values of all constants and entry parameters. These values
   1010   // are special because they are held in read-only buffers. If any of these
   1011   // values share a buffer with other values (for example, the init value of a
   1012   // while is a constant) then copy the value at its definition and replace all
   1013   // its uses with the copy.
   1014   for (const HloValue* value : alias_analysis->dataflow_analysis().values()) {
   1015     if (ValueIsReadOnly(*value) &&
   1016         alias_analysis->GetBufferContainingValue(*value).values().size() > 1) {
   1017       VLOG(2) << "Value " << value->ToShortString()
   1018               << " is read only, but its buffer contains more than one value. "
   1019                  "Copying.";
   1020       add_index_to_copy(value->defining_instruction(), value->defining_index());
   1021     }
   1022   }
   1023 
   1024   // Identify copies which must be added at root instructions
   1025   for (HloComputation* computation : module->computations()) {
   1026     const CallGraphNode& node = call_graph.GetNode(computation);
   1027     if (node.context() == CallContext::kParallel) {
   1028       continue;
   1029     }
   1030     TF_RET_CHECK(node.context() == CallContext::kSequential);
   1031 
   1032     SpecialCaseCopyPolicy policy =
   1033         GetSpecialCaseCopyPolicy(node, module, computation);
   1034     HloInstruction* root = computation->root_instruction();
   1035 
   1036     // Mark nondistinct/ambiguous indices.
   1037     absl::flat_hash_set<const HloBuffer*> seen;
   1038     ShapeUtil::ForEachSubshape(
   1039         root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
   1040           std::vector<const HloBuffer*> buffers_at_index =
   1041               alias_analysis->ComputeBuffersAt(root, index);
   1042           bool buffer_seen_before = false;
   1043           for (const HloBuffer* buffer : buffers_at_index) {
   1044             buffer_seen_before |= !seen.insert(buffer).second;
   1045           }
   1046           if (buffers_at_index.size() > 1 ||
   1047               (buffer_seen_before && policy.copy_root_replicated_buffers)) {
   1048             VLOG(2) << "Index " << index << " of computation "
   1049                     << computation->name() << " (" << root->name()
   1050                     << ") has ambiguous or non-distinct buffer. Copying.";
   1051             add_index_to_copy(root, index);
   1052           }
   1053         });
   1054 
   1055     for (const auto& pair :
   1056          alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) {
   1057       const ShapeIndex& index = pair.first;
   1058       const HloValueSet& value_set = pair.second;
   1059       for (const HloValue* value : value_set.values()) {
   1060         if (ShouldCopyRootValue(*value, policy)) {
   1061           VLOG(2) << "Root of (" << root->name() << ") of computation("
   1062                   << computation->name()
   1063                   << ") has constant or parameter value at index " << index
   1064                   << ". Copying.";
   1065           add_index_to_copy(root, index);
   1066         }
   1067       }
   1068     }
   1069   }
   1070 
   1071   // Add copy instructions indicated in 'instructions_to_copy' to the module.
   1072   for (const auto& pair : instructions_to_copy) {
   1073     HloInstruction* instruction = pair.first;
   1074     const ShapeTree<bool>& indices_to_copy = pair.second;
   1075 
   1076     ShapeTree<HloInstruction*> copies_added(indices_to_copy.shape());
   1077     std::vector<HloInstruction*> users = instruction->users();
   1078     TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
   1079                         instruction->parent()->DeepCopyInstruction(
   1080                             instruction, &indices_to_copy, &copies_added));
   1081     for (HloInstruction* user : users) {
   1082       TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy));
   1083     }
   1084     if (instruction == instruction->parent()->root_instruction()) {
   1085       instruction->parent()->set_root_instruction(deep_copy);
   1086     }
   1087   }
   1088   return Status::OK();
   1089 }
   1090 
   1091 Status CopyInsertion::VerifyNoLiveRangeInterference(const HloOrdering& ordering,
   1092                                                     HloModule* module) {
   1093   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
   1094                       HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
   1095   TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering));
   1096   return Status::OK();
   1097 }
   1098 
   1099 Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
   1100                                               HloModule* module) {
   1101   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
   1102                       HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
   1103 
   1104   CopyRemover copy_remover(*module, *alias_analysis, ordering);
   1105   if (VLOG_IS_ON(3)) {
   1106     LOG(INFO) << "Removing unnecessary copies in " << module->name();
   1107     LOG(INFO) << "Buffer values, in dependency order: ";
   1108     for (const HloBuffer& buffer : alias_analysis->buffers()) {
   1109       LOG(INFO) << "    HloBuffer " << buffer.id();
   1110     }
   1111   }
   1112 
   1113   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
   1114   for (HloComputation* computation : module->computations()) {
   1115     for (HloInstruction* instruction : computation->instructions()) {
   1116       if (instruction->opcode() == HloOpcode::kCopy &&
   1117           copy_remover.TryElideCopy(instruction)) {
   1118         TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction));
   1119         TF_RETURN_IF_ERROR(
   1120             instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
   1121       }
   1122     }
   1123   }
   1124   return Status::OK();
   1125 }
   1126 
   1127 StatusOr<bool> CopyInsertion::Run(HloModule* module) {
   1128   // Copy insertion is performed in three steps:
   1129   //
   1130   // (1) Add copies conservatively to guarantee that there is no live-range
   1131   //     interference. This is done simplistically and usually results in more
   1132   //     copies than is strictly necessary.
   1133   //
   1134   // (2) Using a more fine-grained analysis, remove as many copies that were
   1135   //     added in (1) as possible while ensuring no live-range interference.
   1136   //
   1137   // (3) Add copies to resolve issues not related to live range interference
   1138   //     such as parameters and constants live out of the entry computation.
   1139   //
   1140   // We add copies then remove them (step (1) then (2)) rather than simply
   1141   // adding only the copies that are necessary because, in general, it is
   1142   // difficult to figure out the minimal set of copies to add once there is
   1143   // interference. On the other hand, it is easy to determine if removing a copy
   1144   // will introduce interference.
   1145   //
   1146   // The final copy insertion in (3) is done separately to simplify the
   1147   // implementation of copy removal in (2) which is the most complicated part of
   1148   // the pass. As is, copy removal only has to reason about live range
   1149   // interference. If all copies were added in step (1) then copy removal would
   1150   // also have to reason about things like constants and parameters live out of
   1151   // the computation.
   1152   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
   1153   if (!call_graph->IsFlattened()) {
   1154     return FailedPrecondition(
   1155         "Call graph must be flattened before copy insertion.");
   1156   }
   1157 
   1158   int64 num_existing_copies = 0;
   1159   if (VLOG_IS_ON(1)) {
   1160     for (HloComputation* computation : module->computations()) {
   1161       for (HloInstruction* instruction : computation->instructions()) {
   1162         if (instruction->opcode() == HloOpcode::kCopy) {
   1163           ++num_existing_copies;
   1164         }
   1165       }
   1166     }
   1167   }
   1168 
   1169   TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module));
   1170 
   1171   // Simplify the tuple structures introduced by the deep copies. This should be
   1172   // done before removing copies (RemoveUnnecessaryCopies) because tuple
   1173   // simplification changes dependencies in the graph which changes live range
   1174   // interference in the graph. Also run DCE to remove the dead Tuple/GTE
   1175   // instructions introduced by tuple simplification.
   1176   TupleSimplifier tuple_simplifier;
   1177   HloDCE dce;
   1178   TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
   1179   TF_RETURN_IF_ERROR(dce.Run(module).status());
   1180   DumpHloModuleDuringPassIfEnabled(
   1181       name(), "after adding copies to resolve interference", *module);
   1182 
   1183   DependencyHloOrdering dep_ordering(module);
   1184   TF_DCHECK_OK(VerifyNoLiveRangeInterference(dep_ordering, module));
   1185 
   1186   TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(dep_ordering, module));
   1187   DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies",
   1188                                    *module);
   1189 
   1190   TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
   1191   DumpHloModuleDuringPassIfEnabled(name(), "after adding special-case copies",
   1192                                    *module);
   1193 
   1194   TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
   1195   TF_RETURN_IF_ERROR(dce.Run(module).status());
   1196   TF_DCHECK_OK(
   1197       VerifyNoLiveRangeInterference(DependencyHloOrdering(module), module));
   1198 
   1199   if (VLOG_IS_ON(1)) {
   1200     int64 num_total_copies = 0;
   1201     for (HloComputation* computation : module->computations()) {
   1202       for (HloInstruction* instruction : computation->instructions()) {
   1203         if (instruction->opcode() == HloOpcode::kCopy) {
   1204           num_total_copies++;
   1205         }
   1206       }
   1207     }
   1208     VLOG(1) << "Num copies before copy-insertion: " << num_existing_copies;
   1209     VLOG(1) << "Num copies after copy-insertion: " << num_total_copies;
   1210   }
   1211 
   1212   return true;
   1213 }
   1214 
   1215 namespace {
   1216 
   1217 bool IsWhileBody(const HloComputation* computation,
   1218                  const CallGraph& call_graph) {
   1219   const CallGraphNode& node = call_graph.GetNode(computation);
   1220 
   1221   if (node.context() == CallContext::kSequential &&
   1222       !node.caller_callsites().empty()) {
   1223     // Callgraph should be flattened so sequential context computations can
   1224     // have at most one caller.
   1225     CHECK_EQ(node.caller_callsites().size(), 1);
   1226     const HloInstruction* calling_instruction =
   1227         node.caller_callsites()[0].instruction();
   1228     if (calling_instruction->opcode() == HloOpcode::kWhile &&
   1229         calling_instruction->while_body() == node.computation()) {
   1230       return true;
   1231     }
   1232   }
   1233   return false;
   1234 }
   1235 
   1236 }  // namespace
   1237 
   1238 /* static */ StatusOr<bool> CopyInsertion::AddCopiesForBufferAssignment(
   1239     HloModule* module) {
   1240   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
   1241   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow,
   1242                       HloDataflowAnalysis::Run(*module));
   1243 
   1244   bool changed = false;
   1245 
   1246   // If a buffer live out of a computation is a constant, a parameter, or not
   1247   // defined in the computation, then copy it to account for the limited
   1248   // computation-scoped analysis in buffer assignment. An exception to this rule
   1249   // is the while body which is handled properly without copies.
   1250   for (HloComputation* computation : module->computations()) {
   1251     if (computation == module->entry_computation() ||
   1252         IsWhileBody(computation, *call_graph)) {
   1253       continue;
   1254     }
   1255 
   1256     HloInstruction* root = computation->root_instruction();
   1257     ShapeTree<bool> indices_to_copy(root->shape(), /*init_value=*/false);
   1258     bool copy_root = false;
   1259     for (const auto& pair : dataflow->GetInstructionValueSet(root)) {
   1260       const ShapeIndex& index = pair.first;
   1261       const HloValueSet& value_set = pair.second;
   1262       for (const HloValue* value : value_set.values()) {
   1263         HloInstruction* def = value->defining_instruction();
   1264         if (def->parent() != computation ||
   1265             def->opcode() == HloOpcode::kConstant ||
   1266             def->opcode() == HloOpcode::kParameter) {
   1267           *indices_to_copy.mutable_element(index) = true;
   1268           copy_root = true;
   1269         }
   1270       }
   1271     }
   1272     if (copy_root) {
   1273       TF_ASSIGN_OR_RETURN(
   1274           HloInstruction * root_copy,
   1275           computation->DeepCopyInstruction(root, &indices_to_copy));
   1276       computation->set_root_instruction(root_copy);
   1277       changed = true;
   1278     }
   1279   }
   1280 
   1281   TupleSimplifier tuple_simplifier;
   1282   HloDCE dce;
   1283   TF_ASSIGN_OR_RETURN(bool tuple_simplifier_changed,
   1284                       tuple_simplifier.Run(module));
   1285   TF_ASSIGN_OR_RETURN(bool dce_changed, dce.Run(module));
   1286 
   1287   return changed || tuple_simplifier_changed || dce_changed;
   1288 }
   1289 
   1290 }  // namespace xla
   1291