1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_rematerialization.h" 17 18 #include <algorithm> 19 #include <memory> 20 #include <set> 21 #include <string> 22 23 #include "tensorflow/compiler/xla/map_util.h" 24 #include "tensorflow/compiler/xla/primitive_util.h" 25 #include "tensorflow/compiler/xla/service/flatten_call_graph.h" 26 #include "tensorflow/compiler/xla/service/hlo_computation.h" 27 #include "tensorflow/compiler/xla/service/hlo_dce.h" 28 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 29 #include "tensorflow/compiler/xla/service/hlo_module.h" 30 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 31 #include "tensorflow/compiler/xla/service/hlo_ordering.h" 32 #include "tensorflow/compiler/xla/service/hlo_scheduling.h" 33 #include "tensorflow/compiler/xla/service/liveness_util.h" 34 #include "tensorflow/compiler/xla/service/logical_buffer.h" 35 #include "tensorflow/compiler/xla/status_macros.h" 36 #include "tensorflow/compiler/xla/statusor.h" 37 #include "tensorflow/compiler/xla/types.h" 38 #include "tensorflow/compiler/xla/util.h" 39 #include "tensorflow/core/lib/strings/str_util.h" 40 #include "tensorflow/core/lib/strings/strcat.h" 41 #include "tensorflow/core/lib/strings/stringprintf.h" 42 #include "tensorflow/core/platform/logging.h" 43 44 using ::tensorflow::strings::HumanReadableNumBytes; 45 46 namespace xla { 47 48 namespace { 49 50 // Potential optimizations: 51 // . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue 52 // of candidates. 53 // . Cache IsRematerializable in Item? Only correct if control 54 // predecessors and successors don't change. 55 56 // Returns true if the given instruction is rematerializable. 57 bool IsRematerializable(const HloInstruction* instruction) { 58 // Don't rematerialize instructions with side effects or instructions which 59 // cannot be cloned safely. 60 switch (instruction->opcode()) { 61 case HloOpcode::kCall: 62 case HloOpcode::kConstant: 63 case HloOpcode::kConditional: 64 case HloOpcode::kCrossReplicaSum: 65 case HloOpcode::kCustomCall: 66 case HloOpcode::kParameter: 67 case HloOpcode::kWhile: 68 return false; 69 default: 70 return !instruction->HasSideEffect(); 71 } 72 } 73 74 // Type holding a unique identifier for each Buffer object. 75 using BufferId = int64; 76 using BufferIdList = tensorflow::gtl::InlinedVector<BufferId, 3>; 77 78 // We wrap HloInstruction* with an Item that holds auxiliary 79 // per-instruction state. 80 struct Item { 81 HloInstruction* instruction; 82 83 // True once the instruction is marked as placed (when BeginInstruction 84 // has been called for this instruction). 85 bool placed = false; 86 87 // To avoid an infinite loop rematerializing the same set of 88 // instructions ad infinitum, keep a blacklist of instructions 89 // which should not be rematerialized. 90 bool blacklisted = false; 91 92 // The buffers defined by this instruction. 93 BufferIdList buffers_defined; 94 95 // The buffers used by this instruction. 96 BufferIdList buffers_used; 97 98 private: 99 friend class InstructionList; 100 101 // Items are arranged in a doubly linked list. 102 Item* next; 103 Item* prev; 104 105 // List is ordered by position, which can however be duplicated as 106 // new instructions are inserted. See InsertBeforeInstructions 107 // comment for details. 108 int64 position; 109 }; 110 111 using ItemList = tensorflow::gtl::InlinedVector<Item*, 3>; 112 113 // Class which maintains an ordered list of instructions with fast insertion 114 // before arbitrary elements. 115 class InstructionList { 116 public: 117 explicit InstructionList(const std::vector<const HloInstruction*>& order) { 118 int64 position = 0; 119 Item* last = nullptr; 120 for (const HloInstruction* inst : order) { 121 // Add a new item to the linked list. 122 Item* item = new Item; 123 item->next = nullptr; 124 item->prev = last; 125 if (last == nullptr) { 126 first_ = item; 127 } else { 128 last->next = item; 129 } 130 last = item; 131 132 // Initially position numbers are uniquely assigned in order. Later as 133 // instructions are added with InsertBefore* methods, some instructions 134 // may have duplicate position numbers, but the values will be guaranteed 135 // to be monotonically increasing through the list, and so is still useful 136 // for quickly(-ish) determining the order of arbitrary instructions in 137 // the list. 138 item->instruction = const_cast<HloInstruction*>(inst); 139 item->position = position; 140 position++; 141 142 item_map_[inst] = item; 143 } 144 } 145 146 ~InstructionList() { 147 for (Item* item = first_; item != nullptr;) { 148 Item* next = item->next; 149 delete item; 150 item = next; 151 } 152 } 153 154 size_t size() const { return item_map_.size(); } 155 156 // For ordered iteration over items. 157 // for (auto item = q.first(); item != nullptr; item = q.next(item)) {...} 158 Item* first() const { return first_; } 159 Item* next(Item* item) const { return item->next; } 160 161 // Creates an Item for the given instruction, but doesn't add it to the list. 162 // (Use InsertBeforeInstructions to add the Item to the list.) 163 Item* CreateItem(HloInstruction* inst) { 164 Item* item = new Item; 165 item->instruction = inst; 166 CHECK(item_map_.insert({inst, item}).second) << "inserting inst twice"; 167 return item; 168 } 169 170 // Return the Item corresponding to inst. 171 Item* GetItem(const HloInstruction* inst) const { 172 auto iter = item_map_.find(inst); 173 CHECK(iter != item_map_.end()) << "Did not find " << inst->name(); 174 return iter->second; 175 } 176 177 // Insert instruction 'to_insert' immediately before the earliest instruction 178 // in 'before_instructions'. 179 // 180 // Each instruction gets a non-decreasing ordinal number. We use this to let 181 // InsertBeforeInstructions quickly insert an instruction before the earliest 182 // instruction in a set of instructions. If position_number_[a] < 183 // position_number_[b] then 'a' comes before 'b' in the list. If the position 184 // numbers are the same then nothing can be said about their order without 185 // examining the list. 186 // 187 // On object construction this ordinal is precisely the instruction's index 188 // in the list. Later, instructions inserted via InsertBefore receive 189 // duplicate values. However, monotonicity is preserved. 190 void InsertBeforeInstructions( 191 Item* to_insert, tensorflow::gtl::ArraySlice<Item*> before_instructions) { 192 VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name() 193 << " before {" 194 << tensorflow::str_util::Join(before_instructions, ", ", 195 [](string* out, Item* item) { 196 tensorflow::strings::StrAppend( 197 out, item->instruction->name()); 198 }) 199 << "}"; 200 201 // Find the minimal position number of any instruction in 202 // 'before_instructions'. 203 CHECK(!before_instructions.empty()); 204 Item* min_position_item = nullptr; 205 for (Item* item : before_instructions) { 206 if (min_position_item == nullptr || 207 item->position < min_position_item->position) { 208 min_position_item = item; 209 } 210 } 211 212 // Because more than one instruction in 'before_instructions' may have a 213 // position number of 'min_position_number', find the first such instruction 214 // with position number 'min_position_number'. 215 216 // First find first instruction with the min position. 217 while (min_position_item->prev != nullptr && 218 min_position_item->position == min_position_item->prev->position) { 219 min_position_item = min_position_item->prev; 220 } 221 222 // Now scan forwards until we find one of the before_instructions. 223 while (std::find(before_instructions.begin(), before_instructions.end(), 224 min_position_item) == before_instructions.end()) { 225 min_position_item = min_position_item->next; 226 } 227 return InsertBefore(to_insert, min_position_item); 228 } 229 230 void Blacklist(const HloInstruction* inst) { 231 GetItem(inst)->blacklisted = true; 232 } 233 234 private: 235 // Insert instruction 'item' immediately before 'before' in the list. 236 void InsertBefore(Item* item, Item* before) { 237 VLOG(3) << "InsertBefore: " << item->instruction->name() << " before " 238 << before->instruction->name(); 239 // Insert new item into linked list. 240 item->prev = before->prev; 241 item->next = before; 242 before->prev = item; 243 if (item->prev != nullptr) { 244 item->prev->next = item; 245 } else { 246 first_ = item; 247 } 248 249 // Assign the same position number to the newly added instruction as 250 // 'before'. This guarantees monotonicity of the position numbers, but not 251 // uniqueness. 252 item->position = before->position; 253 } 254 255 Item* first_; 256 257 // Item for each instruction. 258 tensorflow::gtl::FlatMap<const HloInstruction*, Item*> item_map_; 259 }; 260 261 // Return the items which use the given LogicalBuffer. Sets 262 // has_indirect_users to whether any of the uses is indirect. A use is indirect 263 // if the instruction defining logical_buffer is not an operand of the use. This 264 // can happen via buffer aliasing (eg, tuples). 265 ItemList GetUsers(const InstructionList& instruction_list, 266 const LogicalBuffer* logical_buffer, 267 const TuplePointsToAnalysis& points_to_analysis, 268 bool* has_indirect_users) { 269 ItemList users; 270 // To identify uses iterate through all HloInstruction users of the 271 // BufferAliases of the logical buffer. 272 *has_indirect_users = false; 273 for (const BufferAlias& buffer_alias : 274 points_to_analysis.GetBufferAliases(*logical_buffer)) { 275 for (const HloInstruction* user : buffer_alias.instruction()->users()) { 276 if (DoesNotUseOperandBuffer(buffer_alias.instruction(), 277 buffer_alias.index(), user, 278 points_to_analysis)) { 279 // The alias may be an operand of 'user', but the LogicalBuffer cannot 280 // possibly be used by the instruction so ignore 'user'. This is the 281 // case, for example, for the tuple element buffers in a GetTupleElement 282 // instruction (the GTE instruction only uses the pointer vector). 283 continue; 284 } 285 if (buffer_alias.instruction() != logical_buffer->instruction()) { 286 *has_indirect_users = true; 287 } 288 // A buffer may be used by the instruction via more than one alias. For 289 // example, a buffer which appears in more than one element of a tuple. 290 Item* user_item = instruction_list.GetItem(user); 291 if (std::find(users.begin(), users.end(), user_item) == users.end()) { 292 users.push_back(user_item); 293 } 294 } 295 } 296 return users; 297 } 298 299 // Class for tracking memory usage of a computation as the instructions are 300 // placed sequentially. Memory usage is the sum of the sizes of live values 301 // (LogicalBuffers) at the current point in the instruction sequence. 302 class MemoryUsageTracker { 303 public: 304 MemoryUsageTracker( 305 const HloComputation* computation, 306 const HloRematerialization::ShapeSizeFunction& size_function, 307 const TuplePointsToAnalysis& points_to_analysis, 308 const InstructionList& instruction_list); 309 310 // Starts the placement of the given instruction. This adds the sizes of the 311 // LogicalBuffers defined by the instruction to the current memory 312 // usage. Placement is broken into two steps (BeginInstruction and 313 // EndInstruction) to accurately model memory usage. At BeginInstruction the 314 // memory for the output value(s) of the current instruction is allocated. At 315 // EndInstruction memory for dead operand(s) is freed. 316 Status BeginInstruction(Item* item); 317 318 // Finishes the placement of the current instruction. This frees any dead 319 // operands or dead result of the instruction. This must be called after 320 // each call to BeginInstruction. 321 Status EndInstruction(); 322 323 // Returns the number of bytes that the current memory usage will be reduced 324 // if the given instruction is rematerialized. 325 int64 MemoryReducedIfRematerialized(Item* item) const; 326 327 // Adjusts memory usage to account for the rematerialization of 328 // original_item for all remaining unplaced uses. The rematerialization 329 // is remat_item. This method should be called after the HLO graph has 330 // been transformed (rematerialization instruction created and connected to 331 // uses). 332 Status AddRematerializedInstruction(Item* original_item, Item* remat_item); 333 334 // Returns whether the given instruction has been placed (BeginInstruction 335 // has been called with 'instruction' as the argument). 336 bool IsPlaced(const HloInstruction* instruction) const { 337 return instruction_list_.GetItem(instruction)->placed; 338 } 339 340 // Returns the current memory usage. This is the sum of sizes of all live 341 // values. 342 int64 memory_usage() const { return memory_usage_; } 343 344 // Check invariants of the data structure. This is expensive to call. 345 bool Check() const; 346 347 string ToString() const; 348 349 private: 350 // A Buffer represents a single LogicalBuffer in the computation including 351 // various metadata useful for tracking liveness of the value. A LogicalBuffer 352 // is not used directly because the HLO graph is transformed and 353 // TuplePointsToAnalysis which owns all LogicalBuffers cannot be updated after 354 // HLO graph transformations. 355 struct Buffer { 356 // The unique id of this Buffer. This value is equal to the buffer's index 357 // in the vector buffers_. 358 const BufferId id; 359 360 // The instruction which defines this buffer. 361 Item* defining_instruction; 362 363 // The materialized size of the buffer in bytes. 364 const int64 size; 365 366 // Whether this buffer is live-out of the computation. 367 bool live_out; 368 369 // Whether this buffer has indirect uses. Ie, an instruction which is not a 370 // user of defining_instruction uses this buffer. This can occur due to 371 // buffer aliasing (eg, tuples). 372 bool has_indirect_uses; 373 374 // The instructions which use this buffer. 375 ItemList users; 376 377 // The number of users (HloInstructions) of this buffer which have not yet 378 // been placed in the sequence. 379 int64 unfinished_user_count; 380 381 string ToString() const { 382 return tensorflow::strings::StrCat( 383 "Buffer ", id, " (defined by ", 384 defining_instruction->instruction->name(), ", size ", size, 385 " bytes)"); 386 } 387 }; 388 389 // Creates a Buffer representing the given logical buffer. The buffer is added 390 // to buffers_ and a reference is returned. 391 Buffer& CreateBufferFromLogicalBuffer( 392 const LogicalBuffer* logical_buffer, 393 const TuplePointsToAnalysis& points_to_analysis, 394 const HloRematerialization::ShapeSizeFunction& size_function, 395 bool live_out) { 396 bool has_indirect_uses = false; 397 ItemList users = GetUsers(instruction_list_, logical_buffer, 398 points_to_analysis, &has_indirect_uses); 399 return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()), 400 size_function(logical_buffer->shape()), std::move(users), 401 live_out, has_indirect_uses); 402 } 403 404 // Create a new buffer representing a rematerialization of given buffer for 405 // the given uses. 406 Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item, 407 ItemList&& rematerialized_uses) { 408 CHECK(original_buffer.defining_instruction->placed); 409 CHECK(!original_buffer.has_indirect_uses); 410 CHECK(!original_buffer.live_out); 411 for (Item* use : rematerialized_uses) { 412 CHECK(!use->placed); 413 } 414 return NewBuffer(remat_item, original_buffer.size, 415 std::move(rematerialized_uses), /*live_out=*/false, 416 /*has_indirect_uses=*/false); 417 } 418 419 // Return number of bytes allocated for the buffer with the given id. Buffers 420 // allocated by the calling computation (eg, parameter and output buffers) are 421 // considered to have zero bytes because the memory is accounted for in a 422 // different computation. 423 int64 AllocatedSize(BufferId buffer_id) const { 424 const Buffer& buffer = buffers_.at(buffer_id); 425 HloOpcode def_opcode = buffer.defining_instruction->instruction->opcode(); 426 if (buffer.live_out || def_opcode == HloOpcode::kParameter) { 427 return 0; 428 } else { 429 return buffer.size; 430 } 431 } 432 433 // Returns true if BeginInstruction and EndInstruction has been called for the 434 // given instruction. 435 bool IsFinished(Item* item) const { 436 return item->placed && item != in_progress_item_; 437 } 438 439 // Returns whether the given buffer is being used by the in-progress 440 // instruction. 441 bool IsInUse(BufferId buffer_id) const { 442 if (in_progress_item_ == nullptr) { 443 return false; 444 } 445 const BufferIdList& in_progress_uses = in_progress_item_->buffers_used; 446 return std::find(in_progress_uses.begin(), in_progress_uses.end(), 447 buffer_id) != in_progress_uses.end(); 448 } 449 450 // Returns whether the given instruction is live at the current program 451 // point. 452 bool IsCurrentlyLive(BufferId buffer_id) const { 453 const Buffer& buffer = buffers_[buffer_id]; 454 return (buffer.defining_instruction->placed && 455 buffer.unfinished_user_count > 0); 456 } 457 458 // Create a new buffer, add it to buffers_, and return a reference. 459 Buffer& NewBuffer(Item* defining_instruction, int64 size, ItemList&& users, 460 bool live_out, bool has_indirect_uses) { 461 int buffer_id = buffers_.size(); 462 buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out, 463 has_indirect_uses, users, 464 static_cast<int64>(users.size())}); 465 return buffers_.back(); 466 } 467 468 const HloComputation* computation_; 469 470 // Instruction list containing the ordering of instructions in 471 // computation_. This is the order in which instructions are placed 472 // (BeginInstruction/EndInstruction calls). 473 const InstructionList& instruction_list_; 474 475 // Memory usage at the currently placed instruction. 476 int64 memory_usage_ = 0; 477 478 // The instruction currently being placed. This value is non-null only 479 // between the calling of BeginInstruction and EndInstruction. 480 Item* in_progress_item_ = nullptr; 481 482 // All buffers in the computation. 483 std::vector<Buffer> buffers_; 484 }; 485 486 MemoryUsageTracker::MemoryUsageTracker( 487 const HloComputation* computation, 488 const HloRematerialization::ShapeSizeFunction& size_function, 489 const TuplePointsToAnalysis& points_to_analysis, 490 const InstructionList& instruction_list) 491 : computation_(computation), instruction_list_(instruction_list) { 492 PointsToSet::BufferSet live_out_set = 493 points_to_analysis.GetPointsToSet(computation_->root_instruction()) 494 .CreateFlattenedSet(); 495 tensorflow::gtl::FlatMap<const LogicalBuffer*, BufferId> 496 logical_buffer_to_buffer_id; 497 498 for (auto* item = instruction_list_.first(); item != nullptr; 499 item = instruction_list_.next(item)) { 500 const HloInstruction* const instruction = item->instruction; 501 for (const LogicalBuffer* logical_buffer : 502 points_to_analysis.GetBuffersDefinedByInstruction(instruction)) { 503 Buffer* buffer; 504 if (instruction->opcode() == HloOpcode::kWhile) { 505 // The while instruction defines no new buffers. Instead it reuses the 506 // buffers of its operand. Find the Buffer of its operand at the 507 // proper ShapeIndex. 508 const PointsToSet& operand_points_to = 509 points_to_analysis.GetPointsToSet(instruction->operand(0)); 510 CHECK_EQ(operand_points_to.element(logical_buffer->index()).size(), 1); 511 const LogicalBuffer* source_logical_buffer = 512 operand_points_to.element(logical_buffer->index())[0]; 513 buffer = 514 &buffers_.at(logical_buffer_to_buffer_id.at(source_logical_buffer)); 515 516 // Mark buffer as has indirect use and live out. 517 buffer->has_indirect_uses = true; 518 buffer->live_out = 519 buffer->live_out || ContainsKey(live_out_set, logical_buffer); 520 521 // Add users of while to Buffer users. 522 bool unused; 523 for (Item* user_item : GetUsers(instruction_list_, logical_buffer, 524 points_to_analysis, &unused)) { 525 if (std::find(buffer->users.begin(), buffer->users.end(), 526 user_item) == buffer->users.end()) { 527 buffer->users.push_back(user_item); 528 buffer->unfinished_user_count++; 529 user_item->buffers_used.push_back(buffer->id); 530 } 531 } 532 } else { 533 buffer = &CreateBufferFromLogicalBuffer( 534 logical_buffer, points_to_analysis, size_function, 535 ContainsKey(live_out_set, logical_buffer)); 536 item->buffers_defined.push_back(buffer->id); 537 for (Item* user : buffer->users) { 538 user->buffers_used.push_back(buffer->id); 539 } 540 } 541 542 logical_buffer_to_buffer_id[logical_buffer] = buffer->id; 543 } 544 } 545 XLA_VLOG_LINES(10, ToString()); 546 DCHECK(Check()); 547 } 548 549 Status MemoryUsageTracker::BeginInstruction(Item* item) { 550 const HloInstruction* instruction = item->instruction; 551 VLOG(3) << "BeginInstruction " << instruction->name(); 552 TF_RET_CHECK(in_progress_item_ == nullptr); 553 in_progress_item_ = item; 554 555 item->placed = true; 556 557 // All buffers defined by this instruction need memory. 558 for (BufferId buffer_id : item->buffers_defined) { 559 VLOG(3) << " Buffer " << buffers_.at(buffer_id).ToString() 560 << " is now live."; 561 memory_usage_ += AllocatedSize(buffer_id); 562 } 563 564 // TODO(b/37686934): Elementwise instructions can share the buffer of a (dead) 565 // operand. Account for this potential reuse here. 566 567 VLOG(3) << " memory usage = " << memory_usage_; 568 VLOG(10) << ToString(); 569 570 if (VLOG_IS_ON(1)) { 571 DCHECK(Check()); 572 } 573 return Status::OK(); 574 } 575 576 Status MemoryUsageTracker::EndInstruction() { 577 TF_RET_CHECK(in_progress_item_ != nullptr); 578 VLOG(3) << "EndInstruction " << in_progress_item_->instruction->name(); 579 580 for (BufferId buffer_id : in_progress_item_->buffers_used) { 581 Buffer& buffer = buffers_.at(buffer_id); 582 buffer.unfinished_user_count--; 583 CHECK_GE(buffer.unfinished_user_count, 0) 584 << buffer.ToString() << " has negative unfinished use count."; 585 if (buffer.unfinished_user_count == 0) { 586 // Buffer is now dead. 587 VLOG(3) << " " << buffer.ToString() << " is now dead."; 588 memory_usage_ -= AllocatedSize(buffer_id); 589 CHECK_GE(memory_usage_, 0); 590 } 591 } 592 593 // If any buffer defined by this instruction has no uses, then memory can be 594 // reclaimed immediately. 595 for (BufferId buffer_id : in_progress_item_->buffers_defined) { 596 const Buffer& buffer = buffers_.at(buffer_id); 597 if (buffer.unfinished_user_count == 0) { 598 VLOG(3) << " " << buffer.ToString() << " is immediately dead."; 599 memory_usage_ -= AllocatedSize(buffer_id); 600 CHECK_GE(memory_usage_, 0); 601 } 602 } 603 604 in_progress_item_ = nullptr; 605 606 VLOG(3) << " memory usage = " << memory_usage_; 607 VLOG(10) << ToString(); 608 609 if (VLOG_IS_ON(1)) { 610 DCHECK(Check()); 611 } 612 return Status::OK(); 613 } 614 615 int64 MemoryUsageTracker::MemoryReducedIfRematerialized(Item* item) const { 616 CHECK_NE(in_progress_item_, nullptr); 617 if (!item->placed || item == in_progress_item_) { 618 return 0; 619 } 620 621 // TODO(b/37687140): Rematerialization can increase peak memory consumption at 622 // an earlier point in the program if rematerialization extends the live range 623 // of the operand of the instruction being rematerialized across the live 624 // range of the value of instruction being rematerialized. Don't rematerialize 625 // in this case (ie, return 0 here). 626 627 // Compute the amount of memory reduced (if any) by rematerializing 628 // 'instruction'. The LogicalBuffers defined by 'instruction' will no longer 629 // be live at this program point, so initially set memory_reduced to the 630 // size of its defined values. 631 int64 memory_reduced = 0; 632 for (BufferId buffer_id : item->buffers_defined) { 633 // Avoid rematerializing instructions with indirect uses as it is difficult 634 // to reason about liveness after rematerializing the instruction. 635 // TODO(b/37714814): Consider rematerialzing instructions with indirect 636 // uses. 637 if (buffers_.at(buffer_id).has_indirect_uses) { 638 return 0; 639 } 640 641 if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) { 642 memory_reduced += AllocatedSize(buffer_id); 643 } 644 } 645 646 // Account for any logical buffers whose live range must be extended across 647 // this program point. 648 for (BufferId buffer_id : item->buffers_used) { 649 if (!IsCurrentlyLive(buffer_id)) { 650 // This logical buffer is used by 'instruction' but is not live at this 651 // program point. Rematerializing 'instruction' will extend the buffer's 652 // live range across this program point. 653 memory_reduced -= AllocatedSize(buffer_id); 654 } 655 } 656 657 return memory_reduced; 658 } 659 660 Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, 661 Item* remat_item) { 662 VLOG(3) << "AddRematerializedInstruction: original_instruction = " 663 << original_item->instruction->name() 664 << ", remat_instruction = " << remat_item->instruction->name(); 665 666 TF_RET_CHECK(in_progress_item_ != nullptr); 667 TF_RET_CHECK(original_item->placed); 668 TF_RET_CHECK(!remat_item->placed); 669 670 // Construct the list of buffers used and defined by the rematerialization. 671 remat_item->buffers_used = original_item->buffers_used; 672 673 // Account for the additional buffer uses created by the new rematerialization 674 // instruction. Update memory usage if the rematerialization makes a dead 675 // buffer live again. 676 for (BufferId buffer_id : original_item->buffers_used) { 677 Buffer& buffer = buffers_.at(buffer_id); 678 if (buffer.unfinished_user_count == 0) { 679 // Buffer used by this instruction was dead, now is alive. 680 memory_usage_ += AllocatedSize(buffer.id); 681 } 682 683 buffer.unfinished_user_count++; 684 buffer.users.push_back(remat_item); 685 } 686 687 // Create a new set of Buffers defined by the new rematerialization 688 // instruction. Update the internal data structures and memory use to account 689 // for them. 690 for (BufferId old_buffer_id : original_item->buffers_defined) { 691 Buffer& old_buffer = buffers_.at(old_buffer_id); 692 693 ItemList placed_users; 694 ItemList unplaced_users; 695 for (Item* user : old_buffer.users) { 696 if (user->placed) { 697 CHECK(IsFinished(user)); 698 placed_users.push_back(user); 699 } else { 700 unplaced_users.push_back(user); 701 } 702 } 703 old_buffer.users = std::move(placed_users); 704 old_buffer.unfinished_user_count = 0; 705 706 // Buffer is now dead. 707 memory_usage_ -= AllocatedSize(old_buffer.id); 708 709 Buffer& new_buffer = 710 RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users)); 711 712 remat_item->buffers_defined.push_back(new_buffer.id); 713 for (Item* user : new_buffer.users) { 714 BufferIdList& buffers_used = user->buffers_used; 715 std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id, 716 new_buffer.id); 717 } 718 } 719 720 VLOG(3) << " memory usage = " << memory_usage_; 721 XLA_VLOG_LINES(10, ToString()); 722 723 DCHECK(Check()); 724 725 return Status::OK(); 726 } 727 728 string MemoryUsageTracker::ToString() const { 729 string output = tensorflow::strings::StrCat("MemoryUsageTracker for ", 730 computation_->name(), "\n"); 731 tensorflow::strings::StrAppend( 732 &output, "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (", 733 memory_usage(), " bytes)"); 734 for (auto* item = instruction_list_.first(); item != nullptr; 735 item = instruction_list_.next(item)) { 736 const HloInstruction* instruction = item->instruction; 737 string inprogress = item == in_progress_item_ ? " in-progress" : ""; 738 string placed = item->placed ? " placed" : ""; 739 tensorflow::strings::StrAppend(&output, " ", instruction->name(), 740 inprogress, placed, "\n Defines:\n"); 741 for (BufferId buffer_id : item->buffers_defined) { 742 const Buffer& buffer = buffers_[buffer_id]; 743 string live = IsCurrentlyLive(buffer_id) ? " live" : ""; 744 tensorflow::strings::StrAppend(&output, " ", buffer.ToString(), live, 745 ", ", buffer.unfinished_user_count, 746 " unfinished uses\n"); 747 } 748 tensorflow::strings::StrAppend(&output, " Uses:\n"); 749 for (BufferId buffer_id : item->buffers_used) { 750 tensorflow::strings::StrAppend(&output, " ", 751 buffers_[buffer_id].ToString(), "\n"); 752 } 753 } 754 return output; 755 } 756 757 bool MemoryUsageTracker::Check() const { 758 auto elements_are_unique = [](const BufferIdList& vec) { 759 return vec.size() == std::set<BufferId>(vec.begin(), vec.end()).size(); 760 }; 761 762 // Verify buffers_defined per instruction. 763 for (auto* instruction : computation_->instructions()) { 764 const BufferIdList& defined_buffers = 765 instruction_list_.GetItem(instruction)->buffers_defined; 766 CHECK(elements_are_unique(defined_buffers)) 767 << "Instruction " << instruction->name() 768 << " does not have unique defined buffers: " 769 << tensorflow::str_util::Join( 770 defined_buffers, ", ", [this](string* out, BufferId buffer_id) { 771 tensorflow::strings::StrAppend( 772 out, buffers_.at(buffer_id).ToString()); 773 }); 774 775 for (const Buffer& buffer : buffers_) { 776 if (buffer.defining_instruction->instruction == instruction) { 777 CHECK(std::find(defined_buffers.begin(), defined_buffers.end(), 778 buffer.id) != defined_buffers.end()) 779 << "Instruction " << instruction->name() 780 << " defined buffers is missing: " << buffer.ToString(); 781 } 782 } 783 } 784 785 // Verify buffers_used per instruction. 786 for (auto* instruction : computation_->instructions()) { 787 const BufferIdList& used_buffers = 788 instruction_list_.GetItem(instruction)->buffers_used; 789 CHECK(elements_are_unique(used_buffers)) 790 << "Instruction " << instruction->name() 791 << " does not have unique used buffers: " 792 << tensorflow::str_util::Join( 793 used_buffers, ", ", [this](string* out, BufferId buffer_id) { 794 tensorflow::strings::StrAppend( 795 out, buffers_.at(buffer_id).ToString()); 796 }); 797 } 798 for (const Buffer& buffer : buffers_) { 799 int64 unfinished_uses = 0; 800 for (Item* user : buffer.users) { 801 const BufferIdList& used_buffers = user->buffers_used; 802 CHECK(std::find(used_buffers.begin(), used_buffers.end(), buffer.id) != 803 used_buffers.end()) 804 << "Instruction " << user->instruction->name() 805 << " used buffers is missing " << buffer.ToString(); 806 if (!IsFinished(user)) { 807 unfinished_uses++; 808 } 809 } 810 CHECK_EQ(buffer.unfinished_user_count, unfinished_uses) 811 << "Incorrect unplaced use count for " << buffer.ToString(); 812 } 813 return true; 814 } 815 816 // Computes and returns the cost of rematerializing the given instruction. 817 // Cost per rematerialized instruction is defined as: 818 // 819 // memory_limit_bytes / memory_reduced 820 // 821 // The idea is to choose the operation that will save the most memory for 822 // rematerialization and do not worry about how much the compute costs since 823 // running out of memory is more harmful than taking longer to get the answer. 824 int64 RematerializationCost(const HloInstruction* instruction, 825 const MemoryUsageTracker& memory_tracker, 826 int64 memory_reduced, int64 memory_limit_bytes) { 827 // If none of the users of 'instruction' have been placed in the sequence (as 828 // tracked by memory_tracker), then rematerialization of 'instruction' is a 829 // zero-cost move of 'instruction' in the sequence. 830 if (!std::any_of(instruction->users().begin(), instruction->users().end(), 831 [&memory_tracker](const HloInstruction* inst) { 832 return memory_tracker.IsPlaced(inst); 833 })) { 834 return 0; 835 } 836 837 CHECK_GT(memory_reduced, 0); 838 // Return the inverse of the benefit of rematerialization. 839 return memory_limit_bytes / memory_reduced; 840 } 841 842 // Selects and returns the best candidate instruction for rematerialization. 843 // The instruction with lowest rematerialization cost is selected among those 844 // candidate which reduce memory use at the program point of the current 845 // instruction as indicated by memory_tracker. nullptr is returned if no 846 // candidate can be found. 847 Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker, 848 const InstructionList& instruction_list, 849 int64 memory_limit_bytes) { 850 Item* best_item = nullptr; 851 int64 best_cost = 0; 852 853 // TODO(b/35244891): This is currently quadratic in the number of HLO 854 // instructions. 855 for (auto* item = instruction_list.first(); item != nullptr; 856 item = instruction_list.next(item)) { 857 if (!item->placed) { 858 // Only iterate up to the currently placed instruction. 859 // We are trying to reduce memory usage at the placed 860 // instruction so rematerializing later values is of no benefit. 861 break; 862 } 863 HloInstruction* candidate = item->instruction; 864 VLOG(5) << "considering rematerialization candidate " << candidate->name(); 865 866 if (item->blacklisted) { 867 // Skip instructions on the blacklist to avoid infinite loops of 868 // rematerializing the same instruction(s) repeatedly. 869 VLOG(5) << "candidate " << candidate->name() 870 << " is excluded from rematerialization"; 871 continue; 872 } 873 874 if (!IsRematerializable(candidate)) { 875 VLOG(5) << "candidate " << candidate->name() 876 << " not viable: is not rematerializable"; 877 continue; 878 } 879 880 // If any of the candidate's control successor has been placed, we need to 881 // skip this candidate. Otherwise we will violate control dependency. 882 bool control_successor_placed = 883 std::any_of(candidate->control_successors().begin(), 884 candidate->control_successors().end(), 885 [&memory_tracker](const HloInstruction* inst) { 886 return memory_tracker.IsPlaced(inst); 887 }); 888 889 if (control_successor_placed) { 890 continue; 891 } 892 893 const int64 memory_reduced = 894 memory_tracker.MemoryReducedIfRematerialized(item); 895 896 if (memory_reduced <= 0) { 897 VLOG(5) << "candidate " << candidate->name() 898 << " memory reduced = " << memory_reduced << " <= 0"; 899 continue; 900 } 901 902 const int cost = RematerializationCost(candidate, memory_tracker, 903 memory_reduced, memory_limit_bytes); 904 905 VLOG(5) << "candidate " << candidate->name() << ", memory reduced " 906 << memory_reduced << ", cost per byte " << cost; 907 908 if (best_item == nullptr || cost < best_cost) { 909 VLOG(5) << "candidate " << candidate->name() << " now best"; 910 best_item = item; 911 best_cost = cost; 912 } 913 } 914 return best_item; 915 } 916 917 } // namespace 918 919 StatusOr<int64> HloRematerialization::ComputePeakMemory( 920 const HloComputation* computation, 921 const std::vector<const HloInstruction*>& order) const { 922 InstructionList instruction_list(order); 923 MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_, 924 instruction_list); 925 int64 peak_memory = tracker.memory_usage(); 926 for (auto* item = instruction_list.first(); item != nullptr; 927 item = instruction_list.next(item)) { 928 const HloInstruction* instruction = item->instruction; 929 TF_RETURN_IF_ERROR(tracker.BeginInstruction(item)); 930 TF_ASSIGN_OR_RETURN(int64 callee_usage, 931 CalledComputationsMemoryUsage(instruction)); 932 peak_memory = 933 std::max<int64>(peak_memory, tracker.memory_usage() + callee_usage); 934 TF_RETURN_IF_ERROR(tracker.EndInstruction()); 935 } 936 VLOG(1) << "Peak memory for " << computation->name() << ": " 937 << HumanReadableNumBytes(peak_memory); 938 return peak_memory; 939 } 940 941 StatusOr<int64> HloRematerialization::CalledComputationsMemoryUsage( 942 const HloInstruction* instruction) const { 943 const CallSite* callsite = 944 call_graph_->GetNode(instruction->parent()).GetCallSite(instruction); 945 if (callsite == nullptr || callsite->context() == CallContext::kParallel) { 946 return 0; 947 } 948 int64 callee_usage = 0; 949 for (const HloComputation* computation : callsite->called_computations()) { 950 TF_RET_CHECK(ContainsKey(computation_peak_memory_, computation)); 951 callee_usage += computation_peak_memory_.at(computation); 952 } 953 return callee_usage; 954 } 955 956 StatusOr<bool> HloRematerialization::RematerializeComputation( 957 HloComputation* computation, 958 SequentialHloOrdering::HloModuleSequence* sequence, 959 int64 memory_limit_bytes) { 960 VLOG(1) << "Rematerializing computation " << computation->name() 961 << " with limit " << HumanReadableNumBytes(memory_limit_bytes); 962 VLOG(1) << "peak memory usage is " 963 << HumanReadableNumBytes(computation_peak_memory_.at(computation)); 964 CHECK(!ContainsKey(rematerialized_computations_, computation)); 965 966 InstructionList instruction_list(sequence->at(computation)); 967 MemoryUsageTracker memory_tracker(computation, size_function_, 968 *points_to_analysis_, instruction_list); 969 bool changed = false; 970 971 // If the rematerialization makes the source instruction dead, then the 972 // rematerialization is added to 'remat_move_instructions' (the 973 // rematerialization is essentially a move). If the next rematerialization of 974 // the instruction is also a move then the rematerialization is added to the 975 // blacklist. 976 tensorflow::gtl::FlatSet<const HloInstruction*> remat_move_instructions; 977 978 // The peak memory of the computation at any point in the instruction 979 // sequence. 980 int64 peak_memory = memory_tracker.memory_usage(); 981 982 // Total count of instructions rematerialized. 983 int64 remat_count = 0; 984 // Total count of clones created minus number of original rematerialized 985 // instructions which are dead. 986 int64 net_instructions_added = 0; 987 988 const CallGraphNode& call_graph_node = call_graph_->GetNode(computation); 989 990 // Iterate through all instructions in the sequence. At each instruction 991 // (program point) if memory_usage exceeds the specified limit then 992 // rematerialize HLO instructions until memory_usage is reduced. 993 int64 instruction_index = 0; 994 for (auto* item = instruction_list.first(); item != nullptr; 995 item = instruction_list.next(item)) { 996 const HloInstruction* instruction = item->instruction; 997 TF_ASSIGN_OR_RETURN(int64 callee_usage, 998 CalledComputationsMemoryUsage(instruction)); 999 TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(item)); 1000 1001 VLOG(2) << "Program point at " << instruction->name() 1002 << ", memory usage = " << memory_tracker.memory_usage() 1003 << ", callee usage = " << callee_usage << ", [" << instruction_index 1004 << "/" << instruction_list.size() << "]"; 1005 instruction_index++; 1006 1007 while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) { 1008 VLOG(2) << "Over memory limit at instruction " << instruction->name() 1009 << ", using " 1010 << HumanReadableNumBytes(memory_tracker.memory_usage() + 1011 callee_usage) 1012 << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); 1013 1014 Item* best_item = PickRematerializationCandidate( 1015 memory_tracker, instruction_list, memory_limit_bytes); 1016 1017 if (best_item == nullptr) { 1018 VLOG(3) << "Unable to find rematerialization candidate at program " 1019 "point " 1020 << instruction->name() << ". Memory usage = " 1021 << HumanReadableNumBytes(memory_tracker.memory_usage() + 1022 callee_usage); 1023 break; 1024 } 1025 1026 HloInstruction* best = best_item->instruction; 1027 VLOG(1) << "Rematerializing instruction " << best->name() << " (saving " 1028 << HumanReadableNumBytes( 1029 memory_tracker.MemoryReducedIfRematerialized(best_item)) 1030 << ")"; 1031 changed = true; 1032 remat_count++; 1033 1034 HloInstruction* remat = 1035 computation->AddInstruction(best->Clone(/*suffix=*/"remat")); 1036 1037 // Add control dependencies to the new operation. 1038 for (auto successor : best->control_successors()) { 1039 TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor)); 1040 } 1041 for (auto predecessor : best->control_predecessors()) { 1042 TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat)); 1043 } 1044 1045 Item* remat_item = instruction_list.CreateItem(remat); 1046 1047 // Replace each remaining use of 'best' with the rematerialization. 1048 std::vector<HloInstruction*> best_users_copy = best->users(); 1049 for (HloInstruction* user : best_users_copy) { 1050 if (!memory_tracker.IsPlaced(user)) { 1051 VLOG(2) << " Replacing use of " << best->name() << " in " 1052 << user->name() << " with " << remat->name(); 1053 TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat)); 1054 } 1055 } 1056 1057 // Account for the rematerialization in the memory tracker. 1058 TF_RETURN_IF_ERROR( 1059 memory_tracker.AddRematerializedInstruction(best_item, remat_item)); 1060 1061 // Insert rematerialized instruction right before the earliest unplaced 1062 // use of the instruction *and* the earliest unplaced last use of any 1063 // operands of remat. Unplaced uses of the remat's operands are included 1064 // because we don't want to extend the live range of remat's operands as 1065 // this could increase memory usage. 1066 ItemList place_before; 1067 for (auto user : remat->users()) { 1068 place_before.push_back(instruction_list.GetItem(user)); 1069 } 1070 for (auto* operand : remat->operands()) { 1071 for (auto* operand_user : operand->users()) { 1072 if (operand_user != remat) { 1073 Item* operand_user_item = instruction_list.GetItem(operand_user); 1074 if (!operand_user_item->placed) { 1075 place_before.push_back(operand_user_item); 1076 } 1077 } 1078 } 1079 } 1080 // Insert rematerialized instruction before any of its successors to 1081 // preserve ordering regarding control dependency. 1082 for (auto successor : remat->control_successors()) { 1083 Item* successor_item = instruction_list.GetItem(successor); 1084 // Assert to make sure we never remat an operation with control 1085 // successor already placed. 1086 CHECK(!successor_item->placed); 1087 place_before.push_back(successor_item); 1088 } 1089 instruction_list.InsertBeforeInstructions(remat_item, place_before); 1090 1091 // If the rematerialized instruction is dead then rematerialization is 1092 // essentially a move. Don't delete the instruction now because we don't 1093 // want duplicate HloInstruction* values during the course of the 1094 // transformation because we keep maps with HloInstruction* values as 1095 // keys. 1096 if (best->users().empty()) { 1097 VLOG(2) << best->name() << " is now dead"; 1098 if (ContainsKey(remat_move_instructions, best)) { 1099 // Previously, 'best' was a rematerialization which killed the 1100 // instruction it was a copying of. Now 'remat' is a rematerialization 1101 // of 'best' and kills 'best'. Stop rematerializing this instruction 1102 // to avoid an infinite loop. 1103 instruction_list.Blacklist(remat); 1104 } 1105 remat_move_instructions.insert(remat); 1106 } else { 1107 net_instructions_added++; 1108 } 1109 1110 VLOG(1) << "memory_usage after rematerialization = " 1111 << HumanReadableNumBytes(memory_tracker.memory_usage()); 1112 } 1113 1114 const CallSite* callsite = call_graph_node.GetCallSite(instruction); 1115 if (callsite != nullptr && 1116 callsite->context() == CallContext::kSequential && 1117 memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) { 1118 // Memory usage exceeds the limit. Try to rematerialize any 1119 // subcomputation(s) that this instruction calls. 1120 VLOG(1) << "Memory usage still over the limit (" 1121 << (memory_tracker.memory_usage() + callee_usage) << " > " 1122 << memory_limit_bytes 1123 << "). Rematerializing computations called by " 1124 << instruction->name(); 1125 1126 // Recompute callee usage to account for any rematerialization performed 1127 // in the callee computations. 1128 for (HloComputation* called_computation : 1129 callsite->called_computations()) { 1130 if (!ContainsKey(rematerialized_computations_, called_computation)) { 1131 // Memory limit for the subcomputation is the memory limit less the 1132 // amount of memory used at this point in the computation. 1133 int64 subcomputation_memory_limit_bytes = std::max<int64>( 1134 0, memory_limit_bytes - memory_tracker.memory_usage()); 1135 TF_ASSIGN_OR_RETURN( 1136 bool subcomputation_changed, 1137 RematerializeComputation(called_computation, sequence, 1138 subcomputation_memory_limit_bytes)); 1139 changed |= subcomputation_changed; 1140 } 1141 } 1142 TF_ASSIGN_OR_RETURN(callee_usage, 1143 CalledComputationsMemoryUsage(instruction)); 1144 } 1145 1146 peak_memory = std::max<int64>(peak_memory, 1147 memory_tracker.memory_usage() + callee_usage); 1148 VLOG(3) << "peak memory usage = " << HumanReadableNumBytes(peak_memory); 1149 1150 TF_RETURN_IF_ERROR(memory_tracker.EndInstruction()); 1151 } 1152 1153 // Verify some invariants on the memory tracker. 1154 CHECK_EQ(memory_tracker.memory_usage(), 0); 1155 for (auto* instruction : computation->instructions()) { 1156 CHECK(memory_tracker.IsPlaced(instruction)); 1157 } 1158 1159 VLOG(1) << "In computation " << computation->name() << " rematerialized " 1160 << remat_count << " instructions; " << net_instructions_added 1161 << " net instructions added"; 1162 VLOG(1) << " peak memory usage now " << HumanReadableNumBytes(peak_memory) 1163 << " (was " 1164 << HumanReadableNumBytes(computation_peak_memory_.at(computation)) 1165 << ")"; 1166 1167 // Update peak memory used by computation. 1168 computation_peak_memory_.at(computation) = peak_memory; 1169 1170 // Update order to include rematerialized instructions. 1171 auto& dst = sequence->at(computation); 1172 dst.clear(); 1173 for (auto* item = instruction_list.first(); item != nullptr; 1174 item = instruction_list.next(item)) { 1175 const HloInstruction* instruction = item->instruction; 1176 dst.push_back(instruction); 1177 } 1178 rematerialized_computations_.insert(computation); 1179 1180 instructions_rematerialized_ += remat_count; 1181 net_instructions_added_ += net_instructions_added; 1182 1183 return changed; 1184 } 1185 1186 StatusOr<bool> HloRematerialization::Run( 1187 HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, 1188 int64 memory_limit_bytes, RematerializationSizes* sizes) { 1189 // The sequence is constructed entirely by this method. 1190 TF_RET_CHECK(sequence->empty()); 1191 1192 VLOG(1) << "HloRematerialization() with memory limit of " 1193 << HumanReadableNumBytes(memory_limit_bytes); 1194 1195 TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); 1196 1197 // Adjust memory limit to account for the output of the entry 1198 // computation. This is necessary because the per-computation accounting in 1199 // MemoryUsageTracker do not include output as these are typically allocated 1200 // by the caller. 1201 int64 module_output_size = 0; 1202 ShapeUtil::ForEachSubshape( 1203 module->entry_computation()->root_instruction()->shape(), 1204 [&module_output_size, this](const Shape& subshape, 1205 const ShapeIndex& /*index*/) { 1206 module_output_size += size_function_(subshape); 1207 }); 1208 1209 const int64 adjusted_memory_limit_bytes = 1210 memory_limit_bytes - module_output_size; 1211 VLOG(1) << "Adjusted memory limit accounting for output (" 1212 << HumanReadableNumBytes(module_output_size) 1213 << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes); 1214 1215 XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); 1216 // Create initial sequence of HLO instructions. 1217 TF_ASSIGN_OR_RETURN(*sequence, CreateMemoryMinimizingSequence( 1218 *module, 1219 [this](const LogicalBuffer& buffer) { 1220 return size_function_(buffer.shape()); 1221 }, 1222 scheduler_algorithm_)); 1223 // Compute peak memory usage of all computations in the module called in a 1224 // sequential context. 1225 call_graph_ = CallGraph::Build(module); 1226 TF_RETURN_IF_ERROR(call_graph_->VisitNodes( 1227 [this, sequence](const CallGraphNode& node) -> Status { 1228 if (node.context() == CallContext::kSequential) { 1229 TF_ASSIGN_OR_RETURN( 1230 computation_peak_memory_[node.computation()], 1231 ComputePeakMemory(node.computation(), 1232 sequence->at(node.computation()))); 1233 } 1234 return Status::OK(); 1235 }, 1236 /*visit_unreachable_nodes=*/false)); 1237 1238 // The peak memory usage of the module equals the peak memory use of the entry 1239 // computation plus the output size of the computation. This is because the 1240 // peak memory for a computation does not include the output as this is 1241 // typically accounted for in the caller. 1242 const int64 before_peak_memory = 1243 computation_peak_memory_.at(module->entry_computation()) + 1244 module_output_size; 1245 VLOG(1) << "Peak memory usage of module (before): " 1246 << HumanReadableNumBytes(before_peak_memory); 1247 1248 // Subcomputations called by the entry computation will also be 1249 // rematerialized. 1250 TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation( 1251 module->entry_computation(), sequence, 1252 adjusted_memory_limit_bytes)); 1253 1254 // Rematerialization can introduce dead code. This occurs if all uses of an 1255 // instruction are replaced with rematerializations of the instruction. 1256 TF_ASSIGN_OR_RETURN(bool dead_code_removed, HloDCE().Run(module)); 1257 changed |= dead_code_removed; 1258 1259 // After DCE, the module sequence may include instructions which no longer 1260 // exist. 1261 for (const auto* computation : module->MakeNonfusionComputations()) { 1262 if (sequence->at(computation).size() != computation->instruction_count()) { 1263 // A size mismatch between the computation instruction count and the size 1264 // of the ordering of instructions can only be caused by DCE. Rebuild the 1265 // order by removing the deleted instructions from the order. 1266 tensorflow::gtl::FlatSet<const HloInstruction*> instruction_set; 1267 for (const auto& instruction : computation->instructions()) { 1268 instruction_set.insert(instruction); 1269 } 1270 // Move the old order into a temporary vector, then build new order 1271 // inplace. 1272 std::vector<const HloInstruction*>& order = sequence->at(computation); 1273 std::vector<const HloInstruction*> old_order; 1274 using std::swap; 1275 swap(order, old_order); 1276 std::copy_if(old_order.begin(), old_order.end(), 1277 std::back_inserter(order), 1278 [&instruction_set](const HloInstruction* instruction) { 1279 return ContainsKey(instruction_set, instruction); 1280 }); 1281 TF_RET_CHECK(sequence->at(computation).size() == 1282 computation->instruction_count()); 1283 } 1284 } 1285 VLOG(1) << "Rematerialized " << instructions_rematerialized_ 1286 << " instructions in module " << module->name() << "; " 1287 << net_instructions_added_ << " net instructions added"; 1288 const int64 current_peak_memory = 1289 computation_peak_memory_.at(module->entry_computation()) + 1290 module_output_size; 1291 VLOG(1) << "Peak memory usage of module now " 1292 << HumanReadableNumBytes(current_peak_memory) << " (" 1293 << current_peak_memory << " bytes), was " 1294 << HumanReadableNumBytes(before_peak_memory) << " (" 1295 << before_peak_memory << " bytes)"; 1296 const int64 reduced_peak_memory = before_peak_memory - current_peak_memory; 1297 VLOG(1) << "Reduced peak memory by " 1298 << HumanReadableNumBytes(reduced_peak_memory) << " (" 1299 << reduced_peak_memory << " bytes)"; 1300 1301 if (sizes != nullptr) { 1302 sizes->before_bytes = before_peak_memory; 1303 sizes->after_bytes = current_peak_memory; 1304 } 1305 1306 XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString()); 1307 1308 if (current_peak_memory > memory_limit_bytes) { 1309 LOG(WARNING) << tensorflow::strings::Printf( 1310 "Can't reduce memory use below %s (%lld bytes) by rematerialization; " 1311 "only reduced to %s (%lld bytes)", 1312 HumanReadableNumBytes(memory_limit_bytes).c_str(), memory_limit_bytes, 1313 HumanReadableNumBytes(current_peak_memory).c_str(), 1314 current_peak_memory); 1315 } 1316 1317 return changed; 1318 } 1319 1320 /* static */ StatusOr<bool> HloRematerialization::RematerializeAndSchedule( 1321 const HloRematerialization::ShapeSizeFunction& size_function, 1322 int64 memory_limit_bytes, HloModule* hlo_module, 1323 SchedulerAlgorithm scheduler_algorithm, 1324 SequentialHloOrdering::HloModuleSequence* sequence, 1325 RematerializationSizes* sizes) { 1326 HloRematerialization remat(scheduler_algorithm, size_function); 1327 return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes); 1328 } 1329 1330 } // namespace xla 1331