1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/buffer_liveness.h" 17 18 #include <memory> 19 #include <string> 20 21 #include "tensorflow/compiler/xla/ptr_util.h" 22 #include "tensorflow/compiler/xla/service/hlo_computation.h" 23 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 24 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 25 #include "tensorflow/compiler/xla/shape_util.h" 26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 27 #include "tensorflow/compiler/xla/types.h" 28 #include "tensorflow/compiler/xla/xla_data.pb.h" 29 30 namespace xla { 31 namespace { 32 33 class BufferLivenessTest : public HloTestBase { 34 protected: 35 // Returns the LogicalBuffer defined at the given instruction and 36 // index. CHECKs if no buffer is defined at that point. 37 const LogicalBuffer& GetBuffer(const BufferLiveness& liveness, 38 const HloInstruction* instruction, 39 const ShapeIndex& index) { 40 const auto& pointed_to = liveness.points_to_analysis() 41 .GetPointsToSet(instruction) 42 .element(index); 43 CHECK_EQ(1, pointed_to.size()); 44 CHECK_EQ(instruction, pointed_to[0]->instruction()); 45 CHECK(index == pointed_to[0]->index()); 46 return *pointed_to[0]; 47 } 48 49 // Returns true if the top-level buffers for instructions 'a' and 'b' may 50 // interfere. Precondition: 'a' and 'b' are array-shaped. 51 bool InstructionsMayInterfere(const BufferLiveness& liveness, 52 HloInstruction* a, HloInstruction* b) { 53 EXPECT_FALSE(ShapeUtil::IsTuple(a->shape())); 54 EXPECT_FALSE(ShapeUtil::IsTuple(b->shape())); 55 return liveness.MayInterfere( 56 GetBuffer(liveness, /*instruction=*/a, /*index=*/{}), 57 GetBuffer(liveness, /*instruction=*/b, /*index=*/{})); 58 } 59 60 // Returns true if the tuple elements at 'index' for instructions 'a' and 'b' 61 // may interfere. Precondition: 'a' and 'b' are tuple-shaped, with equal 62 // tuple element sub-shapes. 63 bool TupleElementsMayInterfere(const BufferLiveness& liveness, 64 HloInstruction* a, HloInstruction* b, 65 const ShapeIndex& index) { 66 // Check that top-level shapes are tuple and tuple element shapes are equal. 67 EXPECT_TRUE(ShapeUtil::IsTuple(a->shape())); 68 EXPECT_TRUE(ShapeUtil::IsTuple(b->shape())); 69 EXPECT_TRUE( 70 ShapeUtil::Compatible(ShapeUtil::GetSubshape(a->shape(), index), 71 ShapeUtil::GetSubshape(b->shape(), index))); 72 // Lookup PointsTo set for instructions 'a' and 'b'. 73 auto& points_to_analysis = liveness.points_to_analysis(); 74 const auto& points_to_a = 75 points_to_analysis.GetPointsToSet(a).element(index); 76 const auto& points_to_b = 77 points_to_analysis.GetPointsToSet(b).element(index); 78 // Make sure PointsTo sets for 'a' and 'b' are unambiguous. 79 EXPECT_EQ(1, points_to_a.size()); 80 EXPECT_EQ(points_to_a.size(), points_to_b.size()); 81 // Check interference. 82 return liveness.MayInterfere(*points_to_a[0], *points_to_b[0]); 83 } 84 85 // Returns true if the top-level buffers for the given instruction maybe 86 // liveout of the entry computation. 87 // Precondition: instruction is array-shaped. 88 bool InstructionMaybeLiveOut(const BufferLiveness& liveness, 89 HloInstruction* instruction) { 90 return liveness.MaybeLiveOut( 91 GetBuffer(liveness, instruction, /*index=*/{})); 92 } 93 94 std::unique_ptr<HloComputation> BuildDummyComputation() { 95 auto builder = HloComputation::Builder(TestName() + "_dummy"); 96 builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); 97 return builder.Build(); 98 } 99 100 const Shape vec_ = ShapeUtil::MakeShape(xla::F32, {42}); 101 }; 102 103 TEST_F(BufferLivenessTest, ElementwiseChain) { 104 // A simple chain of elementwise operations. No buffers should interfere. 105 // 106 // param --> negate -> exp -> log 107 // 108 auto builder = HloComputation::Builder(TestName()); 109 auto param = 110 builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); 111 auto negate = builder.AddInstruction( 112 HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param)); 113 auto exp = builder.AddInstruction( 114 HloInstruction::CreateUnary(vec_, HloOpcode::kExp, negate)); 115 auto log = builder.AddInstruction( 116 HloInstruction::CreateUnary(vec_, HloOpcode::kLog, exp)); 117 118 auto module = CreateNewModule(); 119 module->AddEntryComputation(builder.Build()); 120 121 auto liveness = 122 BufferLiveness::Run(module.get(), 123 xla::MakeUnique<DependencyHloOrdering>(module.get())) 124 .ConsumeValueOrDie(); 125 126 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate)); 127 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); 128 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, log)); 129 130 // No buffers should interfere. 131 EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, exp)); 132 EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, log)); 133 EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, negate)); 134 EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, log)); 135 EXPECT_FALSE(InstructionsMayInterfere(*liveness, log, negate)); 136 EXPECT_FALSE(InstructionsMayInterfere(*liveness, log, exp)); 137 138 // Buffers should interfere with itself. 139 EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, exp)); 140 141 // Only log is live out. 142 EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, param)); 143 EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, negate)); 144 EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, exp)); 145 EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, log)); 146 } 147 148 TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { 149 // Two entry params, which interfere with each other. 150 // 151 // param0 --> negate ---------------\ 152 // param1 --> exp --> add 153 auto builder = HloComputation::Builder(TestName()); 154 auto param0 = builder.AddInstruction( 155 HloInstruction::CreateParameter(0, vec_, "param0")); 156 auto param1 = builder.AddInstruction( 157 HloInstruction::CreateParameter(1, vec_, "param1")); 158 auto negate = builder.AddInstruction( 159 HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param0)); 160 auto exp = builder.AddInstruction( 161 HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param1)); 162 auto add = builder.AddInstruction( 163 HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); 164 165 auto module = CreateNewModule(); 166 HloComputation* entry = module->AddEntryComputation(builder.Build()); 167 168 SequentialHloOrdering::HloModuleSequence sequence; 169 sequence.insert({entry, {param0, negate, param1, exp, add}}); 170 auto liveness = 171 BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>( 172 module.get(), sequence)) 173 .ConsumeValueOrDie(); 174 175 // Entry parameters interfere as if they are defined simultaneously at 176 // the very beginning. 177 EXPECT_TRUE(InstructionsMayInterfere(*liveness, param0, param1)); 178 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, negate)); 179 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, exp)); 180 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, add)); 181 EXPECT_TRUE(InstructionsMayInterfere(*liveness, param1, param0)); 182 EXPECT_TRUE(InstructionsMayInterfere(*liveness, param1, negate)); 183 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param1, exp)); 184 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param1, add)); 185 186 // Negate and exp still interfere. 187 EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp)); 188 EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate)); 189 190 // But {negate, add} and {exp, add} don't interfere. 191 EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add)); 192 EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate)); 193 EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add)); 194 EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp)); 195 } 196 197 TEST_F(BufferLivenessTest, NonElementwiseOperand) { 198 // A chain of operations with two elementwise and one non-elementwise. The 199 // elementwise op should not interfere with its operand, while the 200 // non-elementwise op should interfere. Entry params always interfere. 201 // 202 // param --> exp -> negate -> reverse 203 // 204 auto builder = HloComputation::Builder(TestName()); 205 auto param = 206 builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); 207 auto exp = builder.AddInstruction( 208 HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param)); 209 auto negate = builder.AddInstruction( 210 HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, exp)); 211 auto reverse = 212 builder.AddInstruction(HloInstruction::CreateReverse(vec_, negate, {0})); 213 214 auto module = CreateNewModule(); 215 module->AddEntryComputation(builder.Build()); 216 217 auto liveness = 218 BufferLiveness::Run(module.get(), 219 xla::MakeUnique<DependencyHloOrdering>(module.get())) 220 .ConsumeValueOrDie(); 221 222 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); 223 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate)); 224 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, reverse)); 225 226 // Negate is elementwise, so doesn't interfere with its operand. 227 // Reverse is non-elementwise, so does interfere with its operand. 228 EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, negate)); 229 EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, reverse)); 230 } 231 232 TEST_F(BufferLivenessTest, OverlappedBuffers) { 233 // Verify simultaneously live buffers interfere (exp and negate). 234 // 235 // param --> negate -> add 236 // \---> exp -----/ 237 // 238 auto builder = HloComputation::Builder(TestName()); 239 auto param = 240 builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); 241 auto negate = builder.AddInstruction( 242 HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param)); 243 auto exp = builder.AddInstruction( 244 HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param)); 245 auto add = builder.AddInstruction( 246 HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); 247 248 auto module = CreateNewModule(); 249 module->AddEntryComputation(builder.Build()); 250 251 auto liveness = 252 BufferLiveness::Run(module.get(), 253 xla::MakeUnique<DependencyHloOrdering>(module.get())) 254 .ConsumeValueOrDie(); 255 256 EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); 257 EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, exp)); 258 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); 259 260 // Negate and exp interfere with each other, but not with add. 261 EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp)); 262 EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate)); 263 EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add)); 264 EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate)); 265 EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add)); 266 EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp)); 267 } 268 269 TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { 270 // Identical to the test OverlappedBuffer but using a sequential ordering of 271 // HLO instructions. 272 // 273 // param --> negate -> add 274 // \---> exp -----/ 275 // 276 // Sequential order: 277 // param, negate, exp, add 278 // 279 // Liveness is identical to the DependencyHloOrdering. 280 auto builder = HloComputation::Builder(TestName()); 281 auto param = 282 builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); 283 auto negate = builder.AddInstruction( 284 HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param)); 285 auto exp = builder.AddInstruction( 286 HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param)); 287 auto add = builder.AddInstruction( 288 HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); 289 290 auto module = CreateNewModule(); 291 auto computation = module->AddEntryComputation(builder.Build()); 292 293 SequentialHloOrdering::HloModuleSequence module_sequence; 294 std::vector<const HloInstruction*> order = {param, negate, exp, add}; 295 module_sequence.emplace(computation, order); 296 auto liveness = 297 BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>( 298 module.get(), module_sequence)) 299 .ConsumeValueOrDie(); 300 301 EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); 302 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); 303 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); 304 305 // Negate and exp interfere with each other, but not with add. 306 EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp)); 307 EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate)); 308 EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add)); 309 EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate)); 310 EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add)); 311 EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp)); 312 } 313 314 TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { 315 // Tests that when the root instruction is not the last instruction in the 316 // schedule, the live range of its buffers interfere with the buffers of the 317 // later instructions. 318 // 319 // Two sets of independent instructions are executed in the computation. 320 // param --> add (root) 321 // recv --> recv-done --> send --> send-done 322 // 323 // Sequential order: 324 // param, add (root), recv, recv-done, send, send-done 325 auto builder = HloComputation::Builder(TestName()); 326 auto param = 327 builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); 328 auto add = builder.AddInstruction( 329 HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, param, param)); 330 auto recv = builder.AddInstruction( 331 HloInstruction::CreateRecv(vec_, /*channel_id=*/0)); 332 auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); 333 auto send = builder.AddInstruction( 334 HloInstruction::CreateSend(recv_done, /*channel_id=*/1)); 335 auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); 336 337 auto module = CreateNewModule(); 338 auto computation = module->AddEntryComputation(builder.Build(add)); 339 340 SequentialHloOrdering::HloModuleSequence module_sequence; 341 std::vector<const HloInstruction*> order = {param, add, recv, 342 recv_done, send, send_done}; 343 module_sequence.emplace(computation, order); 344 auto liveness = 345 BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>( 346 module.get(), module_sequence)) 347 .ConsumeValueOrDie(); 348 349 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); 350 // Check the root instruction (add) buffer interferes with the recv buffer. 351 EXPECT_TRUE( 352 liveness->MayInterfere(GetBuffer(*liveness, add, /*index=*/{}), 353 GetBuffer(*liveness, recv, /*index=*/{0}))); 354 } 355 356 TEST_F(BufferLivenessTest, TupleLiveOut) { 357 // Verify MaybeLiveOut with nested tuples. Result of computation looks like: 358 // 359 // Tuple({Tuple({Negate(Param)}, Exp(Negate(Param)))}) 360 // 361 // All values should be live out except Param. 362 auto builder = HloComputation::Builder(TestName()); 363 auto param = 364 builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); 365 auto negate = builder.AddInstruction( 366 HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param)); 367 auto inner_tuple = 368 builder.AddInstruction(HloInstruction::CreateTuple({negate})); 369 auto exp = builder.AddInstruction( 370 HloInstruction::CreateUnary(vec_, HloOpcode::kExp, negate)); 371 auto outer_tuple = 372 builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple, exp})); 373 374 auto module = CreateNewModule(); 375 module->AddEntryComputation(builder.Build()); 376 377 auto liveness = 378 BufferLiveness::Run(module.get(), 379 xla::MakeUnique<DependencyHloOrdering>(module.get())) 380 .ConsumeValueOrDie(); 381 382 // All buffers should be live out except the param 383 EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, param)); 384 EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, negate)); 385 EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, inner_tuple)); 386 EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, exp)); 387 EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, outer_tuple)); 388 } 389 390 // bitcast liveout. 391 392 TEST_F(BufferLivenessTest, EmbeddedComputation) { 393 // Test MaybeLiveOut and MayInterfere for embedded computation. 394 auto module = CreateNewModule(); 395 396 auto embedded_builder = HloComputation::Builder(TestName() + "_embedded"); 397 auto embedded_param = embedded_builder.AddInstruction( 398 HloInstruction::CreateParameter(0, vec_, "embedded_param")); 399 auto embedded_log = embedded_builder.AddInstruction( 400 HloInstruction::CreateUnary(vec_, HloOpcode::kLog, embedded_param)); 401 402 auto embedded_computation = 403 module->AddEmbeddedComputation(embedded_builder.Build()); 404 405 auto builder = HloComputation::Builder(TestName()); 406 auto param = 407 builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); 408 auto call = builder.AddInstruction( 409 HloInstruction::CreateCall(vec_, {param}, embedded_computation)); 410 411 module->AddEntryComputation(builder.Build()); 412 413 auto liveness = 414 BufferLiveness::Run(module.get(), 415 xla::MakeUnique<DependencyHloOrdering>(module.get())) 416 .ConsumeValueOrDie(); 417 418 // Buffers in different computations should always interfere. 419 EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_log, call)); 420 EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_param, param)); 421 EXPECT_FALSE( 422 InstructionsMayInterfere(*liveness, embedded_param, embedded_log)); 423 424 // The only buffers for which MaybeLiveOut == true are those live out 425 // of the entry computation. Buffers live out of embedded computations should 426 // return false for this method. 427 EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, embedded_log)); 428 EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, call)); 429 } 430 431 TEST_F(BufferLivenessTest, TupleConstantLiveOut) { 432 // Verify non top-level elements of a nested tuple constant are properly 433 // marked as liveout. Computation: 434 // 435 // GetTupleElement(0, TupleConstant({{0, 1}, {3}}) 436 // 437 // Only the array buffers containing 0 and 1 are liveout of the 438 // computation. The buffer containing {0, 1} is copied by GetTupleElement, and 439 // the buffers containing {3} and 3 are dead. 440 auto builder = HloComputation::Builder(TestName()); 441 auto inner_tuple0 = Literal::MakeTuple( 442 {Literal::CreateR0<int64>(0).get(), Literal::CreateR0<int64>(1).get()}); 443 auto inner_tuple1 = Literal::MakeTuple({Literal::CreateR0<int64>(3).get()}); 444 auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( 445 Literal::MakeTuple({inner_tuple0.get(), inner_tuple1.get()}))); 446 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 447 inner_tuple0->shape(), tuple_constant, 0)); 448 449 auto module = CreateNewModule(); 450 module->AddEntryComputation(builder.Build()); 451 452 auto liveness = 453 BufferLiveness::Run(module.get(), 454 xla::MakeUnique<DependencyHloOrdering>(module.get())) 455 .ConsumeValueOrDie(); 456 457 // Only the element buffers of the tuple constant which are pointed to by 458 // the GetTupleElement instruction should be liveout. 459 EXPECT_FALSE(liveness->MaybeLiveOut( 460 GetBuffer(*liveness, tuple_constant, /*index=*/{}))); 461 EXPECT_TRUE(liveness->MaybeLiveOut( 462 GetBuffer(*liveness, tuple_constant, /*index=*/{0}))); 463 EXPECT_TRUE(liveness->MaybeLiveOut( 464 GetBuffer(*liveness, tuple_constant, /*index=*/{0, 0}))); 465 EXPECT_TRUE(liveness->MaybeLiveOut( 466 GetBuffer(*liveness, tuple_constant, /*index=*/{0, 1}))); 467 EXPECT_FALSE(liveness->MaybeLiveOut( 468 GetBuffer(*liveness, tuple_constant, /*index=*/{1}))); 469 EXPECT_FALSE(liveness->MaybeLiveOut( 470 GetBuffer(*liveness, tuple_constant, /*index=*/{1, 0}))); 471 EXPECT_FALSE(liveness->MaybeLiveOut( 472 GetBuffer(*liveness, tuple_constant, /*index=*/{1, 0}))); 473 } 474 475 TEST_F(BufferLivenessTest, IndependentTupleElements) { 476 auto builder = HloComputation::Builder(TestName()); 477 // Create param0 Tuple. 478 auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter( 479 0, 480 ShapeUtil::MakeTupleShape( 481 {ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(S32, {4})}), 482 "param0")); 483 // Create independent computations for each tuple elememt. 484 485 // Tuple element0 computation: 486 // Add(GetTupleElement(tuple_param0, 0), const0) 487 auto tuple_element0_shape = 488 ShapeUtil::GetSubshape(tuple_param0->shape(), {0}); 489 auto tuple_element0 = 490 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 491 tuple_element0_shape, tuple_param0, 0)); 492 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 493 Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); 494 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( 495 tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0)); 496 497 // Tuple element1 computation: 498 // Add(GetTupleElement(tuple_param0, 1), const1) 499 auto tuple_element1_shape = 500 ShapeUtil::GetSubshape(tuple_param0->shape(), {1}); 501 auto tuple_element1 = 502 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 503 tuple_element1_shape, tuple_param0, 1)); 504 auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( 505 Literal::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}))); 506 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( 507 tuple_element1_shape, HloOpcode::kAdd, tuple_element1, const1)); 508 509 // Create output tuple. 510 auto tuple_root = 511 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); 512 513 auto module = CreateNewModule(); 514 module->AddEntryComputation(BuildDummyComputation()); 515 module->AddEmbeddedComputation(builder.Build()); 516 517 auto liveness = 518 BufferLiveness::Run(module.get(), 519 xla::MakeUnique<DependencyHloOrdering>(module.get())) 520 .ConsumeValueOrDie(); 521 522 // We compare tuple element pairs that are input/output to the computation: 523 // 1) (input_tuple_element, output_tuple_element) = ('tuple_element0', 'add0') 524 // 2) (input_tuple_element, output_tuple_element) = ('tuple_element1', 'add1') 525 526 // Tuple output element 'add0' does not depend on input 'tuple_element1'. 527 // Tuple output element 'add1' does not depend on input 'tuple_element0'. 528 529 // Both element pair does not interfere, because there is no other dependency 530 // on the pairs tuple input element, and so liveness can compute that all 531 // users of the input tuple element execute before the associated output 532 // tuple element. 533 EXPECT_FALSE( 534 TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {0})); 535 EXPECT_FALSE( 536 TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1})); 537 } 538 539 TEST_F(BufferLivenessTest, DependentTupleElements) { 540 auto builder = HloComputation::Builder(TestName()); 541 // Create param0 Tuple. 542 auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter( 543 0, 544 ShapeUtil::MakeTupleShape( 545 {ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(F32, {8})}), 546 "param0")); 547 // Create dependent computations for each tuple elememt. 548 549 // Tuple element0 computation: 550 // Add(GetTupleElement(tuple_param0, 0), const0) 551 auto tuple_element0_shape = 552 ShapeUtil::GetSubshape(tuple_param0->shape(), {0}); 553 auto tuple_element0 = 554 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 555 tuple_element0_shape, tuple_param0, 0)); 556 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 557 Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); 558 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( 559 tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0)); 560 561 // Tuple element1 computation: 562 // Add(GetTupleElement(tuple_param0, 0), GetTupleElement(tuple_param0, 1)) 563 auto tuple_element1_shape = 564 ShapeUtil::GetSubshape(tuple_param0->shape(), {1}); 565 auto tuple_element1 = 566 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 567 tuple_element1_shape, tuple_param0, 1)); 568 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( 569 tuple_element1_shape, HloOpcode::kAdd, tuple_element0, tuple_element1)); 570 571 // Create output tuple. 572 auto tuple_root = 573 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); 574 575 auto module = CreateNewModule(); 576 module->AddEntryComputation(BuildDummyComputation()); 577 module->AddEmbeddedComputation(builder.Build()); 578 579 auto liveness = 580 BufferLiveness::Run(module.get(), 581 xla::MakeUnique<DependencyHloOrdering>(module.get())) 582 .ConsumeValueOrDie(); 583 584 // We compare tuple element pairs that are input/output to the computation: 585 // 1) (input_tuple_element, output_tuple_element) = ('tuple_element0', 'add0') 586 // 2) (input_tuple_element, output_tuple_element) = ('tuple_element1', 'add1') 587 588 // The first tuple element pair output 'add0', has no dependency on second 589 // tuple element pairs input 'tuple_element1'. 590 591 // The second tuple element pair output 'add1', has a dependency on first 592 // tuple element pairs input 'tuple_element0'. 593 594 // The first tuple element pair does interfere, because liveness cannot 595 // compute that all references to 'tuple_element0' are executed before 'add0' 596 // (because of the depenency of 'add1' on 'tuple_element0'). 597 EXPECT_TRUE( 598 TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {0})); 599 600 // The second tuple element pair does not interfere, because there is no 601 // other dependency on 'tuple_element1', and so liveness can compute that 602 // all users execute before 'add1'. 603 EXPECT_FALSE( 604 TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1})); 605 } 606 607 class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { 608 protected: 609 // Builds and runs a computation (see test case computation graphs below). 610 // Runs BufferLiveness on this computation. 611 // Returns whether buffer interference is detected between tuple-shaped 612 // parameter and root instructions at tuple element 1. 613 bool Run(const bool update_uses_tuple_element1, 614 const bool fuse_gte0 = false) { 615 auto builder = HloComputation::Builder(TestName()); 616 // Create param0 Tuple. 617 Shape data_shape = ShapeUtil::MakeShape(F32, {8}); 618 Shape update_shape = ShapeUtil::MakeShape(F32, {3}); 619 auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter( 620 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "param0")); 621 622 auto gte0 = builder.AddInstruction( 623 HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 0)); 624 625 auto gte1 = builder.AddInstruction( 626 HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1)); 627 628 auto update = builder.AddInstruction(HloInstruction::CreateConstant( 629 Literal::CreateR1<float>({2.f, 2.f, 2.f}))); 630 HloInstruction* slice = nullptr; 631 if (update_uses_tuple_element1) { 632 // Create a slice instruction as an additional user of 'gte1'. 633 slice = builder.AddInstruction( 634 HloInstruction::CreateSlice(update_shape, gte1, {0}, {3}, {1})); 635 update = builder.AddInstruction(HloInstruction::CreateBinary( 636 update_shape, HloOpcode::kAdd, update, slice)); 637 } 638 // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. 639 auto starts = builder.AddInstruction( 640 HloInstruction::CreateConstant(Literal::CreateR1<int32>({2}))); 641 auto dynamic_update_slice = 642 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( 643 data_shape, gte1, update, starts)); 644 // Create output tuple. 645 auto tuple_root = builder.AddInstruction( 646 HloInstruction::CreateTuple({gte0, dynamic_update_slice})); 647 // Build module and get reference to entry computation. 648 auto module = CreateNewModule(); 649 module->AddEntryComputation(BuildDummyComputation()); 650 auto* computation = module->AddEmbeddedComputation(builder.Build()); 651 // Create fusion instruction based on number of tuple element 1 users. 652 if (update_uses_tuple_element1) { 653 computation->CreateFusionInstruction( 654 {dynamic_update_slice, starts, update, CHECK_NOTNULL(slice), gte1}, 655 HloInstruction::FusionKind::kLoop); 656 } else { 657 computation->CreateFusionInstruction( 658 {dynamic_update_slice, starts, update, gte1}, 659 HloInstruction::FusionKind::kLoop); 660 } 661 // Create fusion instruction for tuple element 0 (if requested). 662 if (fuse_gte0) { 663 computation->CreateFusionInstruction({gte0}, 664 HloInstruction::FusionKind::kLoop); 665 } 666 667 // Run BufferLiveness on 'module'. 668 auto liveness = 669 BufferLiveness::Run( 670 module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get())) 671 .ConsumeValueOrDie(); 672 // Return whether or not buffers interference is detected between 673 // 'tuple_param0' and 'tuple_root' at shape index '{1}'. 674 return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}); 675 } 676 }; 677 678 // Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion) 679 // do not overlap with the following computation: 680 // 681 // Param0 682 // / \ 683 // GTE(0) Fusion -----------> FusionParam 684 // | | | 685 // | | GTE(1) Const Const 686 // | | \ | / 687 // | | DynamicUpdateSlice // fused root 688 // \ / 689 // Tuple // computation root 690 // 691 TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) { 692 EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false)); 693 } 694 695 // Tests that live ranges of buffers Param0[1] and Tuple[1] (which aliases 696 // 'fusion1') do not overlap in the presence of another fusion instruction 697 // (which is a user of 'param0' at a different tuple index). 698 // BufferLiveness should detect no uses of Param0 at index {1} in Fusion0 699 // (because Fusion0 only uses Param0 at index {0}). 700 // 701 // Param0 702 // / \ 703 // FusionParam <----- Fusion0 Fusion1 ------> FusionParam 704 // | | | | 705 // GTE(0) | | GTE(1) Const Const 706 // | | \ | / 707 // \ / DynamicUpdateSlice 708 // Tuple 709 // 710 TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) { 711 EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false, /*fuse_gte0=*/true)); 712 } 713 714 // Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion) 715 // do overlap because GTE(1) has two users: 716 // 1) DynamicUpdateSlice at operand 0. 717 // 2) Slice at operand 0. 718 // 719 // Param0 720 // / \ Const 721 // / \ / 722 // GTE(0) Fusion -----------> FusionParam FusionParam 723 // | | | | 724 // | | GTE(1) / 725 // | | | \ / 726 // | | | Slice / 727 // | | | \ / 728 // | | | Add Const 729 // | | | | | 730 // | | DynamicUpdateSlice // fused root 731 // \ / 732 // Tuple // computation root 733 // 734 TEST_F(FusedDynamicUpdateSliceLivenessTest, WithInterference) { 735 EXPECT_TRUE(Run(/*update_uses_tuple_element1=*/true)); 736 } 737 738 class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { 739 protected: 740 // Builds and runs a computation (see test case computation graphs below). 741 // Runs BufferLiveness on this computation. 742 // Returns whether buffer interference is detected between tuple-shaped 743 // parameter and root instructions at tuple element 1. 744 bool Run(const bool tuple_element1_has_two_uses) { 745 auto builder = HloComputation::Builder(TestName()); 746 // Create param0 Tuple. 747 Shape data_shape = ShapeUtil::MakeShape(F32, {8}); 748 Shape update_shape = ShapeUtil::MakeShape(F32, {3}); 749 auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter( 750 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "param0")); 751 752 auto gte0 = builder.AddInstruction( 753 HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 0)); 754 755 auto gte1 = builder.AddInstruction( 756 HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1)); 757 758 auto update = builder.AddInstruction(HloInstruction::CreateConstant( 759 Literal::CreateR1<float>({2.f, 2.f, 2.f}))); 760 761 if (tuple_element1_has_two_uses) { 762 // Add 'gte0' and 'gte1' to create another user of 'gte1'. 763 gte0 = builder.AddInstruction(HloInstruction::CreateBinary( 764 data_shape, HloOpcode::kAdd, gte0, gte1)); 765 } 766 // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. 767 auto starts = builder.AddInstruction( 768 HloInstruction::CreateConstant(Literal::CreateR1<int32>({2}))); 769 auto dynamic_update_slice = 770 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( 771 data_shape, gte1, update, starts)); 772 // Create output tuple. 773 auto tuple_root = builder.AddInstruction( 774 HloInstruction::CreateTuple({gte0, dynamic_update_slice})); 775 // Build module and get reference to entry computation. 776 auto module = CreateNewModule(); 777 module->AddEntryComputation(BuildDummyComputation()); 778 module->AddEmbeddedComputation(builder.Build()); 779 // Run BufferLiveness on 'module'. 780 auto liveness = 781 BufferLiveness::Run( 782 module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get())) 783 .ConsumeValueOrDie(); 784 // Return whether or not buffers interference is detected between 785 // 'tuple_param0' and 'tuple_root' at shape index '{1}'. 786 return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}); 787 } 788 }; 789 790 // Tests that live ranges of buffers Param0[1] and Tuple[1] do not overlap in 791 // the following computation (because DynamicUpdateSlice (at operand 0) is the 792 // unique user): 793 // 794 // Parameter0 795 // | | 796 // GTE(0) GTE(1) Const Const 797 // | \ | / 798 // | DynamicUpdateSlice 799 // \ / 800 // Tuple 801 // 802 TEST_F(DynamicUpdateSliceLivenessTest, NoInterference) { 803 EXPECT_FALSE(Run(/*tuple_element1_has_two_uses=*/false)); 804 } 805 806 // Tests that live ranges of buffers Param0[1] and Tuple[1] do overlap because 807 // GTE(1) has two users: 808 // 1) DynamicUpdateSlice at operand 0. 809 // 2) Add at operand 1. 810 // 811 // Parameter0 812 // | | 813 // GTE(0) GTE(1) 814 // | / | 815 // | / | 816 // Add | Const Const 817 // | | | | 818 // | DynamicUpdateSlice 819 // \ / 820 // Tuple 821 // 822 TEST_F(DynamicUpdateSliceLivenessTest, WithInterference) { 823 EXPECT_TRUE(Run(/*tuple_element1_has_two_uses=*/true)); 824 } 825 826 } // namespace 827 828 } // namespace xla 829