1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/copy_insertion.h" 17 18 #include <set> 19 20 #include "tensorflow/compiler/xla/debug_options_flags.h" 21 #include "tensorflow/compiler/xla/literal.h" 22 #include "tensorflow/compiler/xla/service/hlo_computation.h" 23 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 24 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 25 #include "tensorflow/compiler/xla/service/hlo_module.h" 26 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 27 #include "tensorflow/compiler/xla/service/hlo_runner.h" 28 #include "tensorflow/compiler/xla/shape_util.h" 29 #include "tensorflow/compiler/xla/test.h" 30 #include "tensorflow/compiler/xla/test_helpers.h" 31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 32 #include "tensorflow/compiler/xla/xla_data.pb.h" 33 #include "tensorflow/core/platform/test_benchmark.h" 34 35 namespace op = xla::testing::opcode_matchers; 36 37 namespace xla { 38 namespace { 39 40 using ::testing::UnorderedElementsAre; 41 42 int64 CountCopies(const HloComputation& computation) { 43 int64 count = 0; 44 for (const auto& instruction : computation.instructions()) { 45 if (instruction->opcode() == HloOpcode::kCopy) { 46 count++; 47 } 48 } 49 return count; 50 } 51 52 int64 CountCopies(const HloModule& module) { 53 int64 count = 0; 54 for (const auto& computation : module.computations()) { 55 count += CountCopies(*computation); 56 } 57 return count; 58 } 59 60 int64 CountControlEdges(const HloComputation& computation) { 61 int64 count = 0; 62 for (const auto& instruction : computation.instructions()) { 63 count += instruction->control_successors().size(); 64 } 65 return count; 66 } 67 68 int64 CountControlEdges(const HloModule& module) { 69 int64 count = 0; 70 for (const auto& computation : module.computations()) { 71 count += CountControlEdges(*computation); 72 } 73 return count; 74 } 75 76 class CopyInsertionTest : public HloTestBase { 77 protected: 78 void InsertCopies(HloModule* module) { 79 CopyInsertion copy_insertion; 80 ASSERT_IS_OK(copy_insertion.Run(module).status()); 81 } 82 83 const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); 84 }; 85 86 TEST_F(CopyInsertionTest, SingleParameter) { 87 // Computation is a single parameter passed into a tuple. The parameter should 88 // be copied before entering the tuple. 89 auto builder = HloComputation::Builder(TestName()); 90 HloInstruction* x = builder.AddInstruction( 91 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x")); 92 HloInstruction* tuple = 93 builder.AddInstruction(HloInstruction::CreateTuple({x})); 94 95 EXPECT_THAT(x->users(), UnorderedElementsAre(tuple)); 96 97 auto module = CreateNewVerifiedModule(); 98 module->AddEntryComputation(builder.Build()); 99 100 InsertCopies(module.get()); 101 102 EXPECT_THAT(module->entry_computation()->root_instruction(), 103 op::Tuple(op::Copy(x))); 104 } 105 106 TEST_F(CopyInsertionTest, SingleConstant) { 107 // Computation is a single constant passed into a tuple. The parameter should 108 // be copied before entering the tuple. 109 auto builder = HloComputation::Builder(TestName()); 110 HloInstruction* constant = builder.AddInstruction( 111 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); 112 HloInstruction* tuple = 113 builder.AddInstruction(HloInstruction::CreateTuple({constant})); 114 115 EXPECT_THAT(constant->users(), UnorderedElementsAre(tuple)); 116 117 auto module = CreateNewVerifiedModule(); 118 module->AddEntryComputation(builder.Build()); 119 120 InsertCopies(module.get()); 121 EXPECT_EQ(CountCopies(*module), 1); 122 123 EXPECT_THAT(module->entry_computation()->root_instruction(), 124 op::Tuple(op::Copy(constant))); 125 } 126 127 TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { 128 // Verify that kCopy instructions which change layout and exist before 129 // copy-insertion remain in the graph after copy-insertion. 130 auto module = CreateNewVerifiedModule(); 131 132 auto builder = HloComputation::Builder(TestName()); 133 HloInstruction* constant = 134 builder.AddInstruction(HloInstruction::CreateConstant( 135 LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}}))); 136 auto minor_to_major = LayoutUtil::MinorToMajor(constant->shape()); 137 Layout reversed_layout = 138 LayoutUtil::MakeLayoutFromMajorToMinor(minor_to_major); 139 Shape copy_shape = constant->shape(); 140 *copy_shape.mutable_layout() = reversed_layout; 141 HloInstruction* copy_1 = builder.AddInstruction( 142 HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant)); 143 HloInstruction* copy_2 = builder.AddInstruction( 144 HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant)); 145 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( 146 constant->shape(), HloOpcode::kAdd, copy_1, copy_2)); 147 builder.AddInstruction( 148 HloInstruction::CreateUnary(add->shape(), HloOpcode::kCopy, add)); 149 150 module->AddEntryComputation(builder.Build()); 151 152 EXPECT_EQ(CountCopies(*module), 3); 153 154 InsertCopies(module.get()); 155 156 EXPECT_EQ(CountCopies(*module), 2); 157 158 EXPECT_EQ(module->entry_computation()->root_instruction(), add); 159 EXPECT_THAT(module->entry_computation()->root_instruction(), 160 op::Add(op::Copy(op::Constant()), op::Copy(op::Constant()))); 161 } 162 163 TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { 164 // Create a computation with more than one constant and parameter. Only one of 165 // each constant/parameter is pointed to by the output tuple. Only these 166 // instructions should be copied. 167 auto builder = HloComputation::Builder(TestName()); 168 169 HloInstruction* constant1 = builder.AddInstruction( 170 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); 171 HloInstruction* constant2 = builder.AddInstruction( 172 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0))); 173 174 HloInstruction* x = builder.AddInstruction( 175 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x")); 176 HloInstruction* y = builder.AddInstruction( 177 HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "y")); 178 179 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( 180 ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, constant1, y)); 181 182 builder.AddInstruction(HloInstruction::CreateTuple({constant2, x, add})); 183 184 auto module = CreateNewVerifiedModule(); 185 module->AddEntryComputation(builder.Build()); 186 187 InsertCopies(module.get()); 188 EXPECT_EQ(CountCopies(*module), 2); 189 190 EXPECT_THAT( 191 module->entry_computation()->root_instruction(), 192 op::Tuple(op::Copy(constant2), op::Copy(x), op::Add(constant1, y))); 193 } 194 195 TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { 196 // Create a computation using select which has an ambiguous points-to set for 197 // the computation result. Verify that copies are added properly. 198 auto builder = HloComputation::Builder(TestName()); 199 HloInstruction* constant1 = builder.AddInstruction( 200 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); 201 HloInstruction* constant2 = builder.AddInstruction( 202 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0))); 203 HloInstruction* constant3 = builder.AddInstruction( 204 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0))); 205 206 HloInstruction* tuple1 = builder.AddInstruction( 207 HloInstruction::CreateTuple({constant1, constant2})); 208 HloInstruction* tuple2 = builder.AddInstruction( 209 HloInstruction::CreateTuple({constant3, constant2})); 210 211 HloInstruction* pred = builder.AddInstruction( 212 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); 213 builder.AddInstruction(HloInstruction::CreateTernary( 214 tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); 215 216 EXPECT_THAT(constant1->users(), UnorderedElementsAre(tuple1)); 217 EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2)); 218 EXPECT_THAT(constant3->users(), UnorderedElementsAre(tuple2)); 219 220 auto module = CreateNewVerifiedModule(); 221 module->AddEntryComputation(builder.Build()); 222 223 HloInstruction* old_root = module->entry_computation()->root_instruction(); 224 InsertCopies(module.get()); 225 EXPECT_EQ(CountCopies(*module), 2); 226 227 EXPECT_THAT(module->entry_computation()->root_instruction(), 228 op::Tuple(op::Copy(op::GetTupleElement(old_root)), 229 op::Copy(op::GetTupleElement(old_root)))); 230 } 231 232 TEST_F(CopyInsertionTest, BitcastParameter) { 233 // The output of a bitcast is its operand (same buffer), so a bitcast 234 // parameter feeding the result must have a copy added. 235 auto builder = HloComputation::Builder(TestName()); 236 HloInstruction* x = builder.AddInstruction( 237 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x")); 238 HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( 239 ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); 240 241 auto module = CreateNewVerifiedModule(); 242 module->AddEntryComputation(builder.Build()); 243 244 EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); 245 246 HloInstruction* old_root = module->entry_computation()->root_instruction(); 247 InsertCopies(module.get()); 248 EXPECT_EQ(CountCopies(*module), 1); 249 250 EXPECT_THAT(module->entry_computation()->root_instruction(), 251 op::Copy(old_root)); 252 } 253 254 TEST_F(CopyInsertionTest, BitcastConstant) { 255 // The output of a bitcast is its operand (same buffer), so a bitcast 256 // constant feeding the result must have a copy added. 257 auto builder = HloComputation::Builder(TestName()); 258 HloInstruction* constant = 259 builder.AddInstruction(HloInstruction::CreateConstant( 260 LiteralUtil::CreateR1<float>({1.0, 42.0}))); 261 HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( 262 ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant)); 263 264 auto module = CreateNewVerifiedModule(); 265 module->AddEntryComputation(builder.Build()); 266 267 EXPECT_THAT(constant->users(), UnorderedElementsAre(bitcast)); 268 269 HloInstruction* old_root = module->entry_computation()->root_instruction(); 270 InsertCopies(module.get()); 271 EXPECT_EQ(CountCopies(*module), 1); 272 273 EXPECT_THAT(module->entry_computation()->root_instruction(), 274 op::Copy(old_root)); 275 } 276 277 TEST_F(CopyInsertionTest, BitcastTupleElementParameter) { 278 // Same as BitcastParameter, but the bitcast is wrapped in a tuple. 279 auto builder = HloComputation::Builder(TestName()); 280 HloInstruction* x = builder.AddInstruction( 281 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x")); 282 HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( 283 ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); 284 builder.AddInstruction(HloInstruction::CreateTuple({bitcast})); 285 286 auto module = CreateNewVerifiedModule(); 287 module->AddEntryComputation(builder.Build()); 288 289 EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); 290 291 InsertCopies(module.get()); 292 EXPECT_EQ(CountCopies(*module), 1); 293 294 EXPECT_THAT(module->entry_computation()->root_instruction(), 295 op::Tuple(op::Copy(bitcast))); 296 } 297 298 TEST_F(CopyInsertionTest, NestedTupleParameter) { 299 // Construct a trivial computation where the root of the computation is a 300 // nested tuple-shaped parameter. The parameter should be deep copied and the 301 // copy should be the root of the computation. 302 auto builder = HloComputation::Builder(TestName()); 303 304 // Param shape is: ((F32[], S32[1,2,3]), F32[42]) 305 builder.AddInstruction(HloInstruction::CreateParameter( 306 0, 307 ShapeUtil::MakeTupleShape( 308 {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}), 309 ShapeUtil::MakeShape(S32, {1, 2, 3})}), 310 ShapeUtil::MakeShape(F32, {42})}), 311 "param0")); 312 313 auto module = CreateNewVerifiedModule(); 314 module->AddEntryComputation(builder.Build()); 315 316 EXPECT_EQ(HloOpcode::kParameter, 317 module->entry_computation()->root_instruction()->opcode()); 318 319 HloInstruction* old_root = module->entry_computation()->root_instruction(); 320 InsertCopies(module.get()); 321 EXPECT_EQ(CountCopies(*module), 3); 322 323 HloInstruction* new_root = module->entry_computation()->root_instruction(); 324 EXPECT_NE(old_root, new_root); 325 326 EXPECT_THAT( 327 new_root, 328 op::Tuple( 329 op::Tuple( 330 op::Copy(op::GetTupleElement(op::GetTupleElement(old_root))), 331 op::Copy(op::GetTupleElement(op::GetTupleElement(old_root)))), 332 op::Copy(op::GetTupleElement(old_root)))); 333 } 334 335 TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { 336 // Construct a computation where the root of the computation is a tuple 337 // element of a nested tuple-shaped parameter. 338 auto builder = HloComputation::Builder(TestName()); 339 340 // Param shape is: ((F32[], S32[1,2,3]), F32[42]) 341 auto param = builder.AddInstruction(HloInstruction::CreateParameter( 342 0, 343 ShapeUtil::MakeTupleShape( 344 {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}), 345 ShapeUtil::MakeShape(S32, {1, 2, 3})}), 346 ShapeUtil::MakeShape(F32, {42})}), 347 "param0")); 348 349 // The return value of the computation is the zero-th element of the nested 350 // tuple. This element is itself a tuple. 351 auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 352 ShapeUtil::GetSubshape(param->shape(), {0}), param, 0)); 353 354 auto module = CreateNewVerifiedModule(); 355 module->AddEntryComputation(builder.Build()); 356 357 EXPECT_EQ(gte, module->entry_computation()->root_instruction()); 358 359 InsertCopies(module.get()); 360 EXPECT_EQ(CountCopies(*module), 2); 361 362 EXPECT_THAT( 363 module->entry_computation()->root_instruction(), 364 op::Tuple(op::Copy(op::GetTupleElement(op::GetTupleElement(param))), 365 op::Copy(op::GetTupleElement(op::GetTupleElement(param))))); 366 } 367 368 TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { 369 // Create a computation using select which has an ambiguous points-to set for 370 // the top-level buffer of the root of the computation. Verify that a shallow 371 // copy is added. 372 auto builder = HloComputation::Builder(TestName()); 373 HloInstruction* constant1 = builder.AddInstruction( 374 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); 375 HloInstruction* constant2 = builder.AddInstruction( 376 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0))); 377 378 HloInstruction* tuple1 = builder.AddInstruction( 379 HloInstruction::CreateTuple({constant1, constant2})); 380 HloInstruction* tuple2 = builder.AddInstruction( 381 HloInstruction::CreateTuple({constant2, constant1})); 382 383 HloInstruction* pred = builder.AddInstruction( 384 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); 385 HloInstruction* select = builder.AddInstruction(HloInstruction::CreateTernary( 386 tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); 387 HloInstruction* gte = 388 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 389 ShapeUtil::GetSubshape(select->shape(), {0}), select, 0)); 390 391 auto module = CreateNewVerifiedModule(); 392 module->AddEntryComputation(builder.Build()); 393 394 EXPECT_EQ(gte, module->entry_computation()->root_instruction()); 395 396 HloInstruction* old_root = module->entry_computation()->root_instruction(); 397 InsertCopies(module.get()); 398 EXPECT_EQ(CountCopies(*module), 1); 399 400 EXPECT_THAT(module->entry_computation()->root_instruction(), 401 op::Copy(old_root)); 402 } 403 404 class WhileCopyInsertionTest : public CopyInsertionTest { 405 protected: 406 WhileCopyInsertionTest() : module_(CreateNewUnverifiedModule()) {} 407 408 // Builds a While condition computation which reads the induction variable 409 // from the tuple parameter, and returns a predicate indicating whether this 410 // value is less than the constant '10'. 411 // The parameter 'nested' specifies the loop state shape from which to 412 // read the induction variable. 413 std::unique_ptr<HloComputation> BuildConditionComputation( 414 const Shape& loop_state_shape) { 415 auto builder = HloComputation::Builder(TestName() + ".Condition"); 416 auto limit_const = builder.AddInstruction( 417 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(10))); 418 auto loop_state = builder.AddInstruction( 419 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); 420 auto induction_variable = 421 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 422 limit_const->shape(), loop_state, 0)); 423 builder.AddInstruction(HloInstruction::CreateCompare( 424 condition_result_shape_, induction_variable, limit_const, 425 ComparisonDirection::kLt)); 426 return builder.Build(); 427 } 428 429 // Builds a While body computation with one output tuple element dependent on 430 // both input tuple elements. 431 // EX: 432 // Body({in0, in1}) 433 // out0 = Add(in0, 1) 434 // out1 = Add(BCast(in0), in1) 435 // Tuple(out0, out1) 436 std::unique_ptr<HloComputation> BuildDependentBodyComputation() { 437 auto builder = HloComputation::Builder(TestName() + ".Body"); 438 // Create param instruction to access loop state. 439 auto loop_state = builder.AddInstruction( 440 HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); 441 // Update the induction variable GTE(0). 442 auto induction_variable = 443 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 444 induction_variable_shape_, loop_state, 0)); 445 auto inc = builder.AddInstruction( 446 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1))); 447 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( 448 induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); 449 // Update data GTE(1). 450 auto data = builder.AddInstruction( 451 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); 452 // Use 'induction_variable' in computation with no path to output tuple. 453 auto update = builder.AddInstruction( 454 HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8})); 455 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( 456 data_shape_, HloOpcode::kAdd, data, update)); 457 // Create output Tuple. 458 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); 459 return builder.Build(); 460 } 461 462 // Builds a While body computation with two output tuple elements dependent on 463 // both input tuple elements. 464 // 465 // EX: Body({in0, in1, in2}) 466 // out0 = Add(in0, 1) 467 // out1 = in1 468 // out2 = in2 469 // Tuple(out0, out1, out2) 470 std::unique_ptr<HloComputation> BuildDependentBodyComputation2() { 471 auto builder = HloComputation::Builder(TestName() + ".Body"); 472 473 const Shape& loop_state_shape = ShapeUtil::MakeTupleShape( 474 {induction_variable_shape_, data_shape_, data_shape_}); 475 476 auto loop_state = builder.AddInstruction( 477 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); 478 479 // Update the induction variable GTE(0). 480 auto induction_variable = 481 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 482 induction_variable_shape_, loop_state, 0)); 483 auto inc = builder.AddInstruction( 484 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1))); 485 486 // add0 = Add(in0, 1) 487 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( 488 induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); 489 // data1 = GTE(1). 490 HloInstruction* data1 = builder.AddInstruction( 491 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); 492 493 // data2 = GTE(2). 494 HloInstruction* data2 = builder.AddInstruction( 495 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 2)); 496 497 // Create output Tuple. 498 builder.AddInstruction(HloInstruction::CreateTuple({add0, data1, data2})); 499 500 return builder.Build(); 501 } 502 503 // Builds a While body computation with read-only tuple element 0. 504 // EX: 505 // Body({in0, in1}) 506 // out0 = in0 507 // out1 = Add(BCast(in0), in1) 508 // Tuple(out0, out1) 509 std::unique_ptr<HloComputation> BuildDependentBodyOneReadOnlyComputation() { 510 auto builder = HloComputation::Builder(TestName() + ".Body"); 511 // Create param instruction to access loop state. 512 auto loop_state = builder.AddInstruction( 513 HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); 514 // Update the induction variable GTE(0). 515 auto induction_variable = 516 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 517 induction_variable_shape_, loop_state, 0)); 518 // Update data GTE(1). 519 auto data = builder.AddInstruction( 520 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); 521 522 // Use 'induction_variable' in computation with no path to output tuple. 523 auto update = builder.AddInstruction( 524 HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8})); 525 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( 526 data_shape_, HloOpcode::kAdd, data, update)); 527 // Create output Tuple. 528 builder.AddInstruction( 529 HloInstruction::CreateTuple({induction_variable, add1})); 530 return builder.Build(); 531 } 532 533 // Builds a While body computation with independent outputs. 534 // EX: 535 // Body({in0, in1}) 536 // out0 = Add(in0, 1) 537 // out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) 538 // Tuple(out0, out1) 539 std::unique_ptr<HloComputation> BuildIndependentBodyComputation( 540 bool nested = false) { 541 auto builder = HloComputation::Builder(TestName() + ".Body"); 542 // Create param instruction to access loop state. 543 const Shape& loop_state_shape = 544 nested ? nested_loop_state_shape_ : loop_state_shape_; 545 546 auto loop_state = builder.AddInstruction( 547 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); 548 // Update the induction variable GTE(0). 549 auto induction_variable = 550 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 551 induction_variable_shape_, loop_state, 0)); 552 auto inc = builder.AddInstruction( 553 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1))); 554 // add0 = Add(in0, 1) 555 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( 556 induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); 557 // Update data GTE(1). 558 HloInstruction* data = nullptr; 559 if (nested) { 560 data = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 561 nested_tuple_shape_, loop_state, 1)); 562 data = builder.AddInstruction( 563 HloInstruction::CreateGetTupleElement(data_shape_, data, 0)); 564 } else { 565 data = builder.AddInstruction( 566 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); 567 } 568 auto update = builder.AddInstruction( 569 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>( 570 {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); 571 // add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) 572 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( 573 data_shape_, HloOpcode::kAdd, data, update)); 574 // Create output Tuple. 575 if (nested) { 576 auto nested_tuple = 577 builder.AddInstruction(HloInstruction::CreateTuple({add1, add1})); 578 builder.AddInstruction(HloInstruction::CreateTuple({add0, nested_tuple})); 579 } else { 580 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); 581 } 582 return builder.Build(); 583 } 584 585 // Builds a While body computation with the following nested tuple 586 // sub-computation: 587 // | 588 // GTE(loop_state, 1) 589 // / \ 590 // GTE(GTE(loop_state, 1), 0) GTE(GTE(loop_state, 1), 1) 591 // | | 592 // Add Reverse 593 // | | 594 std::unique_ptr<HloComputation> BuildNestedBodyComputation() { 595 auto builder = HloComputation::Builder(TestName() + ".Body"); 596 // Create param instruction to access loop state. 597 auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter( 598 0, nested_loop_state_shape_, "loop_state")); 599 // Update GTE(0). 600 auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 601 induction_variable_shape_, loop_state, 0)); 602 auto inc = builder.AddInstruction( 603 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1))); 604 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( 605 gte0->shape(), HloOpcode::kAdd, gte0, inc)); 606 607 // GTE(loop_state, 1) 608 auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 609 nested_tuple_shape_, loop_state, 1)); 610 // GTE(GTE(loop_state, 1), 0) -> Add 611 auto gte10 = builder.AddInstruction( 612 HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0)); 613 auto update10 = builder.AddInstruction( 614 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>( 615 {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); 616 auto add10 = builder.AddInstruction(HloInstruction::CreateBinary( 617 data_shape_, HloOpcode::kAdd, gte10, update10)); 618 619 // GTE(GTE(loop_state, 1), 1) -> Reverse 620 auto gte11 = builder.AddInstruction( 621 HloInstruction::CreateGetTupleElement(data_shape_, gte1, 1)); 622 auto rev11 = builder.AddInstruction( 623 HloInstruction::CreateReverse(data_shape_, gte11, {0})); 624 625 // Create output Tuple. 626 auto inner_tuple = 627 builder.AddInstruction(HloInstruction::CreateTuple({add10, rev11})); 628 builder.AddInstruction(HloInstruction::CreateTuple({add0, inner_tuple})); 629 return builder.Build(); 630 } 631 632 // Builds a While instruction using 'condition' and 'body' sub-computations. 633 // Init operand is initialized to zeros of appropriate shape. 634 HloInstruction* BuildWhileInstruction(HloComputation* condition, 635 HloComputation* body, 636 bool nested = false) { 637 auto builder = HloComputation::Builder(TestName() + ".While"); 638 auto induction_var_init = builder.AddInstruction( 639 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0))); 640 641 auto data_init = builder.AddInstruction( 642 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>( 643 {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); 644 645 if (nested) { 646 auto inner_init = builder.AddInstruction( 647 HloInstruction::CreateTuple({data_init, data_init})); 648 auto loop_state_init = builder.AddInstruction( 649 HloInstruction::CreateTuple({induction_var_init, inner_init})); 650 auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile( 651 loop_state_init->shape(), condition, body, loop_state_init)); 652 module_->AddEntryComputation(builder.Build()); 653 return while_hlo; 654 } 655 656 auto loop_state_init = builder.AddInstruction( 657 HloInstruction::CreateTuple({induction_var_init, data_init})); 658 auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile( 659 loop_state_shape_, condition, body, loop_state_init)); 660 module_->AddEntryComputation(builder.Build()); 661 return while_hlo; 662 } 663 664 HloInstruction* BuildWhileInstruction_InitPointsToConstant() { 665 auto builder = HloComputation::Builder(TestName() + ".While"); 666 auto data_init = builder.AddInstruction( 667 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>( 668 {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); 669 return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init, 670 &builder); 671 } 672 673 HloInstruction* BuildWhileInstruction_InitPointsToParameter() { 674 auto builder = HloComputation::Builder(TestName() + ".While"); 675 auto data_init = builder.AddInstruction( 676 HloInstruction::CreateParameter(0, data_shape_, "data_init")); 677 return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init, 678 &builder); 679 } 680 681 HloInstruction* BuildWhileInstruction_InitPointsToAmbiguous() { 682 auto builder = HloComputation::Builder(TestName() + ".While"); 683 684 auto one = builder.AddInstruction( 685 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); 686 auto v1 = builder.AddInstruction( 687 HloInstruction::CreateBroadcast(data_shape_, one, {1})); 688 auto zero = builder.AddInstruction( 689 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); 690 auto v2 = builder.AddInstruction( 691 HloInstruction::CreateBroadcast(data_shape_, zero, {1})); 692 693 auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({v1, v2})); 694 auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1})); 695 696 auto pred = builder.AddInstruction( 697 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); 698 auto data_init = builder.AddInstruction(HloInstruction::CreateTernary( 699 nested_tuple_shape_, HloOpcode::kTupleSelect, pred, tuple1, tuple2)); 700 701 return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_, 702 data_init, &builder); 703 } 704 705 HloInstruction* BuildWhileInstruction_InitPointsToNonDistinct() { 706 auto builder = HloComputation::Builder(TestName() + ".While"); 707 708 auto one = builder.AddInstruction( 709 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); 710 auto one_vec = builder.AddInstruction( 711 HloInstruction::CreateBroadcast(data_shape_, one, {1})); 712 auto data_init = 713 builder.AddInstruction(HloInstruction::CreateTuple({one_vec, one_vec})); 714 715 return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_, 716 data_init, &builder); 717 } 718 719 HloInstruction* BuildWhileInstruction_InitPointsToInterfering() { 720 auto builder = HloComputation::Builder(TestName() + ".While"); 721 auto one = builder.AddInstruction( 722 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); 723 auto data_init = builder.AddInstruction( 724 HloInstruction::CreateBroadcast(data_shape_, one, {1})); 725 auto one_vec = builder.AddInstruction( 726 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>( 727 {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); 728 // Take a reference to 'data_init' to make it interfere with while result. 729 auto add = builder.AddInstruction(HloInstruction::CreateBinary( 730 data_shape_, HloOpcode::kAdd, data_init, one_vec)); 731 732 auto xla_while = BuildWhileInstructionWithCustomInit(loop_state_shape_, 733 data_init, &builder); 734 735 // Add an additional binary operation operating on the while and the 736 // interfering add so that neither operation is dead. 737 auto gte = xla_while->parent()->AddInstruction( 738 HloInstruction::CreateGetTupleElement( 739 ShapeUtil::GetSubshape(xla_while->shape(), {1}), xla_while, 1)); 740 auto sub = xla_while->parent()->AddInstruction(HloInstruction::CreateBinary( 741 data_shape_, HloOpcode::kSubtract, add, gte)); 742 auto gte0 = xla_while->parent()->AddInstruction( 743 HloInstruction::CreateGetTupleElement( 744 ShapeUtil::GetSubshape(xla_while->shape(), {0}), xla_while, 0)); 745 auto tuple = xla_while->parent()->AddInstruction( 746 HloInstruction::CreateTuple({gte0, sub})); 747 748 xla_while->parent()->set_root_instruction(tuple); 749 750 return xla_while; 751 } 752 753 HloInstruction* BuildWhileInstructionWithCustomInit( 754 const Shape& loop_state_shape, HloInstruction* data_init, 755 HloComputation::Builder* builder) { 756 const bool nested = 757 ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_); 758 auto induction_var_init = builder->AddInstruction( 759 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0))); 760 auto condition = module_->AddEmbeddedComputation( 761 BuildConditionComputation(loop_state_shape)); 762 auto body = module_->AddEmbeddedComputation( 763 BuildIndependentBodyComputation(nested)); 764 auto loop_state_init = builder->AddInstruction( 765 HloInstruction::CreateTuple({induction_var_init, data_init})); 766 auto while_hlo = builder->AddInstruction(HloInstruction::CreateWhile( 767 loop_state_shape, condition, body, loop_state_init)); 768 module_->AddEntryComputation(builder->Build()); 769 return while_hlo; 770 } 771 772 std::unique_ptr<HloModule> module_; 773 Shape induction_variable_shape_ = ShapeUtil::MakeShape(S32, {}); 774 Shape data_shape_ = ShapeUtil::MakeShape(F32, {8}); 775 Shape loop_state_shape_ = 776 ShapeUtil::MakeTupleShape({induction_variable_shape_, data_shape_}); 777 Shape nested_tuple_shape_ = 778 ShapeUtil::MakeTupleShape({data_shape_, data_shape_}); 779 Shape nested_loop_state_shape_ = ShapeUtil::MakeTupleShape( 780 {induction_variable_shape_, nested_tuple_shape_}); 781 Shape condition_result_shape_ = ShapeUtil::MakeShape(PRED, {}); 782 }; 783 784 // Tests while body computation with independent tuple elements: 785 // 786 // While.Body({in0, in1}) 787 // out0 = Add(in0, 1) 788 // out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) 789 // Tuple(out0, out1) 790 // 791 // CopyInsertion pass should not generate any copies. 792 // 793 TEST_F(WhileCopyInsertionTest, IndependentTupleElements) { 794 auto condition = module_->AddEmbeddedComputation( 795 BuildConditionComputation(loop_state_shape_)); 796 auto body = 797 module_->AddEmbeddedComputation(BuildIndependentBodyComputation()); 798 auto while_hlo = BuildWhileInstruction(condition, body); 799 800 InsertCopies(module_.get()); 801 802 // Body should have no copies as the adds can be done inplace. 803 EXPECT_EQ(CountCopies(*body), 0); 804 EXPECT_EQ(CountControlEdges(*module_), 0); 805 806 // Both init indices need copies as they are constants. 807 EXPECT_THAT(while_hlo->operand(0), 808 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant()))); 809 } 810 811 // Tests while body computation with dependent tuple elements: 812 // 813 // While.Body({in0, in1}) 814 // out0 = Add(in0, 1) 815 // out1 = Add(BCast(in0), in1) 816 // Tuple(out0, out1) 817 // 818 // CopyInsertion pass should convert the root instruction to: 819 // 820 // Tuple(Copy(out0), out1) 821 // 822 TEST_F(WhileCopyInsertionTest, DependentTupleElements) { 823 auto condition = module_->AddEmbeddedComputation( 824 BuildConditionComputation(loop_state_shape_)); 825 auto body = module_->AddEmbeddedComputation(BuildDependentBodyComputation()); 826 auto while_hlo = BuildWhileInstruction(condition, body); 827 828 InsertCopies(module_.get()); 829 830 EXPECT_EQ(CountCopies(*body), 1); 831 EXPECT_EQ(CountControlEdges(*body), 0); 832 833 EXPECT_THAT( 834 body->root_instruction(), 835 op::Tuple(op::Add(), op::Add(op::GetTupleElement(), op::Broadcast()))); 836 837 auto add = body->root_instruction()->operand(0); 838 auto bcast = body->root_instruction()->operand(1)->operand(1); 839 ASSERT_EQ(add->opcode(), HloOpcode::kAdd); 840 ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); 841 842 EXPECT_THAT( 843 while_hlo->while_body()->root_instruction(), 844 op::Tuple(op::Add(op::Copy(), op::Constant()), 845 op::Add(op::GetTupleElement(), op::Broadcast(op::Copy())))); 846 847 // Both init indices need copies as they are constants. 848 EXPECT_THAT(while_hlo->operand(0), 849 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant()))); 850 } 851 852 // Tests while body computation with read-only tuple element 0: 853 // 854 // PARAMETER 855 // / \ 856 // GTE(0) GTE(1) 857 // | \ | 858 // | BCAST | 859 // | \ | 860 // | ADD 861 // | | 862 // \ / 863 // TUPLE (root) 864 // 865 // CopyInsertion pass should not generate any copies for the while body. 866 TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) { 867 auto condition = module_->AddEmbeddedComputation( 868 BuildConditionComputation(loop_state_shape_)); 869 auto body = module_->AddEmbeddedComputation( 870 BuildDependentBodyOneReadOnlyComputation()); 871 BuildWhileInstruction(condition, body); 872 873 InsertCopies(module_.get()); 874 875 // No copies or control edges should be inserted. The body is legal as is. 876 EXPECT_EQ(CountCopies(*body), 0); 877 EXPECT_EQ(CountControlEdges(*body), 0); 878 } 879 880 // Same as above, but with two while loops, sharing entry parameters. 881 TEST_F(WhileCopyInsertionTest, 882 DependentTupleElements_OneReadOnly_TwoLoops_EntryParams) { 883 auto condition1 = module_->AddEmbeddedComputation( 884 BuildConditionComputation(loop_state_shape_)); 885 auto condition2 = module_->AddEmbeddedComputation( 886 BuildConditionComputation(loop_state_shape_)); 887 auto body1 = module_->AddEmbeddedComputation( 888 BuildDependentBodyOneReadOnlyComputation()); 889 auto body2 = module_->AddEmbeddedComputation( 890 BuildDependentBodyOneReadOnlyComputation()); 891 892 auto builder = HloComputation::Builder(TestName() + ".While"); 893 auto iter_param = builder.AddInstruction( 894 HloInstruction::CreateParameter(0, induction_variable_shape_, "iter")); 895 auto data_param = builder.AddInstruction( 896 HloInstruction::CreateParameter(1, data_shape_, "data")); 897 auto loop_init = builder.AddInstruction( 898 HloInstruction::CreateTuple({iter_param, data_param})); 899 900 auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( 901 loop_state_shape_, condition1, body1, loop_init)); 902 auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( 903 loop_state_shape_, condition2, body2, loop_init)); 904 905 // Add a couple elements from each of the while so both whiles are live. 906 auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 907 ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0)); 908 auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 909 ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0)); 910 builder.AddInstruction( 911 HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2)); 912 913 auto entry = module_->AddEntryComputation(builder.Build()); 914 915 InsertCopies(module_.get()); 916 917 // Neither body should have any copies or control edges in them. 918 EXPECT_EQ(CountCopies(*body1), 0); 919 EXPECT_EQ(CountCopies(*body2), 0); 920 EXPECT_EQ(CountControlEdges(*body1), 0); 921 EXPECT_EQ(CountControlEdges(*body2), 0); 922 923 // Only two copies should be necessary. Each of the whiles should have 924 // a copy of tuple element 1 (init value is a parameter, and the element is 925 // not non-read-only) so each of the while bodies gets its own buffer to write 926 // element 1 into. 927 EXPECT_EQ(CountCopies(*entry), 2); 928 929 EXPECT_EQ(while_hlo1->operand(0)->operand(1)->opcode(), HloOpcode::kCopy); 930 EXPECT_EQ(while_hlo2->operand(0)->operand(1)->opcode(), HloOpcode::kCopy); 931 932 // The two copies of element 1 should be different. 933 EXPECT_NE(while_hlo1->operand(0)->operand(1), 934 while_hlo2->operand(0)->operand(1)); 935 } 936 937 // Same as above, but with two while loops, sharing non-parameters. 938 TEST_F(WhileCopyInsertionTest, 939 DependentTupleElements_OneReadOnly_TwoLoops_NonParams) { 940 auto condition1 = module_->AddEmbeddedComputation( 941 BuildConditionComputation(loop_state_shape_)); 942 auto condition2 = module_->AddEmbeddedComputation( 943 BuildConditionComputation(loop_state_shape_)); 944 auto body1 = module_->AddEmbeddedComputation( 945 BuildDependentBodyOneReadOnlyComputation()); 946 auto body2 = module_->AddEmbeddedComputation( 947 BuildDependentBodyOneReadOnlyComputation()); 948 949 auto builder = HloComputation::Builder(TestName() + ".While"); 950 auto iter_param = builder.AddInstruction( 951 HloInstruction::CreateParameter(0, induction_variable_shape_, "iter")); 952 auto data_param = builder.AddInstruction( 953 HloInstruction::CreateParameter(1, data_shape_, "data")); 954 // Add dummy ops to ensure loop_init elements aren't entry parameters. 955 auto iter_value = builder.AddInstruction(HloInstruction::CreateUnary( 956 iter_param->shape(), HloOpcode::kExp, iter_param)); 957 auto data_value = builder.AddInstruction(HloInstruction::CreateUnary( 958 data_param->shape(), HloOpcode::kExp, data_param)); 959 auto loop_init = builder.AddInstruction( 960 HloInstruction::CreateTuple({iter_value, data_value})); 961 962 auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( 963 loop_state_shape_, condition1, body1, loop_init)); 964 auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( 965 loop_state_shape_, condition2, body2, loop_init)); 966 967 // Add a couple elements from each of the while so both whiles are not dead. 968 auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 969 ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0)); 970 auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 971 ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0)); 972 builder.AddInstruction( 973 HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2)); 974 auto entry = module_->AddEntryComputation(builder.Build()); 975 976 InsertCopies(module_.get()); 977 978 // Ideally only one copy should be necessary. One of the whiles should 979 // have a copy of tuple element 1 (the non-read-only element) so each of the 980 // while bodies gets its own buffer to write element 1 into. However, the 981 // analysis isn't perfect and adds an additional copy of element 0. 982 EXPECT_EQ(CountCopies(*entry), 2); 983 984 EXPECT_THAT(while_hlo1->operand(0), 985 op::Tuple(op::Exp(), op::Copy(op::Exp()))); 986 EXPECT_THAT(while_hlo2->operand(0), 987 op::Tuple(op::Exp(), op::Copy(op::Exp()))); 988 } 989 990 // Tests while body computation with nested tuple elements: 991 // 992 // | 993 // GTE(loop_state, 1) 994 // / \ 995 // GTE(GTE(loop_state, 1), 0) GTE(GTE(loop_state, 1), 1) 996 // | | 997 // Add Reverse 998 // | | 999 // 1000 // CopyInsertion pass will conceptually generate the following, but with the 1001 // actual GTE and Tuple instructions optimized away: 1002 // 1003 // Tuple // old root 1004 // / \ 1005 // / \ 1006 // GTE(0) GTE(1) 1007 // | / \ 1008 // | / \ 1009 // | GTE(0) GTE(1) 1010 // | | | 1011 // | | Copy 1012 // | | | 1013 // \ | / 1014 // \ Tuple // "inner" tuple. 1015 // \ / 1016 // \ / 1017 // Tuple // new root 1018 // 1019 TEST_F(WhileCopyInsertionTest, NestedTupleElements) { 1020 auto condition = module_->AddEmbeddedComputation( 1021 BuildConditionComputation(nested_loop_state_shape_)); 1022 auto body = module_->AddEmbeddedComputation(BuildNestedBodyComputation()); 1023 BuildWhileInstruction(condition, body, true); 1024 1025 // HloInstruction* old_root = body->root_instruction(); 1026 InsertCopies(module_.get()); 1027 1028 // The only copy necessary is for the kReverse as it cannot be done 1029 // in-place (instruction can share buffer with operand). The other elements of 1030 // the loop state are kAdd instructions which can be done in-place. 1031 EXPECT_EQ(CountCopies(*body), 1); 1032 1033 // Each element of the init needs a copy as all are constants. 1034 EXPECT_EQ(CountCopies(*module_), 4); 1035 1036 // Either the kReverse itself must be copied or the operand of the kReverse 1037 // must be copied. 1038 if (body->root_instruction()->operand(1)->operand(1)->opcode() == 1039 HloOpcode::kCopy) { 1040 EXPECT_THAT( 1041 body->root_instruction(), 1042 op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Reverse())))); 1043 } else { 1044 EXPECT_THAT( 1045 body->root_instruction(), 1046 op::Tuple(op::Add(), op::Tuple(op::Add(), op::Reverse(op::Copy())))); 1047 } 1048 } 1049 1050 // Tests while init instruction which points-to a constant. 1051 // 1052 // init = Tuple(Constant(S32, {}), Constant(F32, {8})) 1053 // 1054 // CopyInsertion pass should add copies for both constants. 1055 // 1056 TEST_F(WhileCopyInsertionTest, InitPointsToConstant) { 1057 auto while_hlo = BuildWhileInstruction_InitPointsToConstant(); 1058 1059 InsertCopies(module_.get()); 1060 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0); 1061 EXPECT_EQ(CountCopies(*module_), 2); 1062 1063 EXPECT_THAT(while_hlo->operand(0), 1064 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant()))); 1065 } 1066 1067 // Tests while init instruction which points-to a parameter. 1068 // 1069 // init = Tuple(Constant(S32, {}), Parameter(F32, {8})) 1070 // 1071 // CopyInsertion pass should add copies for both the constant and parameter. 1072 // 1073 TEST_F(WhileCopyInsertionTest, InitPointsToParameter) { 1074 auto while_hlo = BuildWhileInstruction_InitPointsToParameter(); 1075 1076 InsertCopies(module_.get()); 1077 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0); 1078 EXPECT_EQ(CountCopies(*module_), 2); 1079 1080 EXPECT_THAT(while_hlo->operand(0), 1081 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Parameter()))); 1082 } 1083 1084 // Tests while init instruction which has an ambiguous points-to set. 1085 // 1086 // select = Select(pred, tuple1, tuple2) 1087 // init = Tuple(Constant(S32, {}), Parameter(F32, {8})) 1088 // 1089 // CopyInsertion pass will conceptually generate the following, but with some of 1090 // the actual GTE and Tuple instructions optimized away: 1091 // 1092 // Tuple // old init 1093 // / \ 1094 // / \ 1095 // GTE(0) GTE(1) 1096 // | / \ 1097 // | / \ 1098 // | GTE(0) GTE(1) 1099 // | | | 1100 // Copy Copy Copy 1101 // | | | 1102 // \ | / 1103 // \ Tuple 1104 // \ / 1105 // \ / 1106 // Tuple // new init 1107 // 1108 TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) { 1109 auto while_hlo = BuildWhileInstruction_InitPointsToAmbiguous(); 1110 1111 InsertCopies(module_.get()); 1112 EXPECT_EQ(CountCopies(*module_), 4); 1113 // The entry computation requires three copies to resolve the ambiguity of two 1114 // init elements and the constant passed in as one of the init elements. 1115 EXPECT_EQ(CountCopies(*module_->entry_computation()), 3); 1116 EXPECT_THAT(while_hlo->operand(0), 1117 op::Tuple(op::Copy(op::Constant()), 1118 op::Tuple(op::Copy(op::GetTupleElement()), 1119 op::Copy(op::GetTupleElement())))); 1120 1121 // The body requires one copy because the buffer set is not distinct: the 1122 // result of one of the adds is written into two elements of the output of the 1123 // loop body. Either element might be copied. 1124 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1); 1125 if (while_hlo->while_body() 1126 ->root_instruction() 1127 ->operand(1) 1128 ->operand(0) 1129 ->opcode() == HloOpcode::kCopy) { 1130 EXPECT_THAT( 1131 while_hlo->while_body()->root_instruction(), 1132 op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add()))); 1133 } else { 1134 EXPECT_THAT( 1135 while_hlo->while_body()->root_instruction(), 1136 op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add())))); 1137 } 1138 } 1139 1140 // Tests while init instruction which has a non-distinct points-to set. 1141 // 1142 // init = Tuple(Constant(S32, {}), Tuple({vec_one, vec_one})) 1143 // 1144 // CopyInsertion pass will conceptually generate the following, but with some of 1145 // the actual GTE and Tuple instructions optimized away: 1146 // 1147 // Tuple // old init 1148 // / \ 1149 // / \ 1150 // GTE(0) GTE(1) 1151 // | / \ 1152 // | / \ 1153 // | GTE(0) GTE(1) 1154 // | | | 1155 // Copy Copy Copy 1156 // | | | 1157 // \ | / 1158 // \ Tuple 1159 // \ / 1160 // \ / 1161 // Tuple // new init 1162 // 1163 TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) { 1164 auto while_hlo = BuildWhileInstruction_InitPointsToNonDistinct(); 1165 1166 InsertCopies(module_.get()); 1167 1168 // The entry computation requires two copies to resolve the non-disinctness of 1169 // two init elements and the constant passed in as one of the init 1170 // elements. Either element can be copied for the distinctness issue. 1171 EXPECT_EQ(CountCopies(*module_->entry_computation()), 2); 1172 if (while_hlo->operand(0)->operand(1)->operand(0)->opcode() == 1173 HloOpcode::kCopy) { 1174 EXPECT_THAT( 1175 while_hlo->operand(0), 1176 op::Tuple(op::Copy(op::Constant()), 1177 op::Tuple(op::Copy(op::Broadcast()), op::Broadcast()))); 1178 } else { 1179 EXPECT_THAT( 1180 while_hlo->operand(0), 1181 op::Tuple(op::Copy(op::Constant()), 1182 op::Tuple(op::Broadcast(), op::Copy(op::Broadcast())))); 1183 } 1184 1185 // The body requires one copy because the buffer set is not distinct: the 1186 // result of one of the adds is written into two elements of the output of the 1187 // loop body. Either element might be copied. 1188 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1); 1189 if (while_hlo->while_body() 1190 ->root_instruction() 1191 ->operand(1) 1192 ->operand(0) 1193 ->opcode() == HloOpcode::kCopy) { 1194 EXPECT_THAT( 1195 while_hlo->while_body()->root_instruction(), 1196 op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add()))); 1197 } else { 1198 EXPECT_THAT( 1199 while_hlo->while_body()->root_instruction(), 1200 op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add())))); 1201 } 1202 } 1203 1204 // Tests while init instruction buffer which interferes with while result 1205 // buffer. 1206 // 1207 // init_data = Broadcast(...) 1208 // add_unrelated = Add(init_data) // takes a reference to cause interference 1209 // init = Tuple(Constant(S32, {}), init_data)) 1210 // 1211 // CopyInsertion pass should copy both operands. 1212 // 1213 TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) { 1214 auto while_hlo = BuildWhileInstruction_InitPointsToInterfering(); 1215 1216 InsertCopies(module_.get()); 1217 EXPECT_EQ(CountCopies(*module_), 2); 1218 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0); 1219 1220 EXPECT_THAT(while_hlo->operand(0), 1221 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Broadcast()))); 1222 } 1223 1224 // Tests while init instruction buffer which has a non-distinct points-to set: 1225 // 1226 // init = Tuple(Parameter(S32, {}), Parameter(F32, {8}, 1227 // Parameter(F32, {8}))) 1228 // 1229 // where the second and third parameters are identical *and* the tuple shared 1230 // by another while instruction. 1231 // 1232 // Verifies that the resulting point-to set is distinct in the resulting Tuple 1233 // (non-identical Copys). In other words, verifies that copy sharing does not 1234 // insert identical copies to the resulting tuple. 1235 TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { 1236 // Loop body that outputs tuple comprises two elements dependent on the init 1237 // tuple. 1238 const Shape& loop_state_shape = ShapeUtil::MakeTupleShape( 1239 {induction_variable_shape_, data_shape_, data_shape_}); 1240 1241 auto condition1 = module_->AddEmbeddedComputation( 1242 BuildConditionComputation(loop_state_shape)); 1243 auto condition2 = module_->AddEmbeddedComputation( 1244 BuildConditionComputation(loop_state_shape)); 1245 auto body1 = 1246 module_->AddEmbeddedComputation(BuildDependentBodyComputation2()); 1247 auto body2 = 1248 module_->AddEmbeddedComputation(BuildDependentBodyComputation2()); 1249 1250 auto builder = HloComputation::Builder(TestName() + ".While"); 1251 1252 auto iter_param = builder.AddInstruction( 1253 HloInstruction::CreateParameter(0, induction_variable_shape_, "iter")); 1254 auto data_param = builder.AddInstruction( 1255 HloInstruction::CreateParameter(1, data_shape_, "data")); 1256 1257 // Loop init tuple contains two identical parameter buffers. 1258 auto loop_init = builder.AddInstruction( 1259 HloInstruction::CreateTuple({iter_param, data_param, data_param})); 1260 1261 // Two while loops shares the same loop init tuple. 1262 auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( 1263 loop_state_shape, condition1, body1, loop_init)); 1264 auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( 1265 loop_state_shape, condition2, body2, loop_init)); 1266 1267 // Add add instruction so neither while is dead. 1268 auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 1269 ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0)); 1270 auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 1271 ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo2, 0)); 1272 builder.AddInstruction( 1273 HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2)); 1274 1275 module_->AddEntryComputation(builder.Build()); 1276 1277 InsertCopies(module_.get()); 1278 1279 // None of the bodies should have copies or control flow edges. 1280 EXPECT_EQ(CountCopies(*body1), 0); 1281 EXPECT_EQ(CountCopies(*body2), 0); 1282 1283 // The loop bodies pass through elements 1 and 2 in the init tuple, so ideally 1284 // these should not need to be copied before either while. However, copy 1285 // insertion is not able to reason about the transparency of elements through 1286 // while bodies in all circumstances so extra copies are added (b/xxx). 1287 EXPECT_EQ(CountCopies(*module_->entry_computation()), 2); 1288 1289 EXPECT_THAT(while_hlo1->operand(0), 1290 op::Tuple(op::Copy(), op::Parameter(), op::Parameter())); 1291 EXPECT_THAT(while_hlo2->operand(0), 1292 op::Tuple(op::Copy(), op::Parameter(), op::Parameter())); 1293 } 1294 1295 TEST_F(CopyInsertionTest, SwizzlingWhile) { 1296 // Test a while instruction with a body which permutes its tuple parameter 1297 // elements. 1298 auto module = CreateNewVerifiedModule(); 1299 const Shape loop_state_shape = 1300 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); 1301 1302 // Body simply interchanges the two tuple elements in the loop state. 1303 auto body_builder = HloComputation::Builder("body"); 1304 auto body_param = body_builder.AddInstruction( 1305 HloInstruction::CreateParameter(0, loop_state_shape, "param")); 1306 auto body_element_0 = body_builder.AddInstruction( 1307 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); 1308 auto body_element_1 = body_builder.AddInstruction( 1309 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); 1310 body_builder.AddInstruction( 1311 HloInstruction::CreateTuple({body_element_1, body_element_0})); 1312 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); 1313 1314 auto cond_builder = HloComputation::Builder("condition"); 1315 cond_builder.AddInstruction( 1316 HloInstruction::CreateParameter(0, loop_state_shape, "param")); 1317 auto cond_constant = cond_builder.AddInstruction( 1318 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); 1319 cond_builder.AddInstruction(HloInstruction::CreateUnary( 1320 cond_constant->shape(), HloOpcode::kNot, cond_constant)); 1321 HloComputation* condition = 1322 module->AddEmbeddedComputation(cond_builder.Build()); 1323 1324 auto builder = HloComputation::Builder(TestName()); 1325 auto constant1 = builder.AddInstruction( 1326 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); 1327 auto constant2 = builder.AddInstruction( 1328 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0))); 1329 auto tuple = builder.AddInstruction( 1330 HloInstruction::CreateTuple({constant1, constant2})); 1331 auto xla_while = builder.AddInstruction( 1332 HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple)); 1333 module->AddEntryComputation(builder.Build()); 1334 1335 InsertCopies(module.get()); 1336 1337 EXPECT_EQ(CountCopies(*module), 6); 1338 1339 // The loop state elements should be copied at the parameter and at the root 1340 // with a control edge in between (see DeepCopyAndAddControlEdges). This is 1341 // technically one more copy than is strictly necessary, but in order to have 1342 // only three copies the copies of different loop state elements must be 1343 // ordered with a control edge. 1344 EXPECT_EQ(CountCopies(*body), 4); 1345 EXPECT_EQ(CountControlEdges(*body), 2); 1346 1347 EXPECT_THAT(body->root_instruction(), 1348 op::Tuple(op::Copy(op::Copy()), op::Copy(op::Copy()))); 1349 1350 EXPECT_EQ(CountCopies(*module->entry_computation()), 2); 1351 EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy())); 1352 } 1353 1354 TEST_F(CopyInsertionTest, CrossingParameters) { 1355 // Test a case where two parameters' dataflow cross with each other while 1356 // input and output are aliased with same index: 1357 // 1358 // (p0 , p1) 1359 // | \ /| 1360 // | \ / | 1361 // alias X alias 1362 // | / \ | 1363 // | / \| 1364 // (p1 , p0) 1365 auto module = CreateNewVerifiedModule(); 1366 const Shape tuple_shape = 1367 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); 1368 1369 auto builder = HloComputation::Builder(TestName()); 1370 auto param = builder.AddInstruction( 1371 HloInstruction::CreateParameter(0, tuple_shape, "0")); 1372 auto gte0 = builder.AddInstruction( 1373 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); 1374 auto gte1 = builder.AddInstruction( 1375 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); 1376 builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0})); 1377 module->AddEntryComputation(builder.Build()); 1378 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( 1379 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, 1380 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); 1381 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( 1382 /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, 1383 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); 1384 InsertCopies(module.get()); 1385 1386 EXPECT_EQ(CountCopies(*module), 4); 1387 } 1388 1389 TEST_F(CopyInsertionTest, ParametersAliasing) { 1390 // Test a case where two parameters' dataflow don't interfere with each other 1391 // while aliased. 1392 // 1393 // (p0 , p1) 1394 // | | 1395 // | | 1396 // alias alias 1397 // | | 1398 // | | 1399 // (p0 , p1) 1400 auto module = CreateNewVerifiedModule(); 1401 const Shape tuple_shape = 1402 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); 1403 1404 auto builder = HloComputation::Builder(TestName()); 1405 auto param = builder.AddInstruction( 1406 HloInstruction::CreateParameter(0, tuple_shape, "p0")); 1407 auto gte0 = builder.AddInstruction( 1408 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); 1409 auto gte1 = builder.AddInstruction( 1410 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); 1411 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); 1412 module->AddEntryComputation(builder.Build()); 1413 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( 1414 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, 1415 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); 1416 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( 1417 /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, 1418 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); 1419 InsertCopies(module.get()); 1420 1421 EXPECT_EQ(CountCopies(*module), 0); 1422 } 1423 1424 TEST_F(CopyInsertionTest, ParameterWithNoAliasing) { 1425 // Test a case where no parameter is aliased with result. In this case, copy 1426 // should be added 1427 // 1428 // (p0 , p1) 1429 // | | 1430 // | | 1431 // | | 1432 // | | 1433 // | | 1434 // (p0 , p1) 1435 auto module = CreateNewVerifiedModule(); 1436 const Shape tuple_shape = 1437 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); 1438 1439 auto builder = HloComputation::Builder(TestName()); 1440 auto param = builder.AddInstruction( 1441 HloInstruction::CreateParameter(0, tuple_shape, "p0")); 1442 auto gte0 = builder.AddInstruction( 1443 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); 1444 auto gte1 = builder.AddInstruction( 1445 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); 1446 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); 1447 module->AddEntryComputation(builder.Build()); 1448 InsertCopies(module.get()); 1449 1450 EXPECT_THAT(module->entry_computation()->root_instruction(), 1451 op::Tuple(op::Copy(op::GetTupleElement(param, 0)), 1452 op::Copy(op::GetTupleElement(param, 1)))); 1453 1454 EXPECT_EQ(CountCopies(*module), 2); 1455 } 1456 1457 TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) { 1458 // Test a case where one parameter is aliased with result while another one 1459 // isn't. 1460 // 1461 // (p0 , p1) 1462 // | | 1463 // | | 1464 // alias | 1465 // | | 1466 // | | 1467 // (p0 , p1) 1468 auto module = CreateNewVerifiedModule(); 1469 const Shape tuple_shape = 1470 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); 1471 1472 auto builder = HloComputation::Builder(TestName()); 1473 auto param = builder.AddInstruction( 1474 HloInstruction::CreateParameter(0, tuple_shape, "p0")); 1475 auto gte0 = builder.AddInstruction( 1476 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); 1477 auto gte1 = builder.AddInstruction( 1478 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); 1479 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); 1480 module->AddEntryComputation(builder.Build()); 1481 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( 1482 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, 1483 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); 1484 InsertCopies(module.get()); 1485 1486 EXPECT_THAT(module->entry_computation()->root_instruction(), 1487 op::Tuple(op::GetTupleElement(param, 0), 1488 op::Copy(op::GetTupleElement(param, 1)))); 1489 1490 EXPECT_EQ(CountCopies(*module), 1); 1491 } 1492 1493 TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) { 1494 // Test a case where one parameter is aliased with result while another one 1495 // isn't. 1496 // 1497 // +-- (p0 , p1) 1498 // | | | 1499 // | | | 1500 // alias Negate Negate 1501 // | | | 1502 // | | | 1503 // +-- (p0 , p1) 1504 auto module = CreateNewVerifiedModule(); 1505 const Shape tuple_shape = 1506 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); 1507 1508 auto builder = HloComputation::Builder(TestName()); 1509 auto param = builder.AddInstruction( 1510 HloInstruction::CreateParameter(0, tuple_shape, "p0")); 1511 auto gte0 = builder.AddInstruction( 1512 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); 1513 auto gte1 = builder.AddInstruction( 1514 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); 1515 1516 auto negate0 = builder.AddInstruction( 1517 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0)); 1518 1519 auto negate1 = builder.AddInstruction( 1520 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1)); 1521 builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); 1522 module->AddEntryComputation(builder.Build()); 1523 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( 1524 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, 1525 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); 1526 InsertCopies(module.get()); 1527 1528 EXPECT_EQ(CountCopies(*module), 0); 1529 } 1530 1531 TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) { 1532 // Test a case where one parameter is aliased with result while another one 1533 // isn't. 1534 // 1535 // +-- (p0 , p1) 1536 // | | | 1537 // | | | 1538 // alias Negate Negate 1539 // | | | 1540 // | Add----+ 1541 // | | | 1542 // +-- (p0 , p1) 1543 auto module = CreateNewVerifiedModule(); 1544 const Shape tuple_shape = 1545 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); 1546 1547 auto builder = HloComputation::Builder(TestName()); 1548 auto param = builder.AddInstruction( 1549 HloInstruction::CreateParameter(0, tuple_shape, "p0")); 1550 auto gte0 = builder.AddInstruction( 1551 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); 1552 auto gte1 = builder.AddInstruction( 1553 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); 1554 1555 auto negate0 = builder.AddInstruction( 1556 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0)); 1557 1558 auto negate1 = builder.AddInstruction( 1559 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1)); 1560 1561 auto add = builder.AddInstruction(HloInstruction::CreateBinary( 1562 scalar_shape_, HloOpcode::kAdd, negate0, negate1)); 1563 builder.AddInstruction(HloInstruction::CreateTuple({add, negate1})); 1564 module->AddEntryComputation(builder.Build()); 1565 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( 1566 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, 1567 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); 1568 InsertCopies(module.get()); 1569 1570 EXPECT_EQ(CountCopies(*module), 0); 1571 } 1572 1573 TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { 1574 // Test a while instruction with a body which permutes its tuple parameter 1575 // elements and applies one operation to one of the elements. The addition of 1576 // the operation (instruction) on the element makes the live range of the 1577 // respective input and output elements different than if the instruction were 1578 // not there (as in the SwizzlingWhile test above). 1579 auto module = CreateNewVerifiedModule(); 1580 const Shape loop_state_shape = 1581 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); 1582 1583 // Body interchanges the two tuple elements in the loop state and negates one 1584 // of them. 1585 auto body_builder = HloComputation::Builder("body"); 1586 auto body_param = body_builder.AddInstruction( 1587 HloInstruction::CreateParameter(0, loop_state_shape, "param")); 1588 auto body_element_0 = body_builder.AddInstruction( 1589 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); 1590 auto body_element_1 = body_builder.AddInstruction( 1591 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); 1592 auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( 1593 scalar_shape_, HloOpcode::kNegate, body_element_1)); 1594 body_builder.AddInstruction( 1595 HloInstruction::CreateTuple({negate, body_element_0})); 1596 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); 1597 1598 auto cond_builder = HloComputation::Builder("condition"); 1599 cond_builder.AddInstruction( 1600 HloInstruction::CreateParameter(0, loop_state_shape, "param")); 1601 auto cond_constant = cond_builder.AddInstruction( 1602 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); 1603 cond_builder.AddInstruction(HloInstruction::CreateUnary( 1604 cond_constant->shape(), HloOpcode::kNot, cond_constant)); 1605 HloComputation* condition = 1606 module->AddEmbeddedComputation(cond_builder.Build()); 1607 1608 auto builder = HloComputation::Builder(TestName()); 1609 auto constant1 = builder.AddInstruction( 1610 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); 1611 auto constant2 = builder.AddInstruction( 1612 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0))); 1613 auto tuple = builder.AddInstruction( 1614 HloInstruction::CreateTuple({constant1, constant2})); 1615 auto xla_while = builder.AddInstruction( 1616 HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple)); 1617 module->AddEntryComputation(builder.Build()); 1618 1619 InsertCopies(module.get()); 1620 1621 EXPECT_EQ(CountCopies(*module), 6); 1622 1623 // The loop state elements should be copied at the parameter and at the root 1624 // with a control edge in between (see DeepCopyAndAddControlEdges). 1625 EXPECT_EQ(CountCopies(*body), 4); 1626 EXPECT_EQ(CountControlEdges(*body), 2); 1627 1628 EXPECT_THAT( 1629 body->root_instruction(), 1630 op::Tuple(op::Copy(op::Negate(op::Copy())), op::Copy(op::Copy()))); 1631 1632 EXPECT_EQ(CountCopies(*module->entry_computation()), 2); 1633 EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy())); 1634 } 1635 1636 TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) { 1637 // Test a while instruction with a body which permutes it's tuple parameter 1638 // elements similar to SwizzlinWhile above. However, in this test the input to 1639 // the while body is a single constant (both loop state elements are the same 1640 // constant). This means no copies are necessary because both loop state 1641 // elements are the same so interchanging them is a no-op. 1642 auto module = CreateNewVerifiedModule(); 1643 const Shape loop_state_shape = 1644 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); 1645 1646 // Body simply interchanges the two tuple elements in the loop state. 1647 auto body_builder = HloComputation::Builder("body"); 1648 auto body_param = body_builder.AddInstruction( 1649 HloInstruction::CreateParameter(0, loop_state_shape, "param")); 1650 auto body_element_0 = body_builder.AddInstruction( 1651 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); 1652 auto body_element_1 = body_builder.AddInstruction( 1653 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); 1654 body_builder.AddInstruction( 1655 HloInstruction::CreateTuple({body_element_1, body_element_0})); 1656 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); 1657 1658 auto cond_builder = HloComputation::Builder("condition"); 1659 cond_builder.AddInstruction( 1660 HloInstruction::CreateParameter(0, loop_state_shape, "param")); 1661 auto cond_constant = cond_builder.AddInstruction( 1662 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); 1663 cond_builder.AddInstruction(HloInstruction::CreateUnary( 1664 cond_constant->shape(), HloOpcode::kNot, cond_constant)); 1665 HloComputation* condition = 1666 module->AddEmbeddedComputation(cond_builder.Build()); 1667 1668 auto builder = HloComputation::Builder(TestName()); 1669 auto constant = builder.AddInstruction( 1670 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); 1671 auto tuple = 1672 builder.AddInstruction(HloInstruction::CreateTuple({constant, constant})); 1673 builder.AddInstruction( 1674 HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple)); 1675 module->AddEntryComputation(builder.Build()); 1676 1677 InsertCopies(module.get()); 1678 1679 EXPECT_EQ(CountCopies(*module), 2); 1680 EXPECT_EQ(CountCopies(*body), 0); 1681 1682 EXPECT_EQ(CountCopies(*module->entry_computation()), 2); 1683 EXPECT_THAT(module->entry_computation()->root_instruction(), 1684 op::Tuple(op::Copy(), op::Copy())); 1685 } 1686 1687 TEST_F(CopyInsertionTest, SequentialWhiles) { 1688 // Construct a computation with a series of sequential while instructions 1689 // containing four loop state elements: 1690 // 1691 // element 0 is passed to each while directly from an entry parameter. 1692 // 1693 // element 1 is passed transparently in series through all the while bodies. 1694 // 1695 // element 2 is negated in each while body. (in-place possible) 1696 // 1697 // element 3 is reversed in each while body. (in-place not possible) 1698 // 1699 const Shape element_shape = ShapeUtil::MakeShape(F32, {42}); 1700 const Shape loop_state_shape = ShapeUtil::MakeTupleShape( 1701 {element_shape, element_shape, element_shape, element_shape}); 1702 1703 auto module = CreateNewVerifiedModule(); 1704 auto builder = HloComputation::Builder(TestName()); 1705 auto param_0 = builder.AddInstruction( 1706 HloInstruction::CreateParameter(0, element_shape, "param_0")); 1707 auto param_1 = builder.AddInstruction( 1708 HloInstruction::CreateParameter(1, element_shape, "param_1")); 1709 auto param_2 = builder.AddInstruction( 1710 HloInstruction::CreateParameter(2, element_shape, "param_2")); 1711 auto param_3 = builder.AddInstruction( 1712 HloInstruction::CreateParameter(3, element_shape, "param_3")); 1713 1714 // The number of sequential kWhile instructions. 1715 const int kNumWhiles = 3; 1716 1717 HloInstruction* prev_element_1 = param_1; 1718 HloInstruction* prev_element_2 = param_2; 1719 HloInstruction* prev_element_3 = param_3; 1720 1721 // Vector containing all of the while instructions. 1722 std::vector<const HloInstruction*> whiles; 1723 for (int i = 0; i < kNumWhiles; ++i) { 1724 auto body_builder = HloComputation::Builder("body"); 1725 auto body_param = body_builder.AddInstruction( 1726 HloInstruction::CreateParameter(0, loop_state_shape, "param")); 1727 auto body_element_0 = body_builder.AddInstruction( 1728 HloInstruction::CreateGetTupleElement(element_shape, body_param, 0)); 1729 auto body_element_1 = body_builder.AddInstruction( 1730 HloInstruction::CreateGetTupleElement(element_shape, body_param, 1)); 1731 auto body_element_2 = body_builder.AddInstruction( 1732 HloInstruction::CreateGetTupleElement(element_shape, body_param, 2)); 1733 auto body_element_3 = body_builder.AddInstruction( 1734 HloInstruction::CreateGetTupleElement(element_shape, body_param, 3)); 1735 auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( 1736 element_shape, HloOpcode::kNegate, body_element_2)); 1737 auto reverse = body_builder.AddInstruction( 1738 HloInstruction::CreateReverse(element_shape, body_element_3, {0})); 1739 body_builder.AddInstruction(HloInstruction::CreateTuple( 1740 {body_element_0, body_element_1, negate, reverse})); 1741 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); 1742 1743 auto cond_builder = HloComputation::Builder("condition"); 1744 cond_builder.AddInstruction( 1745 HloInstruction::CreateParameter(0, loop_state_shape, "param")); 1746 auto cond_constant = cond_builder.AddInstruction( 1747 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); 1748 cond_builder.AddInstruction(HloInstruction::CreateUnary( 1749 cond_constant->shape(), HloOpcode::kNot, cond_constant)); 1750 HloComputation* condition = 1751 module->AddEmbeddedComputation(cond_builder.Build()); 1752 1753 auto while_init = builder.AddInstruction(HloInstruction::CreateTuple( 1754 {param_0, prev_element_1, prev_element_2, prev_element_3})); 1755 1756 auto xla_while = builder.AddInstruction(HloInstruction::CreateWhile( 1757 loop_state_shape, condition, body, while_init)); 1758 whiles.push_back(xla_while); 1759 if (i != kNumWhiles - 1) { 1760 prev_element_1 = builder.AddInstruction( 1761 HloInstruction::CreateGetTupleElement(element_shape, xla_while, 1)); 1762 prev_element_2 = builder.AddInstruction( 1763 HloInstruction::CreateGetTupleElement(element_shape, xla_while, 2)); 1764 prev_element_3 = builder.AddInstruction( 1765 HloInstruction::CreateGetTupleElement(element_shape, xla_while, 3)); 1766 } 1767 } 1768 1769 module->AddEntryComputation(builder.Build()); 1770 1771 InsertCopies(module.get()); 1772 1773 // Each while body has one copy. And each loop state element is copied once in 1774 // the entry computation. 1775 EXPECT_EQ(CountCopies(*module), 4 + kNumWhiles); 1776 1777 // Each while body should have exactly one copy for element three which is an 1778 // op (kReverse) which cannot be done in place. 1779 for (const HloInstruction* xla_while : whiles) { 1780 EXPECT_EQ(CountCopies(*xla_while->while_body()), 1); 1781 } 1782 1783 EXPECT_THAT(whiles[0]->operand(0), op::Tuple(op::Parameter(), op::Parameter(), 1784 op::Copy(), op::Copy())); 1785 EXPECT_THAT(module->entry_computation()->root_instruction(), 1786 op::Tuple(op::Copy(), op::Copy(), op::GetTupleElement(), 1787 op::GetTupleElement())); 1788 } 1789 1790 TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) { 1791 // Test a while body and condition which are each simply a constant (root of 1792 // computation is a constant). The body constant should be copied. 1793 auto module = CreateNewVerifiedModule(); 1794 auto builder = HloComputation::Builder(TestName()); 1795 auto param_0 = builder.AddInstruction( 1796 HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); 1797 1798 auto body_builder = HloComputation::Builder("body"); 1799 body_builder.AddInstruction( 1800 HloInstruction::CreateParameter(0, scalar_shape_, "param")); 1801 body_builder.AddInstruction( 1802 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0))); 1803 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); 1804 1805 auto cond_builder = HloComputation::Builder("condition"); 1806 cond_builder.AddInstruction( 1807 HloInstruction::CreateParameter(0, scalar_shape_, "param")); 1808 cond_builder.AddInstruction( 1809 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); 1810 HloComputation* condition = 1811 module->AddEmbeddedComputation(cond_builder.Build()); 1812 1813 auto xla_while = builder.AddInstruction( 1814 HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0)); 1815 1816 module->AddEntryComputation(builder.Build()); 1817 1818 InsertCopies(module.get()); 1819 1820 EXPECT_EQ(CountCopies(*module), 2); 1821 1822 EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter())); 1823 EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant())); 1824 EXPECT_THAT(condition->root_instruction(), op::Constant()); 1825 } 1826 1827 TEST_F(CopyInsertionTest, TokensShouldNotBeCopied) { 1828 string module_string = R"( 1829 HloModule TokensShouldNotBeCopied 1830 1831 %Body (param.1: (s32[], token[])) -> (s32[], token[]) { 1832 %param.1 = (s32[], token[]) parameter(0) 1833 %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 1834 %constant.1 = s32[] constant(1) 1835 %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) 1836 %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 1837 %after-all = token[] after-all(token[] %get-tuple-element.2) 1838 ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) 1839 } 1840 1841 %Cond (param: (s32[], token[])) -> pred[] { 1842 %param = (s32[], token[]) parameter(0) 1843 %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 1844 %constant = s32[] constant(42) 1845 ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT 1846 } 1847 1848 ENTRY %TokensShouldNotBeCopied () -> s32[] { 1849 %one = s32[] constant(1) 1850 %negative_one = s32[] negate(%one) 1851 %init_token = token[] after-all() 1852 %init_tuple = (s32[], token[]) tuple(s32[] %negative_one, token[] %init_token) 1853 %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body 1854 ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 1855 } 1856 )"; 1857 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, 1858 ParseAndReturnVerifiedModule(module_string)); 1859 InsertCopies(module.get()); 1860 1861 // There should be no copies added because tokens should not be copied. 1862 EXPECT_EQ(CountCopies(*module), 0); 1863 } 1864 1865 std::unique_ptr<HloComputation> MakeTrivialCondition(const Shape& shape) { 1866 auto builder = HloComputation::Builder("trivial_condition"); 1867 builder.AddInstruction( 1868 HloInstruction::CreateParameter(0, shape, "loop_state")); 1869 auto constant = builder.AddInstruction( 1870 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); 1871 builder.AddInstruction(HloInstruction::CreateUnary( 1872 constant->shape(), HloOpcode::kNot, constant)); 1873 return builder.Build(); 1874 } 1875 1876 std::unique_ptr<HloComputation> MakeBenchmarkWhileBody() { 1877 auto builder = HloComputation::Builder("benchmark_loop_body"); 1878 const Shape element_shape = ShapeUtil::MakeShape(F32, {42}); 1879 const Shape loop_state_shape = 1880 ShapeUtil::MakeTupleShape({element_shape, element_shape, element_shape}); 1881 HloInstruction* param = builder.AddInstruction( 1882 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); 1883 HloInstruction* element_0 = builder.AddInstruction( 1884 HloInstruction::CreateGetTupleElement(element_shape, param, 0)); 1885 HloInstruction* element_1 = builder.AddInstruction( 1886 HloInstruction::CreateGetTupleElement(element_shape, param, 1)); 1887 HloInstruction* element_2 = builder.AddInstruction( 1888 HloInstruction::CreateGetTupleElement(element_shape, param, 2)); 1889 1890 HloInstruction* rev_1 = builder.AddInstruction( 1891 HloInstruction::CreateReverse(element_shape, element_1, {0})); 1892 HloInstruction* add_1_2 = builder.AddInstruction(HloInstruction::CreateBinary( 1893 element_shape, HloOpcode::kAdd, element_1, element_2)); 1894 1895 builder.AddInstruction( 1896 HloInstruction::CreateTuple({element_0, rev_1, add_1_2})); 1897 return builder.Build(); 1898 } 1899 1900 void BM_SequentialWhiles(int num_iters, int num_whiles) { 1901 // This benchmark constructs a chain of sequential while instructions. 1902 tensorflow::testing::StopTiming(); 1903 for (int i = 0; i < num_iters; ++i) { 1904 HloModuleConfig config; 1905 config.set_debug_options(GetDebugOptionsFromFlags()); 1906 HloModule module("BM_SequentialWhiles", config); 1907 1908 auto builder = HloComputation::Builder("BM_SequentialWhiles"); 1909 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( 1910 0, ShapeUtil::MakeShape(F32, {42}), "x")); 1911 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( 1912 1, ShapeUtil::MakeShape(F32, {42}), "y")); 1913 HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( 1914 2, ShapeUtil::MakeShape(F32, {42}), "z")); 1915 HloInstruction* init = 1916 builder.AddInstruction(HloInstruction::CreateTuple({x, y, z})); 1917 1918 HloInstruction* prev_loop_state = init; 1919 for (int w = 0; w < num_whiles; ++w) { 1920 HloComputation* condition = 1921 module.AddEmbeddedComputation(MakeTrivialCondition(init->shape())); 1922 HloComputation* body = 1923 module.AddEmbeddedComputation(MakeBenchmarkWhileBody()); 1924 prev_loop_state = builder.AddInstruction(HloInstruction::CreateWhile( 1925 init->shape(), condition, body, prev_loop_state)); 1926 } 1927 module.AddEntryComputation(builder.Build()); 1928 1929 CopyInsertion copy_insertion; 1930 1931 tensorflow::testing::StartTiming(); 1932 ASSERT_IS_OK(copy_insertion.Run(&module).status()); 1933 tensorflow::testing::StopTiming(); 1934 1935 // The entry computation should have three copies, and each body has one. 1936 ASSERT_EQ(CountCopies(module), 3 + num_whiles); 1937 } 1938 } 1939 1940 void BM_ParallelWhiles(int num_iters, int num_whiles) { 1941 // This benchmark constructs a fan-out of parallel while instructions. 1942 tensorflow::testing::StopTiming(); 1943 for (int i = 0; i < num_iters; ++i) { 1944 HloModuleConfig config; 1945 config.set_debug_options(GetDebugOptionsFromFlags()); 1946 HloModule module("BM_SequentialWhiles", config); 1947 1948 auto builder = HloComputation::Builder("BM_ParallelWhiles"); 1949 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( 1950 0, ShapeUtil::MakeShape(F32, {42}), "x")); 1951 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( 1952 1, ShapeUtil::MakeShape(F32, {42}), "y")); 1953 HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( 1954 2, ShapeUtil::MakeShape(F32, {42}), "z")); 1955 HloInstruction* init = 1956 builder.AddInstruction(HloInstruction::CreateTuple({x, y, z})); 1957 1958 HloInstruction* sum = nullptr; 1959 for (int w = 0; w < num_whiles; ++w) { 1960 HloComputation* condition = 1961 module.AddEmbeddedComputation(MakeTrivialCondition(init->shape())); 1962 HloComputation* body = 1963 module.AddEmbeddedComputation(MakeBenchmarkWhileBody()); 1964 1965 HloInstruction* xla_while = builder.AddInstruction( 1966 HloInstruction::CreateWhile(init->shape(), condition, body, init)); 1967 1968 if (sum == nullptr) { 1969 sum = builder.AddInstruction( 1970 HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0)); 1971 } else { 1972 HloInstruction* element_0 = builder.AddInstruction( 1973 HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0)); 1974 sum = builder.AddInstruction(HloInstruction::CreateBinary( 1975 x->shape(), HloOpcode::kAdd, sum, element_0)); 1976 } 1977 } 1978 module.AddEntryComputation(builder.Build()); 1979 1980 CopyInsertion copy_insertion; 1981 1982 tensorflow::testing::StartTiming(); 1983 ASSERT_IS_OK(copy_insertion.Run(&module).status()); 1984 tensorflow::testing::StopTiming(); 1985 1986 // Each body receives of copy of two of the parameters (the corresponding 1987 // elements in the body are modifed), and there is one copy in each body. 1988 ASSERT_EQ(CountCopies(module), 3 * num_whiles); 1989 } 1990 } 1991 1992 std::unique_ptr<HloComputation> MakeBenchmarkWhileBody( 1993 const int num_tuple_inputs) { 1994 auto builder = HloComputation::Builder("benchmark_loop_body"); 1995 const Shape element_shape = ShapeUtil::MakeShape(F32, {}); 1996 std::vector<Shape> input_shape(num_tuple_inputs, element_shape); 1997 const Shape loop_state_shape = ShapeUtil::MakeTupleShape(input_shape); 1998 HloInstruction* param = builder.AddInstruction( 1999 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); 2000 std::vector<HloInstruction*> gte_nodes(num_tuple_inputs); 2001 for (int i = 0; i < num_tuple_inputs; ++i) { 2002 gte_nodes[i] = builder.AddInstruction( 2003 HloInstruction::CreateGetTupleElement(element_shape, param, i)); 2004 } 2005 builder.AddInstruction(HloInstruction::CreateTuple(gte_nodes)); 2006 return builder.Build(); 2007 } 2008 2009 void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) { 2010 tensorflow::testing::StopTiming(); 2011 HloModuleConfig config; 2012 config.set_debug_options(GetDebugOptionsFromFlags()); 2013 CopyInsertion copy_insertion; 2014 const Shape element_shape = ShapeUtil::MakeShape(F32, {}); 2015 std::vector<HloInstruction*> tuple_params(num_tuple_inputs); 2016 for (int i = 0; i < num_iters; ++i) { 2017 auto builder = HloComputation::Builder("BM_ParallelWhiles"); 2018 HloModule module("BM_ManyElementTuple", config); 2019 for (int j = 0; j < num_tuple_inputs; ++j) { 2020 tuple_params[j] = builder.AddInstruction( 2021 HloInstruction::CreateParameter(j, element_shape, "")); 2022 } 2023 HloInstruction* init = 2024 builder.AddInstruction(HloInstruction::CreateTuple(tuple_params)); 2025 HloComputation* condition = 2026 module.AddEmbeddedComputation(MakeTrivialCondition(init->shape())); 2027 HloComputation* body = 2028 module.AddEmbeddedComputation(MakeBenchmarkWhileBody(num_tuple_inputs)); 2029 HloInstruction* xla_while = builder.AddInstruction( 2030 HloInstruction::CreateWhile(init->shape(), condition, body, init)); 2031 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 2032 ShapeUtil::MakeShape(F32, {}), xla_while, 0)); 2033 module.AddEntryComputation(builder.Build()); 2034 tensorflow::testing::StartTiming(); 2035 ASSERT_IS_OK(copy_insertion.Run(&module).status()); 2036 tensorflow::testing::StopTiming(); 2037 } 2038 } 2039 2040 BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); 2041 BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); 2042 BENCHMARK(BM_ManyElementTuple)->Arg(1024)->Arg(12288); 2043 2044 TEST_F(CopyInsertionTest, SimpleControlFlowTest) { 2045 const string& hlo_string = R"( 2046 HloModule TestModule 2047 2048 if-body.v5 { 2049 constant.3 = s32[] constant(-1) 2050 p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) 2051 get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1 2052 get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0 2053 get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1 2054 add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66) 2055 tuple.33 = (s32[]) tuple(add.3) 2056 ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33) 2057 } 2058 2059 if-condition.v4 { 2060 p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) 2061 get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0 2062 constant.4 = s32[] constant(0) 2063 ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ 2064 } 2065 2066 _functionalize_body_1__.v28 { 2067 arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0) 2068 get-tuple-element.68 = s32[] get-tuple-element(arg_tuple.4), index=0 2069 constant.7 = s32[] constant(1) 2070 add.4 = s32[] add(get-tuple-element.68, constant.7) 2071 get-tuple-element.69 = s32[] get-tuple-element(arg_tuple.4), index=1 2072 get-tuple-element.70 = s32[] get-tuple-element(arg_tuple.4), index=2 2073 less-than-or-equal-to = pred[] compare(get-tuple-element.69, get-tuple-element.70), direction=LE 2074 constant.8 = s32[] constant(0) 2075 select = s32[] select(less-than-or-equal-to, constant.8, constant.7) 2076 get-tuple-element.71 = s32[] get-tuple-element(arg_tuple.4), index=3 2077 tuple.35 = (s32[], s32[], s32[]) tuple(get-tuple-element.69, get-tuple-element.71, get-tuple-element.70) 2078 tuple.36 = (s32[]) tuple(constant.8) 2079 tuple.37 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.35, tuple.36) 2080 while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.37), condition=if-condition.v4, body=if-body.v5 2081 get-tuple-element.72 = (s32[]) get-tuple-element(while), index=2 2082 get-tuple-element.73 = s32[] get-tuple-element(get-tuple-element.72), index=0 2083 ROOT tuple.38 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.69, get-tuple-element.70, get-tuple-element.73) 2084 } 2085 2086 cond_wrapper.v3.1 { 2087 inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0) 2088 get-tuple-element.75 = s32[] get-tuple-element(inputs.1), index=0 2089 constant.11 = s32[] constant(7) 2090 ROOT less-than.2 = pred[] compare(get-tuple-element.75, constant.11), direction=LT 2091 } 2092 2093 _functionalize_body_2__.v25 { 2094 arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) 2095 get-tuple-element.76 = s32[] get-tuple-element(arg_tuple.5), index=0 2096 get-tuple-element.77 = s32[] get-tuple-element(arg_tuple.5), index=2 2097 get-tuple-element.78 = s32[] get-tuple-element(arg_tuple.5), index=3 2098 get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=4 2099 tuple.39 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.76, get-tuple-element.77, get-tuple-element.78, get-tuple-element.79) 2100 while.2 = (s32[], s32[], s32[], s32[]) while(tuple.39), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28 2101 get-tuple-element.80 = s32[] get-tuple-element(while.2), index=0 2102 get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=1 2103 constant.12 = s32[] constant(1) 2104 add.5 = s32[] add(get-tuple-element.81, constant.12) 2105 get-tuple-element.82 = s32[] get-tuple-element(while.2), index=3 2106 ROOT tuple.40 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.80, add.5, get-tuple-element.77, get-tuple-element.78, get-tuple-element.82) 2107 } 2108 2109 cond_wrapper.v3.2 { 2110 inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) 2111 get-tuple-element.83 = s32[] get-tuple-element(inputs.2), index=1 2112 constant.13 = s32[] constant(5) 2113 ROOT less-than.3 = pred[] compare(get-tuple-element.83, constant.13), direction=LT 2114 } 2115 2116 ENTRY TestComputation { 2117 arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) 2118 ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25 2119 } 2120 )"; 2121 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string); 2122 auto module = module_or_status.ConsumeValueOrDie(); 2123 InsertCopies(module.get()); 2124 } 2125 2126 TEST_F(CopyInsertionTest, ControlFlowTest) { 2127 const string& hlo_string = R"( 2128 HloModule TestModule 2129 2130 if-body.v5 { 2131 constant.3 = s32[] constant(-1) 2132 p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) 2133 get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1 2134 get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0 2135 get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1 2136 add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66) 2137 tuple.33 = (s32[]) tuple(add.3) 2138 ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33) 2139 } 2140 2141 if-condition.v4 { 2142 p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) 2143 get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0 2144 constant.4 = s32[] constant(0) 2145 ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ 2146 } 2147 2148 if-body.v5.1 { 2149 constant.5 = s32[] constant(-1) 2150 p.3 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) 2151 get-tuple-element.68 = (s32[], s32[], s32[]) get-tuple-element(p.3), index=1 2152 get-tuple-element.70 = s32[] get-tuple-element(get-tuple-element.68), index=2 2153 multiply.1 = s32[] multiply(get-tuple-element.70, get-tuple-element.70) 2154 tuple.35 = (s32[]) tuple(multiply.1) 2155 ROOT tuple.36 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.5, get-tuple-element.68, tuple.35) 2156 } 2157 2158 if-condition.v4.1 { 2159 p.4 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) 2160 get-tuple-element.71 = s32[] get-tuple-element(p.4), index=0 2161 constant.6 = s32[] constant(1) 2162 ROOT equal-to.1 = pred[] compare(get-tuple-element.71, constant.6), direction=EQ 2163 } 2164 2165 _functionalize_body_1__.v28 { 2166 arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0) 2167 get-tuple-element.72 = s32[] get-tuple-element(arg_tuple.4), index=0 2168 constant.7 = s32[] constant(1) 2169 add.4 = s32[] add(get-tuple-element.72, constant.7) 2170 get-tuple-element.73 = s32[] get-tuple-element(arg_tuple.4), index=1 2171 get-tuple-element.74 = s32[] get-tuple-element(arg_tuple.4), index=2 2172 less-than-or-equal-to = pred[] compare(get-tuple-element.73, get-tuple-element.74), direction=LE 2173 constant.8 = s32[] constant(0) 2174 select = s32[] select(less-than-or-equal-to, constant.8, constant.7) 2175 get-tuple-element.75 = s32[] get-tuple-element(arg_tuple.4), index=3 2176 tuple.37 = (s32[], s32[], s32[]) tuple(get-tuple-element.73, get-tuple-element.75, get-tuple-element.74) 2177 tuple.38 = (s32[]) tuple(constant.8) 2178 tuple.39 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.37, tuple.38) 2179 while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.39), condition=if-condition.v4, body=if-body.v5 2180 while.1 = (s32[], (s32[], s32[], s32[]), (s32[])) while(while), condition=if-condition.v4.1, body=if-body.v5.1 2181 get-tuple-element.76 = (s32[]) get-tuple-element(while.1), index=2 2182 get-tuple-element.77 = s32[] get-tuple-element(get-tuple-element.76), index=0 2183 ROOT tuple.40 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.73, get-tuple-element.74, get-tuple-element.77) 2184 } 2185 2186 cond_wrapper.v3.1 { 2187 inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0) 2188 get-tuple-element.78 = s32[] get-tuple-element(inputs.1), index=0 2189 constant.11 = s32[] constant(7) 2190 ROOT less-than.2 = pred[] compare(get-tuple-element.78, constant.11), direction=LT 2191 } 2192 2193 _functionalize_body_2__.v25 { 2194 arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) 2195 get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=0 2196 get-tuple-element.80 = s32[] get-tuple-element(arg_tuple.5), index=2 2197 get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=3 2198 get-tuple-element.82 = s32[] get-tuple-element(arg_tuple.5), index=4 2199 tuple.41 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.79, get-tuple-element.80, get-tuple-element.81, get-tuple-element.82) 2200 while.2 = (s32[], s32[], s32[], s32[]) while(tuple.41), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28 2201 get-tuple-element.83 = s32[] get-tuple-element(while.2), index=0 2202 get-tuple-element.84 = s32[] get-tuple-element(arg_tuple.5), index=1 2203 constant.12 = s32[] constant(1) 2204 add.5 = s32[] add(get-tuple-element.84, constant.12) 2205 get-tuple-element.85 = s32[] get-tuple-element(while.2), index=3 2206 ROOT tuple.42 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.83, add.5, get-tuple-element.80, get-tuple-element.81, get-tuple-element.85) 2207 } 2208 2209 cond_wrapper.v3.2 { 2210 inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) 2211 get-tuple-element.86 = s32[] get-tuple-element(inputs.2), index=1 2212 constant.13 = s32[] constant(5) 2213 ROOT less-than.3 = pred[] compare(get-tuple-element.86, constant.13), direction=LT 2214 } 2215 2216 ENTRY TestComputation { 2217 arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) 2218 ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25 2219 } 2220 )"; 2221 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string); 2222 auto module = module_or_status.ConsumeValueOrDie(); 2223 InsertCopies(module.get()); 2224 } 2225 2226 TEST_F(CopyInsertionTest, NestedWhiles) { 2227 // Verify that only no unnecessary copies remain after copy insertion for 2228 // trivial nested whiles (b/112472605). 2229 const string& hlo_string = R"( 2230 HloModule TestModule 2231 2232 cond.inner { 2233 ROOT param.cond.inner = pred[] parameter(0) 2234 } 2235 2236 body.inner { 2237 param.body.inner = pred[] parameter(0) 2238 ROOT not = pred[] not(param.body.inner) 2239 } 2240 2241 cond.outer { 2242 ROOT param.cond.outer = pred[] parameter(0) 2243 } 2244 2245 body.outer { 2246 param.cond.outer = pred[] parameter(0) 2247 ROOT while = pred[] while(param.cond.outer), condition=cond.inner, body=body.inner 2248 } 2249 2250 ENTRY TestComputation { 2251 entry_param = pred[] parameter(0) 2252 ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer 2253 } 2254 )"; 2255 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, 2256 ParseAndReturnVerifiedModule(hlo_string)); 2257 InsertCopies(module.get()); 2258 2259 // There should only be a single copy inserted, and it's in the entry 2260 // computation. 2261 EXPECT_EQ(CountCopies(*module), 1); 2262 EXPECT_THAT(module->entry_computation()->root_instruction(), 2263 op::While(op::Copy(op::Parameter()))); 2264 } 2265 2266 } // namespace 2267 } // namespace xla 2268