Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/copy_insertion.h"
     17 
     18 #include <set>
     19 
     20 #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
     21 #include "tensorflow/compiler/xla/literal_util.h"
     22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     25 #include "tensorflow/compiler/xla/service/hlo_module.h"
     26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     27 #include "tensorflow/compiler/xla/service/hlo_runner.h"
     28 #include "tensorflow/compiler/xla/shape_util.h"
     29 #include "tensorflow/compiler/xla/test.h"
     30 #include "tensorflow/compiler/xla/test_helpers.h"
     31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     32 #include "tensorflow/compiler/xla/xla_data.pb.h"
     33 #include "tensorflow/core/platform/test_benchmark.h"
     34 
     35 namespace op = xla::testing::opcode_matchers;
     36 
     37 namespace xla {
     38 namespace {
     39 
     40 using ::testing::UnorderedElementsAre;
     41 
     42 int64 CountCopies(const HloComputation& computation) {
     43   int64 count = 0;
     44   for (const auto& instruction : computation.instructions()) {
     45     if (instruction->opcode() == HloOpcode::kCopy) {
     46       count++;
     47     }
     48   }
     49   return count;
     50 }
     51 
     52 int64 CountCopies(const HloModule& module) {
     53   int64 count = 0;
     54   for (const auto& computation : module.computations()) {
     55     count += CountCopies(*computation);
     56   }
     57   return count;
     58 }
     59 
     60 int64 CountControlEdges(const HloComputation& computation) {
     61   int64 count = 0;
     62   for (const auto& instruction : computation.instructions()) {
     63     count += instruction->control_successors().size();
     64   }
     65   return count;
     66 }
     67 
     68 int64 CountControlEdges(const HloModule& module) {
     69   int64 count = 0;
     70   for (const auto& computation : module.computations()) {
     71     count += CountControlEdges(*computation);
     72   }
     73   return count;
     74 }
     75 
     76 class CopyInsertionTest : public HloTestBase {
     77  protected:
     78   void InsertCopies(HloModule* module) {
     79     CopyInsertion copy_insertion;
     80     ASSERT_IS_OK(copy_insertion.Run(module).status());
     81   }
     82 
     83   const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
     84 };
     85 
     86 TEST_F(CopyInsertionTest, SingleParameter) {
     87   // Computation is a single parameter passed into a tuple. The parameter should
     88   // be copied before entering the tuple.
     89   auto builder = HloComputation::Builder(TestName());
     90   HloInstruction* x = builder.AddInstruction(
     91       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x"));
     92   HloInstruction* tuple =
     93       builder.AddInstruction(HloInstruction::CreateTuple({x}));
     94 
     95   EXPECT_THAT(x->users(), UnorderedElementsAre(tuple));
     96 
     97   auto module = CreateNewModule();
     98   module->AddEntryComputation(builder.Build());
     99 
    100   InsertCopies(module.get());
    101 
    102   EXPECT_THAT(module->entry_computation()->root_instruction(),
    103               op::Tuple(op::Copy(x)));
    104 }
    105 
    106 TEST_F(CopyInsertionTest, SingleConstant) {
    107   // Computation is a single constant passed into a tuple. The parameter should
    108   // be copied before entering the tuple.
    109   auto builder = HloComputation::Builder(TestName());
    110   HloInstruction* constant = builder.AddInstruction(
    111       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    112   HloInstruction* tuple =
    113       builder.AddInstruction(HloInstruction::CreateTuple({constant}));
    114 
    115   EXPECT_THAT(constant->users(), UnorderedElementsAre(tuple));
    116 
    117   auto module = CreateNewModule();
    118   module->AddEntryComputation(builder.Build());
    119 
    120   InsertCopies(module.get());
    121   EXPECT_EQ(CountCopies(*module), 1);
    122 
    123   EXPECT_THAT(module->entry_computation()->root_instruction(),
    124               op::Tuple(op::Copy(constant)));
    125 }
    126 
    127 TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) {
    128   // Verify that an kCopy instructions which exist in the pass before
    129   // copy-insertion remain in the graph after copy-insertion.
    130   auto module = CreateNewModule();
    131 
    132   auto builder = HloComputation::Builder(TestName());
    133   HloInstruction* constant = builder.AddInstruction(
    134       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    135   HloInstruction* copy_1 = builder.AddInstruction(HloInstruction::CreateUnary(
    136       constant->shape(), HloOpcode::kCopy, constant));
    137   HloInstruction* copy_2 = builder.AddInstruction(HloInstruction::CreateUnary(
    138       constant->shape(), HloOpcode::kCopy, constant));
    139   HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
    140       constant->shape(), HloOpcode::kAdd, copy_1, copy_2));
    141   HloInstruction* add_copy = builder.AddInstruction(
    142       HloInstruction::CreateUnary(constant->shape(), HloOpcode::kCopy, add));
    143 
    144   module->AddEntryComputation(builder.Build());
    145 
    146   EXPECT_EQ(CountCopies(*module), 3);
    147 
    148   InsertCopies(module.get());
    149 
    150   EXPECT_EQ(CountCopies(*module), 3);
    151 
    152   EXPECT_EQ(module->entry_computation()->root_instruction(), add_copy);
    153   EXPECT_THAT(
    154       module->entry_computation()->root_instruction(),
    155       op::Copy(op::Add(op::Copy(op::Constant()), op::Copy(op::Constant()))));
    156 }
    157 
    158 TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) {
    159   // Create a computation with more than one constant and parameter. Only one of
    160   // each constant/parameter is pointed to by the output tuple. Only these
    161   // instructions should be copied.
    162   auto builder = HloComputation::Builder(TestName());
    163 
    164   HloInstruction* constant1 = builder.AddInstruction(
    165       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    166   HloInstruction* constant2 = builder.AddInstruction(
    167       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    168 
    169   HloInstruction* x = builder.AddInstruction(
    170       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x"));
    171   HloInstruction* y = builder.AddInstruction(
    172       HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "y"));
    173 
    174   HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
    175       ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, constant1, y));
    176 
    177   builder.AddInstruction(HloInstruction::CreateTuple({constant2, x, add}));
    178 
    179   auto module = CreateNewModule();
    180   module->AddEntryComputation(builder.Build());
    181 
    182   InsertCopies(module.get());
    183   EXPECT_EQ(CountCopies(*module), 2);
    184 
    185   EXPECT_THAT(
    186       module->entry_computation()->root_instruction(),
    187       op::Tuple(op::Copy(constant2), op::Copy(x), op::Add(constant1, y)));
    188 }
    189 
    190 TEST_F(CopyInsertionTest, AmbiguousPointsToSet) {
    191   // Create a computation using select which has an ambiguous points-to set for
    192   // the computation result. Verify that copies are added properly.
    193   auto builder = HloComputation::Builder(TestName());
    194   HloInstruction* constant1 = builder.AddInstruction(
    195       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    196   HloInstruction* constant2 = builder.AddInstruction(
    197       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    198   HloInstruction* constant3 = builder.AddInstruction(
    199       HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
    200 
    201   HloInstruction* tuple1 = builder.AddInstruction(
    202       HloInstruction::CreateTuple({constant1, constant2}));
    203   HloInstruction* tuple2 = builder.AddInstruction(
    204       HloInstruction::CreateTuple({constant3, constant2}));
    205 
    206   HloInstruction* pred = builder.AddInstruction(
    207       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    208   builder.AddInstruction(HloInstruction::CreateTernary(
    209       tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
    210 
    211   EXPECT_THAT(constant1->users(), UnorderedElementsAre(tuple1));
    212   EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2));
    213   EXPECT_THAT(constant3->users(), UnorderedElementsAre(tuple2));
    214 
    215   auto module = CreateNewModule();
    216   module->AddEntryComputation(builder.Build());
    217 
    218   HloInstruction* old_root = module->entry_computation()->root_instruction();
    219   InsertCopies(module.get());
    220   EXPECT_EQ(CountCopies(*module), 2);
    221 
    222   EXPECT_THAT(module->entry_computation()->root_instruction(),
    223               op::Tuple(op::Copy(op::GetTupleElement(old_root)),
    224                         op::Copy(op::GetTupleElement(old_root))));
    225 }
    226 
    227 TEST_F(CopyInsertionTest, BitcastParameter) {
    228   // The output of a bitcast is its operand (same buffer), so a bitcast
    229   // parameter feeding the result must have a copy added.
    230   auto builder = HloComputation::Builder(TestName());
    231   HloInstruction* x = builder.AddInstruction(
    232       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x"));
    233   HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
    234       ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x));
    235 
    236   auto module = CreateNewModule();
    237   module->AddEntryComputation(builder.Build());
    238 
    239   EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast));
    240 
    241   HloInstruction* old_root = module->entry_computation()->root_instruction();
    242   InsertCopies(module.get());
    243   EXPECT_EQ(CountCopies(*module), 1);
    244 
    245   EXPECT_THAT(module->entry_computation()->root_instruction(),
    246               op::Copy(old_root));
    247 }
    248 
    249 TEST_F(CopyInsertionTest, BitcastConstant) {
    250   // The output of a bitcast is its operand (same buffer), so a bitcast
    251   // constant feeding the result must have a copy added.
    252   auto builder = HloComputation::Builder(TestName());
    253   HloInstruction* constant = builder.AddInstruction(
    254       HloInstruction::CreateConstant(Literal::CreateR1<float>({1.0, 42.0})));
    255   HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
    256       ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant));
    257 
    258   auto module = CreateNewModule();
    259   module->AddEntryComputation(builder.Build());
    260 
    261   EXPECT_THAT(constant->users(), UnorderedElementsAre(bitcast));
    262 
    263   HloInstruction* old_root = module->entry_computation()->root_instruction();
    264   InsertCopies(module.get());
    265   EXPECT_EQ(CountCopies(*module), 1);
    266 
    267   EXPECT_THAT(module->entry_computation()->root_instruction(),
    268               op::Copy(old_root));
    269 }
    270 
    271 TEST_F(CopyInsertionTest, BitcastTupleElementParameter) {
    272   // Same as BitcastParameter, but the bitcast is wrapped in a tuple.
    273   auto builder = HloComputation::Builder(TestName());
    274   HloInstruction* x = builder.AddInstruction(
    275       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x"));
    276   HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
    277       ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x));
    278   builder.AddInstruction(HloInstruction::CreateTuple({bitcast}));
    279 
    280   auto module = CreateNewModule();
    281   module->AddEntryComputation(builder.Build());
    282 
    283   EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast));
    284 
    285   InsertCopies(module.get());
    286   EXPECT_EQ(CountCopies(*module), 1);
    287 
    288   EXPECT_THAT(module->entry_computation()->root_instruction(),
    289               op::Tuple(op::Copy(bitcast)));
    290 }
    291 
    292 TEST_F(CopyInsertionTest, NestedTupleParameter) {
    293   // Construct a trivial computation where the root of the computation is a
    294   // nested tuple-shaped parameter. The parameter should be deep copied and the
    295   // copy should be the root of the computation.
    296   auto builder = HloComputation::Builder(TestName());
    297 
    298   // Param shape is: ((F32[], S32[1,2,3]), F32[42])
    299   builder.AddInstruction(HloInstruction::CreateParameter(
    300       0,
    301       ShapeUtil::MakeTupleShape(
    302           {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}),
    303                                       ShapeUtil::MakeShape(S32, {1, 2, 3})}),
    304            ShapeUtil::MakeShape(F32, {42})}),
    305       "param0"));
    306 
    307   auto module = CreateNewModule();
    308   module->AddEntryComputation(builder.Build());
    309 
    310   EXPECT_EQ(HloOpcode::kParameter,
    311             module->entry_computation()->root_instruction()->opcode());
    312 
    313   HloInstruction* old_root = module->entry_computation()->root_instruction();
    314   InsertCopies(module.get());
    315   EXPECT_EQ(CountCopies(*module), 3);
    316 
    317   HloInstruction* new_root = module->entry_computation()->root_instruction();
    318   EXPECT_NE(old_root, new_root);
    319 
    320   EXPECT_THAT(
    321       new_root,
    322       op::Tuple(
    323           op::Tuple(
    324               op::Copy(op::GetTupleElement(op::GetTupleElement(old_root))),
    325               op::Copy(op::GetTupleElement(op::GetTupleElement(old_root)))),
    326           op::Copy(op::GetTupleElement(old_root))));
    327 }
    328 
    329 TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) {
    330   // Construct a computation where the root of the computation is a tuple
    331   // element of a nested tuple-shaped parameter.
    332   auto builder = HloComputation::Builder(TestName());
    333 
    334   // Param shape is: ((F32[], S32[1,2,3]), F32[42])
    335   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
    336       0,
    337       ShapeUtil::MakeTupleShape(
    338           {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}),
    339                                       ShapeUtil::MakeShape(S32, {1, 2, 3})}),
    340            ShapeUtil::MakeShape(F32, {42})}),
    341       "param0"));
    342 
    343   // The return value of the computation is the zero-th element of the nested
    344   // tuple. This element is itself a tuple.
    345   auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    346       ShapeUtil::GetSubshape(param->shape(), {0}), param, 0));
    347 
    348   auto module = CreateNewModule();
    349   module->AddEntryComputation(builder.Build());
    350 
    351   EXPECT_EQ(gte, module->entry_computation()->root_instruction());
    352 
    353   InsertCopies(module.get());
    354   EXPECT_EQ(CountCopies(*module), 2);
    355 
    356   EXPECT_THAT(
    357       module->entry_computation()->root_instruction(),
    358       op::Tuple(op::Copy(op::GetTupleElement(op::GetTupleElement(param))),
    359                 op::Copy(op::GetTupleElement(op::GetTupleElement(param)))));
    360 }
    361 
    362 TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) {
    363   // Create a computation using select which has an ambiguous points-to set for
    364   // the top-level buffer of the root of the computation. Verify that a shallow
    365   // copy is added.
    366   auto builder = HloComputation::Builder(TestName());
    367   HloInstruction* constant1 = builder.AddInstruction(
    368       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    369   HloInstruction* constant2 = builder.AddInstruction(
    370       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    371 
    372   HloInstruction* tuple1 = builder.AddInstruction(
    373       HloInstruction::CreateTuple({constant1, constant2}));
    374   HloInstruction* tuple2 = builder.AddInstruction(
    375       HloInstruction::CreateTuple({constant2, constant1}));
    376 
    377   HloInstruction* pred = builder.AddInstruction(
    378       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    379   HloInstruction* select = builder.AddInstruction(HloInstruction::CreateTernary(
    380       tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
    381   HloInstruction* gte =
    382       builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    383           ShapeUtil::GetSubshape(select->shape(), {0}), select, 0));
    384 
    385   auto module = CreateNewModule();
    386   module->AddEntryComputation(builder.Build());
    387 
    388   EXPECT_EQ(gte, module->entry_computation()->root_instruction());
    389 
    390   HloInstruction* old_root = module->entry_computation()->root_instruction();
    391   InsertCopies(module.get());
    392   EXPECT_EQ(CountCopies(*module), 1);
    393 
    394   EXPECT_THAT(module->entry_computation()->root_instruction(),
    395               op::Copy(old_root));
    396 }
    397 
    398 class WhileCopyInsertionTest : public CopyInsertionTest {
    399  protected:
    400   WhileCopyInsertionTest() : module_(CreateNewModule()) {}
    401 
    402   // Builds a While condition computation which reads the induction variable
    403   // from the tuple parameter, and returns a predicate indicating whether this
    404   // value is less than the constant '10'.
    405   // The parameter 'nested' specifies the loop state shape from which to
    406   // read the induction variable.
    407   std::unique_ptr<HloComputation> BuildConditionComputation(
    408       const Shape& loop_state_shape) {
    409     auto builder = HloComputation::Builder(TestName() + ".Condition");
    410     auto limit_const = builder.AddInstruction(
    411         HloInstruction::CreateConstant(Literal::CreateR0<int32>(10)));
    412     auto loop_state = builder.AddInstruction(
    413         HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
    414     auto induction_variable =
    415         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    416             limit_const->shape(), loop_state, 0));
    417     builder.AddInstruction(
    418         HloInstruction::CreateBinary(condition_result_shape_, HloOpcode::kLt,
    419                                      induction_variable, limit_const));
    420     return builder.Build();
    421   }
    422 
    423   // Builds a While body computation with one output tuple element dependent on
    424   // both input tuple elements.
    425   // EX:
    426   // Body({in0, in1})
    427   //   out0 = Add(in0, 1)
    428   //   out1 = Add(BCast(in0), in1)
    429   //   Tuple(out0, out1)
    430   std::unique_ptr<HloComputation> BuildDependentBodyComputation() {
    431     auto builder = HloComputation::Builder(TestName() + ".Body");
    432     // Create param instruction to access loop state.
    433     auto loop_state = builder.AddInstruction(
    434         HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
    435     // Update the induction variable GTE(0).
    436     auto induction_variable =
    437         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    438             induction_variable_shape_, loop_state, 0));
    439     auto inc = builder.AddInstruction(
    440         HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
    441     auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
    442         induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
    443     // Update data GTE(1).
    444     auto data = builder.AddInstruction(
    445         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
    446     // Use 'induction_variable' in computation with no path to output tuple.
    447     auto update = builder.AddInstruction(
    448         HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
    449     auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
    450         data_shape_, HloOpcode::kAdd, data, update));
    451     // Create output Tuple.
    452     builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
    453     return builder.Build();
    454   }
    455 
    456   // Builds a While body computation with two output tuple elements dependent on
    457   // both input tuple elements.
    458   //
    459   // EX: Body({in0, in1, in2})
    460   //   out0 = Add(in0, 1)
    461   //   out1 = in1
    462   //   out2 = in2
    463   //   Tuple(out0, out1, out2)
    464   std::unique_ptr<HloComputation> BuildDependentBodyComputation2() {
    465     auto builder = HloComputation::Builder(TestName() + ".Body");
    466 
    467     const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
    468         {induction_variable_shape_, data_shape_, data_shape_});
    469 
    470     auto loop_state = builder.AddInstruction(
    471         HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
    472 
    473     // Update the induction variable GTE(0).
    474     auto induction_variable =
    475         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    476             induction_variable_shape_, loop_state, 0));
    477     auto inc = builder.AddInstruction(
    478         HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
    479 
    480     // add0 = Add(in0, 1)
    481     auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
    482         induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
    483     // data1 = GTE(1).
    484     HloInstruction* data1 = builder.AddInstruction(
    485         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
    486 
    487     // data2 = GTE(2).
    488     HloInstruction* data2 = builder.AddInstruction(
    489         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 2));
    490 
    491     // Create output Tuple.
    492     builder.AddInstruction(HloInstruction::CreateTuple({add0, data1, data2}));
    493 
    494     return builder.Build();
    495   }
    496 
    497   // Builds a While body computation with read-only tuple element 0.
    498   // EX:
    499   // Body({in0, in1})
    500   //   out0 = in0
    501   //   out1 = Add(BCast(in0), in1)
    502   //   Tuple(out0, out1)
    503   std::unique_ptr<HloComputation> BuildDependentBodyOneReadOnlyComputation() {
    504     auto builder = HloComputation::Builder(TestName() + ".Body");
    505     // Create param instruction to access loop state.
    506     auto loop_state = builder.AddInstruction(
    507         HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
    508     // Update the induction variable GTE(0).
    509     auto induction_variable =
    510         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    511             induction_variable_shape_, loop_state, 0));
    512     // Update data GTE(1).
    513     auto data = builder.AddInstruction(
    514         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
    515 
    516     // Use 'induction_variable' in computation with no path to output tuple.
    517     auto update = builder.AddInstruction(
    518         HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
    519     auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
    520         data_shape_, HloOpcode::kAdd, data, update));
    521     // Create output Tuple.
    522     builder.AddInstruction(
    523         HloInstruction::CreateTuple({induction_variable, add1}));
    524     return builder.Build();
    525   }
    526 
    527   // Builds a While body computation with independent outputs.
    528   // EX:
    529   // Body({in0, in1})
    530   //   out0 = Add(in0, 1)
    531   //   out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
    532   //   Tuple(out0, out1)
    533   std::unique_ptr<HloComputation> BuildIndependentBodyComputation(
    534       bool nested = false) {
    535     auto builder = HloComputation::Builder(TestName() + ".Body");
    536     // Create param instruction to access loop state.
    537     const Shape& loop_state_shape =
    538         nested ? nested_loop_state_shape_ : loop_state_shape_;
    539 
    540     auto loop_state = builder.AddInstruction(
    541         HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
    542     // Update the induction variable GTE(0).
    543     auto induction_variable =
    544         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    545             induction_variable_shape_, loop_state, 0));
    546     auto inc = builder.AddInstruction(
    547         HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
    548     // add0 = Add(in0, 1)
    549     auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
    550         induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
    551     // Update data GTE(1).
    552     HloInstruction* data = nullptr;
    553     if (nested) {
    554       data = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    555           nested_tuple_shape_, loop_state, 1));
    556       data = builder.AddInstruction(
    557           HloInstruction::CreateGetTupleElement(data_shape_, data, 0));
    558     } else {
    559       data = builder.AddInstruction(
    560           HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
    561     }
    562     auto update = builder.AddInstruction(HloInstruction::CreateConstant(
    563         Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
    564     // add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
    565     auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
    566         data_shape_, HloOpcode::kAdd, data, update));
    567     // Create output Tuple.
    568     if (nested) {
    569       auto nested_tuple =
    570           builder.AddInstruction(HloInstruction::CreateTuple({add1, add1}));
    571       builder.AddInstruction(HloInstruction::CreateTuple({add0, nested_tuple}));
    572     } else {
    573       builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
    574     }
    575     return builder.Build();
    576   }
    577 
    578   // Builds a While body computation with the following nested tuple
    579   // sub-computation:
    580   //                            |
    581   //                    GTE(loop_state, 1)
    582   //                       /           \
    583   // GTE(GTE(loop_state, 1), 0)     GTE(GTE(loop_state, 1), 1)
    584   //           |                              |
    585   //          Add                           Reverse
    586   //           |                              |
    587   std::unique_ptr<HloComputation> BuildNestedBodyComputation() {
    588     auto builder = HloComputation::Builder(TestName() + ".Body");
    589     // Create param instruction to access loop state.
    590     auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
    591         0, nested_loop_state_shape_, "loop_state"));
    592     // Update GTE(0).
    593     auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    594         induction_variable_shape_, loop_state, 0));
    595     auto inc = builder.AddInstruction(
    596         HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
    597     auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
    598         gte0->shape(), HloOpcode::kAdd, gte0, inc));
    599 
    600     // GTE(loop_state, 1)
    601     auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    602         nested_tuple_shape_, loop_state, 1));
    603     // GTE(GTE(loop_state, 1), 0) -> Add
    604     auto gte10 = builder.AddInstruction(
    605         HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0));
    606     auto update10 = builder.AddInstruction(HloInstruction::CreateConstant(
    607         Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
    608     auto add10 = builder.AddInstruction(HloInstruction::CreateBinary(
    609         data_shape_, HloOpcode::kAdd, gte10, update10));
    610 
    611     // GTE(GTE(loop_state, 1), 1) -> Reverse
    612     auto gte11 = builder.AddInstruction(
    613         HloInstruction::CreateGetTupleElement(data_shape_, gte1, 1));
    614     auto rev11 = builder.AddInstruction(
    615         HloInstruction::CreateReverse(data_shape_, gte11, {0}));
    616 
    617     // Create output Tuple.
    618     auto inner_tuple =
    619         builder.AddInstruction(HloInstruction::CreateTuple({add10, rev11}));
    620     builder.AddInstruction(HloInstruction::CreateTuple({add0, inner_tuple}));
    621     return builder.Build();
    622   }
    623 
    624   // Builds a While instruction using 'condition' and 'body' sub-computations.
    625   // Init operand is initialized to zeros of appropriate shape.
    626   HloInstruction* BuildWhileInstruction(HloComputation* condition,
    627                                         HloComputation* body,
    628                                         bool nested = false) {
    629     auto builder = HloComputation::Builder(TestName() + ".While");
    630     auto induction_var_init = builder.AddInstruction(
    631         HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
    632 
    633     auto data_init = builder.AddInstruction(HloInstruction::CreateConstant(
    634         Literal::CreateR1<float>({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
    635 
    636     if (nested) {
    637       auto inner_init = builder.AddInstruction(
    638           HloInstruction::CreateTuple({data_init, data_init}));
    639       auto loop_state_init = builder.AddInstruction(
    640           HloInstruction::CreateTuple({induction_var_init, inner_init}));
    641       auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
    642           loop_state_init->shape(), condition, body, loop_state_init));
    643       module_->AddEntryComputation(builder.Build());
    644       return while_hlo;
    645     }
    646 
    647     auto loop_state_init = builder.AddInstruction(
    648         HloInstruction::CreateTuple({induction_var_init, data_init}));
    649     auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
    650         loop_state_shape_, condition, body, loop_state_init));
    651     module_->AddEntryComputation(builder.Build());
    652     return while_hlo;
    653   }
    654 
    655   HloInstruction* BuildWhileInstruction_InitPointsToConstant() {
    656     auto builder = HloComputation::Builder(TestName() + ".While");
    657     auto data_init = builder.AddInstruction(HloInstruction::CreateConstant(
    658         Literal::CreateR1<float>({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
    659     return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
    660                                                &builder);
    661   }
    662 
    663   HloInstruction* BuildWhileInstruction_InitPointsToParameter() {
    664     auto builder = HloComputation::Builder(TestName() + ".While");
    665     auto data_init = builder.AddInstruction(
    666         HloInstruction::CreateParameter(0, data_shape_, "data_init"));
    667     return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
    668                                                &builder);
    669   }
    670 
    671   HloInstruction* BuildWhileInstruction_InitPointsToAmbiguous() {
    672     auto builder = HloComputation::Builder(TestName() + ".While");
    673 
    674     auto one = builder.AddInstruction(
    675         HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    676     auto v1 = builder.AddInstruction(
    677         HloInstruction::CreateBroadcast(data_shape_, one, {1}));
    678     auto zero = builder.AddInstruction(
    679         HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    680     auto v2 = builder.AddInstruction(
    681         HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
    682 
    683     auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({v1, v2}));
    684     auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1}));
    685 
    686     auto pred = builder.AddInstruction(
    687         HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    688     auto data_init = builder.AddInstruction(HloInstruction::CreateTernary(
    689         nested_tuple_shape_, HloOpcode::kSelect, pred, tuple1, tuple2));
    690 
    691     return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_,
    692                                                data_init, &builder);
    693   }
    694 
    695   HloInstruction* BuildWhileInstruction_InitPointsToNonDistinct() {
    696     auto builder = HloComputation::Builder(TestName() + ".While");
    697 
    698     auto one = builder.AddInstruction(
    699         HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    700     auto one_vec = builder.AddInstruction(
    701         HloInstruction::CreateBroadcast(data_shape_, one, {1}));
    702     auto data_init =
    703         builder.AddInstruction(HloInstruction::CreateTuple({one_vec, one_vec}));
    704 
    705     return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_,
    706                                                data_init, &builder);
    707   }
    708 
    709   HloInstruction* BuildWhileInstruction_InitPointsToInterfering() {
    710     auto builder = HloComputation::Builder(TestName() + ".While");
    711     auto one = builder.AddInstruction(
    712         HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    713     auto data_init = builder.AddInstruction(
    714         HloInstruction::CreateBroadcast(data_shape_, one, {1}));
    715     auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant(
    716         Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
    717     // Take a reference to 'data_init' to make it interfere with while result.
    718     auto add = builder.AddInstruction(HloInstruction::CreateBinary(
    719         data_shape_, HloOpcode::kAdd, data_init, one_vec));
    720 
    721     auto xla_while = BuildWhileInstructionWithCustomInit(loop_state_shape_,
    722                                                          data_init, &builder);
    723 
    724     // Add an additional binary operation operating on the while and the
    725     // interfering add so that neither operation is dead.
    726     auto gte = xla_while->parent()->AddInstruction(
    727         HloInstruction::CreateGetTupleElement(
    728             ShapeUtil::GetSubshape(xla_while->shape(), {1}), xla_while, 1));
    729     auto sub = xla_while->parent()->AddInstruction(HloInstruction::CreateBinary(
    730         data_shape_, HloOpcode::kSubtract, add, gte));
    731     auto gte0 = xla_while->parent()->AddInstruction(
    732         HloInstruction::CreateGetTupleElement(
    733             ShapeUtil::GetSubshape(xla_while->shape(), {0}), xla_while, 0));
    734     auto tuple = xla_while->parent()->AddInstruction(
    735         HloInstruction::CreateTuple({gte0, sub}));
    736 
    737     xla_while->parent()->set_root_instruction(tuple);
    738 
    739     return xla_while;
    740   }
    741 
    742   HloInstruction* BuildWhileInstructionWithCustomInit(
    743       const Shape& loop_state_shape, HloInstruction* data_init,
    744       HloComputation::Builder* builder) {
    745     const bool nested =
    746         ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_);
    747     auto induction_var_init = builder->AddInstruction(
    748         HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
    749     auto condition = module_->AddEmbeddedComputation(
    750         BuildConditionComputation(loop_state_shape));
    751     auto body = module_->AddEmbeddedComputation(
    752         BuildIndependentBodyComputation(nested));
    753     auto loop_state_init = builder->AddInstruction(
    754         HloInstruction::CreateTuple({induction_var_init, data_init}));
    755     auto while_hlo = builder->AddInstruction(HloInstruction::CreateWhile(
    756         loop_state_shape, condition, body, loop_state_init));
    757     module_->AddEntryComputation(builder->Build());
    758     return while_hlo;
    759   }
    760 
    761   std::unique_ptr<HloModule> module_;
    762   Shape induction_variable_shape_ = ShapeUtil::MakeShape(S32, {});
    763   Shape data_shape_ = ShapeUtil::MakeShape(F32, {8});
    764   Shape loop_state_shape_ =
    765       ShapeUtil::MakeTupleShape({induction_variable_shape_, data_shape_});
    766   Shape nested_tuple_shape_ =
    767       ShapeUtil::MakeTupleShape({data_shape_, data_shape_});
    768   Shape nested_loop_state_shape_ = ShapeUtil::MakeTupleShape(
    769       {induction_variable_shape_, nested_tuple_shape_});
    770   Shape condition_result_shape_ = ShapeUtil::MakeShape(PRED, {});
    771 };
    772 
    773 // Tests while body computation with independent tuple elements:
    774 //
    775 //   While.Body({in0, in1})
    776 //     out0 = Add(in0, 1)
    777 //     out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
    778 //     Tuple(out0, out1)
    779 //
    780 // CopyInsertion pass should not generate any copies.
    781 //
    782 TEST_F(WhileCopyInsertionTest, IndependentTupleElements) {
    783   auto condition = module_->AddEmbeddedComputation(
    784       BuildConditionComputation(loop_state_shape_));
    785   auto body =
    786       module_->AddEmbeddedComputation(BuildIndependentBodyComputation());
    787   auto while_hlo = BuildWhileInstruction(condition, body);
    788 
    789   InsertCopies(module_.get());
    790 
    791   // Body should have no copies as the adds can be done inplace.
    792   EXPECT_EQ(CountCopies(*body), 0);
    793   EXPECT_EQ(CountControlEdges(*module_), 0);
    794 
    795   // Both init indices need copies as they are constants.
    796   EXPECT_THAT(while_hlo->operand(0),
    797               op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
    798 }
    799 
    800 // Tests while body computation with dependent tuple elements:
    801 //
    802 //   While.Body({in0, in1})
    803 //     out0 = Add(in0, 1)
    804 //     out1 = Add(BCast(in0), in1)
    805 //     Tuple(out0, out1)
    806 //
    807 // CopyInsertion pass should convert the root instruction to:
    808 //
    809 //     Tuple(Copy(out0), out1)
    810 //
    811 TEST_F(WhileCopyInsertionTest, DependentTupleElements) {
    812   auto condition = module_->AddEmbeddedComputation(
    813       BuildConditionComputation(loop_state_shape_));
    814   auto body = module_->AddEmbeddedComputation(BuildDependentBodyComputation());
    815   auto while_hlo = BuildWhileInstruction(condition, body);
    816 
    817   InsertCopies(module_.get());
    818 
    819   EXPECT_EQ(CountCopies(*body), 1);
    820   EXPECT_EQ(CountControlEdges(*body), 0);
    821 
    822   EXPECT_THAT(
    823       body->root_instruction(),
    824       op::Tuple(op::Add(), op::Add(op::GetTupleElement(), op::Broadcast())));
    825 
    826   auto add = body->root_instruction()->operand(0);
    827   auto bcast = body->root_instruction()->operand(1)->operand(1);
    828   ASSERT_EQ(add->opcode(), HloOpcode::kAdd);
    829   ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
    830 
    831   EXPECT_THAT(
    832       while_hlo->while_body()->root_instruction(),
    833       op::Tuple(op::Add(op::Copy(), op::Constant()),
    834                 op::Add(op::GetTupleElement(), op::Broadcast(op::Copy()))));
    835 
    836   // Both init indices need copies as they are constants.
    837   EXPECT_THAT(while_hlo->operand(0),
    838               op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
    839 }
    840 
    841 // Tests while body computation with read-only tuple element 0:
    842 //
    843 //                         PARAMETER
    844 //                         /       \
    845 //                      GTE(0)     GTE(1)
    846 //                        |  \      |
    847 //                        |   BCAST |
    848 //                        |      \  |
    849 //                        |       ADD
    850 //                        |        |
    851 //                         \      /
    852 //                           TUPLE (root)
    853 //
    854 // CopyInsertion pass should not generate any copies for the while body.
    855 TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) {
    856   auto condition = module_->AddEmbeddedComputation(
    857       BuildConditionComputation(loop_state_shape_));
    858   auto body = module_->AddEmbeddedComputation(
    859       BuildDependentBodyOneReadOnlyComputation());
    860   BuildWhileInstruction(condition, body);
    861 
    862   InsertCopies(module_.get());
    863 
    864   // No copies or control edges should be inserted. The body is legal as is.
    865   EXPECT_EQ(CountCopies(*body), 0);
    866   EXPECT_EQ(CountControlEdges(*body), 0);
    867 }
    868 
    869 // Same as above, but with two while loops, sharing entry parameters.
    870 TEST_F(WhileCopyInsertionTest,
    871        DependentTupleElements_OneReadOnly_TwoLoops_EntryParams) {
    872   auto condition1 = module_->AddEmbeddedComputation(
    873       BuildConditionComputation(loop_state_shape_));
    874   auto condition2 = module_->AddEmbeddedComputation(
    875       BuildConditionComputation(loop_state_shape_));
    876   auto body1 = module_->AddEmbeddedComputation(
    877       BuildDependentBodyOneReadOnlyComputation());
    878   auto body2 = module_->AddEmbeddedComputation(
    879       BuildDependentBodyOneReadOnlyComputation());
    880 
    881   auto builder = HloComputation::Builder(TestName() + ".While");
    882   auto iter_param = builder.AddInstruction(
    883       HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
    884   auto data_param = builder.AddInstruction(
    885       HloInstruction::CreateParameter(1, data_shape_, "data"));
    886   auto loop_init = builder.AddInstruction(
    887       HloInstruction::CreateTuple({iter_param, data_param}));
    888 
    889   auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
    890       loop_state_shape_, condition1, body1, loop_init));
    891   auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
    892       loop_state_shape_, condition2, body2, loop_init));
    893 
    894   // Add a couple elements from each of the while so both whiles are live.
    895   auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    896       ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
    897   auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    898       ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0));
    899   builder.AddInstruction(
    900       HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
    901 
    902   auto entry = module_->AddEntryComputation(builder.Build());
    903 
    904   InsertCopies(module_.get());
    905 
    906   // Neither body should have any copies or control edges in them.
    907   EXPECT_EQ(CountCopies(*body1), 0);
    908   EXPECT_EQ(CountCopies(*body2), 0);
    909   EXPECT_EQ(CountControlEdges(*body1), 0);
    910   EXPECT_EQ(CountControlEdges(*body2), 0);
    911 
    912   // Only two copies should be necessary. Each of the whiles should have
    913   // a copy of tuple element 1 (init value is a parameter, and the element is
    914   // not non-read-only) so each of the while bodies gets its own buffer to write
    915   // element 1 into.
    916   EXPECT_EQ(CountCopies(*entry), 2);
    917 
    918   EXPECT_EQ(while_hlo1->operand(0)->operand(1)->opcode(), HloOpcode::kCopy);
    919   EXPECT_EQ(while_hlo2->operand(0)->operand(1)->opcode(), HloOpcode::kCopy);
    920 
    921   // The two copies of element 1 should be different.
    922   EXPECT_NE(while_hlo1->operand(0)->operand(1),
    923             while_hlo2->operand(0)->operand(1));
    924 }
    925 
    926 // Same as above, but with two while loops, sharing non-parameters.
    927 TEST_F(WhileCopyInsertionTest,
    928        DependentTupleElements_OneReadOnly_TwoLoops_NonParams) {
    929   auto condition1 = module_->AddEmbeddedComputation(
    930       BuildConditionComputation(loop_state_shape_));
    931   auto condition2 = module_->AddEmbeddedComputation(
    932       BuildConditionComputation(loop_state_shape_));
    933   auto body1 = module_->AddEmbeddedComputation(
    934       BuildDependentBodyOneReadOnlyComputation());
    935   auto body2 = module_->AddEmbeddedComputation(
    936       BuildDependentBodyOneReadOnlyComputation());
    937 
    938   auto builder = HloComputation::Builder(TestName() + ".While");
    939   auto iter_param = builder.AddInstruction(
    940       HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
    941   auto data_param = builder.AddInstruction(
    942       HloInstruction::CreateParameter(1, data_shape_, "data"));
    943   // Add dummy ops to ensure loop_init elements aren't entry parameters.
    944   auto iter_value = builder.AddInstruction(HloInstruction::CreateUnary(
    945       iter_param->shape(), HloOpcode::kExp, iter_param));
    946   auto data_value = builder.AddInstruction(HloInstruction::CreateUnary(
    947       data_param->shape(), HloOpcode::kExp, data_param));
    948   auto loop_init = builder.AddInstruction(
    949       HloInstruction::CreateTuple({iter_value, data_value}));
    950 
    951   auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
    952       loop_state_shape_, condition1, body1, loop_init));
    953   auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
    954       loop_state_shape_, condition2, body2, loop_init));
    955 
    956   // Add a couple elements from each of the while so both whiles are not dead.
    957   auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    958       ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
    959   auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
    960       ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0));
    961   builder.AddInstruction(
    962       HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
    963   auto entry = module_->AddEntryComputation(builder.Build());
    964 
    965   InsertCopies(module_.get());
    966 
    967   // Ideally only one copy should be necessary. One of the whiles should
    968   // have a copy of tuple element 1 (the non-read-only element) so each of the
    969   // while bodies gets its own buffer to write element 1 into. However, the
    970   // analysis isn't perfect and adds an additional copy of element 0.
    971   EXPECT_EQ(CountCopies(*entry), 2);
    972 
    973   EXPECT_THAT(while_hlo1->operand(0),
    974               op::Tuple(op::Exp(), op::Copy(op::Exp())));
    975   EXPECT_THAT(while_hlo2->operand(0),
    976               op::Tuple(op::Exp(), op::Copy(op::Exp())));
    977 }
    978 
    979 // Tests while body computation with nested tuple elements:
    980 //
    981 //                            |
    982 //                    GTE(loop_state, 1)
    983 //                       /          \
    984 // GTE(GTE(loop_state, 1), 0)     GTE(GTE(loop_state, 1), 1)
    985 //           |                              |
    986 //          Add                           Reverse
    987 //           |                              |
    988 //
    989 // CopyInsertion pass will conceptually generate the following, but with the
    990 // actual GTE and Tuple instructions optimized away:
    991 //
    992 //                    Tuple  // old root
    993 //                   /     \
    994 //                  /       \
    995 //                GTE(0)   GTE(1)
    996 //                  |       /  \
    997 //                  |      /    \
    998 //                  |    GTE(0) GTE(1)
    999 //                  |       |    |
   1000 //                  |       |   Copy
   1001 //                  |       |    |
   1002 //                   \      |   /
   1003 //                    \    Tuple  // "inner" tuple.
   1004 //                     \    /
   1005 //                      \  /
   1006 //                     Tuple  // new root
   1007 //
   1008 TEST_F(WhileCopyInsertionTest, NestedTupleElements) {
   1009   auto condition = module_->AddEmbeddedComputation(
   1010       BuildConditionComputation(nested_loop_state_shape_));
   1011   auto body = module_->AddEmbeddedComputation(BuildNestedBodyComputation());
   1012   BuildWhileInstruction(condition, body, true);
   1013 
   1014   //  HloInstruction* old_root = body->root_instruction();
   1015   InsertCopies(module_.get());
   1016 
   1017   // The only copy necessary is for the kReverse as it cannot be done
   1018   // in-place (instruction can share buffer with operand). The other elements of
   1019   // the loop state are kAdd instructions which can be done in-place.
   1020   EXPECT_EQ(CountCopies(*body), 1);
   1021 
   1022   // Each element of the init needs a copy as all are constants.
   1023   EXPECT_EQ(CountCopies(*module_), 4);
   1024 
   1025   // Either the kReverse itself must be copied or the operand of the kReverse
   1026   // must be copied.
   1027   if (body->root_instruction()->operand(1)->operand(1)->opcode() ==
   1028       HloOpcode::kCopy) {
   1029     EXPECT_THAT(
   1030         body->root_instruction(),
   1031         op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Reverse()))));
   1032   } else {
   1033     EXPECT_THAT(
   1034         body->root_instruction(),
   1035         op::Tuple(op::Add(), op::Tuple(op::Add(), op::Reverse(op::Copy()))));
   1036   }
   1037 }
   1038 
   1039 // Tests while init instruction which points-to a constant.
   1040 //
   1041 //     init = Tuple(Constant(S32, {}), Constant(F32, {8}))
   1042 //
   1043 // CopyInsertion pass should add copies for both constants.
   1044 //
   1045 TEST_F(WhileCopyInsertionTest, InitPointsToConstant) {
   1046   auto while_hlo = BuildWhileInstruction_InitPointsToConstant();
   1047 
   1048   InsertCopies(module_.get());
   1049   EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
   1050   EXPECT_EQ(CountCopies(*module_), 2);
   1051 
   1052   EXPECT_THAT(while_hlo->operand(0),
   1053               op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
   1054 }
   1055 
   1056 // Tests while init instruction which points-to a parameter.
   1057 //
   1058 //     init = Tuple(Constant(S32, {}), Parameter(F32, {8}))
   1059 //
   1060 // CopyInsertion pass should add copies for both the constant and parameter.
   1061 //
   1062 TEST_F(WhileCopyInsertionTest, InitPointsToParameter) {
   1063   auto while_hlo = BuildWhileInstruction_InitPointsToParameter();
   1064 
   1065   InsertCopies(module_.get());
   1066   EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
   1067   EXPECT_EQ(CountCopies(*module_), 2);
   1068 
   1069   EXPECT_THAT(while_hlo->operand(0),
   1070               op::Tuple(op::Copy(op::Constant()), op::Copy(op::Parameter())));
   1071 }
   1072 
   1073 // Tests while init instruction which has an ambiguous points-to set.
   1074 //
   1075 //     select = Select(pred, tuple1, tuple2)
   1076 //     init = Tuple(Constant(S32, {}), Parameter(F32, {8}))
   1077 //
   1078 // CopyInsertion pass will conceptually generate the following, but with some of
   1079 // the actual GTE and Tuple instructions optimized away:
   1080 //
   1081 //                    Tuple  // old init
   1082 //                   /     \
   1083 //                  /       \
   1084 //                GTE(0)   GTE(1)
   1085 //                  |       /  \
   1086 //                  |      /    \
   1087 //                  |    GTE(0) GTE(1)
   1088 //                  |       |    |
   1089 //                Copy   Copy   Copy
   1090 //                  |       |    |
   1091 //                   \      |   /
   1092 //                    \    Tuple
   1093 //                     \    /
   1094 //                      \  /
   1095 //                     Tuple  // new init
   1096 //
   1097 TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) {
   1098   auto while_hlo = BuildWhileInstruction_InitPointsToAmbiguous();
   1099 
   1100   InsertCopies(module_.get());
   1101   EXPECT_EQ(CountCopies(*module_), 4);
   1102   // The entry computation requires three copies to resolve the ambiguity of two
   1103   // init elements and the constant passed in as one of the init elements.
   1104   EXPECT_EQ(CountCopies(*module_->entry_computation()), 3);
   1105   EXPECT_THAT(while_hlo->operand(0),
   1106               op::Tuple(op::Copy(op::Constant()),
   1107                         op::Tuple(op::Copy(op::GetTupleElement()),
   1108                                   op::Copy(op::GetTupleElement()))));
   1109 
   1110   // The body requires one copy because the buffer set is not distinct: the
   1111   // result of one of the adds is written into two elements of the output of the
   1112   // loop body. Either element might be copied.
   1113   EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1);
   1114   if (while_hlo->while_body()
   1115           ->root_instruction()
   1116           ->operand(1)
   1117           ->operand(0)
   1118           ->opcode() == HloOpcode::kCopy) {
   1119     EXPECT_THAT(
   1120         while_hlo->while_body()->root_instruction(),
   1121         op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add())));
   1122   } else {
   1123     EXPECT_THAT(
   1124         while_hlo->while_body()->root_instruction(),
   1125         op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add()))));
   1126   }
   1127 }
   1128 
   1129 // Tests while init instruction which has a non-distinct points-to set.
   1130 //
   1131 //     init = Tuple(Constant(S32, {}), Tuple({vec_one, vec_one}))
   1132 //
   1133 // CopyInsertion pass will conceptually generate the following, but with some of
   1134 // the actual GTE and Tuple instructions optimized away:
   1135 //
   1136 //                    Tuple  // old init
   1137 //                   /     \
   1138 //                  /       \
   1139 //                GTE(0)   GTE(1)
   1140 //                  |       /  \
   1141 //                  |      /    \
   1142 //                  |    GTE(0) GTE(1)
   1143 //                  |       |    |
   1144 //                Copy   Copy   Copy
   1145 //                  |       |    |
   1146 //                   \      |   /
   1147 //                    \    Tuple
   1148 //                     \    /
   1149 //                      \  /
   1150 //                     Tuple  // new init
   1151 //
   1152 TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) {
   1153   auto while_hlo = BuildWhileInstruction_InitPointsToNonDistinct();
   1154 
   1155   InsertCopies(module_.get());
   1156 
   1157   // The entry computation requires two copies to resolve the non-disinctness of
   1158   // two init elements and the constant passed in as one of the init
   1159   // elements. Either element can be copied for the distinctness issue.
   1160   EXPECT_EQ(CountCopies(*module_->entry_computation()), 2);
   1161   if (while_hlo->operand(0)->operand(1)->operand(0)->opcode() ==
   1162       HloOpcode::kCopy) {
   1163     EXPECT_THAT(
   1164         while_hlo->operand(0),
   1165         op::Tuple(op::Copy(op::Constant()),
   1166                   op::Tuple(op::Copy(op::Broadcast()), op::Broadcast())));
   1167   } else {
   1168     EXPECT_THAT(
   1169         while_hlo->operand(0),
   1170         op::Tuple(op::Copy(op::Constant()),
   1171                   op::Tuple(op::Broadcast(), op::Copy(op::Broadcast()))));
   1172   }
   1173 
   1174   // The body requires one copy because the buffer set is not distinct: the
   1175   // result of one of the adds is written into two elements of the output of the
   1176   // loop body. Either element might be copied.
   1177   EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1);
   1178   if (while_hlo->while_body()
   1179           ->root_instruction()
   1180           ->operand(1)
   1181           ->operand(0)
   1182           ->opcode() == HloOpcode::kCopy) {
   1183     EXPECT_THAT(
   1184         while_hlo->while_body()->root_instruction(),
   1185         op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add())));
   1186   } else {
   1187     EXPECT_THAT(
   1188         while_hlo->while_body()->root_instruction(),
   1189         op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add()))));
   1190   }
   1191 }
   1192 
   1193 // Tests while init instruction buffer which interferes with while result
   1194 // buffer.
   1195 //
   1196 //     init_data = Broadcast(...)
   1197 //     add_unrelated = Add(init_data) // takes a reference to cause interference
   1198 //     init = Tuple(Constant(S32, {}), init_data))
   1199 //
   1200 // CopyInsertion pass should copy both operands.
   1201 //
   1202 TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) {
   1203   auto while_hlo = BuildWhileInstruction_InitPointsToInterfering();
   1204 
   1205   InsertCopies(module_.get());
   1206   EXPECT_EQ(CountCopies(*module_), 2);
   1207   EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
   1208 
   1209   EXPECT_THAT(while_hlo->operand(0),
   1210               op::Tuple(op::Copy(op::Constant()), op::Copy(op::Broadcast())));
   1211 }
   1212 
   1213 // Tests while init instruction buffer which has a non-distinct points-to set:
   1214 //
   1215 //     init = Tuple(Parameter(S32, {}), Parameter(F32, {8},
   1216 //                  Parameter(F32, {8})))
   1217 //
   1218 // where the second and third parameters are identical *and* the tuple shared
   1219 // by another while instruction.
   1220 //
   1221 // Verifies that the resulting point-to set is distinct in the resulting Tuple
   1222 // (non-identical Copys). In other words, verifies that copy sharing does not
   1223 // insert identical copies to the resulting tuple.
   1224 TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
   1225   // Loop body that outputs tuple comprises two elements dependent on the init
   1226   // tuple.
   1227   const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
   1228       {induction_variable_shape_, data_shape_, data_shape_});
   1229 
   1230   auto condition1 = module_->AddEmbeddedComputation(
   1231       BuildConditionComputation(loop_state_shape));
   1232   auto condition2 = module_->AddEmbeddedComputation(
   1233       BuildConditionComputation(loop_state_shape));
   1234   auto body1 =
   1235       module_->AddEmbeddedComputation(BuildDependentBodyComputation2());
   1236   auto body2 =
   1237       module_->AddEmbeddedComputation(BuildDependentBodyComputation2());
   1238 
   1239   auto builder = HloComputation::Builder(TestName() + ".While");
   1240 
   1241   auto iter_param = builder.AddInstruction(
   1242       HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
   1243   auto data_param = builder.AddInstruction(
   1244       HloInstruction::CreateParameter(1, data_shape_, "data"));
   1245 
   1246   // Loop init tuple contains two identical parameter buffers.
   1247   auto loop_init = builder.AddInstruction(
   1248       HloInstruction::CreateTuple({iter_param, data_param, data_param}));
   1249 
   1250 
   1251   // Two while loops shares the same loop init tuple.
   1252   auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
   1253       loop_state_shape, condition1, body1, loop_init));
   1254   auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
   1255       loop_state_shape, condition2, body2, loop_init));
   1256 
   1257   // Add add instruction so neither while is dead.
   1258   auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
   1259       ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
   1260   auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
   1261       ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo2, 0));
   1262   builder.AddInstruction(
   1263       HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
   1264 
   1265   module_->AddEntryComputation(builder.Build());
   1266 
   1267   InsertCopies(module_.get());
   1268 
   1269   // None of the bodies should have copies or control flow edges.
   1270   EXPECT_EQ(CountCopies(*body1), 0);
   1271   EXPECT_EQ(CountCopies(*body2), 0);
   1272 
   1273   // The loop bodies pass through elements 1 and 2 in the init tuple, so ideally
   1274   // these should not need to be copied before either while. However, copy
   1275   // insertion is not able to reason about the transparency of elements through
   1276   // while bodies in all circumstances so extra copies are added (b/xxx).
   1277   EXPECT_EQ(CountCopies(*module_->entry_computation()), 2);
   1278 
   1279   EXPECT_THAT(while_hlo1->operand(0),
   1280               op::Tuple(op::Copy(), op::Parameter(), op::Parameter()));
   1281   EXPECT_THAT(while_hlo2->operand(0),
   1282               op::Tuple(op::Copy(), op::Parameter(), op::Parameter()));
   1283 }
   1284 
   1285 TEST_F(CopyInsertionTest, SwizzlingWhile) {
   1286   // Test a while instruction with a body which permutes its tuple parameter
   1287   // elements.
   1288   auto module = CreateNewModule();
   1289   const Shape loop_state_shape =
   1290       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
   1291 
   1292   // Body simply interchanges the two tuple elements in the loop state.
   1293   auto body_builder = HloComputation::Builder("body");
   1294   auto body_param = body_builder.AddInstruction(
   1295       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
   1296   auto body_element_0 = body_builder.AddInstruction(
   1297       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
   1298   auto body_element_1 = body_builder.AddInstruction(
   1299       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
   1300   body_builder.AddInstruction(
   1301       HloInstruction::CreateTuple({body_element_1, body_element_0}));
   1302   HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
   1303 
   1304   auto cond_builder = HloComputation::Builder("condition");
   1305   cond_builder.AddInstruction(
   1306       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
   1307   auto cond_constant = cond_builder.AddInstruction(
   1308       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
   1309   cond_builder.AddInstruction(HloInstruction::CreateUnary(
   1310       cond_constant->shape(), HloOpcode::kNot, cond_constant));
   1311   HloComputation* condition =
   1312       module->AddEmbeddedComputation(cond_builder.Build());
   1313 
   1314   auto builder = HloComputation::Builder(TestName());
   1315   auto constant1 = builder.AddInstruction(
   1316       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
   1317   auto constant2 = builder.AddInstruction(
   1318       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
   1319   auto tuple = builder.AddInstruction(
   1320       HloInstruction::CreateTuple({constant1, constant2}));
   1321   auto xla_while = builder.AddInstruction(
   1322       HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
   1323   module->AddEntryComputation(builder.Build());
   1324 
   1325   InsertCopies(module.get());
   1326 
   1327   EXPECT_EQ(CountCopies(*module), 6);
   1328 
   1329   // The loop state elements should be copied at the parameter and at the root
   1330   // with a control edge in between (see DeepCopyAndAddControlEdges). This is
   1331   // technically one more copy than is strictly necessary, but in order to have
   1332   // only three copies the copies of different loop state elements must be
   1333   // ordered with a control edge.
   1334   EXPECT_EQ(CountCopies(*body), 4);
   1335   EXPECT_EQ(CountControlEdges(*body), 2);
   1336 
   1337   EXPECT_THAT(body->root_instruction(),
   1338               op::Tuple(op::Copy(op::Copy()), op::Copy(op::Copy())));
   1339 
   1340   EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
   1341   EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
   1342 }
   1343 
   1344 TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
   1345   // Test a while instruction with a body which permutes its tuple parameter
   1346   // elements and applies one operation to one of the elements. The addition of
   1347   // the operation (instruction) on the element makes the live range of the
   1348   // respective input and output elements different than if the instruction were
   1349   // not there (as in the SwizzlingWhile test above).
   1350   auto module = CreateNewModule();
   1351   const Shape loop_state_shape =
   1352       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
   1353 
   1354   // Body interchanges the two tuple elements in the loop state and negates one
   1355   // of them.
   1356   auto body_builder = HloComputation::Builder("body");
   1357   auto body_param = body_builder.AddInstruction(
   1358       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
   1359   auto body_element_0 = body_builder.AddInstruction(
   1360       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
   1361   auto body_element_1 = body_builder.AddInstruction(
   1362       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
   1363   auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
   1364       scalar_shape_, HloOpcode::kNegate, body_element_1));
   1365   body_builder.AddInstruction(
   1366       HloInstruction::CreateTuple({negate, body_element_0}));
   1367   HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
   1368 
   1369   auto cond_builder = HloComputation::Builder("condition");
   1370   cond_builder.AddInstruction(
   1371       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
   1372   auto cond_constant = cond_builder.AddInstruction(
   1373       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
   1374   cond_builder.AddInstruction(HloInstruction::CreateUnary(
   1375       cond_constant->shape(), HloOpcode::kNot, cond_constant));
   1376   HloComputation* condition =
   1377       module->AddEmbeddedComputation(cond_builder.Build());
   1378 
   1379   auto builder = HloComputation::Builder(TestName());
   1380   auto constant1 = builder.AddInstruction(
   1381       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
   1382   auto constant2 = builder.AddInstruction(
   1383       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
   1384   auto tuple = builder.AddInstruction(
   1385       HloInstruction::CreateTuple({constant1, constant2}));
   1386   auto xla_while = builder.AddInstruction(
   1387       HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
   1388   module->AddEntryComputation(builder.Build());
   1389 
   1390   InsertCopies(module.get());
   1391 
   1392   EXPECT_EQ(CountCopies(*module), 6);
   1393 
   1394   // The loop state elements should be copied at the parameter and at the root
   1395   // with a control edge in between (see DeepCopyAndAddControlEdges).
   1396   EXPECT_EQ(CountCopies(*body), 4);
   1397   EXPECT_EQ(CountControlEdges(*body), 2);
   1398 
   1399   EXPECT_THAT(
   1400       body->root_instruction(),
   1401       op::Tuple(op::Copy(op::Negate(op::Copy())), op::Copy(op::Copy())));
   1402 
   1403   EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
   1404   EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
   1405 }
   1406 
   1407 TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) {
   1408   // Test a while instruction with a body which permutes it's tuple parameter
   1409   // elements similar to SwizzlinWhile above. However, in this test the input to
   1410   // the while body is a single constant (both loop state elements are the same
   1411   // constant). This means no copies are necessary because both loop state
   1412   // elements are the same so interchanging them is a no-op.
   1413   auto module = CreateNewModule();
   1414   const Shape loop_state_shape =
   1415       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
   1416 
   1417   // Body simply interchanges the two tuple elements in the loop state.
   1418   auto body_builder = HloComputation::Builder("body");
   1419   auto body_param = body_builder.AddInstruction(
   1420       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
   1421   auto body_element_0 = body_builder.AddInstruction(
   1422       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
   1423   auto body_element_1 = body_builder.AddInstruction(
   1424       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
   1425   body_builder.AddInstruction(
   1426       HloInstruction::CreateTuple({body_element_1, body_element_0}));
   1427   HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
   1428 
   1429   auto cond_builder = HloComputation::Builder("condition");
   1430   cond_builder.AddInstruction(
   1431       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
   1432   auto cond_constant = cond_builder.AddInstruction(
   1433       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
   1434   cond_builder.AddInstruction(HloInstruction::CreateUnary(
   1435       cond_constant->shape(), HloOpcode::kNot, cond_constant));
   1436   HloComputation* condition =
   1437       module->AddEmbeddedComputation(cond_builder.Build());
   1438 
   1439   auto builder = HloComputation::Builder(TestName());
   1440   auto constant = builder.AddInstruction(
   1441       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
   1442   auto tuple =
   1443       builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
   1444   builder.AddInstruction(
   1445       HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
   1446   module->AddEntryComputation(builder.Build());
   1447 
   1448   InsertCopies(module.get());
   1449 
   1450   EXPECT_EQ(CountCopies(*module), 2);
   1451   EXPECT_EQ(CountCopies(*body), 0);
   1452 
   1453   EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
   1454   EXPECT_THAT(module->entry_computation()->root_instruction(),
   1455               op::Tuple(op::Copy(), op::Copy()));
   1456 }
   1457 
   1458 TEST_F(CopyInsertionTest, SequentialWhiles) {
   1459   // Construct a computation with a series of sequential while instructions
   1460   // containing four loop state elements:
   1461   //
   1462   //   element 0 is passed to each while directly from an entry parameter.
   1463   //
   1464   //   element 1 is passed transparently in series through all the while bodies.
   1465   //
   1466   //   element 2 is negated in each while body. (in-place possible)
   1467   //
   1468   //   element 3 is reversed in each while body. (in-place not possible)
   1469   //
   1470   const Shape element_shape = ShapeUtil::MakeShape(F32, {42});
   1471   const Shape loop_state_shape = ShapeUtil::MakeTupleShape(
   1472       {element_shape, element_shape, element_shape, element_shape});
   1473 
   1474   auto module = CreateNewModule();
   1475   auto builder = HloComputation::Builder(TestName());
   1476   auto param_0 = builder.AddInstruction(
   1477       HloInstruction::CreateParameter(0, element_shape, "param_0"));
   1478   auto param_1 = builder.AddInstruction(
   1479       HloInstruction::CreateParameter(1, element_shape, "param_1"));
   1480   auto param_2 = builder.AddInstruction(
   1481       HloInstruction::CreateParameter(2, element_shape, "param_2"));
   1482   auto param_3 = builder.AddInstruction(
   1483       HloInstruction::CreateParameter(3, element_shape, "param_3"));
   1484 
   1485   // The number of sequential kWhile instructions.
   1486   const int kNumWhiles = 3;
   1487 
   1488   HloInstruction* prev_element_1 = param_1;
   1489   HloInstruction* prev_element_2 = param_2;
   1490   HloInstruction* prev_element_3 = param_3;
   1491 
   1492   // Vector containing all of the while instructions.
   1493   std::vector<const HloInstruction*> whiles;
   1494   for (int i = 0; i < kNumWhiles; ++i) {
   1495     auto body_builder = HloComputation::Builder("body");
   1496     auto body_param = body_builder.AddInstruction(
   1497         HloInstruction::CreateParameter(0, loop_state_shape, "param"));
   1498     auto body_element_0 = body_builder.AddInstruction(
   1499         HloInstruction::CreateGetTupleElement(element_shape, body_param, 0));
   1500     auto body_element_1 = body_builder.AddInstruction(
   1501         HloInstruction::CreateGetTupleElement(element_shape, body_param, 1));
   1502     auto body_element_2 = body_builder.AddInstruction(
   1503         HloInstruction::CreateGetTupleElement(element_shape, body_param, 2));
   1504     auto body_element_3 = body_builder.AddInstruction(
   1505         HloInstruction::CreateGetTupleElement(element_shape, body_param, 3));
   1506     auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
   1507         element_shape, HloOpcode::kNegate, body_element_2));
   1508     auto reverse = body_builder.AddInstruction(
   1509         HloInstruction::CreateReverse(element_shape, body_element_3, {0}));
   1510     body_builder.AddInstruction(HloInstruction::CreateTuple(
   1511         {body_element_0, body_element_1, negate, reverse}));
   1512     HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
   1513 
   1514     auto cond_builder = HloComputation::Builder("condition");
   1515     cond_builder.AddInstruction(
   1516         HloInstruction::CreateParameter(0, loop_state_shape, "param"));
   1517     auto cond_constant = cond_builder.AddInstruction(
   1518         HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
   1519     cond_builder.AddInstruction(HloInstruction::CreateUnary(
   1520         cond_constant->shape(), HloOpcode::kNot, cond_constant));
   1521     HloComputation* condition =
   1522         module->AddEmbeddedComputation(cond_builder.Build());
   1523 
   1524     auto while_init = builder.AddInstruction(HloInstruction::CreateTuple(
   1525         {param_0, prev_element_1, prev_element_2, prev_element_3}));
   1526 
   1527     auto xla_while = builder.AddInstruction(HloInstruction::CreateWhile(
   1528         loop_state_shape, condition, body, while_init));
   1529     whiles.push_back(xla_while);
   1530     if (i != kNumWhiles - 1) {
   1531       prev_element_1 = builder.AddInstruction(
   1532           HloInstruction::CreateGetTupleElement(element_shape, xla_while, 1));
   1533       prev_element_2 = builder.AddInstruction(
   1534           HloInstruction::CreateGetTupleElement(element_shape, xla_while, 2));
   1535       prev_element_3 = builder.AddInstruction(
   1536           HloInstruction::CreateGetTupleElement(element_shape, xla_while, 3));
   1537     }
   1538   }
   1539 
   1540   module->AddEntryComputation(builder.Build());
   1541 
   1542   InsertCopies(module.get());
   1543 
   1544   // Each while body has one copy. And each loop state element is copied once in
   1545   // the entry computation.
   1546   EXPECT_EQ(CountCopies(*module), 4 + kNumWhiles);
   1547 
   1548   // Each while body should have exactly one copy for element three which is an
   1549   // op (kReverse) which cannot be done in place.
   1550   for (const HloInstruction* xla_while : whiles) {
   1551     EXPECT_EQ(CountCopies(*xla_while->while_body()), 1);
   1552   }
   1553 
   1554   EXPECT_THAT(whiles[0]->operand(0), op::Tuple(op::Parameter(), op::Parameter(),
   1555                                                op::Copy(), op::Copy()));
   1556   EXPECT_THAT(module->entry_computation()->root_instruction(),
   1557               op::Tuple(op::Copy(), op::Copy(), op::GetTupleElement(),
   1558                         op::GetTupleElement()));
   1559 }
   1560 
   1561 TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) {
   1562   // Test a while body and condition which are each simply a constant (root of
   1563   // computation is a constant). The body constant should be copied.
   1564   auto module = CreateNewModule();
   1565   auto builder = HloComputation::Builder(TestName());
   1566   auto param_0 = builder.AddInstruction(
   1567       HloInstruction::CreateParameter(0, scalar_shape_, "param_0"));
   1568 
   1569   auto body_builder = HloComputation::Builder("body");
   1570   body_builder.AddInstruction(
   1571       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
   1572   body_builder.AddInstruction(
   1573       HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0)));
   1574   HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
   1575 
   1576   auto cond_builder = HloComputation::Builder("condition");
   1577   cond_builder.AddInstruction(
   1578       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
   1579   cond_builder.AddInstruction(
   1580       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
   1581   HloComputation* condition =
   1582       module->AddEmbeddedComputation(cond_builder.Build());
   1583 
   1584   auto xla_while = builder.AddInstruction(
   1585       HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0));
   1586 
   1587   module->AddEntryComputation(builder.Build());
   1588 
   1589   InsertCopies(module.get());
   1590 
   1591   EXPECT_EQ(CountCopies(*module), 2);
   1592 
   1593   EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter()));
   1594   EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant()));
   1595   EXPECT_THAT(condition->root_instruction(), op::Constant());
   1596 }
   1597 
   1598 std::unique_ptr<HloComputation> MakeTrivialCondition(const Shape& shape) {
   1599   auto builder = HloComputation::Builder("trivial_condition");
   1600   builder.AddInstruction(
   1601       HloInstruction::CreateParameter(0, shape, "loop_state"));
   1602   auto constant = builder.AddInstruction(
   1603       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
   1604   builder.AddInstruction(HloInstruction::CreateUnary(
   1605       constant->shape(), HloOpcode::kNot, constant));
   1606   return builder.Build();
   1607 }
   1608 
   1609 std::unique_ptr<HloComputation> MakeBenchmarkWhileBody() {
   1610   auto builder = HloComputation::Builder("benchmark_loop_body");
   1611   const Shape element_shape = ShapeUtil::MakeShape(F32, {42});
   1612   const Shape loop_state_shape =
   1613       ShapeUtil::MakeTupleShape({element_shape, element_shape, element_shape});
   1614   HloInstruction* param = builder.AddInstruction(
   1615       HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
   1616   HloInstruction* element_0 = builder.AddInstruction(
   1617       HloInstruction::CreateGetTupleElement(element_shape, param, 0));
   1618   HloInstruction* element_1 = builder.AddInstruction(
   1619       HloInstruction::CreateGetTupleElement(element_shape, param, 1));
   1620   HloInstruction* element_2 = builder.AddInstruction(
   1621       HloInstruction::CreateGetTupleElement(element_shape, param, 2));
   1622 
   1623   HloInstruction* rev_1 = builder.AddInstruction(
   1624       HloInstruction::CreateReverse(element_shape, element_1, {0}));
   1625   HloInstruction* add_1_2 = builder.AddInstruction(HloInstruction::CreateBinary(
   1626       element_shape, HloOpcode::kAdd, element_1, element_2));
   1627 
   1628   builder.AddInstruction(
   1629       HloInstruction::CreateTuple({element_0, rev_1, add_1_2}));
   1630   return builder.Build();
   1631 }
   1632 
   1633 void BM_SequentialWhiles(int num_iters, int num_whiles) {
   1634   // This benchmark constructs a chain of sequential while instructions.
   1635   tensorflow::testing::StopTiming();
   1636   for (int i = 0; i < num_iters; ++i) {
   1637     HloModuleConfig config;
   1638     config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
   1639     HloModule module("BM_SequentialWhiles", VersionedComputationHandle(),
   1640                      config);
   1641 
   1642     auto builder = HloComputation::Builder("BM_SequentialWhiles");
   1643     HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
   1644         0, ShapeUtil::MakeShape(F32, {42}), "x"));
   1645     HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
   1646         1, ShapeUtil::MakeShape(F32, {42}), "y"));
   1647     HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
   1648         2, ShapeUtil::MakeShape(F32, {42}), "z"));
   1649     HloInstruction* init =
   1650         builder.AddInstruction(HloInstruction::CreateTuple({x, y, z}));
   1651 
   1652     HloInstruction* prev_loop_state = init;
   1653     for (int w = 0; w < num_whiles; ++w) {
   1654       HloComputation* condition =
   1655           module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
   1656       HloComputation* body =
   1657           module.AddEmbeddedComputation(MakeBenchmarkWhileBody());
   1658       prev_loop_state = builder.AddInstruction(HloInstruction::CreateWhile(
   1659           init->shape(), condition, body, prev_loop_state));
   1660     }
   1661     module.AddEntryComputation(builder.Build());
   1662 
   1663     CopyInsertion copy_insertion;
   1664 
   1665     tensorflow::testing::StartTiming();
   1666     ASSERT_IS_OK(copy_insertion.Run(&module).status());
   1667     tensorflow::testing::StopTiming();
   1668 
   1669     // The entry computation should have three copies, and each body has one.
   1670     ASSERT_EQ(CountCopies(module), 3 + num_whiles);
   1671   }
   1672 }
   1673 
   1674 void BM_ParallelWhiles(int num_iters, int num_whiles) {
   1675   // This benchmark constructs a fan-out of parallel while instructions.
   1676   tensorflow::testing::StopTiming();
   1677   for (int i = 0; i < num_iters; ++i) {
   1678     HloModuleConfig config;
   1679     config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
   1680     HloModule module("BM_SequentialWhiles", VersionedComputationHandle(),
   1681                      config);
   1682 
   1683     auto builder = HloComputation::Builder("BM_ParallelWhiles");
   1684     HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
   1685         0, ShapeUtil::MakeShape(F32, {42}), "x"));
   1686     HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
   1687         1, ShapeUtil::MakeShape(F32, {42}), "y"));
   1688     HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
   1689         2, ShapeUtil::MakeShape(F32, {42}), "z"));
   1690     HloInstruction* init =
   1691         builder.AddInstruction(HloInstruction::CreateTuple({x, y, z}));
   1692 
   1693     HloInstruction* sum = nullptr;
   1694     for (int w = 0; w < num_whiles; ++w) {
   1695       HloComputation* condition =
   1696           module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
   1697       HloComputation* body =
   1698           module.AddEmbeddedComputation(MakeBenchmarkWhileBody());
   1699 
   1700       HloInstruction* xla_while = builder.AddInstruction(
   1701           HloInstruction::CreateWhile(init->shape(), condition, body, init));
   1702 
   1703       if (sum == nullptr) {
   1704         sum = builder.AddInstruction(
   1705             HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0));
   1706       } else {
   1707         HloInstruction* element_0 = builder.AddInstruction(
   1708             HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0));
   1709         sum = builder.AddInstruction(HloInstruction::CreateBinary(
   1710             x->shape(), HloOpcode::kAdd, sum, element_0));
   1711       }
   1712     }
   1713     module.AddEntryComputation(builder.Build());
   1714 
   1715     CopyInsertion copy_insertion;
   1716 
   1717     tensorflow::testing::StartTiming();
   1718     ASSERT_IS_OK(copy_insertion.Run(&module).status());
   1719     tensorflow::testing::StopTiming();
   1720 
   1721     // Each body receives of copy of two of the parameters (the corresponding
   1722     // elements in the body are modifed), and there is one copy in each body.
   1723     ASSERT_EQ(CountCopies(module), 3 * num_whiles);
   1724   }
   1725 }
   1726 
   1727 std::unique_ptr<HloComputation> MakeBenchmarkWhileBody(
   1728     const int num_tuple_inputs) {
   1729   auto builder = HloComputation::Builder("benchmark_loop_body");
   1730   const Shape element_shape = ShapeUtil::MakeShape(F32, {});
   1731   std::vector<Shape> input_shape(num_tuple_inputs, element_shape);
   1732   const Shape loop_state_shape = ShapeUtil::MakeTupleShape(input_shape);
   1733   HloInstruction* param = builder.AddInstruction(
   1734       HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
   1735   std::vector<HloInstruction*> gte_nodes(num_tuple_inputs);
   1736   for (int i = 0; i < num_tuple_inputs; ++i) {
   1737     gte_nodes[i] = builder.AddInstruction(
   1738         HloInstruction::CreateGetTupleElement(element_shape, param, i));
   1739   }
   1740   builder.AddInstruction(HloInstruction::CreateTuple(gte_nodes));
   1741   return builder.Build();
   1742 }
   1743 
   1744 void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) {
   1745   tensorflow::testing::StopTiming();
   1746   HloModuleConfig config;
   1747   config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
   1748   CopyInsertion copy_insertion;
   1749   const Shape element_shape = ShapeUtil::MakeShape(F32, {});
   1750   std::vector<HloInstruction*> tuple_params(num_tuple_inputs);
   1751   for (int i = 0; i < num_iters; ++i) {
   1752     auto builder = HloComputation::Builder("BM_ParallelWhiles");
   1753     HloModule module("BM_ManyElementTuple", VersionedComputationHandle(),
   1754                      config);
   1755     for (int j = 0; j < num_tuple_inputs; ++j) {
   1756       tuple_params[j] = builder.AddInstruction(
   1757           HloInstruction::CreateParameter(j, element_shape, ""));
   1758     }
   1759     HloInstruction* init =
   1760         builder.AddInstruction(HloInstruction::CreateTuple(tuple_params));
   1761     HloComputation* condition =
   1762         module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
   1763     HloComputation* body =
   1764         module.AddEmbeddedComputation(MakeBenchmarkWhileBody(num_tuple_inputs));
   1765     HloInstruction* xla_while = builder.AddInstruction(
   1766         HloInstruction::CreateWhile(init->shape(), condition, body, init));
   1767     builder.AddInstruction(HloInstruction::CreateGetTupleElement(
   1768         ShapeUtil::MakeShape(F32, {}), xla_while, 0));
   1769     module.AddEntryComputation(builder.Build());
   1770     tensorflow::testing::StartTiming();
   1771     ASSERT_IS_OK(copy_insertion.Run(&module).status());
   1772     tensorflow::testing::StopTiming();
   1773   }
   1774 }
   1775 
   1776 BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
   1777 BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
   1778 BENCHMARK(BM_ManyElementTuple)->Arg(1024)->Arg(12288);
   1779 
   1780 TEST_F(CopyInsertionTest, SimpleControlFlowTest) {
   1781   const string& hlo_string = R"(
   1782 HloModule TestModule
   1783 
   1784 if-body.v5 {
   1785   constant.3 = s32[] constant(-1)
   1786   p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
   1787   get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1
   1788   get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0
   1789   get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1
   1790   add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66)
   1791   tuple.33 = (s32[]) tuple(add.3)
   1792   ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33)
   1793 }
   1794 
   1795 if-condition.v4 {
   1796   p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
   1797   get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
   1798   constant.4 = s32[] constant(0)
   1799   ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4)
   1800 }
   1801 
   1802 _functionalize_body_1__.v28 {
   1803   arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0)
   1804   get-tuple-element.68 = s32[] get-tuple-element(arg_tuple.4), index=0
   1805   constant.7 = s32[] constant(1)
   1806   add.4 = s32[] add(get-tuple-element.68, constant.7)
   1807   get-tuple-element.69 = s32[] get-tuple-element(arg_tuple.4), index=1
   1808   get-tuple-element.70 = s32[] get-tuple-element(arg_tuple.4), index=2
   1809   less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.69, get-tuple-element.70)
   1810   constant.8 = s32[] constant(0)
   1811   select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
   1812   get-tuple-element.71 = s32[] get-tuple-element(arg_tuple.4), index=3
   1813   tuple.35 = (s32[], s32[], s32[]) tuple(get-tuple-element.69, get-tuple-element.71, get-tuple-element.70)
   1814   tuple.36 = (s32[]) tuple(constant.8)
   1815   tuple.37 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.35, tuple.36)
   1816   while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.37), condition=if-condition.v4, body=if-body.v5
   1817   get-tuple-element.72 = (s32[]) get-tuple-element(while), index=2
   1818   get-tuple-element.73 = s32[] get-tuple-element(get-tuple-element.72), index=0
   1819   ROOT tuple.38 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.69, get-tuple-element.70, get-tuple-element.73)
   1820 }
   1821 
   1822 cond_wrapper.v3.1 {
   1823   inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
   1824   get-tuple-element.75 = s32[] get-tuple-element(inputs.1), index=0
   1825   constant.11 = s32[] constant(7)
   1826   ROOT less-than.2 = pred[] less-than(get-tuple-element.75, constant.11)
   1827 }
   1828 
   1829 _functionalize_body_2__.v25 {
   1830   arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
   1831   get-tuple-element.76 = s32[] get-tuple-element(arg_tuple.5), index=0
   1832   get-tuple-element.77 = s32[] get-tuple-element(arg_tuple.5), index=2
   1833   get-tuple-element.78 = s32[] get-tuple-element(arg_tuple.5), index=3
   1834   get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=4
   1835   tuple.39 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.76, get-tuple-element.77, get-tuple-element.78, get-tuple-element.79)
   1836   while.2 = (s32[], s32[], s32[], s32[]) while(tuple.39), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28
   1837   get-tuple-element.80 = s32[] get-tuple-element(while.2), index=0
   1838   get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=1
   1839   constant.12 = s32[] constant(1)
   1840   add.5 = s32[] add(get-tuple-element.81, constant.12)
   1841   get-tuple-element.82 = s32[] get-tuple-element(while.2), index=3
   1842   ROOT tuple.40 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.80, add.5, get-tuple-element.77, get-tuple-element.78, get-tuple-element.82)
   1843 }
   1844 
   1845 cond_wrapper.v3.2 {
   1846   inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
   1847   get-tuple-element.83 = s32[] get-tuple-element(inputs.2), index=1
   1848   constant.13 = s32[] constant(5)
   1849   ROOT less-than.3 = pred[] less-than(get-tuple-element.83, constant.13)
   1850 }
   1851 
   1852 ENTRY TestComputation {
   1853   arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
   1854   ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
   1855 }
   1856 )";
   1857   auto module_or_status =
   1858       HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
   1859   auto module = module_or_status.ConsumeValueOrDie();
   1860   InsertCopies(module.get());
   1861 }
   1862 
   1863 TEST_F(CopyInsertionTest, ControlFlowTest) {
   1864   const string& hlo_string = R"(
   1865 HloModule TestModule
   1866 
   1867 if-body.v5 {
   1868   constant.3 = s32[] constant(-1)
   1869   p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
   1870   get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1
   1871   get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0
   1872   get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1
   1873   add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66)
   1874   tuple.33 = (s32[]) tuple(add.3)
   1875   ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33)
   1876 }
   1877 
   1878 if-condition.v4 {
   1879   p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
   1880   get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
   1881   constant.4 = s32[] constant(0)
   1882   ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4)
   1883 }
   1884 
   1885 if-body.v5.1 {
   1886   constant.5 = s32[] constant(-1)
   1887   p.3 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
   1888   get-tuple-element.68 = (s32[], s32[], s32[]) get-tuple-element(p.3), index=1
   1889   get-tuple-element.70 = s32[] get-tuple-element(get-tuple-element.68), index=2
   1890   multiply.1 = s32[] multiply(get-tuple-element.70, get-tuple-element.70)
   1891   tuple.35 = (s32[]) tuple(multiply.1)
   1892   ROOT tuple.36 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.5, get-tuple-element.68, tuple.35)
   1893 }
   1894 
   1895 if-condition.v4.1 {
   1896   p.4 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
   1897   get-tuple-element.71 = s32[] get-tuple-element(p.4), index=0
   1898   constant.6 = s32[] constant(1)
   1899   ROOT equal-to.1 = pred[] equal-to(get-tuple-element.71, constant.6)
   1900 }
   1901 
   1902 _functionalize_body_1__.v28 {
   1903   arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0)
   1904   get-tuple-element.72 = s32[] get-tuple-element(arg_tuple.4), index=0
   1905   constant.7 = s32[] constant(1)
   1906   add.4 = s32[] add(get-tuple-element.72, constant.7)
   1907   get-tuple-element.73 = s32[] get-tuple-element(arg_tuple.4), index=1
   1908   get-tuple-element.74 = s32[] get-tuple-element(arg_tuple.4), index=2
   1909   less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.73, get-tuple-element.74)
   1910   constant.8 = s32[] constant(0)
   1911   select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
   1912   get-tuple-element.75 = s32[] get-tuple-element(arg_tuple.4), index=3
   1913   tuple.37 = (s32[], s32[], s32[]) tuple(get-tuple-element.73, get-tuple-element.75, get-tuple-element.74)
   1914   tuple.38 = (s32[]) tuple(constant.8)
   1915   tuple.39 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.37, tuple.38)
   1916   while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.39), condition=if-condition.v4, body=if-body.v5
   1917   while.1 = (s32[], (s32[], s32[], s32[]), (s32[])) while(while), condition=if-condition.v4.1, body=if-body.v5.1
   1918   get-tuple-element.76 = (s32[]) get-tuple-element(while.1), index=2
   1919   get-tuple-element.77 = s32[] get-tuple-element(get-tuple-element.76), index=0
   1920   ROOT tuple.40 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.73, get-tuple-element.74, get-tuple-element.77)
   1921 }
   1922 
   1923 cond_wrapper.v3.1 {
   1924   inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
   1925   get-tuple-element.78 = s32[] get-tuple-element(inputs.1), index=0
   1926   constant.11 = s32[] constant(7)
   1927   ROOT less-than.2 = pred[] less-than(get-tuple-element.78, constant.11)
   1928 }
   1929 
   1930 _functionalize_body_2__.v25 {
   1931   arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
   1932   get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=0
   1933   get-tuple-element.80 = s32[] get-tuple-element(arg_tuple.5), index=2
   1934   get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=3
   1935   get-tuple-element.82 = s32[] get-tuple-element(arg_tuple.5), index=4
   1936   tuple.41 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.79, get-tuple-element.80, get-tuple-element.81, get-tuple-element.82)
   1937   while.2 = (s32[], s32[], s32[], s32[]) while(tuple.41), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28
   1938   get-tuple-element.83 = s32[] get-tuple-element(while.2), index=0
   1939   get-tuple-element.84 = s32[] get-tuple-element(arg_tuple.5), index=1
   1940   constant.12 = s32[] constant(1)
   1941   add.5 = s32[] add(get-tuple-element.84, constant.12)
   1942   get-tuple-element.85 = s32[] get-tuple-element(while.2), index=3
   1943   ROOT tuple.42 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.83, add.5, get-tuple-element.80, get-tuple-element.81, get-tuple-element.85)
   1944 }
   1945 
   1946 cond_wrapper.v3.2 {
   1947   inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
   1948   get-tuple-element.86 = s32[] get-tuple-element(inputs.2), index=1
   1949   constant.13 = s32[] constant(5)
   1950   ROOT less-than.3 = pred[] less-than(get-tuple-element.86, constant.13)
   1951 }
   1952 
   1953 ENTRY TestComputation {
   1954   arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
   1955   ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
   1956 }
   1957 )";
   1958   auto module_or_status =
   1959       HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest());
   1960   auto module = module_or_status.ConsumeValueOrDie();
   1961   InsertCopies(module.get());
   1962 }
   1963 
   1964 }  // namespace
   1965 }  // namespace xla
   1966