1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 17 18 #include <memory> 19 #include <set> 20 #include <string> 21 #include <utility> 22 #include <vector> 23 24 #include "tensorflow/compiler/xla/literal_util.h" 25 #include "tensorflow/compiler/xla/ptr_util.h" 26 #include "tensorflow/compiler/xla/service/call_graph.h" 27 #include "tensorflow/compiler/xla/service/computation_tracker.h" 28 #include "tensorflow/compiler/xla/service/copy_insertion.h" 29 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 30 #include "tensorflow/compiler/xla/service/flatten_call_graph.h" 31 #include "tensorflow/compiler/xla/service/hlo_computation.h" 32 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 33 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 34 #include "tensorflow/compiler/xla/service/hlo_ordering.h" 35 #include "tensorflow/compiler/xla/service/hlo_scheduling.h" 36 #include "tensorflow/compiler/xla/shape_util.h" 37 #include "tensorflow/compiler/xla/test.h" 38 #include "tensorflow/compiler/xla/test_helpers.h" 39 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 40 #include "tensorflow/compiler/xla/types.h" 41 #include "tensorflow/compiler/xla/xla_data.pb.h" 42 #include "tensorflow/core/platform/macros.h" 43 44 namespace xla { 45 46 namespace { 47 48 // DFS visitor that collects the instructions referenced by a computation 49 // without descending into nested computations, i.e., only from the operands. 50 class InstructionListVisitor : public DfsHloVisitorWithDefault { 51 public: 52 explicit InstructionListVisitor(const HloInstruction* root) : root_(root) {} 53 54 Status DefaultAction(HloInstruction* hlo) override { 55 // For each instruction, just push it on the list after walking the 56 // operands. 57 instructions_.push_back(hlo); 58 VLOG(0) << "List instruction " << hlo->ToString(); 59 return Status::OK(); 60 } 61 62 std::vector<const HloInstruction*> GetInstructions() { return instructions_; } 63 64 private: 65 // The instruction root of the computation. 66 const HloInstruction* root_; 67 68 // The full set of instructions found (may be duplicates, e.g., kParameter). 69 std::vector<const HloInstruction*> instructions_; 70 71 TF_DISALLOW_COPY_AND_ASSIGN(InstructionListVisitor); 72 }; 73 74 const std::vector<const HloInstruction*> GetInstructions(HloInstruction* root) { 75 InstructionListVisitor main_list(root); 76 TF_CHECK_OK(root->Accept(&main_list)); 77 return main_list.GetInstructions(); 78 } 79 80 class BufferAssignmentTest : public HloTestBase { 81 protected: 82 BufferAssignmentTest() : computation_tracker_() {} 83 ~BufferAssignmentTest() override {} 84 85 std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module, 86 int64 alignment = 1) { 87 return BufferAssigner::Run( 88 module, xla::MakeUnique<DependencyHloOrdering>(module), 89 backend().compiler()->BufferSizeBytesFunction(), 90 [alignment](LogicalBuffer::Color) { return alignment; }) 91 .ConsumeValueOrDie(); 92 } 93 94 std::unique_ptr<BufferAssignment> RunColoredBufferAssignment( 95 HloModule* module, BufferLiveness::Colorer colorer, int64 alignment = 1) { 96 return BufferAssigner::Run( 97 module, xla::MakeUnique<DependencyHloOrdering>(module), 98 backend().compiler()->BufferSizeBytesFunction(), 99 [alignment](LogicalBuffer::Color) { return alignment; }, false, 100 std::move(colorer)) 101 .ConsumeValueOrDie(); 102 } 103 104 // Builds an x+1.0 computation to use in a Map. 105 std::unique_ptr<HloComputation> BuildMapComputationPlus1(const string& name) { 106 auto builder = HloComputation::Builder(name); 107 auto param = 108 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); 109 auto value = builder.AddInstruction( 110 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 111 builder.AddInstruction( 112 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value)); 113 return builder.Build(); 114 } 115 116 // Builds a simple compare-to-limit (x < 4) computation for a While. 117 // 118 // condition: 119 // const4[s32] -----------------------------------\ 120 // \ 121 // param[(s32,f32[4])] --- get-tuple-element[0] --- less-than 122 // 123 std::unique_ptr<HloComputation> BuildWhileConditionComputation( 124 const string& name) { 125 auto builder = HloComputation::Builder(name); 126 auto const4 = builder.AddInstruction( 127 HloInstruction::CreateConstant(Literal::CreateR0<int>(4))); 128 auto param = builder.AddInstruction( 129 HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); 130 auto index = builder.AddInstruction( 131 HloInstruction::CreateGetTupleElement(const4->shape(), param, 0)); 132 builder.AddInstruction( 133 HloInstruction::CreateBinary(r0f32_, HloOpcode::kLt, index, const4)); 134 return builder.Build(); 135 } 136 137 // Builds a simple body computation for a While. 138 // 139 // body: 140 // constv[f32[4]] --------------------------------------\ 141 // \ 142 // /--- get-tuple-elementv[1] --- addv ---\ 143 // param[(s32,f32[4])] ---| tuple 144 // \--- get-tuple-elementc[0] --- addc ---/ 145 // / 146 // const1[s32] -----------------------------------------/ 147 // 148 std::unique_ptr<HloComputation> BuildWhileBodyComputation( 149 const string& name) { 150 auto builder = HloComputation::Builder(name); 151 auto const1 = builder.AddInstruction( 152 HloInstruction::CreateConstant(Literal::CreateR0<int>(1))); 153 auto constv = builder.AddInstruction(HloInstruction::CreateConstant( 154 Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f}))); 155 auto param = builder.AddInstruction( 156 HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); 157 auto indexc = builder.AddInstruction( 158 HloInstruction::CreateGetTupleElement(const1->shape(), param, 0)); 159 auto addc = builder.AddInstruction(HloInstruction::CreateBinary( 160 indexc->shape(), HloOpcode::kAdd, indexc, const1)); 161 auto indexv = builder.AddInstruction( 162 HloInstruction::CreateGetTupleElement(constv->shape(), param, 1)); 163 auto addv = builder.AddInstruction(HloInstruction::CreateBinary( 164 constv->shape(), HloOpcode::kAdd, indexv, constv)); 165 builder.AddInstruction(HloInstruction::CreateTuple({addc, addv})); 166 return builder.Build(); 167 } 168 169 std::unique_ptr<HloComputation> BuildR0F32UnaryOpComputation( 170 HloOpcode opcode, const string& name) { 171 auto builder = HloComputation::Builder(name); 172 auto param = 173 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); 174 builder.AddInstruction(HloInstruction::CreateUnary(r0f32_, opcode, param)); 175 return builder.Build(); 176 } 177 178 // Verifies that the given instruction hlo has a valid input buffer assigned, 179 // i.e., the parameter number matches the op's. 180 const BufferAllocation& GetAssignedInputAllocation( 181 const BufferAssignment& buffers, HloInstruction* hlo) { 182 LOG(INFO) << "Checking input: " << hlo->ToString(); 183 const BufferAllocation& buffer = 184 *buffers.GetUniqueTopLevelSlice(hlo).ConsumeValueOrDie().allocation(); 185 EXPECT_EQ(hlo->parameter_number(), buffer.parameter_number()); 186 return buffer; 187 } 188 189 // Verifies that the given instruction hlo has a valid output buffer 190 // assigned, and returns it. 191 const BufferAllocation& GetAssignedOutputAllocation( 192 const BufferAssignment& buffers, HloInstruction* hlo) { 193 LOG(INFO) << "Checking output: " << hlo->ToString(); 194 const BufferAllocation& buffer = GetTopLevelAllocation(buffers, hlo); 195 return buffer; 196 } 197 198 // Returns the allocation for the given instruction. 199 const BufferAllocation& GetAllocation(const BufferAssignment& buffers, 200 const HloInstruction* hlo, 201 const ShapeIndex& index) { 202 return *buffers.GetUniqueSlice(hlo, index).ConsumeValueOrDie().allocation(); 203 } 204 const BufferAllocation& GetTopLevelAllocation(const BufferAssignment& buffers, 205 const HloInstruction* hlo) { 206 return *buffers.GetUniqueTopLevelSlice(hlo) 207 .ConsumeValueOrDie() 208 .allocation(); 209 } 210 211 // Verifies that all instructions in the given instruction list except 212 // kConstant have assigned buffers, and returns their total size. If min_index 213 // and max_index are not nullptr, the minimum and maximum buffer indices in 214 // the assignment are written into them. 215 int64 ValidateBuffers(const std::vector<const HloInstruction*>& instructions, 216 const BufferAssignment& buffers) { 217 // Verifies all instructions have buffers, and gets the index ranges. 218 for (const HloInstruction* hlo : instructions) { 219 if (!buffers.HasTopLevelAllocation(hlo)) { 220 // If `hlo` has no assigned buffer, it is either a constant or a nested 221 // parameter. 222 EXPECT_TRUE(HloOpcode::kConstant == hlo->opcode() || 223 HloOpcode::kParameter == hlo->opcode()); 224 continue; 225 } 226 } 227 228 // Gets the total size of all buffers assigned. 229 int64 total_size = 0; 230 for (auto& allocation : buffers.Allocations()) { 231 total_size += allocation.size(); 232 } 233 return total_size; 234 } 235 236 // Computation tracker for nested computations. 237 ComputationTracker computation_tracker_; 238 239 // Shapes for use in the examples. 240 Shape s32_ = ShapeUtil::MakeShape(xla::S32, {}); 241 Shape r0f32_ = ShapeUtil::MakeShape(xla::F32, {}); 242 Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4}); 243 Shape f32vec10_ = ShapeUtil::MakeShape(F32, {10}); 244 Shape f32vec100_ = ShapeUtil::MakeShape(F32, {100}); 245 Shape f32a100x10_ = ShapeUtil::MakeShape(F32, {100, 10}); 246 Shape t_s32_f32v4_ = ShapeUtil::MakeTupleShape({s32_, f32vec4_}); 247 Shape t_s32_f32v10_ = ShapeUtil::MakeTupleShape({s32_, f32vec10_}); 248 }; 249 250 // Returns true if the buffers assigned to instructions in "a" are distinct 251 // from the buffers assigned to those in "b" (ie, intersection is empty). 252 static bool BuffersDistinct(const std::vector<const HloInstruction*>& a, 253 const std::vector<const HloInstruction*>& b, 254 const BufferAssignment& assignment) { 255 std::set<BufferAllocation::Slice> a_slices; 256 for (const HloInstruction* instruction : a) { 257 if (assignment.HasTopLevelAllocation(instruction)) { 258 a_slices.insert( 259 assignment.GetUniqueTopLevelSlice(instruction).ConsumeValueOrDie()); 260 } 261 } 262 263 for (const HloInstruction* instruction : b) { 264 if (assignment.HasTopLevelAllocation(instruction)) { 265 if (a_slices.count(assignment.GetUniqueTopLevelSlice(instruction) 266 .ConsumeValueOrDie())) { 267 return false; 268 } 269 } 270 } 271 return true; 272 } 273 274 // Tests a computation consisting of a single scalar constant node. 275 TEST_F(BufferAssignmentTest, ScalarConstant) { 276 auto builder = HloComputation::Builder(TestName()); 277 auto const0 = builder.AddInstruction( 278 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 279 auto module = CreateNewModule(); 280 module->AddEntryComputation(builder.Build()); 281 282 auto buffers = RunBufferAssignment(module.get()); 283 // Check that the constant does not have a buffer assigned. 284 EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); 285 } 286 287 TEST_F(BufferAssignmentTest, BufferForConst) { 288 // Addition of two vector constants: checks that internal constant nodes have 289 // no buffers assigned, and their consumer has a buffer. 290 auto builder = HloComputation::Builder(TestName()); 291 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 292 Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f}))); 293 auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( 294 Literal::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f}))); 295 auto add = builder.AddInstruction( 296 HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1)); 297 auto module = CreateNewModule(); 298 module->AddEntryComputation(builder.Build()); 299 300 auto buffers = RunBufferAssignment(module.get()); 301 // The two constant nodes have no buffers assigned. 302 EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); 303 EXPECT_FALSE(buffers->HasTopLevelAllocation(const1)); 304 // The add node has an output buffer. 305 GetAssignedOutputAllocation(*buffers, add); 306 } 307 308 TEST_F(BufferAssignmentTest, HasAllocationAt) { 309 // Create a tuple with non-const and const elements and check that 310 // HasAllocationAt works correctly. 311 auto builder = HloComputation::Builder(TestName()); 312 auto param0 = builder.AddInstruction( 313 HloInstruction::CreateParameter(0, f32vec100_, "param0")); 314 auto constant = builder.AddInstruction( 315 HloInstruction::CreateConstant(Literal::CreateR0<int>(1))); 316 auto negate = builder.AddInstruction( 317 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); 318 auto tuple = builder.AddInstruction( 319 HloInstruction::CreateTuple({negate, param0, constant})); 320 auto module = CreateNewModule(); 321 module->AddEntryComputation(builder.Build()); 322 323 auto buffers = RunBufferAssignment(module.get()); 324 // Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation() 325 // reports for the instruction directly. 326 EXPECT_EQ(buffers->HasTopLevelAllocation(tuple), 327 buffers->HasAllocationAt(tuple, /*index=*/{})); 328 EXPECT_EQ(buffers->HasTopLevelAllocation(negate), 329 buffers->HasAllocationAt(tuple, /*index=*/{0})); 330 EXPECT_EQ(buffers->HasTopLevelAllocation(param0), 331 buffers->HasAllocationAt(tuple, /*index=*/{1})); 332 EXPECT_EQ(buffers->HasTopLevelAllocation(constant), 333 buffers->HasAllocationAt(tuple, /*index=*/{2})); 334 } 335 336 TEST_F(BufferAssignmentTest, BufferForOutputConst) { 337 // This computation copies a constant to output. 338 auto builder = HloComputation::Builder(TestName()); 339 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 340 Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f}))); 341 auto copy = builder.AddInstruction( 342 HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0)); 343 auto module = CreateNewModule(); 344 module->AddEntryComputation(builder.Build()); 345 346 auto buffers = RunBufferAssignment(module.get()); 347 // The copy node now has an output buffer. 348 GetAssignedOutputAllocation(*buffers, copy); 349 } 350 351 TEST_F(BufferAssignmentTest, Basic) { 352 // paramscalar ------- (mul) -- (add) -- (sub) 353 // / / / 354 // param0[100] -------/ / / 355 // / / 356 // param1[100] --------------/--------/ 357 auto builder = HloComputation::Builder(TestName()); 358 auto paramscalar = 359 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); 360 auto param0 = builder.AddInstruction( 361 HloInstruction::CreateParameter(1, f32vec100_, "")); 362 auto param1 = builder.AddInstruction( 363 HloInstruction::CreateParameter(2, f32vec100_, "")); 364 auto mul = builder.AddInstruction(HloInstruction::CreateBinary( 365 f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); 366 auto add = builder.AddInstruction( 367 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); 368 auto sub = builder.AddInstruction(HloInstruction::CreateBinary( 369 f32vec100_, HloOpcode::kSubtract, add, param1)); 370 auto module = CreateNewModule(); 371 module->AddEntryComputation(builder.Build()); 372 373 auto buffers = RunBufferAssignment(module.get()); 374 375 // Distinct input buffers were assigned for parameters. 376 BufferAllocation paramscalar_buffer = 377 GetAssignedInputAllocation(*buffers, paramscalar); 378 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0); 379 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1); 380 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index()); 381 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index()); 382 EXPECT_NE(param0_buffer.index(), param1_buffer.index()); 383 384 // The mul node has a valid buffer assigned, doesn't share with input. 385 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); 386 EXPECT_NE(mul_buffer.index(), param0_buffer.index()); 387 388 // The add node can reuse the mul node's buffer. 389 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add); 390 EXPECT_EQ(add_buffer.index(), mul_buffer.index()); 391 392 // The sub node has a valid output buffer assigned. 393 GetAssignedOutputAllocation(*buffers, sub); 394 } 395 396 TEST_F(BufferAssignmentTest, BasicUniquelyColored) { 397 // paramscalar ------- (mul) -- (add) -- (sub) 398 // / / / 399 // param0[100] -------/ / / 400 // / / 401 // param1[100] --------------/--------/ 402 // The output of each op is colored with a different color, so we can not 403 // share anything. 404 auto builder = HloComputation::Builder(TestName()); 405 auto paramscalar = 406 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); 407 auto param0 = builder.AddInstruction( 408 HloInstruction::CreateParameter(1, f32vec100_, "")); 409 auto param1 = builder.AddInstruction( 410 HloInstruction::CreateParameter(2, f32vec100_, "")); 411 auto mul = builder.AddInstruction(HloInstruction::CreateBinary( 412 f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); 413 auto add = builder.AddInstruction( 414 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); 415 auto sub = builder.AddInstruction(HloInstruction::CreateBinary( 416 f32vec100_, HloOpcode::kSubtract, add, param1)); 417 auto module = CreateNewModule(); 418 module->AddEntryComputation(builder.Build()); 419 420 auto colorer = [](const BufferLiveness& buffer_liveness) { 421 int color = 0; 422 423 for (LogicalBuffer::Id id = 0; 424 id < buffer_liveness.points_to_analysis().num_logical_buffers(); 425 id++) { 426 auto& buffer = buffer_liveness.points_to_analysis().logical_buffer(id); 427 buffer.set_color(LogicalBuffer::Color(color++)); 428 } 429 return Status::OK(); 430 }; 431 432 auto buffers = RunColoredBufferAssignment(module.get(), colorer); 433 434 // Distinct input buffers were assigned for parameters. 435 BufferAllocation paramscalar_buffer = 436 GetAssignedInputAllocation(*buffers, paramscalar); 437 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0); 438 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1); 439 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index()); 440 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index()); 441 EXPECT_NE(param0_buffer.index(), param1_buffer.index()); 442 443 // The mul node has a valid buffer assigned, doesn't share with input. 444 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); 445 EXPECT_NE(mul_buffer.index(), param0_buffer.index()); 446 447 // The add node can not reuse the mul node's buffer due to coloring. 448 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add); 449 EXPECT_NE(add_buffer.index(), mul_buffer.index()); 450 451 // The sub node has a valid output buffer assigned. 452 GetAssignedOutputAllocation(*buffers, sub); 453 } 454 455 TEST_F(BufferAssignmentTest, BasicPartiallyColored) { 456 // paramscalar ------- (mul) -- (add) -- (sub) 457 // / / / 458 // param0[100] -------/ / / 459 // / / 460 // param1[100] --------------/--------/ 461 // The output of the mul and the add have the color 1, and the other buffers 462 // have the color 0, which allows the mul and add to share buffers. 463 auto builder = HloComputation::Builder(TestName()); 464 auto paramscalar = 465 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); 466 auto param0 = builder.AddInstruction( 467 HloInstruction::CreateParameter(1, f32vec100_, "")); 468 auto param1 = builder.AddInstruction( 469 HloInstruction::CreateParameter(2, f32vec100_, "")); 470 auto mul = builder.AddInstruction(HloInstruction::CreateBinary( 471 f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); 472 auto add = builder.AddInstruction( 473 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); 474 auto sub = builder.AddInstruction(HloInstruction::CreateBinary( 475 f32vec100_, HloOpcode::kSubtract, add, param1)); 476 auto module = CreateNewModule(); 477 module->AddEntryComputation(builder.Build()); 478 479 auto colorer = [](const BufferLiveness& buffer_liveness) { 480 for (LogicalBuffer::Id id = 0; 481 id < buffer_liveness.points_to_analysis().num_logical_buffers(); 482 id++) { 483 auto& buffer = buffer_liveness.points_to_analysis().logical_buffer(id); 484 const auto& aliases = 485 buffer_liveness.points_to_analysis().GetBufferAliases(buffer); 486 for (const auto& alias : aliases) { 487 if (alias.instruction()->opcode() == HloOpcode::kAdd || 488 alias.instruction()->opcode() == HloOpcode::kMultiply) { 489 buffer.set_color(LogicalBuffer::Color(1)); 490 } 491 } 492 if (!buffer.has_color()) { 493 buffer.set_color(LogicalBuffer::Color(0)); 494 } 495 } 496 return Status::OK(); 497 }; 498 499 auto buffers = RunColoredBufferAssignment(module.get(), colorer); 500 501 // Distinct input buffers were assigned for parameters. 502 BufferAllocation paramscalar_buffer = 503 GetAssignedInputAllocation(*buffers, paramscalar); 504 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0); 505 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1); 506 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index()); 507 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index()); 508 EXPECT_NE(param0_buffer.index(), param1_buffer.index()); 509 510 // The mul node has a valid buffer assigned, doesn't share with input. 511 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); 512 EXPECT_NE(mul_buffer.index(), param0_buffer.index()); 513 514 // The add node can reuse the mul node's buffer. 515 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add); 516 EXPECT_EQ(add_buffer.index(), mul_buffer.index()); 517 518 // The sub node has a valid output buffer assigned. 519 GetAssignedOutputAllocation(*buffers, sub); 520 } 521 522 TEST_F(BufferAssignmentTest, MultipleUsersForNode) { 523 // This is similar to the Basic test, with the difference that (sub) is 524 // another user of (mul)'s result, so (mul)'s buffer cannot be reused for 525 // (add)'s output. 526 // 527 // paramscalar -------\ /-----------\ 528 // \ / \ 529 // param0[100] ------- (mul) -- (add) -- (sub) 530 // / 531 // param1[100] ----------------/ 532 // 533 auto builder = HloComputation::Builder(TestName()); 534 auto paramscalar = 535 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); 536 auto param0 = builder.AddInstruction( 537 HloInstruction::CreateParameter(1, f32vec100_, "")); 538 auto param1 = builder.AddInstruction( 539 HloInstruction::CreateParameter(2, f32vec100_, "")); 540 auto mul = builder.AddInstruction(HloInstruction::CreateBinary( 541 f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); 542 auto add = builder.AddInstruction( 543 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); 544 auto sub = builder.AddInstruction( 545 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kSubtract, add, mul)); 546 auto module = CreateNewModule(); 547 module->AddEntryComputation(builder.Build()); 548 549 auto buffers = RunBufferAssignment(module.get()); 550 551 // Input buffers were assigned for parameters. 552 BufferAllocation paramscalar_buffer = 553 GetAssignedInputAllocation(*buffers, paramscalar); 554 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0); 555 BufferAllocation param1_index = GetAssignedInputAllocation(*buffers, param1); 556 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index()); 557 EXPECT_NE(paramscalar_buffer.index(), param1_index.index()); 558 EXPECT_NE(param0_buffer.index(), param1_index.index()); 559 560 // The mul node had a buffer allocated. 561 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); 562 563 // Now the add node can't reuse the mul node's buffer. 564 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add); 565 EXPECT_NE(add_buffer.index(), mul_buffer.index()); 566 567 // Log size information for inspection. 568 const std::vector<const HloInstruction*> level0 = GetInstructions(sub); 569 int64 size0 = ValidateBuffers(level0, *buffers); 570 LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size() 571 << " for " << level0.size() << " instructions; " 572 << "total buffer size " << size0; 573 } 574 575 TEST_F(BufferAssignmentTest, TrivialMap) { 576 // This tests a trivial x+1 map as the only operation. 577 // 578 // param0[100x10] ---> (map x+1) 579 // 580 // Builds the map function. 581 auto module = CreateNewModule(); 582 auto map_computation = 583 module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1")); 584 auto inner_last = map_computation->root_instruction(); 585 586 // Creates the main kernel and verifies instruction counts. 587 auto builder = HloComputation::Builder(TestName()); 588 auto param0 = builder.AddInstruction( 589 HloInstruction::CreateParameter(0, f32a100x10_, "")); 590 auto map = builder.AddInstruction( 591 HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation)); 592 module->AddEntryComputation(builder.Build()); 593 594 const std::vector<const HloInstruction*> level0 = GetInstructions(map); 595 EXPECT_EQ(2, level0.size()) << "Invalid main kernel size"; 596 const std::vector<const HloInstruction*> level1 = GetInstructions(inner_last); 597 EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size"; 598 599 // Assigns buffers and fetches sizes. 600 auto buffers = RunBufferAssignment(module.get()); 601 int64 size0 = ValidateBuffers(level0, *buffers); 602 int64 size1 = ValidateBuffers(level1, *buffers); 603 604 // Both algorithms assign the map's buffer before processing the embedded 605 // computation, so we can verify that the buffers aren't shared between them 606 // by checking: 607 EXPECT_TRUE(BuffersDistinct(level0, level1, *buffers)) 608 << "Reuse between main kernel and embedded mapping."; 609 610 // An input buffer was assigned for the parameter. 611 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0); 612 613 // An output buffer was assigned for the map. 614 BufferAllocation map_buffer = GetAssignedOutputAllocation(*buffers, map); 615 EXPECT_NE(param0_buffer.index(), map_buffer.index()); 616 617 // The final computation node of the map is an add of an f32 param and a 618 // constant. 619 EXPECT_EQ(HloOpcode::kAdd, inner_last->opcode()); 620 const BufferAllocation& inner_add_buffer = 621 GetTopLevelAllocation(*buffers, inner_last); 622 EXPECT_NE(inner_add_buffer.index(), map_buffer.index()); 623 624 // Log size information for inspection. 625 LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size() 626 << " for " << level0.size() + level1.size() << " instructions; " 627 << "total buffer size " << size0 + size1; 628 } 629 630 TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { 631 // Make sure that the input buffer of a reduce cannot be reused for its 632 // output. (Reuse is not safe in the general case, as it reshapes and some 633 // out-of-order reductions could overwrite an element before a use.) 634 // 635 // param0[100] --- (exp1) --- (exp2) --- (reduce x+1) --- (exp3) 636 auto module = CreateNewModule(); 637 auto reduce_computation = 638 module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1")); 639 640 auto builder = HloComputation::Builder(TestName()); 641 auto param0 = builder.AddInstruction( 642 HloInstruction::CreateParameter(0, f32a100x10_, "")); 643 auto exp1 = builder.AddInstruction( 644 HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0)); 645 auto exp2 = builder.AddInstruction( 646 HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1)); 647 auto const0 = builder.AddInstruction( 648 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); 649 auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( 650 /*shape=*/f32vec10_, 651 /*operand=*/exp2, 652 /*init_value=*/const0, 653 /*dimensions_to_reduce=*/{0}, reduce_computation)); 654 auto exp3 = builder.AddInstruction( 655 HloInstruction::CreateUnary(f32vec10_, HloOpcode::kExp, reduce)); 656 657 module->AddEntryComputation(builder.Build()); 658 659 auto buffers = RunBufferAssignment(module.get()); 660 const std::vector<const HloInstruction*> instrs = GetInstructions(exp3); 661 ValidateBuffers(instrs, *buffers); 662 663 const BufferAllocation& exp1_buffer = GetTopLevelAllocation(*buffers, exp1); 664 const BufferAllocation& exp2_buffer = GetTopLevelAllocation(*buffers, exp2); 665 const BufferAllocation& reduce_buffer = 666 GetTopLevelAllocation(*buffers, reduce); 667 668 // The buffer of exp1 is trivially reusable for exp2 - this is just for sanity 669 // checking. 670 EXPECT_EQ(exp1_buffer.index(), exp2_buffer.index()); 671 672 // The buffer of exp2 cannot be used for reduce, even though it's the only 673 // operand. 674 EXPECT_NE(exp2_buffer.index(), reduce_buffer.index()); 675 } 676 677 TEST_F(BufferAssignmentTest, ExampleWhile) { 678 // This tests a While loop example from the ir_semantics document. 679 // 680 // condition (s32,f32[4]) -> bool -- see BuildWhileConditionComputation. 681 // body: (s32,f32[4]) -> (s32,f32[4]) -- see BuildWhileBodyComputation. 682 // 683 // const3[s32] -------\ 684 // const4[f32[4]] --- tuple --- while[condition, body] 685 // 686 // Builds the nested condition and body. 687 auto module = CreateNewModule(); 688 auto condition_computation = 689 module->AddEmbeddedComputation(BuildWhileConditionComputation("if<4")); 690 auto body_computation = 691 module->AddEmbeddedComputation(BuildWhileBodyComputation("add-update")); 692 693 // Creates the main kernel and verifies instruction counts. 694 auto builder = HloComputation::Builder(TestName()); 695 auto const3 = builder.AddInstruction( 696 HloInstruction::CreateConstant(Literal::CreateR0<int>(0))); 697 auto const4 = builder.AddInstruction(HloInstruction::CreateConstant( 698 Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f}))); 699 auto tuple = 700 builder.AddInstruction(HloInstruction::CreateTuple({const3, const4})); 701 auto while_op = builder.AddInstruction(HloInstruction::CreateWhile( 702 t_s32_f32v4_, condition_computation, body_computation, tuple)); 703 module->AddEntryComputation(builder.Build()); 704 705 const std::vector<const HloInstruction*> level0 = GetInstructions(while_op); 706 EXPECT_EQ(4, level0.size()) << "Invalid while kernel size"; 707 const std::vector<const HloInstruction*> levelc = 708 GetInstructions(condition_computation->root_instruction()); 709 EXPECT_EQ(4, levelc.size()) << "Invalid nested condition size"; 710 const std::vector<const HloInstruction*> levelb = 711 GetInstructions(body_computation->root_instruction()); 712 EXPECT_EQ(8, levelb.size()) << "Invalid nested body size"; 713 714 // Assigns buffers and fetches sizes. 715 auto buffers = RunBufferAssignment(module.get()); 716 int64 size0 = ValidateBuffers(level0, *buffers); 717 int64 sizec = ValidateBuffers(levelc, *buffers); 718 int64 sizeb = ValidateBuffers(levelb, *buffers); 719 720 // BufferAssignment will assign a single allocation for the following 721 // instructions: while, while.cond.param, while.body.param, while.body.result. 722 EXPECT_FALSE(BuffersDistinct(level0, levelc, *buffers)) 723 << "Should be reuse between main kernel and embedded condition."; 724 EXPECT_FALSE(BuffersDistinct(levelb, levelc, *buffers)) 725 << "Should be reuse between embedded condition and body."; 726 // Expect buffer reuse between main kernel and body computation. 727 EXPECT_FALSE(BuffersDistinct(level0, levelb, *buffers)) 728 << "Should be reuse between main kernel and embedded body."; 729 730 // The final computation node of the while body is a tuple of s32 and 731 // f32[4] adds. 732 HloInstruction* body_root = body_computation->root_instruction(); 733 EXPECT_EQ(HloOpcode::kTuple, body_root->opcode()); 734 735 // Check that buffer for each subshape of 'while_op' shares allocation with 736 // corresponding buffer from while body computation at same index. 737 ShapeUtil::ForEachSubshape( 738 while_op->shape(), 739 [this, &buffers, while_op, body_root](const Shape& /*subshape*/, 740 const ShapeIndex& index) { 741 auto while_op_allocation = GetAllocation(*buffers, while_op, index); 742 auto body_root_allocation = GetAllocation(*buffers, body_root, index); 743 EXPECT_EQ(while_op_allocation.index(), body_root_allocation.index()); 744 }); 745 746 // Log size information for inspection. 747 LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size() 748 << " for " << level0.size() + levelc.size() + levelb.size() 749 << " instructions; total buffer size " << size0 + sizec + sizeb; 750 } 751 752 TEST_F(BufferAssignmentTest, ExampleConditional) { 753 auto module = CreateNewModule(); 754 auto true_computation = module->AddEmbeddedComputation( 755 BuildR0F32UnaryOpComputation(HloOpcode::kCeil, "Ceil")); 756 auto false_computation = module->AddEmbeddedComputation( 757 BuildR0F32UnaryOpComputation(HloOpcode::kFloor, "Floor")); 758 759 auto builder = HloComputation::Builder(TestName()); 760 auto pred = builder.AddInstruction( 761 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 762 auto const1 = builder.AddInstruction( 763 HloInstruction::CreateConstant(Literal::CreateR0<float>(56.4f))); 764 auto const2 = builder.AddInstruction( 765 HloInstruction::CreateConstant(Literal::CreateR0<float>(12.4f))); 766 auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( 767 r0f32_, pred, const1, true_computation, const2, false_computation)); 768 module->AddEntryComputation(builder.Build()); 769 770 const std::vector<const HloInstruction*> conditional_instrs = 771 GetInstructions(conditional); 772 const std::vector<const HloInstruction*> true_instrs = 773 GetInstructions(true_computation->root_instruction()); 774 const std::vector<const HloInstruction*> false_instrs = 775 GetInstructions(false_computation->root_instruction()); 776 EXPECT_EQ(4, conditional_instrs.size()); 777 EXPECT_EQ(2, true_instrs.size()); 778 EXPECT_EQ(2, false_instrs.size()); 779 780 auto buffers = RunBufferAssignment(module.get()); 781 ValidateBuffers(conditional_instrs, *buffers); 782 ValidateBuffers(true_instrs, *buffers); 783 ValidateBuffers(false_instrs, *buffers); 784 785 EXPECT_FALSE(BuffersDistinct(conditional_instrs, true_instrs, *buffers)) 786 << "Should be reuse between conditional and true computation."; 787 EXPECT_FALSE(BuffersDistinct(conditional_instrs, false_instrs, *buffers)) 788 << "Should be reuse between conditional and false computation."; 789 EXPECT_FALSE(BuffersDistinct(true_instrs, false_instrs, *buffers)) 790 << "Should be reuse between true and false computations."; 791 792 const BufferAllocation& conditional_buffer = 793 GetTopLevelAllocation(*buffers, conditional); 794 const BufferAllocation& true_buffer = 795 GetTopLevelAllocation(*buffers, true_computation->root_instruction()); 796 const BufferAllocation& false_buffer = 797 GetTopLevelAllocation(*buffers, false_computation->root_instruction()); 798 EXPECT_EQ(conditional_buffer.size(), true_buffer.size()); 799 EXPECT_EQ(conditional_buffer.size(), false_buffer.size()); 800 } 801 802 TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { 803 // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg) 804 auto builder = HloComputation::Builder(TestName()); 805 auto param0 = builder.AddInstruction( 806 HloInstruction::CreateParameter(0, f32vec100_, "")); 807 auto exp1 = builder.AddInstruction( 808 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0)); 809 auto tanh = builder.AddInstruction( 810 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kTanh, exp1)); 811 auto exp2 = builder.AddInstruction( 812 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, tanh)); 813 auto neg = builder.AddInstruction( 814 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, exp2)); 815 816 auto module = CreateNewModule(); 817 module->AddEntryComputation(builder.Build()); 818 auto assignment = RunBufferAssignment(module.get()); 819 820 // tanh and exp2 can reuse exp1's buffer 821 EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1)); 822 auto& buffer_for_exp1 = GetTopLevelAllocation(*assignment, exp1); 823 EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, tanh)); 824 EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, exp2)); 825 EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, neg)); 826 } 827 828 TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) { 829 // This computation is a chain of operations which decreases in buffer size 830 // (via slice) then increases in size (via broadcast): 831 // 832 // param ---> (negate) ---> (slice) ---> (broadcast) 833 // 834 // The negate should share a buffer with broadcast. 835 auto builder = HloComputation::Builder(TestName()); 836 auto param0 = builder.AddInstruction( 837 HloInstruction::CreateParameter(0, f32vec100_, "param0")); 838 auto negate = builder.AddInstruction( 839 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); 840 auto slice = builder.AddInstruction( 841 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1})); 842 auto broadcast = builder.AddInstruction( 843 HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); 844 845 auto module = CreateNewModule(); 846 module->AddEntryComputation(builder.Build()); 847 auto assignment = RunBufferAssignment(module.get()); 848 849 // negate and broadcast should share a buffer. 850 EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast)); 851 auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast); 852 EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate)); 853 854 // Slice should have its own buffer. 855 EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice)); 856 } 857 858 TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) { 859 // This computation is identical to that in ReuseNonOperandBuffer, but the 860 // negate value is live until the end of the computation (due to it being an 861 // operand of the output tuple) preventing reuse. 862 // 863 // param ---> (negate) ---> (slice) ---> (broadcast)-> (tuple) 864 // \-----------------------------------/ 865 // 866 // The negate should not share a buffer with broadcast. 867 auto builder = HloComputation::Builder(TestName()); 868 auto param0 = builder.AddInstruction( 869 HloInstruction::CreateParameter(0, f32vec100_, "param0")); 870 auto negate = builder.AddInstruction( 871 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); 872 auto slice = builder.AddInstruction( 873 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1})); 874 auto broadcast = builder.AddInstruction( 875 HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); 876 builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast})); 877 878 auto module = CreateNewModule(); 879 module->AddEntryComputation(builder.Build()); 880 auto assignment = RunBufferAssignment(module.get()); 881 882 // The instructions should not share buffers. 883 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), 884 GetTopLevelAllocation(*assignment, negate)); 885 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), 886 GetTopLevelAllocation(*assignment, slice)); 887 EXPECT_NE(GetTopLevelAllocation(*assignment, negate), 888 GetTopLevelAllocation(*assignment, slice)); 889 } 890 891 TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) { 892 // This computation is identical to that in ReuseNonOperandBuffer, but the 893 // negate value is placed into a tuple which lives to the end of the 894 // computation. This extends the live range of negate's buffer preventing 895 // reuse due to buffer aliasing. 896 // 897 // param ---> (negate) ---> (tuple) -> (slice) ---> (broadcast)-> (tuple) 898 // \-----------------------------------/ 899 // 900 // The negate should not share a buffer with broadcast. 901 auto builder = HloComputation::Builder(TestName()); 902 auto param0 = builder.AddInstruction( 903 HloInstruction::CreateParameter(0, f32vec100_, "param0")); 904 auto negate = builder.AddInstruction( 905 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); 906 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({negate})); 907 auto tuple_element = builder.AddInstruction( 908 HloInstruction::CreateGetTupleElement(f32vec100_, tuple, 0)); 909 auto slice = builder.AddInstruction( 910 HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10}, {1})); 911 auto broadcast = builder.AddInstruction( 912 HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); 913 builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast})); 914 915 auto module = CreateNewModule(); 916 module->AddEntryComputation(builder.Build()); 917 auto assignment = RunBufferAssignment(module.get()); 918 919 // The instructions should not share buffers. 920 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), 921 GetTopLevelAllocation(*assignment, negate)); 922 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), 923 GetTopLevelAllocation(*assignment, slice)); 924 EXPECT_NE(GetTopLevelAllocation(*assignment, negate), 925 GetTopLevelAllocation(*assignment, slice)); 926 } 927 928 TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) { 929 // This computation is very similar to ReuseNonOperandBuffer except the 930 // broadcast has a smaller output than the negate. This should block reuse of 931 // negate's buffer by broadcast because the output buffer(s) of a computation 932 // should be exactly sized for the value. 933 // 934 // param ---> (negate) ---> (slice) ---> (broadcast) 935 // 936 // Neither negate nor slice may share a buffer with broadcast. 937 auto builder = HloComputation::Builder(TestName()); 938 auto param0 = builder.AddInstruction( 939 HloInstruction::CreateParameter(0, f32vec100_, "param0")); 940 // Negate output is 100 elements. 941 auto negate = builder.AddInstruction( 942 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); 943 // Slice output is 10 elements. 944 auto slice = builder.AddInstruction( 945 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1})); 946 // Broadcast output is 40 elements. 947 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( 948 ShapeUtil::MakeShape(F32, {10, 4}), slice, {0})); 949 950 auto module = CreateNewModule(); 951 module->AddEntryComputation(builder.Build()); 952 auto assignment = RunBufferAssignment(module.get()); 953 954 // The broadcast output buffer cannot be shared. 955 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), 956 GetTopLevelAllocation(*assignment, negate)); 957 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), 958 GetTopLevelAllocation(*assignment, slice)); 959 } 960 961 TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) { 962 // This is identical to DoNotReuseOversizedOutputBuffer except the broadcast 963 // output is exactly the same size as the negate (rather than being 964 // smaller). This enables reuse of negate's buffer by the broadcast because 965 // the output buffer will be sized exactly to its value. 966 // 967 // param ---> (negate) ---> (slice) ---> (broadcast) 968 // 969 // The negate should *not* share a buffer with broadcast. 970 auto builder = HloComputation::Builder(TestName()); 971 auto param0 = builder.AddInstruction( 972 HloInstruction::CreateParameter(0, f32vec100_, "param0")); 973 // Negate output is 100 elements. 974 auto negate = builder.AddInstruction( 975 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); 976 auto slice = builder.AddInstruction( 977 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1})); 978 // Broadcast output is 40 elements. 979 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( 980 ShapeUtil::MakeShape(F32, {10, 10}), slice, {0})); 981 982 auto module = CreateNewModule(); 983 module->AddEntryComputation(builder.Build()); 984 auto assignment = RunBufferAssignment(module.get()); 985 986 // negate and broadcast should share a buffer. 987 EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast)); 988 auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast); 989 EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate)); 990 991 // Slice should have its own buffer. 992 EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice)); 993 } 994 995 TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) { 996 // This computation is very similar to ReuseNonOperandBuffer except the 997 // broadcast has a smaller output than the negate, and the broadcast is 998 // contained in the computation output as a tuple element. This should block 999 // reuse of the negate's buffer by the broadcast because the output buffer(s) 1000 // of a computation should be exactly sized for the value. This includes those 1001 // buffers aliased in the output (eg, contained as tuple elements). 1002 // 1003 // param ---> (negate) ---> (slice) ---> (broadcast) --> (tuple) 1004 // 1005 // Neither negate nor slice may share a buffer with broadcast. 1006 auto builder = HloComputation::Builder(TestName()); 1007 auto param0 = builder.AddInstruction( 1008 HloInstruction::CreateParameter(0, f32vec100_, "param0")); 1009 // Negate output is 100 elements. 1010 auto negate = builder.AddInstruction( 1011 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); 1012 // Slice output is 10 elements. 1013 auto slice = builder.AddInstruction( 1014 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1})); 1015 // Broadcast output is 40 elements. 1016 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( 1017 ShapeUtil::MakeShape(F32, {10, 4}), slice, {0})); 1018 builder.AddInstruction(HloInstruction::CreateTuple({broadcast})); 1019 1020 auto module = CreateNewModule(); 1021 module->AddEntryComputation(builder.Build()); 1022 auto assignment = RunBufferAssignment(module.get()); 1023 1024 // The broadcast output buffer cannot be shared. 1025 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), 1026 GetTopLevelAllocation(*assignment, negate)); 1027 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), 1028 GetTopLevelAllocation(*assignment, slice)); 1029 } 1030 1031 TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { 1032 // Verify that buffers for embedded computations are properly marked as 1033 // thread-local and that embedded parameters are not marked as 1034 // is_entry_computation_parameter. 1035 auto module = CreateNewModule(); 1036 auto vec_shape = ShapeUtil::MakeShape(F32, {42}); 1037 auto scalar_shape = ShapeUtil::MakeShape(F32, {}); 1038 1039 // Create a scalar computation to use in a map. 1040 auto map_builder = HloComputation::Builder(TestName() + "_map"); 1041 auto map_param = map_builder.AddInstruction( 1042 HloInstruction::CreateParameter(0, scalar_shape, "map_param")); 1043 auto map_root = map_builder.AddInstruction( 1044 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, map_param)); 1045 auto map_computation = module->AddEmbeddedComputation(map_builder.Build()); 1046 1047 // Create a vector computation to use in a kCall. 1048 auto call_builder = HloComputation::Builder(TestName() + "_call"); 1049 auto call_param = call_builder.AddInstruction( 1050 HloInstruction::CreateParameter(0, vec_shape, "vec_param")); 1051 auto call_root = call_builder.AddInstruction( 1052 HloInstruction::CreateUnary(vec_shape, HloOpcode::kExp, call_param)); 1053 auto call_computation = module->AddEmbeddedComputation(call_builder.Build()); 1054 1055 // Create entry computation which kCalls call_computation and then calls map 1056 // with map_computation on the result. 1057 auto builder = HloComputation::Builder(TestName()); 1058 auto param = builder.AddInstruction( 1059 HloInstruction::CreateParameter(0, vec_shape, "param")); 1060 auto call = builder.AddInstruction( 1061 HloInstruction::CreateCall(vec_shape, {param}, call_computation)); 1062 auto map = builder.AddInstruction( 1063 HloInstruction::CreateMap(vec_shape, {call}, map_computation)); 1064 module->AddEntryComputation(builder.Build()); 1065 1066 auto assignment = RunBufferAssignment(module.get()); 1067 1068 // Allocations for the map computation should be thread-local and not 1069 // live-out. 1070 auto& map_param_alloc = GetTopLevelAllocation(*assignment, map_param); 1071 EXPECT_FALSE(map_param_alloc.is_entry_computation_parameter()); 1072 EXPECT_FALSE(map_param_alloc.maybe_live_out()); 1073 EXPECT_TRUE(map_param_alloc.is_thread_local()); 1074 1075 auto& map_root_alloc = GetTopLevelAllocation(*assignment, map_root); 1076 EXPECT_FALSE(map_root_alloc.is_entry_computation_parameter()); 1077 EXPECT_FALSE(map_root_alloc.maybe_live_out()); 1078 EXPECT_TRUE(map_root_alloc.is_thread_local()); 1079 1080 // Allocations for the call computation should not be thread-local. 1081 auto& call_param_alloc = GetTopLevelAllocation(*assignment, call_param); 1082 EXPECT_FALSE(call_param_alloc.is_entry_computation_parameter()); 1083 EXPECT_FALSE(call_param_alloc.maybe_live_out()); 1084 EXPECT_FALSE(call_param_alloc.is_thread_local()); 1085 1086 auto& call_root_alloc = GetTopLevelAllocation(*assignment, call_root); 1087 EXPECT_FALSE(call_root_alloc.is_entry_computation_parameter()); 1088 EXPECT_FALSE(call_root_alloc.is_thread_local()); 1089 1090 // Entry computation allocations can be marked liveout and 1091 // is_entry_computation_parameter. 1092 auto& param_alloc = GetTopLevelAllocation(*assignment, param); 1093 EXPECT_TRUE(param_alloc.is_entry_computation_parameter()); 1094 EXPECT_FALSE(param_alloc.maybe_live_out()); 1095 EXPECT_FALSE(param_alloc.is_thread_local()); 1096 1097 auto& map_alloc = GetTopLevelAllocation(*assignment, map); 1098 EXPECT_FALSE(map_alloc.is_entry_computation_parameter()); 1099 EXPECT_TRUE(map_alloc.maybe_live_out()); 1100 EXPECT_FALSE(map_alloc.is_thread_local()); 1101 } 1102 1103 TEST_F(BufferAssignmentTest, TupleParameterAsOutput) { 1104 // Test a computation that returns a tuple parameter. 1105 auto builder = HloComputation::Builder(TestName()); 1106 auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter( 1107 0, 1108 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), 1109 ShapeUtil::MakeShape(F32, {}), 1110 ShapeUtil::MakeShape(S32, {42})}), 1111 "param0")); 1112 1113 auto module = CreateNewModule(); 1114 module->AddEntryComputation(builder.Build()); 1115 auto assignment = RunBufferAssignment(module.get()); 1116 1117 // There should be four allocations: one for vector of pointers, and one for 1118 // each tuple element. 1119 EXPECT_EQ(4, assignment->Allocations().size()); 1120 1121 // Verify each buffer allocation is marked as an entry computation parameter 1122 // and is liveout. 1123 ShapeUtil::ForEachSubshape( 1124 tuple_param->shape(), 1125 [this, &assignment, tuple_param](const Shape& /*subshape*/, 1126 const ShapeIndex& index) { 1127 auto allocation = GetAllocation(*assignment, tuple_param, index); 1128 EXPECT_TRUE(allocation.is_entry_computation_parameter()); 1129 EXPECT_EQ(0, allocation.parameter_number()); 1130 EXPECT_TRUE(allocation.maybe_live_out()); 1131 }); 1132 } 1133 1134 TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) { 1135 // Test a computation which returns a GetElementTuple of a nested tuple 1136 // parameter. 1137 auto builder = HloComputation::Builder(TestName()); 1138 auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter( 1139 0, 1140 ShapeUtil::MakeTupleShape( 1141 {ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), 1142 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {42}), 1143 ShapeUtil::MakeShape(S32, {101})})}), 1144 "param0")); 1145 auto tuple_element = 1146 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 1147 ShapeUtil::GetSubshape(tuple_param->shape(), {1}), tuple_param, 1)); 1148 1149 auto module = CreateNewModule(); 1150 module->AddEntryComputation(builder.Build()); 1151 auto assignment = RunBufferAssignment(module.get()); 1152 1153 // Only some of the elements of the input param are liveout. 1154 EXPECT_FALSE( 1155 GetAllocation(*assignment, tuple_param, /*index=*/{}).maybe_live_out()); 1156 // Tuple element at index={1} is live out because GetTupleElement({1}) 1157 // forwards a pointer to this allocation (instead of defining its own buffer). 1158 EXPECT_TRUE( 1159 GetAllocation(*assignment, tuple_param, /*index=*/{1}).maybe_live_out()); 1160 EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0}) 1161 .maybe_live_out()); 1162 EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1}) 1163 .maybe_live_out()); 1164 1165 // The GetTupleElement output is liveout. 1166 EXPECT_TRUE( 1167 GetTopLevelAllocation(*assignment, tuple_element).maybe_live_out()); 1168 1169 // Verify that the GetTupleElement allocations of its elements match the 1170 // corresponding tuple parameter allocations because they alias. 1171 EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0}), 1172 GetAllocation(*assignment, tuple_element, /*index=*/{0})); 1173 EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1}), 1174 GetAllocation(*assignment, tuple_element, /*index=*/{1})); 1175 1176 // GetTupleElement forwards a pointer to its underlying buffer, so verify 1177 // that it has the same allocation than the corresponding parameter element. 1178 EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1}), 1179 GetTopLevelAllocation(*assignment, tuple_element)); 1180 } 1181 1182 // TODO(b/32248867): Enable when buffer assignment gives allocations to 1183 // constants. 1184 TEST_F(BufferAssignmentTest, DISABLED_TupleConstantAsOutput) { 1185 // Test that a tuple constant which is forwarded to the computation output 1186 // is properly handled. 1187 auto builder = HloComputation::Builder(TestName()); 1188 builder.AddInstruction(HloInstruction::CreateConstant(Literal::MakeTuple( 1189 {Literal::CreateR0<int64>(0).get(), Literal::CreateR0<int64>(1).get()}))); 1190 1191 auto module = CreateNewModule(); 1192 module->AddEntryComputation(builder.Build()); 1193 auto assignment = RunBufferAssignment(module.get()); 1194 1195 EXPECT_EQ(3, assignment->Allocations().size()); 1196 } 1197 1198 TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) { 1199 // Test a computation which returns a tuple custom call value. 1200 auto builder = HloComputation::Builder(TestName()); 1201 auto custom_call = builder.AddInstruction(HloInstruction::CreateCustomCall( 1202 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), 1203 ShapeUtil::MakeShape(S32, {101})}), 1204 /*operands=*/{}, /*custom_call_target=*/"foo_function")); 1205 auto module = CreateNewModule(); 1206 module->AddEntryComputation(builder.Build()); 1207 auto assignment = RunBufferAssignment(module.get()); 1208 1209 EXPECT_EQ(3, assignment->Allocations().size()); 1210 EXPECT_TRUE( 1211 GetAllocation(*assignment, custom_call, /*index=*/{}).maybe_live_out()); 1212 EXPECT_TRUE( 1213 GetAllocation(*assignment, custom_call, /*index=*/{0}).maybe_live_out()); 1214 EXPECT_TRUE( 1215 GetAllocation(*assignment, custom_call, /*index=*/{1}).maybe_live_out()); 1216 } 1217 1218 TEST_F(BufferAssignmentTest, TupleCallAsOutput) { 1219 // Test a computation which returns a tuple call value. 1220 auto module = CreateNewModule(); 1221 auto elem_shape = f32vec4_; 1222 auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape}); 1223 1224 auto sub_builder = HloComputation::Builder(TestName() + "_sub"); 1225 auto sub_param = sub_builder.AddInstruction( 1226 HloInstruction::CreateParameter(0, elem_shape, "sub_param")); 1227 auto sub_tuple = 1228 sub_builder.AddInstruction(HloInstruction::CreateTuple({sub_param})); 1229 auto sub_computation = module->AddEmbeddedComputation(sub_builder.Build()); 1230 1231 auto builder = HloComputation::Builder(TestName()); 1232 auto param = builder.AddInstruction( 1233 HloInstruction::CreateParameter(0, elem_shape, "param")); 1234 auto call = builder.AddInstruction( 1235 HloInstruction::CreateCall(tuple_shape, {param}, sub_computation)); 1236 module->AddEntryComputation(builder.Build()); 1237 1238 auto assignment = RunBufferAssignment(module.get()); 1239 1240 EXPECT_EQ(3, assignment->Allocations().size()); 1241 // Buffers for call are colocated with the sub-computation. 1242 EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{}), 1243 GetAllocation(*assignment, sub_tuple, /*index=*/{})); 1244 EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{0}), 1245 GetAllocation(*assignment, sub_param, /*index=*/{})); 1246 // The parameter isn't aliased with anything. 1247 EXPECT_NE(GetTopLevelAllocation(*assignment, param), 1248 GetTopLevelAllocation(*assignment, sub_tuple)); 1249 EXPECT_NE(GetTopLevelAllocation(*assignment, param), 1250 GetTopLevelAllocation(*assignment, sub_param)); 1251 } 1252 1253 TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) { 1254 // Test a chain of calls with tuple output. The chain looks like: 1255 // A: call(B, tuple(param)) 1256 // B: call(C, param) 1257 // C: call(D, param) 1258 // D: param 1259 auto module = CreateNewModule(); 1260 auto elem_shape = f32vec4_; 1261 auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape}); 1262 1263 auto d_builder = HloComputation::Builder(TestName() + "_d"); 1264 auto d_param = d_builder.AddInstruction( 1265 HloInstruction::CreateParameter(0, tuple_shape, "d_param")); 1266 auto d_computation = d_builder.Build(); 1267 1268 auto c_builder = HloComputation::Builder(TestName() + "_c"); 1269 auto c_param = c_builder.AddInstruction( 1270 HloInstruction::CreateParameter(0, tuple_shape, "c_param")); 1271 auto c_call = c_builder.AddInstruction( 1272 HloInstruction::CreateCall(tuple_shape, {c_param}, d_computation.get())); 1273 auto c_computation = c_builder.Build(); 1274 1275 auto b_builder = HloComputation::Builder(TestName() + "_b"); 1276 auto b_param = b_builder.AddInstruction( 1277 HloInstruction::CreateParameter(0, tuple_shape, "b_param")); 1278 auto b_call = b_builder.AddInstruction( 1279 HloInstruction::CreateCall(tuple_shape, {b_param}, c_computation.get())); 1280 auto b_computation = b_builder.Build(); 1281 1282 auto a_builder = HloComputation::Builder(TestName()); 1283 auto a_param = a_builder.AddInstruction( 1284 HloInstruction::CreateParameter(0, elem_shape, "param")); 1285 auto a_tuple = 1286 a_builder.AddInstruction(HloInstruction::CreateTuple({a_param})); 1287 auto a_call = a_builder.AddInstruction( 1288 HloInstruction::CreateCall(tuple_shape, {a_tuple}, b_computation.get())); 1289 auto a_computation = a_builder.Build(); 1290 1291 // Add the computations in an order that doesn't match the dependency 1292 // post-order, to shake out more possible bugs. 1293 module->AddEmbeddedComputation(std::move(d_computation)); 1294 module->AddEmbeddedComputation(std::move(c_computation)); 1295 module->AddEntryComputation(std::move(a_computation)); 1296 module->AddEmbeddedComputation(std::move(b_computation)); 1297 1298 auto assignment = RunBufferAssignment(module.get()); 1299 1300 // Buffers for call are colocated with the sub-computations. 1301 EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}), 1302 GetAllocation(*assignment, b_call, /*index=*/{})); 1303 EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{}), 1304 GetAllocation(*assignment, c_call, /*index=*/{})); 1305 EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{}), 1306 GetAllocation(*assignment, d_param, /*index=*/{})); 1307 EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{0}), 1308 GetAllocation(*assignment, b_call, /*index=*/{0})); 1309 EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{0}), 1310 GetAllocation(*assignment, c_call, /*index=*/{0})); 1311 EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{0}), 1312 GetAllocation(*assignment, d_param, /*index=*/{0})); 1313 // The parameters aren't aliased with anything. 1314 EXPECT_TRUE(BuffersDistinct({a_param}, {b_param}, *assignment)); 1315 EXPECT_TRUE(BuffersDistinct({a_param}, {c_param}, *assignment)); 1316 EXPECT_TRUE(BuffersDistinct({a_param}, {d_param}, *assignment)); 1317 EXPECT_TRUE(BuffersDistinct({b_param}, {c_param}, *assignment)); 1318 EXPECT_TRUE(BuffersDistinct({b_param}, {d_param}, *assignment)); 1319 EXPECT_TRUE(BuffersDistinct({c_param}, {d_param}, *assignment)); 1320 } 1321 1322 TEST_F(BufferAssignmentTest, BitcastAsOutput) { 1323 // Test a computation which returns a bitcast value. 1324 auto builder = HloComputation::Builder(TestName()); 1325 auto param = builder.AddInstruction(HloInstruction::CreateParameter( 1326 0, ShapeUtil::MakeShape(F32, {42}), "param")); 1327 auto bitcast = builder.AddInstruction( 1328 HloInstruction::CreateUnary(param->shape(), HloOpcode::kBitcast, param)); 1329 1330 auto module = CreateNewModule(); 1331 module->AddEntryComputation(builder.Build()); 1332 auto assignment = RunBufferAssignment(module.get()); 1333 1334 // Bitcast should get the same allocation as the param. 1335 EXPECT_EQ(1, assignment->Allocations().size()); 1336 EXPECT_EQ(GetTopLevelAllocation(*assignment, param), 1337 GetTopLevelAllocation(*assignment, bitcast)); 1338 } 1339 1340 TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) { 1341 // Test a computation with an output that has an ambiguous points-to set. 1342 // This is constructed using a select among tuple shapes. 1343 auto builder = HloComputation::Builder(TestName()); 1344 auto tuple_shape = 1345 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4})}); 1346 1347 auto tuple_param0 = builder.AddInstruction( 1348 HloInstruction::CreateParameter(0, tuple_shape, "param0")); 1349 auto tuple_param1 = builder.AddInstruction( 1350 HloInstruction::CreateParameter(1, tuple_shape, "param1")); 1351 auto pred_param = builder.AddInstruction(HloInstruction::CreateParameter( 1352 2, ShapeUtil::MakeShape(PRED, {}), "param1")); 1353 auto select = builder.AddInstruction(HloInstruction::CreateTernary( 1354 tuple_shape, HloOpcode::kSelect, pred_param, tuple_param0, tuple_param1)); 1355 1356 auto module = CreateNewModule(); 1357 module->AddEntryComputation(builder.Build()); 1358 auto assignment = RunBufferAssignment(module.get()); 1359 1360 // Select shallow copies one of its operands so it defines its own top-level 1361 // buffer and receives its own allocation. 1362 auto select_alloc = GetTopLevelAllocation(*assignment, select); 1363 EXPECT_EQ(1, select_alloc.assigned_buffers().size()); 1364 EXPECT_EQ(select, 1365 select_alloc.assigned_buffers().begin()->first->instruction()); 1366 1367 // The buffer for the tuple element of the select is forwarded from one its 1368 // operands which cannot be determined statically. Therefore its slices 1369 // should include the slices of both of the elements in the parameters. 1370 auto element_slices = assignment->GetAllSlices(select, /*index=*/{0}); 1371 EXPECT_EQ(2, element_slices.size()); 1372 EXPECT_THAT(element_slices, 1373 ::testing::UnorderedElementsAre( 1374 assignment->GetUniqueSlice(tuple_param0, /*index=*/{0}) 1375 .ConsumeValueOrDie(), 1376 assignment->GetUniqueSlice(tuple_param1, /*index=*/{0}) 1377 .ConsumeValueOrDie())); 1378 } 1379 1380 // TODO(b/34669761): Remove this test when buffers are allowed to share 1381 // allocations. 1382 TEST_F(BufferAssignmentTest, TupleBufferNotReused) { 1383 // Test a computation that returns a tuple parameter. 1384 auto builder = HloComputation::Builder(TestName()); 1385 auto scalar_shape = ShapeUtil::MakeShape(F32, {}); 1386 auto param = builder.AddInstruction( 1387 HloInstruction::CreateParameter(0, scalar_shape, "param0")); 1388 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({param})); 1389 auto tuple_element = builder.AddInstruction( 1390 HloInstruction::CreateGetTupleElement(scalar_shape, tuple, 0)); 1391 auto copy = builder.AddInstruction(HloInstruction::CreateUnary( 1392 scalar_shape, HloOpcode::kCopy, tuple_element)); 1393 1394 auto module = CreateNewModule(); 1395 module->AddEntryComputation(builder.Build()); 1396 auto assignment = RunBufferAssignment(module.get()); 1397 1398 // There should be no buffer reuse. The copy should not reuse the tuple 1399 // buffer. 1400 EXPECT_EQ(3, assignment->Allocations().size()); 1401 EXPECT_NE(GetTopLevelAllocation(*assignment, tuple), 1402 GetTopLevelAllocation(*assignment, copy)); 1403 } 1404 1405 TEST_F(BufferAssignmentTest, OneTempAllocation) { 1406 // Test a computation that requires multiple temp buffers, and ensure they 1407 // are combined into a single allocation. 1408 auto builder = HloComputation::Builder(TestName()); 1409 Shape shape_2x3 = ShapeUtil::MakeShape(F32, {2, 3}); 1410 Shape shape_2x4 = ShapeUtil::MakeShape(F32, {2, 4}); 1411 Shape shape_3x4 = ShapeUtil::MakeShape(F32, {3, 4}); 1412 Shape shape_4x4 = ShapeUtil::MakeShape(F32, {4, 4}); 1413 Shape shape_5x4 = ShapeUtil::MakeShape(F32, {5, 4}); 1414 1415 // There should be separate temp buffers for dot_ab and dot_bc. 1416 auto param_a = builder.AddInstruction( 1417 HloInstruction::CreateParameter(0, shape_2x3, "param_a")); 1418 auto param_b = builder.AddInstruction( 1419 HloInstruction::CreateParameter(1, shape_3x4, "param_b")); 1420 auto param_c = builder.AddInstruction( 1421 HloInstruction::CreateParameter(2, shape_4x4, "param_c")); 1422 DotDimensionNumbers dot_dnums; 1423 dot_dnums.add_lhs_contracting_dimensions(1); 1424 dot_dnums.add_rhs_contracting_dimensions(0); 1425 auto dot_ab = builder.AddInstruction( 1426 HloInstruction::CreateDot(shape_2x4, param_a, param_b, dot_dnums)); 1427 auto dot_bc = builder.AddInstruction( 1428 HloInstruction::CreateDot(shape_3x4, param_b, param_c, dot_dnums)); 1429 builder.AddInstruction( 1430 HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 1)); 1431 1432 // Run buffer assignment with alignment=1. 1433 auto module = CreateNewModule(); 1434 module->AddEntryComputation(builder.Build()); 1435 auto assignment = RunBufferAssignment(module.get(), /*alignment=*/1); 1436 1437 // There are 5 allocations: 3 parameters, 1 output, and 1 temp. 1438 EXPECT_EQ(5, assignment->Allocations().size()); 1439 1440 // Ensure the temp buffers for dot_ab and dot_bc share a single allocation, 1441 // and each occupies different slices of that allocation. 1442 BufferAllocation::Slice slice_ab = 1443 assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie(); 1444 BufferAllocation::Slice slice_bc = 1445 assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie(); 1446 EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation()); 1447 EXPECT_NE(slice_ab, slice_bc); 1448 EXPECT_EQ(32, slice_ab.size()); 1449 EXPECT_EQ(48, slice_bc.size()); 1450 EXPECT_EQ(80, slice_ab.allocation()->size()); 1451 EXPECT_EQ(80, slice_bc.allocation()->size()); 1452 1453 // Re-run buffer assignment with alignment=64. 1454 assignment = RunBufferAssignment(module.get(), /*alignment=*/64); 1455 EXPECT_EQ(5, assignment->Allocations().size()); 1456 slice_ab = assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie(); 1457 slice_bc = assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie(); 1458 EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation()); 1459 EXPECT_NE(slice_ab, slice_bc); 1460 EXPECT_EQ(32, slice_ab.size()); 1461 EXPECT_EQ(48, slice_bc.size()); 1462 // Ensure the offsets and allocation size account for the alignment, without 1463 // assuming which buffer gets assigned first. 1464 if (slice_ab.offset() == 0) { 1465 EXPECT_EQ(64, slice_bc.offset()); 1466 EXPECT_EQ(64 + 48, slice_ab.allocation()->size()); 1467 EXPECT_EQ(64 + 48, slice_bc.allocation()->size()); 1468 } else { 1469 EXPECT_EQ(64, slice_ab.offset()); 1470 EXPECT_EQ(0, slice_bc.offset()); 1471 EXPECT_EQ(64 + 32, slice_ab.allocation()->size()); 1472 EXPECT_EQ(64 + 32, slice_bc.allocation()->size()); 1473 } 1474 } 1475 1476 class WhileBufferAssignmentTest : public HloTestBase { 1477 protected: 1478 std::unique_ptr<HloComputation> BuildWhileConditionComputation( 1479 const string& name) { 1480 auto builder = HloComputation::Builder(name); 1481 builder.AddInstruction( 1482 HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); 1483 auto zero = builder.AddInstruction( 1484 HloInstruction::CreateConstant(Literal::CreateR0<int>(0))); 1485 auto ten = builder.AddInstruction( 1486 HloInstruction::CreateConstant(Literal::CreateR0<int>(10))); 1487 builder.AddInstruction(HloInstruction::CreateBinary( 1488 ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten)); 1489 return builder.Build(); 1490 } 1491 1492 std::unique_ptr<HloComputation> BuildWhileBodyComputation( 1493 const string& name) { 1494 auto builder = HloComputation::Builder(name); 1495 auto loop_state = builder.AddInstruction( 1496 HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); 1497 auto input = builder.AddInstruction( 1498 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 0)); 1499 auto weights = builder.AddInstruction( 1500 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); 1501 auto output = builder.AddInstruction(HloInstruction::CreateBinary( 1502 data_shape_, HloOpcode::kMultiply, input, weights)); 1503 builder.AddInstruction( 1504 HloInstruction::CreateTuple({input, weights, output})); 1505 return builder.Build(); 1506 } 1507 1508 std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module, 1509 int64 alignment = 1) { 1510 auto sequence = 1511 CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); 1512 return BufferAssigner::Run( 1513 module, xla::MakeUnique<SequentialHloOrdering>(module, sequence), 1514 ByteSizeOf, 1515 [alignment](LogicalBuffer::Color) { return alignment; }) 1516 .ConsumeValueOrDie(); 1517 } 1518 1519 static int64 ByteSizeOf(const LogicalBuffer& buffer) { 1520 return ShapeUtil::ByteSizeOf(buffer.shape(), sizeof(void*)); 1521 } 1522 1523 Shape data_shape_ = ShapeUtil::MakeShape(F32, {4}); 1524 Shape loop_state_shape_ = 1525 ShapeUtil::MakeTupleShape({data_shape_, data_shape_, data_shape_}); 1526 }; 1527 1528 static void RunCopyInsertion(HloModule* module) { 1529 CopyInsertion copy_insertion; 1530 EXPECT_IS_OK(copy_insertion.Run(module).status()); 1531 } 1532 1533 TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { 1534 auto module = xla::MakeUnique<HloModule>(TestName()); 1535 auto builder = HloComputation::Builder("entry"); 1536 1537 auto input0 = builder.AddInstruction( 1538 HloInstruction::CreateParameter(0, data_shape_, "input0")); 1539 auto weights0 = builder.AddInstruction( 1540 HloInstruction::CreateParameter(1, data_shape_, "weights0")); 1541 auto weights1 = builder.AddInstruction( 1542 HloInstruction::CreateParameter(2, data_shape_, "weights1")); 1543 1544 auto zero = builder.AddInstruction( 1545 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0))); 1546 auto output0 = builder.AddInstruction( 1547 HloInstruction::CreateBroadcast(data_shape_, zero, {1})); 1548 auto output1 = builder.AddInstruction( 1549 HloInstruction::CreateBroadcast(data_shape_, zero, {1})); 1550 1551 auto cond0 = 1552 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); 1553 auto body0 = 1554 module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); 1555 1556 auto tuple0 = builder.AddInstruction( 1557 HloInstruction::CreateTuple({input0, weights0, output0})); 1558 auto while0 = builder.AddInstruction( 1559 HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0)); 1560 1561 auto cond1 = 1562 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); 1563 auto body1 = 1564 module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); 1565 auto input1 = builder.AddInstruction( 1566 HloInstruction::CreateGetTupleElement(data_shape_, while0, 2)); 1567 auto tuple1 = builder.AddInstruction( 1568 HloInstruction::CreateTuple({input1, weights1, output1})); 1569 auto while1 = builder.AddInstruction( 1570 HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1)); 1571 1572 module->AddEntryComputation(builder.Build()); 1573 RunCopyInsertion(module.get()); 1574 auto assignment = RunBufferAssignment(module.get()); 1575 1576 // Verify 'input0' and read-only use while0{0} alias. 1577 EXPECT_EQ(assignment->GetUniqueSlice(input0, {}).ConsumeValueOrDie(), 1578 assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie()); 1579 // Verify 'weights0' and read-only use while0{1} alias. 1580 EXPECT_EQ(assignment->GetUniqueSlice(weights0, {}).ConsumeValueOrDie(), 1581 assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie()); 1582 // Verify 'while0{2}' and read-only use while1{0} alias. 1583 EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(), 1584 assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie()); 1585 // Verify 'weights1' and read-only use while1{1} alias. 1586 EXPECT_EQ(assignment->GetUniqueSlice(weights1, {}).ConsumeValueOrDie(), 1587 assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie()); 1588 } 1589 1590 // Tests that the colocated buffers for while instructions are properly assigned 1591 // during buffer assignment such that the result tuple elements are not assigned 1592 // to the same buffer. 1593 // 1594 // %infeed --> %while.0 --> %while.1 --+ 1595 // +-- %tuple 1596 // %zero --> %add --> %while.2 --+ 1597 // 1598 // Execution Order: 1599 // %infeed -> %while.0 -> %while.1 -> %zero -> %add -> %while.2 -> %tuple 1600 // 1601 // The HLO computation used in this test requires specific ordering to expose 1602 // the bug (b/72496031). During buffer assignment, the visitation order of 1603 // colocated buffers is %while.2 -> while.0 -> while.1, and the buffer 1604 // assignment was coalescing the colocated buffers for all 3 while instructions, 1605 // therefore assigning the same buffer to the two result tuple elements. 1606 TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { 1607 const Shape r0s32 = ShapeUtil::MakeShape(S32, {}); 1608 1609 // Builds a condition computation: x -> x < 4 1610 auto build_cond = [&]() { 1611 auto builder = HloComputation::Builder("cond"); 1612 auto const4 = builder.AddInstruction( 1613 HloInstruction::CreateConstant(Literal::CreateR0<int>(4))); 1614 auto param = 1615 builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x")); 1616 builder.AddInstruction(HloInstruction::CreateBinary( 1617 ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, const4)); 1618 return builder.Build(); 1619 }; 1620 1621 // Builds a body computation: x -> x + 9 1622 auto build_body = [&]() { 1623 auto builder = HloComputation::Builder("body"); 1624 auto const9 = builder.AddInstruction( 1625 HloInstruction::CreateConstant(Literal::CreateR0<int>(9))); 1626 auto param = 1627 builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x")); 1628 builder.AddInstruction( 1629 HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, param, const9)); 1630 return builder.Build(); 1631 }; 1632 1633 // Build the entry computation as described in the comment above. 1634 auto module = xla::MakeUnique<HloModule>(TestName()); 1635 auto builder = HloComputation::Builder("entry"); 1636 1637 auto infeed = builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, "")); 1638 auto cond0 = module->AddEmbeddedComputation(build_cond()); 1639 auto body0 = module->AddEmbeddedComputation(build_body()); 1640 auto while0 = builder.AddInstruction( 1641 HloInstruction::CreateWhile(r0s32, cond0, body0, infeed)); 1642 1643 auto cond1 = module->AddEmbeddedComputation(build_cond()); 1644 auto body1 = module->AddEmbeddedComputation(build_body()); 1645 auto while1 = builder.AddInstruction( 1646 HloInstruction::CreateWhile(r0s32, cond1, body1, while0)); 1647 1648 auto zero = builder.AddInstruction( 1649 HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))); 1650 auto add = builder.AddInstruction( 1651 HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, zero, zero)); 1652 auto cond2 = module->AddEmbeddedComputation(build_cond()); 1653 auto body2 = module->AddEmbeddedComputation(build_body()); 1654 auto while2 = builder.AddInstruction( 1655 HloInstruction::CreateWhile(r0s32, cond2, body2, add)); 1656 1657 auto tuple = 1658 builder.AddInstruction(HloInstruction::CreateTuple({while2, while1})); 1659 module->AddEntryComputation(builder.Build()); 1660 1661 // Run CopyInsertion and check if the graph constructed above doesn't need 1662 // any copies inserted for BufferAssignment to run. 1663 int64 instruction_count = module->instruction_count(); 1664 CopyInsertion copy_insertion; 1665 ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); 1666 ASSERT_EQ(instruction_count, module->instruction_count()); 1667 1668 // Create a sequential order among all the instructions in the entry 1669 // computation, since the issue this test stresses depends on the order the 1670 // nodes are traversed during BufferAssignment. 1671 SequentialHloOrdering::HloModuleSequence sequence; 1672 sequence[module->entry_computation()] = {infeed, while0, while1, zero, 1673 add, while2, tuple}; 1674 TF_ASSERT_OK_AND_ASSIGN( 1675 auto assignment, 1676 BufferAssigner::Run( 1677 module.get(), 1678 xla::MakeUnique<SequentialHloOrdering>(module.get(), sequence), 1679 backend().compiler()->BufferSizeBytesFunction(), 1680 [](LogicalBuffer::Color) { return 1; })); 1681 1682 // The result tuple elements must be assigned with different buffers. 1683 TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0})); 1684 TF_ASSERT_OK_AND_ASSIGN(auto slice1, assignment->GetUniqueSlice(tuple, {1})); 1685 EXPECT_NE(slice0, slice1); 1686 1687 // while0 and while1 result buffers must be equal to slice1. 1688 TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, 1689 assignment->GetUniqueSlice(while0, {})); 1690 TF_ASSERT_OK_AND_ASSIGN(auto slice_while1, 1691 assignment->GetUniqueSlice(while1, {})); 1692 EXPECT_EQ(slice1, slice_while0); 1693 EXPECT_EQ(slice1, slice_while1); 1694 1695 // while2 result buffer must be equal to slice0. 1696 TF_ASSERT_OK_AND_ASSIGN(auto slice_while2, 1697 assignment->GetUniqueSlice(while2, {})); 1698 EXPECT_EQ(slice0, slice_while2); 1699 } 1700 1701 TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { 1702 auto module = xla::MakeUnique<HloModule>(TestName()); 1703 auto builder = HloComputation::Builder("entry"); 1704 1705 auto input0 = builder.AddInstruction( 1706 HloInstruction::CreateParameter(0, data_shape_, "input0")); 1707 auto weights0 = builder.AddInstruction( 1708 HloInstruction::CreateParameter(1, data_shape_, "weights0")); 1709 1710 auto zero = builder.AddInstruction( 1711 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0))); 1712 auto output0 = builder.AddInstruction( 1713 HloInstruction::CreateBroadcast(data_shape_, zero, {1})); 1714 1715 auto cond0 = 1716 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); 1717 auto body0 = 1718 module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); 1719 1720 auto tuple0 = builder.AddInstruction( 1721 HloInstruction::CreateTuple({input0, weights0, output0})); 1722 auto while0 = builder.AddInstruction( 1723 HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0)); 1724 1725 auto cond1 = 1726 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); 1727 auto body1 = 1728 module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); 1729 1730 auto while1 = builder.AddInstruction( 1731 HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0)); 1732 1733 module->AddEntryComputation(builder.Build()); 1734 RunCopyInsertion(module.get()); 1735 auto assignment = RunBufferAssignment(module.get()); 1736 1737 // while0 and while1 buffers should be completely aligned. 1738 EXPECT_EQ(assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie(), 1739 assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie()); 1740 EXPECT_EQ(assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie(), 1741 assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie()); 1742 EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(), 1743 assignment->GetUniqueSlice(while1, {2}).ConsumeValueOrDie()); 1744 } 1745 1746 TEST_F(BufferAssignmentTest, TwoCalls) { 1747 auto module = xla::MakeUnique<HloModule>(TestName()); 1748 Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {}); 1749 HloComputation* sub_computation; 1750 { 1751 auto builder = HloComputation::Builder(TestName() + "_sub_comp"); 1752 auto param = builder.AddInstruction( 1753 HloInstruction::CreateParameter(0, r0f32, "param")); 1754 auto constant1 = builder.AddInstruction( 1755 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 1756 auto add = builder.AddInstruction( 1757 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1)); 1758 sub_computation = module->AddEmbeddedComputation(builder.Build(add)); 1759 } 1760 auto builder = HloComputation::Builder(TestName()); 1761 auto constant2 = builder.AddInstruction( 1762 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 1763 auto constant3 = builder.AddInstruction( 1764 HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0))); 1765 auto call1 = builder.AddInstruction( 1766 HloInstruction::CreateCall(r0f32, {constant2}, sub_computation)); 1767 auto call2 = builder.AddInstruction( 1768 HloInstruction::CreateCall(r0f32, {constant3}, sub_computation)); 1769 auto add1 = builder.AddInstruction( 1770 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call1, constant2)); 1771 auto add2 = builder.AddInstruction( 1772 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call2, add1)); 1773 module->AddEntryComputation(builder.Build(add2)); 1774 1775 { 1776 FlattenCallGraph flatten; 1777 TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); 1778 EXPECT_TRUE(result); 1779 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get()); 1780 } 1781 1782 RunCopyInsertion(module.get()); 1783 auto assignment = RunBufferAssignment(module.get()); 1784 1785 EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment)); 1786 } 1787 1788 static bool IsPostOrderTraversal( 1789 const std::vector<const HloInstruction*>& sequence) { 1790 tensorflow::gtl::FlatSet<const HloInstruction*> seen_so_far; 1791 auto has_not_been_seen_yet = [&](const HloInstruction* instruction) { 1792 return seen_so_far.count(instruction) == 0; 1793 }; 1794 1795 for (auto instruction : sequence) { 1796 if (std::any_of(instruction->operands().begin(), 1797 instruction->operands().end(), has_not_been_seen_yet) || 1798 std::any_of(instruction->control_predecessors().begin(), 1799 instruction->control_predecessors().end(), 1800 has_not_been_seen_yet)) { 1801 return false; // Not a post order. 1802 } 1803 if (!seen_so_far.insert(instruction).second) { 1804 return false; // Not a "traversal". 1805 } 1806 } 1807 1808 return true; 1809 } 1810 1811 TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { 1812 auto module = xla::MakeUnique<HloModule>(TestName()); 1813 auto builder = HloComputation::Builder(TestName()); 1814 1815 auto zero = builder.AddInstruction( 1816 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0))); 1817 auto one = builder.AddInstruction( 1818 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 1819 1820 auto input0 = builder.AddInstruction( 1821 HloInstruction::CreateParameter(0, data_shape_, "input0")); 1822 auto weights0 = builder.AddInstruction( 1823 HloInstruction::CreateParameter(1, data_shape_, "weights0")); 1824 auto output0 = builder.AddInstruction( 1825 HloInstruction::CreateBroadcast(data_shape_, zero, {1})); 1826 1827 auto input1 = builder.AddInstruction( 1828 HloInstruction::CreateParameter(2, data_shape_, "input1")); 1829 auto weights1 = builder.AddInstruction( 1830 HloInstruction::CreateParameter(3, data_shape_, "weights1")); 1831 auto output1 = builder.AddInstruction( 1832 HloInstruction::CreateBroadcast(data_shape_, one, {1})); 1833 1834 auto cond = 1835 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); 1836 auto body = module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); 1837 1838 auto tuple0 = builder.AddInstruction( 1839 HloInstruction::CreateTuple({input0, weights0, output0})); 1840 auto tuple1 = builder.AddInstruction( 1841 HloInstruction::CreateTuple({input1, weights1, output1})); 1842 1843 auto while0 = builder.AddInstruction( 1844 HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple0)); 1845 auto while1 = builder.AddInstruction( 1846 HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1)); 1847 1848 auto gte0 = builder.AddInstruction( 1849 HloInstruction::CreateGetTupleElement(data_shape_, while0, 0)); 1850 auto gte1 = builder.AddInstruction( 1851 HloInstruction::CreateGetTupleElement(data_shape_, while1, 1)); 1852 auto root_add = builder.AddInstruction(HloInstruction::CreateBinary( 1853 while0->shape(), HloOpcode::kAdd, gte0, gte1)); 1854 1855 module->AddEntryComputation(builder.Build()); 1856 1857 { 1858 FlattenCallGraph flatten; 1859 TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); 1860 EXPECT_TRUE(result); 1861 } 1862 1863 RunCopyInsertion(module.get()); 1864 1865 auto sequence = 1866 CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); 1867 1868 // To trigger b/38494731, we want a specific Hlo sequence for the 1869 // root computation, so we overwrite that entry with a manually 1870 // crafted sequence. 1871 sequence[module->entry_computation()] = { 1872 input1, weights1, one, output1, while1->operand(0), while1, 1873 input0, weights0, zero, output0, while0->operand(0), while0, 1874 gte0, gte1, root_add}; 1875 1876 // If this ASSERT_TRUE fails, we constructed a bogus sequence above 1877 // and this test itself is buggy. 1878 ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()])); 1879 1880 auto assignment = 1881 BufferAssigner::Run( 1882 module.get(), 1883 xla::MakeUnique<SequentialHloOrdering>(module.get(), sequence), 1884 ByteSizeOf, [](LogicalBuffer::Color) { return 1; }) 1885 .ConsumeValueOrDie(); 1886 1887 EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); 1888 } 1889 1890 TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { 1891 auto module = xla::MakeUnique<HloModule>(TestName()); 1892 auto builder = HloComputation::Builder("entry"); 1893 1894 auto input0 = builder.AddInstruction( 1895 HloInstruction::CreateParameter(0, data_shape_, "input0")); 1896 auto weights0 = builder.AddInstruction( 1897 HloInstruction::CreateParameter(1, data_shape_, "weights0")); 1898 1899 auto zero = builder.AddInstruction( 1900 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0))); 1901 auto output0 = builder.AddInstruction( 1902 HloInstruction::CreateBroadcast(data_shape_, zero, {1})); 1903 auto output1 = builder.AddInstruction( 1904 HloInstruction::CreateBroadcast(data_shape_, zero, {1})); 1905 1906 auto cond0 = 1907 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); 1908 auto body0 = 1909 module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); 1910 1911 auto tuple0 = builder.AddInstruction( 1912 HloInstruction::CreateTuple({input0, weights0, output0})); 1913 auto while0 = builder.AddInstruction( 1914 HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0)); 1915 1916 // Get output of 'while0' and feed as input to 'while1'. 1917 auto while0_out = builder.AddInstruction( 1918 HloInstruction::CreateGetTupleElement(data_shape_, while0, 2)); 1919 1920 auto cond1 = 1921 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); 1922 auto body1 = 1923 module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); 1924 1925 auto tuple1 = builder.AddInstruction( 1926 HloInstruction::CreateTuple({while0_out, weights0, output1})); 1927 auto while1 = builder.AddInstruction( 1928 HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1)); 1929 1930 // Get output of 'while1' so that it is live out of computation. 1931 auto while1_out = builder.AddInstruction( 1932 HloInstruction::CreateGetTupleElement(data_shape_, while1, 2)); 1933 1934 module->AddEntryComputation(builder.Build()); 1935 RunCopyInsertion(module.get()); 1936 auto assignment = RunBufferAssignment(module.get()); 1937 // Get BufferAllocation for root instruction. 1938 auto* root_alloc = assignment->GetUniqueTopLevelSlice(while1_out) 1939 .ConsumeValueOrDie() 1940 .allocation(); 1941 // Test that root instruction allocation is live out. 1942 EXPECT_TRUE(root_alloc->maybe_live_out()); 1943 // Test that root instruction allocation is not an entry parameter. 1944 EXPECT_FALSE(root_alloc->is_entry_computation_parameter()); 1945 } 1946 1947 } // namespace 1948 } // namespace xla 1949