Home | History | Annotate | Download | only in service
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_scheduling.h"
     17 
     18 #include <map>
     19 #include <utility>
     20 #include <vector>
     21 
     22 #include "tensorflow/compiler/xla/service/heap_simulator.h"
     23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     24 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
     25 #include "tensorflow/compiler/xla/shape_util.h"
     26 #include "tensorflow/compiler/xla/status_macros.h"
     27 #include "tensorflow/compiler/xla/statusor.h"
     28 #include "tensorflow/compiler/xla/types.h"
     29 #include "tensorflow/compiler/xla/util.h"
     30 #include "tensorflow/core/lib/core/errors.h"
     31 #include "tensorflow/core/lib/strings/str_util.h"
     32 #include "tensorflow/core/lib/strings/stringprintf.h"
     33 #include "tensorflow/core/platform/logging.h"
     34 
     35 using ::tensorflow::strings::HumanReadableNumBytes;
     36 
     37 namespace xla {
     38 
     39 StatusOr<int64> MinimumMemoryForSequence(
     40     const SequentialHloOrdering::HloModuleSequence& module_sequence,
     41     const LogicalBuffer::SizeFunction& size_function) {
     42   if (module_sequence.empty()) {
     43     return 0;
     44   }
     45 
     46   const HloModule* module = module_sequence.begin()->first->parent();
     47   TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
     48                       TuplePointsToAnalysis::Run(module));
     49 
     50   // The absolute minimum memory required for a given sequence of instructions
     51   // is determined by the sequence of Alloc and Free calls on a simulated heap,
     52   // ignoring fragmentation. We run the heap simulation on the whole module,
     53   // rather than summing each computation, since it gives us a better lower
     54   // bound, by minimizing the liveness of sub-computations.
     55   TF_ASSIGN_OR_RETURN(
     56       HeapSimulator::Result result,
     57       HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), *module,
     58                          module_sequence, *points_to_analysis, size_function));
     59   return result.heap_size;
     60 }
     61 
     62 namespace {
     63 
     64 // Class implementing a list scheduler of HLO instructions which produces a
     65 // sequence which minimizes memory usage.
     66 class ListScheduler {
     67  public:
     68   // Construct and return a memory-minimizing sequence of HLO instructions
     69   // containing the given HLO computation.
     70   static StatusOr<std::vector<const HloInstruction*>> Run(
     71       const HloComputation& computation,
     72       const TuplePointsToAnalysis& points_to_analysis,
     73       const LogicalBuffer::SizeFunction& size_function) {
     74     ListScheduler scheduler(computation, points_to_analysis, size_function);
     75     return scheduler.CreateSchedule();
     76   }
     77 
     78   // Returns whether the memory used by the given HLO should be ignored by the
     79   // scheduling heuristic.
     80   static bool IgnoreInstruction(const HloInstruction& instruction) {
     81     return instruction.opcode() == HloOpcode::kParameter ||
     82            instruction.opcode() == HloOpcode::kConstant;
     83   }
     84 
     85  private:
     86   // The scheduling priority of an instruction is first the number of bytes
     87   // freed by scheduling the instruction, and second (tie-breaker) by the number
     88   // of users. This is represented as a std::pair containing these two values
     89   // (first element is the bytes freed). std::pair provides the necessary
     90   // comparison operators.
     91   using Priority = std::pair<int64, int64>;
     92 
     93   ListScheduler(const HloComputation& computation,
     94                 const TuplePointsToAnalysis& points_to_analysis,
     95                 const LogicalBuffer::SizeFunction& size_function)
     96       : computation_(computation),
     97         points_to_analysis_(points_to_analysis),
     98         size_function_(size_function) {
     99     // Create a map containing the LogicalBuffer uses for each HLO
    100     // instruction. An HLO instruction "uses" a LogicalBuffer if the
    101     // LogicalBuffer is in an operand of the instruction as indicated by
    102     // points-to analysis.
    103     for (auto* instruction : computation.instructions()) {
    104       tensorflow::gtl::FlatSet<const LogicalBuffer*> instr_uses;
    105       for (auto* operand : instruction->operands()) {
    106         for (const LogicalBuffer* buffer :
    107              points_to_analysis.GetBuffersDefinedByInstruction(operand)) {
    108           instr_uses.insert(buffer);
    109         }
    110       }
    111       buffer_uses_[instruction] = std::vector<const LogicalBuffer*>(
    112           instr_uses.begin(), instr_uses.end());
    113     }
    114 
    115     // Create map containing the number of unscheduled uses (hlo instructions)
    116     // of each logical buffer.
    117     for (auto* instruction : computation.instructions()) {
    118       for (auto* buffer :
    119            points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
    120         unscheduled_use_count_[buffer] = 0;
    121       }
    122     }
    123     for (auto* instruction : computation.instructions()) {
    124       for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) {
    125         ++unscheduled_use_count_[buffer];
    126       }
    127     }
    128 
    129     // Buffers live out of the computation have an implicit use at the end of
    130     // the computation.
    131     for (const LogicalBuffer* live_out_buffer :
    132          points_to_analysis.GetPointsToSet(computation.root_instruction())
    133              .CreateFlattenedSet()) {
    134       ++unscheduled_use_count_[live_out_buffer];
    135     }
    136   }
    137 
    138   // Returns whether the memory used by the given buffer should be ignored by
    139   // the scheduling heuristic.
    140   static bool IgnoreBuffer(const LogicalBuffer& buffer) {
    141     return IgnoreInstruction(*buffer.instruction());
    142   }
    143 
    144   // An entry in the worklist used by CreateSchedule.  Corresponds to one
    145   // HloInstruction, plus some cached metadata, saved for the purposes of making
    146   // BytesFreedIfScheduled fast.
    147   struct ReadyListEntry {
    148     const HloInstruction* instruction;
    149 
    150     // The total size of all buffers defined by this instruction.
    151     int64 bytes_defined;
    152 
    153     // For each buffer B used by this instruction, we keep a pair (B, U), where
    154     // U is the number of uses of B that have not yet been scheduled. This pair
    155     // is a pointer into the unscheduled_use_count_ map, so it gets updated for
    156     // free when we update counts in the map.
    157     std::vector<const std::pair<const LogicalBuffer* const, int64>*>
    158         used_buffer_unscheduled_use_counts;
    159   };
    160 
    161   // Creates a ReadyListEntry for the given instruction.
    162   ReadyListEntry MakeReadyListEntry(const HloInstruction* instruction) {
    163     ReadyListEntry entry;
    164     entry.instruction = instruction;
    165 
    166     entry.bytes_defined = 0;
    167     for (auto* buffer :
    168          points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
    169       if (!IgnoreBuffer(*buffer)) {
    170         entry.bytes_defined += size_function_(*buffer);
    171       }
    172     }
    173 
    174     for (auto* buffer : buffer_uses_.at(instruction)) {
    175       if (IgnoreBuffer(*buffer)) {
    176         continue;
    177       }
    178       auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer);
    179       CHECK(unscheduled_use_count_it != unscheduled_use_count_.end());
    180       entry.used_buffer_unscheduled_use_counts.push_back(
    181           &*unscheduled_use_count_it);
    182     }
    183     return entry;
    184   }
    185 
    186   // Returns the number of bytes freed if the HLO instruction is scheduled.
    187   int64 BytesFreedIfScheduled(const ReadyListEntry& entry) {
    188     int64 freed_bytes = 0;
    189     for (const auto& kv : entry.used_buffer_unscheduled_use_counts) {
    190       auto buffer = kv->first;
    191       auto use_count = kv->second;
    192       if (use_count == 1) {
    193         freed_bytes += size_function_(*buffer);
    194       }
    195     }
    196     return freed_bytes - entry.bytes_defined;
    197   }
    198 
    199   // Constructs the scheduling priority of the given instruction.
    200   Priority GetPriority(const ReadyListEntry& entry) {
    201     return {BytesFreedIfScheduled(entry), entry.instruction->user_count()};
    202   }
    203 
    204   std::vector<const HloInstruction*> CreateSchedule() {
    205     std::vector<const HloInstruction*> schedule;
    206 
    207     // Populate the ready list with instructions which have no operands or
    208     // control predecessors.
    209     tensorflow::gtl::FlatMap<const HloInstruction*, int64>
    210         unscheduled_pred_count;
    211     for (auto* instruction : computation_.instructions()) {
    212       // TODO(b/34466113): Replace this and above with successors() or
    213       // predecessors() when these methods are added to HloInstruction.
    214       for (const HloInstruction* user : instruction->users()) {
    215         unscheduled_pred_count[user]++;
    216       }
    217       for (const HloInstruction* succ : instruction->control_successors()) {
    218         unscheduled_pred_count[succ]++;
    219       }
    220     }
    221 
    222     // Use a multimap to sort ReadyListEntry according to their priority.
    223     std::multimap<Priority, ReadyListEntry> ready_queue;
    224 
    225     // Map of ready instructions to their iterators in ready_queue.
    226     tensorflow::gtl::FlatMap<const HloInstruction*,
    227                              std::multimap<Priority, ReadyListEntry>::iterator>
    228         ready_instructions;
    229 
    230     auto add_to_ready_queue = [&](HloInstruction* inst) {
    231       auto entry = MakeReadyListEntry(inst);
    232       auto it = ready_queue.emplace(GetPriority(entry), std::move(entry));
    233       ready_instructions[inst] = it;
    234     };
    235 
    236     for (auto* instruction : computation_.instructions()) {
    237       // Instruction with no operands or control predecessors will
    238       // not be in the map.
    239       if (unscheduled_pred_count.count(instruction) == 0) {
    240         add_to_ready_queue(instruction);
    241       }
    242     }
    243 
    244     while (!ready_queue.empty()) {
    245       // Remove the selected instruction from the ready list and add it to the
    246       // schedule.
    247       auto best_it = ready_queue.end();
    248       --best_it;
    249       const HloInstruction* best = best_it->second.instruction;
    250       ready_queue.erase(best_it);
    251       ready_instructions.erase(best);
    252       schedule.push_back(best);
    253       scheduled_instructions_.insert(best);
    254 
    255       bool adjust_ready_queue = false;
    256       // Update the unscheduled uses of the logical buffers.
    257       for (const LogicalBuffer* buffer : buffer_uses_.at(best)) {
    258         int64& count = unscheduled_use_count_[buffer];
    259         CHECK_GT(count, 0);
    260         --count;
    261         if (count == 1) {
    262           adjust_ready_queue = true;
    263         }
    264       }
    265 
    266       // Add new instructions to ready list.
    267       auto update_pred_count = [&](HloInstruction* inst) {
    268         int64 pred_count = --unscheduled_pred_count.at(inst);
    269         CHECK_GE(pred_count, 0);
    270         if (pred_count == 0) {
    271           add_to_ready_queue(inst);
    272         }
    273       };
    274       // TODO(b/34466113): Replace this and above with successors() or
    275       // predecessors() when these methods are added to HloInstruction.
    276       for (HloInstruction* user : best->users()) {
    277         update_pred_count(user);
    278       }
    279       for (HloInstruction* succ : best->control_successors()) {
    280         update_pred_count(succ);
    281       }
    282       // The unscheduled use count for a buffer has changed to 1, so the
    283       // priorities of some ready instructions may go up. We update them in the
    284       // ready queue, so that they can appear earlier.
    285       if (adjust_ready_queue) {
    286         for (HloInstruction* operand : best->operands()) {
    287           for (HloInstruction* operand_user : operand->users()) {
    288             auto ready_instructions_it = ready_instructions.find(operand_user);
    289             if (ready_instructions_it == ready_instructions.end()) {
    290               continue;
    291             }
    292             auto ready_queue_it = ready_instructions_it->second;
    293             auto& entry = ready_queue_it->second;
    294             Priority new_priority = GetPriority(entry);
    295             if (new_priority == ready_queue_it->first) {
    296               continue;
    297             }
    298             // Create a new entry in ready_queue, then update
    299             // ready_instructions[operand_user] to refer to the new entry.
    300             ready_instructions_it->second =
    301                 ready_queue.emplace(new_priority, std::move(entry));
    302             // Remove the old entry in ready_queue.
    303             ready_queue.erase(ready_queue_it);
    304           }
    305         }
    306       }
    307     }
    308     CHECK_EQ(schedule.size(), computation_.instruction_count());
    309     CHECK_EQ(scheduled_instructions_.size(), computation_.instruction_count());
    310 
    311     return schedule;
    312   }
    313 
    314   const HloComputation& computation_;
    315   const TuplePointsToAnalysis& points_to_analysis_;
    316   const LogicalBuffer::SizeFunction& size_function_;
    317 
    318   // A map containing the LogicalBuffers that each instruction uses.
    319   tensorflow::gtl::FlatMap<const HloInstruction*,
    320                            std::vector<const LogicalBuffer*>>
    321       buffer_uses_;
    322 
    323   // A map containing the count of unscheduled HLOs which using a particular
    324   // LogicalBuffer.  We rely on iterator stability in this map, and that the map
    325   // entries are std::pair's.
    326   std::unordered_map<const LogicalBuffer*, int64> unscheduled_use_count_;
    327 
    328   // Set of instructions which have been scheduled.
    329   tensorflow::gtl::FlatSet<const HloInstruction*> scheduled_instructions_;
    330 };
    331 
    332 int64 SumLogicalBufferSizes(
    333     const TuplePointsToAnalysis::BufferDefinitionVector& buffers,
    334     const LogicalBuffer::SizeFunction& size_function) {
    335   int64 size = 0;
    336   for (const LogicalBuffer* buffer : buffers) {
    337     size += size_function(*buffer);
    338   }
    339   return size;
    340 }
    341 
    342 StatusOr<std::vector<const HloInstruction*>> RunDFSMemoryScheduler(
    343     const HloComputation& computation,
    344     const TuplePointsToAnalysis& points_to_analysis,
    345     const LogicalBuffer::SizeFunction& size_function) {
    346   // This ordering is based on DFS post-order, with a heuristic to decide which
    347   // operand to visit first.  The heuristic is based on 'extra_users', which is
    348   // simply users-1 for each instruction.  By subtracting 1, we're saying that
    349   // instructions with no users or a single user don't count; instructions with
    350   // lots of fan-out will be visited earlier.
    351   tensorflow::gtl::FlatMap<const HloInstruction*, int64> extra_users;
    352   tensorflow::gtl::FlatMap<const HloInstruction*, int64> total_sizes;
    353   for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) {
    354     if (ListScheduler::IgnoreInstruction(*hlo)) {
    355       extra_users[hlo] = 0;
    356       total_sizes[hlo] = 0;
    357       continue;
    358     }
    359     extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1;
    360     total_sizes[hlo] = SumLogicalBufferSizes(
    361         points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function);
    362     tensorflow::gtl::FlatSet<const HloInstruction*> unique_operands(
    363         hlo->operands().begin(), hlo->operands().end());
    364     for (const HloInstruction* operand : unique_operands) {
    365       extra_users[hlo] += extra_users[operand];
    366       total_sizes[hlo] += total_sizes[operand];
    367     }
    368   }
    369   CHECK_EQ(extra_users.size(), computation.instruction_count());
    370   CHECK_EQ(total_sizes.size(), computation.instruction_count());
    371 
    372   // Construct a total order based on DFS post-order, visiting operands in
    373   // decreasing cumulative extra user order, and next by cumulative size, with a
    374   // tiebreaker by name for determinism.
    375   std::vector<const HloInstruction*> sequence;
    376   FunctionVisitor visitor([&sequence](HloInstruction* hlo) {
    377     sequence.push_back(hlo);
    378     return Status::OK();
    379   });
    380   TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder(
    381       &visitor, [&extra_users, &total_sizes](const HloInstruction* a,
    382                                              const HloInstruction* b) {
    383         if (extra_users[a] != extra_users[b]) {
    384           return extra_users[a] > extra_users[b];
    385         }
    386         if (total_sizes[a] != total_sizes[b]) {
    387           return total_sizes[a] > total_sizes[b];
    388         }
    389         return a->name() < b->name();
    390       }));
    391   CHECK_EQ(sequence.size(), computation.instruction_count());
    392   return sequence;
    393 }
    394 
    395 StatusOr<int64> MinimumMemoryForComputation(
    396     const HloComputation& computation,
    397     const std::vector<const HloInstruction*>& sequence,
    398     const TuplePointsToAnalysis& points_to_analysis,
    399     const LogicalBuffer::SizeFunction& size_function) {
    400   TF_ASSIGN_OR_RETURN(
    401       HeapSimulator::Result result,
    402       HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation,
    403                          sequence, points_to_analysis, size_function));
    404   return result.heap_size;
    405 }
    406 
    407 StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
    408     const HloComputation& computation,
    409     const TuplePointsToAnalysis& points_to_analysis,
    410     const LogicalBuffer::SizeFunction& size_function,
    411     SchedulerAlgorithm algorithm) {
    412   VLOG(2) << "Computation: " << computation.name();
    413   if (algorithm == SchedulerAlgorithm::kListSchedule) {
    414     return ListScheduler::Run(computation, points_to_analysis, size_function);
    415   }
    416   if (algorithm == SchedulerAlgorithm::kDfsSchedule) {
    417     return RunDFSMemoryScheduler(computation, points_to_analysis,
    418                                  size_function);
    419   }
    420 
    421   // We try both a list-scheduler based ordering and a DFS based ordering, and
    422   // choose whichever returns a lower min-memory, not accounting for
    423   // fragmentation.
    424   //
    425   // Note that this is just a heuristic. One obvious inaccuracy is that the
    426   // memory required for sub-computations might be different when considered
    427   // within the caller's context. But it's good enough for now.
    428   TF_ASSIGN_OR_RETURN(
    429       std::vector<const HloInstruction*> list_sequence,
    430       ListScheduler::Run(computation, points_to_analysis, size_function));
    431   TF_ASSIGN_OR_RETURN(
    432       const int64 list_memory,
    433       MinimumMemoryForComputation(computation, list_sequence,
    434                                   points_to_analysis, size_function));
    435   VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
    436 
    437   TF_ASSIGN_OR_RETURN(
    438       std::vector<const HloInstruction*> dfs_sequence,
    439       RunDFSMemoryScheduler(computation, points_to_analysis, size_function));
    440   TF_ASSIGN_OR_RETURN(
    441       const int64 dfs_memory,
    442       MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis,
    443                                   size_function));
    444   VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
    445 
    446   if (list_memory <= dfs_memory) {
    447     VLOG(2) << "Chose min-memory list sequence: "
    448             << HumanReadableNumBytes(list_memory);
    449     return list_sequence;
    450   } else {
    451     VLOG(2) << "Chose min-memory dfs sequence: "
    452             << HumanReadableNumBytes(dfs_memory);
    453     return dfs_sequence;
    454   }
    455 }
    456 
    457 }  // namespace
    458 
    459 StatusOr<SequentialHloOrdering::HloModuleSequence>
    460 CreateMemoryMinimizingSequence(const HloModule& module,
    461                                const LogicalBuffer::SizeFunction& size_function,
    462                                SchedulerAlgorithm algorithm) {
    463   SequentialHloOrdering::HloModuleSequence sequence;
    464   TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
    465                       TuplePointsToAnalysis::Run(&module));
    466   for (const auto* computation : module.MakeNonfusionComputations()) {
    467     TF_ASSIGN_OR_RETURN(
    468         sequence[computation],
    469         CreateMemoryMinimizingSequence(*computation, *points_to_analysis,
    470                                        size_function, algorithm));
    471   }
    472   return sequence;
    473 }
    474 
    475 StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
    476     const HloComputation& computation,
    477     const LogicalBuffer::SizeFunction& size_function,
    478     SchedulerAlgorithm algorithm) {
    479   CHECK(!computation.IsFusionComputation());
    480   TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
    481                       TuplePointsToAnalysis::Run(computation.parent()));
    482   return CreateMemoryMinimizingSequence(computation, *points_to_analysis,
    483                                         size_function, algorithm);
    484 }
    485 
    486 }  // namespace xla
    487