1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/copy_insertion.h" 17 18 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" 19 #include "tensorflow/compiler/xla/service/hlo_computation.h" 20 #include "tensorflow/compiler/xla/service/hlo_dce.h" 21 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 #include "tensorflow/compiler/xla/service/hlo_module.h" 24 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 25 #include "tensorflow/compiler/xla/service/hlo_ordering.h" 26 #include "tensorflow/compiler/xla/service/liveness_util.h" 27 #include "tensorflow/compiler/xla/service/logical_buffer.h" 28 #include "tensorflow/compiler/xla/service/tuple_simplifier.h" 29 #include "tensorflow/compiler/xla/status_macros.h" 30 #include "tensorflow/compiler/xla/statusor.h" 31 #include "tensorflow/compiler/xla/types.h" 32 #include "tensorflow/compiler/xla/util.h" 33 #include "tensorflow/core/lib/gtl/flatmap.h" 34 #include "tensorflow/core/lib/gtl/flatset.h" 35 #include "tensorflow/core/lib/strings/str_util.h" 36 #include "tensorflow/core/lib/strings/strcat.h" 37 #include "tensorflow/core/platform/logging.h" 38 39 namespace xla { 40 41 using ::tensorflow::str_util::Join; 42 using ::tensorflow::strings::StrAppend; 43 using ::tensorflow::strings::StrCat; 44 45 namespace { 46 47 bool IsEntryParameterValue(const HloValue& value) { 48 const HloComputation* computation = value.defining_instruction()->parent(); 49 return value.defining_instruction()->opcode() == HloOpcode::kParameter && 50 computation == computation->parent()->entry_computation(); 51 } 52 53 bool IsConstantValue(const HloValue& value) { 54 return value.defining_instruction()->opcode() == HloOpcode::kConstant; 55 } 56 57 bool ValueIsReadOnly(const HloValue& value) { 58 return IsConstantValue(value) || IsEntryParameterValue(value); 59 } 60 61 // Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in 62 // 'indices_to_copy'. Add control edges from the respective kCopy instructions 63 // in deep copy of 'from' to the respective kCopy instruction in the deep copy 64 // of 'to'. 65 // 66 // Requirements: 'from' and 'to' must have compatible shapes. 67 // 68 // For example, suppose 'from' and 'to' are two-element tuples where index 0 is 69 // the only index to copy. Prior to deep-copying we have: 70 // 71 // 72 // 'from' 73 // | 74 // ... 75 // | 76 // 'to' 77 // 78 // DeepCopyAndAddControlEdges produces: 79 // 80 // 'from' 81 // / \ 82 // GTE GTE 83 // | | 84 // Copy | 85 // / \ / 86 // | Tuple 87 // | | 88 // ctrl ... 89 // edge | 90 // | | 91 // | 'to' 92 // | / \ 93 // | GTE GTE 94 // \ | | 95 // Copy | 96 // \ / 97 // Tuple 98 // 99 StatusOr<std::pair<HloInstruction*, HloInstruction*>> 100 DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to, 101 const ShapeTree<bool>& indices_to_copy) { 102 DCHECK(ShapeUtil::Compatible(from->shape(), to->shape())); 103 // to/from_copy_tree hold the kCopy instruction produces by the deep 104 // copies. Elements which are not copied (indices_to_copy.element(index) == 105 // false) have nullptr at that index. 106 ShapeTree<HloInstruction*> from_copy_tree(from->shape(), 107 /*init_value=*/nullptr); 108 TF_ASSIGN_OR_RETURN(HloInstruction * from_deep_copy, 109 from->parent()->DeepCopyInstruction( 110 from, &indices_to_copy, &from_copy_tree)); 111 112 ShapeTree<HloInstruction*> to_copy_tree(to->shape(), /*init_value=*/nullptr); 113 TF_ASSIGN_OR_RETURN( 114 HloInstruction * to_deep_copy, 115 to->parent()->DeepCopyInstruction(to, &indices_to_copy, &to_copy_tree)); 116 117 // Add control edges between the respective kCopy instructions. 118 for (const auto& pair : from_copy_tree) { 119 const ShapeIndex& index = pair.first; 120 HloInstruction* from_copy = pair.second; 121 HloInstruction* to_copy = to_copy_tree.element(index); 122 if (from_copy == nullptr) { 123 TF_RET_CHECK(to_copy == nullptr); 124 continue; 125 } 126 TF_RET_CHECK(to_copy != nullptr); 127 TF_RETURN_IF_ERROR(from_copy->AddControlDependencyTo(to_copy)); 128 } 129 130 return std::make_pair(from_deep_copy, to_deep_copy); 131 } 132 133 // Compute the indices of the loop state which need copies in order to avoid 134 // live range interference. Generally, an element in the loop state does not 135 // need to be copied if the element is passed through transparently through the 136 // body. 137 // 138 // Returns whether any indices need to be copied. 139 bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow, 140 const HloInstruction* xla_while, 141 ShapeTree<bool>* indices_to_copy) { 142 DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(), xla_while->shape())); 143 144 bool any_copies = false; 145 const HloInstruction* init = xla_while->operand(0); 146 for (auto& pair : *indices_to_copy) { 147 const ShapeIndex& index = pair.first; 148 bool& should_copy = pair.second; 149 // If there is any ambiguity, then loop state must be copied. 150 if (dataflow.GetValueSet(init, index).values().size() > 1 || 151 dataflow.GetValueSet(xla_while, index).values().size() > 1) { 152 should_copy = true; 153 } else { 154 // If the output of the while instruction is not the same as the init 155 // value of the while, then this element is not passed through the body 156 // transparently and must be copied. 157 should_copy = dataflow.GetUniqueValueAt(xla_while, index) != 158 dataflow.GetUniqueValueAt(init, index); 159 } 160 any_copies |= should_copy; 161 } 162 return any_copies; 163 } 164 165 // Add kCopy instructions around the given kWhile instruction to eliminate any 166 // possible live range interference of HLO values assuming a dependency-based 167 // ordering (HloDependencyOrdering). Copies are added conservatively. There 168 // likely are copies which are not strictly necessary, but there are removed 169 // later in the pass via CopyRemover. 170 // 171 // 172 // Elements (each ShapeIndex) in the loop state are considered independently. A 173 // copy is added to each element of the loop state which is modified in the 174 // while body. For each such element, a total of three kCopy instructions are 175 // added at following locations: 176 // 177 // (1) The init value is copied before the kWhile instruction. Before: 178 // 179 // (Init) 180 // | 181 // kWhile 182 // | 183 // ... 184 // 185 // After: 186 // 187 // (Init) 188 // | 189 // kCopy 190 // | 191 // kWhile 192 // | 193 // ... 194 // 195 // This copy is necessary in case the init value is simultaneously live 196 // with the kWhile. 197 // 198 // (2) Copies are added to the parameter and root of the while body 199 // computation. Before: 200 // 201 // kParameter 202 // | 203 // ... 204 // | 205 // (body root) 206 // 207 // After: 208 // 209 // kParameter 210 // | 211 // kCopy ----------+ 212 // | | 213 // ... ctrl 214 // | edge 215 // (body root) | 216 // | | 217 // kCopy <---------+ 218 // 219 // The root kCopy becomes the new root of the computation. Both copies are 220 // necessary to any potential interference between the parameter value and 221 // the root value. The control edge prevents potential interference 222 // between the copies themselves. 223 // 224 // If the loop state is a tuple then the above kCopy instructions are a deep 225 // copy constructed of kCopy, KGetTupleElement, and kTuple instruction as 226 // constructed by HloInstruction::DeepCopyInstruction. 227 Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis, 228 HloInstruction* xla_while) { 229 VLOG(2) << "Adding copies for kWhile instruction " << xla_while->name(); 230 TF_RET_CHECK(xla_while->opcode() == HloOpcode::kWhile); 231 232 ShapeTree<bool> indices_to_copy(xla_while->shape()); 233 if (!IndicesToCopyForWhile(alias_analysis.dataflow_analysis(), xla_while, 234 &indices_to_copy)) { 235 VLOG(2) << "No copies necessary for kWhile instruction " 236 << xla_while->name(); 237 return Status::OK(); 238 } 239 240 VLOG(2) << "Adding copies for " << xla_while->name() << " at indices:"; 241 for (auto& pair : indices_to_copy) { 242 if (pair.second) { 243 VLOG(2) << " " << pair.first; 244 } 245 } 246 247 // Deep copy init. 248 HloInstruction* while_init = xla_while->mutable_operand(0); 249 TF_ASSIGN_OR_RETURN( 250 HloInstruction * while_init_copy, 251 xla_while->parent()->DeepCopyInstruction(while_init, &indices_to_copy)); 252 TF_RETURN_IF_ERROR(while_init->ReplaceUseWith(xla_while, while_init_copy)); 253 254 // Deep copy the parameter and the root. Extend a control edge from the copy 255 // of the parameter value to the corresponding copy value of the root. 256 HloComputation* body = xla_while->while_body(); 257 HloInstruction* param = body->parameter_instruction(0); 258 HloInstruction* root = body->root_instruction(); 259 260 // If param is the root then all indices should have been passed through the 261 // while body and we should have returned early above. 262 TF_RET_CHECK(param != root); 263 264 // Copy users before making a deep copy of the parameter as the deep copy 265 // will create new users of the parameter (eg, the GTE instructions of the 266 // deep copy). 267 std::vector<HloInstruction*> param_users = param->users(); 268 269 ShapeIndex current_index; 270 TF_ASSIGN_OR_RETURN(auto pair, 271 DeepCopyAndAddControlEdges(param, root, indices_to_copy)); 272 273 HloInstruction* param_copy = pair.first; 274 HloInstruction* root_copy = pair.second; 275 276 for (HloInstruction* user : param_users) { 277 TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, param_copy)); 278 } 279 280 body->set_root_instruction(root_copy); 281 282 return Status::OK(); 283 } 284 285 // Removes any control dependencies to or from the given instruction. 286 Status StripControlDependenciesFrom(HloInstruction* instruction) { 287 while (!instruction->control_successors().empty()) { 288 TF_RETURN_IF_ERROR(instruction->RemoveControlDependencyTo( 289 instruction->control_successors().front())); 290 } 291 292 while (!instruction->control_predecessors().empty()) { 293 TF_RETURN_IF_ERROR( 294 instruction->control_predecessors().front()->RemoveControlDependencyTo( 295 instruction)); 296 } 297 298 return Status::OK(); 299 } 300 301 // Add kCopy instructions to the given module to guarantee there is no 302 // live-range interference. Generally interference can only occur around kWhile 303 // instructions which have update-in-place semantics. 304 Status AddCopiesToResolveInterference(HloModule* module) { 305 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, 306 HloAliasAnalysis::Run(module)); 307 308 for (HloComputation* computation : module->computations()) { 309 for (HloInstruction* instruction : computation->instructions()) { 310 if (instruction->opcode() == HloOpcode::kWhile) { 311 TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction)); 312 } 313 } 314 } 315 return Status::OK(); 316 } 317 318 // Class for removing unnecessary copies from the module. 319 // 320 // kCopy instructions are added conservatively to guarantee no live range 321 // interference between HLO values. This class uses a more fine-grained analysis 322 // to remove some of these added copies which are not strictly necessary. 323 class CopyRemover { 324 public: 325 CopyRemover(const HloAliasAnalysis& alias_analysis, 326 const HloOrdering& ordering, HloModule* module) 327 : module_(module), 328 alias_analysis_(alias_analysis), 329 ordering_(ordering), 330 buffer_value_tracker_(*module, alias_analysis, ordering) {} 331 332 // Try to elide the given copy. The copy is elided if the instruction is not 333 // necessary to prevent live-range interference of HLO values. Returns true if 334 // copy was elided. 335 // 336 // The copy instruction is not actually removed here. Instead it is left for 337 // dead in the graph. Later calls to DCE will remove the instruction. 338 StatusOr<bool> TryElideCopy(HloInstruction* copy) { 339 if (buffer_value_tracker_.TryElideCopy(copy)) { 340 TF_RETURN_IF_ERROR(StripControlDependenciesFrom(copy)); 341 TF_RETURN_IF_ERROR(copy->ReplaceAllUsesWith(copy->mutable_operand(0))); 342 return true; 343 } 344 return false; 345 } 346 347 string ToString() const { 348 string out = StrCat("CopyRemover, module ", module_->name(), "\n"); 349 StrAppend(&out, " Buffer values, in dependency order:\n"); 350 for (const HloBuffer& buffer : alias_analysis_.buffers()) { 351 StrAppend(&out, " HloBuffer ", buffer.id(), ":\n"); 352 } 353 return out; 354 } 355 356 private: 357 // Class which tracks the HLO values within each HLO buffer in the module 358 // during copy removal. 359 // 360 // The values are held in a linked list where there is one list for each 361 // buffer. Removing a copy instruction merges together the values in the 362 // source buffer of the copy to the destination buffer of the copy. This class 363 // tracks these value lists as copies are removed from the graph (and value 364 // lists are merged). 365 // 366 // The BufferValueTracker object is initialized to match the state of 367 // HloAliasAnalysis. However, as copies are removed this state diverges. The 368 // values-to-buffer mapping is maintained outside of HloAliasAnalysis because 369 // a fully updatable alias analysis is very slow. 370 class BufferValueTracker { 371 public: 372 // The values held in a single HLO buffer are represented using a linked 373 // list. An element type in this list is ValueNode. 374 // 375 // This linked list is hand-rolled to enable efficient splicing of lists 376 // using only references to list elements without knowing which lists are 377 // being spliced. std::list requires a reference to the list object to 378 // splice. 379 struct ValueNode { 380 explicit ValueNode(const HloValue* v) : value(v) {} 381 382 const HloValue* value; 383 384 // The uses are maintained outside of HloValue::uses() because 385 // HloValue::uses() is not updatable (a fully updatable dataflow analysis 386 // is slow). 387 std::vector<const HloUse*> uses; 388 389 // next/prev elements in the linked list. The list is circularly linked so 390 // these values are never null for elements in the list. 391 ValueNode* prev = nullptr; 392 ValueNode* next = nullptr; 393 }; 394 395 BufferValueTracker(const HloModule& module, 396 const HloAliasAnalysis& alias_analysis, 397 const HloOrdering& ordering) 398 : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) { 399 // Construct a list for each HLO buffer in the alias analysis. Maintain a 400 // map from HloValue to the respective list element representing that 401 // value. The map is used to construct the copy info map below. 402 tensorflow::gtl::FlatMap<const HloValue*, ValueNode*> value_to_node; 403 for (const HloBuffer& buffer : alias_analysis.buffers()) { 404 // Verify values contained in the buffer are strictly ordered. This 405 // should always be the case after adding copies to eliminate 406 // interference. Specifically, the addition of the control flow edges 407 // between copies added around aliased operations (kWhile) guarantees 408 // this strict order. 409 for (const HloValue* value_a : buffer.values()) { 410 for (const HloValue* value_b : buffer.values()) { 411 if (value_a != value_b) { 412 DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b, 413 dataflow_) || 414 ordering_.LiveRangeStrictlyBefore(*value_b, *value_a, 415 dataflow_)) 416 << value_a->ToShortString() << " and " 417 << value_b->ToShortString() << " are not ordered"; 418 } 419 } 420 } 421 422 std::vector<const HloValue*> values = buffer.values(); 423 std::sort(values.begin(), values.end(), 424 [this](const HloValue* a, const HloValue* b) { 425 return ordering_.IsDefinedBefore(*a, *b); 426 }); 427 428 // Create a list containing all of the values in the buffer. 429 AddValueList(values, &value_to_node); 430 } 431 432 // Create copy_map_ which contains the source and destination values 433 // of all copies. 434 CreateCopyMap(module, value_to_node); 435 436 XLA_VLOG_LINES(3, ToString()); 437 TF_DCHECK_OK(Verify()); 438 } 439 440 // Add a list containing the given values to BufferValueTracker. This 441 // represents the values contained in a single buffer. For each value in 442 // 'values' an entry is created in value_to_node which indicates the 443 // respective ValueNode representing that value. 444 void AddValueList( 445 tensorflow::gtl::ArraySlice<const HloValue*> values, 446 tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>* value_to_node) { 447 ValueNode* tail = nullptr; 448 ValueNode* head = nullptr; 449 for (const HloValue* value : values) { 450 auto new_node = new ValueNode(value); 451 (*value_to_node)[value] = new_node; 452 453 // Copy the HLO values's uses into the ValueNode for the value. These 454 // uses in ValueNode are updated as copies are removed. 455 new_node->uses.reserve(value->uses().size()); 456 for (const HloUse& use : value->uses()) { 457 new_node->uses.push_back(&use); 458 } 459 460 // Connect the new node into the linked list. 461 if (tail == nullptr) { 462 head = new_node; 463 } else { 464 tail->next = new_node; 465 new_node->prev = tail; 466 } 467 tail = new_node; 468 } 469 470 // The linked list is circular so connect the head and tail. 471 tail->next = head; 472 head->prev = tail; 473 value_lists_.insert(head); 474 } 475 476 // This method also fills in copy_map_ which indicates which nodes 477 // in the value lists corresponding to the source and destination values of 478 // kCopy instructions. value_to_node should map each HloValue to its 479 // respective ValueNode. 480 void CreateCopyMap( 481 const HloModule& module, 482 const tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>& 483 value_to_node) { 484 for (HloComputation* computation : module.computations()) { 485 for (HloInstruction* instruction : computation->instructions()) { 486 // Add copies with unambiguous source values to the map. Copies with 487 // ambiguous sources are not removable. 488 if (instruction->opcode() == HloOpcode::kCopy) { 489 const HloValueSet& src_value_set = 490 dataflow_.GetValueSet(instruction->operand(0)); 491 if (src_value_set.values().size() == 1) { 492 CopyNodes& copy_node = copy_map_[instruction]; 493 copy_node.dest = 494 value_to_node.at(&dataflow_.GetUniqueValueAt(instruction)); 495 copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue()); 496 } 497 } 498 } 499 } 500 } 501 502 ~BufferValueTracker() { 503 for (const ValueNode* head : value_lists_) { 504 const ValueNode* p = head; 505 do { 506 const ValueNode* tmp = p->next; 507 delete p; 508 p = tmp; 509 } while (p != head); 510 } 511 } 512 513 // Verify invariants within the linked lists. 514 Status Verify() const { 515 for (const ValueNode* head : value_lists_) { 516 const ValueNode* p = head; 517 do { 518 // Verify links between elements are consistent. 519 TF_RET_CHECK(p->prev->next == p); 520 TF_RET_CHECK(p->next->prev == p); 521 522 const HloInstruction* def = p->value->defining_instruction(); 523 if (def->opcode() == HloOpcode::kCopy && 524 ContainsKey(copy_map_, def)) { 525 TF_RET_CHECK(copy_map_.at(def).dest == p); 526 } 527 for (const HloUse* use : p->uses) { 528 if (use->instruction->opcode() == HloOpcode::kCopy && 529 ContainsKey(copy_map_, use->instruction)) { 530 TF_RET_CHECK(copy_map_.at(use->instruction).src == p); 531 } 532 } 533 534 p = p->next; 535 } while (p != head); 536 } 537 return Status::OK(); 538 } 539 540 // Try to elide the given copy. Elision of a copy is possible only if no 541 // live range interference is introduced by the copy's elimination. If 542 // elision is possible, then the internal state (value lists) are updated, 543 // and true is returned. Returns false otherwise. 544 bool TryElideCopy(const HloInstruction* copy) { 545 VLOG(2) << "Trying to remove " << copy->name(); 546 547 if (!ContainsKey(copy_map_, copy)) { 548 VLOG(2) << copy->name() << " is not removable"; 549 return false; 550 } 551 552 const CopyNodes& copy_node = copy_map_.at(copy); 553 ValueNode* src = copy_node.src; 554 ValueNode* dest = copy_node.dest; 555 DCHECK(src != nullptr); 556 DCHECK(dest != nullptr); 557 558 auto is_live_range_before = [this](const ValueNode& a, 559 const ValueNode& b) { 560 if (LiveRangeBefore(a, b)) { 561 VLOG(2) << " Live range of " << a.value->ToShortString() 562 << " is before " << b.value->ToShortString(); 563 return true; 564 } else { 565 VLOG(2) << " Live range of " << a.value->ToShortString() 566 << " is not before " << b.value->ToShortString(); 567 return false; 568 } 569 }; 570 571 VLOG(3) << copy->name() << " copies value " 572 << src->value->ToShortString(); 573 VLOG(3) << "Source buffer values: " << ValueListToString(src); 574 VLOG(3) << "Dest buffer values: " << ValueListToString(src); 575 576 // A kCopy instruction copies an HLO value from a source buffer and 577 // defines an HLO value in a destination buffer. Most generally, the 578 // source and destination buffers may each hold more than one value at 579 // different points in the computation so we define the following: 580 // 581 // Values in source buffer: {s_0, ..., s_n} 582 // Values in destination buffer: {d_0, ..., d_m} 583 // 584 // A kCopy instruction between these buffers copies a value s_x in the 585 // source buffer and defines a value d_y in the destination buffer. The 586 // elision of a copy merges the source and destination buffers together, 587 // so the list of values for the source and destination buffers are 588 // merged. 589 // 590 // We handle two different cases for copy elision: 591 // 592 // (1) the kCopy defines the first value in the destination buffer (d_0). 593 // 594 // (2) the kCopy copies the last value in the source buffer (s_n). 595 // 596 // For the remaining case where the kCopy copies a not-last value from the 597 // source buffer to a not-first value of the destination buffer, the kCopy 598 // instruction cannot be removed. This case is generated, for example, if 599 // the kCopy copies a while body parameter of the loop state at one tuple 600 // index to a different tuple index in the while body root. Removal of the 601 // copy necessarily results in live range interference of values in the 602 // loop state at the two different tuple indices. 603 // 604 // We can only perform copy elision if the resulting merged values have 605 // totally ordered live ranges; otherwise the merged buffer would have 606 // live range interference. 607 if (IsHead(*dest)) { 608 // The copy copies an arbitrary value in the source buffer (call it s_x) 609 // and defines d_0, the first value in the destination buffer. After 610 // merging, the values in the combined buffer must be strictly ordered 611 // as follows** to elide the copy: 612 // 613 // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n} 614 // 615 // Removing the copy eliminates d_0, and uses of d_0 become uses of 616 // s_x. In the above ordering, the live range of d_m must be ordered 617 // before the live range of s_{x+1} and the definition and all uses of 618 // s_x must be ordered before the definition of d_1. These conditions 619 // are checked below prior to elision. 620 // 621 // ** Technically it might be possible to have a non-interfering 622 // non-trivial interleaving of the values of the source and 623 // destination buffers in the resulting order. However, this case is 624 // slow and complicated to check and likely not worth it. So instead 625 // we simply check for the case where *all* values of the destination 626 // buffer (d_1 through d_m) are spliced into the point where the copy 627 // used to be. 628 VLOG(2) << copy->name() << " defines the first value in its buffer"; 629 ValueNode* next_dest = Next(*dest); 630 if (next_dest != nullptr) { 631 // Live range of 'from' value (s_x) must be before 'next_dest' (d_1); 632 if (!is_live_range_before(*src, *next_dest)) { 633 return false; 634 } 635 } 636 ValueNode* next_src = Next(*src); 637 638 if (next_src != nullptr) { 639 // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}. 640 ValueNode* last_dest = dest->prev; 641 DCHECK(IsTail(*last_dest)); 642 if (!is_live_range_before(*last_dest, *next_src)) { 643 return false; 644 } 645 } 646 647 // Splice in destination buffer values list right after 'src'. 648 SpliceAfter(dest, src); 649 } else if (IsTail(*src)) { 650 // The copy copies the last value in the source buffer, s_n, and defines 651 // an arbitrary value in the destination buffer, d_y. After 652 // merging, the values in the combined buffer must be strictly ordered 653 // as follows** to elide the copy: 654 // 655 // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m} 656 // 657 // Removing the copy eliminates d_y, and uses of d_y become uses of 658 // s_n. To enforce the above order, the live range of d_{y-1} must be 659 // before the live range of s_0, and the live range of s_n must be 660 // before the live range of d_{y+1}. 661 // 662 // ** See comment above in the code handling Case (1). 663 VLOG(2) << copy->name() << " copies the last value (" 664 << src->value->ToShortString() << ") in its buffer"; 665 666 ValueNode* prev_dest = Prev(*dest); 667 // nullptr condition handled above in the first 'if' case. 668 DCHECK(prev_dest != nullptr); 669 ValueNode* first_src = src->next; 670 DCHECK(IsHead(*first_src)); 671 if (!is_live_range_before(*prev_dest, *first_src)) { 672 // Live range of value d_{y-1} is not before s_0. 673 return false; 674 } 675 ValueNode* next_dest = Next(*dest); 676 if (next_dest != nullptr) { 677 if (!is_live_range_before(*src, *next_dest)) { 678 // Live range of value s_n is not before d_{y+1}. 679 return false; 680 } 681 } 682 683 // Splice source buffer values list right after 'prev_dest'. 684 SpliceAfter(first_src, prev_dest); 685 } else { 686 VLOG(2) 687 << copy->name() 688 << " copies value in middle of source buffer to value in middle " 689 "of destination buffer"; 690 return false; 691 } 692 693 RemoveCopyValue(dest); 694 695 XLA_VLOG_LINES(4, ToString()); 696 TF_DCHECK_OK(Verify()); 697 698 return true; 699 } 700 701 // Delete the given ValueNode associated with a elided kCopy 702 // instruction. This should be called after splicing the value lists of the 703 // source and destination buffers together. 704 void RemoveCopyValue(ValueNode* copy_value_node) { 705 CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(), 706 HloOpcode::kCopy); 707 ValueNode* operand_node = copy_value_node->prev; 708 CHECK(operand_node != copy_value_node); 709 710 VLOG(2) << "Removing copy " << operand_node->value->ToShortString() 711 << " => " << copy_value_node->value->ToShortString(); 712 713 // Splice out the copy value node. 714 operand_node->next = copy_value_node->next; 715 copy_value_node->next->prev = operand_node; 716 717 // Patch up uses. Remove use of copy from operand_node uses. 718 auto it = 719 std::find_if(operand_node->uses.begin(), operand_node->uses.end(), 720 [copy_value_node](const HloUse* use) { 721 return use->instruction == 722 copy_value_node->value->defining_instruction(); 723 }); 724 CHECK(it != operand_node->uses.end()); 725 operand_node->uses.erase(it); 726 727 // If the elided copy has any uses which are themselves kCopy instructions 728 // then patch up the copy info to reflect the that this kCopy instruction 729 // has a different operand (the operand of the elided copy). 730 for (const HloUse* copy_use : copy_value_node->uses) { 731 operand_node->uses.push_back(copy_use); 732 if (copy_use->instruction->opcode() == HloOpcode::kCopy && 733 ContainsKey(copy_map_, copy_use->instruction)) { 734 copy_map_.at(copy_use->instruction).src = operand_node; 735 } 736 } 737 738 // Delete the copy info and the value node. 739 copy_map_.erase(copy_value_node->value->defining_instruction()); 740 delete copy_value_node; 741 } 742 743 // Returns true if the live range of given value 'a' is before the live 744 // range of 'b'. 745 // 746 // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not 747 // updated as copies are removed. 748 bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) { 749 if (a.uses.empty()) { 750 VLOG(2) << "Empty uses"; 751 return ordering_.IsDefinedBefore(*a.value, *b.value); 752 } 753 for (const HloUse* use : a.uses) { 754 VLOG(2) << "use: " << *use; 755 VLOG(2) << "is before:" << *b.value; 756 if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) { 757 VLOG(2) << "Not before"; 758 return false; 759 } 760 } 761 return true; 762 } 763 764 // Returns whether 'node' is the last node in its list. 765 bool IsTail(const ValueNode& node) const { 766 return ContainsKey(value_lists_, node.next); 767 } 768 769 // Returns whether 'node' is the first node in its list. 770 bool IsHead(const ValueNode& node) const { 771 return ContainsKey(value_lists_, &node); 772 } 773 774 // Returns the next node in the list after 'node'. If 'node' is the 775 // tail, then nullptr is returned. 776 ValueNode* Next(const ValueNode& node) const { 777 if (IsTail(node)) { 778 return nullptr; 779 } else { 780 return node.next; 781 } 782 } 783 784 // Returns the previous node in the list before 'node'. If 'node' 785 // is the head, then nullptr is returned. 786 ValueNode* Prev(const ValueNode& node) const { 787 if (IsHead(node)) { 788 return nullptr; 789 } else { 790 return node.prev; 791 } 792 } 793 794 // Splices the entire linked list with 'head' as its head right after the 795 // node 'insert_after' in another linked list. 796 void SpliceAfter(ValueNode* head, ValueNode* insert_after) { 797 DCHECK(IsHead(*head)); 798 value_lists_.erase(head); 799 800 ValueNode* tail = head->prev; 801 tail->next = insert_after->next; 802 insert_after->next->prev = tail; 803 804 insert_after->next = head; 805 head->prev = insert_after; 806 } 807 808 string ValueListToString(const ValueNode* element) { 809 const ValueNode* head = element; 810 while (!IsHead(*head)) { 811 head = Prev(*head); 812 } 813 std::vector<const HloValue*> values; 814 for (const ValueNode* p = head; p != nullptr; p = Next(*p)) { 815 values.push_back(p->value); 816 } 817 return StrCat("{", 818 Join(values, ", ", 819 [](string* s, const HloValue* value) { 820 StrAppend(s, value->ToShortString()); 821 }), 822 "}"); 823 } 824 825 string ToString() const { 826 string out = StrCat("BufferValueTracker:\n"); 827 StrAppend(&out, " Def-use chains in each buffer:\n"); 828 for (const ValueNode* head : value_lists_) { 829 StrAppend(&out, " Buffer defined by ", head->value->ToShortString(), 830 ":\n"); 831 const ValueNode* p = head; 832 do { 833 StrAppend(&out, " ", p->value->ToShortString(), ", uses: ", 834 Join(p->uses, "; ", 835 [](string* s, const HloUse* use) { 836 StrAppend(s, use->ToString()); 837 }), 838 "\n"); 839 840 p = p->next; 841 } while (p != head); 842 } 843 StrAppend(&out, " Potentially removable copies:\n"); 844 for (const auto& pair : copy_map_) { 845 const HloInstruction* copy = pair.first; 846 const CopyNodes& copy_info = pair.second; 847 848 StrAppend(&out, " ", copy->name(), " : ", 849 copy_info.src->value->ToShortString(), " => ", 850 copy_info.dest->value->ToShortString(), "\n"); 851 } 852 return out; 853 } 854 855 private: 856 const HloDataflowAnalysis& dataflow_; 857 const HloOrdering& ordering_; 858 859 // The heads of all the value lists. Each value list represents the HLO 860 // values contained in a particular HLO buffer. The values in the list are 861 // in dependency order. 862 tensorflow::gtl::FlatSet<const ValueNode*> value_lists_; 863 864 // Copy removal requires fast access to the value list elements 865 // corresponding to the source and destination values of the kCopy 866 // instruction. This data structure holds pointers to these elements for 867 // each kCopy instruction in the graph. 868 struct CopyNodes { 869 // The source and destinations values of the kCopy instruction. 870 ValueNode* src = nullptr; 871 ValueNode* dest = nullptr; 872 }; 873 tensorflow::gtl::FlatMap<const HloInstruction*, CopyNodes> copy_map_; 874 }; 875 876 HloModule* module_; 877 const HloAliasAnalysis& alias_analysis_; 878 const HloOrdering& ordering_; 879 880 // Object tracking the HLO values contained in each HLO buffer. 881 BufferValueTracker buffer_value_tracker_; 882 }; 883 884 // Try to remove as many copies from the module as possible without introducing 885 // live range interference. Copy instructions (identified by their unique id) in 886 // the set copies_to_exclude are not considered for removal. 887 Status RemoveUnnecessaryCopies( 888 const HloOrdering& ordering, 889 const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module) { 890 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, 891 HloAliasAnalysis::Run(module)); 892 CopyRemover copy_remover(*alias_analysis, ordering, module); 893 XLA_VLOG_LINES(3, copy_remover.ToString()); 894 895 tensorflow::gtl::FlatSet<int> existing_copies; 896 for (HloComputation* computation : module->computations()) { 897 for (HloInstruction* instruction : computation->instructions()) { 898 if (instruction->opcode() == HloOpcode::kCopy && 899 !ContainsKey(copies_to_exclude, instruction->unique_id())) { 900 TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); 901 } 902 } 903 } 904 905 return Status::OK(); 906 } 907 908 // Add copies to address special constraints on the roots of computations not 909 // related to live range interference: 910 // 911 // (1) Entry computation root must be unambiguous and distinct. 912 // 913 // (2) Any computation called by a kCall instruction must have an 914 // unambiguous root. 915 // 916 // (3) Constants and parameters cannot be live out of the entry computation 917 // 918 Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { 919 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, 920 HloAliasAnalysis::Run(module)); 921 922 // Identify which shape indices of which instructions need to be copied. Store 923 // these results in 'instructions_to_copy'. 924 std::unordered_map<HloInstruction*, ShapeTree<bool>> instructions_to_copy; 925 auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction, 926 const ShapeIndex& index) { 927 auto it = instructions_to_copy.find(instruction); 928 if (it == instructions_to_copy.end()) { 929 auto it_added = instructions_to_copy.emplace( 930 std::piecewise_construct, std::forward_as_tuple(instruction), 931 std::forward_as_tuple(instruction->shape(), /*init_value=*/false)); 932 it = it_added.first; 933 } 934 *it->second.mutable_element(index) = true; 935 }; 936 937 // Iterate through values of all constants and entry parameters. These values 938 // are special because they are held in read-only buffers. If any of these 939 // values share a buffer with other values (for example, the init value of a 940 // while is a constant) then copy the value at its definition and replace all 941 // its uses with the copy. 942 for (const HloValue* value : alias_analysis->dataflow_analysis().values()) { 943 if (ValueIsReadOnly(*value) && 944 alias_analysis->GetBufferContainingValue(*value).values().size() > 1) { 945 VLOG(2) << "Value " << value->ToShortString() 946 << " is read only, but its buffer contains more than one value. " 947 "Copying."; 948 add_index_to_copy(value->defining_instruction(), value->defining_index()); 949 } 950 } 951 952 // Identify copies which must be added at root instructions 953 for (HloComputation* computation : module->computations()) { 954 const CallGraphNode& node = call_graph.GetNode(computation); 955 if (node.context() == CallContext::kParallel) { 956 continue; 957 } 958 TF_RET_CHECK(node.context() == CallContext::kSequential); 959 960 const bool is_entry = computation == module->entry_computation(); 961 HloInstruction* root = computation->root_instruction(); 962 963 // Mark nondistinct/ambiguous indices. 964 tensorflow::gtl::FlatSet<const HloBuffer*> seen; 965 ShapeUtil::ForEachSubshape( 966 root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) { 967 std::vector<const HloBuffer*> buffers_at_index = 968 alias_analysis->ComputeBuffersAt(root, index); 969 bool buffer_seen_before = false; 970 for (const HloBuffer* buffer : buffers_at_index) { 971 buffer_seen_before |= !seen.insert(buffer).second; 972 } 973 if (buffers_at_index.size() > 1 || (buffer_seen_before && is_entry)) { 974 VLOG(2) << "Index " << index << " of root of computation " 975 << computation->name() << " (" << root->name() 976 << ") has ambiguous or non-distinct buffer. Copying."; 977 add_index_to_copy(root, index); 978 } 979 }); 980 981 // For entry instructions, mark any parameter or constant values. 982 if (is_entry) { 983 for (const auto& pair : 984 alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) { 985 const ShapeIndex& index = pair.first; 986 const HloValueSet& value_set = pair.second; 987 for (const HloValue* value : value_set.values()) { 988 if (ValueIsReadOnly(*value)) { 989 VLOG(2) << "Root of entry computation (" << root->name() 990 << ") has constant or entry parameter value at index " 991 << index << ". Copying."; 992 add_index_to_copy(root, index); 993 } 994 } 995 } 996 } 997 } 998 999 // Add copy instructions indicated in 'instructions_to_copy' to the module. 1000 for (const auto& pair : instructions_to_copy) { 1001 HloInstruction* instruction = pair.first; 1002 const ShapeTree<bool>& indices_to_copy = pair.second; 1003 1004 std::vector<HloInstruction*> users = instruction->users(); 1005 TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, 1006 instruction->parent()->DeepCopyInstruction( 1007 instruction, &indices_to_copy)); 1008 for (HloInstruction* user : users) { 1009 TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy)); 1010 } 1011 if (instruction == instruction->parent()->root_instruction()) { 1012 instruction->parent()->set_root_instruction(deep_copy); 1013 } 1014 } 1015 1016 return Status::OK(); 1017 } 1018 1019 Status VerifyNoLiveRangeInterference(HloModule* module) { 1020 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, 1021 HloAliasAnalysis::Run(module)); 1022 DependencyHloOrdering ordering(module); 1023 TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering)); 1024 return Status::OK(); 1025 } 1026 1027 void MaybeDumpModule(const string& message, const HloModule& module) { 1028 if (VLOG_IS_ON(3)) { 1029 VLOG(3) << message; 1030 XLA_VLOG_LINES(3, module.ToString()); 1031 hlo_graph_dumper::MaybeDumpHloModule(module, message); 1032 } 1033 } 1034 1035 } // namespace 1036 1037 StatusOr<bool> CopyInsertion::Run(HloModule* module) { 1038 // Copy insertion is performed in three steps: 1039 // 1040 // (1) Add copies conservatively to guarantee that there is no live-range 1041 // interference. This is done simplistically and usually results in more 1042 // copies than is strictly necessary. 1043 // 1044 // (2) Using a more fine-grained analysis, remove as many copies that were 1045 // added in (1) as possible while ensuring no live-range interference. 1046 // 1047 // (3) Add copies to resolve issues not related to live range interference 1048 // such as parameters and constants live out of the entry computation. 1049 // 1050 // We add copies then remove them (step (1) then (2)) rather than simply 1051 // adding only the copies that are necessary because, in general, it is 1052 // difficult to figure out the minimal set of copies to add once there is 1053 // interference. On the other hand, it is easy to determine if removing a copy 1054 // will introduce interference. 1055 // 1056 // The final copy insertion in (3) is done separately to simplify the 1057 // implementation of copy removal in (2) which is the most complicated part of 1058 // the pass. As is, copy removal only has to reason about live range 1059 // interference. If all copies were added in step (1) then copy removal would 1060 // also have to reason about things like constants and parameters live out of 1061 // the computation. 1062 MaybeDumpModule("before copy insertion", *module); 1063 1064 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); 1065 if (!call_graph->IsFlattened()) { 1066 return FailedPrecondition( 1067 "Call graph must be flattened before copy insertion."); 1068 } 1069 1070 // Gather Ids of existing kCopy instructions in the module. We avoid removing 1071 // these copies (except via DCE in TupleSimplifier) because they may have been 1072 // added for reasons not considered by copy insertion (eg, layout assignment). 1073 // Instruction id is used instead of HloInstruction* because the pointer 1074 // values may be recycled. 1075 tensorflow::gtl::FlatSet<int> existing_copies; 1076 for (HloComputation* computation : module->computations()) { 1077 for (HloInstruction* instruction : computation->instructions()) { 1078 if (instruction->opcode() == HloOpcode::kCopy) { 1079 existing_copies.insert(instruction->unique_id()); 1080 } 1081 } 1082 } 1083 1084 TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module)); 1085 1086 // Simplify the tuple structures introduced by the deep copies. This should be 1087 // done before removing copies (RemoveUnnecessaryCopies) because tuple 1088 // simplification changes dependencies in the graph which changes live range 1089 // interference in the graph. Also run DCE to remove the dead Tuple/GTE 1090 // instructions introduced by tuple simplification. 1091 TupleSimplifier tuple_simplifier; 1092 HloDCE dce; 1093 TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); 1094 TF_RETURN_IF_ERROR(dce.Run(module).status()); 1095 1096 TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); 1097 1098 MaybeDumpModule("after adding copies to resolve interference", *module); 1099 1100 DependencyHloOrdering ordering(module); 1101 TF_RETURN_IF_ERROR( 1102 RemoveUnnecessaryCopies(ordering, existing_copies, module)); 1103 1104 MaybeDumpModule("after removing unnecessary copies", *module); 1105 1106 TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); 1107 1108 MaybeDumpModule("after adding special-case copies", *module); 1109 1110 TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); 1111 TF_RETURN_IF_ERROR(dce.Run(module).status()); 1112 TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); 1113 1114 MaybeDumpModule("after copy insertion", *module); 1115 1116 if (VLOG_IS_ON(1)) { 1117 int64 num_total_copies = 0; 1118 for (HloComputation* computation : module->computations()) { 1119 for (HloInstruction* instruction : computation->instructions()) { 1120 if (instruction->opcode() == HloOpcode::kCopy) { 1121 num_total_copies++; 1122 } 1123 } 1124 } 1125 VLOG(1) << "Num copies before copy-insertion: " << existing_copies.size(); 1126 VLOG(1) << "Num copies after copy-insertion: " << num_total_copies; 1127 } 1128 1129 return true; 1130 } 1131 1132 namespace { 1133 1134 bool IsWhileBody(const HloComputation* computation, 1135 const CallGraph& call_graph) { 1136 const CallGraphNode& node = call_graph.GetNode(computation); 1137 1138 if (node.context() == CallContext::kSequential && 1139 !node.caller_callsites().empty()) { 1140 // Callgraph should be flattened so sequential context computations can 1141 // have at most one caller. 1142 CHECK_EQ(node.caller_callsites().size(), 1); 1143 const HloInstruction* calling_instruction = 1144 node.caller_callsites()[0].instruction(); 1145 if (calling_instruction->opcode() == HloOpcode::kWhile && 1146 calling_instruction->while_body() == node.computation()) { 1147 return true; 1148 } 1149 } 1150 return false; 1151 } 1152 1153 } // namespace 1154 1155 /* static */ StatusOr<bool> CopyInsertion::AddCopiesForBufferAssignment( 1156 HloModule* module) { 1157 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); 1158 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow, 1159 HloDataflowAnalysis::Run(*module)); 1160 1161 bool changed = false; 1162 1163 // If a buffer live out of a computation is a constant, a parameter, or not 1164 // defined in the computation, then copy it to account for the limited 1165 // computation-scoped analysis in buffer assignment. An exception to this rule 1166 // is the while body which is handled properly without copies. 1167 for (HloComputation* computation : module->computations()) { 1168 if (computation == module->entry_computation() || 1169 IsWhileBody(computation, *call_graph)) { 1170 continue; 1171 } 1172 1173 HloInstruction* root = computation->root_instruction(); 1174 ShapeTree<bool> indices_to_copy(root->shape(), /*init_value=*/false); 1175 bool copy_root = false; 1176 for (const auto& pair : dataflow->GetInstructionValueSet(root)) { 1177 const ShapeIndex& index = pair.first; 1178 const HloValueSet& value_set = pair.second; 1179 for (const HloValue* value : value_set.values()) { 1180 HloInstruction* def = value->defining_instruction(); 1181 if (def->parent() != computation || 1182 def->opcode() == HloOpcode::kConstant || 1183 def->opcode() == HloOpcode::kParameter) { 1184 *indices_to_copy.mutable_element(index) = true; 1185 copy_root = true; 1186 } 1187 } 1188 } 1189 if (copy_root) { 1190 TF_ASSIGN_OR_RETURN( 1191 HloInstruction * root_copy, 1192 computation->DeepCopyInstruction(root, &indices_to_copy)); 1193 computation->set_root_instruction(root_copy); 1194 changed = true; 1195 } 1196 } 1197 1198 TupleSimplifier tuple_simplifier; 1199 HloDCE dce; 1200 TF_ASSIGN_OR_RETURN(bool tuple_simplifier_changed, 1201 tuple_simplifier.Run(module)); 1202 TF_ASSIGN_OR_RETURN(bool dce_changed, dce.Run(module)); 1203 1204 return changed || tuple_simplifier_changed || dce_changed; 1205 } 1206 1207 } // namespace xla 1208