1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_scheduling.h" 17 18 #include <map> 19 #include <utility> 20 #include <vector> 21 22 #include "tensorflow/compiler/xla/service/heap_simulator.h" 23 #include "tensorflow/compiler/xla/service/hlo_computation.h" 24 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 25 #include "tensorflow/compiler/xla/shape_util.h" 26 #include "tensorflow/compiler/xla/status_macros.h" 27 #include "tensorflow/compiler/xla/statusor.h" 28 #include "tensorflow/compiler/xla/types.h" 29 #include "tensorflow/compiler/xla/util.h" 30 #include "tensorflow/core/lib/core/errors.h" 31 #include "tensorflow/core/lib/strings/str_util.h" 32 #include "tensorflow/core/lib/strings/stringprintf.h" 33 #include "tensorflow/core/platform/logging.h" 34 35 using ::tensorflow::strings::HumanReadableNumBytes; 36 37 namespace xla { 38 39 StatusOr<int64> MinimumMemoryForSequence( 40 const SequentialHloOrdering::HloModuleSequence& module_sequence, 41 const LogicalBuffer::SizeFunction& size_function) { 42 if (module_sequence.empty()) { 43 return 0; 44 } 45 46 const HloModule* module = module_sequence.begin()->first->parent(); 47 TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, 48 TuplePointsToAnalysis::Run(module)); 49 50 // The absolute minimum memory required for a given sequence of instructions 51 // is determined by the sequence of Alloc and Free calls on a simulated heap, 52 // ignoring fragmentation. We run the heap simulation on the whole module, 53 // rather than summing each computation, since it gives us a better lower 54 // bound, by minimizing the liveness of sub-computations. 55 TF_ASSIGN_OR_RETURN( 56 HeapSimulator::Result result, 57 HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), *module, 58 module_sequence, *points_to_analysis, size_function)); 59 return result.heap_size; 60 } 61 62 namespace { 63 64 // Class implementing a list scheduler of HLO instructions which produces a 65 // sequence which minimizes memory usage. 66 class ListScheduler { 67 public: 68 // Construct and return a memory-minimizing sequence of HLO instructions 69 // containing the given HLO computation. 70 static StatusOr<std::vector<const HloInstruction*>> Run( 71 const HloComputation& computation, 72 const TuplePointsToAnalysis& points_to_analysis, 73 const LogicalBuffer::SizeFunction& size_function) { 74 ListScheduler scheduler(computation, points_to_analysis, size_function); 75 return scheduler.CreateSchedule(); 76 } 77 78 // Returns whether the memory used by the given HLO should be ignored by the 79 // scheduling heuristic. 80 static bool IgnoreInstruction(const HloInstruction& instruction) { 81 return instruction.opcode() == HloOpcode::kParameter || 82 instruction.opcode() == HloOpcode::kConstant; 83 } 84 85 private: 86 // The scheduling priority of an instruction is first the number of bytes 87 // freed by scheduling the instruction, and second (tie-breaker) by the number 88 // of users. This is represented as a std::pair containing these two values 89 // (first element is the bytes freed). std::pair provides the necessary 90 // comparison operators. 91 using Priority = std::pair<int64, int64>; 92 93 ListScheduler(const HloComputation& computation, 94 const TuplePointsToAnalysis& points_to_analysis, 95 const LogicalBuffer::SizeFunction& size_function) 96 : computation_(computation), 97 points_to_analysis_(points_to_analysis), 98 size_function_(size_function) { 99 // Create a map containing the LogicalBuffer uses for each HLO 100 // instruction. An HLO instruction "uses" a LogicalBuffer if the 101 // LogicalBuffer is in an operand of the instruction as indicated by 102 // points-to analysis. 103 for (auto* instruction : computation.instructions()) { 104 tensorflow::gtl::FlatSet<const LogicalBuffer*> instr_uses; 105 for (auto* operand : instruction->operands()) { 106 for (const LogicalBuffer* buffer : 107 points_to_analysis.GetBuffersDefinedByInstruction(operand)) { 108 instr_uses.insert(buffer); 109 } 110 } 111 buffer_uses_[instruction] = std::vector<const LogicalBuffer*>( 112 instr_uses.begin(), instr_uses.end()); 113 } 114 115 // Create map containing the number of unscheduled uses (hlo instructions) 116 // of each logical buffer. 117 for (auto* instruction : computation.instructions()) { 118 for (auto* buffer : 119 points_to_analysis.GetBuffersDefinedByInstruction(instruction)) { 120 unscheduled_use_count_[buffer] = 0; 121 } 122 } 123 for (auto* instruction : computation.instructions()) { 124 for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) { 125 ++unscheduled_use_count_[buffer]; 126 } 127 } 128 129 // Buffers live out of the computation have an implicit use at the end of 130 // the computation. 131 for (const LogicalBuffer* live_out_buffer : 132 points_to_analysis.GetPointsToSet(computation.root_instruction()) 133 .CreateFlattenedSet()) { 134 ++unscheduled_use_count_[live_out_buffer]; 135 } 136 } 137 138 // Returns whether the memory used by the given buffer should be ignored by 139 // the scheduling heuristic. 140 static bool IgnoreBuffer(const LogicalBuffer& buffer) { 141 return IgnoreInstruction(*buffer.instruction()); 142 } 143 144 // An entry in the worklist used by CreateSchedule. Corresponds to one 145 // HloInstruction, plus some cached metadata, saved for the purposes of making 146 // BytesFreedIfScheduled fast. 147 struct ReadyListEntry { 148 const HloInstruction* instruction; 149 150 // The total size of all buffers defined by this instruction. 151 int64 bytes_defined; 152 153 // For each buffer B used by this instruction, we keep a pair (B, U), where 154 // U is the number of uses of B that have not yet been scheduled. This pair 155 // is a pointer into the unscheduled_use_count_ map, so it gets updated for 156 // free when we update counts in the map. 157 std::vector<const std::pair<const LogicalBuffer* const, int64>*> 158 used_buffer_unscheduled_use_counts; 159 }; 160 161 // Creates a ReadyListEntry for the given instruction. 162 ReadyListEntry MakeReadyListEntry(const HloInstruction* instruction) { 163 ReadyListEntry entry; 164 entry.instruction = instruction; 165 166 entry.bytes_defined = 0; 167 for (auto* buffer : 168 points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) { 169 if (!IgnoreBuffer(*buffer)) { 170 entry.bytes_defined += size_function_(*buffer); 171 } 172 } 173 174 for (auto* buffer : buffer_uses_.at(instruction)) { 175 if (IgnoreBuffer(*buffer)) { 176 continue; 177 } 178 auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer); 179 CHECK(unscheduled_use_count_it != unscheduled_use_count_.end()); 180 entry.used_buffer_unscheduled_use_counts.push_back( 181 &*unscheduled_use_count_it); 182 } 183 return entry; 184 } 185 186 // Returns the number of bytes freed if the HLO instruction is scheduled. 187 int64 BytesFreedIfScheduled(const ReadyListEntry& entry) { 188 int64 freed_bytes = 0; 189 for (const auto& kv : entry.used_buffer_unscheduled_use_counts) { 190 auto buffer = kv->first; 191 auto use_count = kv->second; 192 if (use_count == 1) { 193 freed_bytes += size_function_(*buffer); 194 } 195 } 196 return freed_bytes - entry.bytes_defined; 197 } 198 199 // Constructs the scheduling priority of the given instruction. 200 Priority GetPriority(const ReadyListEntry& entry) { 201 return {BytesFreedIfScheduled(entry), entry.instruction->user_count()}; 202 } 203 204 std::vector<const HloInstruction*> CreateSchedule() { 205 std::vector<const HloInstruction*> schedule; 206 207 // Populate the ready list with instructions which have no operands or 208 // control predecessors. 209 tensorflow::gtl::FlatMap<const HloInstruction*, int64> 210 unscheduled_pred_count; 211 for (auto* instruction : computation_.instructions()) { 212 // TODO(b/34466113): Replace this and above with successors() or 213 // predecessors() when these methods are added to HloInstruction. 214 for (const HloInstruction* user : instruction->users()) { 215 unscheduled_pred_count[user]++; 216 } 217 for (const HloInstruction* succ : instruction->control_successors()) { 218 unscheduled_pred_count[succ]++; 219 } 220 } 221 222 // Use a multimap to sort ReadyListEntry according to their priority. 223 std::multimap<Priority, ReadyListEntry> ready_queue; 224 225 // Map of ready instructions to their iterators in ready_queue. 226 tensorflow::gtl::FlatMap<const HloInstruction*, 227 std::multimap<Priority, ReadyListEntry>::iterator> 228 ready_instructions; 229 230 auto add_to_ready_queue = [&](HloInstruction* inst) { 231 auto entry = MakeReadyListEntry(inst); 232 auto it = ready_queue.emplace(GetPriority(entry), std::move(entry)); 233 ready_instructions[inst] = it; 234 }; 235 236 for (auto* instruction : computation_.instructions()) { 237 // Instruction with no operands or control predecessors will 238 // not be in the map. 239 if (unscheduled_pred_count.count(instruction) == 0) { 240 add_to_ready_queue(instruction); 241 } 242 } 243 244 while (!ready_queue.empty()) { 245 // Remove the selected instruction from the ready list and add it to the 246 // schedule. 247 auto best_it = ready_queue.end(); 248 --best_it; 249 const HloInstruction* best = best_it->second.instruction; 250 ready_queue.erase(best_it); 251 ready_instructions.erase(best); 252 schedule.push_back(best); 253 scheduled_instructions_.insert(best); 254 255 bool adjust_ready_queue = false; 256 // Update the unscheduled uses of the logical buffers. 257 for (const LogicalBuffer* buffer : buffer_uses_.at(best)) { 258 int64& count = unscheduled_use_count_[buffer]; 259 CHECK_GT(count, 0); 260 --count; 261 if (count == 1) { 262 adjust_ready_queue = true; 263 } 264 } 265 266 // Add new instructions to ready list. 267 auto update_pred_count = [&](HloInstruction* inst) { 268 int64 pred_count = --unscheduled_pred_count.at(inst); 269 CHECK_GE(pred_count, 0); 270 if (pred_count == 0) { 271 add_to_ready_queue(inst); 272 } 273 }; 274 // TODO(b/34466113): Replace this and above with successors() or 275 // predecessors() when these methods are added to HloInstruction. 276 for (HloInstruction* user : best->users()) { 277 update_pred_count(user); 278 } 279 for (HloInstruction* succ : best->control_successors()) { 280 update_pred_count(succ); 281 } 282 // The unscheduled use count for a buffer has changed to 1, so the 283 // priorities of some ready instructions may go up. We update them in the 284 // ready queue, so that they can appear earlier. 285 if (adjust_ready_queue) { 286 for (HloInstruction* operand : best->operands()) { 287 for (HloInstruction* operand_user : operand->users()) { 288 auto ready_instructions_it = ready_instructions.find(operand_user); 289 if (ready_instructions_it == ready_instructions.end()) { 290 continue; 291 } 292 auto ready_queue_it = ready_instructions_it->second; 293 auto& entry = ready_queue_it->second; 294 Priority new_priority = GetPriority(entry); 295 if (new_priority == ready_queue_it->first) { 296 continue; 297 } 298 // Create a new entry in ready_queue, then update 299 // ready_instructions[operand_user] to refer to the new entry. 300 ready_instructions_it->second = 301 ready_queue.emplace(new_priority, std::move(entry)); 302 // Remove the old entry in ready_queue. 303 ready_queue.erase(ready_queue_it); 304 } 305 } 306 } 307 } 308 CHECK_EQ(schedule.size(), computation_.instruction_count()); 309 CHECK_EQ(scheduled_instructions_.size(), computation_.instruction_count()); 310 311 return schedule; 312 } 313 314 const HloComputation& computation_; 315 const TuplePointsToAnalysis& points_to_analysis_; 316 const LogicalBuffer::SizeFunction& size_function_; 317 318 // A map containing the LogicalBuffers that each instruction uses. 319 tensorflow::gtl::FlatMap<const HloInstruction*, 320 std::vector<const LogicalBuffer*>> 321 buffer_uses_; 322 323 // A map containing the count of unscheduled HLOs which using a particular 324 // LogicalBuffer. We rely on iterator stability in this map, and that the map 325 // entries are std::pair's. 326 std::unordered_map<const LogicalBuffer*, int64> unscheduled_use_count_; 327 328 // Set of instructions which have been scheduled. 329 tensorflow::gtl::FlatSet<const HloInstruction*> scheduled_instructions_; 330 }; 331 332 int64 SumLogicalBufferSizes( 333 const TuplePointsToAnalysis::BufferDefinitionVector& buffers, 334 const LogicalBuffer::SizeFunction& size_function) { 335 int64 size = 0; 336 for (const LogicalBuffer* buffer : buffers) { 337 size += size_function(*buffer); 338 } 339 return size; 340 } 341 342 StatusOr<std::vector<const HloInstruction*>> RunDFSMemoryScheduler( 343 const HloComputation& computation, 344 const TuplePointsToAnalysis& points_to_analysis, 345 const LogicalBuffer::SizeFunction& size_function) { 346 // This ordering is based on DFS post-order, with a heuristic to decide which 347 // operand to visit first. The heuristic is based on 'extra_users', which is 348 // simply users-1 for each instruction. By subtracting 1, we're saying that 349 // instructions with no users or a single user don't count; instructions with 350 // lots of fan-out will be visited earlier. 351 tensorflow::gtl::FlatMap<const HloInstruction*, int64> extra_users; 352 tensorflow::gtl::FlatMap<const HloInstruction*, int64> total_sizes; 353 for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { 354 if (ListScheduler::IgnoreInstruction(*hlo)) { 355 extra_users[hlo] = 0; 356 total_sizes[hlo] = 0; 357 continue; 358 } 359 extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1; 360 total_sizes[hlo] = SumLogicalBufferSizes( 361 points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); 362 tensorflow::gtl::FlatSet<const HloInstruction*> unique_operands( 363 hlo->operands().begin(), hlo->operands().end()); 364 for (const HloInstruction* operand : unique_operands) { 365 extra_users[hlo] += extra_users[operand]; 366 total_sizes[hlo] += total_sizes[operand]; 367 } 368 } 369 CHECK_EQ(extra_users.size(), computation.instruction_count()); 370 CHECK_EQ(total_sizes.size(), computation.instruction_count()); 371 372 // Construct a total order based on DFS post-order, visiting operands in 373 // decreasing cumulative extra user order, and next by cumulative size, with a 374 // tiebreaker by name for determinism. 375 std::vector<const HloInstruction*> sequence; 376 FunctionVisitor visitor([&sequence](HloInstruction* hlo) { 377 sequence.push_back(hlo); 378 return Status::OK(); 379 }); 380 TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder( 381 &visitor, [&extra_users, &total_sizes](const HloInstruction* a, 382 const HloInstruction* b) { 383 if (extra_users[a] != extra_users[b]) { 384 return extra_users[a] > extra_users[b]; 385 } 386 if (total_sizes[a] != total_sizes[b]) { 387 return total_sizes[a] > total_sizes[b]; 388 } 389 return a->name() < b->name(); 390 })); 391 CHECK_EQ(sequence.size(), computation.instruction_count()); 392 return sequence; 393 } 394 395 StatusOr<int64> MinimumMemoryForComputation( 396 const HloComputation& computation, 397 const std::vector<const HloInstruction*>& sequence, 398 const TuplePointsToAnalysis& points_to_analysis, 399 const LogicalBuffer::SizeFunction& size_function) { 400 TF_ASSIGN_OR_RETURN( 401 HeapSimulator::Result result, 402 HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation, 403 sequence, points_to_analysis, size_function)); 404 return result.heap_size; 405 } 406 407 StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence( 408 const HloComputation& computation, 409 const TuplePointsToAnalysis& points_to_analysis, 410 const LogicalBuffer::SizeFunction& size_function, 411 SchedulerAlgorithm algorithm) { 412 VLOG(2) << "Computation: " << computation.name(); 413 if (algorithm == SchedulerAlgorithm::kListSchedule) { 414 return ListScheduler::Run(computation, points_to_analysis, size_function); 415 } 416 if (algorithm == SchedulerAlgorithm::kDfsSchedule) { 417 return RunDFSMemoryScheduler(computation, points_to_analysis, 418 size_function); 419 } 420 421 // We try both a list-scheduler based ordering and a DFS based ordering, and 422 // choose whichever returns a lower min-memory, not accounting for 423 // fragmentation. 424 // 425 // Note that this is just a heuristic. One obvious inaccuracy is that the 426 // memory required for sub-computations might be different when considered 427 // within the caller's context. But it's good enough for now. 428 TF_ASSIGN_OR_RETURN( 429 std::vector<const HloInstruction*> list_sequence, 430 ListScheduler::Run(computation, points_to_analysis, size_function)); 431 TF_ASSIGN_OR_RETURN( 432 const int64 list_memory, 433 MinimumMemoryForComputation(computation, list_sequence, 434 points_to_analysis, size_function)); 435 VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); 436 437 TF_ASSIGN_OR_RETURN( 438 std::vector<const HloInstruction*> dfs_sequence, 439 RunDFSMemoryScheduler(computation, points_to_analysis, size_function)); 440 TF_ASSIGN_OR_RETURN( 441 const int64 dfs_memory, 442 MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, 443 size_function)); 444 VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); 445 446 if (list_memory <= dfs_memory) { 447 VLOG(2) << "Chose min-memory list sequence: " 448 << HumanReadableNumBytes(list_memory); 449 return list_sequence; 450 } else { 451 VLOG(2) << "Chose min-memory dfs sequence: " 452 << HumanReadableNumBytes(dfs_memory); 453 return dfs_sequence; 454 } 455 } 456 457 } // namespace 458 459 StatusOr<SequentialHloOrdering::HloModuleSequence> 460 CreateMemoryMinimizingSequence(const HloModule& module, 461 const LogicalBuffer::SizeFunction& size_function, 462 SchedulerAlgorithm algorithm) { 463 SequentialHloOrdering::HloModuleSequence sequence; 464 TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, 465 TuplePointsToAnalysis::Run(&module)); 466 for (const auto* computation : module.MakeNonfusionComputations()) { 467 TF_ASSIGN_OR_RETURN( 468 sequence[computation], 469 CreateMemoryMinimizingSequence(*computation, *points_to_analysis, 470 size_function, algorithm)); 471 } 472 return sequence; 473 } 474 475 StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence( 476 const HloComputation& computation, 477 const LogicalBuffer::SizeFunction& size_function, 478 SchedulerAlgorithm algorithm) { 479 CHECK(!computation.IsFusionComputation()); 480 TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, 481 TuplePointsToAnalysis::Run(computation.parent())); 482 return CreateMemoryMinimizingSequence(computation, *points_to_analysis, 483 size_function, algorithm); 484 } 485 486 } // namespace xla 487