Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/copy_insertion.h"
     17 
     18 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
     19 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     20 #include "tensorflow/compiler/xla/service/hlo_dce.h"
     21 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
     22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     23 #include "tensorflow/compiler/xla/service/hlo_module.h"
     24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     25 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
     26 #include "tensorflow/compiler/xla/service/liveness_util.h"
     27 #include "tensorflow/compiler/xla/service/logical_buffer.h"
     28 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
     29 #include "tensorflow/compiler/xla/status_macros.h"
     30 #include "tensorflow/compiler/xla/statusor.h"
     31 #include "tensorflow/compiler/xla/types.h"
     32 #include "tensorflow/compiler/xla/util.h"
     33 #include "tensorflow/core/lib/gtl/flatmap.h"
     34 #include "tensorflow/core/lib/gtl/flatset.h"
     35 #include "tensorflow/core/lib/strings/str_util.h"
     36 #include "tensorflow/core/lib/strings/strcat.h"
     37 #include "tensorflow/core/platform/logging.h"
     38 
     39 namespace xla {
     40 
     41 using ::tensorflow::str_util::Join;
     42 using ::tensorflow::strings::StrAppend;
     43 using ::tensorflow::strings::StrCat;
     44 
     45 namespace {
     46 
     47 bool IsEntryParameterValue(const HloValue& value) {
     48   const HloComputation* computation = value.defining_instruction()->parent();
     49   return value.defining_instruction()->opcode() == HloOpcode::kParameter &&
     50          computation == computation->parent()->entry_computation();
     51 }
     52 
     53 bool IsConstantValue(const HloValue& value) {
     54   return value.defining_instruction()->opcode() == HloOpcode::kConstant;
     55 }
     56 
     57 bool ValueIsReadOnly(const HloValue& value) {
     58   return IsConstantValue(value) || IsEntryParameterValue(value);
     59 }
     60 
     61 // Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in
     62 // 'indices_to_copy'. Add control edges from the respective kCopy instructions
     63 // in deep copy of 'from' to the respective kCopy instruction in the deep copy
     64 // of 'to'.
     65 //
     66 // Requirements: 'from' and 'to' must have compatible shapes.
     67 //
     68 // For example, suppose 'from' and 'to' are two-element tuples where index 0 is
     69 // the only index to copy. Prior to deep-copying we have:
     70 //
     71 //
     72 //      'from'
     73 //         |
     74 //        ...
     75 //         |
     76 //       'to'
     77 //
     78 // DeepCopyAndAddControlEdges produces:
     79 //
     80 //       'from'
     81 //        /   \
     82 //      GTE   GTE
     83 //       |     |
     84 //     Copy    |
     85 //    /   \   /
     86 //   |    Tuple
     87 //   |      |
     88 //  ctrl   ...
     89 //  edge    |
     90 //   |      |
     91 //   |    'to'
     92 //   |    /   \
     93 //   |  GTE   GTE
     94 //    \  |     |
     95 //     Copy    |
     96 //        \   /
     97 //        Tuple
     98 //
     99 StatusOr<std::pair<HloInstruction*, HloInstruction*>>
    100 DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to,
    101                            const ShapeTree<bool>& indices_to_copy) {
    102   DCHECK(ShapeUtil::Compatible(from->shape(), to->shape()));
    103   // to/from_copy_tree hold the kCopy instruction produces by the deep
    104   // copies. Elements which are not copied (indices_to_copy.element(index) ==
    105   // false) have nullptr at that index.
    106   ShapeTree<HloInstruction*> from_copy_tree(from->shape(),
    107                                             /*init_value=*/nullptr);
    108   TF_ASSIGN_OR_RETURN(HloInstruction * from_deep_copy,
    109                       from->parent()->DeepCopyInstruction(
    110                           from, &indices_to_copy, &from_copy_tree));
    111 
    112   ShapeTree<HloInstruction*> to_copy_tree(to->shape(), /*init_value=*/nullptr);
    113   TF_ASSIGN_OR_RETURN(
    114       HloInstruction * to_deep_copy,
    115       to->parent()->DeepCopyInstruction(to, &indices_to_copy, &to_copy_tree));
    116 
    117   // Add control edges between the respective kCopy instructions.
    118   for (const auto& pair : from_copy_tree) {
    119     const ShapeIndex& index = pair.first;
    120     HloInstruction* from_copy = pair.second;
    121     HloInstruction* to_copy = to_copy_tree.element(index);
    122     if (from_copy == nullptr) {
    123       TF_RET_CHECK(to_copy == nullptr);
    124       continue;
    125     }
    126     TF_RET_CHECK(to_copy != nullptr);
    127     TF_RETURN_IF_ERROR(from_copy->AddControlDependencyTo(to_copy));
    128   }
    129 
    130   return std::make_pair(from_deep_copy, to_deep_copy);
    131 }
    132 
    133 // Compute the indices of the loop state which need copies in order to avoid
    134 // live range interference. Generally, an element in the loop state does not
    135 // need to be copied if the element is passed through transparently through the
    136 // body.
    137 //
    138 // Returns whether any indices need to be copied.
    139 bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow,
    140                            const HloInstruction* xla_while,
    141                            ShapeTree<bool>* indices_to_copy) {
    142   DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(), xla_while->shape()));
    143 
    144   bool any_copies = false;
    145   const HloInstruction* init = xla_while->operand(0);
    146   for (auto& pair : *indices_to_copy) {
    147     const ShapeIndex& index = pair.first;
    148     bool& should_copy = pair.second;
    149     // If there is any ambiguity, then loop state must be copied.
    150     if (dataflow.GetValueSet(init, index).values().size() > 1 ||
    151         dataflow.GetValueSet(xla_while, index).values().size() > 1) {
    152       should_copy = true;
    153     } else {
    154       // If the output of the while instruction is not the same as the init
    155       // value of the while, then this element is not passed through the body
    156       // transparently and must be copied.
    157       should_copy = dataflow.GetUniqueValueAt(xla_while, index) !=
    158                     dataflow.GetUniqueValueAt(init, index);
    159     }
    160     any_copies |= should_copy;
    161   }
    162   return any_copies;
    163 }
    164 
    165 // Add kCopy instructions around the given kWhile instruction to eliminate any
    166 // possible live range interference of HLO values assuming a dependency-based
    167 // ordering (HloDependencyOrdering). Copies are added conservatively. There
    168 // likely are copies which are not strictly necessary, but there are removed
    169 // later in the pass via CopyRemover.
    170 //
    171 //
    172 // Elements (each ShapeIndex) in the loop state are considered independently.  A
    173 // copy is added to each element of the loop state which is modified in the
    174 // while body. For each such element, a total of three kCopy instructions are
    175 // added at following locations:
    176 //
    177 //   (1) The init value is copied before the kWhile instruction. Before:
    178 //
    179 //           (Init)
    180 //             |
    181 //           kWhile
    182 //             |
    183 //            ...
    184 //
    185 //       After:
    186 //
    187 //           (Init)
    188 //             |
    189 //           kCopy
    190 //             |
    191 //           kWhile
    192 //             |
    193 //            ...
    194 //
    195 //       This copy is necessary in case the init value is simultaneously live
    196 //       with the kWhile.
    197 //
    198 //   (2) Copies are added to the parameter and root of the while body
    199 //       computation. Before:
    200 //
    201 //           kParameter
    202 //               |
    203 //              ...
    204 //               |
    205 //           (body root)
    206 //
    207 //       After:
    208 //
    209 //           kParameter
    210 //               |
    211 //             kCopy ----------+
    212 //               |             |
    213 //              ...           ctrl
    214 //               |            edge
    215 //           (body root)       |
    216 //               |             |
    217 //             kCopy <---------+
    218 //
    219 //       The root kCopy becomes the new root of the computation. Both copies are
    220 //       necessary to any potential interference between the parameter value and
    221 //       the root value. The control edge prevents potential interference
    222 //       between the copies themselves.
    223 //
    224 // If the loop state is a tuple then the above kCopy instructions are a deep
    225 // copy constructed of kCopy, KGetTupleElement, and kTuple instruction as
    226 // constructed by HloInstruction::DeepCopyInstruction.
    227 Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis,
    228                          HloInstruction* xla_while) {
    229   VLOG(2) << "Adding copies for kWhile instruction " << xla_while->name();
    230   TF_RET_CHECK(xla_while->opcode() == HloOpcode::kWhile);
    231 
    232   ShapeTree<bool> indices_to_copy(xla_while->shape());
    233   if (!IndicesToCopyForWhile(alias_analysis.dataflow_analysis(), xla_while,
    234                              &indices_to_copy)) {
    235     VLOG(2) << "No copies necessary for kWhile instruction "
    236             << xla_while->name();
    237     return Status::OK();
    238   }
    239 
    240   VLOG(2) << "Adding copies for " << xla_while->name() << " at indices:";
    241   for (auto& pair : indices_to_copy) {
    242     if (pair.second) {
    243       VLOG(2) << "  " << pair.first;
    244     }
    245   }
    246 
    247   // Deep copy init.
    248   HloInstruction* while_init = xla_while->mutable_operand(0);
    249   TF_ASSIGN_OR_RETURN(
    250       HloInstruction * while_init_copy,
    251       xla_while->parent()->DeepCopyInstruction(while_init, &indices_to_copy));
    252   TF_RETURN_IF_ERROR(while_init->ReplaceUseWith(xla_while, while_init_copy));
    253 
    254   // Deep copy the parameter and the root. Extend a control edge from the copy
    255   // of the parameter value to the corresponding copy value of the root.
    256   HloComputation* body = xla_while->while_body();
    257   HloInstruction* param = body->parameter_instruction(0);
    258   HloInstruction* root = body->root_instruction();
    259 
    260   // If param is the root then all indices should have been passed through the
    261   // while body and we should have returned early above.
    262   TF_RET_CHECK(param != root);
    263 
    264   // Copy users before making a deep copy of the parameter as the deep copy
    265   // will create new users of the parameter (eg, the GTE instructions of the
    266   // deep copy).
    267   std::vector<HloInstruction*> param_users = param->users();
    268 
    269   ShapeIndex current_index;
    270   TF_ASSIGN_OR_RETURN(auto pair,
    271                       DeepCopyAndAddControlEdges(param, root, indices_to_copy));
    272 
    273   HloInstruction* param_copy = pair.first;
    274   HloInstruction* root_copy = pair.second;
    275 
    276   for (HloInstruction* user : param_users) {
    277     TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, param_copy));
    278   }
    279 
    280   body->set_root_instruction(root_copy);
    281 
    282   return Status::OK();
    283 }
    284 
    285 // Removes any control dependencies to or from the given instruction.
    286 Status StripControlDependenciesFrom(HloInstruction* instruction) {
    287   while (!instruction->control_successors().empty()) {
    288     TF_RETURN_IF_ERROR(instruction->RemoveControlDependencyTo(
    289         instruction->control_successors().front()));
    290   }
    291 
    292   while (!instruction->control_predecessors().empty()) {
    293     TF_RETURN_IF_ERROR(
    294         instruction->control_predecessors().front()->RemoveControlDependencyTo(
    295             instruction));
    296   }
    297 
    298   return Status::OK();
    299 }
    300 
    301 // Add kCopy instructions to the given module to guarantee there is no
    302 // live-range interference. Generally interference can only occur around kWhile
    303 // instructions which have update-in-place semantics.
    304 Status AddCopiesToResolveInterference(HloModule* module) {
    305   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
    306                       HloAliasAnalysis::Run(module));
    307 
    308   for (HloComputation* computation : module->computations()) {
    309     for (HloInstruction* instruction : computation->instructions()) {
    310       if (instruction->opcode() == HloOpcode::kWhile) {
    311         TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction));
    312       }
    313     }
    314   }
    315   return Status::OK();
    316 }
    317 
    318 // Class for removing unnecessary copies from the module.
    319 //
    320 // kCopy instructions are added conservatively to guarantee no live range
    321 // interference between HLO values. This class uses a more fine-grained analysis
    322 // to remove some of these added copies which are not strictly necessary.
    323 class CopyRemover {
    324  public:
    325   CopyRemover(const HloAliasAnalysis& alias_analysis,
    326               const HloOrdering& ordering, HloModule* module)
    327       : module_(module),
    328         alias_analysis_(alias_analysis),
    329         ordering_(ordering),
    330         buffer_value_tracker_(*module, alias_analysis, ordering) {}
    331 
    332   // Try to elide the given copy. The copy is elided if the instruction is not
    333   // necessary to prevent live-range interference of HLO values. Returns true if
    334   // copy was elided.
    335   //
    336   // The copy instruction is not actually removed here. Instead it is left for
    337   // dead in the graph. Later calls to DCE will remove the instruction.
    338   StatusOr<bool> TryElideCopy(HloInstruction* copy) {
    339     if (buffer_value_tracker_.TryElideCopy(copy)) {
    340       TF_RETURN_IF_ERROR(StripControlDependenciesFrom(copy));
    341       TF_RETURN_IF_ERROR(copy->ReplaceAllUsesWith(copy->mutable_operand(0)));
    342       return true;
    343     }
    344     return false;
    345   }
    346 
    347   string ToString() const {
    348     string out = StrCat("CopyRemover, module ", module_->name(), "\n");
    349     StrAppend(&out, "  Buffer values, in dependency order:\n");
    350     for (const HloBuffer& buffer : alias_analysis_.buffers()) {
    351       StrAppend(&out, "    HloBuffer ", buffer.id(), ":\n");
    352     }
    353     return out;
    354   }
    355 
    356  private:
    357   // Class which tracks the HLO values within each HLO buffer in the module
    358   // during copy removal.
    359   //
    360   // The values are held in a linked list where there is one list for each
    361   // buffer. Removing a copy instruction merges together the values in the
    362   // source buffer of the copy to the destination buffer of the copy. This class
    363   // tracks these value lists as copies are removed from the graph (and value
    364   // lists are merged).
    365   //
    366   // The BufferValueTracker object is initialized to match the state of
    367   // HloAliasAnalysis. However, as copies are removed this state diverges. The
    368   // values-to-buffer mapping is maintained outside of HloAliasAnalysis because
    369   // a fully updatable alias analysis is very slow.
    370   class BufferValueTracker {
    371    public:
    372     // The values held in a single HLO buffer are represented using a linked
    373     // list. An element type in this list is ValueNode.
    374     //
    375     // This linked list is hand-rolled to enable efficient splicing of lists
    376     // using only references to list elements without knowing which lists are
    377     // being spliced. std::list requires a reference to the list object to
    378     // splice.
    379     struct ValueNode {
    380       explicit ValueNode(const HloValue* v) : value(v) {}
    381 
    382       const HloValue* value;
    383 
    384       // The uses are maintained outside of HloValue::uses() because
    385       // HloValue::uses() is not updatable (a fully updatable dataflow analysis
    386       // is slow).
    387       std::vector<const HloUse*> uses;
    388 
    389       // next/prev elements in the linked list. The list is circularly linked so
    390       // these values are never null for elements in the list.
    391       ValueNode* prev = nullptr;
    392       ValueNode* next = nullptr;
    393     };
    394 
    395     BufferValueTracker(const HloModule& module,
    396                        const HloAliasAnalysis& alias_analysis,
    397                        const HloOrdering& ordering)
    398         : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) {
    399       // Construct a list for each HLO buffer in the alias analysis. Maintain a
    400       // map from HloValue to the respective list element representing that
    401       // value. The map is used to construct the copy info map below.
    402       tensorflow::gtl::FlatMap<const HloValue*, ValueNode*> value_to_node;
    403       for (const HloBuffer& buffer : alias_analysis.buffers()) {
    404         // Verify values contained in the buffer are strictly ordered. This
    405         // should always be the case after adding copies to eliminate
    406         // interference. Specifically, the addition of the control flow edges
    407         // between copies added around aliased operations (kWhile) guarantees
    408         // this strict order.
    409         for (const HloValue* value_a : buffer.values()) {
    410           for (const HloValue* value_b : buffer.values()) {
    411             if (value_a != value_b) {
    412               DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b,
    413                                                        dataflow_) ||
    414                      ordering_.LiveRangeStrictlyBefore(*value_b, *value_a,
    415                                                        dataflow_))
    416                   << value_a->ToShortString() << " and "
    417                   << value_b->ToShortString() << " are not ordered";
    418             }
    419           }
    420         }
    421 
    422         std::vector<const HloValue*> values = buffer.values();
    423         std::sort(values.begin(), values.end(),
    424                   [this](const HloValue* a, const HloValue* b) {
    425                     return ordering_.IsDefinedBefore(*a, *b);
    426                   });
    427 
    428         // Create a list containing all of the values in the buffer.
    429         AddValueList(values, &value_to_node);
    430       }
    431 
    432       // Create copy_map_ which contains the source and destination values
    433       // of all copies.
    434       CreateCopyMap(module, value_to_node);
    435 
    436       XLA_VLOG_LINES(3, ToString());
    437       TF_DCHECK_OK(Verify());
    438     }
    439 
    440     // Add a list containing the given values to BufferValueTracker. This
    441     // represents the values contained in a single buffer. For each value in
    442     // 'values' an entry is created in value_to_node which indicates the
    443     // respective ValueNode representing that value.
    444     void AddValueList(
    445         tensorflow::gtl::ArraySlice<const HloValue*> values,
    446         tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>* value_to_node) {
    447       ValueNode* tail = nullptr;
    448       ValueNode* head = nullptr;
    449       for (const HloValue* value : values) {
    450         auto new_node = new ValueNode(value);
    451         (*value_to_node)[value] = new_node;
    452 
    453         // Copy the HLO values's uses into the ValueNode for the value. These
    454         // uses in ValueNode are updated as copies are removed.
    455         new_node->uses.reserve(value->uses().size());
    456         for (const HloUse& use : value->uses()) {
    457           new_node->uses.push_back(&use);
    458         }
    459 
    460         // Connect the new node into the linked list.
    461         if (tail == nullptr) {
    462           head = new_node;
    463         } else {
    464           tail->next = new_node;
    465           new_node->prev = tail;
    466         }
    467         tail = new_node;
    468       }
    469 
    470       // The linked list is circular so connect the head and tail.
    471       tail->next = head;
    472       head->prev = tail;
    473       value_lists_.insert(head);
    474     }
    475 
    476     // This method also fills in copy_map_ which indicates which nodes
    477     // in the value lists corresponding to the source and destination values of
    478     // kCopy instructions. value_to_node should map each HloValue to its
    479     // respective ValueNode.
    480     void CreateCopyMap(
    481         const HloModule& module,
    482         const tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>&
    483             value_to_node) {
    484       for (HloComputation* computation : module.computations()) {
    485         for (HloInstruction* instruction : computation->instructions()) {
    486           // Add copies with unambiguous source values to the map. Copies with
    487           // ambiguous sources are not removable.
    488           if (instruction->opcode() == HloOpcode::kCopy) {
    489             const HloValueSet& src_value_set =
    490                 dataflow_.GetValueSet(instruction->operand(0));
    491             if (src_value_set.values().size() == 1) {
    492               CopyNodes& copy_node = copy_map_[instruction];
    493               copy_node.dest =
    494                   value_to_node.at(&dataflow_.GetUniqueValueAt(instruction));
    495               copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue());
    496             }
    497           }
    498         }
    499       }
    500     }
    501 
    502     ~BufferValueTracker() {
    503       for (const ValueNode* head : value_lists_) {
    504         const ValueNode* p = head;
    505         do {
    506           const ValueNode* tmp = p->next;
    507           delete p;
    508           p = tmp;
    509         } while (p != head);
    510       }
    511     }
    512 
    513     // Verify invariants within the linked lists.
    514     Status Verify() const {
    515       for (const ValueNode* head : value_lists_) {
    516         const ValueNode* p = head;
    517         do {
    518           // Verify links between elements are consistent.
    519           TF_RET_CHECK(p->prev->next == p);
    520           TF_RET_CHECK(p->next->prev == p);
    521 
    522           const HloInstruction* def = p->value->defining_instruction();
    523           if (def->opcode() == HloOpcode::kCopy &&
    524               ContainsKey(copy_map_, def)) {
    525             TF_RET_CHECK(copy_map_.at(def).dest == p);
    526           }
    527           for (const HloUse* use : p->uses) {
    528             if (use->instruction->opcode() == HloOpcode::kCopy &&
    529                 ContainsKey(copy_map_, use->instruction)) {
    530               TF_RET_CHECK(copy_map_.at(use->instruction).src == p);
    531             }
    532           }
    533 
    534           p = p->next;
    535         } while (p != head);
    536       }
    537       return Status::OK();
    538     }
    539 
    540     // Try to elide the given copy. Elision of a copy is possible only if no
    541     // live range interference is introduced by the copy's elimination. If
    542     // elision is possible, then the internal state (value lists) are updated,
    543     // and true is returned. Returns false otherwise.
    544     bool TryElideCopy(const HloInstruction* copy) {
    545       VLOG(2) << "Trying to remove " << copy->name();
    546 
    547       if (!ContainsKey(copy_map_, copy)) {
    548         VLOG(2) << copy->name() << " is not removable";
    549         return false;
    550       }
    551 
    552       const CopyNodes& copy_node = copy_map_.at(copy);
    553       ValueNode* src = copy_node.src;
    554       ValueNode* dest = copy_node.dest;
    555       DCHECK(src != nullptr);
    556       DCHECK(dest != nullptr);
    557 
    558       auto is_live_range_before = [this](const ValueNode& a,
    559                                          const ValueNode& b) {
    560         if (LiveRangeBefore(a, b)) {
    561           VLOG(2) << "  Live range of " << a.value->ToShortString()
    562                   << " is before " << b.value->ToShortString();
    563           return true;
    564         } else {
    565           VLOG(2) << "  Live range of " << a.value->ToShortString()
    566                   << " is not before " << b.value->ToShortString();
    567           return false;
    568         }
    569       };
    570 
    571       VLOG(3) << copy->name() << " copies value "
    572               << src->value->ToShortString();
    573       VLOG(3) << "Source buffer values: " << ValueListToString(src);
    574       VLOG(3) << "Dest buffer values: " << ValueListToString(src);
    575 
    576       // A kCopy instruction copies an HLO value from a source buffer and
    577       // defines an HLO value in a destination buffer. Most generally, the
    578       // source and destination buffers may each hold more than one value at
    579       // different points in the computation so we define the following:
    580       //
    581       //   Values in source buffer:      {s_0, ..., s_n}
    582       //   Values in destination buffer: {d_0, ..., d_m}
    583       //
    584       // A kCopy instruction between these buffers copies a value s_x in the
    585       // source buffer and defines a value d_y in the destination buffer. The
    586       // elision of a copy merges the source and destination buffers together,
    587       // so the list of values for the source and destination buffers are
    588       // merged.
    589       //
    590       // We handle two different cases for copy elision:
    591       //
    592       //  (1) the kCopy defines the first value in the destination buffer (d_0).
    593       //
    594       //  (2) the kCopy copies the last value in the source buffer (s_n).
    595       //
    596       // For the remaining case where the kCopy copies a not-last value from the
    597       // source buffer to a not-first value of the destination buffer, the kCopy
    598       // instruction cannot be removed. This case is generated, for example, if
    599       // the kCopy copies a while body parameter of the loop state at one tuple
    600       // index to a different tuple index in the while body root. Removal of the
    601       // copy necessarily results in live range interference of values in the
    602       // loop state at the two different tuple indices.
    603       //
    604       //  We can only perform copy elision if the resulting merged values have
    605       //  totally ordered live ranges; otherwise the merged buffer would have
    606       //  live range interference.
    607       if (IsHead(*dest)) {
    608         // The copy copies an arbitrary value in the source buffer (call it s_x)
    609         // and defines d_0, the first value in the destination buffer. After
    610         // merging, the values in the combined buffer must be strictly ordered
    611         // as follows** to elide the copy:
    612         //
    613         // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n}
    614         //
    615         // Removing the copy eliminates d_0, and uses of d_0 become uses of
    616         // s_x. In the above ordering, the live range of d_m must be ordered
    617         // before the live range of s_{x+1} and the definition and all uses of
    618         // s_x must be ordered before the definition of d_1. These conditions
    619         // are checked below prior to elision.
    620         //
    621         // ** Technically it might be possible to have a non-interfering
    622         //    non-trivial interleaving of the values of the source and
    623         //    destination buffers in the resulting order. However, this case is
    624         //    slow and complicated to check and likely not worth it. So instead
    625         //    we simply check for the case where *all* values of the destination
    626         //    buffer (d_1 through d_m) are spliced into the point where the copy
    627         //    used to be.
    628         VLOG(2) << copy->name() << " defines the first value in its buffer";
    629         ValueNode* next_dest = Next(*dest);
    630         if (next_dest != nullptr) {
    631           // Live range of 'from' value (s_x) must be before 'next_dest' (d_1);
    632           if (!is_live_range_before(*src, *next_dest)) {
    633             return false;
    634           }
    635         }
    636         ValueNode* next_src = Next(*src);
    637 
    638         if (next_src != nullptr) {
    639           // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}.
    640           ValueNode* last_dest = dest->prev;
    641           DCHECK(IsTail(*last_dest));
    642           if (!is_live_range_before(*last_dest, *next_src)) {
    643             return false;
    644           }
    645         }
    646 
    647         // Splice in destination buffer values list right after 'src'.
    648         SpliceAfter(dest, src);
    649       } else if (IsTail(*src)) {
    650         // The copy copies the last value in the source buffer, s_n, and defines
    651         // an arbitrary value in the destination buffer, d_y.  After
    652         // merging, the values in the combined buffer must be strictly ordered
    653         // as follows** to elide the copy:
    654         //
    655         // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m}
    656         //
    657         // Removing the copy eliminates d_y, and uses of d_y become uses of
    658         // s_n. To enforce the above order, the live range of d_{y-1} must be
    659         // before the live range of s_0, and the live range of s_n must be
    660         // before the live range of d_{y+1}.
    661         //
    662         // ** See comment above in the code handling Case (1).
    663         VLOG(2) << copy->name() << " copies the last value ("
    664                 << src->value->ToShortString() << ") in its buffer";
    665 
    666         ValueNode* prev_dest = Prev(*dest);
    667         // nullptr condition handled above in the first 'if' case.
    668         DCHECK(prev_dest != nullptr);
    669         ValueNode* first_src = src->next;
    670         DCHECK(IsHead(*first_src));
    671         if (!is_live_range_before(*prev_dest, *first_src)) {
    672           // Live range of value d_{y-1} is not before s_0.
    673           return false;
    674         }
    675         ValueNode* next_dest = Next(*dest);
    676         if (next_dest != nullptr) {
    677           if (!is_live_range_before(*src, *next_dest)) {
    678             // Live range of value s_n is not before d_{y+1}.
    679             return false;
    680           }
    681         }
    682 
    683         // Splice source buffer values list right after 'prev_dest'.
    684         SpliceAfter(first_src, prev_dest);
    685       } else {
    686         VLOG(2)
    687             << copy->name()
    688             << " copies value in middle of source buffer to value in middle "
    689                "of destination buffer";
    690         return false;
    691       }
    692 
    693       RemoveCopyValue(dest);
    694 
    695       XLA_VLOG_LINES(4, ToString());
    696       TF_DCHECK_OK(Verify());
    697 
    698       return true;
    699     }
    700 
    701     // Delete the given ValueNode associated with a elided kCopy
    702     // instruction. This should be called after splicing the value lists of the
    703     // source and destination buffers together.
    704     void RemoveCopyValue(ValueNode* copy_value_node) {
    705       CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(),
    706                HloOpcode::kCopy);
    707       ValueNode* operand_node = copy_value_node->prev;
    708       CHECK(operand_node != copy_value_node);
    709 
    710       VLOG(2) << "Removing copy " << operand_node->value->ToShortString()
    711               << " => " << copy_value_node->value->ToShortString();
    712 
    713       // Splice out the copy value node.
    714       operand_node->next = copy_value_node->next;
    715       copy_value_node->next->prev = operand_node;
    716 
    717       // Patch up uses. Remove use of copy from operand_node uses.
    718       auto it =
    719           std::find_if(operand_node->uses.begin(), operand_node->uses.end(),
    720                        [copy_value_node](const HloUse* use) {
    721                          return use->instruction ==
    722                                 copy_value_node->value->defining_instruction();
    723                        });
    724       CHECK(it != operand_node->uses.end());
    725       operand_node->uses.erase(it);
    726 
    727       // If the elided copy has any uses which are themselves kCopy instructions
    728       // then patch up the copy info to reflect the that this kCopy instruction
    729       // has a different operand (the operand of the elided copy).
    730       for (const HloUse* copy_use : copy_value_node->uses) {
    731         operand_node->uses.push_back(copy_use);
    732         if (copy_use->instruction->opcode() == HloOpcode::kCopy &&
    733             ContainsKey(copy_map_, copy_use->instruction)) {
    734           copy_map_.at(copy_use->instruction).src = operand_node;
    735         }
    736       }
    737 
    738       // Delete the copy info and the value node.
    739       copy_map_.erase(copy_value_node->value->defining_instruction());
    740       delete copy_value_node;
    741     }
    742 
    743     // Returns true if the live range of given value 'a' is before the live
    744     // range of 'b'.
    745     //
    746     // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not
    747     // updated as copies are removed.
    748     bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) {
    749       if (a.uses.empty()) {
    750         VLOG(2) << "Empty uses";
    751         return ordering_.IsDefinedBefore(*a.value, *b.value);
    752       }
    753       for (const HloUse* use : a.uses) {
    754         VLOG(2) << "use: " << *use;
    755         VLOG(2) << "is before:" << *b.value;
    756         if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) {
    757           VLOG(2) << "Not before";
    758           return false;
    759         }
    760       }
    761       return true;
    762     }
    763 
    764     // Returns whether 'node' is the last node in its list.
    765     bool IsTail(const ValueNode& node) const {
    766       return ContainsKey(value_lists_, node.next);
    767     }
    768 
    769     // Returns whether 'node' is the first node in its list.
    770     bool IsHead(const ValueNode& node) const {
    771       return ContainsKey(value_lists_, &node);
    772     }
    773 
    774     // Returns the next node in the list after 'node'. If 'node' is the
    775     // tail, then nullptr is returned.
    776     ValueNode* Next(const ValueNode& node) const {
    777       if (IsTail(node)) {
    778         return nullptr;
    779       } else {
    780         return node.next;
    781       }
    782     }
    783 
    784     // Returns the previous node in the list before 'node'. If 'node'
    785     // is the head, then nullptr is returned.
    786     ValueNode* Prev(const ValueNode& node) const {
    787       if (IsHead(node)) {
    788         return nullptr;
    789       } else {
    790         return node.prev;
    791       }
    792     }
    793 
    794     // Splices the entire linked list with 'head' as its head right after the
    795     // node 'insert_after' in another linked list.
    796     void SpliceAfter(ValueNode* head, ValueNode* insert_after) {
    797       DCHECK(IsHead(*head));
    798       value_lists_.erase(head);
    799 
    800       ValueNode* tail = head->prev;
    801       tail->next = insert_after->next;
    802       insert_after->next->prev = tail;
    803 
    804       insert_after->next = head;
    805       head->prev = insert_after;
    806     }
    807 
    808     string ValueListToString(const ValueNode* element) {
    809       const ValueNode* head = element;
    810       while (!IsHead(*head)) {
    811         head = Prev(*head);
    812       }
    813       std::vector<const HloValue*> values;
    814       for (const ValueNode* p = head; p != nullptr; p = Next(*p)) {
    815         values.push_back(p->value);
    816       }
    817       return StrCat("{",
    818                     Join(values, ", ",
    819                          [](string* s, const HloValue* value) {
    820                            StrAppend(s, value->ToShortString());
    821                          }),
    822                     "}");
    823     }
    824 
    825     string ToString() const {
    826       string out = StrCat("BufferValueTracker:\n");
    827       StrAppend(&out, "  Def-use chains in each buffer:\n");
    828       for (const ValueNode* head : value_lists_) {
    829         StrAppend(&out, "    Buffer defined by ", head->value->ToShortString(),
    830                   ":\n");
    831         const ValueNode* p = head;
    832         do {
    833           StrAppend(&out, "      ", p->value->ToShortString(), ", uses: ",
    834                     Join(p->uses, "; ",
    835                          [](string* s, const HloUse* use) {
    836                            StrAppend(s, use->ToString());
    837                          }),
    838                     "\n");
    839 
    840           p = p->next;
    841         } while (p != head);
    842       }
    843       StrAppend(&out, "  Potentially removable copies:\n");
    844       for (const auto& pair : copy_map_) {
    845         const HloInstruction* copy = pair.first;
    846         const CopyNodes& copy_info = pair.second;
    847 
    848         StrAppend(&out, "    ", copy->name(), " : ",
    849                   copy_info.src->value->ToShortString(), " => ",
    850                   copy_info.dest->value->ToShortString(), "\n");
    851       }
    852       return out;
    853     }
    854 
    855    private:
    856     const HloDataflowAnalysis& dataflow_;
    857     const HloOrdering& ordering_;
    858 
    859     // The heads of all the value lists. Each value list represents the HLO
    860     // values contained in a particular HLO buffer. The values in the list are
    861     // in dependency order.
    862     tensorflow::gtl::FlatSet<const ValueNode*> value_lists_;
    863 
    864     // Copy removal requires fast access to the value list elements
    865     // corresponding to the source and destination values of the kCopy
    866     // instruction. This data structure holds pointers to these elements for
    867     // each kCopy instruction in the graph.
    868     struct CopyNodes {
    869       // The source and destinations values of the kCopy instruction.
    870       ValueNode* src = nullptr;
    871       ValueNode* dest = nullptr;
    872     };
    873     tensorflow::gtl::FlatMap<const HloInstruction*, CopyNodes> copy_map_;
    874   };
    875 
    876   HloModule* module_;
    877   const HloAliasAnalysis& alias_analysis_;
    878   const HloOrdering& ordering_;
    879 
    880   // Object tracking the HLO values contained in each HLO buffer.
    881   BufferValueTracker buffer_value_tracker_;
    882 };
    883 
    884 // Try to remove as many copies from the module as possible without introducing
    885 // live range interference. Copy instructions (identified by their unique id) in
    886 // the set copies_to_exclude are not considered for removal.
    887 Status RemoveUnnecessaryCopies(
    888     const HloOrdering& ordering,
    889     const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module) {
    890   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
    891                       HloAliasAnalysis::Run(module));
    892   CopyRemover copy_remover(*alias_analysis, ordering, module);
    893   XLA_VLOG_LINES(3, copy_remover.ToString());
    894 
    895   tensorflow::gtl::FlatSet<int> existing_copies;
    896   for (HloComputation* computation : module->computations()) {
    897     for (HloInstruction* instruction : computation->instructions()) {
    898       if (instruction->opcode() == HloOpcode::kCopy &&
    899           !ContainsKey(copies_to_exclude, instruction->unique_id())) {
    900         TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status());
    901       }
    902     }
    903   }
    904 
    905   return Status::OK();
    906 }
    907 
    908 // Add copies to address special constraints on the roots of computations not
    909 // related to live range interference:
    910 //
    911 //    (1) Entry computation root must be unambiguous and distinct.
    912 //
    913 //    (2) Any computation called by a kCall instruction must have an
    914 //        unambiguous root.
    915 //
    916 //    (3) Constants and parameters cannot be live out of the entry computation
    917 //
    918 Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) {
    919   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
    920                       HloAliasAnalysis::Run(module));
    921 
    922   // Identify which shape indices of which instructions need to be copied. Store
    923   // these results in 'instructions_to_copy'.
    924   std::unordered_map<HloInstruction*, ShapeTree<bool>> instructions_to_copy;
    925   auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction,
    926                                                    const ShapeIndex& index) {
    927     auto it = instructions_to_copy.find(instruction);
    928     if (it == instructions_to_copy.end()) {
    929       auto it_added = instructions_to_copy.emplace(
    930           std::piecewise_construct, std::forward_as_tuple(instruction),
    931           std::forward_as_tuple(instruction->shape(), /*init_value=*/false));
    932       it = it_added.first;
    933     }
    934     *it->second.mutable_element(index) = true;
    935   };
    936 
    937   // Iterate through values of all constants and entry parameters. These values
    938   // are special because they are held in read-only buffers. If any of these
    939   // values share a buffer with other values (for example, the init value of a
    940   // while is a constant) then copy the value at its definition and replace all
    941   // its uses with the copy.
    942   for (const HloValue* value : alias_analysis->dataflow_analysis().values()) {
    943     if (ValueIsReadOnly(*value) &&
    944         alias_analysis->GetBufferContainingValue(*value).values().size() > 1) {
    945       VLOG(2) << "Value " << value->ToShortString()
    946               << " is read only, but its buffer contains more than one value. "
    947                  "Copying.";
    948       add_index_to_copy(value->defining_instruction(), value->defining_index());
    949     }
    950   }
    951 
    952   // Identify copies which must be added at root instructions
    953   for (HloComputation* computation : module->computations()) {
    954     const CallGraphNode& node = call_graph.GetNode(computation);
    955     if (node.context() == CallContext::kParallel) {
    956       continue;
    957     }
    958     TF_RET_CHECK(node.context() == CallContext::kSequential);
    959 
    960     const bool is_entry = computation == module->entry_computation();
    961     HloInstruction* root = computation->root_instruction();
    962 
    963     // Mark nondistinct/ambiguous indices.
    964     tensorflow::gtl::FlatSet<const HloBuffer*> seen;
    965     ShapeUtil::ForEachSubshape(
    966         root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
    967           std::vector<const HloBuffer*> buffers_at_index =
    968               alias_analysis->ComputeBuffersAt(root, index);
    969           bool buffer_seen_before = false;
    970           for (const HloBuffer* buffer : buffers_at_index) {
    971             buffer_seen_before |= !seen.insert(buffer).second;
    972           }
    973           if (buffers_at_index.size() > 1 || (buffer_seen_before && is_entry)) {
    974             VLOG(2) << "Index " << index << " of root of computation "
    975                     << computation->name() << " (" << root->name()
    976                     << ") has ambiguous or non-distinct buffer. Copying.";
    977             add_index_to_copy(root, index);
    978           }
    979         });
    980 
    981     // For entry instructions, mark any parameter or constant values.
    982     if (is_entry) {
    983       for (const auto& pair :
    984            alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) {
    985         const ShapeIndex& index = pair.first;
    986         const HloValueSet& value_set = pair.second;
    987         for (const HloValue* value : value_set.values()) {
    988           if (ValueIsReadOnly(*value)) {
    989             VLOG(2) << "Root of entry computation (" << root->name()
    990                     << ") has constant or entry parameter value at index "
    991                     << index << ". Copying.";
    992             add_index_to_copy(root, index);
    993           }
    994         }
    995       }
    996     }
    997   }
    998 
    999   // Add copy instructions indicated in 'instructions_to_copy' to the module.
   1000   for (const auto& pair : instructions_to_copy) {
   1001     HloInstruction* instruction = pair.first;
   1002     const ShapeTree<bool>& indices_to_copy = pair.second;
   1003 
   1004     std::vector<HloInstruction*> users = instruction->users();
   1005     TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
   1006                         instruction->parent()->DeepCopyInstruction(
   1007                             instruction, &indices_to_copy));
   1008     for (HloInstruction* user : users) {
   1009       TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy));
   1010     }
   1011     if (instruction == instruction->parent()->root_instruction()) {
   1012       instruction->parent()->set_root_instruction(deep_copy);
   1013     }
   1014   }
   1015 
   1016   return Status::OK();
   1017 }
   1018 
   1019 Status VerifyNoLiveRangeInterference(HloModule* module) {
   1020   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
   1021                       HloAliasAnalysis::Run(module));
   1022   DependencyHloOrdering ordering(module);
   1023   TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering));
   1024   return Status::OK();
   1025 }
   1026 
   1027 void MaybeDumpModule(const string& message, const HloModule& module) {
   1028   if (VLOG_IS_ON(3)) {
   1029     VLOG(3) << message;
   1030     XLA_VLOG_LINES(3, module.ToString());
   1031     hlo_graph_dumper::MaybeDumpHloModule(module, message);
   1032   }
   1033 }
   1034 
   1035 }  // namespace
   1036 
   1037 StatusOr<bool> CopyInsertion::Run(HloModule* module) {
   1038   // Copy insertion is performed in three steps:
   1039   //
   1040   // (1) Add copies conservatively to guarantee that there is no live-range
   1041   //     interference. This is done simplistically and usually results in more
   1042   //     copies than is strictly necessary.
   1043   //
   1044   // (2) Using a more fine-grained analysis, remove as many copies that were
   1045   //     added in (1) as possible while ensuring no live-range interference.
   1046   //
   1047   // (3) Add copies to resolve issues not related to live range interference
   1048   //     such as parameters and constants live out of the entry computation.
   1049   //
   1050   // We add copies then remove them (step (1) then (2)) rather than simply
   1051   // adding only the copies that are necessary because, in general, it is
   1052   // difficult to figure out the minimal set of copies to add once there is
   1053   // interference. On the other hand, it is easy to determine if removing a copy
   1054   // will introduce interference.
   1055   //
   1056   // The final copy insertion in (3) is done separately to simplify the
   1057   // implementation of copy removal in (2) which is the most complicated part of
   1058   // the pass. As is, copy removal only has to reason about live range
   1059   // interference. If all copies were added in step (1) then copy removal would
   1060   // also have to reason about things like constants and parameters live out of
   1061   // the computation.
   1062   MaybeDumpModule("before copy insertion", *module);
   1063 
   1064   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
   1065   if (!call_graph->IsFlattened()) {
   1066     return FailedPrecondition(
   1067         "Call graph must be flattened before copy insertion.");
   1068   }
   1069 
   1070   // Gather Ids of existing kCopy instructions in the module. We avoid removing
   1071   // these copies (except via DCE in TupleSimplifier) because they may have been
   1072   // added for reasons not considered by copy insertion (eg, layout assignment).
   1073   // Instruction id is used instead of HloInstruction* because the pointer
   1074   // values may be recycled.
   1075   tensorflow::gtl::FlatSet<int> existing_copies;
   1076   for (HloComputation* computation : module->computations()) {
   1077     for (HloInstruction* instruction : computation->instructions()) {
   1078       if (instruction->opcode() == HloOpcode::kCopy) {
   1079         existing_copies.insert(instruction->unique_id());
   1080       }
   1081     }
   1082   }
   1083 
   1084   TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module));
   1085 
   1086   // Simplify the tuple structures introduced by the deep copies. This should be
   1087   // done before removing copies (RemoveUnnecessaryCopies) because tuple
   1088   // simplification changes dependencies in the graph which changes live range
   1089   // interference in the graph. Also run DCE to remove the dead Tuple/GTE
   1090   // instructions introduced by tuple simplification.
   1091   TupleSimplifier tuple_simplifier;
   1092   HloDCE dce;
   1093   TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
   1094   TF_RETURN_IF_ERROR(dce.Run(module).status());
   1095 
   1096   TF_DCHECK_OK(VerifyNoLiveRangeInterference(module));
   1097 
   1098   MaybeDumpModule("after adding copies to resolve interference", *module);
   1099 
   1100   DependencyHloOrdering ordering(module);
   1101   TF_RETURN_IF_ERROR(
   1102       RemoveUnnecessaryCopies(ordering, existing_copies, module));
   1103 
   1104   MaybeDumpModule("after removing unnecessary copies", *module);
   1105 
   1106   TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
   1107 
   1108   MaybeDumpModule("after adding special-case copies", *module);
   1109 
   1110   TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
   1111   TF_RETURN_IF_ERROR(dce.Run(module).status());
   1112   TF_DCHECK_OK(VerifyNoLiveRangeInterference(module));
   1113 
   1114   MaybeDumpModule("after copy insertion", *module);
   1115 
   1116   if (VLOG_IS_ON(1)) {
   1117     int64 num_total_copies = 0;
   1118     for (HloComputation* computation : module->computations()) {
   1119       for (HloInstruction* instruction : computation->instructions()) {
   1120         if (instruction->opcode() == HloOpcode::kCopy) {
   1121           num_total_copies++;
   1122         }
   1123       }
   1124     }
   1125     VLOG(1) << "Num copies before copy-insertion: " << existing_copies.size();
   1126     VLOG(1) << "Num copies after copy-insertion: " << num_total_copies;
   1127   }
   1128 
   1129   return true;
   1130 }
   1131 
   1132 namespace {
   1133 
   1134 bool IsWhileBody(const HloComputation* computation,
   1135                  const CallGraph& call_graph) {
   1136   const CallGraphNode& node = call_graph.GetNode(computation);
   1137 
   1138   if (node.context() == CallContext::kSequential &&
   1139       !node.caller_callsites().empty()) {
   1140     // Callgraph should be flattened so sequential context computations can
   1141     // have at most one caller.
   1142     CHECK_EQ(node.caller_callsites().size(), 1);
   1143     const HloInstruction* calling_instruction =
   1144         node.caller_callsites()[0].instruction();
   1145     if (calling_instruction->opcode() == HloOpcode::kWhile &&
   1146         calling_instruction->while_body() == node.computation()) {
   1147       return true;
   1148     }
   1149   }
   1150   return false;
   1151 }
   1152 
   1153 }  // namespace
   1154 
   1155 /* static */ StatusOr<bool> CopyInsertion::AddCopiesForBufferAssignment(
   1156     HloModule* module) {
   1157   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
   1158   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow,
   1159                       HloDataflowAnalysis::Run(*module));
   1160 
   1161   bool changed = false;
   1162 
   1163   // If a buffer live out of a computation is a constant, a parameter, or not
   1164   // defined in the computation, then copy it to account for the limited
   1165   // computation-scoped analysis in buffer assignment. An exception to this rule
   1166   // is the while body which is handled properly without copies.
   1167   for (HloComputation* computation : module->computations()) {
   1168     if (computation == module->entry_computation() ||
   1169         IsWhileBody(computation, *call_graph)) {
   1170       continue;
   1171     }
   1172 
   1173     HloInstruction* root = computation->root_instruction();
   1174     ShapeTree<bool> indices_to_copy(root->shape(), /*init_value=*/false);
   1175     bool copy_root = false;
   1176     for (const auto& pair : dataflow->GetInstructionValueSet(root)) {
   1177       const ShapeIndex& index = pair.first;
   1178       const HloValueSet& value_set = pair.second;
   1179       for (const HloValue* value : value_set.values()) {
   1180         HloInstruction* def = value->defining_instruction();
   1181         if (def->parent() != computation ||
   1182             def->opcode() == HloOpcode::kConstant ||
   1183             def->opcode() == HloOpcode::kParameter) {
   1184           *indices_to_copy.mutable_element(index) = true;
   1185           copy_root = true;
   1186         }
   1187       }
   1188     }
   1189     if (copy_root) {
   1190       TF_ASSIGN_OR_RETURN(
   1191           HloInstruction * root_copy,
   1192           computation->DeepCopyInstruction(root, &indices_to_copy));
   1193       computation->set_root_instruction(root_copy);
   1194       changed = true;
   1195     }
   1196   }
   1197 
   1198   TupleSimplifier tuple_simplifier;
   1199   HloDCE dce;
   1200   TF_ASSIGN_OR_RETURN(bool tuple_simplifier_changed,
   1201                       tuple_simplifier.Run(module));
   1202   TF_ASSIGN_OR_RETURN(bool dce_changed, dce.Run(module));
   1203 
   1204   return changed || tuple_simplifier_changed || dce_changed;
   1205 }
   1206 
   1207 }  // namespace xla
   1208