Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
     17 
     18 #include "tensorflow/compiler/xla/literal_util.h"
     19 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     20 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
     21 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     23 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
     24 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
     25 #include "tensorflow/compiler/xla/shape_util.h"
     26 #include "tensorflow/compiler/xla/status_macros.h"
     27 #include "tensorflow/compiler/xla/test.h"
     28 #include "tensorflow/compiler/xla/test_helpers.h"
     29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     30 #include "tensorflow/compiler/xla/xla_data.pb.h"
     31 #include "tensorflow/core/platform/logging.h"
     32 #include "tensorflow/core/platform/test.h"
     33 
     34 namespace xla {
     35 namespace {
     36 
     37 using ::testing::ElementsAre;
     38 using ::testing::UnorderedElementsAre;
     39 
     40 // Test is parameterized on a bool which is whether the dataflow analysis is
     41 // performed with SSA form.
     42 class HloDataflowAnalysisTest : public HloTestBase,
     43                                 public ::testing::WithParamInterface<bool> {
     44  protected:
     45   HloDataflowAnalysisTest() : module_(CreateNewModule()) {}
     46 
     47   // Run dataflow analysis on the member module. For convenience returns a
     48   // reference to the generated analysis stored in analysis_.
     49   const HloDataflowAnalysis& RunAnalysis(bool ssa_form,
     50                                          bool bitcast_defines_value = false) {
     51     hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before dataflow analysis");
     52     analysis_ =
     53         HloDataflowAnalysis::Run(*module_, ssa_form, bitcast_defines_value)
     54             .ConsumeValueOrDie();
     55     return *analysis_;
     56   }
     57 
     58   // Return a vector of the HloValues at the given program position.
     59   std::vector<HloValue> HloValuesAt(const HloInstruction* instruction,
     60                                     const ShapeIndex& index = {}) {
     61     CHECK(analysis_ != nullptr);
     62     std::vector<HloValue> values;
     63     for (const HloValue* value :
     64          analysis_->GetValueSet(instruction, index).values()) {
     65       values.push_back(*value);
     66     }
     67     return values;
     68   }
     69 
     70   // Returns true if the top-level values for instructions 'a' and 'b' may
     71   // interfere. Precondition: 'a' and 'b' define array-shaped values.
     72   bool InstructionsMayInterfere(const HloOrdering& ordering,
     73                                 const HloInstruction* a,
     74                                 const HloInstruction* b) {
     75     EXPECT_FALSE(ShapeUtil::IsTuple(a->shape()));
     76     EXPECT_FALSE(ShapeUtil::IsTuple(b->shape()));
     77     return ordering.MayInterfere(analysis_->GetValueDefinedAt(a),
     78                                  analysis_->GetValueDefinedAt(b), *analysis_);
     79   }
     80 
     81   std::unique_ptr<HloComputation> CreateR0F32UnaryOpComputation(
     82       HloOpcode opcode) {
     83     HloComputation::Builder builder(TestName() + "." + HloOpcodeString(opcode));
     84     HloInstruction* param0 = builder.AddInstruction(
     85         HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
     86     builder.AddInstruction(
     87         HloInstruction::CreateUnary(scalar_shape_, opcode, param0));
     88     return builder.Build();
     89   }
     90 
     91   std::unique_ptr<HloModule> module_;
     92   std::unique_ptr<HloDataflowAnalysis> analysis_;
     93 
     94   const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
     95   const Shape vector_shape_ = ShapeUtil::MakeShape(F32, {42});
     96   const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
     97       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})});
     98 };
     99 
    100 TEST_P(HloDataflowAnalysisTest, BinaryOperation) {
    101   // Test the dataflow for a simple binary operation (Add).
    102   auto builder = HloComputation::Builder(TestName());
    103   auto constant1 = builder.AddInstruction(
    104       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    105   auto constant2 = builder.AddInstruction(
    106       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    107   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
    108       scalar_shape_, HloOpcode::kAdd, constant1, constant2));
    109   module_->AddEntryComputation(builder.Build());
    110 
    111   bool ssa_form = GetParam();
    112   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
    113 
    114   // Each instruction should define a single value.
    115   EXPECT_EQ(analysis.values().size(), 3);
    116   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
    117   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
    118   EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
    119 
    120   // Verify the positions of the values. These positions are all trivial because
    121   // there are no instructions which forward values.
    122   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).positions(),
    123               UnorderedElementsAre(HloPosition{constant1, {}}));
    124   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).positions(),
    125               UnorderedElementsAre(HloPosition{constant2, {}}));
    126   EXPECT_THAT(analysis.GetValueDefinedAt(add).positions(),
    127               UnorderedElementsAre(HloPosition{add, {}}));
    128 
    129   // Verify the uses of the values.
    130   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
    131               UnorderedElementsAre(HloUse{add, 0, {}}));
    132   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
    133               UnorderedElementsAre(HloUse{add, 1, {}}));
    134   EXPECT_TRUE(analysis.GetValueDefinedAt(add).uses().empty());
    135 
    136   // Verify liveout values from the module.
    137   EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
    138   EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
    139   EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
    140 }
    141 
    142 TEST_P(HloDataflowAnalysisTest, TupleAndGtes) {
    143   // Verify the dataflow through a Tuple and GetTupleElement instructions.
    144   auto builder = HloComputation::Builder(TestName());
    145   auto param0 = builder.AddInstruction(
    146       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
    147   auto param1 = builder.AddInstruction(
    148       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
    149   auto tuple =
    150       builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
    151   auto gte0 = builder.AddInstruction(
    152       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 0));
    153   auto gte1 = builder.AddInstruction(
    154       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
    155   auto add = builder.AddInstruction(
    156       HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1));
    157   module_->AddEntryComputation(builder.Build());
    158 
    159   bool ssa_form = GetParam();
    160   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
    161 
    162   // The two params, tuple, and add should each define one value.
    163   EXPECT_EQ(analysis.values().size(), 4);
    164 
    165   EXPECT_TRUE(analysis.ValueIsDefinedAt(param0));
    166   EXPECT_TRUE(analysis.ValueIsDefinedAt(param1));
    167   EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple, /*index=*/{}));
    168   EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{0}));
    169   EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{1}));
    170   EXPECT_FALSE(analysis.ValueIsDefinedAt(gte0));
    171   EXPECT_FALSE(analysis.ValueIsDefinedAt(gte1));
    172   EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
    173 
    174   // Verify the positions of the values.
    175   EXPECT_THAT(
    176       analysis.GetValueDefinedAt(param0).positions(),
    177       UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
    178                            HloPosition{gte0, {}}));
    179   EXPECT_THAT(
    180       analysis.GetValueDefinedAt(param1).positions(),
    181       UnorderedElementsAre(HloPosition{param1, {}}, HloPosition{tuple, {1}},
    182                            HloPosition{gte1, {}}));
    183   EXPECT_THAT(analysis.GetValueDefinedAt(tuple).positions(),
    184               UnorderedElementsAre(HloPosition{tuple, {}}));
    185 
    186   // Verify uses. Of interest is that a GetTupleElement instruction is only a
    187   // use of the top-level value in the tuple operand.
    188   EXPECT_THAT(analysis.GetValueDefinedAt(param0).uses(),
    189               UnorderedElementsAre(HloUse{add, 0, {}}));
    190   EXPECT_THAT(analysis.GetValueDefinedAt(param1).uses(),
    191               UnorderedElementsAre(HloUse{add, 1, {}}));
    192   EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(),
    193               UnorderedElementsAre(HloUse{gte0, 0, {}}, HloUse{gte1, 0, {}}));
    194   EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
    195 }
    196 
    197 TEST_P(HloDataflowAnalysisTest, NestedTuple) {
    198   // Verify the dataflow through a nested tuple.
    199   auto builder = HloComputation::Builder(TestName());
    200   auto constant1 = builder.AddInstruction(
    201       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    202   auto constant2 = builder.AddInstruction(
    203       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    204   auto tuple = builder.AddInstruction(
    205       HloInstruction::CreateTuple({constant1, constant2}));
    206   auto nested_tuple = builder.AddInstruction(
    207       HloInstruction::CreateTuple({tuple, tuple, constant1}));
    208   auto gte_tuple = builder.AddInstruction(
    209       HloInstruction::CreateGetTupleElement(tuple->shape(), nested_tuple, 1));
    210   auto gte_out = builder.AddInstruction(
    211       HloInstruction::CreateGetTupleElement(scalar_shape_, gte_tuple, 0));
    212   module_->AddEntryComputation(builder.Build());
    213 
    214   bool ssa_form = GetParam();
    215   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
    216 
    217   EXPECT_EQ(analysis.values().size(), 4);
    218 
    219   // Verify positions and uses.
    220   EXPECT_THAT(
    221       analysis.GetValueDefinedAt(constant1).positions(),
    222       UnorderedElementsAre(
    223           HloPosition{constant1, {}}, HloPosition{tuple, {0}},
    224           HloPosition{nested_tuple, {0, 0}}, HloPosition{nested_tuple, {1, 0}},
    225           HloPosition{nested_tuple, {2}}, HloPosition{gte_tuple, {0}},
    226           HloPosition{gte_out, {}}));
    227   // Constant values should have only a single use, which is the root of the
    228   // computation.
    229   EXPECT_THAT(analysis.GetValueDefinedAt(constant1, /*index=*/{}).uses(),
    230               UnorderedElementsAre(HloUse{gte_out, 0, {0}}));
    231   EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty());
    232 
    233   // The top-level tuple values are used in GTE instructions.
    234   EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(),
    235               UnorderedElementsAre(HloUse{gte_out, 0, {}}));
    236   EXPECT_THAT(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{}).uses(),
    237               UnorderedElementsAre(HloUse{gte_tuple, 0, {}}));
    238 
    239   EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
    240   EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
    241   EXPECT_FALSE(
    242       analysis.GetValueDefinedAt(tuple, /*index=*/{}).live_out_of_module());
    243   EXPECT_FALSE(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{})
    244                    .live_out_of_module());
    245 }
    246 
    247 TEST_P(HloDataflowAnalysisTest, SingleCall) {
    248   // Test a single call of a subcomputation. The subcomputation adds its two
    249   // array-shaped parameters.
    250   auto subbuilder = HloComputation::Builder("Subcomputation");
    251   auto subparam0 = subbuilder.AddInstruction(
    252       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
    253   auto subparam1 = subbuilder.AddInstruction(
    254       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
    255   auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
    256       scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
    257   HloComputation* called_computation =
    258       module_->AddEmbeddedComputation(subbuilder.Build());
    259 
    260   auto builder = HloComputation::Builder(TestName());
    261   auto constant1 = builder.AddInstruction(
    262       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    263   auto constant2 = builder.AddInstruction(
    264       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    265   auto call = builder.AddInstruction(HloInstruction::CreateCall(
    266       scalar_shape_, {constant1, constant2}, called_computation));
    267   module_->AddEntryComputation(builder.Build());
    268 
    269   bool ssa_form = GetParam();
    270   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
    271 
    272   EXPECT_EQ(analysis.values().size(), 3);
    273 
    274   // The parameters of the subcomputation and the call instruction itself should
    275   // not define values. Their values flow from elsewhere.
    276   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
    277   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
    278   EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0));
    279   EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam1));
    280   EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
    281   EXPECT_FALSE(analysis.ValueIsDefinedAt(call));
    282 
    283   EXPECT_EQ(analysis.GetUniqueValueAt(subparam0),
    284             analysis.GetValueDefinedAt(constant1));
    285   EXPECT_EQ(analysis.GetUniqueValueAt(subparam1),
    286             analysis.GetValueDefinedAt(constant2));
    287   EXPECT_EQ(analysis.GetUniqueValueAt(call), analysis.GetValueDefinedAt(add));
    288 
    289   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
    290               UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{add, 0, {}}));
    291   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
    292               UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{add, 1, {}}));
    293 
    294   EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
    295 }
    296 
    297 TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) {
    298   // Test a subcomputation which is called twice with identical values.
    299   auto subbuilder = HloComputation::Builder("Subcomputation");
    300   auto subparam0 = subbuilder.AddInstruction(
    301       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
    302   auto subparam1 = subbuilder.AddInstruction(
    303       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
    304   auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
    305       scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
    306   HloComputation* called_computation =
    307       module_->AddEmbeddedComputation(subbuilder.Build());
    308 
    309   auto builder = HloComputation::Builder(TestName());
    310   auto constant1 = builder.AddInstruction(
    311       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    312   auto constant2 = builder.AddInstruction(
    313       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    314   auto call1 = builder.AddInstruction(HloInstruction::CreateCall(
    315       scalar_shape_, {constant1, constant2}, called_computation));
    316   auto call2 = builder.AddInstruction(HloInstruction::CreateCall(
    317       scalar_shape_, {constant1, constant2}, called_computation));
    318   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
    319       scalar_shape_, HloOpcode::kSubtract, call1, call2));
    320   module_->AddEntryComputation(builder.Build());
    321 
    322   bool ssa_form = GetParam();
    323   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
    324 
    325   EXPECT_EQ(analysis.values().size(), 4);
    326 
    327   // Definitions should be identical to the single callsite case.
    328   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
    329   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
    330   EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0));
    331   EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam1));
    332   EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
    333   EXPECT_FALSE(analysis.ValueIsDefinedAt(call1));
    334   EXPECT_FALSE(analysis.ValueIsDefinedAt(call2));
    335   EXPECT_TRUE(analysis.ValueIsDefinedAt(sub));
    336 
    337   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
    338               UnorderedElementsAre(HloUse{call1, 0, {}}, HloUse{call2, 0, {}},
    339                                    HloUse{add, 0, {}}));
    340   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
    341               UnorderedElementsAre(HloUse{call1, 1, {}}, HloUse{call2, 1, {}},
    342                                    HloUse{add, 1, {}}));
    343   // The Add from the subcomputation is used as both operands of the Subtract.
    344   EXPECT_THAT(analysis.GetValueDefinedAt(add).uses(),
    345               UnorderedElementsAre(HloUse{sub, 0, {}}, HloUse{sub, 1, {}}));
    346 
    347   EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module());
    348   EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_module());
    349 }
    350 
    351 TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) {
    352   // Test a subcomputation which is called twice with different argument values.
    353   auto subbuilder = HloComputation::Builder("Subcomputation");
    354   auto subparam0 = subbuilder.AddInstruction(
    355       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
    356   auto subparam1 = subbuilder.AddInstruction(
    357       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
    358   auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
    359       scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
    360   HloComputation* called_computation =
    361       module_->AddEmbeddedComputation(subbuilder.Build());
    362 
    363   auto builder = HloComputation::Builder(TestName());
    364   auto constant1 = builder.AddInstruction(
    365       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    366   auto constant2 = builder.AddInstruction(
    367       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    368   auto call1 = builder.AddInstruction(HloInstruction::CreateCall(
    369       scalar_shape_, {constant1, constant2}, called_computation));
    370   auto call2 = builder.AddInstruction(HloInstruction::CreateCall(
    371       scalar_shape_, {call1, constant2}, called_computation));
    372   module_->AddEntryComputation(builder.Build());
    373 
    374   bool ssa_form = GetParam();
    375   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
    376 
    377   EXPECT_FALSE(analysis.ValueIsDefinedAt(call1));
    378   EXPECT_FALSE(analysis.ValueIsDefinedAt(call2));
    379 
    380   EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0));
    381 
    382   EXPECT_THAT(HloValuesAt(subparam0),
    383               UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
    384                                    analysis.GetValueDefinedAt(add)));
    385   EXPECT_THAT(HloValuesAt(subparam1),
    386               UnorderedElementsAre(analysis.GetValueDefinedAt(constant2)));
    387 
    388   EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
    389 }
    390 
    391 TEST_P(HloDataflowAnalysisTest, NestedCalls) {
    392   // Test a module with nested computations. HLO is:
    393   //
    394   // F32[] inner_computation(F32[] %param0, F32[] %param1):
    395   //   %add = Add(%param0, %param1)
    396   //
    397   // F32[] outer_computation((F32[] %param0, F32[] %param1):
    398   //  ;; Note that parameters are interchanged in the call.
    399   //   %nested_call = Call(inner_computation, {%param1, %param0})
    400   //
    401   // F32[] entry:
    402   //   %constant1 = Constant(1.0)
    403   //   %constant2 = Constant(2.0)
    404   //   %call = Call(outer_computation, {%constant1, %constant2})
    405   //
    406   auto inner_builder = HloComputation::Builder("InnerComputation");
    407   auto inner_param0 = inner_builder.AddInstruction(
    408       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
    409   auto inner_param1 = inner_builder.AddInstruction(
    410       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
    411   auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary(
    412       scalar_shape_, HloOpcode::kAdd, inner_param0, inner_param1));
    413   HloComputation* inner_computation =
    414       module_->AddEmbeddedComputation(inner_builder.Build());
    415 
    416   auto outer_builder = HloComputation::Builder("OuterComputation");
    417   auto outer_param0 = outer_builder.AddInstruction(
    418       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
    419   auto outer_param1 = outer_builder.AddInstruction(
    420       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
    421   // Swizzle parameters.
    422   auto nested_call = outer_builder.AddInstruction(HloInstruction::CreateCall(
    423       scalar_shape_, {outer_param1, outer_param0}, inner_computation));
    424   HloComputation* outer_computation =
    425       module_->AddEmbeddedComputation(outer_builder.Build());
    426 
    427   auto builder = HloComputation::Builder(TestName());
    428   auto constant1 = builder.AddInstruction(
    429       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    430   auto constant2 = builder.AddInstruction(
    431       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    432   auto call = builder.AddInstruction(HloInstruction::CreateCall(
    433       scalar_shape_, {constant1, constant2}, outer_computation));
    434   module_->AddEntryComputation(builder.Build());
    435 
    436   bool ssa_form = GetParam();
    437   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
    438 
    439   // Only three values should be defined. Most instructions just pass through
    440   // their operand values.
    441   EXPECT_EQ(analysis.values().size(), 3);
    442 
    443   // Verify that the uses of the constants are properly swizzled by parameter
    444   // permutation in nested_call.
    445   EXPECT_THAT(
    446       analysis.GetValueDefinedAt(constant1).uses(),
    447       UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{nested_call, 1, {}},
    448                            HloUse{add, 1, {}}));
    449   EXPECT_THAT(
    450       analysis.GetValueDefinedAt(constant2).uses(),
    451       UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{nested_call, 0, {}},
    452                            HloUse{add, 0, {}}));
    453 
    454   EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
    455 }
    456 
    457 TEST_P(HloDataflowAnalysisTest, SingleWhile) {
    458   // Test a simple single while instruction. The while body includes a
    459   // pass-through value. HLO:
    460   //
    461   // body((F32[], F32[]) %tuple_param):
    462   //   %add = Add(%tuple_param{0}, %tuple_param{1})
    463   //   return Tuple(%tuple_param{0}, %add)
    464   //
    465   // condition((F32[], F32[]) %tuple_param):
    466   //   return Constant(false)
    467   //
    468   // entry:
    469   //   %constant1 = Constant(1.0)
    470   //   %constant2 = Constant(2.0)
    471   //   %tuple = Tuple(%constant1, %constant2)
    472   //   return While(%tuple, body, condition)
    473   //
    474   const Shape tuple_shape =
    475       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
    476 
    477   // Element 0 passes transparently through the body.
    478   auto body_builder = HloComputation::Builder("body");
    479   auto body_param = body_builder.AddInstruction(
    480       HloInstruction::CreateParameter(0, tuple_shape, "param"));
    481   auto body_element_0 = body_builder.AddInstruction(
    482       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
    483   auto body_element_1 = body_builder.AddInstruction(
    484       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
    485   auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
    486       scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
    487   auto body_root = body_builder.AddInstruction(
    488       HloInstruction::CreateTuple({body_element_0, add}));
    489   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
    490 
    491   // Condition computation trivially returns a constant "false".
    492   auto cond_builder = HloComputation::Builder("condition");
    493   auto cond_param = cond_builder.AddInstruction(
    494       HloInstruction::CreateParameter(0, tuple_shape, "param"));
    495   auto cond_constant = cond_builder.AddInstruction(
    496       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    497   HloComputation* condition =
    498       module_->AddEmbeddedComputation(cond_builder.Build());
    499 
    500   auto builder = HloComputation::Builder(TestName());
    501   auto constant1 = builder.AddInstruction(
    502       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    503   auto constant2 = builder.AddInstruction(
    504       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    505   auto tuple = builder.AddInstruction(
    506       HloInstruction::CreateTuple({constant1, constant2}));
    507   auto xla_while = builder.AddInstruction(
    508       HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
    509   module_->AddEntryComputation(builder.Build());
    510 
    511   bool ssa_form = GetParam();
    512   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
    513 
    514   EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module());
    515 
    516   if (ssa_form) {
    517     // Element 0 of the tuple passed through the body so no phi value is
    518     // defined.
    519     EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
    520     EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
    521     EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
    522 
    523     // Element 1 of the tuple should be a phi value.
    524     EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
    525     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
    526     EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
    527     EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi());
    528     EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
    529     EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi());
    530 
    531     EXPECT_THAT(
    532         analysis.GetValueDefinedAt(constant1).uses(),
    533         UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{body_root, 0, {}},
    534                              HloUse{xla_while, 0, {0}}));
    535 
    536     // Constant1 passes through the body and out of the module.
    537     EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
    538     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
    539                     .live_out_of_module());
    540 
    541     EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module());
    542   } else {
    543     // While instruction and subcomputation parameters should not define values
    544     // in non-ssa form.
    545     EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
    546     EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
    547     EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
    548     EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
    549     EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
    550     EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
    551 
    552     EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
    553     EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
    554   }
    555 }
    556 
    557 TEST_P(HloDataflowAnalysisTest, SequentialWhiles) {
    558   // Test sequential while instructions. The while body includes a
    559   // pass-through value. HLO:
    560   //
    561   // body((F32[], F32[]) %tuple_param):
    562   //   %add = Add(%tuple_param{0}, %tuple_param{1})
    563   //   return Tuple(%tuple_param{0}, %add)
    564   //
    565   // condition((F32[], F32[]) %tuple_param):
    566   //   return Constant(false)
    567   //
    568   // entry:
    569   //   %constant1 = Constant(1.0)
    570   //   %constant2 = Constant(2.0)
    571   //   %tuple = Tuple(%constant1, %constant2)
    572   //   %while0 = While(%tuple, body, condition)
    573   //   %while1 = While(%while0, body, condition)
    574   //   return While(%while1, body, condition)
    575   //
    576   const Shape tuple_shape =
    577       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
    578 
    579   // Element 0 passes transparently through the body.
    580   auto body_builder = HloComputation::Builder("body");
    581   auto body_param = body_builder.AddInstruction(
    582       HloInstruction::CreateParameter(0, tuple_shape, "param"));
    583   auto body_element_0 = body_builder.AddInstruction(
    584       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
    585   auto body_element_1 = body_builder.AddInstruction(
    586       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
    587   auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
    588       scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
    589   body_builder.AddInstruction(
    590       HloInstruction::CreateTuple({body_element_0, add}));
    591   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
    592 
    593   auto cond_builder = HloComputation::Builder("condition");
    594   cond_builder.AddInstruction(
    595       HloInstruction::CreateParameter(0, tuple_shape, "param"));
    596   cond_builder.AddInstruction(
    597       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    598   HloComputation* condition =
    599       module_->AddEmbeddedComputation(cond_builder.Build());
    600 
    601   auto builder = HloComputation::Builder(TestName());
    602   auto constant1 = builder.AddInstruction(
    603       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    604   auto constant2 = builder.AddInstruction(
    605       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    606   auto tuple = builder.AddInstruction(
    607       HloInstruction::CreateTuple({constant1, constant2}));
    608   auto xla_while0 = builder.AddInstruction(
    609       HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
    610   auto xla_while1 = builder.AddInstruction(
    611       HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while0));
    612   auto xla_while2 = builder.AddInstruction(
    613       HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1));
    614   module_->AddEntryComputation(builder.Build());
    615 
    616   bool ssa_form = GetParam();
    617   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
    618 
    619   // Element 0 is passed through all the while instructions and out of the
    620   // module..
    621   EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}),
    622             analysis.GetValueDefinedAt(constant1));
    623   EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}),
    624             analysis.GetValueDefinedAt(constant1));
    625   EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}),
    626             analysis.GetValueDefinedAt(constant1));
    627   EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
    628 }
    629 
    630 TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
    631   // Test nested while instructions. The inner body passes through element 0 of
    632   // its parameter, and the outer body passes through element 1.  HLO:
    633   //
    634   // inner_body((F32[], F32[]) %tuple_param):
    635   //   %add = Add(%tuple_param{0}, %tuple_param{1})
    636   //   return Tuple(%tuple_param{0}, %add)
    637   //
    638   // outer_body((F32[], F32[]) %tuple_param):
    639   //   %negate = Negate(%tuple_param{0})
    640   //   %tuple = Tuple(%negate, %tuple_param{1})
    641   //   return While(%tuple, inner_body, condition)
    642   //
    643   // entry:
    644   //   %constant1 = Constant(1.0)
    645   //   %constant2 = Constant(2.0)
    646   //   %tuple = Tuple(%constant1, %constant2)
    647   //   return While(%tuple, outer_body, condition)
    648   //
    649   const Shape tuple_shape =
    650       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
    651 
    652   auto cond_builder = HloComputation::Builder("condition");
    653   cond_builder.AddInstruction(
    654       HloInstruction::CreateParameter(0, tuple_shape, "param"));
    655   cond_builder.AddInstruction(
    656       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    657   HloComputation* condition =
    658       module_->AddEmbeddedComputation(cond_builder.Build());
    659 
    660   // Element 0 passes transparently through the body.
    661   auto inner_builder = HloComputation::Builder("inner_body");
    662   auto inner_param = inner_builder.AddInstruction(
    663       HloInstruction::CreateParameter(0, tuple_shape, "param"));
    664   auto inner_element_0 = inner_builder.AddInstruction(
    665       HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 0));
    666   auto inner_element_1 = inner_builder.AddInstruction(
    667       HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 1));
    668   auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary(
    669       scalar_shape_, HloOpcode::kAdd, inner_element_0, inner_element_1));
    670   inner_builder.AddInstruction(
    671       HloInstruction::CreateTuple({inner_element_0, add}));
    672   HloComputation* inner_body =
    673       module_->AddEmbeddedComputation(inner_builder.Build());
    674 
    675   // Element 1 passes transparently through the body.
    676   auto outer_builder = HloComputation::Builder("outer_body");
    677   auto outer_param = outer_builder.AddInstruction(
    678       HloInstruction::CreateParameter(0, tuple_shape, "param"));
    679   auto outer_element_0 = outer_builder.AddInstruction(
    680       HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 0));
    681   auto negate = outer_builder.AddInstruction(HloInstruction::CreateUnary(
    682       scalar_shape_, HloOpcode::kNegate, outer_element_0));
    683   auto outer_element_1 = outer_builder.AddInstruction(
    684       HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 1));
    685   auto outer_tuple = outer_builder.AddInstruction(
    686       HloInstruction::CreateTuple({negate, outer_element_1}));
    687   auto nested_while = outer_builder.AddInstruction(HloInstruction::CreateWhile(
    688       tuple_shape, condition, inner_body, outer_tuple));
    689   HloComputation* outer_body =
    690       module_->AddEmbeddedComputation(outer_builder.Build());
    691 
    692   auto builder = HloComputation::Builder(TestName());
    693   auto constant1 = builder.AddInstruction(
    694       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    695   auto constant2 = builder.AddInstruction(
    696       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    697   auto tuple = builder.AddInstruction(
    698       HloInstruction::CreateTuple({constant1, constant2}));
    699   auto entry_while = builder.AddInstruction(
    700       HloInstruction::CreateWhile(tuple_shape, condition, outer_body, tuple));
    701   module_->AddEntryComputation(builder.Build());
    702 
    703   bool ssa_form = GetParam();
    704   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
    705 
    706   EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
    707               UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
    708   if (ssa_form) {
    709     EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_param, /*index=*/{1}));
    710     EXPECT_TRUE(
    711         analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi());
    712 
    713     // Element 0 of the nested while is %negate.
    714     EXPECT_FALSE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0}));
    715     EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
    716                 UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
    717     // Element 1 is a phi value (join of %add and %constant2).
    718     EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{1}));
    719     EXPECT_TRUE(
    720         analysis.GetValueDefinedAt(nested_while, /*index=*/{1}).is_phi());
    721 
    722     EXPECT_TRUE(analysis.ValueIsDefinedAt(entry_while, /*index=*/{0}));
    723     EXPECT_TRUE(
    724         analysis.GetValueDefinedAt(entry_while, /*index=*/{0}).is_phi());
    725 
    726     EXPECT_TRUE(analysis.ValueIsDefinedAt(entry_while, /*index=*/{1}));
    727     EXPECT_TRUE(
    728         analysis.GetValueDefinedAt(entry_while, /*index=*/{1}).is_phi());
    729   } else {
    730     EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{1}),
    731                 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
    732                                      analysis.GetValueDefinedAt(constant2)));
    733 
    734     EXPECT_THAT(HloValuesAt(nested_while, /*index=*/{0}),
    735                 UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
    736     EXPECT_THAT(HloValuesAt(nested_while, /*index=*/{1}),
    737                 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
    738                                      analysis.GetValueDefinedAt(constant2)));
    739 
    740     EXPECT_THAT(HloValuesAt(entry_while, /*index=*/{0}),
    741                 UnorderedElementsAre(analysis.GetValueDefinedAt(negate),
    742                                      analysis.GetValueDefinedAt(constant1)));
    743     EXPECT_THAT(HloValuesAt(entry_while, /*index=*/{1}),
    744                 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
    745                                      analysis.GetValueDefinedAt(constant2)));
    746   }
    747 }
    748 
    749 TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) {
    750   // Test a while instruction with a body which permutes it's tuple parameter
    751   // elements. HLO:
    752   //
    753   // body((F32[], F32[]) %tuple_param):
    754   //   return Tuple(%tuple_param{1}, %tuple_param{0})
    755   //
    756   // condition((F32[], F32[]) %tuple_param):
    757   //   return Constant(false)
    758   //
    759   // entry:
    760   //   %constant1 = Constant(1.0)
    761   //   %constant2 = Constant(2.0)
    762   //   %tuple = Tuple(%constant1, %constant2)
    763   //   return While(%tuple, body, condition)
    764   //
    765   const Shape tuple_shape =
    766       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
    767 
    768   auto body_builder = HloComputation::Builder("body");
    769   auto body_param = body_builder.AddInstruction(
    770       HloInstruction::CreateParameter(0, tuple_shape, "param"));
    771   auto body_element_0 = body_builder.AddInstruction(
    772       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
    773   auto body_element_1 = body_builder.AddInstruction(
    774       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
    775   body_builder.AddInstruction(
    776       HloInstruction::CreateTuple({body_element_1, body_element_0}));
    777   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
    778 
    779   auto cond_builder = HloComputation::Builder("condition");
    780   auto cond_param = cond_builder.AddInstruction(
    781       HloInstruction::CreateParameter(0, tuple_shape, "param"));
    782   cond_builder.AddInstruction(
    783       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    784   HloComputation* condition =
    785       module_->AddEmbeddedComputation(cond_builder.Build());
    786 
    787   auto builder = HloComputation::Builder(TestName());
    788   auto constant1 = builder.AddInstruction(
    789       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    790   auto constant2 = builder.AddInstruction(
    791       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    792   auto tuple = builder.AddInstruction(
    793       HloInstruction::CreateTuple({constant1, constant2}));
    794   auto xla_while = builder.AddInstruction(
    795       HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
    796   module_->AddEntryComputation(builder.Build());
    797 
    798   bool ssa_form = GetParam();
    799   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
    800 
    801   if (ssa_form) {
    802     // Element 0 and 1 in the while should both be phi values.
    803     EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
    804     EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{0}).is_phi());
    805     EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
    806     EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi());
    807 
    808     EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
    809     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi());
    810     EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
    811     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
    812 
    813     EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
    814     EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{0}).is_phi());
    815     EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
    816     EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi());
    817 
    818     EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
    819     EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
    820     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{})
    821                     .live_out_of_module());
    822     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0})
    823                     .live_out_of_module());
    824     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
    825                     .live_out_of_module());
    826   } else {
    827     // Elements 0 and 1 have both constants as reaching definitions.
    828     EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{0}),
    829                 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
    830                                      analysis.GetValueDefinedAt(constant2)));
    831     EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{1}),
    832                 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
    833                                      analysis.GetValueDefinedAt(constant2)));
    834     EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
    835     EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
    836   }
    837 }
    838 
    839 TEST_P(HloDataflowAnalysisTest, ArraySelect) {
    840   // Test a kSelect of an array value.
    841   auto builder = HloComputation::Builder(TestName());
    842   auto pred = builder.AddInstruction(
    843       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    844   auto constant1 = builder.AddInstruction(
    845       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    846   auto constant2 = builder.AddInstruction(
    847       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    848   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
    849       scalar_shape_, HloOpcode::kSelect, pred, constant1, constant2));
    850 
    851   module_->AddEntryComputation(builder.Build());
    852 
    853   bool ssa_form = GetParam();
    854   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
    855 
    856   EXPECT_TRUE(analysis.ValueIsDefinedAt(select));
    857   EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
    858   EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
    859   EXPECT_TRUE(analysis.GetValueDefinedAt(select).live_out_of_module());
    860 }
    861 
    862 TEST_P(HloDataflowAnalysisTest, TupleSelect) {
    863   // Test a kSelect of a tuple value. Non-top-level element flow through the
    864   // instruction.
    865   auto builder = HloComputation::Builder(TestName());
    866   auto pred = builder.AddInstruction(
    867       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    868   auto constant1 = builder.AddInstruction(
    869       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    870   auto constant2 = builder.AddInstruction(
    871       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    872   auto constant3 = builder.AddInstruction(
    873       HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
    874   auto constant4 = builder.AddInstruction(
    875       HloInstruction::CreateConstant(Literal::CreateR0<float>(4.0)));
    876   auto tuple1 =
    877       builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
    878   auto tuple2 =
    879       builder.AddInstruction(HloInstruction::CreateTuple({constant2}));
    880   auto tuple3 =
    881       builder.AddInstruction(HloInstruction::CreateTuple({constant3}));
    882   auto tuple4 =
    883       builder.AddInstruction(HloInstruction::CreateTuple({constant4}));
    884   const Shape tuple_shape = tuple1->shape();
    885   auto select11 = builder.AddInstruction(HloInstruction::CreateTernary(
    886       tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple1));
    887   auto select12 = builder.AddInstruction(HloInstruction::CreateTernary(
    888       tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2));
    889   auto select34 = builder.AddInstruction(HloInstruction::CreateTernary(
    890       tuple_shape, HloOpcode::kSelect, pred, tuple3, tuple4));
    891   auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary(
    892       tuple_shape, HloOpcode::kSelect, pred, select12, select34));
    893 
    894   module_->AddEntryComputation(builder.Build());
    895 
    896   bool ssa_form = GetParam();
    897   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
    898 
    899   // Top-level value is always defined by a kSelect.
    900   EXPECT_TRUE(analysis.ValueIsDefinedAt(select11));
    901   EXPECT_TRUE(analysis.ValueIsDefinedAt(select12));
    902   EXPECT_TRUE(analysis.ValueIsDefinedAt(select34));
    903   EXPECT_TRUE(analysis.ValueIsDefinedAt(select1234));
    904 
    905   EXPECT_FALSE(analysis.ValueIsDefinedAt(select11, /*index=*/{0}));
    906   EXPECT_FALSE(analysis.ValueIsDefinedAt(select12, /*index=*/{0}));
    907   EXPECT_FALSE(analysis.ValueIsDefinedAt(select34, /*index=*/{0}));
    908   EXPECT_FALSE(analysis.ValueIsDefinedAt(select1234, /*index=*/{0}));
    909 
    910   EXPECT_THAT(HloValuesAt(select11, /*index=*/{0}),
    911               UnorderedElementsAre(analysis.GetValueDefinedAt(constant1)));
    912   EXPECT_THAT(HloValuesAt(select12, /*index=*/{0}),
    913               UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
    914                                    analysis.GetValueDefinedAt(constant2)));
    915   EXPECT_THAT(HloValuesAt(select34, /*index=*/{0}),
    916               UnorderedElementsAre(analysis.GetValueDefinedAt(constant3),
    917                                    analysis.GetValueDefinedAt(constant4)));
    918   EXPECT_THAT(HloValuesAt(select1234, /*index=*/{0}),
    919               UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
    920                                    analysis.GetValueDefinedAt(constant2),
    921                                    analysis.GetValueDefinedAt(constant3),
    922                                    analysis.GetValueDefinedAt(constant4)));
    923 
    924   EXPECT_THAT(
    925       analysis.GetValueDefinedAt(tuple1, /*index=*/{}).uses(),
    926       UnorderedElementsAre(HloUse{select11, 1, {}}, HloUse{select11, 2, {}},
    927                            HloUse{select12, 1, {}}));
    928 
    929   // The two constant values just pass through the Selects and are not
    930   // used except at the root. They are live out however.
    931   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
    932               UnorderedElementsAre(HloUse{select1234, 1, {0}}));
    933   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
    934               UnorderedElementsAre(HloUse{select1234, 1, {0}}));
    935   EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
    936   EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
    937 }
    938 
    939 TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) {
    940   // Test kSelect of a nested tuple.
    941   auto builder = HloComputation::Builder(TestName());
    942   auto pred = builder.AddInstruction(
    943       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
    944   auto constant1 = builder.AddInstruction(
    945       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    946   auto constant2 = builder.AddInstruction(
    947       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
    948   auto constant3 = builder.AddInstruction(
    949       HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
    950   auto constant4 = builder.AddInstruction(
    951       HloInstruction::CreateConstant(Literal::CreateR0<float>(4.0)));
    952   auto constant5 = builder.AddInstruction(
    953       HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0)));
    954   auto inner_tuple1 = builder.AddInstruction(
    955       HloInstruction::CreateTuple({constant2, constant3}));
    956   auto tuple1 = builder.AddInstruction(
    957       HloInstruction::CreateTuple({constant1, inner_tuple1}));
    958   auto inner_tuple2 = builder.AddInstruction(
    959       HloInstruction::CreateTuple({constant5, constant3}));
    960   auto tuple2 = builder.AddInstruction(
    961       HloInstruction::CreateTuple({constant4, inner_tuple2}));
    962   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
    963       tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
    964 
    965   module_->AddEntryComputation(builder.Build());
    966 
    967   bool ssa_form = GetParam();
    968   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
    969 
    970   EXPECT_TRUE(analysis.ValueIsDefinedAt(select));
    971 
    972   EXPECT_THAT(HloValuesAt(select, /*index=*/{0}),
    973               UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
    974                                    analysis.GetValueDefinedAt(constant4)));
    975   EXPECT_THAT(HloValuesAt(select, /*index=*/{1}),
    976               UnorderedElementsAre(analysis.GetValueDefinedAt(inner_tuple1),
    977                                    analysis.GetValueDefinedAt(inner_tuple2)));
    978   EXPECT_THAT(HloValuesAt(select, /*index=*/{1, 0}),
    979               UnorderedElementsAre(analysis.GetValueDefinedAt(constant2),
    980                                    analysis.GetValueDefinedAt(constant5)));
    981   EXPECT_THAT(HloValuesAt(select, /*index=*/{1, 1}),
    982               UnorderedElementsAre(analysis.GetValueDefinedAt(constant3)));
    983 }
    984 
    985 TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) {
    986   // Test a tuple-shaped kSelect feeding a kWhile instruction. HLO:
    987   //
    988   // body((F32[], F32[]) %tuple_param):
    989   //   %add = Add(%tuple_param{0}, %tuple_param{1})
    990   //   return Tuple(%tuple_param{0}, %add)
    991   //
    992   // condition((F32[], F32[]) %tuple_param):
    993   //   return Constant(false)
    994   //
    995   // entry:
    996   //   %constant1 = Constant(1.0)
    997   //   %constant2 = Constant(2.0)
    998   //   %constant3 = Constant(3.0)
    999   //   %tuple1 = Tuple(%constant1)
   1000   //   %tuple2 = Tuple(%constant2)
   1001   //   %select = Select(%tuple1, %tuple2)
   1002   //   %gte = GetTupleElement(%select, 0)
   1003   //   %tuple = Tuple(%gte, %constant3)
   1004   //   return While(%tuple, body, condition)
   1005   //
   1006   auto builder = HloComputation::Builder(TestName());
   1007 
   1008   const Shape tuple_shape =
   1009       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
   1010 
   1011   // Element 0 passes transparently through the body.
   1012   auto body_builder = HloComputation::Builder("body");
   1013   auto body_param = body_builder.AddInstruction(
   1014       HloInstruction::CreateParameter(0, tuple_shape, "param"));
   1015   auto body_element_0 = body_builder.AddInstruction(
   1016       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
   1017   auto body_element_1 = body_builder.AddInstruction(
   1018       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
   1019   auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
   1020       scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
   1021   body_builder.AddInstruction(
   1022       HloInstruction::CreateTuple({body_element_0, add}));
   1023   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
   1024 
   1025   auto cond_builder = HloComputation::Builder("condition");
   1026   cond_builder.AddInstruction(
   1027       HloInstruction::CreateParameter(0, tuple_shape, "param"));
   1028   cond_builder.AddInstruction(
   1029       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
   1030   HloComputation* condition =
   1031       module_->AddEmbeddedComputation(cond_builder.Build());
   1032 
   1033   auto pred = builder.AddInstruction(
   1034       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
   1035   auto constant1 = builder.AddInstruction(
   1036       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
   1037   auto constant2 = builder.AddInstruction(
   1038       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
   1039   auto constant3 = builder.AddInstruction(
   1040       HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
   1041   auto tuple1 =
   1042       builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
   1043   auto tuple2 =
   1044       builder.AddInstruction(HloInstruction::CreateTuple({constant2}));
   1045   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
   1046       tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
   1047   auto gte = builder.AddInstruction(
   1048       HloInstruction::CreateGetTupleElement(scalar_shape_, select, 0));
   1049   auto tuple =
   1050       builder.AddInstruction(HloInstruction::CreateTuple({gte, constant3}));
   1051   auto xla_while = builder.AddInstruction(
   1052       HloInstruction::CreateWhile(tuple->shape(), condition, body, tuple));
   1053 
   1054   module_->AddEntryComputation(builder.Build());
   1055 
   1056   bool ssa_form = GetParam();
   1057   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
   1058 
   1059   if (ssa_form) {
   1060     EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
   1061     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi());
   1062     EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
   1063     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
   1064 
   1065     EXPECT_FALSE(analysis.ValueIsDefinedAt(select, /*index=*/{0}));
   1066 
   1067     EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
   1068     EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
   1069     EXPECT_FALSE(analysis.GetValueDefinedAt(constant3).live_out_of_module());
   1070     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
   1071                     .live_out_of_module());
   1072   } else {
   1073     EXPECT_THAT(HloValuesAt(gte),
   1074                 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
   1075                                      analysis.GetValueDefinedAt(constant2)));
   1076     EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{0}),
   1077                 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
   1078                                      analysis.GetValueDefinedAt(constant2)));
   1079     EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{1}),
   1080                 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
   1081                                      analysis.GetValueDefinedAt(constant3)));
   1082     EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
   1083     EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
   1084     EXPECT_TRUE(analysis.GetValueDefinedAt(constant3).live_out_of_module());
   1085   }
   1086 }
   1087 
   1088 TEST_P(HloDataflowAnalysisTest, BitcastDefinesValue) {
   1089   // Test the bitcast_defines_value flag to the dataflow analysis.
   1090   auto builder = HloComputation::Builder(TestName());
   1091   auto constant = builder.AddInstruction(
   1092       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
   1093   auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
   1094       scalar_shape_, HloOpcode::kBitcast, constant));
   1095 
   1096   module_->AddEntryComputation(builder.Build());
   1097 
   1098   bool ssa_form = GetParam();
   1099   {
   1100     const HloDataflowAnalysis& analysis =
   1101         RunAnalysis(ssa_form, /*bitcast_defines_value=*/true);
   1102 
   1103     EXPECT_EQ(analysis.values().size(), 2);
   1104 
   1105     EXPECT_TRUE(analysis.ValueIsDefinedAt(constant));
   1106     EXPECT_TRUE(analysis.ValueIsDefinedAt(bitcast));
   1107     EXPECT_FALSE(analysis.GetValueDefinedAt(constant).live_out_of_module());
   1108     EXPECT_TRUE(analysis.GetValueDefinedAt(bitcast).live_out_of_module());
   1109   }
   1110   {
   1111     const HloDataflowAnalysis& analysis =
   1112         RunAnalysis(ssa_form, /*bitcast_defines_value=*/false);
   1113     EXPECT_EQ(analysis.values().size(), 1);
   1114 
   1115     EXPECT_TRUE(analysis.ValueIsDefinedAt(constant));
   1116     EXPECT_FALSE(analysis.ValueIsDefinedAt(bitcast));
   1117     EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module());
   1118   }
   1119 }
   1120 
   1121 TEST_P(HloDataflowAnalysisTest, TupleCopy) {
   1122   // Test that a tuple-shaped copy only copies (defines) the top-level value.
   1123   auto builder = HloComputation::Builder(TestName());
   1124   auto param0 = builder.AddInstruction(
   1125       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
   1126   auto param1 = builder.AddInstruction(
   1127       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
   1128   auto tuple =
   1129       builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
   1130   auto copy = builder.AddInstruction(
   1131       HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple));
   1132   module_->AddEntryComputation(builder.Build());
   1133 
   1134   bool ssa_form = GetParam();
   1135   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
   1136 
   1137   EXPECT_EQ(analysis.values().size(), 4);
   1138 
   1139   EXPECT_TRUE(analysis.ValueIsDefinedAt(param0));
   1140   EXPECT_TRUE(analysis.ValueIsDefinedAt(param1));
   1141   EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple, /*index=*/{}));
   1142   EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{0}));
   1143   EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{1}));
   1144   EXPECT_TRUE(analysis.ValueIsDefinedAt(copy, /*index=*/{}));
   1145   EXPECT_FALSE(analysis.ValueIsDefinedAt(copy, /*index=*/{0}));
   1146   EXPECT_FALSE(analysis.ValueIsDefinedAt(copy, /*index=*/{1}));
   1147 
   1148   EXPECT_THAT(HloValuesAt(copy, /*index=*/{0}),
   1149               UnorderedElementsAre(analysis.GetValueDefinedAt(param0)));
   1150   EXPECT_THAT(HloValuesAt(copy, /*index=*/{1}),
   1151               UnorderedElementsAre(analysis.GetValueDefinedAt(param1)));
   1152   EXPECT_TRUE(
   1153       analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module());
   1154 }
   1155 
   1156 TEST_P(HloDataflowAnalysisTest, SendAndSendDone) {
   1157   // Test that a Send forwards its operand to the output tuple at {0}.
   1158   auto builder = HloComputation::Builder(TestName());
   1159   auto param = builder.AddInstruction(
   1160       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
   1161   auto send = builder.AddInstruction(
   1162       HloInstruction::CreateSend(param, /*channel_id=*/0));
   1163   auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
   1164   module_->AddEntryComputation(builder.Build());
   1165 
   1166   bool ssa_form = GetParam();
   1167   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
   1168 
   1169   EXPECT_EQ(analysis.values().size(), 4);
   1170 
   1171   EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
   1172   EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{}));
   1173   EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0}));
   1174   EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1}));
   1175   EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done));
   1176   EXPECT_THAT(HloValuesAt(send, /*index=*/{0}),
   1177               UnorderedElementsAre(analysis.GetValueDefinedAt(param)));
   1178 }
   1179 
   1180 TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) {
   1181   // Test that a RecvDone forwards its operand tuple element at {0} to the
   1182   // output.
   1183   auto builder = HloComputation::Builder(TestName());
   1184   auto recv = builder.AddInstruction(
   1185       HloInstruction::CreateRecv(scalar_shape_, /*channel_id=*/0));
   1186   auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
   1187   module_->AddEntryComputation(builder.Build());
   1188 
   1189   bool ssa_form = GetParam();
   1190   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
   1191 
   1192   EXPECT_EQ(analysis.values().size(), 3);
   1193 
   1194   EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{}));
   1195   EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0}));
   1196   EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1}));
   1197   EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done));
   1198   EXPECT_THAT(HloValuesAt(recv_done),
   1199               UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0})));
   1200   EXPECT_TRUE(
   1201       analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module());
   1202 }
   1203 
   1204 TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) {
   1205   // A simple chain of elementwise operations. No values should interfere.
   1206   //
   1207   // param --> negate -> exp -> log
   1208   //
   1209   auto builder = HloComputation::Builder(TestName());
   1210   auto param = builder.AddInstruction(
   1211       HloInstruction::CreateParameter(0, vector_shape_, "param"));
   1212   auto negate = builder.AddInstruction(
   1213       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
   1214   auto exp = builder.AddInstruction(
   1215       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, negate));
   1216   auto log = builder.AddInstruction(
   1217       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kLog, exp));
   1218 
   1219   module_->AddEntryComputation(builder.Build());
   1220   RunAnalysis(GetParam());
   1221 
   1222   DependencyHloOrdering ordering(module_.get());
   1223 
   1224   // No values should interfere.
   1225   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate));
   1226   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
   1227   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, log));
   1228   EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, exp));
   1229   EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, log));
   1230   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate));
   1231   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, log));
   1232   EXPECT_FALSE(InstructionsMayInterfere(ordering, log, negate));
   1233   EXPECT_FALSE(InstructionsMayInterfere(ordering, log, exp));
   1234 
   1235   // Values should interfere with itself.
   1236   EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, exp));
   1237 }
   1238 
   1239 TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) {
   1240   // Two entry params, which interfere with each other.
   1241   //
   1242   // param0 --> negate ---------------\
   1243   //                param1 --> exp --> add
   1244   auto builder = HloComputation::Builder(TestName());
   1245   auto param0 = builder.AddInstruction(
   1246       HloInstruction::CreateParameter(0, vector_shape_, "param0"));
   1247   auto param1 = builder.AddInstruction(
   1248       HloInstruction::CreateParameter(1, vector_shape_, "param1"));
   1249   auto negate = builder.AddInstruction(
   1250       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param0));
   1251   auto exp = builder.AddInstruction(
   1252       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param1));
   1253   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
   1254       vector_shape_, HloOpcode::kAdd, negate, exp));
   1255 
   1256   auto entry = module_->AddEntryComputation(builder.Build());
   1257   RunAnalysis(GetParam());
   1258 
   1259   SequentialHloOrdering::HloModuleSequence sequence;
   1260   sequence.insert({entry, {param0, negate, param1, exp, add}});
   1261   SequentialHloOrdering ordering(module_.get(), sequence);
   1262 
   1263   // Entry parameters interfere as if they are defined simultaneously at
   1264   // the very beginning.
   1265   EXPECT_TRUE(InstructionsMayInterfere(ordering, param0, param1));
   1266   EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, negate));
   1267   EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, exp));
   1268   EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, add));
   1269   EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, param0));
   1270   EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, negate));
   1271   EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, exp));
   1272   EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, add));
   1273 
   1274   // Negate and exp still interfere.
   1275   EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
   1276   EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
   1277 
   1278   // But {negate, add} and {exp, add} don't interfere.
   1279   EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
   1280   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
   1281   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
   1282   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
   1283 }
   1284 
   1285 TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) {
   1286   // Similar to MultipleEntryParameters_Sequential, but the parameter is of
   1287   // while body computation. Body computation in the sequential order:
   1288   //
   1289   //  %constant = Constant(...)
   1290   //  %exp = Exp(%constant)
   1291   //  %param = Param(0)
   1292   //  %add = Add(%param, %exp)  ;; Root of body
   1293   //  %dead_constant = Constant(...)
   1294   //  %dead_negate = Negate(%dead_constant)
   1295   //
   1296   // %constant and its only use %exp are ordered before 'param'. However, the
   1297   // %constant and %param values still interfere because the parameter is
   1298   // considered live into the while body.
   1299   //
   1300   // Similarly, %dead_constant and %dead_negate are ordered after the root of
   1301   // the body computation %add. However, %add is liveout of the computation so
   1302   // %dead_constant and %add interfere.
   1303   auto body_builder = HloComputation::Builder(TestName());
   1304   auto body_param = body_builder.AddInstruction(
   1305       HloInstruction::CreateParameter(0, scalar_shape_, "body_param"));
   1306   auto constant = body_builder.AddInstruction(
   1307       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
   1308   auto exp = body_builder.AddInstruction(
   1309       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
   1310   auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
   1311       scalar_shape_, HloOpcode::kAdd, exp, body_param));
   1312   auto dead_constant = body_builder.AddInstruction(
   1313       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
   1314   auto dead_negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
   1315       scalar_shape_, HloOpcode::kNegate, dead_constant));
   1316   HloComputation* body = module_->AddEmbeddedComputation(
   1317       body_builder.Build(/*root_instruction=*/add));
   1318 
   1319   auto cond_builder = HloComputation::Builder("condition");
   1320   auto cond_param = cond_builder.AddInstruction(
   1321       HloInstruction::CreateParameter(0, scalar_shape_, "cond_param"));
   1322   auto cond_constant = cond_builder.AddInstruction(
   1323       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
   1324   HloComputation* condition =
   1325       module_->AddEmbeddedComputation(cond_builder.Build());
   1326 
   1327   auto builder = HloComputation::Builder(TestName());
   1328   auto param = builder.AddInstruction(
   1329       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
   1330   auto xla_while = builder.AddInstruction(
   1331       HloInstruction::CreateWhile(scalar_shape_, condition, body, param));
   1332 
   1333   auto entry = module_->AddEntryComputation(builder.Build());
   1334   bool ssa_form = GetParam();
   1335   RunAnalysis(ssa_form);
   1336 
   1337   SequentialHloOrdering::HloModuleSequence sequence;
   1338   sequence.insert({entry, {param, xla_while}});
   1339   sequence.insert({condition, {cond_param, cond_constant}});
   1340   // Construct the order such that 'constant' and its use 'exp' are before
   1341   // body_param.
   1342   sequence.insert({body, {constant, exp, body_param, add}});
   1343 
   1344   SequentialHloOrdering ordering(module_.get(), sequence);
   1345 
   1346   // 'add' is live out of the body and will interfere with an later instructions
   1347   // such as 'dead_constant' and 'dead_negate'.
   1348   EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_constant));
   1349   EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_negate));
   1350 
   1351   // The remaining checks test phi values defined by body and condition
   1352   // parameters which only occur in the SSA form of the analysis.
   1353   if (ssa_form) {
   1354     // Though the ordering suggests 'constant' and 'param' should not interfere,
   1355     // 'param' is live in and thus interferes with any earlier instruction of
   1356     // the computation in the order (eg 'constant')'
   1357     EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, constant));
   1358     EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, exp));
   1359     EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add));
   1360 
   1361     // The following values end up in the same buffer:
   1362     //  (1) the init value: 'param'
   1363     //  (2) the body parameter: 'body_param'
   1364     //  (3) the condition parameter: 'cond_param'
   1365     //  (4) the root value of the while body: 'add'
   1366     //  (5) the while value: 'xla_while'
   1367     // None should interfere.
   1368     EXPECT_FALSE(InstructionsMayInterfere(ordering, param, body_param));
   1369     EXPECT_FALSE(InstructionsMayInterfere(ordering, param, cond_param));
   1370     EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
   1371     EXPECT_FALSE(InstructionsMayInterfere(ordering, param, xla_while));
   1372 
   1373     EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, cond_param));
   1374     EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add));
   1375     EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, xla_while));
   1376 
   1377     EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, add));
   1378     EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, xla_while));
   1379 
   1380     EXPECT_FALSE(InstructionsMayInterfere(ordering, add, xla_while));
   1381   }
   1382 }
   1383 
   1384 TEST_P(HloDataflowAnalysisTest, NonElementwiseOperand) {
   1385   // A chain of operations with two elementwise and one non-elementwise. The
   1386   // elementwise op should not interfere with its operand, while the
   1387   // non-elementwise op should interfere. Entry params always interfere.
   1388   //
   1389   // param --> exp -> negate -> reverse
   1390   //
   1391   auto builder = HloComputation::Builder(TestName());
   1392   auto param = builder.AddInstruction(
   1393       HloInstruction::CreateParameter(0, vector_shape_, "param"));
   1394   auto exp = builder.AddInstruction(
   1395       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
   1396   auto negate = builder.AddInstruction(
   1397       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, exp));
   1398   auto reverse = builder.AddInstruction(
   1399       HloInstruction::CreateReverse(vector_shape_, negate, {0}));
   1400 
   1401   module_->AddEntryComputation(builder.Build());
   1402   RunAnalysis(GetParam());
   1403 
   1404   DependencyHloOrdering ordering(module_.get());
   1405 
   1406   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
   1407   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate));
   1408   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, reverse));
   1409 
   1410   // Negate is elementwise, so doesn't interfere with its operand.
   1411   // Reverse is non-elementwise, so does interfere with its operand.
   1412   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate));
   1413   EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, reverse));
   1414 }
   1415 
   1416 TEST_P(HloDataflowAnalysisTest, OverlappedValues) {
   1417   // Verify simultaneously live values interfere (exp and negate).
   1418   //
   1419   // param --> negate -> add
   1420   //     \---> exp -----/
   1421   //
   1422   auto builder = HloComputation::Builder(TestName());
   1423   auto param = builder.AddInstruction(
   1424       HloInstruction::CreateParameter(0, vector_shape_, "param"));
   1425   auto negate = builder.AddInstruction(
   1426       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
   1427   auto exp = builder.AddInstruction(
   1428       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
   1429   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
   1430       vector_shape_, HloOpcode::kAdd, negate, exp));
   1431 
   1432   module_->AddEntryComputation(builder.Build());
   1433   RunAnalysis(GetParam());
   1434 
   1435   DependencyHloOrdering ordering(module_.get());
   1436 
   1437   EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
   1438   EXPECT_TRUE(InstructionsMayInterfere(ordering, param, exp));
   1439   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
   1440 
   1441   // Negate and exp interfere with each other, but not with add.
   1442   EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
   1443   EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
   1444   EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
   1445   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
   1446   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
   1447   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
   1448 }
   1449 
   1450 TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) {
   1451   // Identical to the test OverlappedValue but using a sequential ordering of
   1452   // HLO instructions.
   1453   //
   1454   // param --> negate -> add
   1455   //     \---> exp -----/
   1456   //
   1457   // Sequential order:
   1458   //  param, negate, exp, add
   1459   //
   1460   // Liveness is identical to the DependencyHloOrdering.
   1461   auto builder = HloComputation::Builder(TestName());
   1462   auto param = builder.AddInstruction(
   1463       HloInstruction::CreateParameter(0, vector_shape_, "param"));
   1464   auto negate = builder.AddInstruction(
   1465       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
   1466   auto exp = builder.AddInstruction(
   1467       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
   1468   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
   1469       vector_shape_, HloOpcode::kAdd, negate, exp));
   1470 
   1471   auto entry = module_->AddEntryComputation(builder.Build());
   1472   RunAnalysis(GetParam());
   1473 
   1474   SequentialHloOrdering::HloModuleSequence sequence;
   1475   std::vector<const HloInstruction*> order = {param, negate, exp, add};
   1476   sequence.emplace(entry, order);
   1477 
   1478   SequentialHloOrdering ordering(module_.get(), sequence);
   1479 
   1480   EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
   1481   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
   1482   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
   1483 
   1484   // Negate and exp interfere with each other, but not with add.
   1485   EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
   1486   EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
   1487   EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
   1488   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
   1489   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
   1490   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
   1491 }
   1492 
   1493 TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) {
   1494   // Test MayInterfere() for embedded computation, specifically the interference
   1495   // of values in different computations.
   1496   //
   1497   // embedded_computation:
   1498   //   %embedded_param = Param(0)
   1499   //   %embedded_log = Log(%embedded_param)
   1500   //
   1501   // entry computation:
   1502   //   %param = Param(0)
   1503   //   %negate = Negate(%param)
   1504   //   %exp = Negate(%exp)
   1505   //   %call = Call(embedded_computation, {%exp})
   1506   //   %add = Add(%negate, %call)
   1507   //
   1508   // Note %negate is live across the call and should interfere with all values
   1509   // in the embedded computation.
   1510   auto embedded_builder = HloComputation::Builder(TestName() + "_embedded");
   1511   auto embedded_param = embedded_builder.AddInstruction(
   1512       HloInstruction::CreateParameter(0, vector_shape_, "embedded_param"));
   1513   auto embedded_log =
   1514       embedded_builder.AddInstruction(HloInstruction::CreateUnary(
   1515           vector_shape_, HloOpcode::kLog, embedded_param));
   1516   auto embedded_computation =
   1517       module_->AddEmbeddedComputation(embedded_builder.Build());
   1518 
   1519   auto builder = HloComputation::Builder(TestName());
   1520   auto param = builder.AddInstruction(
   1521       HloInstruction::CreateParameter(0, vector_shape_, "param"));
   1522   auto negate = builder.AddInstruction(
   1523       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
   1524   auto exp = builder.AddInstruction(
   1525       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
   1526   auto call = builder.AddInstruction(
   1527       HloInstruction::CreateCall(vector_shape_, {exp}, embedded_computation));
   1528   builder.AddInstruction(HloInstruction::CreateBinary(
   1529       vector_shape_, HloOpcode::kAdd, negate, call));
   1530   module_->AddEntryComputation(builder.Build());
   1531   RunAnalysis(GetParam());
   1532 
   1533   DependencyHloOrdering ordering(module_.get());
   1534 
   1535   // Exp only use is the call so it should not interfere with values inside the
   1536   // embedded computation.
   1537   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, embedded_log));
   1538 
   1539   // Negate is live across the call and should interfere with values in the
   1540   // embedded computation
   1541   EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log));
   1542 }
   1543 
   1544 TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) {
   1545   // Test conditional with identity computations in both true and false cases.
   1546   //
   1547   // true_computation(F32[] %true_param):
   1548   //   return %true_param
   1549   //
   1550   // false_computation(F32[] %false_param):
   1551   //   return %false_param
   1552   //
   1553   // entry:
   1554   //   %pred = Constant(true)
   1555   //   %constant1 = Constant(56.0)
   1556   //   %constant2 = Constant(12.0)
   1557   //   return Conditional(%pred, %constant1, true_computation,
   1558   //                      %constant2, false_computation)
   1559 
   1560   auto true_builder = HloComputation::Builder(TestName() + "_true");
   1561   auto true_param = true_builder.AddInstruction(
   1562       HloInstruction::CreateParameter(0, scalar_shape_, "true_param"));
   1563   HloComputation* true_computation =
   1564       module_->AddEmbeddedComputation(true_builder.Build());
   1565 
   1566   auto false_builder = HloComputation::Builder(TestName() + "_false");
   1567   auto false_param = false_builder.AddInstruction(
   1568       HloInstruction::CreateParameter(0, scalar_shape_, "false_param"));
   1569   HloComputation* false_computation =
   1570       module_->AddEmbeddedComputation(false_builder.Build());
   1571 
   1572   auto builder = HloComputation::Builder(TestName());
   1573   auto pred = builder.AddInstruction(
   1574       HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
   1575   auto constant1 = builder.AddInstruction(
   1576       HloInstruction::CreateConstant(Literal::CreateR0<float>(56.0f)));
   1577   auto constant2 = builder.AddInstruction(
   1578       HloInstruction::CreateConstant(Literal::CreateR0<float>(12.0f)));
   1579   auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
   1580       scalar_shape_, pred, constant1, true_computation, constant2,
   1581       false_computation));
   1582   module_->AddEntryComputation(builder.Build());
   1583 
   1584   const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
   1585 
   1586   EXPECT_TRUE(analysis.ValueIsDefinedAt(pred));
   1587   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
   1588   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
   1589 
   1590   EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param));
   1591   EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param));
   1592 
   1593   EXPECT_EQ(analysis.GetUniqueValueAt(true_param),
   1594             analysis.GetValueDefinedAt(constant1));
   1595   EXPECT_EQ(analysis.GetUniqueValueAt(false_param),
   1596             analysis.GetValueDefinedAt(constant2));
   1597 
   1598   EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(),
   1599               ElementsAre(HloUse{conditional, 0, {}}));
   1600   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
   1601               ElementsAre(HloUse{conditional, 1, {}}));
   1602   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
   1603               ElementsAre(HloUse{conditional, 2, {}}));
   1604 
   1605   EXPECT_EQ(analysis.values().size(), 3);
   1606   EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
   1607   EXPECT_THAT(HloValuesAt(conditional),
   1608               UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
   1609                                    analysis.GetValueDefinedAt(constant2)));
   1610 }
   1611 
   1612 TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) {
   1613   // Test conditional with true and false computations taking a tuple operand.
   1614   //
   1615   // true_computation((F32[], F32[]) %true_param):
   1616   //   %true_x = GetTupleElement(%true_param, 0)
   1617   //   %true_y = GetTupleElement(%true_param, 1)
   1618   //   return Add(%true_x, %true_y)
   1619   //
   1620   // false_computation((F32[], F32[]) %false_param):
   1621   //   %false_x = GetTupleElement(%false_param, 0)
   1622   //   %false_y = GetTupleElement(%false_param, 1)
   1623   //   return Subtract(%false_x, %false_y)
   1624   //
   1625   // entry:
   1626   //   %pred = Constant(true)
   1627   //   %constant1 = Constant(56.0)
   1628   //   %constant2 = Constant(12.0)
   1629   //   %tuple_operand = Tuple(%constant1, %constant2)
   1630   //   return Conditional(%pred, %tuple_operand, true_computation,
   1631   //                      %tuple_operand, false_computation)
   1632 
   1633   auto true_builder = HloComputation::Builder(TestName() + "_true");
   1634   auto true_param = true_builder.AddInstruction(
   1635       HloInstruction::CreateParameter(0, tuple_shape_, "true_param"));
   1636   auto true_x = true_builder.AddInstruction(
   1637       HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 0));
   1638   auto true_y = true_builder.AddInstruction(
   1639       HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 1));
   1640   auto add = true_builder.AddInstruction(HloInstruction::CreateBinary(
   1641       scalar_shape_, HloOpcode::kAdd, true_x, true_y));
   1642   HloComputation* true_computation =
   1643       module_->AddEmbeddedComputation(true_builder.Build());
   1644 
   1645   auto false_builder = HloComputation::Builder(TestName() + "_false");
   1646   auto false_param = false_builder.AddInstruction(
   1647       HloInstruction::CreateParameter(0, tuple_shape_, "false_param"));
   1648   auto false_x = false_builder.AddInstruction(
   1649       HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 0));
   1650   auto false_y = false_builder.AddInstruction(
   1651       HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 1));
   1652   auto sub = false_builder.AddInstruction(HloInstruction::CreateBinary(
   1653       scalar_shape_, HloOpcode::kSubtract, false_x, false_y));
   1654   HloComputation* false_computation =
   1655       module_->AddEmbeddedComputation(false_builder.Build());
   1656 
   1657   auto builder = HloComputation::Builder(TestName());
   1658   auto pred = builder.AddInstruction(
   1659       HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
   1660   auto constant1 = builder.AddInstruction(
   1661       HloInstruction::CreateConstant(Literal::CreateR0<float>(56.0f)));
   1662   auto constant2 = builder.AddInstruction(
   1663       HloInstruction::CreateConstant(Literal::CreateR0<float>(12.0f)));
   1664   auto tuple_operand = builder.AddInstruction(
   1665       HloInstruction::CreateTuple({constant1, constant2}));
   1666   auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
   1667       scalar_shape_, pred, tuple_operand, true_computation, tuple_operand,
   1668       false_computation));
   1669   module_->AddEntryComputation(builder.Build());
   1670 
   1671   const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
   1672 
   1673   EXPECT_TRUE(analysis.ValueIsDefinedAt(pred));
   1674   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
   1675   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
   1676   EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand));
   1677   EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
   1678   EXPECT_TRUE(analysis.ValueIsDefinedAt(sub));
   1679 
   1680   EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param));
   1681   EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param));
   1682   EXPECT_FALSE(analysis.ValueIsDefinedAt(true_x));
   1683   EXPECT_FALSE(analysis.ValueIsDefinedAt(true_y));
   1684   EXPECT_FALSE(analysis.ValueIsDefinedAt(false_x));
   1685   EXPECT_FALSE(analysis.ValueIsDefinedAt(false_y));
   1686 
   1687   EXPECT_EQ(analysis.GetUniqueValueAt(true_param),
   1688             analysis.GetValueDefinedAt(tuple_operand));
   1689   EXPECT_EQ(analysis.GetUniqueValueAt(false_param),
   1690             analysis.GetValueDefinedAt(tuple_operand));
   1691   EXPECT_EQ(analysis.GetUniqueValueAt(true_x),
   1692             analysis.GetValueDefinedAt(constant1));
   1693   EXPECT_EQ(analysis.GetUniqueValueAt(true_y),
   1694             analysis.GetValueDefinedAt(constant2));
   1695   EXPECT_EQ(analysis.GetUniqueValueAt(false_x),
   1696             analysis.GetValueDefinedAt(constant1));
   1697   EXPECT_EQ(analysis.GetUniqueValueAt(false_y),
   1698             analysis.GetValueDefinedAt(constant2));
   1699 
   1700   EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(),
   1701               ElementsAre(HloUse{conditional, 0, {}}));
   1702   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
   1703               UnorderedElementsAre(HloUse{conditional, 1, {0}},
   1704                                    HloUse{conditional, 2, {0}},
   1705                                    HloUse{add, 0, {}}, HloUse{sub, 0, {}}));
   1706   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
   1707               UnorderedElementsAre(HloUse{conditional, 1, {1}},
   1708                                    HloUse{conditional, 2, {1}},
   1709                                    HloUse{add, 1, {}}, HloUse{sub, 1, {}}));
   1710   EXPECT_THAT(analysis.GetValueDefinedAt(tuple_operand).uses(),
   1711               UnorderedElementsAre(
   1712                   HloUse{conditional, 1, {}}, HloUse{conditional, 2, {}},
   1713                   HloUse{true_x, 0, {}}, HloUse{true_y, 0, {}},
   1714                   HloUse{false_x, 0, {}}, HloUse{false_y, 0, {}}));
   1715 
   1716   EXPECT_EQ(analysis.values().size(), 6);
   1717   EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
   1718   EXPECT_THAT(HloValuesAt(conditional),
   1719               UnorderedElementsAre(analysis.GetValueDefinedAt(add),
   1720                                    analysis.GetValueDefinedAt(sub)));
   1721 }
   1722 
   1723 TEST_P(HloDataflowAnalysisTest, NestedConditionals) {
   1724   // computation1(F32[] %param1):
   1725   //   %ceil = Ceil(%param1)
   1726   //   return %ceil
   1727   //
   1728   // computation2(F32[] %param2):
   1729   //   %floor = Floor(%param2)
   1730   //   return %floor
   1731   //
   1732   // computation3(F32[] %param3):
   1733   //   %negate = Negate(%param3)
   1734   //   return %negate
   1735   //
   1736   // inner_conditional((PRED, F32[], F32[]) %param_cond):
   1737   //   %pred_cond = GetTupleElement(%param_cond, 0)
   1738   //   %true_operand_cond = GetTupleElement(%param_cond, 1)
   1739   //   %false_opearnd_cond = GetTupleElement(%param_cond, 2)
   1740   //   return Conditional(%pred_cond, %true_operand_cond, computation1,
   1741   //                      %false_operand_cond, computation2)
   1742   //
   1743   // entry:
   1744   //   %pred1 = Constant(true)
   1745   //   %pred2 = Constant(false)
   1746   //   %constant1 = Constant(1.1);
   1747   //   %constant2 = Constant(2.2);
   1748   //   %constant3 = Constant(3.3);
   1749   //   return Conditional(%pred1, (%pred2, %constant1, %constant2),
   1750   //                      inner_conditional, %constant3, computation3)
   1751 
   1752   auto computation1 = module_->AddEmbeddedComputation(
   1753       CreateR0F32UnaryOpComputation(HloOpcode::kCeil));
   1754   auto computation2 = module_->AddEmbeddedComputation(
   1755       CreateR0F32UnaryOpComputation(HloOpcode::kFloor));
   1756   auto computation3 = module_->AddEmbeddedComputation(
   1757       CreateR0F32UnaryOpComputation(HloOpcode::kNegate));
   1758 
   1759   // Build inner_conditional computation.
   1760   const Shape scalar_bool_shape = ShapeUtil::MakeShape(PRED, {});
   1761   const Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
   1762       {scalar_bool_shape, scalar_shape_, scalar_shape_});
   1763   auto inner_builder =
   1764       HloComputation::Builder(TestName() + "_inner_conditional");
   1765   auto param_cond = inner_builder.AddInstruction(
   1766       HloInstruction::CreateParameter(0, tuple_param_shape, "param_cond"));
   1767   auto pred_cond = inner_builder.AddInstruction(
   1768       HloInstruction::CreateGetTupleElement(scalar_bool_shape, param_cond, 0));
   1769   auto true_operand_cond = inner_builder.AddInstruction(
   1770       HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 1));
   1771   auto false_operand_cond = inner_builder.AddInstruction(
   1772       HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 2));
   1773   auto inner_conditional =
   1774       inner_builder.AddInstruction(HloInstruction::CreateConditional(
   1775           scalar_shape_, pred_cond, true_operand_cond, computation1,
   1776           false_operand_cond, computation2));
   1777   auto inner_conditional_computation =
   1778       module_->AddEmbeddedComputation(inner_builder.Build());
   1779 
   1780   // Build entry computation.
   1781   auto builder = HloComputation::Builder(TestName());
   1782   auto pred1 = builder.AddInstruction(
   1783       HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
   1784   auto pred2 = builder.AddInstruction(
   1785       HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
   1786   auto constant1 = builder.AddInstruction(
   1787       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
   1788   auto constant2 = builder.AddInstruction(
   1789       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.2f)));
   1790   auto constant3 = builder.AddInstruction(
   1791       HloInstruction::CreateConstant(Literal::CreateR0<float>(3.3f)));
   1792   auto tuple_operand = builder.AddInstruction(
   1793       HloInstruction::CreateTuple({pred2, constant1, constant2}));
   1794   auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
   1795       scalar_shape_, pred1, tuple_operand, inner_conditional_computation,
   1796       constant3, computation3));
   1797   module_->AddEntryComputation(builder.Build());
   1798 
   1799   const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
   1800 
   1801   EXPECT_TRUE(analysis.ValueIsDefinedAt(pred1));
   1802   EXPECT_TRUE(analysis.ValueIsDefinedAt(pred2));
   1803   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
   1804   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
   1805   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant3));
   1806   EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand));
   1807   EXPECT_TRUE(analysis.ValueIsDefinedAt(computation1->root_instruction()));
   1808   EXPECT_TRUE(analysis.ValueIsDefinedAt(computation2->root_instruction()));
   1809   EXPECT_TRUE(analysis.ValueIsDefinedAt(computation3->root_instruction()));
   1810 
   1811   auto computation1_param = computation1->parameter_instruction(0);
   1812   auto computation2_param = computation2->parameter_instruction(0);
   1813   auto computation3_param = computation3->parameter_instruction(0);
   1814   EXPECT_FALSE(analysis.ValueIsDefinedAt(computation1_param));
   1815   EXPECT_FALSE(analysis.ValueIsDefinedAt(computation2_param));
   1816   EXPECT_FALSE(analysis.ValueIsDefinedAt(computation3_param));
   1817   EXPECT_EQ(analysis.GetUniqueValueAt(computation1_param),
   1818             analysis.GetValueDefinedAt(constant1));
   1819   EXPECT_EQ(analysis.GetUniqueValueAt(computation2_param),
   1820             analysis.GetValueDefinedAt(constant2));
   1821   EXPECT_EQ(analysis.GetUniqueValueAt(computation3_param),
   1822             analysis.GetValueDefinedAt(constant3));
   1823 
   1824   EXPECT_FALSE(analysis.ValueIsDefinedAt(param_cond));
   1825   EXPECT_FALSE(analysis.ValueIsDefinedAt(pred_cond));
   1826   EXPECT_FALSE(analysis.ValueIsDefinedAt(true_operand_cond));
   1827   EXPECT_FALSE(analysis.ValueIsDefinedAt(false_operand_cond));
   1828   EXPECT_EQ(analysis.GetUniqueValueAt(param_cond),
   1829             analysis.GetValueDefinedAt(tuple_operand));
   1830   EXPECT_EQ(analysis.GetUniqueValueAt(pred_cond),
   1831             analysis.GetValueDefinedAt(pred2));
   1832   EXPECT_EQ(analysis.GetUniqueValueAt(true_operand_cond),
   1833             analysis.GetValueDefinedAt(constant1));
   1834   EXPECT_EQ(analysis.GetUniqueValueAt(false_operand_cond),
   1835             analysis.GetValueDefinedAt(constant2));
   1836 
   1837   EXPECT_EQ(analysis.values().size(), 9);
   1838   EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional));
   1839   EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
   1840   EXPECT_THAT(
   1841       HloValuesAt(inner_conditional),
   1842       UnorderedElementsAre(
   1843           analysis.GetValueDefinedAt(computation1->root_instruction()),
   1844           analysis.GetValueDefinedAt(computation2->root_instruction())));
   1845   EXPECT_THAT(
   1846       HloValuesAt(conditional),
   1847       UnorderedElementsAre(
   1848           analysis.GetValueDefinedAt(computation1->root_instruction()),
   1849           analysis.GetValueDefinedAt(computation2->root_instruction()),
   1850           analysis.GetValueDefinedAt(computation3->root_instruction())));
   1851 }
   1852 
   1853 INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation,
   1854                         HloDataflowAnalysisTest,
   1855                         ::testing::Values(false, true));
   1856 
   1857 }  // namespace
   1858 }  // namespace xla
   1859