1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/copy_insertion.h" 17 18 #include <set> 19 20 #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" 21 #include "tensorflow/compiler/xla/literal_util.h" 22 #include "tensorflow/compiler/xla/service/hlo_computation.h" 23 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 24 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 25 #include "tensorflow/compiler/xla/service/hlo_module.h" 26 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 27 #include "tensorflow/compiler/xla/service/hlo_runner.h" 28 #include "tensorflow/compiler/xla/shape_util.h" 29 #include "tensorflow/compiler/xla/test.h" 30 #include "tensorflow/compiler/xla/test_helpers.h" 31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 32 #include "tensorflow/compiler/xla/xla_data.pb.h" 33 #include "tensorflow/core/platform/test_benchmark.h" 34 35 namespace op = xla::testing::opcode_matchers; 36 37 namespace xla { 38 namespace { 39 40 using ::testing::UnorderedElementsAre; 41 42 int64 CountCopies(const HloComputation& computation) { 43 int64 count = 0; 44 for (const auto& instruction : computation.instructions()) { 45 if (instruction->opcode() == HloOpcode::kCopy) { 46 count++; 47 } 48 } 49 return count; 50 } 51 52 int64 CountCopies(const HloModule& module) { 53 int64 count = 0; 54 for (const auto& computation : module.computations()) { 55 count += CountCopies(*computation); 56 } 57 return count; 58 } 59 60 int64 CountControlEdges(const HloComputation& computation) { 61 int64 count = 0; 62 for (const auto& instruction : computation.instructions()) { 63 count += instruction->control_successors().size(); 64 } 65 return count; 66 } 67 68 int64 CountControlEdges(const HloModule& module) { 69 int64 count = 0; 70 for (const auto& computation : module.computations()) { 71 count += CountControlEdges(*computation); 72 } 73 return count; 74 } 75 76 class CopyInsertionTest : public HloTestBase { 77 protected: 78 void InsertCopies(HloModule* module) { 79 CopyInsertion copy_insertion; 80 ASSERT_IS_OK(copy_insertion.Run(module).status()); 81 } 82 83 const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); 84 }; 85 86 TEST_F(CopyInsertionTest, SingleParameter) { 87 // Computation is a single parameter passed into a tuple. The parameter should 88 // be copied before entering the tuple. 89 auto builder = HloComputation::Builder(TestName()); 90 HloInstruction* x = builder.AddInstruction( 91 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x")); 92 HloInstruction* tuple = 93 builder.AddInstruction(HloInstruction::CreateTuple({x})); 94 95 EXPECT_THAT(x->users(), UnorderedElementsAre(tuple)); 96 97 auto module = CreateNewModule(); 98 module->AddEntryComputation(builder.Build()); 99 100 InsertCopies(module.get()); 101 102 EXPECT_THAT(module->entry_computation()->root_instruction(), 103 op::Tuple(op::Copy(x))); 104 } 105 106 TEST_F(CopyInsertionTest, SingleConstant) { 107 // Computation is a single constant passed into a tuple. The parameter should 108 // be copied before entering the tuple. 109 auto builder = HloComputation::Builder(TestName()); 110 HloInstruction* constant = builder.AddInstruction( 111 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 112 HloInstruction* tuple = 113 builder.AddInstruction(HloInstruction::CreateTuple({constant})); 114 115 EXPECT_THAT(constant->users(), UnorderedElementsAre(tuple)); 116 117 auto module = CreateNewModule(); 118 module->AddEntryComputation(builder.Build()); 119 120 InsertCopies(module.get()); 121 EXPECT_EQ(CountCopies(*module), 1); 122 123 EXPECT_THAT(module->entry_computation()->root_instruction(), 124 op::Tuple(op::Copy(constant))); 125 } 126 127 TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { 128 // Verify that an kCopy instructions which exist in the pass before 129 // copy-insertion remain in the graph after copy-insertion. 130 auto module = CreateNewModule(); 131 132 auto builder = HloComputation::Builder(TestName()); 133 HloInstruction* constant = builder.AddInstruction( 134 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 135 HloInstruction* copy_1 = builder.AddInstruction(HloInstruction::CreateUnary( 136 constant->shape(), HloOpcode::kCopy, constant)); 137 HloInstruction* copy_2 = builder.AddInstruction(HloInstruction::CreateUnary( 138 constant->shape(), HloOpcode::kCopy, constant)); 139 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( 140 constant->shape(), HloOpcode::kAdd, copy_1, copy_2)); 141 HloInstruction* add_copy = builder.AddInstruction( 142 HloInstruction::CreateUnary(constant->shape(), HloOpcode::kCopy, add)); 143 144 module->AddEntryComputation(builder.Build()); 145 146 EXPECT_EQ(CountCopies(*module), 3); 147 148 InsertCopies(module.get()); 149 150 EXPECT_EQ(CountCopies(*module), 3); 151 152 EXPECT_EQ(module->entry_computation()->root_instruction(), add_copy); 153 EXPECT_THAT( 154 module->entry_computation()->root_instruction(), 155 op::Copy(op::Add(op::Copy(op::Constant()), op::Copy(op::Constant())))); 156 } 157 158 TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { 159 // Create a computation with more than one constant and parameter. Only one of 160 // each constant/parameter is pointed to by the output tuple. Only these 161 // instructions should be copied. 162 auto builder = HloComputation::Builder(TestName()); 163 164 HloInstruction* constant1 = builder.AddInstruction( 165 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 166 HloInstruction* constant2 = builder.AddInstruction( 167 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 168 169 HloInstruction* x = builder.AddInstruction( 170 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x")); 171 HloInstruction* y = builder.AddInstruction( 172 HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "y")); 173 174 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( 175 ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, constant1, y)); 176 177 builder.AddInstruction(HloInstruction::CreateTuple({constant2, x, add})); 178 179 auto module = CreateNewModule(); 180 module->AddEntryComputation(builder.Build()); 181 182 InsertCopies(module.get()); 183 EXPECT_EQ(CountCopies(*module), 2); 184 185 EXPECT_THAT( 186 module->entry_computation()->root_instruction(), 187 op::Tuple(op::Copy(constant2), op::Copy(x), op::Add(constant1, y))); 188 } 189 190 TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { 191 // Create a computation using select which has an ambiguous points-to set for 192 // the computation result. Verify that copies are added properly. 193 auto builder = HloComputation::Builder(TestName()); 194 HloInstruction* constant1 = builder.AddInstruction( 195 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 196 HloInstruction* constant2 = builder.AddInstruction( 197 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 198 HloInstruction* constant3 = builder.AddInstruction( 199 HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0))); 200 201 HloInstruction* tuple1 = builder.AddInstruction( 202 HloInstruction::CreateTuple({constant1, constant2})); 203 HloInstruction* tuple2 = builder.AddInstruction( 204 HloInstruction::CreateTuple({constant3, constant2})); 205 206 HloInstruction* pred = builder.AddInstruction( 207 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 208 builder.AddInstruction(HloInstruction::CreateTernary( 209 tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); 210 211 EXPECT_THAT(constant1->users(), UnorderedElementsAre(tuple1)); 212 EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2)); 213 EXPECT_THAT(constant3->users(), UnorderedElementsAre(tuple2)); 214 215 auto module = CreateNewModule(); 216 module->AddEntryComputation(builder.Build()); 217 218 HloInstruction* old_root = module->entry_computation()->root_instruction(); 219 InsertCopies(module.get()); 220 EXPECT_EQ(CountCopies(*module), 2); 221 222 EXPECT_THAT(module->entry_computation()->root_instruction(), 223 op::Tuple(op::Copy(op::GetTupleElement(old_root)), 224 op::Copy(op::GetTupleElement(old_root)))); 225 } 226 227 TEST_F(CopyInsertionTest, BitcastParameter) { 228 // The output of a bitcast is its operand (same buffer), so a bitcast 229 // parameter feeding the result must have a copy added. 230 auto builder = HloComputation::Builder(TestName()); 231 HloInstruction* x = builder.AddInstruction( 232 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x")); 233 HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( 234 ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); 235 236 auto module = CreateNewModule(); 237 module->AddEntryComputation(builder.Build()); 238 239 EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); 240 241 HloInstruction* old_root = module->entry_computation()->root_instruction(); 242 InsertCopies(module.get()); 243 EXPECT_EQ(CountCopies(*module), 1); 244 245 EXPECT_THAT(module->entry_computation()->root_instruction(), 246 op::Copy(old_root)); 247 } 248 249 TEST_F(CopyInsertionTest, BitcastConstant) { 250 // The output of a bitcast is its operand (same buffer), so a bitcast 251 // constant feeding the result must have a copy added. 252 auto builder = HloComputation::Builder(TestName()); 253 HloInstruction* constant = builder.AddInstruction( 254 HloInstruction::CreateConstant(Literal::CreateR1<float>({1.0, 42.0}))); 255 HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( 256 ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant)); 257 258 auto module = CreateNewModule(); 259 module->AddEntryComputation(builder.Build()); 260 261 EXPECT_THAT(constant->users(), UnorderedElementsAre(bitcast)); 262 263 HloInstruction* old_root = module->entry_computation()->root_instruction(); 264 InsertCopies(module.get()); 265 EXPECT_EQ(CountCopies(*module), 1); 266 267 EXPECT_THAT(module->entry_computation()->root_instruction(), 268 op::Copy(old_root)); 269 } 270 271 TEST_F(CopyInsertionTest, BitcastTupleElementParameter) { 272 // Same as BitcastParameter, but the bitcast is wrapped in a tuple. 273 auto builder = HloComputation::Builder(TestName()); 274 HloInstruction* x = builder.AddInstruction( 275 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x")); 276 HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( 277 ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); 278 builder.AddInstruction(HloInstruction::CreateTuple({bitcast})); 279 280 auto module = CreateNewModule(); 281 module->AddEntryComputation(builder.Build()); 282 283 EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); 284 285 InsertCopies(module.get()); 286 EXPECT_EQ(CountCopies(*module), 1); 287 288 EXPECT_THAT(module->entry_computation()->root_instruction(), 289 op::Tuple(op::Copy(bitcast))); 290 } 291 292 TEST_F(CopyInsertionTest, NestedTupleParameter) { 293 // Construct a trivial computation where the root of the computation is a 294 // nested tuple-shaped parameter. The parameter should be deep copied and the 295 // copy should be the root of the computation. 296 auto builder = HloComputation::Builder(TestName()); 297 298 // Param shape is: ((F32[], S32[1,2,3]), F32[42]) 299 builder.AddInstruction(HloInstruction::CreateParameter( 300 0, 301 ShapeUtil::MakeTupleShape( 302 {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}), 303 ShapeUtil::MakeShape(S32, {1, 2, 3})}), 304 ShapeUtil::MakeShape(F32, {42})}), 305 "param0")); 306 307 auto module = CreateNewModule(); 308 module->AddEntryComputation(builder.Build()); 309 310 EXPECT_EQ(HloOpcode::kParameter, 311 module->entry_computation()->root_instruction()->opcode()); 312 313 HloInstruction* old_root = module->entry_computation()->root_instruction(); 314 InsertCopies(module.get()); 315 EXPECT_EQ(CountCopies(*module), 3); 316 317 HloInstruction* new_root = module->entry_computation()->root_instruction(); 318 EXPECT_NE(old_root, new_root); 319 320 EXPECT_THAT( 321 new_root, 322 op::Tuple( 323 op::Tuple( 324 op::Copy(op::GetTupleElement(op::GetTupleElement(old_root))), 325 op::Copy(op::GetTupleElement(op::GetTupleElement(old_root)))), 326 op::Copy(op::GetTupleElement(old_root)))); 327 } 328 329 TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { 330 // Construct a computation where the root of the computation is a tuple 331 // element of a nested tuple-shaped parameter. 332 auto builder = HloComputation::Builder(TestName()); 333 334 // Param shape is: ((F32[], S32[1,2,3]), F32[42]) 335 auto param = builder.AddInstruction(HloInstruction::CreateParameter( 336 0, 337 ShapeUtil::MakeTupleShape( 338 {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}), 339 ShapeUtil::MakeShape(S32, {1, 2, 3})}), 340 ShapeUtil::MakeShape(F32, {42})}), 341 "param0")); 342 343 // The return value of the computation is the zero-th element of the nested 344 // tuple. This element is itself a tuple. 345 auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 346 ShapeUtil::GetSubshape(param->shape(), {0}), param, 0)); 347 348 auto module = CreateNewModule(); 349 module->AddEntryComputation(builder.Build()); 350 351 EXPECT_EQ(gte, module->entry_computation()->root_instruction()); 352 353 InsertCopies(module.get()); 354 EXPECT_EQ(CountCopies(*module), 2); 355 356 EXPECT_THAT( 357 module->entry_computation()->root_instruction(), 358 op::Tuple(op::Copy(op::GetTupleElement(op::GetTupleElement(param))), 359 op::Copy(op::GetTupleElement(op::GetTupleElement(param))))); 360 } 361 362 TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { 363 // Create a computation using select which has an ambiguous points-to set for 364 // the top-level buffer of the root of the computation. Verify that a shallow 365 // copy is added. 366 auto builder = HloComputation::Builder(TestName()); 367 HloInstruction* constant1 = builder.AddInstruction( 368 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 369 HloInstruction* constant2 = builder.AddInstruction( 370 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 371 372 HloInstruction* tuple1 = builder.AddInstruction( 373 HloInstruction::CreateTuple({constant1, constant2})); 374 HloInstruction* tuple2 = builder.AddInstruction( 375 HloInstruction::CreateTuple({constant2, constant1})); 376 377 HloInstruction* pred = builder.AddInstruction( 378 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 379 HloInstruction* select = builder.AddInstruction(HloInstruction::CreateTernary( 380 tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); 381 HloInstruction* gte = 382 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 383 ShapeUtil::GetSubshape(select->shape(), {0}), select, 0)); 384 385 auto module = CreateNewModule(); 386 module->AddEntryComputation(builder.Build()); 387 388 EXPECT_EQ(gte, module->entry_computation()->root_instruction()); 389 390 HloInstruction* old_root = module->entry_computation()->root_instruction(); 391 InsertCopies(module.get()); 392 EXPECT_EQ(CountCopies(*module), 1); 393 394 EXPECT_THAT(module->entry_computation()->root_instruction(), 395 op::Copy(old_root)); 396 } 397 398 class WhileCopyInsertionTest : public CopyInsertionTest { 399 protected: 400 WhileCopyInsertionTest() : module_(CreateNewModule()) {} 401 402 // Builds a While condition computation which reads the induction variable 403 // from the tuple parameter, and returns a predicate indicating whether this 404 // value is less than the constant '10'. 405 // The parameter 'nested' specifies the loop state shape from which to 406 // read the induction variable. 407 std::unique_ptr<HloComputation> BuildConditionComputation( 408 const Shape& loop_state_shape) { 409 auto builder = HloComputation::Builder(TestName() + ".Condition"); 410 auto limit_const = builder.AddInstruction( 411 HloInstruction::CreateConstant(Literal::CreateR0<int32>(10))); 412 auto loop_state = builder.AddInstruction( 413 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); 414 auto induction_variable = 415 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 416 limit_const->shape(), loop_state, 0)); 417 builder.AddInstruction( 418 HloInstruction::CreateBinary(condition_result_shape_, HloOpcode::kLt, 419 induction_variable, limit_const)); 420 return builder.Build(); 421 } 422 423 // Builds a While body computation with one output tuple element dependent on 424 // both input tuple elements. 425 // EX: 426 // Body({in0, in1}) 427 // out0 = Add(in0, 1) 428 // out1 = Add(BCast(in0), in1) 429 // Tuple(out0, out1) 430 std::unique_ptr<HloComputation> BuildDependentBodyComputation() { 431 auto builder = HloComputation::Builder(TestName() + ".Body"); 432 // Create param instruction to access loop state. 433 auto loop_state = builder.AddInstruction( 434 HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); 435 // Update the induction variable GTE(0). 436 auto induction_variable = 437 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 438 induction_variable_shape_, loop_state, 0)); 439 auto inc = builder.AddInstruction( 440 HloInstruction::CreateConstant(Literal::CreateR0<int32>(1))); 441 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( 442 induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); 443 // Update data GTE(1). 444 auto data = builder.AddInstruction( 445 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); 446 // Use 'induction_variable' in computation with no path to output tuple. 447 auto update = builder.AddInstruction( 448 HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8})); 449 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( 450 data_shape_, HloOpcode::kAdd, data, update)); 451 // Create output Tuple. 452 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); 453 return builder.Build(); 454 } 455 456 // Builds a While body computation with two output tuple elements dependent on 457 // both input tuple elements. 458 // 459 // EX: Body({in0, in1, in2}) 460 // out0 = Add(in0, 1) 461 // out1 = in1 462 // out2 = in2 463 // Tuple(out0, out1, out2) 464 std::unique_ptr<HloComputation> BuildDependentBodyComputation2() { 465 auto builder = HloComputation::Builder(TestName() + ".Body"); 466 467 const Shape& loop_state_shape = ShapeUtil::MakeTupleShape( 468 {induction_variable_shape_, data_shape_, data_shape_}); 469 470 auto loop_state = builder.AddInstruction( 471 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); 472 473 // Update the induction variable GTE(0). 474 auto induction_variable = 475 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 476 induction_variable_shape_, loop_state, 0)); 477 auto inc = builder.AddInstruction( 478 HloInstruction::CreateConstant(Literal::CreateR0<int32>(1))); 479 480 // add0 = Add(in0, 1) 481 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( 482 induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); 483 // data1 = GTE(1). 484 HloInstruction* data1 = builder.AddInstruction( 485 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); 486 487 // data2 = GTE(2). 488 HloInstruction* data2 = builder.AddInstruction( 489 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 2)); 490 491 // Create output Tuple. 492 builder.AddInstruction(HloInstruction::CreateTuple({add0, data1, data2})); 493 494 return builder.Build(); 495 } 496 497 // Builds a While body computation with read-only tuple element 0. 498 // EX: 499 // Body({in0, in1}) 500 // out0 = in0 501 // out1 = Add(BCast(in0), in1) 502 // Tuple(out0, out1) 503 std::unique_ptr<HloComputation> BuildDependentBodyOneReadOnlyComputation() { 504 auto builder = HloComputation::Builder(TestName() + ".Body"); 505 // Create param instruction to access loop state. 506 auto loop_state = builder.AddInstruction( 507 HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); 508 // Update the induction variable GTE(0). 509 auto induction_variable = 510 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 511 induction_variable_shape_, loop_state, 0)); 512 // Update data GTE(1). 513 auto data = builder.AddInstruction( 514 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); 515 516 // Use 'induction_variable' in computation with no path to output tuple. 517 auto update = builder.AddInstruction( 518 HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8})); 519 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( 520 data_shape_, HloOpcode::kAdd, data, update)); 521 // Create output Tuple. 522 builder.AddInstruction( 523 HloInstruction::CreateTuple({induction_variable, add1})); 524 return builder.Build(); 525 } 526 527 // Builds a While body computation with independent outputs. 528 // EX: 529 // Body({in0, in1}) 530 // out0 = Add(in0, 1) 531 // out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) 532 // Tuple(out0, out1) 533 std::unique_ptr<HloComputation> BuildIndependentBodyComputation( 534 bool nested = false) { 535 auto builder = HloComputation::Builder(TestName() + ".Body"); 536 // Create param instruction to access loop state. 537 const Shape& loop_state_shape = 538 nested ? nested_loop_state_shape_ : loop_state_shape_; 539 540 auto loop_state = builder.AddInstruction( 541 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); 542 // Update the induction variable GTE(0). 543 auto induction_variable = 544 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 545 induction_variable_shape_, loop_state, 0)); 546 auto inc = builder.AddInstruction( 547 HloInstruction::CreateConstant(Literal::CreateR0<int32>(1))); 548 // add0 = Add(in0, 1) 549 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( 550 induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); 551 // Update data GTE(1). 552 HloInstruction* data = nullptr; 553 if (nested) { 554 data = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 555 nested_tuple_shape_, loop_state, 1)); 556 data = builder.AddInstruction( 557 HloInstruction::CreateGetTupleElement(data_shape_, data, 0)); 558 } else { 559 data = builder.AddInstruction( 560 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); 561 } 562 auto update = builder.AddInstruction(HloInstruction::CreateConstant( 563 Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); 564 // add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) 565 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( 566 data_shape_, HloOpcode::kAdd, data, update)); 567 // Create output Tuple. 568 if (nested) { 569 auto nested_tuple = 570 builder.AddInstruction(HloInstruction::CreateTuple({add1, add1})); 571 builder.AddInstruction(HloInstruction::CreateTuple({add0, nested_tuple})); 572 } else { 573 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); 574 } 575 return builder.Build(); 576 } 577 578 // Builds a While body computation with the following nested tuple 579 // sub-computation: 580 // | 581 // GTE(loop_state, 1) 582 // / \ 583 // GTE(GTE(loop_state, 1), 0) GTE(GTE(loop_state, 1), 1) 584 // | | 585 // Add Reverse 586 // | | 587 std::unique_ptr<HloComputation> BuildNestedBodyComputation() { 588 auto builder = HloComputation::Builder(TestName() + ".Body"); 589 // Create param instruction to access loop state. 590 auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter( 591 0, nested_loop_state_shape_, "loop_state")); 592 // Update GTE(0). 593 auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 594 induction_variable_shape_, loop_state, 0)); 595 auto inc = builder.AddInstruction( 596 HloInstruction::CreateConstant(Literal::CreateR0<int32>(1))); 597 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( 598 gte0->shape(), HloOpcode::kAdd, gte0, inc)); 599 600 // GTE(loop_state, 1) 601 auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 602 nested_tuple_shape_, loop_state, 1)); 603 // GTE(GTE(loop_state, 1), 0) -> Add 604 auto gte10 = builder.AddInstruction( 605 HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0)); 606 auto update10 = builder.AddInstruction(HloInstruction::CreateConstant( 607 Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); 608 auto add10 = builder.AddInstruction(HloInstruction::CreateBinary( 609 data_shape_, HloOpcode::kAdd, gte10, update10)); 610 611 // GTE(GTE(loop_state, 1), 1) -> Reverse 612 auto gte11 = builder.AddInstruction( 613 HloInstruction::CreateGetTupleElement(data_shape_, gte1, 1)); 614 auto rev11 = builder.AddInstruction( 615 HloInstruction::CreateReverse(data_shape_, gte11, {0})); 616 617 // Create output Tuple. 618 auto inner_tuple = 619 builder.AddInstruction(HloInstruction::CreateTuple({add10, rev11})); 620 builder.AddInstruction(HloInstruction::CreateTuple({add0, inner_tuple})); 621 return builder.Build(); 622 } 623 624 // Builds a While instruction using 'condition' and 'body' sub-computations. 625 // Init operand is initialized to zeros of appropriate shape. 626 HloInstruction* BuildWhileInstruction(HloComputation* condition, 627 HloComputation* body, 628 bool nested = false) { 629 auto builder = HloComputation::Builder(TestName() + ".While"); 630 auto induction_var_init = builder.AddInstruction( 631 HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))); 632 633 auto data_init = builder.AddInstruction(HloInstruction::CreateConstant( 634 Literal::CreateR1<float>({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); 635 636 if (nested) { 637 auto inner_init = builder.AddInstruction( 638 HloInstruction::CreateTuple({data_init, data_init})); 639 auto loop_state_init = builder.AddInstruction( 640 HloInstruction::CreateTuple({induction_var_init, inner_init})); 641 auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile( 642 loop_state_init->shape(), condition, body, loop_state_init)); 643 module_->AddEntryComputation(builder.Build()); 644 return while_hlo; 645 } 646 647 auto loop_state_init = builder.AddInstruction( 648 HloInstruction::CreateTuple({induction_var_init, data_init})); 649 auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile( 650 loop_state_shape_, condition, body, loop_state_init)); 651 module_->AddEntryComputation(builder.Build()); 652 return while_hlo; 653 } 654 655 HloInstruction* BuildWhileInstruction_InitPointsToConstant() { 656 auto builder = HloComputation::Builder(TestName() + ".While"); 657 auto data_init = builder.AddInstruction(HloInstruction::CreateConstant( 658 Literal::CreateR1<float>({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); 659 return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init, 660 &builder); 661 } 662 663 HloInstruction* BuildWhileInstruction_InitPointsToParameter() { 664 auto builder = HloComputation::Builder(TestName() + ".While"); 665 auto data_init = builder.AddInstruction( 666 HloInstruction::CreateParameter(0, data_shape_, "data_init")); 667 return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init, 668 &builder); 669 } 670 671 HloInstruction* BuildWhileInstruction_InitPointsToAmbiguous() { 672 auto builder = HloComputation::Builder(TestName() + ".While"); 673 674 auto one = builder.AddInstruction( 675 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 676 auto v1 = builder.AddInstruction( 677 HloInstruction::CreateBroadcast(data_shape_, one, {1})); 678 auto zero = builder.AddInstruction( 679 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 680 auto v2 = builder.AddInstruction( 681 HloInstruction::CreateBroadcast(data_shape_, zero, {1})); 682 683 auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({v1, v2})); 684 auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1})); 685 686 auto pred = builder.AddInstruction( 687 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 688 auto data_init = builder.AddInstruction(HloInstruction::CreateTernary( 689 nested_tuple_shape_, HloOpcode::kSelect, pred, tuple1, tuple2)); 690 691 return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_, 692 data_init, &builder); 693 } 694 695 HloInstruction* BuildWhileInstruction_InitPointsToNonDistinct() { 696 auto builder = HloComputation::Builder(TestName() + ".While"); 697 698 auto one = builder.AddInstruction( 699 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 700 auto one_vec = builder.AddInstruction( 701 HloInstruction::CreateBroadcast(data_shape_, one, {1})); 702 auto data_init = 703 builder.AddInstruction(HloInstruction::CreateTuple({one_vec, one_vec})); 704 705 return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_, 706 data_init, &builder); 707 } 708 709 HloInstruction* BuildWhileInstruction_InitPointsToInterfering() { 710 auto builder = HloComputation::Builder(TestName() + ".While"); 711 auto one = builder.AddInstruction( 712 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 713 auto data_init = builder.AddInstruction( 714 HloInstruction::CreateBroadcast(data_shape_, one, {1})); 715 auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( 716 Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); 717 // Take a reference to 'data_init' to make it interfere with while result. 718 auto add = builder.AddInstruction(HloInstruction::CreateBinary( 719 data_shape_, HloOpcode::kAdd, data_init, one_vec)); 720 721 auto xla_while = BuildWhileInstructionWithCustomInit(loop_state_shape_, 722 data_init, &builder); 723 724 // Add an additional binary operation operating on the while and the 725 // interfering add so that neither operation is dead. 726 auto gte = xla_while->parent()->AddInstruction( 727 HloInstruction::CreateGetTupleElement( 728 ShapeUtil::GetSubshape(xla_while->shape(), {1}), xla_while, 1)); 729 auto sub = xla_while->parent()->AddInstruction(HloInstruction::CreateBinary( 730 data_shape_, HloOpcode::kSubtract, add, gte)); 731 auto gte0 = xla_while->parent()->AddInstruction( 732 HloInstruction::CreateGetTupleElement( 733 ShapeUtil::GetSubshape(xla_while->shape(), {0}), xla_while, 0)); 734 auto tuple = xla_while->parent()->AddInstruction( 735 HloInstruction::CreateTuple({gte0, sub})); 736 737 xla_while->parent()->set_root_instruction(tuple); 738 739 return xla_while; 740 } 741 742 HloInstruction* BuildWhileInstructionWithCustomInit( 743 const Shape& loop_state_shape, HloInstruction* data_init, 744 HloComputation::Builder* builder) { 745 const bool nested = 746 ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_); 747 auto induction_var_init = builder->AddInstruction( 748 HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))); 749 auto condition = module_->AddEmbeddedComputation( 750 BuildConditionComputation(loop_state_shape)); 751 auto body = module_->AddEmbeddedComputation( 752 BuildIndependentBodyComputation(nested)); 753 auto loop_state_init = builder->AddInstruction( 754 HloInstruction::CreateTuple({induction_var_init, data_init})); 755 auto while_hlo = builder->AddInstruction(HloInstruction::CreateWhile( 756 loop_state_shape, condition, body, loop_state_init)); 757 module_->AddEntryComputation(builder->Build()); 758 return while_hlo; 759 } 760 761 std::unique_ptr<HloModule> module_; 762 Shape induction_variable_shape_ = ShapeUtil::MakeShape(S32, {}); 763 Shape data_shape_ = ShapeUtil::MakeShape(F32, {8}); 764 Shape loop_state_shape_ = 765 ShapeUtil::MakeTupleShape({induction_variable_shape_, data_shape_}); 766 Shape nested_tuple_shape_ = 767 ShapeUtil::MakeTupleShape({data_shape_, data_shape_}); 768 Shape nested_loop_state_shape_ = ShapeUtil::MakeTupleShape( 769 {induction_variable_shape_, nested_tuple_shape_}); 770 Shape condition_result_shape_ = ShapeUtil::MakeShape(PRED, {}); 771 }; 772 773 // Tests while body computation with independent tuple elements: 774 // 775 // While.Body({in0, in1}) 776 // out0 = Add(in0, 1) 777 // out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) 778 // Tuple(out0, out1) 779 // 780 // CopyInsertion pass should not generate any copies. 781 // 782 TEST_F(WhileCopyInsertionTest, IndependentTupleElements) { 783 auto condition = module_->AddEmbeddedComputation( 784 BuildConditionComputation(loop_state_shape_)); 785 auto body = 786 module_->AddEmbeddedComputation(BuildIndependentBodyComputation()); 787 auto while_hlo = BuildWhileInstruction(condition, body); 788 789 InsertCopies(module_.get()); 790 791 // Body should have no copies as the adds can be done inplace. 792 EXPECT_EQ(CountCopies(*body), 0); 793 EXPECT_EQ(CountControlEdges(*module_), 0); 794 795 // Both init indices need copies as they are constants. 796 EXPECT_THAT(while_hlo->operand(0), 797 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant()))); 798 } 799 800 // Tests while body computation with dependent tuple elements: 801 // 802 // While.Body({in0, in1}) 803 // out0 = Add(in0, 1) 804 // out1 = Add(BCast(in0), in1) 805 // Tuple(out0, out1) 806 // 807 // CopyInsertion pass should convert the root instruction to: 808 // 809 // Tuple(Copy(out0), out1) 810 // 811 TEST_F(WhileCopyInsertionTest, DependentTupleElements) { 812 auto condition = module_->AddEmbeddedComputation( 813 BuildConditionComputation(loop_state_shape_)); 814 auto body = module_->AddEmbeddedComputation(BuildDependentBodyComputation()); 815 auto while_hlo = BuildWhileInstruction(condition, body); 816 817 InsertCopies(module_.get()); 818 819 EXPECT_EQ(CountCopies(*body), 1); 820 EXPECT_EQ(CountControlEdges(*body), 0); 821 822 EXPECT_THAT( 823 body->root_instruction(), 824 op::Tuple(op::Add(), op::Add(op::GetTupleElement(), op::Broadcast()))); 825 826 auto add = body->root_instruction()->operand(0); 827 auto bcast = body->root_instruction()->operand(1)->operand(1); 828 ASSERT_EQ(add->opcode(), HloOpcode::kAdd); 829 ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); 830 831 EXPECT_THAT( 832 while_hlo->while_body()->root_instruction(), 833 op::Tuple(op::Add(op::Copy(), op::Constant()), 834 op::Add(op::GetTupleElement(), op::Broadcast(op::Copy())))); 835 836 // Both init indices need copies as they are constants. 837 EXPECT_THAT(while_hlo->operand(0), 838 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant()))); 839 } 840 841 // Tests while body computation with read-only tuple element 0: 842 // 843 // PARAMETER 844 // / \ 845 // GTE(0) GTE(1) 846 // | \ | 847 // | BCAST | 848 // | \ | 849 // | ADD 850 // | | 851 // \ / 852 // TUPLE (root) 853 // 854 // CopyInsertion pass should not generate any copies for the while body. 855 TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) { 856 auto condition = module_->AddEmbeddedComputation( 857 BuildConditionComputation(loop_state_shape_)); 858 auto body = module_->AddEmbeddedComputation( 859 BuildDependentBodyOneReadOnlyComputation()); 860 BuildWhileInstruction(condition, body); 861 862 InsertCopies(module_.get()); 863 864 // No copies or control edges should be inserted. The body is legal as is. 865 EXPECT_EQ(CountCopies(*body), 0); 866 EXPECT_EQ(CountControlEdges(*body), 0); 867 } 868 869 // Same as above, but with two while loops, sharing entry parameters. 870 TEST_F(WhileCopyInsertionTest, 871 DependentTupleElements_OneReadOnly_TwoLoops_EntryParams) { 872 auto condition1 = module_->AddEmbeddedComputation( 873 BuildConditionComputation(loop_state_shape_)); 874 auto condition2 = module_->AddEmbeddedComputation( 875 BuildConditionComputation(loop_state_shape_)); 876 auto body1 = module_->AddEmbeddedComputation( 877 BuildDependentBodyOneReadOnlyComputation()); 878 auto body2 = module_->AddEmbeddedComputation( 879 BuildDependentBodyOneReadOnlyComputation()); 880 881 auto builder = HloComputation::Builder(TestName() + ".While"); 882 auto iter_param = builder.AddInstruction( 883 HloInstruction::CreateParameter(0, induction_variable_shape_, "iter")); 884 auto data_param = builder.AddInstruction( 885 HloInstruction::CreateParameter(1, data_shape_, "data")); 886 auto loop_init = builder.AddInstruction( 887 HloInstruction::CreateTuple({iter_param, data_param})); 888 889 auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( 890 loop_state_shape_, condition1, body1, loop_init)); 891 auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( 892 loop_state_shape_, condition2, body2, loop_init)); 893 894 // Add a couple elements from each of the while so both whiles are live. 895 auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 896 ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0)); 897 auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 898 ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0)); 899 builder.AddInstruction( 900 HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2)); 901 902 auto entry = module_->AddEntryComputation(builder.Build()); 903 904 InsertCopies(module_.get()); 905 906 // Neither body should have any copies or control edges in them. 907 EXPECT_EQ(CountCopies(*body1), 0); 908 EXPECT_EQ(CountCopies(*body2), 0); 909 EXPECT_EQ(CountControlEdges(*body1), 0); 910 EXPECT_EQ(CountControlEdges(*body2), 0); 911 912 // Only two copies should be necessary. Each of the whiles should have 913 // a copy of tuple element 1 (init value is a parameter, and the element is 914 // not non-read-only) so each of the while bodies gets its own buffer to write 915 // element 1 into. 916 EXPECT_EQ(CountCopies(*entry), 2); 917 918 EXPECT_EQ(while_hlo1->operand(0)->operand(1)->opcode(), HloOpcode::kCopy); 919 EXPECT_EQ(while_hlo2->operand(0)->operand(1)->opcode(), HloOpcode::kCopy); 920 921 // The two copies of element 1 should be different. 922 EXPECT_NE(while_hlo1->operand(0)->operand(1), 923 while_hlo2->operand(0)->operand(1)); 924 } 925 926 // Same as above, but with two while loops, sharing non-parameters. 927 TEST_F(WhileCopyInsertionTest, 928 DependentTupleElements_OneReadOnly_TwoLoops_NonParams) { 929 auto condition1 = module_->AddEmbeddedComputation( 930 BuildConditionComputation(loop_state_shape_)); 931 auto condition2 = module_->AddEmbeddedComputation( 932 BuildConditionComputation(loop_state_shape_)); 933 auto body1 = module_->AddEmbeddedComputation( 934 BuildDependentBodyOneReadOnlyComputation()); 935 auto body2 = module_->AddEmbeddedComputation( 936 BuildDependentBodyOneReadOnlyComputation()); 937 938 auto builder = HloComputation::Builder(TestName() + ".While"); 939 auto iter_param = builder.AddInstruction( 940 HloInstruction::CreateParameter(0, induction_variable_shape_, "iter")); 941 auto data_param = builder.AddInstruction( 942 HloInstruction::CreateParameter(1, data_shape_, "data")); 943 // Add dummy ops to ensure loop_init elements aren't entry parameters. 944 auto iter_value = builder.AddInstruction(HloInstruction::CreateUnary( 945 iter_param->shape(), HloOpcode::kExp, iter_param)); 946 auto data_value = builder.AddInstruction(HloInstruction::CreateUnary( 947 data_param->shape(), HloOpcode::kExp, data_param)); 948 auto loop_init = builder.AddInstruction( 949 HloInstruction::CreateTuple({iter_value, data_value})); 950 951 auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( 952 loop_state_shape_, condition1, body1, loop_init)); 953 auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( 954 loop_state_shape_, condition2, body2, loop_init)); 955 956 // Add a couple elements from each of the while so both whiles are not dead. 957 auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 958 ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0)); 959 auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 960 ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0)); 961 builder.AddInstruction( 962 HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2)); 963 auto entry = module_->AddEntryComputation(builder.Build()); 964 965 InsertCopies(module_.get()); 966 967 // Ideally only one copy should be necessary. One of the whiles should 968 // have a copy of tuple element 1 (the non-read-only element) so each of the 969 // while bodies gets its own buffer to write element 1 into. However, the 970 // analysis isn't perfect and adds an additional copy of element 0. 971 EXPECT_EQ(CountCopies(*entry), 2); 972 973 EXPECT_THAT(while_hlo1->operand(0), 974 op::Tuple(op::Exp(), op::Copy(op::Exp()))); 975 EXPECT_THAT(while_hlo2->operand(0), 976 op::Tuple(op::Exp(), op::Copy(op::Exp()))); 977 } 978 979 // Tests while body computation with nested tuple elements: 980 // 981 // | 982 // GTE(loop_state, 1) 983 // / \ 984 // GTE(GTE(loop_state, 1), 0) GTE(GTE(loop_state, 1), 1) 985 // | | 986 // Add Reverse 987 // | | 988 // 989 // CopyInsertion pass will conceptually generate the following, but with the 990 // actual GTE and Tuple instructions optimized away: 991 // 992 // Tuple // old root 993 // / \ 994 // / \ 995 // GTE(0) GTE(1) 996 // | / \ 997 // | / \ 998 // | GTE(0) GTE(1) 999 // | | | 1000 // | | Copy 1001 // | | | 1002 // \ | / 1003 // \ Tuple // "inner" tuple. 1004 // \ / 1005 // \ / 1006 // Tuple // new root 1007 // 1008 TEST_F(WhileCopyInsertionTest, NestedTupleElements) { 1009 auto condition = module_->AddEmbeddedComputation( 1010 BuildConditionComputation(nested_loop_state_shape_)); 1011 auto body = module_->AddEmbeddedComputation(BuildNestedBodyComputation()); 1012 BuildWhileInstruction(condition, body, true); 1013 1014 // HloInstruction* old_root = body->root_instruction(); 1015 InsertCopies(module_.get()); 1016 1017 // The only copy necessary is for the kReverse as it cannot be done 1018 // in-place (instruction can share buffer with operand). The other elements of 1019 // the loop state are kAdd instructions which can be done in-place. 1020 EXPECT_EQ(CountCopies(*body), 1); 1021 1022 // Each element of the init needs a copy as all are constants. 1023 EXPECT_EQ(CountCopies(*module_), 4); 1024 1025 // Either the kReverse itself must be copied or the operand of the kReverse 1026 // must be copied. 1027 if (body->root_instruction()->operand(1)->operand(1)->opcode() == 1028 HloOpcode::kCopy) { 1029 EXPECT_THAT( 1030 body->root_instruction(), 1031 op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Reverse())))); 1032 } else { 1033 EXPECT_THAT( 1034 body->root_instruction(), 1035 op::Tuple(op::Add(), op::Tuple(op::Add(), op::Reverse(op::Copy())))); 1036 } 1037 } 1038 1039 // Tests while init instruction which points-to a constant. 1040 // 1041 // init = Tuple(Constant(S32, {}), Constant(F32, {8})) 1042 // 1043 // CopyInsertion pass should add copies for both constants. 1044 // 1045 TEST_F(WhileCopyInsertionTest, InitPointsToConstant) { 1046 auto while_hlo = BuildWhileInstruction_InitPointsToConstant(); 1047 1048 InsertCopies(module_.get()); 1049 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0); 1050 EXPECT_EQ(CountCopies(*module_), 2); 1051 1052 EXPECT_THAT(while_hlo->operand(0), 1053 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant()))); 1054 } 1055 1056 // Tests while init instruction which points-to a parameter. 1057 // 1058 // init = Tuple(Constant(S32, {}), Parameter(F32, {8})) 1059 // 1060 // CopyInsertion pass should add copies for both the constant and parameter. 1061 // 1062 TEST_F(WhileCopyInsertionTest, InitPointsToParameter) { 1063 auto while_hlo = BuildWhileInstruction_InitPointsToParameter(); 1064 1065 InsertCopies(module_.get()); 1066 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0); 1067 EXPECT_EQ(CountCopies(*module_), 2); 1068 1069 EXPECT_THAT(while_hlo->operand(0), 1070 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Parameter()))); 1071 } 1072 1073 // Tests while init instruction which has an ambiguous points-to set. 1074 // 1075 // select = Select(pred, tuple1, tuple2) 1076 // init = Tuple(Constant(S32, {}), Parameter(F32, {8})) 1077 // 1078 // CopyInsertion pass will conceptually generate the following, but with some of 1079 // the actual GTE and Tuple instructions optimized away: 1080 // 1081 // Tuple // old init 1082 // / \ 1083 // / \ 1084 // GTE(0) GTE(1) 1085 // | / \ 1086 // | / \ 1087 // | GTE(0) GTE(1) 1088 // | | | 1089 // Copy Copy Copy 1090 // | | | 1091 // \ | / 1092 // \ Tuple 1093 // \ / 1094 // \ / 1095 // Tuple // new init 1096 // 1097 TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) { 1098 auto while_hlo = BuildWhileInstruction_InitPointsToAmbiguous(); 1099 1100 InsertCopies(module_.get()); 1101 EXPECT_EQ(CountCopies(*module_), 4); 1102 // The entry computation requires three copies to resolve the ambiguity of two 1103 // init elements and the constant passed in as one of the init elements. 1104 EXPECT_EQ(CountCopies(*module_->entry_computation()), 3); 1105 EXPECT_THAT(while_hlo->operand(0), 1106 op::Tuple(op::Copy(op::Constant()), 1107 op::Tuple(op::Copy(op::GetTupleElement()), 1108 op::Copy(op::GetTupleElement())))); 1109 1110 // The body requires one copy because the buffer set is not distinct: the 1111 // result of one of the adds is written into two elements of the output of the 1112 // loop body. Either element might be copied. 1113 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1); 1114 if (while_hlo->while_body() 1115 ->root_instruction() 1116 ->operand(1) 1117 ->operand(0) 1118 ->opcode() == HloOpcode::kCopy) { 1119 EXPECT_THAT( 1120 while_hlo->while_body()->root_instruction(), 1121 op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add()))); 1122 } else { 1123 EXPECT_THAT( 1124 while_hlo->while_body()->root_instruction(), 1125 op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add())))); 1126 } 1127 } 1128 1129 // Tests while init instruction which has a non-distinct points-to set. 1130 // 1131 // init = Tuple(Constant(S32, {}), Tuple({vec_one, vec_one})) 1132 // 1133 // CopyInsertion pass will conceptually generate the following, but with some of 1134 // the actual GTE and Tuple instructions optimized away: 1135 // 1136 // Tuple // old init 1137 // / \ 1138 // / \ 1139 // GTE(0) GTE(1) 1140 // | / \ 1141 // | / \ 1142 // | GTE(0) GTE(1) 1143 // | | | 1144 // Copy Copy Copy 1145 // | | | 1146 // \ | / 1147 // \ Tuple 1148 // \ / 1149 // \ / 1150 // Tuple // new init 1151 // 1152 TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) { 1153 auto while_hlo = BuildWhileInstruction_InitPointsToNonDistinct(); 1154 1155 InsertCopies(module_.get()); 1156 1157 // The entry computation requires two copies to resolve the non-disinctness of 1158 // two init elements and the constant passed in as one of the init 1159 // elements. Either element can be copied for the distinctness issue. 1160 EXPECT_EQ(CountCopies(*module_->entry_computation()), 2); 1161 if (while_hlo->operand(0)->operand(1)->operand(0)->opcode() == 1162 HloOpcode::kCopy) { 1163 EXPECT_THAT( 1164 while_hlo->operand(0), 1165 op::Tuple(op::Copy(op::Constant()), 1166 op::Tuple(op::Copy(op::Broadcast()), op::Broadcast()))); 1167 } else { 1168 EXPECT_THAT( 1169 while_hlo->operand(0), 1170 op::Tuple(op::Copy(op::Constant()), 1171 op::Tuple(op::Broadcast(), op::Copy(op::Broadcast())))); 1172 } 1173 1174 // The body requires one copy because the buffer set is not distinct: the 1175 // result of one of the adds is written into two elements of the output of the 1176 // loop body. Either element might be copied. 1177 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1); 1178 if (while_hlo->while_body() 1179 ->root_instruction() 1180 ->operand(1) 1181 ->operand(0) 1182 ->opcode() == HloOpcode::kCopy) { 1183 EXPECT_THAT( 1184 while_hlo->while_body()->root_instruction(), 1185 op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add()))); 1186 } else { 1187 EXPECT_THAT( 1188 while_hlo->while_body()->root_instruction(), 1189 op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add())))); 1190 } 1191 } 1192 1193 // Tests while init instruction buffer which interferes with while result 1194 // buffer. 1195 // 1196 // init_data = Broadcast(...) 1197 // add_unrelated = Add(init_data) // takes a reference to cause interference 1198 // init = Tuple(Constant(S32, {}), init_data)) 1199 // 1200 // CopyInsertion pass should copy both operands. 1201 // 1202 TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) { 1203 auto while_hlo = BuildWhileInstruction_InitPointsToInterfering(); 1204 1205 InsertCopies(module_.get()); 1206 EXPECT_EQ(CountCopies(*module_), 2); 1207 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0); 1208 1209 EXPECT_THAT(while_hlo->operand(0), 1210 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Broadcast()))); 1211 } 1212 1213 // Tests while init instruction buffer which has a non-distinct points-to set: 1214 // 1215 // init = Tuple(Parameter(S32, {}), Parameter(F32, {8}, 1216 // Parameter(F32, {8}))) 1217 // 1218 // where the second and third parameters are identical *and* the tuple shared 1219 // by another while instruction. 1220 // 1221 // Verifies that the resulting point-to set is distinct in the resulting Tuple 1222 // (non-identical Copys). In other words, verifies that copy sharing does not 1223 // insert identical copies to the resulting tuple. 1224 TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { 1225 // Loop body that outputs tuple comprises two elements dependent on the init 1226 // tuple. 1227 const Shape& loop_state_shape = ShapeUtil::MakeTupleShape( 1228 {induction_variable_shape_, data_shape_, data_shape_}); 1229 1230 auto condition1 = module_->AddEmbeddedComputation( 1231 BuildConditionComputation(loop_state_shape)); 1232 auto condition2 = module_->AddEmbeddedComputation( 1233 BuildConditionComputation(loop_state_shape)); 1234 auto body1 = 1235 module_->AddEmbeddedComputation(BuildDependentBodyComputation2()); 1236 auto body2 = 1237 module_->AddEmbeddedComputation(BuildDependentBodyComputation2()); 1238 1239 auto builder = HloComputation::Builder(TestName() + ".While"); 1240 1241 auto iter_param = builder.AddInstruction( 1242 HloInstruction::CreateParameter(0, induction_variable_shape_, "iter")); 1243 auto data_param = builder.AddInstruction( 1244 HloInstruction::CreateParameter(1, data_shape_, "data")); 1245 1246 // Loop init tuple contains two identical parameter buffers. 1247 auto loop_init = builder.AddInstruction( 1248 HloInstruction::CreateTuple({iter_param, data_param, data_param})); 1249 1250 1251 // Two while loops shares the same loop init tuple. 1252 auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( 1253 loop_state_shape, condition1, body1, loop_init)); 1254 auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( 1255 loop_state_shape, condition2, body2, loop_init)); 1256 1257 // Add add instruction so neither while is dead. 1258 auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 1259 ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0)); 1260 auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 1261 ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo2, 0)); 1262 builder.AddInstruction( 1263 HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2)); 1264 1265 module_->AddEntryComputation(builder.Build()); 1266 1267 InsertCopies(module_.get()); 1268 1269 // None of the bodies should have copies or control flow edges. 1270 EXPECT_EQ(CountCopies(*body1), 0); 1271 EXPECT_EQ(CountCopies(*body2), 0); 1272 1273 // The loop bodies pass through elements 1 and 2 in the init tuple, so ideally 1274 // these should not need to be copied before either while. However, copy 1275 // insertion is not able to reason about the transparency of elements through 1276 // while bodies in all circumstances so extra copies are added (b/xxx). 1277 EXPECT_EQ(CountCopies(*module_->entry_computation()), 2); 1278 1279 EXPECT_THAT(while_hlo1->operand(0), 1280 op::Tuple(op::Copy(), op::Parameter(), op::Parameter())); 1281 EXPECT_THAT(while_hlo2->operand(0), 1282 op::Tuple(op::Copy(), op::Parameter(), op::Parameter())); 1283 } 1284 1285 TEST_F(CopyInsertionTest, SwizzlingWhile) { 1286 // Test a while instruction with a body which permutes its tuple parameter 1287 // elements. 1288 auto module = CreateNewModule(); 1289 const Shape loop_state_shape = 1290 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); 1291 1292 // Body simply interchanges the two tuple elements in the loop state. 1293 auto body_builder = HloComputation::Builder("body"); 1294 auto body_param = body_builder.AddInstruction( 1295 HloInstruction::CreateParameter(0, loop_state_shape, "param")); 1296 auto body_element_0 = body_builder.AddInstruction( 1297 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); 1298 auto body_element_1 = body_builder.AddInstruction( 1299 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); 1300 body_builder.AddInstruction( 1301 HloInstruction::CreateTuple({body_element_1, body_element_0})); 1302 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); 1303 1304 auto cond_builder = HloComputation::Builder("condition"); 1305 cond_builder.AddInstruction( 1306 HloInstruction::CreateParameter(0, loop_state_shape, "param")); 1307 auto cond_constant = cond_builder.AddInstruction( 1308 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 1309 cond_builder.AddInstruction(HloInstruction::CreateUnary( 1310 cond_constant->shape(), HloOpcode::kNot, cond_constant)); 1311 HloComputation* condition = 1312 module->AddEmbeddedComputation(cond_builder.Build()); 1313 1314 auto builder = HloComputation::Builder(TestName()); 1315 auto constant1 = builder.AddInstruction( 1316 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 1317 auto constant2 = builder.AddInstruction( 1318 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 1319 auto tuple = builder.AddInstruction( 1320 HloInstruction::CreateTuple({constant1, constant2})); 1321 auto xla_while = builder.AddInstruction( 1322 HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple)); 1323 module->AddEntryComputation(builder.Build()); 1324 1325 InsertCopies(module.get()); 1326 1327 EXPECT_EQ(CountCopies(*module), 6); 1328 1329 // The loop state elements should be copied at the parameter and at the root 1330 // with a control edge in between (see DeepCopyAndAddControlEdges). This is 1331 // technically one more copy than is strictly necessary, but in order to have 1332 // only three copies the copies of different loop state elements must be 1333 // ordered with a control edge. 1334 EXPECT_EQ(CountCopies(*body), 4); 1335 EXPECT_EQ(CountControlEdges(*body), 2); 1336 1337 EXPECT_THAT(body->root_instruction(), 1338 op::Tuple(op::Copy(op::Copy()), op::Copy(op::Copy()))); 1339 1340 EXPECT_EQ(CountCopies(*module->entry_computation()), 2); 1341 EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy())); 1342 } 1343 1344 TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { 1345 // Test a while instruction with a body which permutes its tuple parameter 1346 // elements and applies one operation to one of the elements. The addition of 1347 // the operation (instruction) on the element makes the live range of the 1348 // respective input and output elements different than if the instruction were 1349 // not there (as in the SwizzlingWhile test above). 1350 auto module = CreateNewModule(); 1351 const Shape loop_state_shape = 1352 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); 1353 1354 // Body interchanges the two tuple elements in the loop state and negates one 1355 // of them. 1356 auto body_builder = HloComputation::Builder("body"); 1357 auto body_param = body_builder.AddInstruction( 1358 HloInstruction::CreateParameter(0, loop_state_shape, "param")); 1359 auto body_element_0 = body_builder.AddInstruction( 1360 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); 1361 auto body_element_1 = body_builder.AddInstruction( 1362 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); 1363 auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( 1364 scalar_shape_, HloOpcode::kNegate, body_element_1)); 1365 body_builder.AddInstruction( 1366 HloInstruction::CreateTuple({negate, body_element_0})); 1367 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); 1368 1369 auto cond_builder = HloComputation::Builder("condition"); 1370 cond_builder.AddInstruction( 1371 HloInstruction::CreateParameter(0, loop_state_shape, "param")); 1372 auto cond_constant = cond_builder.AddInstruction( 1373 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 1374 cond_builder.AddInstruction(HloInstruction::CreateUnary( 1375 cond_constant->shape(), HloOpcode::kNot, cond_constant)); 1376 HloComputation* condition = 1377 module->AddEmbeddedComputation(cond_builder.Build()); 1378 1379 auto builder = HloComputation::Builder(TestName()); 1380 auto constant1 = builder.AddInstruction( 1381 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 1382 auto constant2 = builder.AddInstruction( 1383 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 1384 auto tuple = builder.AddInstruction( 1385 HloInstruction::CreateTuple({constant1, constant2})); 1386 auto xla_while = builder.AddInstruction( 1387 HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple)); 1388 module->AddEntryComputation(builder.Build()); 1389 1390 InsertCopies(module.get()); 1391 1392 EXPECT_EQ(CountCopies(*module), 6); 1393 1394 // The loop state elements should be copied at the parameter and at the root 1395 // with a control edge in between (see DeepCopyAndAddControlEdges). 1396 EXPECT_EQ(CountCopies(*body), 4); 1397 EXPECT_EQ(CountControlEdges(*body), 2); 1398 1399 EXPECT_THAT( 1400 body->root_instruction(), 1401 op::Tuple(op::Copy(op::Negate(op::Copy())), op::Copy(op::Copy()))); 1402 1403 EXPECT_EQ(CountCopies(*module->entry_computation()), 2); 1404 EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy())); 1405 } 1406 1407 TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) { 1408 // Test a while instruction with a body which permutes it's tuple parameter 1409 // elements similar to SwizzlinWhile above. However, in this test the input to 1410 // the while body is a single constant (both loop state elements are the same 1411 // constant). This means no copies are necessary because both loop state 1412 // elements are the same so interchanging them is a no-op. 1413 auto module = CreateNewModule(); 1414 const Shape loop_state_shape = 1415 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); 1416 1417 // Body simply interchanges the two tuple elements in the loop state. 1418 auto body_builder = HloComputation::Builder("body"); 1419 auto body_param = body_builder.AddInstruction( 1420 HloInstruction::CreateParameter(0, loop_state_shape, "param")); 1421 auto body_element_0 = body_builder.AddInstruction( 1422 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); 1423 auto body_element_1 = body_builder.AddInstruction( 1424 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); 1425 body_builder.AddInstruction( 1426 HloInstruction::CreateTuple({body_element_1, body_element_0})); 1427 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); 1428 1429 auto cond_builder = HloComputation::Builder("condition"); 1430 cond_builder.AddInstruction( 1431 HloInstruction::CreateParameter(0, loop_state_shape, "param")); 1432 auto cond_constant = cond_builder.AddInstruction( 1433 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 1434 cond_builder.AddInstruction(HloInstruction::CreateUnary( 1435 cond_constant->shape(), HloOpcode::kNot, cond_constant)); 1436 HloComputation* condition = 1437 module->AddEmbeddedComputation(cond_builder.Build()); 1438 1439 auto builder = HloComputation::Builder(TestName()); 1440 auto constant = builder.AddInstruction( 1441 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 1442 auto tuple = 1443 builder.AddInstruction(HloInstruction::CreateTuple({constant, constant})); 1444 builder.AddInstruction( 1445 HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple)); 1446 module->AddEntryComputation(builder.Build()); 1447 1448 InsertCopies(module.get()); 1449 1450 EXPECT_EQ(CountCopies(*module), 2); 1451 EXPECT_EQ(CountCopies(*body), 0); 1452 1453 EXPECT_EQ(CountCopies(*module->entry_computation()), 2); 1454 EXPECT_THAT(module->entry_computation()->root_instruction(), 1455 op::Tuple(op::Copy(), op::Copy())); 1456 } 1457 1458 TEST_F(CopyInsertionTest, SequentialWhiles) { 1459 // Construct a computation with a series of sequential while instructions 1460 // containing four loop state elements: 1461 // 1462 // element 0 is passed to each while directly from an entry parameter. 1463 // 1464 // element 1 is passed transparently in series through all the while bodies. 1465 // 1466 // element 2 is negated in each while body. (in-place possible) 1467 // 1468 // element 3 is reversed in each while body. (in-place not possible) 1469 // 1470 const Shape element_shape = ShapeUtil::MakeShape(F32, {42}); 1471 const Shape loop_state_shape = ShapeUtil::MakeTupleShape( 1472 {element_shape, element_shape, element_shape, element_shape}); 1473 1474 auto module = CreateNewModule(); 1475 auto builder = HloComputation::Builder(TestName()); 1476 auto param_0 = builder.AddInstruction( 1477 HloInstruction::CreateParameter(0, element_shape, "param_0")); 1478 auto param_1 = builder.AddInstruction( 1479 HloInstruction::CreateParameter(1, element_shape, "param_1")); 1480 auto param_2 = builder.AddInstruction( 1481 HloInstruction::CreateParameter(2, element_shape, "param_2")); 1482 auto param_3 = builder.AddInstruction( 1483 HloInstruction::CreateParameter(3, element_shape, "param_3")); 1484 1485 // The number of sequential kWhile instructions. 1486 const int kNumWhiles = 3; 1487 1488 HloInstruction* prev_element_1 = param_1; 1489 HloInstruction* prev_element_2 = param_2; 1490 HloInstruction* prev_element_3 = param_3; 1491 1492 // Vector containing all of the while instructions. 1493 std::vector<const HloInstruction*> whiles; 1494 for (int i = 0; i < kNumWhiles; ++i) { 1495 auto body_builder = HloComputation::Builder("body"); 1496 auto body_param = body_builder.AddInstruction( 1497 HloInstruction::CreateParameter(0, loop_state_shape, "param")); 1498 auto body_element_0 = body_builder.AddInstruction( 1499 HloInstruction::CreateGetTupleElement(element_shape, body_param, 0)); 1500 auto body_element_1 = body_builder.AddInstruction( 1501 HloInstruction::CreateGetTupleElement(element_shape, body_param, 1)); 1502 auto body_element_2 = body_builder.AddInstruction( 1503 HloInstruction::CreateGetTupleElement(element_shape, body_param, 2)); 1504 auto body_element_3 = body_builder.AddInstruction( 1505 HloInstruction::CreateGetTupleElement(element_shape, body_param, 3)); 1506 auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( 1507 element_shape, HloOpcode::kNegate, body_element_2)); 1508 auto reverse = body_builder.AddInstruction( 1509 HloInstruction::CreateReverse(element_shape, body_element_3, {0})); 1510 body_builder.AddInstruction(HloInstruction::CreateTuple( 1511 {body_element_0, body_element_1, negate, reverse})); 1512 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); 1513 1514 auto cond_builder = HloComputation::Builder("condition"); 1515 cond_builder.AddInstruction( 1516 HloInstruction::CreateParameter(0, loop_state_shape, "param")); 1517 auto cond_constant = cond_builder.AddInstruction( 1518 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 1519 cond_builder.AddInstruction(HloInstruction::CreateUnary( 1520 cond_constant->shape(), HloOpcode::kNot, cond_constant)); 1521 HloComputation* condition = 1522 module->AddEmbeddedComputation(cond_builder.Build()); 1523 1524 auto while_init = builder.AddInstruction(HloInstruction::CreateTuple( 1525 {param_0, prev_element_1, prev_element_2, prev_element_3})); 1526 1527 auto xla_while = builder.AddInstruction(HloInstruction::CreateWhile( 1528 loop_state_shape, condition, body, while_init)); 1529 whiles.push_back(xla_while); 1530 if (i != kNumWhiles - 1) { 1531 prev_element_1 = builder.AddInstruction( 1532 HloInstruction::CreateGetTupleElement(element_shape, xla_while, 1)); 1533 prev_element_2 = builder.AddInstruction( 1534 HloInstruction::CreateGetTupleElement(element_shape, xla_while, 2)); 1535 prev_element_3 = builder.AddInstruction( 1536 HloInstruction::CreateGetTupleElement(element_shape, xla_while, 3)); 1537 } 1538 } 1539 1540 module->AddEntryComputation(builder.Build()); 1541 1542 InsertCopies(module.get()); 1543 1544 // Each while body has one copy. And each loop state element is copied once in 1545 // the entry computation. 1546 EXPECT_EQ(CountCopies(*module), 4 + kNumWhiles); 1547 1548 // Each while body should have exactly one copy for element three which is an 1549 // op (kReverse) which cannot be done in place. 1550 for (const HloInstruction* xla_while : whiles) { 1551 EXPECT_EQ(CountCopies(*xla_while->while_body()), 1); 1552 } 1553 1554 EXPECT_THAT(whiles[0]->operand(0), op::Tuple(op::Parameter(), op::Parameter(), 1555 op::Copy(), op::Copy())); 1556 EXPECT_THAT(module->entry_computation()->root_instruction(), 1557 op::Tuple(op::Copy(), op::Copy(), op::GetTupleElement(), 1558 op::GetTupleElement())); 1559 } 1560 1561 TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) { 1562 // Test a while body and condition which are each simply a constant (root of 1563 // computation is a constant). The body constant should be copied. 1564 auto module = CreateNewModule(); 1565 auto builder = HloComputation::Builder(TestName()); 1566 auto param_0 = builder.AddInstruction( 1567 HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); 1568 1569 auto body_builder = HloComputation::Builder("body"); 1570 body_builder.AddInstruction( 1571 HloInstruction::CreateParameter(0, scalar_shape_, "param")); 1572 body_builder.AddInstruction( 1573 HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0))); 1574 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); 1575 1576 auto cond_builder = HloComputation::Builder("condition"); 1577 cond_builder.AddInstruction( 1578 HloInstruction::CreateParameter(0, scalar_shape_, "param")); 1579 cond_builder.AddInstruction( 1580 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 1581 HloComputation* condition = 1582 module->AddEmbeddedComputation(cond_builder.Build()); 1583 1584 auto xla_while = builder.AddInstruction( 1585 HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0)); 1586 1587 module->AddEntryComputation(builder.Build()); 1588 1589 InsertCopies(module.get()); 1590 1591 EXPECT_EQ(CountCopies(*module), 2); 1592 1593 EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter())); 1594 EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant())); 1595 EXPECT_THAT(condition->root_instruction(), op::Constant()); 1596 } 1597 1598 std::unique_ptr<HloComputation> MakeTrivialCondition(const Shape& shape) { 1599 auto builder = HloComputation::Builder("trivial_condition"); 1600 builder.AddInstruction( 1601 HloInstruction::CreateParameter(0, shape, "loop_state")); 1602 auto constant = builder.AddInstruction( 1603 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 1604 builder.AddInstruction(HloInstruction::CreateUnary( 1605 constant->shape(), HloOpcode::kNot, constant)); 1606 return builder.Build(); 1607 } 1608 1609 std::unique_ptr<HloComputation> MakeBenchmarkWhileBody() { 1610 auto builder = HloComputation::Builder("benchmark_loop_body"); 1611 const Shape element_shape = ShapeUtil::MakeShape(F32, {42}); 1612 const Shape loop_state_shape = 1613 ShapeUtil::MakeTupleShape({element_shape, element_shape, element_shape}); 1614 HloInstruction* param = builder.AddInstruction( 1615 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); 1616 HloInstruction* element_0 = builder.AddInstruction( 1617 HloInstruction::CreateGetTupleElement(element_shape, param, 0)); 1618 HloInstruction* element_1 = builder.AddInstruction( 1619 HloInstruction::CreateGetTupleElement(element_shape, param, 1)); 1620 HloInstruction* element_2 = builder.AddInstruction( 1621 HloInstruction::CreateGetTupleElement(element_shape, param, 2)); 1622 1623 HloInstruction* rev_1 = builder.AddInstruction( 1624 HloInstruction::CreateReverse(element_shape, element_1, {0})); 1625 HloInstruction* add_1_2 = builder.AddInstruction(HloInstruction::CreateBinary( 1626 element_shape, HloOpcode::kAdd, element_1, element_2)); 1627 1628 builder.AddInstruction( 1629 HloInstruction::CreateTuple({element_0, rev_1, add_1_2})); 1630 return builder.Build(); 1631 } 1632 1633 void BM_SequentialWhiles(int num_iters, int num_whiles) { 1634 // This benchmark constructs a chain of sequential while instructions. 1635 tensorflow::testing::StopTiming(); 1636 for (int i = 0; i < num_iters; ++i) { 1637 HloModuleConfig config; 1638 config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); 1639 HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), 1640 config); 1641 1642 auto builder = HloComputation::Builder("BM_SequentialWhiles"); 1643 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( 1644 0, ShapeUtil::MakeShape(F32, {42}), "x")); 1645 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( 1646 1, ShapeUtil::MakeShape(F32, {42}), "y")); 1647 HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( 1648 2, ShapeUtil::MakeShape(F32, {42}), "z")); 1649 HloInstruction* init = 1650 builder.AddInstruction(HloInstruction::CreateTuple({x, y, z})); 1651 1652 HloInstruction* prev_loop_state = init; 1653 for (int w = 0; w < num_whiles; ++w) { 1654 HloComputation* condition = 1655 module.AddEmbeddedComputation(MakeTrivialCondition(init->shape())); 1656 HloComputation* body = 1657 module.AddEmbeddedComputation(MakeBenchmarkWhileBody()); 1658 prev_loop_state = builder.AddInstruction(HloInstruction::CreateWhile( 1659 init->shape(), condition, body, prev_loop_state)); 1660 } 1661 module.AddEntryComputation(builder.Build()); 1662 1663 CopyInsertion copy_insertion; 1664 1665 tensorflow::testing::StartTiming(); 1666 ASSERT_IS_OK(copy_insertion.Run(&module).status()); 1667 tensorflow::testing::StopTiming(); 1668 1669 // The entry computation should have three copies, and each body has one. 1670 ASSERT_EQ(CountCopies(module), 3 + num_whiles); 1671 } 1672 } 1673 1674 void BM_ParallelWhiles(int num_iters, int num_whiles) { 1675 // This benchmark constructs a fan-out of parallel while instructions. 1676 tensorflow::testing::StopTiming(); 1677 for (int i = 0; i < num_iters; ++i) { 1678 HloModuleConfig config; 1679 config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); 1680 HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), 1681 config); 1682 1683 auto builder = HloComputation::Builder("BM_ParallelWhiles"); 1684 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( 1685 0, ShapeUtil::MakeShape(F32, {42}), "x")); 1686 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( 1687 1, ShapeUtil::MakeShape(F32, {42}), "y")); 1688 HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( 1689 2, ShapeUtil::MakeShape(F32, {42}), "z")); 1690 HloInstruction* init = 1691 builder.AddInstruction(HloInstruction::CreateTuple({x, y, z})); 1692 1693 HloInstruction* sum = nullptr; 1694 for (int w = 0; w < num_whiles; ++w) { 1695 HloComputation* condition = 1696 module.AddEmbeddedComputation(MakeTrivialCondition(init->shape())); 1697 HloComputation* body = 1698 module.AddEmbeddedComputation(MakeBenchmarkWhileBody()); 1699 1700 HloInstruction* xla_while = builder.AddInstruction( 1701 HloInstruction::CreateWhile(init->shape(), condition, body, init)); 1702 1703 if (sum == nullptr) { 1704 sum = builder.AddInstruction( 1705 HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0)); 1706 } else { 1707 HloInstruction* element_0 = builder.AddInstruction( 1708 HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0)); 1709 sum = builder.AddInstruction(HloInstruction::CreateBinary( 1710 x->shape(), HloOpcode::kAdd, sum, element_0)); 1711 } 1712 } 1713 module.AddEntryComputation(builder.Build()); 1714 1715 CopyInsertion copy_insertion; 1716 1717 tensorflow::testing::StartTiming(); 1718 ASSERT_IS_OK(copy_insertion.Run(&module).status()); 1719 tensorflow::testing::StopTiming(); 1720 1721 // Each body receives of copy of two of the parameters (the corresponding 1722 // elements in the body are modifed), and there is one copy in each body. 1723 ASSERT_EQ(CountCopies(module), 3 * num_whiles); 1724 } 1725 } 1726 1727 std::unique_ptr<HloComputation> MakeBenchmarkWhileBody( 1728 const int num_tuple_inputs) { 1729 auto builder = HloComputation::Builder("benchmark_loop_body"); 1730 const Shape element_shape = ShapeUtil::MakeShape(F32, {}); 1731 std::vector<Shape> input_shape(num_tuple_inputs, element_shape); 1732 const Shape loop_state_shape = ShapeUtil::MakeTupleShape(input_shape); 1733 HloInstruction* param = builder.AddInstruction( 1734 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); 1735 std::vector<HloInstruction*> gte_nodes(num_tuple_inputs); 1736 for (int i = 0; i < num_tuple_inputs; ++i) { 1737 gte_nodes[i] = builder.AddInstruction( 1738 HloInstruction::CreateGetTupleElement(element_shape, param, i)); 1739 } 1740 builder.AddInstruction(HloInstruction::CreateTuple(gte_nodes)); 1741 return builder.Build(); 1742 } 1743 1744 void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) { 1745 tensorflow::testing::StopTiming(); 1746 HloModuleConfig config; 1747 config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); 1748 CopyInsertion copy_insertion; 1749 const Shape element_shape = ShapeUtil::MakeShape(F32, {}); 1750 std::vector<HloInstruction*> tuple_params(num_tuple_inputs); 1751 for (int i = 0; i < num_iters; ++i) { 1752 auto builder = HloComputation::Builder("BM_ParallelWhiles"); 1753 HloModule module("BM_ManyElementTuple", VersionedComputationHandle(), 1754 config); 1755 for (int j = 0; j < num_tuple_inputs; ++j) { 1756 tuple_params[j] = builder.AddInstruction( 1757 HloInstruction::CreateParameter(j, element_shape, "")); 1758 } 1759 HloInstruction* init = 1760 builder.AddInstruction(HloInstruction::CreateTuple(tuple_params)); 1761 HloComputation* condition = 1762 module.AddEmbeddedComputation(MakeTrivialCondition(init->shape())); 1763 HloComputation* body = 1764 module.AddEmbeddedComputation(MakeBenchmarkWhileBody(num_tuple_inputs)); 1765 HloInstruction* xla_while = builder.AddInstruction( 1766 HloInstruction::CreateWhile(init->shape(), condition, body, init)); 1767 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 1768 ShapeUtil::MakeShape(F32, {}), xla_while, 0)); 1769 module.AddEntryComputation(builder.Build()); 1770 tensorflow::testing::StartTiming(); 1771 ASSERT_IS_OK(copy_insertion.Run(&module).status()); 1772 tensorflow::testing::StopTiming(); 1773 } 1774 } 1775 1776 BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); 1777 BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); 1778 BENCHMARK(BM_ManyElementTuple)->Arg(1024)->Arg(12288); 1779 1780 TEST_F(CopyInsertionTest, SimpleControlFlowTest) { 1781 const string& hlo_string = R"( 1782 HloModule TestModule 1783 1784 if-body.v5 { 1785 constant.3 = s32[] constant(-1) 1786 p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) 1787 get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1 1788 get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0 1789 get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1 1790 add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66) 1791 tuple.33 = (s32[]) tuple(add.3) 1792 ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33) 1793 } 1794 1795 if-condition.v4 { 1796 p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) 1797 get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0 1798 constant.4 = s32[] constant(0) 1799 ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4) 1800 } 1801 1802 _functionalize_body_1__.v28 { 1803 arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0) 1804 get-tuple-element.68 = s32[] get-tuple-element(arg_tuple.4), index=0 1805 constant.7 = s32[] constant(1) 1806 add.4 = s32[] add(get-tuple-element.68, constant.7) 1807 get-tuple-element.69 = s32[] get-tuple-element(arg_tuple.4), index=1 1808 get-tuple-element.70 = s32[] get-tuple-element(arg_tuple.4), index=2 1809 less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.69, get-tuple-element.70) 1810 constant.8 = s32[] constant(0) 1811 select = s32[] select(less-than-or-equal-to, constant.8, constant.7) 1812 get-tuple-element.71 = s32[] get-tuple-element(arg_tuple.4), index=3 1813 tuple.35 = (s32[], s32[], s32[]) tuple(get-tuple-element.69, get-tuple-element.71, get-tuple-element.70) 1814 tuple.36 = (s32[]) tuple(constant.8) 1815 tuple.37 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.35, tuple.36) 1816 while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.37), condition=if-condition.v4, body=if-body.v5 1817 get-tuple-element.72 = (s32[]) get-tuple-element(while), index=2 1818 get-tuple-element.73 = s32[] get-tuple-element(get-tuple-element.72), index=0 1819 ROOT tuple.38 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.69, get-tuple-element.70, get-tuple-element.73) 1820 } 1821 1822 cond_wrapper.v3.1 { 1823 inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0) 1824 get-tuple-element.75 = s32[] get-tuple-element(inputs.1), index=0 1825 constant.11 = s32[] constant(7) 1826 ROOT less-than.2 = pred[] less-than(get-tuple-element.75, constant.11) 1827 } 1828 1829 _functionalize_body_2__.v25 { 1830 arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) 1831 get-tuple-element.76 = s32[] get-tuple-element(arg_tuple.5), index=0 1832 get-tuple-element.77 = s32[] get-tuple-element(arg_tuple.5), index=2 1833 get-tuple-element.78 = s32[] get-tuple-element(arg_tuple.5), index=3 1834 get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=4 1835 tuple.39 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.76, get-tuple-element.77, get-tuple-element.78, get-tuple-element.79) 1836 while.2 = (s32[], s32[], s32[], s32[]) while(tuple.39), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28 1837 get-tuple-element.80 = s32[] get-tuple-element(while.2), index=0 1838 get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=1 1839 constant.12 = s32[] constant(1) 1840 add.5 = s32[] add(get-tuple-element.81, constant.12) 1841 get-tuple-element.82 = s32[] get-tuple-element(while.2), index=3 1842 ROOT tuple.40 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.80, add.5, get-tuple-element.77, get-tuple-element.78, get-tuple-element.82) 1843 } 1844 1845 cond_wrapper.v3.2 { 1846 inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) 1847 get-tuple-element.83 = s32[] get-tuple-element(inputs.2), index=1 1848 constant.13 = s32[] constant(5) 1849 ROOT less-than.3 = pred[] less-than(get-tuple-element.83, constant.13) 1850 } 1851 1852 ENTRY TestComputation { 1853 arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) 1854 ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25 1855 } 1856 )"; 1857 auto module_or_status = 1858 HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); 1859 auto module = module_or_status.ConsumeValueOrDie(); 1860 InsertCopies(module.get()); 1861 } 1862 1863 TEST_F(CopyInsertionTest, ControlFlowTest) { 1864 const string& hlo_string = R"( 1865 HloModule TestModule 1866 1867 if-body.v5 { 1868 constant.3 = s32[] constant(-1) 1869 p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) 1870 get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1 1871 get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0 1872 get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1 1873 add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66) 1874 tuple.33 = (s32[]) tuple(add.3) 1875 ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33) 1876 } 1877 1878 if-condition.v4 { 1879 p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) 1880 get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0 1881 constant.4 = s32[] constant(0) 1882 ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4) 1883 } 1884 1885 if-body.v5.1 { 1886 constant.5 = s32[] constant(-1) 1887 p.3 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) 1888 get-tuple-element.68 = (s32[], s32[], s32[]) get-tuple-element(p.3), index=1 1889 get-tuple-element.70 = s32[] get-tuple-element(get-tuple-element.68), index=2 1890 multiply.1 = s32[] multiply(get-tuple-element.70, get-tuple-element.70) 1891 tuple.35 = (s32[]) tuple(multiply.1) 1892 ROOT tuple.36 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.5, get-tuple-element.68, tuple.35) 1893 } 1894 1895 if-condition.v4.1 { 1896 p.4 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) 1897 get-tuple-element.71 = s32[] get-tuple-element(p.4), index=0 1898 constant.6 = s32[] constant(1) 1899 ROOT equal-to.1 = pred[] equal-to(get-tuple-element.71, constant.6) 1900 } 1901 1902 _functionalize_body_1__.v28 { 1903 arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0) 1904 get-tuple-element.72 = s32[] get-tuple-element(arg_tuple.4), index=0 1905 constant.7 = s32[] constant(1) 1906 add.4 = s32[] add(get-tuple-element.72, constant.7) 1907 get-tuple-element.73 = s32[] get-tuple-element(arg_tuple.4), index=1 1908 get-tuple-element.74 = s32[] get-tuple-element(arg_tuple.4), index=2 1909 less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.73, get-tuple-element.74) 1910 constant.8 = s32[] constant(0) 1911 select = s32[] select(less-than-or-equal-to, constant.8, constant.7) 1912 get-tuple-element.75 = s32[] get-tuple-element(arg_tuple.4), index=3 1913 tuple.37 = (s32[], s32[], s32[]) tuple(get-tuple-element.73, get-tuple-element.75, get-tuple-element.74) 1914 tuple.38 = (s32[]) tuple(constant.8) 1915 tuple.39 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.37, tuple.38) 1916 while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.39), condition=if-condition.v4, body=if-body.v5 1917 while.1 = (s32[], (s32[], s32[], s32[]), (s32[])) while(while), condition=if-condition.v4.1, body=if-body.v5.1 1918 get-tuple-element.76 = (s32[]) get-tuple-element(while.1), index=2 1919 get-tuple-element.77 = s32[] get-tuple-element(get-tuple-element.76), index=0 1920 ROOT tuple.40 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.73, get-tuple-element.74, get-tuple-element.77) 1921 } 1922 1923 cond_wrapper.v3.1 { 1924 inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0) 1925 get-tuple-element.78 = s32[] get-tuple-element(inputs.1), index=0 1926 constant.11 = s32[] constant(7) 1927 ROOT less-than.2 = pred[] less-than(get-tuple-element.78, constant.11) 1928 } 1929 1930 _functionalize_body_2__.v25 { 1931 arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) 1932 get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=0 1933 get-tuple-element.80 = s32[] get-tuple-element(arg_tuple.5), index=2 1934 get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=3 1935 get-tuple-element.82 = s32[] get-tuple-element(arg_tuple.5), index=4 1936 tuple.41 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.79, get-tuple-element.80, get-tuple-element.81, get-tuple-element.82) 1937 while.2 = (s32[], s32[], s32[], s32[]) while(tuple.41), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28 1938 get-tuple-element.83 = s32[] get-tuple-element(while.2), index=0 1939 get-tuple-element.84 = s32[] get-tuple-element(arg_tuple.5), index=1 1940 constant.12 = s32[] constant(1) 1941 add.5 = s32[] add(get-tuple-element.84, constant.12) 1942 get-tuple-element.85 = s32[] get-tuple-element(while.2), index=3 1943 ROOT tuple.42 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.83, add.5, get-tuple-element.80, get-tuple-element.81, get-tuple-element.85) 1944 } 1945 1946 cond_wrapper.v3.2 { 1947 inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) 1948 get-tuple-element.86 = s32[] get-tuple-element(inputs.2), index=1 1949 constant.13 = s32[] constant(5) 1950 ROOT less-than.3 = pred[] less-than(get-tuple-element.86, constant.13) 1951 } 1952 1953 ENTRY TestComputation { 1954 arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) 1955 ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25 1956 } 1957 )"; 1958 auto module_or_status = 1959 HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); 1960 auto module = module_or_status.ConsumeValueOrDie(); 1961 InsertCopies(module.get()); 1962 } 1963 1964 } // namespace 1965 } // namespace xla 1966