Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_rematerialization.h"
     17 
     18 #include <algorithm>
     19 #include <memory>
     20 #include <set>
     21 #include <string>
     22 
     23 #include "tensorflow/compiler/xla/map_util.h"
     24 #include "tensorflow/compiler/xla/primitive_util.h"
     25 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
     26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     27 #include "tensorflow/compiler/xla/service/hlo_dce.h"
     28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     29 #include "tensorflow/compiler/xla/service/hlo_module.h"
     30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     31 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
     32 #include "tensorflow/compiler/xla/service/hlo_scheduling.h"
     33 #include "tensorflow/compiler/xla/service/liveness_util.h"
     34 #include "tensorflow/compiler/xla/service/logical_buffer.h"
     35 #include "tensorflow/compiler/xla/status_macros.h"
     36 #include "tensorflow/compiler/xla/statusor.h"
     37 #include "tensorflow/compiler/xla/types.h"
     38 #include "tensorflow/compiler/xla/util.h"
     39 #include "tensorflow/core/lib/strings/str_util.h"
     40 #include "tensorflow/core/lib/strings/strcat.h"
     41 #include "tensorflow/core/lib/strings/stringprintf.h"
     42 #include "tensorflow/core/platform/logging.h"
     43 
     44 using ::tensorflow::strings::HumanReadableNumBytes;
     45 
     46 namespace xla {
     47 
     48 namespace {
     49 
     50 // Potential optimizations:
     51 // . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue
     52 //   of candidates.
     53 // . Cache IsRematerializable in Item?  Only correct if control
     54 //   predecessors and successors don't change.
     55 
     56 // Returns true if the given instruction is rematerializable.
     57 bool IsRematerializable(const HloInstruction* instruction) {
     58   // Don't rematerialize instructions with side effects or instructions which
     59   // cannot be cloned safely.
     60   switch (instruction->opcode()) {
     61     case HloOpcode::kCall:
     62     case HloOpcode::kConstant:
     63     case HloOpcode::kConditional:
     64     case HloOpcode::kCrossReplicaSum:
     65     case HloOpcode::kCustomCall:
     66     case HloOpcode::kParameter:
     67     case HloOpcode::kWhile:
     68       return false;
     69     default:
     70       return !instruction->HasSideEffect();
     71   }
     72 }
     73 
     74 // Type holding a unique identifier for each Buffer object.
     75 using BufferId = int64;
     76 using BufferIdList = tensorflow::gtl::InlinedVector<BufferId, 3>;
     77 
     78 // We wrap HloInstruction* with an Item that holds auxiliary
     79 // per-instruction state.
     80 struct Item {
     81   HloInstruction* instruction;
     82 
     83   // True once the instruction is marked as placed (when BeginInstruction
     84   // has been called for this instruction).
     85   bool placed = false;
     86 
     87   // To avoid an infinite loop rematerializing the same set of
     88   // instructions ad infinitum, keep a blacklist of instructions
     89   // which should not be rematerialized.
     90   bool blacklisted = false;
     91 
     92   // The buffers defined by this instruction.
     93   BufferIdList buffers_defined;
     94 
     95   // The buffers used by this instruction.
     96   BufferIdList buffers_used;
     97 
     98  private:
     99   friend class InstructionList;
    100 
    101   // Items are arranged in a doubly linked list.
    102   Item* next;
    103   Item* prev;
    104 
    105   // List is ordered by position, which can however be duplicated as
    106   // new instructions are inserted.  See InsertBeforeInstructions
    107   // comment for details.
    108   int64 position;
    109 };
    110 
    111 using ItemList = tensorflow::gtl::InlinedVector<Item*, 3>;
    112 
    113 // Class which maintains an ordered list of instructions with fast insertion
    114 // before arbitrary elements.
    115 class InstructionList {
    116  public:
    117   explicit InstructionList(const std::vector<const HloInstruction*>& order) {
    118     int64 position = 0;
    119     Item* last = nullptr;
    120     for (const HloInstruction* inst : order) {
    121       // Add a new item to the linked list.
    122       Item* item = new Item;
    123       item->next = nullptr;
    124       item->prev = last;
    125       if (last == nullptr) {
    126         first_ = item;
    127       } else {
    128         last->next = item;
    129       }
    130       last = item;
    131 
    132       // Initially position numbers are uniquely assigned in order. Later as
    133       // instructions are added with InsertBefore* methods, some instructions
    134       // may have duplicate position numbers, but the values will be guaranteed
    135       // to be monotonically increasing through the list, and so is still useful
    136       // for quickly(-ish) determining the order of arbitrary instructions in
    137       // the list.
    138       item->instruction = const_cast<HloInstruction*>(inst);
    139       item->position = position;
    140       position++;
    141 
    142       item_map_[inst] = item;
    143     }
    144   }
    145 
    146   ~InstructionList() {
    147     for (Item* item = first_; item != nullptr;) {
    148       Item* next = item->next;
    149       delete item;
    150       item = next;
    151     }
    152   }
    153 
    154   size_t size() const { return item_map_.size(); }
    155 
    156   // For ordered iteration over items.
    157   //    for (auto item = q.first(); item != nullptr; item = q.next(item)) {...}
    158   Item* first() const { return first_; }
    159   Item* next(Item* item) const { return item->next; }
    160 
    161   // Creates an Item for the given instruction, but doesn't add it to the list.
    162   // (Use InsertBeforeInstructions to add the Item to the list.)
    163   Item* CreateItem(HloInstruction* inst) {
    164     Item* item = new Item;
    165     item->instruction = inst;
    166     CHECK(item_map_.insert({inst, item}).second) << "inserting inst twice";
    167     return item;
    168   }
    169 
    170   // Return the Item corresponding to inst.
    171   Item* GetItem(const HloInstruction* inst) const {
    172     auto iter = item_map_.find(inst);
    173     CHECK(iter != item_map_.end()) << "Did not find " << inst->name();
    174     return iter->second;
    175   }
    176 
    177   // Insert instruction 'to_insert' immediately before the earliest instruction
    178   // in 'before_instructions'.
    179   //
    180   // Each instruction gets a non-decreasing ordinal number. We use this to let
    181   // InsertBeforeInstructions quickly insert an instruction before the earliest
    182   // instruction in a set of instructions.  If position_number_[a] <
    183   // position_number_[b] then 'a' comes before 'b' in the list. If the position
    184   // numbers are the same then nothing can be said about their order without
    185   // examining the list.
    186   //
    187   // On object construction this ordinal is precisely the instruction's index
    188   // in the list. Later, instructions inserted via InsertBefore receive
    189   // duplicate values. However, monotonicity is preserved.
    190   void InsertBeforeInstructions(
    191       Item* to_insert, tensorflow::gtl::ArraySlice<Item*> before_instructions) {
    192     VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name()
    193             << " before {"
    194             << tensorflow::str_util::Join(before_instructions, ", ",
    195                                           [](string* out, Item* item) {
    196                                             tensorflow::strings::StrAppend(
    197                                                 out, item->instruction->name());
    198                                           })
    199             << "}";
    200 
    201     // Find the minimal position number of any instruction in
    202     // 'before_instructions'.
    203     CHECK(!before_instructions.empty());
    204     Item* min_position_item = nullptr;
    205     for (Item* item : before_instructions) {
    206       if (min_position_item == nullptr ||
    207           item->position < min_position_item->position) {
    208         min_position_item = item;
    209       }
    210     }
    211 
    212     // Because more than one instruction in 'before_instructions' may have a
    213     // position number of 'min_position_number', find the first such instruction
    214     // with position number 'min_position_number'.
    215 
    216     // First find first instruction with the min position.
    217     while (min_position_item->prev != nullptr &&
    218            min_position_item->position == min_position_item->prev->position) {
    219       min_position_item = min_position_item->prev;
    220     }
    221 
    222     // Now scan forwards until we find one of the before_instructions.
    223     while (std::find(before_instructions.begin(), before_instructions.end(),
    224                      min_position_item) == before_instructions.end()) {
    225       min_position_item = min_position_item->next;
    226     }
    227     return InsertBefore(to_insert, min_position_item);
    228   }
    229 
    230   void Blacklist(const HloInstruction* inst) {
    231     GetItem(inst)->blacklisted = true;
    232   }
    233 
    234  private:
    235   // Insert instruction 'item' immediately before 'before' in the list.
    236   void InsertBefore(Item* item, Item* before) {
    237     VLOG(3) << "InsertBefore: " << item->instruction->name() << " before "
    238             << before->instruction->name();
    239     // Insert new item into linked list.
    240     item->prev = before->prev;
    241     item->next = before;
    242     before->prev = item;
    243     if (item->prev != nullptr) {
    244       item->prev->next = item;
    245     } else {
    246       first_ = item;
    247     }
    248 
    249     // Assign the same position number to the newly added instruction as
    250     // 'before'. This guarantees monotonicity of the position numbers, but not
    251     // uniqueness.
    252     item->position = before->position;
    253   }
    254 
    255   Item* first_;
    256 
    257   // Item for each instruction.
    258   tensorflow::gtl::FlatMap<const HloInstruction*, Item*> item_map_;
    259 };
    260 
    261 // Return the items which use the given LogicalBuffer. Sets
    262 // has_indirect_users to whether any of the uses is indirect. A use is indirect
    263 // if the instruction defining logical_buffer is not an operand of the use. This
    264 // can happen via buffer aliasing (eg, tuples).
    265 ItemList GetUsers(const InstructionList& instruction_list,
    266                   const LogicalBuffer* logical_buffer,
    267                   const TuplePointsToAnalysis& points_to_analysis,
    268                   bool* has_indirect_users) {
    269   ItemList users;
    270   // To identify uses iterate through all HloInstruction users of the
    271   // BufferAliases of the logical buffer.
    272   *has_indirect_users = false;
    273   for (const BufferAlias& buffer_alias :
    274        points_to_analysis.GetBufferAliases(*logical_buffer)) {
    275     for (const HloInstruction* user : buffer_alias.instruction()->users()) {
    276       if (DoesNotUseOperandBuffer(buffer_alias.instruction(),
    277                                   buffer_alias.index(), user,
    278                                   points_to_analysis)) {
    279         // The alias may be an operand of 'user', but the LogicalBuffer cannot
    280         // possibly be used by the instruction so ignore 'user'. This is the
    281         // case, for example, for the tuple element buffers in a GetTupleElement
    282         // instruction (the GTE instruction only uses the pointer vector).
    283         continue;
    284       }
    285       if (buffer_alias.instruction() != logical_buffer->instruction()) {
    286         *has_indirect_users = true;
    287       }
    288       // A buffer may be used by the instruction via more than one alias. For
    289       // example, a buffer which appears in more than one element of a tuple.
    290       Item* user_item = instruction_list.GetItem(user);
    291       if (std::find(users.begin(), users.end(), user_item) == users.end()) {
    292         users.push_back(user_item);
    293       }
    294     }
    295   }
    296   return users;
    297 }
    298 
    299 // Class for tracking memory usage of a computation as the instructions are
    300 // placed sequentially. Memory usage is the sum of the sizes of live values
    301 // (LogicalBuffers) at the current point in the instruction sequence.
    302 class MemoryUsageTracker {
    303  public:
    304   MemoryUsageTracker(
    305       const HloComputation* computation,
    306       const HloRematerialization::ShapeSizeFunction& size_function,
    307       const TuplePointsToAnalysis& points_to_analysis,
    308       const InstructionList& instruction_list);
    309 
    310   // Starts the placement of the given instruction. This adds the sizes of the
    311   // LogicalBuffers defined by the instruction to the current memory
    312   // usage. Placement is broken into two steps (BeginInstruction and
    313   // EndInstruction) to accurately model memory usage. At BeginInstruction the
    314   // memory for the output value(s) of the current instruction is allocated. At
    315   // EndInstruction memory for dead operand(s) is freed.
    316   Status BeginInstruction(Item* item);
    317 
    318   // Finishes the placement of the current instruction. This frees any dead
    319   // operands or dead result of the instruction. This must be called after
    320   // each call to BeginInstruction.
    321   Status EndInstruction();
    322 
    323   // Returns the number of bytes that the current memory usage will be reduced
    324   // if the given instruction is rematerialized.
    325   int64 MemoryReducedIfRematerialized(Item* item) const;
    326 
    327   // Adjusts memory usage to account for the rematerialization of
    328   // original_item for all remaining unplaced uses. The rematerialization
    329   // is remat_item. This method should be called after the HLO graph has
    330   // been transformed (rematerialization instruction created and connected to
    331   // uses).
    332   Status AddRematerializedInstruction(Item* original_item, Item* remat_item);
    333 
    334   // Returns whether the given instruction has been placed (BeginInstruction
    335   // has been called with 'instruction' as the argument).
    336   bool IsPlaced(const HloInstruction* instruction) const {
    337     return instruction_list_.GetItem(instruction)->placed;
    338   }
    339 
    340   // Returns the current memory usage. This is the sum of sizes of all live
    341   // values.
    342   int64 memory_usage() const { return memory_usage_; }
    343 
    344   // Check invariants of the data structure. This is expensive to call.
    345   bool Check() const;
    346 
    347   string ToString() const;
    348 
    349  private:
    350   // A Buffer represents a single LogicalBuffer in the computation including
    351   // various metadata useful for tracking liveness of the value. A LogicalBuffer
    352   // is not used directly because the HLO graph is transformed and
    353   // TuplePointsToAnalysis which owns all LogicalBuffers cannot be updated after
    354   // HLO graph transformations.
    355   struct Buffer {
    356     // The unique id of this Buffer. This value is equal to the buffer's index
    357     // in the vector buffers_.
    358     const BufferId id;
    359 
    360     // The instruction which defines this buffer.
    361     Item* defining_instruction;
    362 
    363     // The materialized size of the buffer in bytes.
    364     const int64 size;
    365 
    366     // Whether this buffer is live-out of the computation.
    367     bool live_out;
    368 
    369     // Whether this buffer has indirect uses. Ie, an instruction which is not a
    370     // user of defining_instruction uses this buffer. This can occur due to
    371     // buffer aliasing (eg, tuples).
    372     bool has_indirect_uses;
    373 
    374     // The instructions which use this buffer.
    375     ItemList users;
    376 
    377     // The number of users (HloInstructions) of this buffer which have not yet
    378     // been placed in the sequence.
    379     int64 unfinished_user_count;
    380 
    381     string ToString() const {
    382       return tensorflow::strings::StrCat(
    383           "Buffer ", id, " (defined by ",
    384           defining_instruction->instruction->name(), ", size ", size,
    385           " bytes)");
    386     }
    387   };
    388 
    389   // Creates a Buffer representing the given logical buffer. The buffer is added
    390   // to buffers_ and a reference is returned.
    391   Buffer& CreateBufferFromLogicalBuffer(
    392       const LogicalBuffer* logical_buffer,
    393       const TuplePointsToAnalysis& points_to_analysis,
    394       const HloRematerialization::ShapeSizeFunction& size_function,
    395       bool live_out) {
    396     bool has_indirect_uses = false;
    397     ItemList users = GetUsers(instruction_list_, logical_buffer,
    398                               points_to_analysis, &has_indirect_uses);
    399     return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
    400                      size_function(logical_buffer->shape()), std::move(users),
    401                      live_out, has_indirect_uses);
    402   }
    403 
    404   // Create a new buffer representing a rematerialization of given buffer for
    405   // the given uses.
    406   Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item,
    407                               ItemList&& rematerialized_uses) {
    408     CHECK(original_buffer.defining_instruction->placed);
    409     CHECK(!original_buffer.has_indirect_uses);
    410     CHECK(!original_buffer.live_out);
    411     for (Item* use : rematerialized_uses) {
    412       CHECK(!use->placed);
    413     }
    414     return NewBuffer(remat_item, original_buffer.size,
    415                      std::move(rematerialized_uses), /*live_out=*/false,
    416                      /*has_indirect_uses=*/false);
    417   }
    418 
    419   // Return number of bytes allocated for the buffer with the given id. Buffers
    420   // allocated by the calling computation (eg, parameter and output buffers) are
    421   // considered to have zero bytes because the memory is accounted for in a
    422   // different computation.
    423   int64 AllocatedSize(BufferId buffer_id) const {
    424     const Buffer& buffer = buffers_.at(buffer_id);
    425     HloOpcode def_opcode = buffer.defining_instruction->instruction->opcode();
    426     if (buffer.live_out || def_opcode == HloOpcode::kParameter) {
    427       return 0;
    428     } else {
    429       return buffer.size;
    430     }
    431   }
    432 
    433   // Returns true if BeginInstruction and EndInstruction has been called for the
    434   // given instruction.
    435   bool IsFinished(Item* item) const {
    436     return item->placed && item != in_progress_item_;
    437   }
    438 
    439   // Returns whether the given buffer is being used by the in-progress
    440   // instruction.
    441   bool IsInUse(BufferId buffer_id) const {
    442     if (in_progress_item_ == nullptr) {
    443       return false;
    444     }
    445     const BufferIdList& in_progress_uses = in_progress_item_->buffers_used;
    446     return std::find(in_progress_uses.begin(), in_progress_uses.end(),
    447                      buffer_id) != in_progress_uses.end();
    448   }
    449 
    450   // Returns whether the given instruction is live at the current program
    451   // point.
    452   bool IsCurrentlyLive(BufferId buffer_id) const {
    453     const Buffer& buffer = buffers_[buffer_id];
    454     return (buffer.defining_instruction->placed &&
    455             buffer.unfinished_user_count > 0);
    456   }
    457 
    458   // Create a new buffer, add it to buffers_, and return a reference.
    459   Buffer& NewBuffer(Item* defining_instruction, int64 size, ItemList&& users,
    460                     bool live_out, bool has_indirect_uses) {
    461     int buffer_id = buffers_.size();
    462     buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out,
    463                               has_indirect_uses, users,
    464                               static_cast<int64>(users.size())});
    465     return buffers_.back();
    466   }
    467 
    468   const HloComputation* computation_;
    469 
    470   // Instruction list containing the ordering of instructions in
    471   // computation_. This is the order in which instructions are placed
    472   // (BeginInstruction/EndInstruction calls).
    473   const InstructionList& instruction_list_;
    474 
    475   // Memory usage at the currently placed instruction.
    476   int64 memory_usage_ = 0;
    477 
    478   // The instruction currently being placed. This value is non-null only
    479   // between the calling of BeginInstruction and EndInstruction.
    480   Item* in_progress_item_ = nullptr;
    481 
    482   // All buffers in the computation.
    483   std::vector<Buffer> buffers_;
    484 };
    485 
    486 MemoryUsageTracker::MemoryUsageTracker(
    487     const HloComputation* computation,
    488     const HloRematerialization::ShapeSizeFunction& size_function,
    489     const TuplePointsToAnalysis& points_to_analysis,
    490     const InstructionList& instruction_list)
    491     : computation_(computation), instruction_list_(instruction_list) {
    492   PointsToSet::BufferSet live_out_set =
    493       points_to_analysis.GetPointsToSet(computation_->root_instruction())
    494           .CreateFlattenedSet();
    495   tensorflow::gtl::FlatMap<const LogicalBuffer*, BufferId>
    496       logical_buffer_to_buffer_id;
    497 
    498   for (auto* item = instruction_list_.first(); item != nullptr;
    499        item = instruction_list_.next(item)) {
    500     const HloInstruction* const instruction = item->instruction;
    501     for (const LogicalBuffer* logical_buffer :
    502          points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
    503       Buffer* buffer;
    504       if (instruction->opcode() == HloOpcode::kWhile) {
    505         // The while instruction defines no new buffers. Instead it reuses the
    506         // buffers of its operand. Find the Buffer of its operand at the
    507         // proper ShapeIndex.
    508         const PointsToSet& operand_points_to =
    509             points_to_analysis.GetPointsToSet(instruction->operand(0));
    510         CHECK_EQ(operand_points_to.element(logical_buffer->index()).size(), 1);
    511         const LogicalBuffer* source_logical_buffer =
    512             operand_points_to.element(logical_buffer->index())[0];
    513         buffer =
    514             &buffers_.at(logical_buffer_to_buffer_id.at(source_logical_buffer));
    515 
    516         // Mark buffer as has indirect use and live out.
    517         buffer->has_indirect_uses = true;
    518         buffer->live_out =
    519             buffer->live_out || ContainsKey(live_out_set, logical_buffer);
    520 
    521         // Add users of while to Buffer users.
    522         bool unused;
    523         for (Item* user_item : GetUsers(instruction_list_, logical_buffer,
    524                                         points_to_analysis, &unused)) {
    525           if (std::find(buffer->users.begin(), buffer->users.end(),
    526                         user_item) == buffer->users.end()) {
    527             buffer->users.push_back(user_item);
    528             buffer->unfinished_user_count++;
    529             user_item->buffers_used.push_back(buffer->id);
    530           }
    531         }
    532       } else {
    533         buffer = &CreateBufferFromLogicalBuffer(
    534             logical_buffer, points_to_analysis, size_function,
    535             ContainsKey(live_out_set, logical_buffer));
    536         item->buffers_defined.push_back(buffer->id);
    537         for (Item* user : buffer->users) {
    538           user->buffers_used.push_back(buffer->id);
    539         }
    540       }
    541 
    542       logical_buffer_to_buffer_id[logical_buffer] = buffer->id;
    543     }
    544   }
    545   XLA_VLOG_LINES(10, ToString());
    546   DCHECK(Check());
    547 }
    548 
    549 Status MemoryUsageTracker::BeginInstruction(Item* item) {
    550   const HloInstruction* instruction = item->instruction;
    551   VLOG(3) << "BeginInstruction " << instruction->name();
    552   TF_RET_CHECK(in_progress_item_ == nullptr);
    553   in_progress_item_ = item;
    554 
    555   item->placed = true;
    556 
    557   // All buffers defined by this instruction need memory.
    558   for (BufferId buffer_id : item->buffers_defined) {
    559     VLOG(3) << "  Buffer " << buffers_.at(buffer_id).ToString()
    560             << " is now live.";
    561     memory_usage_ += AllocatedSize(buffer_id);
    562   }
    563 
    564   // TODO(b/37686934): Elementwise instructions can share the buffer of a (dead)
    565   // operand. Account for this potential reuse here.
    566 
    567   VLOG(3) << "  memory usage = " << memory_usage_;
    568   VLOG(10) << ToString();
    569 
    570   if (VLOG_IS_ON(1)) {
    571     DCHECK(Check());
    572   }
    573   return Status::OK();
    574 }
    575 
    576 Status MemoryUsageTracker::EndInstruction() {
    577   TF_RET_CHECK(in_progress_item_ != nullptr);
    578   VLOG(3) << "EndInstruction " << in_progress_item_->instruction->name();
    579 
    580   for (BufferId buffer_id : in_progress_item_->buffers_used) {
    581     Buffer& buffer = buffers_.at(buffer_id);
    582     buffer.unfinished_user_count--;
    583     CHECK_GE(buffer.unfinished_user_count, 0)
    584         << buffer.ToString() << " has negative unfinished use count.";
    585     if (buffer.unfinished_user_count == 0) {
    586       // Buffer is now dead.
    587       VLOG(3) << "  " << buffer.ToString() << " is now dead.";
    588       memory_usage_ -= AllocatedSize(buffer_id);
    589       CHECK_GE(memory_usage_, 0);
    590     }
    591   }
    592 
    593   // If any buffer defined by this instruction has no uses, then memory can be
    594   // reclaimed immediately.
    595   for (BufferId buffer_id : in_progress_item_->buffers_defined) {
    596     const Buffer& buffer = buffers_.at(buffer_id);
    597     if (buffer.unfinished_user_count == 0) {
    598       VLOG(3) << "  " << buffer.ToString() << " is immediately dead.";
    599       memory_usage_ -= AllocatedSize(buffer_id);
    600       CHECK_GE(memory_usage_, 0);
    601     }
    602   }
    603 
    604   in_progress_item_ = nullptr;
    605 
    606   VLOG(3) << "  memory usage = " << memory_usage_;
    607   VLOG(10) << ToString();
    608 
    609   if (VLOG_IS_ON(1)) {
    610     DCHECK(Check());
    611   }
    612   return Status::OK();
    613 }
    614 
    615 int64 MemoryUsageTracker::MemoryReducedIfRematerialized(Item* item) const {
    616   CHECK_NE(in_progress_item_, nullptr);
    617   if (!item->placed || item == in_progress_item_) {
    618     return 0;
    619   }
    620 
    621   // TODO(b/37687140): Rematerialization can increase peak memory consumption at
    622   // an earlier point in the program if rematerialization extends the live range
    623   // of the operand of the instruction being rematerialized across the live
    624   // range of the value of instruction being rematerialized. Don't rematerialize
    625   // in this case (ie, return 0 here).
    626 
    627   // Compute the amount of memory reduced (if any) by rematerializing
    628   // 'instruction'. The LogicalBuffers defined by 'instruction' will no longer
    629   // be live at this program point, so initially set memory_reduced to the
    630   // size of its defined values.
    631   int64 memory_reduced = 0;
    632   for (BufferId buffer_id : item->buffers_defined) {
    633     // Avoid rematerializing instructions with indirect uses as it is difficult
    634     // to reason about liveness after rematerializing the instruction.
    635     // TODO(b/37714814): Consider rematerialzing instructions with indirect
    636     // uses.
    637     if (buffers_.at(buffer_id).has_indirect_uses) {
    638       return 0;
    639     }
    640 
    641     if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) {
    642       memory_reduced += AllocatedSize(buffer_id);
    643     }
    644   }
    645 
    646   // Account for any logical buffers whose live range must be extended across
    647   // this program point.
    648   for (BufferId buffer_id : item->buffers_used) {
    649     if (!IsCurrentlyLive(buffer_id)) {
    650       // This logical buffer is used by 'instruction' but is not live at this
    651       // program point. Rematerializing 'instruction' will extend the buffer's
    652       // live range across this program point.
    653       memory_reduced -= AllocatedSize(buffer_id);
    654     }
    655   }
    656 
    657   return memory_reduced;
    658 }
    659 
    660 Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
    661                                                         Item* remat_item) {
    662   VLOG(3) << "AddRematerializedInstruction: original_instruction = "
    663           << original_item->instruction->name()
    664           << ", remat_instruction = " << remat_item->instruction->name();
    665 
    666   TF_RET_CHECK(in_progress_item_ != nullptr);
    667   TF_RET_CHECK(original_item->placed);
    668   TF_RET_CHECK(!remat_item->placed);
    669 
    670   // Construct the list of buffers used and defined by the rematerialization.
    671   remat_item->buffers_used = original_item->buffers_used;
    672 
    673   // Account for the additional buffer uses created by the new rematerialization
    674   // instruction. Update memory usage if the rematerialization makes a dead
    675   // buffer live again.
    676   for (BufferId buffer_id : original_item->buffers_used) {
    677     Buffer& buffer = buffers_.at(buffer_id);
    678     if (buffer.unfinished_user_count == 0) {
    679       // Buffer used by this instruction was dead, now is alive.
    680       memory_usage_ += AllocatedSize(buffer.id);
    681     }
    682 
    683     buffer.unfinished_user_count++;
    684     buffer.users.push_back(remat_item);
    685   }
    686 
    687   // Create a new set of Buffers defined by the new rematerialization
    688   // instruction. Update the internal data structures and memory use to account
    689   // for them.
    690   for (BufferId old_buffer_id : original_item->buffers_defined) {
    691     Buffer& old_buffer = buffers_.at(old_buffer_id);
    692 
    693     ItemList placed_users;
    694     ItemList unplaced_users;
    695     for (Item* user : old_buffer.users) {
    696       if (user->placed) {
    697         CHECK(IsFinished(user));
    698         placed_users.push_back(user);
    699       } else {
    700         unplaced_users.push_back(user);
    701       }
    702     }
    703     old_buffer.users = std::move(placed_users);
    704     old_buffer.unfinished_user_count = 0;
    705 
    706     // Buffer is now dead.
    707     memory_usage_ -= AllocatedSize(old_buffer.id);
    708 
    709     Buffer& new_buffer =
    710         RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users));
    711 
    712     remat_item->buffers_defined.push_back(new_buffer.id);
    713     for (Item* user : new_buffer.users) {
    714       BufferIdList& buffers_used = user->buffers_used;
    715       std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id,
    716                    new_buffer.id);
    717     }
    718   }
    719 
    720   VLOG(3) << "  memory usage = " << memory_usage_;
    721   XLA_VLOG_LINES(10, ToString());
    722 
    723   DCHECK(Check());
    724 
    725   return Status::OK();
    726 }
    727 
    728 string MemoryUsageTracker::ToString() const {
    729   string output = tensorflow::strings::StrCat("MemoryUsageTracker for ",
    730                                               computation_->name(), "\n");
    731   tensorflow::strings::StrAppend(
    732       &output, "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (",
    733       memory_usage(), " bytes)");
    734   for (auto* item = instruction_list_.first(); item != nullptr;
    735        item = instruction_list_.next(item)) {
    736     const HloInstruction* instruction = item->instruction;
    737     string inprogress = item == in_progress_item_ ? " in-progress" : "";
    738     string placed = item->placed ? " placed" : "";
    739     tensorflow::strings::StrAppend(&output, "  ", instruction->name(),
    740                                    inprogress, placed, "\n    Defines:\n");
    741     for (BufferId buffer_id : item->buffers_defined) {
    742       const Buffer& buffer = buffers_[buffer_id];
    743       string live = IsCurrentlyLive(buffer_id) ? " live" : "";
    744       tensorflow::strings::StrAppend(&output, "      ", buffer.ToString(), live,
    745                                      ", ", buffer.unfinished_user_count,
    746                                      " unfinished uses\n");
    747     }
    748     tensorflow::strings::StrAppend(&output, "    Uses:\n");
    749     for (BufferId buffer_id : item->buffers_used) {
    750       tensorflow::strings::StrAppend(&output, "      ",
    751                                      buffers_[buffer_id].ToString(), "\n");
    752     }
    753   }
    754   return output;
    755 }
    756 
    757 bool MemoryUsageTracker::Check() const {
    758   auto elements_are_unique = [](const BufferIdList& vec) {
    759     return vec.size() == std::set<BufferId>(vec.begin(), vec.end()).size();
    760   };
    761 
    762   // Verify buffers_defined per instruction.
    763   for (auto* instruction : computation_->instructions()) {
    764     const BufferIdList& defined_buffers =
    765         instruction_list_.GetItem(instruction)->buffers_defined;
    766     CHECK(elements_are_unique(defined_buffers))
    767         << "Instruction " << instruction->name()
    768         << " does not have unique defined buffers: "
    769         << tensorflow::str_util::Join(
    770                defined_buffers, ", ", [this](string* out, BufferId buffer_id) {
    771                  tensorflow::strings::StrAppend(
    772                      out, buffers_.at(buffer_id).ToString());
    773                });
    774 
    775     for (const Buffer& buffer : buffers_) {
    776       if (buffer.defining_instruction->instruction == instruction) {
    777         CHECK(std::find(defined_buffers.begin(), defined_buffers.end(),
    778                         buffer.id) != defined_buffers.end())
    779             << "Instruction " << instruction->name()
    780             << " defined buffers is missing: " << buffer.ToString();
    781       }
    782     }
    783   }
    784 
    785   // Verify buffers_used per instruction.
    786   for (auto* instruction : computation_->instructions()) {
    787     const BufferIdList& used_buffers =
    788         instruction_list_.GetItem(instruction)->buffers_used;
    789     CHECK(elements_are_unique(used_buffers))
    790         << "Instruction " << instruction->name()
    791         << " does not have unique used buffers: "
    792         << tensorflow::str_util::Join(
    793                used_buffers, ", ", [this](string* out, BufferId buffer_id) {
    794                  tensorflow::strings::StrAppend(
    795                      out, buffers_.at(buffer_id).ToString());
    796                });
    797   }
    798   for (const Buffer& buffer : buffers_) {
    799     int64 unfinished_uses = 0;
    800     for (Item* user : buffer.users) {
    801       const BufferIdList& used_buffers = user->buffers_used;
    802       CHECK(std::find(used_buffers.begin(), used_buffers.end(), buffer.id) !=
    803             used_buffers.end())
    804           << "Instruction " << user->instruction->name()
    805           << " used buffers is missing " << buffer.ToString();
    806       if (!IsFinished(user)) {
    807         unfinished_uses++;
    808       }
    809     }
    810     CHECK_EQ(buffer.unfinished_user_count, unfinished_uses)
    811         << "Incorrect unplaced use count for " << buffer.ToString();
    812   }
    813   return true;
    814 }
    815 
    816 // Computes and returns the cost of rematerializing the given instruction.
    817 // Cost per rematerialized instruction is defined as:
    818 //
    819 // memory_limit_bytes / memory_reduced
    820 //
    821 // The idea is to choose the operation that will save the most memory for
    822 // rematerialization and do not worry about how much the compute costs since
    823 // running out of memory is more harmful than taking longer to get the answer.
    824 int64 RematerializationCost(const HloInstruction* instruction,
    825                             const MemoryUsageTracker& memory_tracker,
    826                             int64 memory_reduced, int64 memory_limit_bytes) {
    827   // If none of the users of 'instruction' have been placed in the sequence (as
    828   // tracked by memory_tracker), then rematerialization of 'instruction' is a
    829   // zero-cost move of 'instruction' in the sequence.
    830   if (!std::any_of(instruction->users().begin(), instruction->users().end(),
    831                    [&memory_tracker](const HloInstruction* inst) {
    832                      return memory_tracker.IsPlaced(inst);
    833                    })) {
    834     return 0;
    835   }
    836 
    837   CHECK_GT(memory_reduced, 0);
    838   // Return the inverse of the benefit of rematerialization.
    839   return memory_limit_bytes / memory_reduced;
    840 }
    841 
    842 // Selects and returns the best candidate instruction for rematerialization.
    843 // The instruction with lowest rematerialization cost is selected among those
    844 // candidate which reduce memory use at the program point of the current
    845 // instruction as indicated by memory_tracker. nullptr is returned if no
    846 // candidate can be found.
    847 Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker,
    848                                      const InstructionList& instruction_list,
    849                                      int64 memory_limit_bytes) {
    850   Item* best_item = nullptr;
    851   int64 best_cost = 0;
    852 
    853   // TODO(b/35244891): This is currently quadratic in the number of HLO
    854   // instructions.
    855   for (auto* item = instruction_list.first(); item != nullptr;
    856        item = instruction_list.next(item)) {
    857     if (!item->placed) {
    858       // Only iterate up to the currently placed instruction.
    859       // We are trying to reduce memory usage at the placed
    860       // instruction so rematerializing later values is of no benefit.
    861       break;
    862     }
    863     HloInstruction* candidate = item->instruction;
    864     VLOG(5) << "considering rematerialization candidate " << candidate->name();
    865 
    866     if (item->blacklisted) {
    867       // Skip instructions on the blacklist to avoid infinite loops of
    868       // rematerializing the same instruction(s) repeatedly.
    869       VLOG(5) << "candidate " << candidate->name()
    870               << " is excluded from rematerialization";
    871       continue;
    872     }
    873 
    874     if (!IsRematerializable(candidate)) {
    875       VLOG(5) << "candidate " << candidate->name()
    876               << " not viable: is not rematerializable";
    877       continue;
    878     }
    879 
    880     // If any of the candidate's control successor has been placed, we need to
    881     // skip this candidate. Otherwise we will violate control dependency.
    882     bool control_successor_placed =
    883         std::any_of(candidate->control_successors().begin(),
    884                     candidate->control_successors().end(),
    885                     [&memory_tracker](const HloInstruction* inst) {
    886                       return memory_tracker.IsPlaced(inst);
    887                     });
    888 
    889     if (control_successor_placed) {
    890       continue;
    891     }
    892 
    893     const int64 memory_reduced =
    894         memory_tracker.MemoryReducedIfRematerialized(item);
    895 
    896     if (memory_reduced <= 0) {
    897       VLOG(5) << "candidate " << candidate->name()
    898               << " memory reduced = " << memory_reduced << " <=  0";
    899       continue;
    900     }
    901 
    902     const int cost = RematerializationCost(candidate, memory_tracker,
    903                                            memory_reduced, memory_limit_bytes);
    904 
    905     VLOG(5) << "candidate " << candidate->name() << ", memory reduced "
    906             << memory_reduced << ", cost per byte " << cost;
    907 
    908     if (best_item == nullptr || cost < best_cost) {
    909       VLOG(5) << "candidate " << candidate->name() << " now best";
    910       best_item = item;
    911       best_cost = cost;
    912     }
    913   }
    914   return best_item;
    915 }
    916 
    917 }  // namespace
    918 
    919 StatusOr<int64> HloRematerialization::ComputePeakMemory(
    920     const HloComputation* computation,
    921     const std::vector<const HloInstruction*>& order) const {
    922   InstructionList instruction_list(order);
    923   MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_,
    924                              instruction_list);
    925   int64 peak_memory = tracker.memory_usage();
    926   for (auto* item = instruction_list.first(); item != nullptr;
    927        item = instruction_list.next(item)) {
    928     const HloInstruction* instruction = item->instruction;
    929     TF_RETURN_IF_ERROR(tracker.BeginInstruction(item));
    930     TF_ASSIGN_OR_RETURN(int64 callee_usage,
    931                         CalledComputationsMemoryUsage(instruction));
    932     peak_memory =
    933         std::max<int64>(peak_memory, tracker.memory_usage() + callee_usage);
    934     TF_RETURN_IF_ERROR(tracker.EndInstruction());
    935   }
    936   VLOG(1) << "Peak memory for " << computation->name() << ": "
    937           << HumanReadableNumBytes(peak_memory);
    938   return peak_memory;
    939 }
    940 
    941 StatusOr<int64> HloRematerialization::CalledComputationsMemoryUsage(
    942     const HloInstruction* instruction) const {
    943   const CallSite* callsite =
    944       call_graph_->GetNode(instruction->parent()).GetCallSite(instruction);
    945   if (callsite == nullptr || callsite->context() == CallContext::kParallel) {
    946     return 0;
    947   }
    948   int64 callee_usage = 0;
    949   for (const HloComputation* computation : callsite->called_computations()) {
    950     TF_RET_CHECK(ContainsKey(computation_peak_memory_, computation));
    951     callee_usage += computation_peak_memory_.at(computation);
    952   }
    953   return callee_usage;
    954 }
    955 
    956 StatusOr<bool> HloRematerialization::RematerializeComputation(
    957     HloComputation* computation,
    958     SequentialHloOrdering::HloModuleSequence* sequence,
    959     int64 memory_limit_bytes) {
    960   VLOG(1) << "Rematerializing computation " << computation->name()
    961           << " with limit " << HumanReadableNumBytes(memory_limit_bytes);
    962   VLOG(1) << "peak memory usage is "
    963           << HumanReadableNumBytes(computation_peak_memory_.at(computation));
    964   CHECK(!ContainsKey(rematerialized_computations_, computation));
    965 
    966   InstructionList instruction_list(sequence->at(computation));
    967   MemoryUsageTracker memory_tracker(computation, size_function_,
    968                                     *points_to_analysis_, instruction_list);
    969   bool changed = false;
    970 
    971   // If the rematerialization makes the source instruction dead, then the
    972   // rematerialization is added to 'remat_move_instructions' (the
    973   // rematerialization is essentially a move). If the next rematerialization of
    974   // the instruction is also a move then the rematerialization is added to the
    975   // blacklist.
    976   tensorflow::gtl::FlatSet<const HloInstruction*> remat_move_instructions;
    977 
    978   // The peak memory of the computation at any point in the instruction
    979   // sequence.
    980   int64 peak_memory = memory_tracker.memory_usage();
    981 
    982   // Total count of instructions rematerialized.
    983   int64 remat_count = 0;
    984   // Total count of clones created minus number of original rematerialized
    985   // instructions which are dead.
    986   int64 net_instructions_added = 0;
    987 
    988   const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
    989 
    990   // Iterate through all instructions in the sequence. At each instruction
    991   // (program point) if memory_usage exceeds the specified limit then
    992   // rematerialize HLO instructions until memory_usage is reduced.
    993   int64 instruction_index = 0;
    994   for (auto* item = instruction_list.first(); item != nullptr;
    995        item = instruction_list.next(item)) {
    996     const HloInstruction* instruction = item->instruction;
    997     TF_ASSIGN_OR_RETURN(int64 callee_usage,
    998                         CalledComputationsMemoryUsage(instruction));
    999     TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(item));
   1000 
   1001     VLOG(2) << "Program point at " << instruction->name()
   1002             << ", memory usage = " << memory_tracker.memory_usage()
   1003             << ", callee usage = " << callee_usage << ", [" << instruction_index
   1004             << "/" << instruction_list.size() << "]";
   1005     instruction_index++;
   1006 
   1007     while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
   1008       VLOG(2) << "Over memory limit at instruction " << instruction->name()
   1009               << ", using "
   1010               << HumanReadableNumBytes(memory_tracker.memory_usage() +
   1011                                        callee_usage)
   1012               << ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
   1013 
   1014       Item* best_item = PickRematerializationCandidate(
   1015           memory_tracker, instruction_list, memory_limit_bytes);
   1016 
   1017       if (best_item == nullptr) {
   1018         VLOG(3) << "Unable to find rematerialization candidate at program "
   1019                    "point "
   1020                 << instruction->name() << ". Memory usage = "
   1021                 << HumanReadableNumBytes(memory_tracker.memory_usage() +
   1022                                          callee_usage);
   1023         break;
   1024       }
   1025 
   1026       HloInstruction* best = best_item->instruction;
   1027       VLOG(1) << "Rematerializing instruction " << best->name() << " (saving "
   1028               << HumanReadableNumBytes(
   1029                      memory_tracker.MemoryReducedIfRematerialized(best_item))
   1030               << ")";
   1031       changed = true;
   1032       remat_count++;
   1033 
   1034       HloInstruction* remat =
   1035           computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
   1036 
   1037       // Add control dependencies to the new operation.
   1038       for (auto successor : best->control_successors()) {
   1039         TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor));
   1040       }
   1041       for (auto predecessor : best->control_predecessors()) {
   1042         TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat));
   1043       }
   1044 
   1045       Item* remat_item = instruction_list.CreateItem(remat);
   1046 
   1047       // Replace each remaining use of 'best' with the rematerialization.
   1048       std::vector<HloInstruction*> best_users_copy = best->users();
   1049       for (HloInstruction* user : best_users_copy) {
   1050         if (!memory_tracker.IsPlaced(user)) {
   1051           VLOG(2) << "  Replacing use of " << best->name() << " in "
   1052                   << user->name() << " with " << remat->name();
   1053           TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat));
   1054         }
   1055       }
   1056 
   1057       // Account for the rematerialization in the memory tracker.
   1058       TF_RETURN_IF_ERROR(
   1059           memory_tracker.AddRematerializedInstruction(best_item, remat_item));
   1060 
   1061       // Insert rematerialized instruction right before the earliest unplaced
   1062       // use of the instruction *and* the earliest unplaced last use of any
   1063       // operands of remat. Unplaced uses of the remat's operands are included
   1064       // because we don't want to extend the live range of remat's operands as
   1065       // this could increase memory usage.
   1066       ItemList place_before;
   1067       for (auto user : remat->users()) {
   1068         place_before.push_back(instruction_list.GetItem(user));
   1069       }
   1070       for (auto* operand : remat->operands()) {
   1071         for (auto* operand_user : operand->users()) {
   1072           if (operand_user != remat) {
   1073             Item* operand_user_item = instruction_list.GetItem(operand_user);
   1074             if (!operand_user_item->placed) {
   1075               place_before.push_back(operand_user_item);
   1076             }
   1077           }
   1078         }
   1079       }
   1080       // Insert rematerialized instruction before any of its successors to
   1081       // preserve ordering regarding control dependency.
   1082       for (auto successor : remat->control_successors()) {
   1083         Item* successor_item = instruction_list.GetItem(successor);
   1084         // Assert to make sure we never remat an operation with control
   1085         // successor already placed.
   1086         CHECK(!successor_item->placed);
   1087         place_before.push_back(successor_item);
   1088       }
   1089       instruction_list.InsertBeforeInstructions(remat_item, place_before);
   1090 
   1091       // If the rematerialized instruction is dead then rematerialization is
   1092       // essentially a move. Don't delete the instruction now because we don't
   1093       // want duplicate HloInstruction* values during the course of the
   1094       // transformation because we keep maps with HloInstruction* values as
   1095       // keys.
   1096       if (best->users().empty()) {
   1097         VLOG(2) << best->name() << " is now dead";
   1098         if (ContainsKey(remat_move_instructions, best)) {
   1099           // Previously, 'best' was a rematerialization which killed the
   1100           // instruction it was a copying of. Now 'remat' is a rematerialization
   1101           // of 'best' and kills 'best'. Stop rematerializing this instruction
   1102           // to avoid an infinite loop.
   1103           instruction_list.Blacklist(remat);
   1104         }
   1105         remat_move_instructions.insert(remat);
   1106       } else {
   1107         net_instructions_added++;
   1108       }
   1109 
   1110       VLOG(1) << "memory_usage after rematerialization = "
   1111               << HumanReadableNumBytes(memory_tracker.memory_usage());
   1112     }
   1113 
   1114     const CallSite* callsite = call_graph_node.GetCallSite(instruction);
   1115     if (callsite != nullptr &&
   1116         callsite->context() == CallContext::kSequential &&
   1117         memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
   1118       // Memory usage exceeds the limit. Try to rematerialize any
   1119       // subcomputation(s) that this instruction calls.
   1120       VLOG(1) << "Memory usage still over the limit ("
   1121               << (memory_tracker.memory_usage() + callee_usage) << " > "
   1122               << memory_limit_bytes
   1123               << "). Rematerializing computations called by "
   1124               << instruction->name();
   1125 
   1126       // Recompute callee usage to account for any rematerialization performed
   1127       // in the callee computations.
   1128       for (HloComputation* called_computation :
   1129            callsite->called_computations()) {
   1130         if (!ContainsKey(rematerialized_computations_, called_computation)) {
   1131           // Memory limit for the subcomputation is the memory limit less the
   1132           // amount of memory used at this point in the computation.
   1133           int64 subcomputation_memory_limit_bytes = std::max<int64>(
   1134               0, memory_limit_bytes - memory_tracker.memory_usage());
   1135           TF_ASSIGN_OR_RETURN(
   1136               bool subcomputation_changed,
   1137               RematerializeComputation(called_computation, sequence,
   1138                                        subcomputation_memory_limit_bytes));
   1139           changed |= subcomputation_changed;
   1140         }
   1141       }
   1142       TF_ASSIGN_OR_RETURN(callee_usage,
   1143                           CalledComputationsMemoryUsage(instruction));
   1144     }
   1145 
   1146     peak_memory = std::max<int64>(peak_memory,
   1147                                   memory_tracker.memory_usage() + callee_usage);
   1148     VLOG(3) << "peak memory usage = " << HumanReadableNumBytes(peak_memory);
   1149 
   1150     TF_RETURN_IF_ERROR(memory_tracker.EndInstruction());
   1151   }
   1152 
   1153   // Verify some invariants on the memory tracker.
   1154   CHECK_EQ(memory_tracker.memory_usage(), 0);
   1155   for (auto* instruction : computation->instructions()) {
   1156     CHECK(memory_tracker.IsPlaced(instruction));
   1157   }
   1158 
   1159   VLOG(1) << "In computation " << computation->name() << " rematerialized "
   1160           << remat_count << " instructions; " << net_instructions_added
   1161           << " net instructions added";
   1162   VLOG(1) << "  peak memory usage now " << HumanReadableNumBytes(peak_memory)
   1163           << " (was "
   1164           << HumanReadableNumBytes(computation_peak_memory_.at(computation))
   1165           << ")";
   1166 
   1167   // Update peak memory used by computation.
   1168   computation_peak_memory_.at(computation) = peak_memory;
   1169 
   1170   // Update order to include rematerialized instructions.
   1171   auto& dst = sequence->at(computation);
   1172   dst.clear();
   1173   for (auto* item = instruction_list.first(); item != nullptr;
   1174        item = instruction_list.next(item)) {
   1175     const HloInstruction* instruction = item->instruction;
   1176     dst.push_back(instruction);
   1177   }
   1178   rematerialized_computations_.insert(computation);
   1179 
   1180   instructions_rematerialized_ += remat_count;
   1181   net_instructions_added_ += net_instructions_added;
   1182 
   1183   return changed;
   1184 }
   1185 
   1186 StatusOr<bool> HloRematerialization::Run(
   1187     HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence,
   1188     int64 memory_limit_bytes, RematerializationSizes* sizes) {
   1189   // The sequence is constructed entirely by this method.
   1190   TF_RET_CHECK(sequence->empty());
   1191 
   1192   VLOG(1) << "HloRematerialization() with memory limit of "
   1193           << HumanReadableNumBytes(memory_limit_bytes);
   1194 
   1195   TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
   1196 
   1197   // Adjust memory limit to account for the output of the entry
   1198   // computation. This is necessary because the per-computation accounting in
   1199   // MemoryUsageTracker do not include output as these are typically allocated
   1200   // by the caller.
   1201   int64 module_output_size = 0;
   1202   ShapeUtil::ForEachSubshape(
   1203       module->entry_computation()->root_instruction()->shape(),
   1204       [&module_output_size, this](const Shape& subshape,
   1205                                   const ShapeIndex& /*index*/) {
   1206         module_output_size += size_function_(subshape);
   1207       });
   1208 
   1209   const int64 adjusted_memory_limit_bytes =
   1210       memory_limit_bytes - module_output_size;
   1211   VLOG(1) << "Adjusted memory limit accounting for output ("
   1212           << HumanReadableNumBytes(module_output_size)
   1213           << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes);
   1214 
   1215   XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
   1216   // Create initial sequence of HLO instructions.
   1217   TF_ASSIGN_OR_RETURN(*sequence, CreateMemoryMinimizingSequence(
   1218                                      *module,
   1219                                      [this](const LogicalBuffer& buffer) {
   1220                                        return size_function_(buffer.shape());
   1221                                      },
   1222                                      scheduler_algorithm_));
   1223   // Compute peak memory usage of all computations in the module called in a
   1224   // sequential context.
   1225   call_graph_ = CallGraph::Build(module);
   1226   TF_RETURN_IF_ERROR(call_graph_->VisitNodes(
   1227       [this, sequence](const CallGraphNode& node) -> Status {
   1228         if (node.context() == CallContext::kSequential) {
   1229           TF_ASSIGN_OR_RETURN(
   1230               computation_peak_memory_[node.computation()],
   1231               ComputePeakMemory(node.computation(),
   1232                                 sequence->at(node.computation())));
   1233         }
   1234         return Status::OK();
   1235       },
   1236       /*visit_unreachable_nodes=*/false));
   1237 
   1238   // The peak memory usage of the module equals the peak memory use of the entry
   1239   // computation plus the output size of the computation. This is because the
   1240   // peak memory for a computation does not include the output as this is
   1241   // typically accounted for in the caller.
   1242   const int64 before_peak_memory =
   1243       computation_peak_memory_.at(module->entry_computation()) +
   1244       module_output_size;
   1245   VLOG(1) << "Peak memory usage of module (before): "
   1246           << HumanReadableNumBytes(before_peak_memory);
   1247 
   1248   // Subcomputations called by the entry computation will also be
   1249   // rematerialized.
   1250   TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation(
   1251                                         module->entry_computation(), sequence,
   1252                                         adjusted_memory_limit_bytes));
   1253 
   1254   // Rematerialization can introduce dead code. This occurs if all uses of an
   1255   // instruction are replaced with rematerializations of the instruction.
   1256   TF_ASSIGN_OR_RETURN(bool dead_code_removed, HloDCE().Run(module));
   1257   changed |= dead_code_removed;
   1258 
   1259   // After DCE, the module sequence may include instructions which no longer
   1260   // exist.
   1261   for (const auto* computation : module->MakeNonfusionComputations()) {
   1262     if (sequence->at(computation).size() != computation->instruction_count()) {
   1263       // A size mismatch between the computation instruction count and the size
   1264       // of the ordering of instructions can only be caused by DCE. Rebuild the
   1265       // order by removing the deleted instructions from the order.
   1266       tensorflow::gtl::FlatSet<const HloInstruction*> instruction_set;
   1267       for (const auto& instruction : computation->instructions()) {
   1268         instruction_set.insert(instruction);
   1269       }
   1270       // Move the old order into a temporary vector, then build new order
   1271       // inplace.
   1272       std::vector<const HloInstruction*>& order = sequence->at(computation);
   1273       std::vector<const HloInstruction*> old_order;
   1274       using std::swap;
   1275       swap(order, old_order);
   1276       std::copy_if(old_order.begin(), old_order.end(),
   1277                    std::back_inserter(order),
   1278                    [&instruction_set](const HloInstruction* instruction) {
   1279                      return ContainsKey(instruction_set, instruction);
   1280                    });
   1281       TF_RET_CHECK(sequence->at(computation).size() ==
   1282                    computation->instruction_count());
   1283     }
   1284   }
   1285   VLOG(1) << "Rematerialized " << instructions_rematerialized_
   1286           << " instructions in module " << module->name() << "; "
   1287           << net_instructions_added_ << " net instructions added";
   1288   const int64 current_peak_memory =
   1289       computation_peak_memory_.at(module->entry_computation()) +
   1290       module_output_size;
   1291   VLOG(1) << "Peak memory usage of module now "
   1292           << HumanReadableNumBytes(current_peak_memory) << " ("
   1293           << current_peak_memory << " bytes), was "
   1294           << HumanReadableNumBytes(before_peak_memory) << " ("
   1295           << before_peak_memory << " bytes)";
   1296   const int64 reduced_peak_memory = before_peak_memory - current_peak_memory;
   1297   VLOG(1) << "Reduced peak memory by "
   1298           << HumanReadableNumBytes(reduced_peak_memory) << " ("
   1299           << reduced_peak_memory << " bytes)";
   1300 
   1301   if (sizes != nullptr) {
   1302     sizes->before_bytes = before_peak_memory;
   1303     sizes->after_bytes = current_peak_memory;
   1304   }
   1305 
   1306   XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString());
   1307 
   1308   if (current_peak_memory > memory_limit_bytes) {
   1309     LOG(WARNING) << tensorflow::strings::Printf(
   1310         "Can't reduce memory use below %s (%lld bytes) by rematerialization; "
   1311         "only reduced to %s (%lld bytes)",
   1312         HumanReadableNumBytes(memory_limit_bytes).c_str(), memory_limit_bytes,
   1313         HumanReadableNumBytes(current_peak_memory).c_str(),
   1314         current_peak_memory);
   1315   }
   1316 
   1317   return changed;
   1318 }
   1319 
   1320 /* static */ StatusOr<bool> HloRematerialization::RematerializeAndSchedule(
   1321     const HloRematerialization::ShapeSizeFunction& size_function,
   1322     int64 memory_limit_bytes, HloModule* hlo_module,
   1323     SchedulerAlgorithm scheduler_algorithm,
   1324     SequentialHloOrdering::HloModuleSequence* sequence,
   1325     RematerializationSizes* sizes) {
   1326   HloRematerialization remat(scheduler_algorithm, size_function);
   1327   return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes);
   1328 }
   1329 
   1330 }  // namespace xla
   1331