1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/copy_insertion.h" 17 18 #include "absl/container/flat_hash_map.h" 19 #include "absl/container/flat_hash_set.h" 20 #include "absl/strings/str_cat.h" 21 #include "absl/strings/str_join.h" 22 #include "tensorflow/compiler/xla/service/dump.h" 23 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" 24 #include "tensorflow/compiler/xla/service/hlo_computation.h" 25 #include "tensorflow/compiler/xla/service/hlo_dce.h" 26 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" 27 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 28 #include "tensorflow/compiler/xla/service/hlo_module.h" 29 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 30 #include "tensorflow/compiler/xla/service/hlo_ordering.h" 31 #include "tensorflow/compiler/xla/service/logical_buffer.h" 32 #include "tensorflow/compiler/xla/service/tuple_simplifier.h" 33 #include "tensorflow/compiler/xla/status_macros.h" 34 #include "tensorflow/compiler/xla/statusor.h" 35 #include "tensorflow/compiler/xla/types.h" 36 #include "tensorflow/compiler/xla/util.h" 37 #include "tensorflow/core/platform/logging.h" 38 39 namespace xla { 40 namespace { 41 42 using absl::StrAppend; 43 44 bool IsReadonlyEntryParameterValue(const HloValue& value) { 45 const HloComputation* computation = value.defining_instruction()->parent(); 46 return value.defining_instruction()->opcode() == HloOpcode::kParameter && 47 computation == computation->parent()->entry_computation() && 48 !computation->parent()->input_output_alias_config().ParameterHasAlias( 49 value.defining_instruction()->parameter_number(), value.index()); 50 } 51 52 bool IsConstantValue(const HloValue& value) { 53 return value.defining_instruction()->opcode() == HloOpcode::kConstant; 54 } 55 56 bool ValueIsReadOnly(const HloValue& value) { 57 return IsConstantValue(value) || IsReadonlyEntryParameterValue(value); 58 } 59 60 // Data structure describing the action which should be taken on parts of a 61 // computation buffers, with respect to the adding of special case copies. 62 struct SpecialCaseCopyPolicy { 63 // Insert a copy if the same buffer is found at multiple indices within the 64 // output tuple. 65 bool copy_root_replicated_buffers = false; 66 // If true, insert a copy if a buffer coming from a constant or a parameter 67 // is found within the output tuple. 68 bool copy_parameters_and_constants = false; 69 }; 70 71 SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node, 72 HloModule* module, 73 HloComputation* computation) { 74 SpecialCaseCopyPolicy policy; 75 if (computation == module->entry_computation()) { 76 policy.copy_parameters_and_constants = true; 77 policy.copy_root_replicated_buffers = true; 78 } 79 return policy; 80 } 81 82 bool ShouldCopyRootValue(const HloValue& value, 83 const SpecialCaseCopyPolicy& policy) { 84 if (policy.copy_parameters_and_constants) { 85 return ValueIsReadOnly(value); 86 } 87 return false; 88 } 89 90 // Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in 91 // 'indices_to_copy'. Add control edges from the respective kCopy instructions 92 // in deep copy of 'from' to the respective kCopy instruction in the deep copy 93 // of 'to'. 94 // 95 // Requirements: 'from' and 'to' must have compatible shapes. 96 // 97 // For example, suppose 'from' and 'to' are two-element tuples where index 0 is 98 // the only index to copy. Prior to deep-copying we have: 99 // 100 // 101 // 'from' 102 // | 103 // ... 104 // | 105 // 'to' 106 // 107 // DeepCopyAndAddControlEdges produces: 108 // 109 // 'from' 110 // / \ 111 // GTE GTE 112 // | | 113 // Copy | 114 // / \ / 115 // | Tuple 116 // | | 117 // ctrl ... 118 // edge | 119 // | | 120 // | 'to' 121 // | / \ 122 // | GTE GTE 123 // \ | | 124 // Copy | 125 // \ / 126 // Tuple 127 // 128 StatusOr<std::pair<HloInstruction*, HloInstruction*>> 129 DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to, 130 const ShapeTree<bool>& indices_to_copy) { 131 DCHECK(ShapeUtil::Compatible(from->shape(), to->shape())); 132 // to/from_copy_tree hold the kCopy instruction produces by the deep 133 // copies. Elements which are not copied (indices_to_copy.element(index) == 134 // false) have nullptr at that index. 135 ShapeTree<HloInstruction*> from_copy_tree(from->shape(), 136 /*init_value=*/nullptr); 137 TF_ASSIGN_OR_RETURN(HloInstruction * from_deep_copy, 138 from->parent()->DeepCopyInstruction( 139 from, &indices_to_copy, &from_copy_tree)); 140 141 ShapeTree<HloInstruction*> to_copy_tree(to->shape(), /*init_value=*/nullptr); 142 TF_ASSIGN_OR_RETURN( 143 HloInstruction * to_deep_copy, 144 to->parent()->DeepCopyInstruction(to, &indices_to_copy, &to_copy_tree)); 145 146 // Add control edges between the respective kCopy instructions. 147 for (const auto& pair : from_copy_tree) { 148 const ShapeIndex& index = pair.first; 149 HloInstruction* from_copy = pair.second; 150 HloInstruction* to_copy = to_copy_tree.element(index); 151 if (from_copy == nullptr) { 152 TF_RET_CHECK(to_copy == nullptr); 153 continue; 154 } 155 TF_RET_CHECK(to_copy != nullptr); 156 TF_RETURN_IF_ERROR(from_copy->AddControlDependencyTo(to_copy)); 157 } 158 159 return std::make_pair(from_deep_copy, to_deep_copy); 160 } 161 162 // Compute the indices of the loop state which need copies in order to avoid 163 // live range interference. Generally, an element in the loop state does not 164 // need to be copied if the element is passed through transparently through the 165 // body. 166 // 167 // Returns whether any indices need to be copied. 168 bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow, 169 const HloInstruction* xla_while, 170 ShapeTree<bool>* indices_to_copy) { 171 DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(), xla_while->shape())); 172 173 bool any_copies = false; 174 const HloInstruction* init = xla_while->operand(0); 175 for (auto& pair : *indices_to_copy) { 176 const ShapeIndex& index = pair.first; 177 bool& should_copy = pair.second; 178 // If there is any ambiguity, then loop state must be copied. 179 if (dataflow.GetValueSet(init, index).values().size() > 1 || 180 dataflow.GetValueSet(xla_while, index).values().size() > 1) { 181 should_copy = true; 182 } else { 183 // If the output of the while instruction is not the same as the init 184 // value of the while, then this element is not passed through the body 185 // transparently and must be copied. 186 should_copy = dataflow.GetUniqueValueAt(xla_while, index) != 187 dataflow.GetUniqueValueAt(init, index); 188 } 189 any_copies |= should_copy; 190 } 191 return any_copies; 192 } 193 194 // Add kCopy instructions around the given kWhile instruction to eliminate any 195 // possible live range interference of HLO values assuming a dependency-based 196 // ordering (HloDependencyOrdering). Copies are added conservatively. There 197 // likely are copies which are not strictly necessary, but they are removed 198 // later in the pass via RemoveUnnecessaryCopies. 199 // 200 // 201 // Elements (each ShapeIndex) in the loop state are considered independently. A 202 // copy is added to each element of the loop state which is modified in the 203 // while body. For each such element, a total of three kCopy instructions are 204 // added at following locations: 205 // 206 // (1) The init value is copied before the kWhile instruction. Before: 207 // 208 // (Init) 209 // | 210 // kWhile 211 // | 212 // ... 213 // 214 // After: 215 // 216 // (Init) 217 // | 218 // kCopy 219 // | 220 // kWhile 221 // | 222 // ... 223 // 224 // This copy is necessary in case the init value is simultaneously live 225 // with the kWhile. 226 // 227 // (2) Copies are added to the parameter and root of the while body 228 // computation. Before: 229 // 230 // kParameter 231 // | 232 // ... 233 // | 234 // (body root) 235 // 236 // After: 237 // 238 // kParameter 239 // | 240 // kCopy ----------+ 241 // | | 242 // ... ctrl 243 // | edge 244 // (body root) | 245 // | | 246 // kCopy <---------+ 247 // 248 // The root kCopy becomes the new root of the computation. Both copies are 249 // necessary to any potential interference between the parameter value and 250 // the root value. The control edge prevents potential interference 251 // between the copies themselves. 252 // 253 // If the loop state is a tuple then the above kCopy instructions are a deep 254 // copy constructed of kCopy, KGetTupleElement, and kTuple instruction as 255 // constructed by HloInstruction::DeepCopyInstruction. 256 Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis, 257 HloInstruction* xla_while) { 258 VLOG(2) << "Adding copies for kWhile instruction " << xla_while->name(); 259 TF_RET_CHECK(xla_while->opcode() == HloOpcode::kWhile); 260 261 ShapeTree<bool> indices_to_copy(xla_while->shape()); 262 if (!IndicesToCopyForWhile(alias_analysis.dataflow_analysis(), xla_while, 263 &indices_to_copy)) { 264 VLOG(2) << "No copies necessary for kWhile instruction " 265 << xla_while->name(); 266 return Status::OK(); 267 } 268 269 VLOG(2) << "Adding copies for " << xla_while->name() << " at indices:"; 270 for (auto& pair : indices_to_copy) { 271 if (pair.second) { 272 VLOG(2) << " " << pair.first; 273 } 274 } 275 276 // Deep copy init. 277 HloInstruction* while_init = xla_while->mutable_operand(0); 278 TF_ASSIGN_OR_RETURN( 279 HloInstruction * while_init_copy, 280 xla_while->parent()->DeepCopyInstruction(while_init, &indices_to_copy)); 281 TF_RETURN_IF_ERROR(while_init->ReplaceUseWith(xla_while, while_init_copy)); 282 283 // Deep copy the parameter and the root. Extend a control edge from the copy 284 // of the parameter value to the corresponding copy value of the root. 285 HloComputation* body = xla_while->while_body(); 286 HloInstruction* param = body->parameter_instruction(0); 287 HloInstruction* root = body->root_instruction(); 288 289 // If param is the root then all indices should have been passed through the 290 // while body and we should have returned early above. 291 TF_RET_CHECK(param != root); 292 293 // Copy users before making a deep copy of the parameter as the deep copy 294 // will create new users of the parameter (eg, the GTE instructions of the 295 // deep copy). 296 std::vector<HloInstruction*> param_users = param->users(); 297 298 ShapeIndex current_index; 299 TF_ASSIGN_OR_RETURN(auto pair, 300 DeepCopyAndAddControlEdges(param, root, indices_to_copy)); 301 302 HloInstruction* param_copy = pair.first; 303 HloInstruction* root_copy = pair.second; 304 305 for (HloInstruction* user : param_users) { 306 TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, param_copy)); 307 } 308 309 body->set_root_instruction(root_copy); 310 311 return Status::OK(); 312 } 313 314 // We add copies for all the indices of the true and false computation roots, in 315 // order to resolve interference. We later rely on RemoveUnnecessaryCopies to 316 // drop the unnecessary ones. 317 Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, 318 HloInstruction* conditional) { 319 VLOG(2) << "Adding copies for kConditional instruction " 320 << conditional->name(); 321 TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional); 322 323 for (HloComputation* computation : conditional->branch_computations()) { 324 HloInstruction* root = computation->root_instruction(); 325 std::vector<HloInstruction*> users = root->users(); 326 TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, 327 computation->DeepCopyInstruction(root)); 328 for (HloInstruction* user : users) { 329 TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy)); 330 } 331 computation->set_root_instruction(deep_copy); 332 } 333 return Status::OK(); 334 } 335 336 // Conservatively adds copies before root instruction of entry computation and 337 // each aliased parameter to resolve interference of aliased input and output 338 // buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary 339 // ones. 340 Status AddCopiesForAliasedInputOutputs(HloModule* module) { 341 HloComputation* entry = module->entry_computation(); 342 HloInstruction* root = entry->root_instruction(); 343 344 ShapeTree<bool> output_indices_to_copy(root->shape()); 345 std::vector<absl::optional<ShapeTree<HloInstruction*>>> copied_parameters( 346 entry->num_parameters()); 347 bool has_alias = false; 348 for (auto* param : entry->parameter_instructions()) { 349 bool param_has_alias = false; 350 ShapeTree<bool> param_indices_to_copy(param->shape()); 351 352 module->input_output_alias_config().ForEachAlias( 353 [&](const ShapeIndex& output_index, 354 const HloInputOutputAliasConfig::Alias& alias) { 355 if (alias.parameter_number == param->parameter_number()) { 356 param_has_alias = true; 357 *(param_indices_to_copy.mutable_element(alias.parameter_index)) = 358 true; 359 *(output_indices_to_copy.mutable_element(output_index)) = true; 360 } 361 }); 362 363 if (!param_has_alias) { 364 continue; 365 } 366 367 TF_RET_CHECK(param->parameter_number() < entry->num_parameters()); 368 TF_RET_CHECK(!copied_parameters[param->parameter_number()]); 369 370 has_alias = true; 371 // Store a snapshot of users before DeepCopyInstruction, as 372 // DeepCopyInstruction introduces new users of the instruction. 373 std::vector<HloInstruction*> users = param->users(); 374 ShapeTree<HloInstruction*> param_copy_tree(param->shape(), 375 /*init_value=*/nullptr); 376 TF_ASSIGN_OR_RETURN(HloInstruction * copied, 377 entry->DeepCopyInstruction( 378 param, ¶m_indices_to_copy, ¶m_copy_tree)); 379 for (HloInstruction* user : users) { 380 TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, copied)); 381 } 382 383 copied_parameters[param->parameter_number()] = param_copy_tree; 384 } 385 386 if (!has_alias) { 387 return Status::OK(); 388 } 389 390 // Add copies before root instruction. 391 ShapeTree<HloInstruction*> output_copy_tree(root->shape(), 392 /*init_value=*/nullptr); 393 394 TF_ASSIGN_OR_RETURN(HloInstruction * root_copied, 395 root->parent()->DeepCopyInstruction( 396 root, &output_indices_to_copy, &output_copy_tree)); 397 398 // Add control dependencies between the input/output copies. 399 TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus( 400 [&](const ShapeIndex& output_index, 401 const HloInputOutputAliasConfig::Alias& alias) -> Status { 402 if (!copied_parameters[alias.parameter_number]) { 403 return Status::OK(); 404 } 405 HloInstruction* from = 406 copied_parameters[alias.parameter_number]->element( 407 alias.parameter_index); 408 HloInstruction* to = output_copy_tree.element(output_index); 409 410 TF_RET_CHECK(from != nullptr); 411 TF_RET_CHECK(to != nullptr); 412 TF_RETURN_IF_ERROR(from->AddControlDependencyTo(to)); 413 return Status::OK(); 414 })); 415 416 entry->set_root_instruction(root_copied); 417 418 return Status::OK(); 419 } 420 421 // Removes any control dependencies to or from the given instruction. 422 Status StripControlDependenciesFrom(HloInstruction* instruction) { 423 while (!instruction->control_successors().empty()) { 424 TF_RETURN_IF_ERROR(instruction->RemoveControlDependencyTo( 425 instruction->control_successors().front())); 426 } 427 428 while (!instruction->control_predecessors().empty()) { 429 TF_RETURN_IF_ERROR( 430 instruction->control_predecessors().front()->RemoveControlDependencyTo( 431 instruction)); 432 } 433 434 return Status::OK(); 435 } 436 437 // Class which tracks the HLO values within each HLO buffer in the module 438 // during copy removal. 439 // 440 // The values are held in a linked list where there is one list for each 441 // buffer. Removing a copy instruction merges together the values in the 442 // source buffer of the copy to the destination buffer of the copy. This class 443 // tracks these value lists as copies are removed from the graph (and value 444 // lists are merged). 445 // 446 // The CopyRemover object is initialized to match the state of 447 // HloAliasAnalysis. However, as copies are removed this state diverges. The 448 // values-to-buffer mapping is maintained outside of HloAliasAnalysis because 449 // a fully updatable alias analysis is very slow. 450 class CopyRemover { 451 public: 452 // The values held in a single HLO buffer are represented using a linked 453 // list. An element type in this list is ValueNode. 454 // 455 // This linked list is hand-rolled to enable efficient splicing of lists 456 // using only references to list elements without knowing which lists are 457 // being spliced. std::list requires a reference to the list object to 458 // splice. 459 struct ValueNode { 460 explicit ValueNode(const HloValue* v) : value(v) {} 461 462 const HloValue* value; 463 464 // The uses are maintained outside of HloValue::uses() because 465 // HloValue::uses() is not updatable (a fully updatable dataflow analysis 466 // is slow). 467 std::vector<const HloUse*> uses; 468 469 // next/prev elements in the linked list. The list is circularly linked so 470 // these values are never null for elements in the list. 471 ValueNode* prev = nullptr; 472 ValueNode* next = nullptr; 473 }; 474 475 CopyRemover(const HloModule& module, const HloAliasAnalysis& alias_analysis, 476 const HloOrdering& ordering) 477 : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) { 478 // Construct a list for each HLO buffer in the alias analysis. Maintain a 479 // map from HloValue to the respective list element representing that 480 // value. The map is used to construct the copy info map below. 481 absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node; 482 for (const HloBuffer& buffer : alias_analysis.buffers()) { 483 // Verify values contained in the buffer are strictly ordered. This 484 // should always be the case after adding copies to eliminate 485 // interference. Specifically, the addition of the control flow edges 486 // between copies added around aliased operations (kWhile) guarantees 487 // this strict order. 488 for (const HloValue* value_a : buffer.values()) { 489 if (value_a->shape().IsToken()) { 490 // Token values have no representation and cannot interfere. 491 continue; 492 } 493 for (const HloValue* value_b : buffer.values()) { 494 if (value_a != value_b) { 495 DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b, 496 dataflow_) || 497 ordering_.LiveRangeStrictlyBefore(*value_b, *value_a, 498 dataflow_)) 499 << value_a->ToShortString() << " and " 500 << value_b->ToShortString() << " are not ordered"; 501 } 502 } 503 } 504 505 std::vector<const HloValue*> values = buffer.values(); 506 absl::c_sort(values, [this](const HloValue* a, const HloValue* b) { 507 return ordering_.IsDefinedBefore(*a, *b); 508 }); 509 510 // Create a list containing all of the values in the buffer. 511 AddValueList(values, &value_to_node); 512 } 513 514 // Create copy_map_ which contains the source and destination values 515 // of all copies. 516 CreateCopyMap(module, value_to_node); 517 518 XLA_VLOG_LINES(3, ToString()); 519 TF_DCHECK_OK(Verify()); 520 } 521 522 // Add a list containing the given values to CopyRemover. This 523 // represents the values contained in a single buffer. For each value in 524 // 'values' an entry is created in value_to_node which indicates the 525 // respective ValueNode representing that value. 526 void AddValueList( 527 absl::Span<const HloValue* const> values, 528 absl::flat_hash_map<const HloValue*, ValueNode*>* value_to_node) { 529 ValueNode* tail = nullptr; 530 ValueNode* head = nullptr; 531 for (const HloValue* value : values) { 532 auto new_node = new ValueNode(value); 533 (*value_to_node)[value] = new_node; 534 535 // Copy the HLO values's uses into the ValueNode for the value. These 536 // uses in ValueNode are updated as copies are removed. 537 new_node->uses.reserve(value->uses().size()); 538 for (const HloUse& use : value->uses()) { 539 new_node->uses.push_back(&use); 540 } 541 542 // Connect the new node into the linked list. 543 if (tail == nullptr) { 544 head = new_node; 545 } else { 546 tail->next = new_node; 547 new_node->prev = tail; 548 } 549 tail = new_node; 550 } 551 552 // The linked list is circular so connect the head and tail. 553 tail->next = head; 554 head->prev = tail; 555 value_lists_.insert(head); 556 } 557 558 // This method also fills in copy_map_ which indicates which nodes 559 // in the value lists corresponding to the source and destination values of 560 // kCopy instructions. value_to_node should map each HloValue to its 561 // respective ValueNode. 562 void CreateCopyMap( 563 const HloModule& module, 564 const absl::flat_hash_map<const HloValue*, ValueNode*>& value_to_node) { 565 for (HloComputation* computation : module.computations()) { 566 for (HloInstruction* instruction : computation->instructions()) { 567 // Add copies with unambiguous source values to the map. Copies with 568 // ambiguous sources are not removable. 569 if (instruction->opcode() == HloOpcode::kCopy) { 570 const HloValueSet& src_value_set = 571 dataflow_.GetValueSet(instruction->operand(0)); 572 if (src_value_set.values().size() == 1) { 573 CopyNodes& copy_node = copy_map_[instruction]; 574 copy_node.dest = 575 value_to_node.at(&dataflow_.GetUniqueValueAt(instruction)); 576 copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue()); 577 } 578 } 579 } 580 } 581 } 582 583 ~CopyRemover() { 584 for (const ValueNode* head : value_lists_) { 585 const ValueNode* p = head; 586 do { 587 const ValueNode* tmp = p->next; 588 delete p; 589 p = tmp; 590 } while (p != head); 591 } 592 } 593 594 // Verify invariants within the linked lists. 595 Status Verify() const { 596 for (const ValueNode* head : value_lists_) { 597 const ValueNode* p = head; 598 do { 599 // Verify links between elements are consistent. 600 TF_RET_CHECK(p->prev->next == p); 601 TF_RET_CHECK(p->next->prev == p); 602 603 const HloInstruction* def = p->value->defining_instruction(); 604 if (def->opcode() == HloOpcode::kCopy && ContainsKey(copy_map_, def)) { 605 TF_RET_CHECK(copy_map_.at(def).dest == p); 606 } 607 for (const HloUse* use : p->uses) { 608 if (use->instruction->opcode() == HloOpcode::kCopy && 609 ContainsKey(copy_map_, use->instruction)) { 610 TF_RET_CHECK(copy_map_.at(use->instruction).src == p); 611 } 612 } 613 614 p = p->next; 615 } while (p != head); 616 } 617 return Status::OK(); 618 } 619 620 // Try to elide the given copy. Elision of a copy is possible only if no 621 // live range interference is introduced by the copy's elimination. If 622 // elision is possible, then the internal state (value lists) are updated, 623 // and true is returned. Returns false otherwise. 624 bool TryElideCopy(const HloInstruction* copy) { 625 VLOG(2) << "Trying to remove " << copy->name(); 626 627 if (!ContainsKey(copy_map_, copy)) { 628 VLOG(2) << copy->name() << " is not removable"; 629 return false; 630 } 631 if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) { 632 VLOG(2) << copy->name() << " is not removable (shape mismatch)"; 633 return false; 634 } 635 const CopyNodes& copy_node = copy_map_.at(copy); 636 ValueNode* src = copy_node.src; 637 ValueNode* dest = copy_node.dest; 638 DCHECK(src != nullptr); 639 DCHECK(dest != nullptr); 640 641 auto is_live_range_before = [this](const ValueNode& a, const ValueNode& b) { 642 VLOG(3) << "Checking live range of " << *a.value << " WRT " << *b.value; 643 if (LiveRangeBefore(a, b)) { 644 VLOG(2) << " Live range of " << a.value->ToShortString() 645 << " is before " << b.value->ToShortString(); 646 return true; 647 } else { 648 VLOG(2) << " Live range of " << a.value->ToShortString() 649 << " is not before " << b.value->ToShortString(); 650 return false; 651 } 652 }; 653 654 VLOG(3) << copy->name() << " copies value " << src->value->ToShortString(); 655 VLOG(3) << "Source buffer values: " << ValueListToString(src); 656 VLOG(3) << "Dest buffer values: " << ValueListToString(dest); 657 658 // A kCopy instruction copies an HLO value from a source buffer and 659 // defines an HLO value in a destination buffer. Most generally, the 660 // source and destination buffers may each hold more than one value at 661 // different points in the computation so we define the following: 662 // 663 // Values in source buffer: {s_0, ..., s_n} 664 // Values in destination buffer: {d_0, ..., d_m} 665 // 666 // A kCopy instruction between these buffers copies a value s_x in the 667 // source buffer and defines a value d_y in the destination buffer. The 668 // elision of a copy merges the source and destination buffers together, 669 // so the list of values for the source and destination buffers are 670 // merged. 671 // 672 // We handle two different cases for copy elision: 673 // 674 // (1) the kCopy defines the first value in the destination buffer (d_0). 675 // 676 // (2) the kCopy copies the last value in the source buffer (s_n). 677 // 678 // For the remaining case where the kCopy copies a not-last value from the 679 // source buffer to a not-first value of the destination buffer, the kCopy 680 // instruction cannot be removed. This case is generated, for example, if 681 // the kCopy copies a while body parameter of the loop state at one tuple 682 // index to a different tuple index in the while body root. Removal of the 683 // copy necessarily results in live range interference of values in the 684 // loop state at the two different tuple indices. 685 // 686 // We can only perform copy elision if the resulting merged values have 687 // totally ordered live ranges; otherwise the merged buffer would have 688 // live range interference. 689 if (src->next == dest) { 690 // In the process of eliding copies, its possible for a copy to have the 691 // same source and destination buffer. In this case, the copy can be 692 // safely removed. 693 VLOG(2) << copy->name() << " source and destination buffers are same."; 694 } else if (IsHead(*dest)) { 695 // The copy copies an arbitrary value in the source buffer (call it s_x) 696 // and defines d_0, the first value in the destination buffer. After 697 // merging, the values in the combined buffer must be strictly ordered 698 // as follows** to elide the copy: 699 // 700 // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n} 701 // 702 // Removing the copy eliminates d_0, and uses of d_0 become uses of 703 // s_x. In the above ordering, the live range of d_m must be ordered 704 // before the live range of s_{x+1} and the definition and all uses of 705 // s_x must be ordered before the definition of d_1. These conditions 706 // are checked below prior to elision. 707 // 708 // ** Technically it might be possible to have a non-interfering 709 // non-trivial interleaving of the values of the source and 710 // destination buffers in the resulting order. However, this case is 711 // slow and complicated to check and likely not worth it. So instead 712 // we simply check for the case where *all* values of the destination 713 // buffer (d_1 through d_m) are spliced into the point where the copy 714 // used to be. 715 VLOG(2) << copy->name() << " defines the first value in its buffer"; 716 ValueNode* next_dest = Next(*dest); 717 if (next_dest != nullptr) { 718 // Live range of 'from' value (s_x) must be before 'next_dest' (d_1); 719 if (!is_live_range_before(*src, *next_dest)) { 720 return false; 721 } 722 } 723 ValueNode* next_src = Next(*src); 724 725 if (next_src != nullptr) { 726 // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}. 727 ValueNode* last_dest = dest->prev; 728 DCHECK(IsTail(*last_dest)); 729 if (!is_live_range_before(*last_dest, *next_src)) { 730 return false; 731 } 732 } 733 734 // Splice in destination buffer values list right after 'src'. 735 SpliceAfter(dest, src); 736 } else if (IsTail(*src)) { 737 // The copy copies the last value in the source buffer, s_n, and defines 738 // an arbitrary value in the destination buffer, d_y. After 739 // merging, the values in the combined buffer must be strictly ordered 740 // as follows** to elide the copy: 741 // 742 // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m} 743 // 744 // Removing the copy eliminates d_y, and uses of d_y become uses of 745 // s_n. To enforce the above order, the live range of d_{y-1} must be 746 // before the live range of s_0, and the live range of s_n must be 747 // before the live range of d_{y+1}. 748 // 749 // ** See comment above in the code handling Case (1). 750 VLOG(2) << copy->name() << " copies the last value (" 751 << src->value->ToShortString() << ") in its buffer"; 752 753 ValueNode* prev_dest = Prev(*dest); 754 // nullptr condition handled above in the first 'if' case. 755 DCHECK(prev_dest != nullptr); 756 ValueNode* first_src = src->next; 757 DCHECK(IsHead(*first_src)); 758 if (!is_live_range_before(*prev_dest, *first_src)) { 759 // Live range of value d_{y-1} is not before s_0. 760 return false; 761 } 762 ValueNode* next_dest = Next(*dest); 763 if (next_dest != nullptr) { 764 if (!is_live_range_before(*src, *next_dest)) { 765 // Live range of value s_n is not before d_{y+1}. 766 return false; 767 } 768 } 769 770 // Splice source buffer values list right after 'prev_dest'. 771 SpliceAfter(first_src, prev_dest); 772 } else { 773 VLOG(2) << copy->name() 774 << " copies value in middle of source buffer to value in middle " 775 "of destination buffer"; 776 return false; 777 } 778 779 RemoveCopyValue(dest); 780 781 XLA_VLOG_LINES(4, ToString()); 782 TF_DCHECK_OK(Verify()); 783 784 return true; 785 } 786 787 // Delete the given ValueNode associated with a elided kCopy 788 // instruction. This should be called after splicing the value lists of the 789 // source and destination buffers together. 790 void RemoveCopyValue(ValueNode* copy_value_node) { 791 CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(), 792 HloOpcode::kCopy); 793 ValueNode* operand_node = copy_value_node->prev; 794 CHECK(operand_node != copy_value_node); 795 796 VLOG(2) << "Removing copy " << operand_node->value->ToShortString() 797 << " => " << copy_value_node->value->ToShortString(); 798 799 // Splice out the copy value node. 800 operand_node->next = copy_value_node->next; 801 copy_value_node->next->prev = operand_node; 802 803 // Patch up uses. Remove use of copy from operand_node uses. 804 auto it = absl::c_find_if(operand_node->uses, [copy_value_node]( 805 const HloUse* use) { 806 return use->instruction == copy_value_node->value->defining_instruction(); 807 }); 808 CHECK(it != operand_node->uses.end()); 809 operand_node->uses.erase(it); 810 811 // If the elided copy has any uses which are themselves kCopy instructions 812 // then patch up the copy info to reflect the that this kCopy instruction 813 // has a different operand (the operand of the elided copy). 814 for (const HloUse* copy_use : copy_value_node->uses) { 815 operand_node->uses.push_back(copy_use); 816 if (copy_use->instruction->opcode() == HloOpcode::kCopy && 817 ContainsKey(copy_map_, copy_use->instruction)) { 818 copy_map_.at(copy_use->instruction).src = operand_node; 819 } 820 } 821 822 // Delete the copy info and the value node. 823 copy_map_.erase(copy_value_node->value->defining_instruction()); 824 delete copy_value_node; 825 } 826 827 // Returns true if the live range of given value 'a' is before the live 828 // range of 'b'. 829 // 830 // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not 831 // updated as copies are removed. 832 bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) { 833 if (a.uses.empty()) { 834 VLOG(2) << "Empty uses for " << *a.value; 835 return ordering_.IsDefinedBefore(*a.value, *b.value); 836 } 837 for (const HloUse* use : a.uses) { 838 VLOG(2) << "Checking use " << *use << " against " << *b.value; 839 if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) { 840 VLOG(2) << "Use " << *use << " is NOT before " << *b.value; 841 return false; 842 } 843 VLOG(2) << "Use " << *use << " is before " << *b.value; 844 } 845 return true; 846 } 847 848 // Returns whether 'node' is the last node in its list. 849 bool IsTail(const ValueNode& node) const { 850 return ContainsKey(value_lists_, node.next); 851 } 852 853 // Returns whether 'node' is the first node in its list. 854 bool IsHead(const ValueNode& node) const { 855 return ContainsKey(value_lists_, &node); 856 } 857 858 // Returns the next node in the list after 'node'. If 'node' is the 859 // tail, then nullptr is returned. 860 ValueNode* Next(const ValueNode& node) const { 861 if (IsTail(node)) { 862 return nullptr; 863 } else { 864 return node.next; 865 } 866 } 867 868 // Returns the previous node in the list before 'node'. If 'node' 869 // is the head, then nullptr is returned. 870 ValueNode* Prev(const ValueNode& node) const { 871 if (IsHead(node)) { 872 return nullptr; 873 } else { 874 return node.prev; 875 } 876 } 877 878 // Splices the entire linked list with 'head' as its head right after the 879 // node 'insert_after' in another linked list. 880 void SpliceAfter(ValueNode* head, ValueNode* insert_after) { 881 DCHECK(IsHead(*head)); 882 value_lists_.erase(head); 883 884 ValueNode* tail = head->prev; 885 tail->next = insert_after->next; 886 insert_after->next->prev = tail; 887 888 insert_after->next = head; 889 head->prev = insert_after; 890 } 891 892 string ValueListToString(const ValueNode* element) { 893 const ValueNode* head = element; 894 while (!IsHead(*head)) { 895 head = Prev(*head); 896 } 897 std::vector<const HloValue*> values; 898 for (const ValueNode* p = head; p != nullptr; p = Next(*p)) { 899 values.push_back(p->value); 900 } 901 return absl::StrCat("{", 902 absl::StrJoin(values, ", ", 903 [](string* s, const HloValue* value) { 904 StrAppend(s, value->ToShortString()); 905 }), 906 "}"); 907 } 908 909 string ToString() const { 910 string out = absl::StrCat("CopyRemover:\n"); 911 StrAppend(&out, " Def-use chains in each buffer:\n"); 912 for (const ValueNode* head : value_lists_) { 913 StrAppend(&out, " Buffer defined by ", head->value->ToShortString(), 914 ":\n"); 915 const ValueNode* p = head; 916 do { 917 StrAppend(&out, " ", p->value->ToShortString(), ", uses: ", 918 absl::StrJoin(p->uses, "; ", 919 [](string* s, const HloUse* use) { 920 StrAppend(s, use->ToString()); 921 }), 922 "\n"); 923 924 p = p->next; 925 } while (p != head); 926 } 927 StrAppend(&out, " Potentially removable copies:\n"); 928 for (const auto& pair : copy_map_) { 929 const HloInstruction* copy = pair.first; 930 const CopyNodes& copy_info = pair.second; 931 932 StrAppend(&out, " ", copy->name(), " : ", 933 copy_info.src->value->ToShortString(), " => ", 934 copy_info.dest->value->ToShortString(), "\n"); 935 } 936 return out; 937 } 938 939 private: 940 const HloDataflowAnalysis& dataflow_; 941 const HloOrdering& ordering_; 942 943 // The heads of all the value lists. Each value list represents the HLO 944 // values contained in a particular HLO buffer. The values in the list are 945 // in dependency order. 946 absl::flat_hash_set<const ValueNode*> value_lists_; 947 948 // Copy removal requires fast access to the value list elements 949 // corresponding to the source and destination values of the kCopy 950 // instruction. This data structure holds pointers to these elements for 951 // each kCopy instruction in the graph. 952 struct CopyNodes { 953 // The source and destinations values of the kCopy instruction. 954 ValueNode* src = nullptr; 955 ValueNode* dest = nullptr; 956 }; 957 absl::flat_hash_map<const HloInstruction*, CopyNodes> copy_map_; 958 }; 959 960 } // namespace 961 962 // Add kCopy instructions to the given module to guarantee there is no 963 // live-range interference. Generally interference can only occur around kWhile 964 // instructions which have update-in-place semantics. 965 Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) { 966 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, 967 HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); 968 969 for (HloComputation* computation : module->computations()) { 970 for (HloInstruction* instruction : computation->instructions()) { 971 if (instruction->opcode() == HloOpcode::kWhile) { 972 TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction)); 973 } else if (instruction->opcode() == HloOpcode::kConditional) { 974 TF_RETURN_IF_ERROR( 975 AddCopiesForConditional(*alias_analysis, instruction)); 976 } 977 } 978 } 979 980 TF_RETURN_IF_ERROR(AddCopiesForAliasedInputOutputs(module)); 981 return Status::OK(); 982 } 983 984 Status CopyInsertion::AddSpecialCaseCopies(HloModule* module) { 985 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); 986 return AddSpecialCaseCopies(*call_graph, module); 987 } 988 989 Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, 990 HloModule* module) { 991 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, 992 HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); 993 994 // Identify which shape indices of which instructions need to be copied. Store 995 // these results in 'instructions_to_copy'. 996 HloInstructionMap<ShapeTree<bool>> instructions_to_copy; 997 auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction, 998 const ShapeIndex& index) { 999 auto it = instructions_to_copy.find(instruction); 1000 if (it == instructions_to_copy.end()) { 1001 auto it_added = instructions_to_copy.emplace( 1002 std::piecewise_construct, std::forward_as_tuple(instruction), 1003 std::forward_as_tuple(instruction->shape(), /*init_value=*/false)); 1004 it = it_added.first; 1005 } 1006 *it->second.mutable_element(index) = true; 1007 }; 1008 1009 // Iterate through values of all constants and entry parameters. These values 1010 // are special because they are held in read-only buffers. If any of these 1011 // values share a buffer with other values (for example, the init value of a 1012 // while is a constant) then copy the value at its definition and replace all 1013 // its uses with the copy. 1014 for (const HloValue* value : alias_analysis->dataflow_analysis().values()) { 1015 if (ValueIsReadOnly(*value) && 1016 alias_analysis->GetBufferContainingValue(*value).values().size() > 1) { 1017 VLOG(2) << "Value " << value->ToShortString() 1018 << " is read only, but its buffer contains more than one value. " 1019 "Copying."; 1020 add_index_to_copy(value->defining_instruction(), value->defining_index()); 1021 } 1022 } 1023 1024 // Identify copies which must be added at root instructions 1025 for (HloComputation* computation : module->computations()) { 1026 const CallGraphNode& node = call_graph.GetNode(computation); 1027 if (node.context() == CallContext::kParallel) { 1028 continue; 1029 } 1030 TF_RET_CHECK(node.context() == CallContext::kSequential); 1031 1032 SpecialCaseCopyPolicy policy = 1033 GetSpecialCaseCopyPolicy(node, module, computation); 1034 HloInstruction* root = computation->root_instruction(); 1035 1036 // Mark nondistinct/ambiguous indices. 1037 absl::flat_hash_set<const HloBuffer*> seen; 1038 ShapeUtil::ForEachSubshape( 1039 root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) { 1040 std::vector<const HloBuffer*> buffers_at_index = 1041 alias_analysis->ComputeBuffersAt(root, index); 1042 bool buffer_seen_before = false; 1043 for (const HloBuffer* buffer : buffers_at_index) { 1044 buffer_seen_before |= !seen.insert(buffer).second; 1045 } 1046 if (buffers_at_index.size() > 1 || 1047 (buffer_seen_before && policy.copy_root_replicated_buffers)) { 1048 VLOG(2) << "Index " << index << " of computation " 1049 << computation->name() << " (" << root->name() 1050 << ") has ambiguous or non-distinct buffer. Copying."; 1051 add_index_to_copy(root, index); 1052 } 1053 }); 1054 1055 for (const auto& pair : 1056 alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) { 1057 const ShapeIndex& index = pair.first; 1058 const HloValueSet& value_set = pair.second; 1059 for (const HloValue* value : value_set.values()) { 1060 if (ShouldCopyRootValue(*value, policy)) { 1061 VLOG(2) << "Root of (" << root->name() << ") of computation(" 1062 << computation->name() 1063 << ") has constant or parameter value at index " << index 1064 << ". Copying."; 1065 add_index_to_copy(root, index); 1066 } 1067 } 1068 } 1069 } 1070 1071 // Add copy instructions indicated in 'instructions_to_copy' to the module. 1072 for (const auto& pair : instructions_to_copy) { 1073 HloInstruction* instruction = pair.first; 1074 const ShapeTree<bool>& indices_to_copy = pair.second; 1075 1076 ShapeTree<HloInstruction*> copies_added(indices_to_copy.shape()); 1077 std::vector<HloInstruction*> users = instruction->users(); 1078 TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, 1079 instruction->parent()->DeepCopyInstruction( 1080 instruction, &indices_to_copy, &copies_added)); 1081 for (HloInstruction* user : users) { 1082 TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy)); 1083 } 1084 if (instruction == instruction->parent()->root_instruction()) { 1085 instruction->parent()->set_root_instruction(deep_copy); 1086 } 1087 } 1088 return Status::OK(); 1089 } 1090 1091 Status CopyInsertion::VerifyNoLiveRangeInterference(const HloOrdering& ordering, 1092 HloModule* module) { 1093 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, 1094 HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); 1095 TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering)); 1096 return Status::OK(); 1097 } 1098 1099 Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering, 1100 HloModule* module) { 1101 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, 1102 HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); 1103 1104 CopyRemover copy_remover(*module, *alias_analysis, ordering); 1105 if (VLOG_IS_ON(3)) { 1106 LOG(INFO) << "Removing unnecessary copies in " << module->name(); 1107 LOG(INFO) << "Buffer values, in dependency order: "; 1108 for (const HloBuffer& buffer : alias_analysis->buffers()) { 1109 LOG(INFO) << " HloBuffer " << buffer.id(); 1110 } 1111 } 1112 1113 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); 1114 for (HloComputation* computation : module->computations()) { 1115 for (HloInstruction* instruction : computation->instructions()) { 1116 if (instruction->opcode() == HloOpcode::kCopy && 1117 copy_remover.TryElideCopy(instruction)) { 1118 TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction)); 1119 TF_RETURN_IF_ERROR( 1120 instruction->ReplaceAllUsesWith(instruction->mutable_operand(0))); 1121 } 1122 } 1123 } 1124 return Status::OK(); 1125 } 1126 1127 StatusOr<bool> CopyInsertion::Run(HloModule* module) { 1128 // Copy insertion is performed in three steps: 1129 // 1130 // (1) Add copies conservatively to guarantee that there is no live-range 1131 // interference. This is done simplistically and usually results in more 1132 // copies than is strictly necessary. 1133 // 1134 // (2) Using a more fine-grained analysis, remove as many copies that were 1135 // added in (1) as possible while ensuring no live-range interference. 1136 // 1137 // (3) Add copies to resolve issues not related to live range interference 1138 // such as parameters and constants live out of the entry computation. 1139 // 1140 // We add copies then remove them (step (1) then (2)) rather than simply 1141 // adding only the copies that are necessary because, in general, it is 1142 // difficult to figure out the minimal set of copies to add once there is 1143 // interference. On the other hand, it is easy to determine if removing a copy 1144 // will introduce interference. 1145 // 1146 // The final copy insertion in (3) is done separately to simplify the 1147 // implementation of copy removal in (2) which is the most complicated part of 1148 // the pass. As is, copy removal only has to reason about live range 1149 // interference. If all copies were added in step (1) then copy removal would 1150 // also have to reason about things like constants and parameters live out of 1151 // the computation. 1152 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); 1153 if (!call_graph->IsFlattened()) { 1154 return FailedPrecondition( 1155 "Call graph must be flattened before copy insertion."); 1156 } 1157 1158 int64 num_existing_copies = 0; 1159 if (VLOG_IS_ON(1)) { 1160 for (HloComputation* computation : module->computations()) { 1161 for (HloInstruction* instruction : computation->instructions()) { 1162 if (instruction->opcode() == HloOpcode::kCopy) { 1163 ++num_existing_copies; 1164 } 1165 } 1166 } 1167 } 1168 1169 TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module)); 1170 1171 // Simplify the tuple structures introduced by the deep copies. This should be 1172 // done before removing copies (RemoveUnnecessaryCopies) because tuple 1173 // simplification changes dependencies in the graph which changes live range 1174 // interference in the graph. Also run DCE to remove the dead Tuple/GTE 1175 // instructions introduced by tuple simplification. 1176 TupleSimplifier tuple_simplifier; 1177 HloDCE dce; 1178 TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); 1179 TF_RETURN_IF_ERROR(dce.Run(module).status()); 1180 DumpHloModuleDuringPassIfEnabled( 1181 name(), "after adding copies to resolve interference", *module); 1182 1183 DependencyHloOrdering dep_ordering(module); 1184 TF_DCHECK_OK(VerifyNoLiveRangeInterference(dep_ordering, module)); 1185 1186 TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(dep_ordering, module)); 1187 DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies", 1188 *module); 1189 1190 TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); 1191 DumpHloModuleDuringPassIfEnabled(name(), "after adding special-case copies", 1192 *module); 1193 1194 TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); 1195 TF_RETURN_IF_ERROR(dce.Run(module).status()); 1196 TF_DCHECK_OK( 1197 VerifyNoLiveRangeInterference(DependencyHloOrdering(module), module)); 1198 1199 if (VLOG_IS_ON(1)) { 1200 int64 num_total_copies = 0; 1201 for (HloComputation* computation : module->computations()) { 1202 for (HloInstruction* instruction : computation->instructions()) { 1203 if (instruction->opcode() == HloOpcode::kCopy) { 1204 num_total_copies++; 1205 } 1206 } 1207 } 1208 VLOG(1) << "Num copies before copy-insertion: " << num_existing_copies; 1209 VLOG(1) << "Num copies after copy-insertion: " << num_total_copies; 1210 } 1211 1212 return true; 1213 } 1214 1215 namespace { 1216 1217 bool IsWhileBody(const HloComputation* computation, 1218 const CallGraph& call_graph) { 1219 const CallGraphNode& node = call_graph.GetNode(computation); 1220 1221 if (node.context() == CallContext::kSequential && 1222 !node.caller_callsites().empty()) { 1223 // Callgraph should be flattened so sequential context computations can 1224 // have at most one caller. 1225 CHECK_EQ(node.caller_callsites().size(), 1); 1226 const HloInstruction* calling_instruction = 1227 node.caller_callsites()[0].instruction(); 1228 if (calling_instruction->opcode() == HloOpcode::kWhile && 1229 calling_instruction->while_body() == node.computation()) { 1230 return true; 1231 } 1232 } 1233 return false; 1234 } 1235 1236 } // namespace 1237 1238 /* static */ StatusOr<bool> CopyInsertion::AddCopiesForBufferAssignment( 1239 HloModule* module) { 1240 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); 1241 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow, 1242 HloDataflowAnalysis::Run(*module)); 1243 1244 bool changed = false; 1245 1246 // If a buffer live out of a computation is a constant, a parameter, or not 1247 // defined in the computation, then copy it to account for the limited 1248 // computation-scoped analysis in buffer assignment. An exception to this rule 1249 // is the while body which is handled properly without copies. 1250 for (HloComputation* computation : module->computations()) { 1251 if (computation == module->entry_computation() || 1252 IsWhileBody(computation, *call_graph)) { 1253 continue; 1254 } 1255 1256 HloInstruction* root = computation->root_instruction(); 1257 ShapeTree<bool> indices_to_copy(root->shape(), /*init_value=*/false); 1258 bool copy_root = false; 1259 for (const auto& pair : dataflow->GetInstructionValueSet(root)) { 1260 const ShapeIndex& index = pair.first; 1261 const HloValueSet& value_set = pair.second; 1262 for (const HloValue* value : value_set.values()) { 1263 HloInstruction* def = value->defining_instruction(); 1264 if (def->parent() != computation || 1265 def->opcode() == HloOpcode::kConstant || 1266 def->opcode() == HloOpcode::kParameter) { 1267 *indices_to_copy.mutable_element(index) = true; 1268 copy_root = true; 1269 } 1270 } 1271 } 1272 if (copy_root) { 1273 TF_ASSIGN_OR_RETURN( 1274 HloInstruction * root_copy, 1275 computation->DeepCopyInstruction(root, &indices_to_copy)); 1276 computation->set_root_instruction(root_copy); 1277 changed = true; 1278 } 1279 } 1280 1281 TupleSimplifier tuple_simplifier; 1282 HloDCE dce; 1283 TF_ASSIGN_OR_RETURN(bool tuple_simplifier_changed, 1284 tuple_simplifier.Run(module)); 1285 TF_ASSIGN_OR_RETURN(bool dce_changed, dce.Run(module)); 1286 1287 return changed || tuple_simplifier_changed || dce_changed; 1288 } 1289 1290 } // namespace xla 1291