1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/heap_simulator.h" 17 18 #include <memory> 19 #include <utility> 20 #include <vector> 21 22 #include "absl/container/flat_hash_map.h" 23 #include "absl/memory/memory.h" 24 #include "tensorflow/compiler/xla/literal.h" 25 #include "tensorflow/compiler/xla/service/buffer_value.h" 26 #include "tensorflow/compiler/xla/service/hlo_computation.h" 27 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 28 #include "tensorflow/compiler/xla/service/hlo_module.h" 29 #include "tensorflow/compiler/xla/service/hlo_ordering.h" 30 #include "tensorflow/compiler/xla/service/hlo_value.h" 31 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 32 #include "tensorflow/compiler/xla/status_macros.h" 33 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 34 #include "tensorflow/core/lib/core/status_test_util.h" 35 36 namespace xla { 37 namespace { 38 39 class MinimumMemoryForSequenceTest : public HloTestBase {}; 40 41 TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { 42 auto module = CreateNewVerifiedModule(); 43 const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); 44 const Shape tuple_shape = 45 ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); 46 47 auto cond_builder = HloComputation::Builder("WhileCond"); 48 // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) 49 HloInstruction* cond_param = cond_builder.AddInstruction( 50 HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); 51 HloInstruction* cond_iter = cond_builder.AddInstruction( 52 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); 53 HloInstruction* cond_data = cond_builder.AddInstruction( 54 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); 55 // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) 56 HloInstruction* cond_lt = cond_builder.AddInstruction( 57 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter, 58 cond_data, ComparisonDirection::kLt)); 59 HloComputation* cond_computation = 60 module->AddEmbeddedComputation(cond_builder.Build()); 61 62 auto body_builder = HloComputation::Builder("WhileBody"); 63 // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) 64 HloInstruction* body_param = body_builder.AddInstruction( 65 HloInstruction::CreateParameter(0, tuple_shape, "body_param")); 66 HloComputation* body_computation = 67 module->AddEmbeddedComputation(body_builder.Build()); 68 69 auto builder = HloComputation::Builder(TestName()); 70 // Entry params: 8 bytes (4 bytes per param), TOTAL=8 71 HloInstruction* iter = builder.AddInstruction( 72 HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); 73 HloInstruction* data = builder.AddInstruction( 74 HloInstruction::CreateParameter(1, scalar_shape, "param_data")); 75 // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 76 HloInstruction* tuple = 77 builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); 78 // While: 8 bytes (4 bytes per element), TOTAL=32 79 // Both cond and body use a max of 24 bytes, TOTAL=56 80 HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( 81 tuple_shape, cond_computation, body_computation, tuple)); 82 HloComputation* entry_computation = 83 module->AddEntryComputation(builder.Build()); 84 85 auto size_fn = [](const BufferValue& buffer) { 86 return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); 87 }; 88 89 HloSchedule schedule(module.get()); 90 schedule.set_sequence(cond_computation, 91 {cond_param, cond_iter, cond_data, cond_lt}); 92 schedule.set_sequence(body_computation, {body_param}); 93 schedule.set_sequence(entry_computation, {iter, data, tuple, while_op}); 94 TF_ASSERT_OK(schedule.Verify()); 95 96 EXPECT_EQ( 97 56, 98 HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie()); 99 } 100 101 TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) { 102 // HloModule SubcomputationAccounting 103 104 // %WhileBody (body_param: f32[4]) -> f32[4] { 105 // %body_param = f32[4]{0} parameter(0) 106 // %constant.1 = f32[4]{0} constant({1, 1, 1, 1}) 107 // ROOT %subtract = f32[4]{0} subtract(f32[4]{0} %body_param, f32[4]{0} 108 // %constant.1) 109 // } 110 111 // %WhileCond (cond_param: f32[4]) -> pred[] { 112 // %cond_param = f32[4]{0} parameter(0) 113 // %slice = f32[1]{0} slice(f32[4]{0} %cond_param), slice={[0:1]} 114 // %reshape = f32[] reshape(f32[1]{0} %slice) 115 // %constant = f32[] constant(0) 116 // ROOT %not-equal-to = pred[] compare(f32[] %reshape, f32[] %constant), 117 // direction=NE 118 // } 119 120 // ENTRY %SubcomputationAccounting () -> f32[2,4] { 121 // %constant.3 = f32[2,4]{1,0} constant(f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 122 // 3, 4 } }) %transpose = f32[2,4]{1,0} transpose(f32[2,4]{1,0} 123 // %constant.3), dimensions={0,1} %constant.2 = f32[4]{0} constant({1, 1, 1, 124 // 1}) %while = f32[4]{0} while(f32[4]{0} %constant.2), 125 // condition=%WhileCond, body=%WhileBody %broadcast = f32[2,4]{1,0} 126 // broadcast(f32[4]{0} %while), dimensions={1} ROOT %add = f32[2,4]{1,0} 127 // add(f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast) 128 // } 129 130 auto module = CreateNewVerifiedModule(); 131 const Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 132 const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); 133 const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); 134 135 // reshape(slice(param)) != 0 136 // Needs 5 bytes 137 auto cond_builder = HloComputation::Builder("WhileCond"); 138 HloInstruction* cond_param = cond_builder.AddInstruction( 139 HloInstruction::CreateParameter(0, r1f32, "cond_param")); 140 HloInstruction* slice = 141 cond_builder.AddInstruction(HloInstruction::CreateSlice( 142 ShapeUtil::MakeShape(F32, {1}), cond_param, {0}, {1}, {1})); 143 HloInstruction* reshape = 144 cond_builder.AddInstruction(HloInstruction::CreateReshape(r0f32, slice)); 145 HloInstruction* zero = cond_builder.AddInstruction( 146 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))); 147 HloInstruction* cond_comparison = cond_builder.AddInstruction( 148 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), reshape, 149 zero, ComparisonDirection::kNe)); 150 auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); 151 152 // param - 1 153 // Needs 16 bytes 154 auto body_builder = HloComputation::Builder("WhileBody"); 155 HloInstruction* body_param = body_builder.AddInstruction( 156 HloInstruction::CreateParameter(0, r1f32, "body_param")); 157 HloInstruction* one_vector = 158 body_builder.AddInstruction(HloInstruction::CreateConstant( 159 LiteralUtil::CreateR1<float>({1, 1, 1, 1}))); 160 HloInstruction* subtract = 161 body_builder.AddInstruction(HloInstruction::CreateBinary( 162 r1f32, HloOpcode::kSubtract, body_param, one_vector)); 163 auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); 164 165 // transpose(matrix) + bcast(while) 166 auto builder = HloComputation::Builder(TestName()); 167 HloInstruction* while_init = 168 builder.AddInstruction(HloInstruction::CreateConstant( 169 LiteralUtil::CreateR1<float>({1, 1, 1, 1}))); 170 // Creates 16 bytes, ignoring subcomputations 171 HloInstruction* while_loop = 172 builder.AddInstruction(HloInstruction::CreateWhile( 173 r1f32, cond_computation, body_computation, while_init)); 174 175 // Creates 32 bytes and frees 16 176 HloInstruction* bcast = builder.AddInstruction( 177 HloInstruction::CreateBroadcast(r2f32, while_loop, {1})); 178 179 HloInstruction* matrix = builder.AddInstruction( 180 HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>( 181 {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}}))); 182 // Creates 32 bytes 183 HloInstruction* transpose = builder.AddInstruction( 184 HloInstruction::CreateTranspose(r2f32, matrix, {0, 1})); 185 186 // Creates 32 bytes and frees 64 187 HloInstruction* add = builder.AddInstruction( 188 HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast)); 189 190 auto entry_computation = module->AddEntryComputation(builder.Build()); 191 192 HloSchedule schedule(module.get()); 193 std::vector<HloInstruction*> cond_vec = {cond_param, slice, reshape, zero, 194 cond_comparison}; 195 std::vector<HloInstruction*> while_body_vec = {body_param, one_vector, 196 subtract}; 197 std::vector<HloInstruction*> entry_comp_vec = {while_init, while_loop, bcast, 198 matrix, transpose, add}; 199 schedule.set_sequence(cond_computation, cond_vec); 200 schedule.set_sequence(body_computation, while_body_vec); 201 schedule.set_sequence(entry_computation, entry_comp_vec); 202 203 auto size_fn = [](const BufferValue& buffer) { 204 return ShapeUtil::ByteSizeOf(buffer.shape()); 205 }; 206 absl::flat_hash_map<const HloComputation*, int64> memory_by_computation; 207 memory_by_computation[cond_computation] = 5; 208 memory_by_computation[body_computation] = 16; 209 std::unique_ptr<TuplePointsToAnalysis> points_to_analysis = 210 TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); 211 212 // HeapSimulator accounts for subcomputations. The output buffer is aliased, 213 // so we don't double count. 214 EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( 215 *entry_computation, schedule.sequence(entry_computation), 216 *points_to_analysis, size_fn, &memory_by_computation) 217 .ValueOrDie()); 218 } 219 220 const char kAlloc[] = "Alloc"; 221 const char kFree[] = "Free"; 222 const char kFinish[] = "Finish"; 223 224 // CallSequence records a sequence of Alloc/Free/Finish calls. 225 using CallSequence = std::vector<std::pair<string, const BufferValue*>>; 226 227 // HeapCallRecorder is a dummy heap algorithm that simply records its calls. 228 class HeapCallRecorder : public HeapAlgorithm { 229 public: 230 explicit HeapCallRecorder(CallSequence* calls) : calls_(calls) {} 231 ~HeapCallRecorder() override {} 232 233 void Alloc(const BufferValue* buffer, int64 size) override { 234 calls_->emplace_back(kAlloc, buffer); 235 // Instead of assigning a real offset, we set the cardinality of the Alloc 236 // call. This isn't a valid assignment, but allows us to easily test for 237 // buffer sharing. 238 const int64 offset = result_.chunk_map.size(); 239 result_.chunk_map.emplace(buffer, Chunk{offset, size}); 240 } 241 void Free(const BufferValue* buffer, int64 size) override { 242 calls_->emplace_back(kFree, buffer); 243 } 244 Result Finish() override { 245 calls_->emplace_back(kFinish, nullptr); 246 return result_; 247 } 248 249 private: 250 CallSequence* calls_; 251 Result result_; 252 }; 253 254 // HeapSimulatorTracker runs the heap simulator, recording the sequence of calls 255 // made to the underlying heap algorithm. Tests compare the actual call 256 // sequence against an expected sequence. 257 class HeapSimulatorTracker { 258 public: 259 // Constructor for testing a single entry computation. 260 HeapSimulatorTracker( 261 const string& name, std::unique_ptr<HloComputation> computation, 262 const std::vector<HloInstruction*>& instruction_sequence) { 263 HloModuleConfig config; 264 module_ = absl::make_unique<HloModule>(name, config); 265 module_->AddEntryComputation(std::move(computation)); 266 points_to_analysis_ = 267 TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); 268 // Since we're only tracking the sequence of Alloc/Free calls, the actual 269 // size of the buffers doesn't matter, so we always return 0. We rely on 270 // the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls by 271 // buffer id, for determinism in the tests. 272 auto zero_size = [](const BufferValue& buffer) { return 0; }; 273 auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>( 274 absl::make_unique<HeapCallRecorder>(&actual_calls_)); 275 result_ = 276 HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(), 277 HloInstructionSequence(instruction_sequence), 278 *points_to_analysis_, zero_size) 279 .ConsumeValueOrDie(); 280 } 281 282 explicit HeapSimulatorTracker(const string& name) { 283 HloModuleConfig config; 284 module_ = absl::make_unique<HloModule>(name, config); 285 } 286 287 // Similar to the single entry computation constructor above, but runs the 288 // simulation over the entire module. 289 void RunWholeModule( 290 const std::vector<HloInstruction*>& full_module_sequence) { 291 points_to_analysis_ = 292 TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); 293 294 // Construct the module sequence grouped by computation. 295 HloSchedule schedule(module_.get()); 296 absl::flat_hash_map<const HloInstruction*, int> reverse_position; 297 for (int i = 0; i < full_module_sequence.size(); ++i) { 298 HloInstruction* instruction = full_module_sequence[i]; 299 schedule.GetOrCreateSequence(instruction->parent()) 300 .push_back(instruction); 301 reverse_position[instruction] = full_module_sequence.size() - i; 302 } 303 304 // Hack the size_fn so that it returns a decreasing value as we step through 305 // the sequence. This lets us ensure the Alloc calls are in the sequence 306 // order. The Free calls are sorted by BufferValue.id, which is at least 307 // deterministic. 308 auto size_fn = [&reverse_position](const BufferValue& buffer) { 309 return reverse_position[buffer.instruction()]; 310 }; 311 auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>( 312 absl::make_unique<HeapCallRecorder>(&actual_calls_)); 313 result_ = HeapSimulator::Run(std::move(algorithm), *module_, schedule, 314 *points_to_analysis_, size_fn) 315 .ConsumeValueOrDie(); 316 } 317 318 HloModule* module() { return module_.get(); } 319 320 // Returns the buffer defined at the given instruction and index. 321 const BufferValue* BufferAt(const HloInstruction* instruction, 322 const ShapeIndex& index) const { 323 return points_to_analysis_->GetBufferDefinedAt(instruction, index) 324 .ConsumeValueOrDie(); 325 } 326 327 int64 OffsetAt(const HloInstruction* instruction, const ShapeIndex& index) { 328 const BufferValue* buffer = BufferAt(instruction, index); 329 return result_.chunk_map.at(buffer).offset; 330 } 331 332 // Ensures the expected sequence of Alloc/Free/Finish calls was performed. 333 void ExpectCallSequence(const CallSequence& expected) const { 334 EXPECT_EQ(expected, actual_calls_); 335 } 336 337 // Ensures the buffers defined by the respective (instruction,index) pairs are 338 // shared, relying on the unique offsets assigned in HeapCallRecorder::Alloc. 339 void ExpectSharedBuffers(const HloInstruction* instruction_a, 340 const ShapeIndex& index_a, 341 const HloInstruction* instruction_b, 342 const ShapeIndex& index_b) { 343 int64 offset_a = OffsetAt(instruction_a, index_a); 344 int64 offset_b = OffsetAt(instruction_b, index_b); 345 EXPECT_EQ(offset_a, offset_b); 346 } 347 348 private: 349 std::unique_ptr<HloModule> module_; 350 std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_; 351 CallSequence actual_calls_; 352 HeapSimulator::Result result_; 353 }; 354 355 class HeapSimulatorTest : public HloTestBase { 356 protected: 357 HeapSimulatorTest() {} 358 ~HeapSimulatorTest() override {} 359 360 // Shapes for use in the examples. 361 Shape f32scalar_ = ShapeUtil::MakeShape(xla::F32, {}); 362 Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4}); 363 }; 364 365 TEST_F(HeapSimulatorTest, ScalarConstant) { 366 auto builder = HloComputation::Builder(TestName()); 367 auto const0 = builder.AddInstruction( 368 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); 369 370 // Constants aren't assigned. See b/32248867 371 HeapSimulatorTracker tracker(TestName(), builder.Build(), {const0}); 372 tracker.ExpectCallSequence({{kFinish, nullptr}}); 373 } 374 375 TEST_F(HeapSimulatorTest, OneParam) { 376 auto builder = HloComputation::Builder(TestName()); 377 auto param0 = builder.AddInstruction( 378 HloInstruction::CreateParameter(0, f32scalar_, "param0")); 379 380 // A single parameter which is also the output. 381 HeapSimulatorTracker tracker(TestName(), builder.Build(), {param0}); 382 tracker.ExpectCallSequence({ 383 {kAlloc, tracker.BufferAt(param0, {})}, 384 {kFree, tracker.BufferAt(param0, {})}, 385 {kFinish, nullptr}, 386 }); 387 } 388 389 TEST_F(HeapSimulatorTest, Multiply) { 390 auto builder = HloComputation::Builder(TestName()); 391 auto paramA = builder.AddInstruction( 392 HloInstruction::CreateParameter(0, f32scalar_, "paramA")); 393 auto paramX = builder.AddInstruction( 394 HloInstruction::CreateParameter(1, f32vec4_, "paramX")); 395 auto mul = builder.AddInstruction(HloInstruction::CreateBinary( 396 f32vec4_, HloOpcode::kMultiply, paramA, paramX)); 397 398 // We must keep all parameters and outputs. 399 HeapSimulatorTracker tracker(TestName(), builder.Build(), 400 {paramA, paramX, mul}); 401 tracker.ExpectCallSequence({ 402 {kAlloc, tracker.BufferAt(paramA, {})}, 403 {kAlloc, tracker.BufferAt(paramX, {})}, 404 {kAlloc, tracker.BufferAt(mul, {})}, 405 // All params and outputs are freed at the end. 406 {kFree, tracker.BufferAt(paramA, {})}, 407 {kFree, tracker.BufferAt(paramX, {})}, 408 {kFree, tracker.BufferAt(mul, {})}, 409 {kFinish, nullptr}, 410 }); 411 } 412 413 TEST_F(HeapSimulatorTest, MultiplyAdd) { 414 auto builder = HloComputation::Builder(TestName()); 415 auto paramA = builder.AddInstruction( 416 HloInstruction::CreateParameter(0, f32scalar_, "paramA")); 417 auto paramX = builder.AddInstruction( 418 HloInstruction::CreateParameter(1, f32vec4_, "paramX")); 419 auto paramY = builder.AddInstruction( 420 HloInstruction::CreateParameter(2, f32vec4_, "paramY")); 421 auto mul = builder.AddInstruction(HloInstruction::CreateBinary( 422 f32vec4_, HloOpcode::kMultiply, paramA, paramX)); 423 auto add = builder.AddInstruction( 424 HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY)); 425 426 // The buffer for add is the output, and it's shared with the buffer for mul. 427 HeapSimulatorTracker tracker(TestName(), builder.Build(), 428 {paramA, paramX, mul, paramY, add}); 429 tracker.ExpectCallSequence({ 430 {kAlloc, tracker.BufferAt(paramA, {})}, 431 {kAlloc, tracker.BufferAt(paramX, {})}, 432 {kAlloc, tracker.BufferAt(mul, {})}, 433 {kAlloc, tracker.BufferAt(paramY, {})}, 434 // All params and outputs are freed at the end. 435 {kFree, tracker.BufferAt(paramA, {})}, 436 {kFree, tracker.BufferAt(paramX, {})}, 437 {kFree, tracker.BufferAt(mul, {})}, 438 {kFree, tracker.BufferAt(paramY, {})}, 439 {kFinish, nullptr}, 440 }); 441 tracker.ExpectSharedBuffers(add, {}, mul, {}); 442 } 443 444 TEST_F(HeapSimulatorTest, BufferReusedOnce) { 445 HeapSimulatorTracker tracker(TestName()); 446 auto builder = HloComputation::Builder(TestName()); 447 448 HloComputation::Builder fusion_builder("fusion"); 449 { 450 HloComputation::Builder& builder = fusion_builder; 451 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( 452 /*parameter_number=*/0, f32vec4_, "A")); 453 auto exp = builder.AddInstruction( 454 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kExp, a_param)); 455 auto neg = builder.AddInstruction( 456 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param)); 457 458 builder.AddInstruction(HloInstruction::CreateTuple({exp, neg})); 459 } 460 auto fusion_computation = 461 tracker.module()->AddEmbeddedComputation(fusion_builder.Build()); 462 auto a_param = builder.AddInstruction( 463 HloInstruction::CreateParameter(0, f32vec4_, "paramA")); 464 auto neg = builder.AddInstruction( 465 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param)); 466 auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( 467 ShapeUtil::MakeTupleShape({f32vec4_, f32vec4_}), 468 HloInstruction::FusionKind::kLoop, {neg}, fusion_computation)); 469 tracker.module()->AddEntryComputation(builder.Build()); 470 471 tracker.RunWholeModule({a_param, neg, fusion}); 472 473 auto neg_buffer = tracker.OffsetAt(neg, {}); 474 int64 output_buffer_0 = tracker.OffsetAt(fusion, {0}); 475 int64 output_buffer_1 = tracker.OffsetAt(fusion, {1}); 476 // Only one buffer should be shared. 477 EXPECT_TRUE((neg_buffer == output_buffer_0) ^ 478 (neg_buffer == output_buffer_1)); 479 } 480 481 TEST_F(HeapSimulatorTest, MultiplyDot) { 482 auto builder = HloComputation::Builder(TestName()); 483 auto paramA = builder.AddInstruction( 484 HloInstruction::CreateParameter(0, f32scalar_, "paramA")); 485 auto paramX = builder.AddInstruction( 486 HloInstruction::CreateParameter(1, f32vec4_, "paramX")); 487 auto paramY = builder.AddInstruction( 488 HloInstruction::CreateParameter(2, f32scalar_, "paramY")); 489 auto mul = builder.AddInstruction(HloInstruction::CreateBinary( 490 f32vec4_, HloOpcode::kMultiply, paramA, paramX)); 491 DotDimensionNumbers dot_dnums; 492 dot_dnums.add_lhs_contracting_dimensions(1); 493 dot_dnums.add_rhs_contracting_dimensions(0); 494 auto dot = builder.AddInstruction(HloInstruction::CreateDot( 495 f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); 496 497 // The buffer for dot is the output, and it cannot be shared with the buffer 498 // for mul, since dot isn't elementwise. 499 HeapSimulatorTracker tracker(TestName(), builder.Build(), 500 {paramA, paramX, mul, paramY, dot}); 501 tracker.ExpectCallSequence({ 502 {kAlloc, tracker.BufferAt(paramA, {})}, 503 {kAlloc, tracker.BufferAt(paramX, {})}, 504 {kAlloc, tracker.BufferAt(mul, {})}, 505 {kAlloc, tracker.BufferAt(paramY, {})}, 506 {kAlloc, tracker.BufferAt(dot, {})}, 507 // All params and outputs are freed at the end. 508 {kFree, tracker.BufferAt(paramA, {})}, 509 {kFree, tracker.BufferAt(paramX, {})}, 510 {kFree, tracker.BufferAt(mul, {})}, 511 {kFree, tracker.BufferAt(paramY, {})}, 512 {kFree, tracker.BufferAt(dot, {})}, 513 {kFinish, nullptr}, 514 }); 515 } 516 517 TEST_F(HeapSimulatorTest, MultiplyDotAdd) { 518 auto builder = HloComputation::Builder(TestName()); 519 auto paramA = builder.AddInstruction( 520 HloInstruction::CreateParameter(0, f32scalar_, "paramA")); 521 auto paramX = builder.AddInstruction( 522 HloInstruction::CreateParameter(1, f32vec4_, "paramX")); 523 auto paramY = builder.AddInstruction( 524 HloInstruction::CreateParameter(2, f32scalar_, "paramY")); 525 auto mul = builder.AddInstruction(HloInstruction::CreateBinary( 526 f32vec4_, HloOpcode::kMultiply, paramA, paramX)); 527 DotDimensionNumbers dot_dnums; 528 dot_dnums.add_lhs_contracting_dimensions(1); 529 dot_dnums.add_rhs_contracting_dimensions(0); 530 auto dot = builder.AddInstruction(HloInstruction::CreateDot( 531 f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); 532 auto add = builder.AddInstruction( 533 HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA)); 534 535 // The buffer for add is the output, and it's shared with the buffer for dot. 536 HeapSimulatorTracker tracker(TestName(), builder.Build(), 537 {paramA, paramX, mul, paramY, dot, add}); 538 tracker.ExpectCallSequence({ 539 {kAlloc, tracker.BufferAt(paramA, {})}, 540 {kAlloc, tracker.BufferAt(paramX, {})}, 541 {kAlloc, tracker.BufferAt(mul, {})}, 542 {kAlloc, tracker.BufferAt(paramY, {})}, 543 {kAlloc, tracker.BufferAt(dot, {})}, 544 // All params and outputs are freed at the end. 545 {kFree, tracker.BufferAt(paramA, {})}, 546 {kFree, tracker.BufferAt(paramX, {})}, 547 {kFree, tracker.BufferAt(mul, {})}, 548 {kFree, tracker.BufferAt(paramY, {})}, 549 {kFree, tracker.BufferAt(dot, {})}, 550 {kFinish, nullptr}, 551 }); 552 tracker.ExpectSharedBuffers(add, {}, dot, {}); 553 } 554 555 TEST_F(HeapSimulatorTest, MultiplyDotDot) { 556 auto builder = HloComputation::Builder(TestName()); 557 auto paramA = builder.AddInstruction( 558 HloInstruction::CreateParameter(0, f32scalar_, "paramA")); 559 auto paramX = builder.AddInstruction( 560 HloInstruction::CreateParameter(1, f32vec4_, "paramX")); 561 auto paramY = builder.AddInstruction( 562 HloInstruction::CreateParameter(2, f32scalar_, "paramY")); 563 auto mul = builder.AddInstruction(HloInstruction::CreateBinary( 564 f32vec4_, HloOpcode::kMultiply, paramA, paramX)); 565 DotDimensionNumbers dot_dnums; 566 dot_dnums.add_lhs_contracting_dimensions(1); 567 dot_dnums.add_rhs_contracting_dimensions(0); 568 auto dot0 = builder.AddInstruction(HloInstruction::CreateDot( 569 f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); 570 auto dot1 = builder.AddInstruction(HloInstruction::CreateDot( 571 f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2))); 572 573 // The buffer for dot1 is the output. No buffers can be shared. The buffer 574 // for mul is freed before the end, since it's no longer used after dot0 575 // finishes. 576 HeapSimulatorTracker tracker(TestName(), builder.Build(), 577 {paramA, paramX, mul, paramY, dot0, dot1}); 578 tracker.ExpectCallSequence({ 579 {kAlloc, tracker.BufferAt(paramA, {})}, 580 {kAlloc, tracker.BufferAt(paramX, {})}, 581 {kAlloc, tracker.BufferAt(mul, {})}, 582 {kAlloc, tracker.BufferAt(paramY, {})}, 583 {kAlloc, tracker.BufferAt(dot0, {})}, 584 {kFree, tracker.BufferAt(mul, {})}, // mul no longer used 585 {kAlloc, tracker.BufferAt(dot1, {})}, 586 // All params and outputs are freed at the end. 587 {kFree, tracker.BufferAt(paramA, {})}, 588 {kFree, tracker.BufferAt(paramX, {})}, 589 {kFree, tracker.BufferAt(paramY, {})}, 590 {kFree, tracker.BufferAt(dot0, {})}, 591 {kFree, tracker.BufferAt(dot1, {})}, 592 {kFinish, nullptr}, 593 }); 594 } 595 596 TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { 597 auto builder = HloComputation::Builder(TestName()); 598 auto paramA = builder.AddInstruction( 599 HloInstruction::CreateParameter(0, f32scalar_, "paramA")); 600 auto paramX = builder.AddInstruction( 601 HloInstruction::CreateParameter(1, f32vec4_, "paramX")); 602 auto paramY = builder.AddInstruction( 603 HloInstruction::CreateParameter(2, f32scalar_, "paramY")); 604 auto mul = builder.AddInstruction(HloInstruction::CreateBinary( 605 f32vec4_, HloOpcode::kMultiply, paramA, paramX)); 606 DotDimensionNumbers dot_dnums; 607 dot_dnums.add_lhs_contracting_dimensions(1); 608 dot_dnums.add_rhs_contracting_dimensions(0); 609 auto dot0 = builder.AddInstruction(HloInstruction::CreateDot( 610 f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); 611 auto dot1 = builder.AddInstruction(HloInstruction::CreateDot( 612 f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2))); 613 auto tuple = 614 builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1})); 615 616 // The buffers for dot0, dot1 and tuple are the output. No buffers can be 617 // shared. The buffer for mul is freed before the end, since it's no longer 618 // used after dot0 finishes. 619 HeapSimulatorTracker tracker( 620 TestName(), builder.Build(), 621 {paramA, paramX, mul, paramY, dot0, dot1, tuple}); 622 tracker.ExpectCallSequence({ 623 {kAlloc, tracker.BufferAt(paramA, {})}, 624 {kAlloc, tracker.BufferAt(paramX, {})}, 625 {kAlloc, tracker.BufferAt(mul, {})}, 626 {kAlloc, tracker.BufferAt(paramY, {})}, 627 {kAlloc, tracker.BufferAt(dot0, {})}, 628 {kFree, tracker.BufferAt(mul, {})}, // mul no longer used 629 {kAlloc, tracker.BufferAt(dot1, {})}, 630 {kAlloc, tracker.BufferAt(tuple, {})}, 631 // All params and outputs are freed at the end. 632 {kFree, tracker.BufferAt(paramA, {})}, 633 {kFree, tracker.BufferAt(paramX, {})}, 634 {kFree, tracker.BufferAt(paramY, {})}, 635 {kFree, tracker.BufferAt(dot0, {})}, 636 {kFree, tracker.BufferAt(dot1, {})}, 637 {kFree, tracker.BufferAt(tuple, {})}, 638 {kFinish, nullptr}, 639 }); 640 } 641 642 TEST_F(HeapSimulatorTest, IndependentTupleElements) { 643 auto builder = HloComputation::Builder(TestName()); 644 auto paramA = builder.AddInstruction( 645 HloInstruction::CreateParameter(0, f32scalar_, "paramA")); 646 auto paramB = builder.AddInstruction( 647 HloInstruction::CreateParameter(1, f32scalar_, "paramB")); 648 auto mul = builder.AddInstruction(HloInstruction::CreateBinary( 649 f32scalar_, HloOpcode::kMultiply, paramA, paramB)); 650 auto add = builder.AddInstruction(HloInstruction::CreateBinary( 651 f32scalar_, HloOpcode::kAdd, paramA, paramB)); 652 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({mul, add})); 653 auto element0 = builder.AddInstruction( 654 HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 0)); 655 auto broadcast = builder.AddInstruction( 656 HloInstruction::CreateBroadcast(f32vec4_, element0, {0})); 657 auto sub = builder.AddInstruction(HloInstruction::CreateBinary( 658 f32scalar_, HloOpcode::kSubtract, paramA, paramB)); 659 auto element1 = builder.AddInstruction( 660 HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 1)); 661 auto output = builder.AddInstruction( 662 HloInstruction::CreateTuple({broadcast, sub, element1})); 663 664 HeapSimulatorTracker tracker(TestName(), builder.Build(), 665 {paramA, paramB, mul, add, tuple, element0, 666 broadcast, sub, element1, output}); 667 tracker.ExpectCallSequence({ 668 {kAlloc, tracker.BufferAt(paramA, {})}, 669 {kAlloc, tracker.BufferAt(paramB, {})}, 670 {kAlloc, tracker.BufferAt(mul, {})}, 671 {kAlloc, tracker.BufferAt(add, {})}, 672 {kAlloc, tracker.BufferAt(tuple, {})}, 673 {kAlloc, tracker.BufferAt(broadcast, {})}, 674 // The mul can be freed right after the broadcast happens, even though 675 // The other GetTupleElement is still alive. 676 {kFree, tracker.BufferAt(mul, {})}, 677 {kAlloc, tracker.BufferAt(sub, {})}, 678 // The temporary tuple is now dead. 679 {kFree, tracker.BufferAt(tuple, {})}, 680 {kAlloc, tracker.BufferAt(output, {})}, 681 // All params and outputs are freed at the end. 682 {kFree, tracker.BufferAt(paramA, {})}, 683 {kFree, tracker.BufferAt(paramB, {})}, 684 {kFree, tracker.BufferAt(add, {})}, 685 {kFree, tracker.BufferAt(broadcast, {})}, 686 {kFree, tracker.BufferAt(sub, {})}, 687 {kFree, tracker.BufferAt(output, {})}, 688 {kFinish, nullptr}, 689 }); 690 } 691 692 TEST_F(HeapSimulatorTest, WholeModule) { 693 HeapSimulatorTracker tracker(TestName()); 694 695 const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); 696 const Shape tuple_shape = 697 ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); 698 699 auto cond_builder = HloComputation::Builder("WhileCond"); 700 HloInstruction* cond_param = cond_builder.AddInstruction( 701 HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); 702 HloInstruction* cond_iter = cond_builder.AddInstruction( 703 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); 704 HloInstruction* cond_data = cond_builder.AddInstruction( 705 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); 706 HloInstruction* cond_lt = cond_builder.AddInstruction( 707 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter, 708 cond_data, ComparisonDirection::kLt)); 709 HloComputation* cond_computation = 710 tracker.module()->AddEmbeddedComputation(cond_builder.Build()); 711 712 auto body_builder = HloComputation::Builder("WhileBody"); 713 HloInstruction* body_param = body_builder.AddInstruction( 714 HloInstruction::CreateParameter(0, tuple_shape, "body_param")); 715 HloComputation* body_computation = 716 tracker.module()->AddEmbeddedComputation(body_builder.Build()); 717 718 auto builder = HloComputation::Builder(TestName()); 719 HloInstruction* param = builder.AddInstruction( 720 HloInstruction::CreateParameter(0, tuple_shape, "param")); 721 HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( 722 tuple_shape, cond_computation, body_computation, param)); 723 tracker.module()->AddEntryComputation(builder.Build()); 724 725 tracker.RunWholeModule( 726 {param, while_op, body_param, cond_param, cond_iter, cond_data, cond_lt}); 727 tracker.ExpectCallSequence({ 728 // The entry computation param and while_op are allocated first. 729 {kAlloc, tracker.BufferAt(param, {})}, 730 {kAlloc, tracker.BufferAt(param, {0})}, 731 {kAlloc, tracker.BufferAt(param, {1})}, 732 {kAlloc, tracker.BufferAt(while_op, {})}, 733 {kAlloc, tracker.BufferAt(while_op, {0})}, 734 {kAlloc, tracker.BufferAt(while_op, {1})}, 735 736 // Now the while body param is allocated and freed. 737 {kAlloc, tracker.BufferAt(body_param, {})}, 738 {kAlloc, tracker.BufferAt(body_param, {0})}, 739 {kAlloc, tracker.BufferAt(body_param, {1})}, 740 {kFree, tracker.BufferAt(body_param, {})}, 741 {kFree, tracker.BufferAt(body_param, {0})}, 742 {kFree, tracker.BufferAt(body_param, {1})}, 743 744 // Now the while cond param is allocated. The GTE instructions just alias 745 // the param elements, so the param tuple can immediately be freed. 746 {kAlloc, tracker.BufferAt(cond_param, {})}, 747 {kAlloc, tracker.BufferAt(cond_param, {0})}, 748 {kAlloc, tracker.BufferAt(cond_param, {1})}, 749 {kFree, tracker.BufferAt(cond_param, {})}, 750 751 // Now the final cond less-than buffer is allocated. 752 {kAlloc, tracker.BufferAt(cond_lt, {})}, 753 754 // The order of the remaining Free calls is based on the BufferValue.id, 755 // which is deterministic, but not obvious. 756 {kFree, tracker.BufferAt(param, {})}, 757 {kFree, tracker.BufferAt(param, {0})}, 758 {kFree, tracker.BufferAt(param, {1})}, 759 760 {kFree, tracker.BufferAt(while_op, {})}, 761 {kFree, tracker.BufferAt(while_op, {0})}, 762 {kFree, tracker.BufferAt(while_op, {1})}, 763 764 {kFree, tracker.BufferAt(cond_param, {0})}, 765 {kFree, tracker.BufferAt(cond_param, {1})}, 766 {kFree, tracker.BufferAt(cond_lt, {})}, 767 768 {kFinish, nullptr}, 769 }); 770 } 771 772 // Base class for heap algorithm tests. 773 class HeapAlgorithmTestBase : public ::testing::Test { 774 protected: 775 HeapAlgorithmTestBase() : builder_("heap_simulator_test") { 776 buffer_a_ = DummyBufferValue(); 777 buffer_b_ = DummyBufferValue(); 778 buffer_c_ = DummyBufferValue(); 779 buffer_d_ = DummyBufferValue(); 780 buffer_e_ = DummyBufferValue(); 781 buffer_f_ = DummyBufferValue(); 782 buffer_g_ = DummyBufferValue(); 783 buffer_h_ = DummyBufferValue(); 784 buffer_i_ = DummyBufferValue(); 785 } 786 ~HeapAlgorithmTestBase() override {} 787 788 const BufferValue* buffer_a_; 789 const BufferValue* buffer_b_; 790 const BufferValue* buffer_c_; 791 const BufferValue* buffer_d_; 792 const BufferValue* buffer_e_; 793 const BufferValue* buffer_f_; 794 const BufferValue* buffer_g_; 795 const BufferValue* buffer_h_; 796 const BufferValue* buffer_i_; 797 798 private: 799 // Create a dummy BufferValue to pass to the heap algorithm. 800 const BufferValue* DummyBufferValue() { 801 const BufferValue::Id id = buffers_.size(); 802 auto const0 = builder_.AddInstruction( 803 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); 804 buffers_.emplace_back( 805 absl::make_unique<HloValue>(id, const0, ShapeIndex{})); 806 return buffers_.back().get(); 807 } 808 809 HloComputation::Builder builder_; 810 std::vector<std::unique_ptr<BufferValue>> buffers_; 811 }; 812 813 class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {}; 814 815 TEST_F(NoFragmentationStatsHeapTest, Empty) { 816 NoFragmentationStatsHeap heap; 817 EXPECT_EQ(0, heap.Finish().heap_size); 818 } 819 820 TEST_F(NoFragmentationStatsHeapTest, Simple) { 821 NoFragmentationStatsHeap heap; 822 heap.Alloc(buffer_a_, 10); 823 heap.Alloc(buffer_b_, 20); 824 heap.Alloc(buffer_c_, 30); 825 heap.Alloc(buffer_d_, 30); 826 heap.Free(buffer_a_, 10); 827 heap.Free(buffer_b_, 20); 828 heap.Free(buffer_c_, 30); 829 heap.Free(buffer_d_, 30); 830 EXPECT_EQ(90, heap.Finish().heap_size); 831 } 832 833 TEST_F(NoFragmentationStatsHeapTest, Mixed) { 834 NoFragmentationStatsHeap heap; 835 heap.Alloc(buffer_a_, 10); // max: A 836 837 heap.Alloc(buffer_b_, 20); // max: A+B 838 heap.Free(buffer_b_, 20); 839 840 heap.Alloc(buffer_c_, 30); // max: A+C 841 heap.Free(buffer_c_, 30); 842 843 heap.Alloc(buffer_d_, 5); // max: A+C 844 heap.Free(buffer_d_, 5); 845 846 heap.Free(buffer_a_, 10); 847 EXPECT_EQ(40, heap.Finish().heap_size); 848 } 849 850 class DecreasingSizeRunsHeapTest : public HeapAlgorithmTestBase {}; 851 852 TEST_F(DecreasingSizeRunsHeapTest, Empty) { 853 CallSequence call_sequence; 854 DecreasingSizeRunsHeap heap( 855 absl::make_unique<HeapCallRecorder>(&call_sequence)); 856 heap.Finish(); 857 EXPECT_EQ(call_sequence, CallSequence({ 858 {kFinish, nullptr}, 859 })); 860 } 861 862 TEST_F(DecreasingSizeRunsHeapTest, Simple) { 863 CallSequence call_sequence; 864 DecreasingSizeRunsHeap heap( 865 absl::make_unique<HeapCallRecorder>(&call_sequence)); 866 heap.Alloc(buffer_a_, 10); 867 heap.Alloc(buffer_b_, 20); 868 heap.Alloc(buffer_c_, 30); 869 heap.Alloc(buffer_d_, 30); 870 heap.Free(buffer_a_, 10); 871 heap.Free(buffer_b_, 20); 872 heap.Free(buffer_c_, 30); 873 heap.Free(buffer_d_, 30); 874 heap.Finish(); 875 // Runs of Allocs and Frees are sorted by decreasing size, with buffer id 876 // tiebreaker. 877 EXPECT_EQ(call_sequence, CallSequence({ 878 {kAlloc, buffer_c_}, 879 {kAlloc, buffer_d_}, 880 {kAlloc, buffer_b_}, 881 {kAlloc, buffer_a_}, 882 {kFree, buffer_c_}, 883 {kFree, buffer_d_}, 884 {kFree, buffer_b_}, 885 {kFree, buffer_a_}, 886 {kFinish, nullptr}, 887 })); 888 } 889 890 TEST_F(DecreasingSizeRunsHeapTest, Mixed) { 891 CallSequence call_sequence; 892 DecreasingSizeRunsHeap heap( 893 absl::make_unique<HeapCallRecorder>(&call_sequence)); 894 heap.Alloc(buffer_a_, 10); 895 heap.Alloc(buffer_b_, 20); 896 heap.Free(buffer_b_, 20); 897 898 heap.Alloc(buffer_c_, 30); 899 heap.Free(buffer_c_, 30); 900 901 heap.Alloc(buffer_d_, 5); 902 heap.Free(buffer_d_, 5); 903 heap.Free(buffer_a_, 10); 904 heap.Finish(); 905 // Runs of Allocs and Frees are sorted by decreasing size. 906 EXPECT_EQ(call_sequence, CallSequence({ 907 {kAlloc, buffer_b_}, 908 {kAlloc, buffer_a_}, 909 {kFree, buffer_b_}, 910 911 {kAlloc, buffer_c_}, 912 {kFree, buffer_c_}, 913 914 {kAlloc, buffer_d_}, 915 {kFree, buffer_a_}, 916 {kFree, buffer_d_}, 917 {kFinish, nullptr}, 918 })); 919 } 920 921 class LazyBestFitHeapTest : public HeapAlgorithmTestBase {}; 922 923 TEST_F(LazyBestFitHeapTest, Empty) { 924 LazyBestFitHeap heap(/*alignment=*/1); 925 const HeapSimulator::Result result = heap.Finish(); 926 EXPECT_EQ(0, result.heap_size); 927 EXPECT_EQ(0, result.chunk_map.size()); 928 } 929 930 TEST_F(LazyBestFitHeapTest, Simple) { 931 LazyBestFitHeap heap(/*alignment=*/1); 932 heap.Alloc(buffer_a_, 10); 933 heap.Alloc(buffer_b_, 20); 934 heap.Alloc(buffer_c_, 30); 935 heap.Alloc(buffer_d_, 30); 936 heap.Free(buffer_a_, 10); 937 heap.Free(buffer_b_, 20); 938 heap.Free(buffer_c_, 30); 939 heap.Free(buffer_d_, 30); 940 941 const HeapSimulator::Result result = heap.Finish(); 942 EXPECT_EQ(90, result.heap_size); 943 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); 944 EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); 945 EXPECT_EQ(30, result.chunk_map.at(buffer_c_).size); 946 EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size); 947 948 EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset); 949 EXPECT_EQ(10, result.chunk_map.at(buffer_b_).offset); 950 EXPECT_EQ(30, result.chunk_map.at(buffer_c_).offset); 951 EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset); 952 } 953 954 TEST_F(LazyBestFitHeapTest, Mixed) { 955 LazyBestFitHeap heap(/*alignment=*/1); 956 heap.Alloc(buffer_a_, 10); // A lazy offset 957 958 heap.Alloc(buffer_b_, 20); // B lazy offset 959 heap.Free(buffer_b_, 20); // B range = [0, 20) free = [0, 20) 960 961 heap.Alloc(buffer_c_, 30); // C range = [0, 30) 962 heap.Free(buffer_c_, 30); // free = [0, 30) 963 964 heap.Alloc(buffer_d_, 5); // D range = [0, 5) free = [5, 30) 965 heap.Free(buffer_d_, 5); // free = [0, 30) 966 967 heap.Free(buffer_a_, 10); // A range = [30, 10) free = [0, 40) 968 969 const HeapSimulator::Result result = heap.Finish(); 970 EXPECT_EQ(40, result.heap_size); 971 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); 972 EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); 973 EXPECT_EQ(30, result.chunk_map.at(buffer_c_).size); 974 EXPECT_EQ(5, result.chunk_map.at(buffer_d_).size); 975 976 EXPECT_EQ(30, result.chunk_map.at(buffer_a_).offset); 977 EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset); 978 EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset); 979 EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset); 980 } 981 982 TEST_F(LazyBestFitHeapTest, BestFit) { 983 LazyBestFitHeap heap(/*alignment=*/1); 984 985 // First alloc/free buffer_a_, to force a big free chunk to appear. 986 heap.Alloc(buffer_a_, 200); // A lazy offset 987 heap.Free(buffer_a_, 200); // A range = [0, 200) free = [0, 200) 988 989 // Now alloc a bunch of buffers that are allocated out of the free chunk. 990 heap.Alloc(buffer_b_, 30); // B range = [0, 30) free = [30, 200) 991 heap.Alloc(buffer_c_, 30); // C range = [30, 60) free = [60, 200) 992 heap.Alloc(buffer_d_, 20); // D range = [60, 80) free = [80, 200) 993 heap.Alloc(buffer_e_, 20); // E range = [80, 100) free = [100, 200) 994 heap.Alloc(buffer_f_, 10); // F range = [100, 110) free = [110, 200) 995 heap.Alloc(buffer_g_, 10); // G range = [110, 120) free = [120, 200) 996 heap.Alloc(buffer_h_, 80); // H range = [120, 200) 997 998 // Free buffers to create free chunks of different sizes. 999 heap.Free(buffer_c_, 30); // free = [30, 60) 1000 heap.Free(buffer_e_, 20); // free = [30, 60), [80, 100) 1001 heap.Free(buffer_g_, 10); // free = [30, 60), [80, 100), [110, 120) 1002 1003 // The best fit is picked out of the existing free chunks. 1004 heap.Alloc(buffer_i_, 15); // I range = [80, 95) 1005 1006 // The frees here ensure the buffer-coalescing logic is exercised. 1007 heap.Free(buffer_b_, 30); 1008 heap.Free(buffer_d_, 20); 1009 heap.Free(buffer_f_, 10); 1010 heap.Free(buffer_h_, 80); 1011 heap.Free(buffer_i_, 15); 1012 1013 const HeapSimulator::Result result = heap.Finish(); 1014 EXPECT_EQ(200, result.heap_size); 1015 EXPECT_EQ(200, result.chunk_map.at(buffer_a_).size); 1016 EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size); 1017 EXPECT_EQ(30, result.chunk_map.at(buffer_c_).size); 1018 EXPECT_EQ(20, result.chunk_map.at(buffer_d_).size); 1019 EXPECT_EQ(20, result.chunk_map.at(buffer_e_).size); 1020 EXPECT_EQ(10, result.chunk_map.at(buffer_f_).size); 1021 EXPECT_EQ(10, result.chunk_map.at(buffer_g_).size); 1022 EXPECT_EQ(80, result.chunk_map.at(buffer_h_).size); 1023 EXPECT_EQ(15, result.chunk_map.at(buffer_i_).size); 1024 1025 EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset); 1026 EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset); 1027 EXPECT_EQ(30, result.chunk_map.at(buffer_c_).offset); 1028 EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset); 1029 EXPECT_EQ(80, result.chunk_map.at(buffer_e_).offset); 1030 EXPECT_EQ(100, result.chunk_map.at(buffer_f_).offset); 1031 EXPECT_EQ(110, result.chunk_map.at(buffer_g_).offset); 1032 EXPECT_EQ(120, result.chunk_map.at(buffer_h_).offset); 1033 EXPECT_EQ(80, result.chunk_map.at(buffer_i_).offset); 1034 } 1035 1036 TEST_F(LazyBestFitHeapTest, Lazy) { 1037 LazyBestFitHeap heap(/*alignment=*/1); 1038 1039 // First alloc some buffers, which are all lazily allocated offsets. 1040 heap.Alloc(buffer_a_, 10); 1041 heap.Alloc(buffer_b_, 5); 1042 heap.Alloc(buffer_c_, 10); 1043 1044 // Now free some buffers, which forces offset assignment. 1045 heap.Free(buffer_a_, 10); // A range = [0, 10) free = [0, 10) 1046 heap.Free(buffer_c_, 10); // C range = [10, 20) free = [0, 20) 1047 1048 // If we hadn't lazily assigned offsets, the free chunk wouldn't be large 1049 // enough to hold the entire allocation. 1050 heap.Alloc(buffer_d_, 20); // D range = [0, 20) 1051 1052 heap.Free(buffer_b_, 5); // B range = [20, 25) 1053 heap.Free(buffer_d_, 20); 1054 1055 const HeapSimulator::Result result = heap.Finish(); 1056 EXPECT_EQ(25, result.heap_size); 1057 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); 1058 EXPECT_EQ(5, result.chunk_map.at(buffer_b_).size); 1059 EXPECT_EQ(10, result.chunk_map.at(buffer_c_).size); 1060 EXPECT_EQ(20, result.chunk_map.at(buffer_d_).size); 1061 1062 EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset); 1063 EXPECT_EQ(20, result.chunk_map.at(buffer_b_).offset); 1064 EXPECT_EQ(10, result.chunk_map.at(buffer_c_).offset); 1065 EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset); 1066 } 1067 1068 TEST_F(LazyBestFitHeapTest, ReuseLastFreeChunk) { 1069 LazyBestFitHeap heap(/*alignment=*/1); 1070 1071 // First alloc/free buffer_a_, to force a big free chunk to appear. 1072 heap.Alloc(buffer_a_, 60); // A lazy offset 1073 heap.Free(buffer_a_, 60); // A range = [0, 60) free = [0, 60) 1074 1075 // Now alloc a bunch of buffers that are allocated out of the free chunk. 1076 heap.Alloc(buffer_b_, 10); // B range = [0, 10) free = [10, 60) 1077 heap.Alloc(buffer_c_, 20); // C range = [10, 30) free = [30, 60) 1078 heap.Alloc(buffer_d_, 30); // D range = [30, 60) 1079 1080 // Free buffers to create free chunks of different sizes. 1081 heap.Free(buffer_b_, 10); // free = [0, 10) 1082 heap.Free(buffer_d_, 30); // free = [0, 10), [30, 60) 1083 1084 // No free chunks are large enough, but the last free chunk is adjacent to the 1085 // end of the heap, so we re-use that chunk. 1086 heap.Alloc(buffer_e_, 40); // E range = [30, 70) 1087 1088 heap.Free(buffer_c_, 20); 1089 heap.Free(buffer_e_, 40); 1090 1091 const HeapSimulator::Result result = heap.Finish(); 1092 EXPECT_EQ(70, result.heap_size); 1093 EXPECT_EQ(60, result.chunk_map.at(buffer_a_).size); 1094 EXPECT_EQ(10, result.chunk_map.at(buffer_b_).size); 1095 EXPECT_EQ(20, result.chunk_map.at(buffer_c_).size); 1096 EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size); 1097 EXPECT_EQ(40, result.chunk_map.at(buffer_e_).size); 1098 1099 EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset); 1100 EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset); 1101 EXPECT_EQ(10, result.chunk_map.at(buffer_c_).offset); 1102 EXPECT_EQ(30, result.chunk_map.at(buffer_d_).offset); 1103 EXPECT_EQ(30, result.chunk_map.at(buffer_e_).offset); 1104 } 1105 1106 TEST_F(LazyBestFitHeapTest, Alignment) { 1107 LazyBestFitHeap heap(/*alignment=*/64); 1108 1109 // First alloc some buffers, which are all lazily allocated offsets. 1110 heap.Alloc(buffer_a_, 10); 1111 heap.Alloc(buffer_b_, 5); 1112 heap.Alloc(buffer_c_, 10); 1113 1114 // Now free some buffers, which forces offset assignment with alignment. 1115 heap.Free(buffer_a_, 10); // A range = [0, 10) free = [0, 10) 1116 heap.Free(buffer_c_, 10); // C range = [64, 74) free = [0, 74) 1117 1118 // If we hadn't lazily assigned offsets, and accounted for alignment, the free 1119 // chunk wouldn't be large enough to hold the entire allocation. 1120 heap.Alloc(buffer_d_, 74); // D range = [0, 74) free = [) 1121 1122 heap.Free(buffer_b_, 5); // B range = [128, 133) free = [74, 133) 1123 heap.Alloc(buffer_e_, 23); // E range = [128, 151) free = [74, 128) 1124 1125 heap.Free(buffer_d_, 74); // free = [0, 128) 1126 heap.Free(buffer_e_, 23); // free = [0, 151) 1127 1128 const HeapSimulator::Result result = heap.Finish(); 1129 EXPECT_EQ(151, result.heap_size); 1130 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); 1131 EXPECT_EQ(5, result.chunk_map.at(buffer_b_).size); 1132 EXPECT_EQ(10, result.chunk_map.at(buffer_c_).size); 1133 EXPECT_EQ(74, result.chunk_map.at(buffer_d_).size); 1134 EXPECT_EQ(23, result.chunk_map.at(buffer_e_).size); 1135 1136 EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset); 1137 EXPECT_EQ(128, result.chunk_map.at(buffer_b_).offset); 1138 EXPECT_EQ(64, result.chunk_map.at(buffer_c_).offset); 1139 EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset); 1140 EXPECT_EQ(128, result.chunk_map.at(buffer_e_).offset); 1141 } 1142 1143 class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase {}; 1144 1145 TEST_F(GlobalDecreasingSizeBestFitHeapTest, Empty) { 1146 GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); 1147 const HeapSimulator::Result result = heap.Finish(); 1148 EXPECT_EQ(0, result.heap_size); 1149 EXPECT_EQ(0, result.chunk_map.size()); 1150 } 1151 1152 TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) { 1153 // space 1154 // ^ 1155 // | +---a---+ 1156 // | +-------+ 1157 // | +---c---+ 1158 // | +-------+ 1159 // | | b | 1160 // | +-------+ 1161 // | +-------+ 1162 // | | | 1163 // | | d | 1164 // | +-------+ 1165 // -----------------> time 1166 GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); 1167 heap.Alloc(buffer_a_, 10); 1168 heap.Alloc(buffer_b_, 30); 1169 heap.Alloc(buffer_c_, 20); 1170 heap.Alloc(buffer_d_, 40); 1171 heap.Free(buffer_a_, 10); 1172 heap.Free(buffer_b_, 30); 1173 heap.Free(buffer_c_, 20); 1174 heap.Free(buffer_d_, 40); 1175 1176 const HeapSimulator::Result result = heap.Finish(); 1177 EXPECT_EQ(100, result.heap_size); 1178 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); 1179 EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size); 1180 EXPECT_EQ(20, result.chunk_map.at(buffer_c_).size); 1181 EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size); 1182 1183 EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset); 1184 EXPECT_EQ(40, result.chunk_map.at(buffer_b_).offset); 1185 EXPECT_EQ(70, result.chunk_map.at(buffer_c_).offset); 1186 EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset); 1187 } 1188 1189 TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) { 1190 // space 1191 // ^ 1192 // | +-------+ 1193 // | +---b---+ 1194 // | +-------+ 1195 // | | | 1196 // | | d | 1197 // | +---a---+ +-------+ 1198 // | 1199 // | +-------+ 1200 // | | | 1201 // | | c | 1202 // | | | 1203 // | +-------+ 1204 // ---------------------> time 1205 GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/20); 1206 heap.Alloc(buffer_a_, 10); 1207 heap.Alloc(buffer_b_, 20); 1208 heap.Alloc(buffer_c_, 50); 1209 heap.Free(buffer_a_, 10); 1210 heap.Alloc(buffer_d_, 40); 1211 heap.Free(buffer_b_, 20); 1212 heap.Free(buffer_c_, 50); 1213 heap.Free(buffer_d_, 40); 1214 1215 const HeapSimulator::Result result = heap.Finish(); 1216 EXPECT_EQ(120, result.heap_size); 1217 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); 1218 EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); 1219 EXPECT_EQ(50, result.chunk_map.at(buffer_c_).size); 1220 EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size); 1221 1222 EXPECT_EQ(60, result.chunk_map.at(buffer_a_).offset); 1223 EXPECT_EQ(100, result.chunk_map.at(buffer_b_).offset); 1224 EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset); 1225 EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset); 1226 } 1227 1228 TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) { 1229 // space 1230 // ^ 1231 // | +-------+ 1232 // | +---b---+ 1233 // | +-------+ 1234 // | | d | 1235 // | +--a--+ +-------+ 1236 // | +-------+ 1237 // | | | 1238 // | | c | 1239 // | +-------+ 1240 // | +-------+ 1241 // | | | 1242 // | | e | 1243 // | | | 1244 // | +-------+ 1245 // ---------------------> time 1246 GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); 1247 heap.Alloc(buffer_a_, 10); 1248 heap.Alloc(buffer_b_, 20); 1249 heap.Alloc(buffer_c_, 40); 1250 heap.Free(buffer_a_, 10); 1251 heap.Alloc(buffer_d_, 30); 1252 heap.Alloc(buffer_e_, 50); 1253 heap.Free(buffer_b_, 20); 1254 heap.Free(buffer_c_, 40); 1255 heap.Free(buffer_d_, 30); 1256 heap.Free(buffer_e_, 50); 1257 1258 const HeapSimulator::Result result = heap.Finish(); 1259 EXPECT_EQ(140, result.heap_size); 1260 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); 1261 EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); 1262 EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size); 1263 EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size); 1264 EXPECT_EQ(50, result.chunk_map.at(buffer_e_).size); 1265 1266 EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset); 1267 EXPECT_EQ(120, result.chunk_map.at(buffer_b_).offset); 1268 EXPECT_EQ(50, result.chunk_map.at(buffer_c_).offset); 1269 EXPECT_EQ(90, result.chunk_map.at(buffer_d_).offset); 1270 EXPECT_EQ(0, result.chunk_map.at(buffer_e_).offset); 1271 } 1272 1273 } // namespace 1274 } // namespace xla 1275