1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" 17 18 #include "tensorflow/compiler/xla/literal_util.h" 19 #include "tensorflow/compiler/xla/service/hlo_computation.h" 20 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" 21 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 22 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 23 #include "tensorflow/compiler/xla/service/hlo_ordering.h" 24 #include "tensorflow/compiler/xla/service/instruction_fusion.h" 25 #include "tensorflow/compiler/xla/shape_util.h" 26 #include "tensorflow/compiler/xla/status_macros.h" 27 #include "tensorflow/compiler/xla/test.h" 28 #include "tensorflow/compiler/xla/test_helpers.h" 29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 30 #include "tensorflow/compiler/xla/xla_data.pb.h" 31 #include "tensorflow/core/platform/logging.h" 32 #include "tensorflow/core/platform/test.h" 33 34 namespace xla { 35 namespace { 36 37 using ::testing::ElementsAre; 38 using ::testing::UnorderedElementsAre; 39 40 // Test is parameterized on a bool which is whether the dataflow analysis is 41 // performed with SSA form. 42 class HloDataflowAnalysisTest : public HloTestBase, 43 public ::testing::WithParamInterface<bool> { 44 protected: 45 HloDataflowAnalysisTest() : module_(CreateNewModule()) {} 46 47 // Run dataflow analysis on the member module. For convenience returns a 48 // reference to the generated analysis stored in analysis_. 49 const HloDataflowAnalysis& RunAnalysis(bool ssa_form, 50 bool bitcast_defines_value = false) { 51 hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before dataflow analysis"); 52 analysis_ = 53 HloDataflowAnalysis::Run(*module_, ssa_form, bitcast_defines_value) 54 .ConsumeValueOrDie(); 55 return *analysis_; 56 } 57 58 // Return a vector of the HloValues at the given program position. 59 std::vector<HloValue> HloValuesAt(const HloInstruction* instruction, 60 const ShapeIndex& index = {}) { 61 CHECK(analysis_ != nullptr); 62 std::vector<HloValue> values; 63 for (const HloValue* value : 64 analysis_->GetValueSet(instruction, index).values()) { 65 values.push_back(*value); 66 } 67 return values; 68 } 69 70 // Returns true if the top-level values for instructions 'a' and 'b' may 71 // interfere. Precondition: 'a' and 'b' define array-shaped values. 72 bool InstructionsMayInterfere(const HloOrdering& ordering, 73 const HloInstruction* a, 74 const HloInstruction* b) { 75 EXPECT_FALSE(ShapeUtil::IsTuple(a->shape())); 76 EXPECT_FALSE(ShapeUtil::IsTuple(b->shape())); 77 return ordering.MayInterfere(analysis_->GetValueDefinedAt(a), 78 analysis_->GetValueDefinedAt(b), *analysis_); 79 } 80 81 std::unique_ptr<HloComputation> CreateR0F32UnaryOpComputation( 82 HloOpcode opcode) { 83 HloComputation::Builder builder(TestName() + "." + HloOpcodeString(opcode)); 84 HloInstruction* param0 = builder.AddInstruction( 85 HloInstruction::CreateParameter(0, scalar_shape_, "param0")); 86 builder.AddInstruction( 87 HloInstruction::CreateUnary(scalar_shape_, opcode, param0)); 88 return builder.Build(); 89 } 90 91 std::unique_ptr<HloModule> module_; 92 std::unique_ptr<HloDataflowAnalysis> analysis_; 93 94 const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); 95 const Shape vector_shape_ = ShapeUtil::MakeShape(F32, {42}); 96 const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( 97 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}); 98 }; 99 100 TEST_P(HloDataflowAnalysisTest, BinaryOperation) { 101 // Test the dataflow for a simple binary operation (Add). 102 auto builder = HloComputation::Builder(TestName()); 103 auto constant1 = builder.AddInstruction( 104 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 105 auto constant2 = builder.AddInstruction( 106 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 107 auto add = builder.AddInstruction(HloInstruction::CreateBinary( 108 scalar_shape_, HloOpcode::kAdd, constant1, constant2)); 109 module_->AddEntryComputation(builder.Build()); 110 111 bool ssa_form = GetParam(); 112 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); 113 114 // Each instruction should define a single value. 115 EXPECT_EQ(analysis.values().size(), 3); 116 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); 117 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); 118 EXPECT_TRUE(analysis.ValueIsDefinedAt(add)); 119 120 // Verify the positions of the values. These positions are all trivial because 121 // there are no instructions which forward values. 122 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).positions(), 123 UnorderedElementsAre(HloPosition{constant1, {}})); 124 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).positions(), 125 UnorderedElementsAre(HloPosition{constant2, {}})); 126 EXPECT_THAT(analysis.GetValueDefinedAt(add).positions(), 127 UnorderedElementsAre(HloPosition{add, {}})); 128 129 // Verify the uses of the values. 130 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), 131 UnorderedElementsAre(HloUse{add, 0, {}})); 132 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), 133 UnorderedElementsAre(HloUse{add, 1, {}})); 134 EXPECT_TRUE(analysis.GetValueDefinedAt(add).uses().empty()); 135 136 // Verify liveout values from the module. 137 EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); 138 EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); 139 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); 140 } 141 142 TEST_P(HloDataflowAnalysisTest, TupleAndGtes) { 143 // Verify the dataflow through a Tuple and GetTupleElement instructions. 144 auto builder = HloComputation::Builder(TestName()); 145 auto param0 = builder.AddInstruction( 146 HloInstruction::CreateParameter(0, scalar_shape_, "param0")); 147 auto param1 = builder.AddInstruction( 148 HloInstruction::CreateParameter(1, scalar_shape_, "param1")); 149 auto tuple = 150 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1})); 151 auto gte0 = builder.AddInstruction( 152 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 0)); 153 auto gte1 = builder.AddInstruction( 154 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1)); 155 auto add = builder.AddInstruction( 156 HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1)); 157 module_->AddEntryComputation(builder.Build()); 158 159 bool ssa_form = GetParam(); 160 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); 161 162 // The two params, tuple, and add should each define one value. 163 EXPECT_EQ(analysis.values().size(), 4); 164 165 EXPECT_TRUE(analysis.ValueIsDefinedAt(param0)); 166 EXPECT_TRUE(analysis.ValueIsDefinedAt(param1)); 167 EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple, /*index=*/{})); 168 EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{0})); 169 EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{1})); 170 EXPECT_FALSE(analysis.ValueIsDefinedAt(gte0)); 171 EXPECT_FALSE(analysis.ValueIsDefinedAt(gte1)); 172 EXPECT_TRUE(analysis.ValueIsDefinedAt(add)); 173 174 // Verify the positions of the values. 175 EXPECT_THAT( 176 analysis.GetValueDefinedAt(param0).positions(), 177 UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}}, 178 HloPosition{gte0, {}})); 179 EXPECT_THAT( 180 analysis.GetValueDefinedAt(param1).positions(), 181 UnorderedElementsAre(HloPosition{param1, {}}, HloPosition{tuple, {1}}, 182 HloPosition{gte1, {}})); 183 EXPECT_THAT(analysis.GetValueDefinedAt(tuple).positions(), 184 UnorderedElementsAre(HloPosition{tuple, {}})); 185 186 // Verify uses. Of interest is that a GetTupleElement instruction is only a 187 // use of the top-level value in the tuple operand. 188 EXPECT_THAT(analysis.GetValueDefinedAt(param0).uses(), 189 UnorderedElementsAre(HloUse{add, 0, {}})); 190 EXPECT_THAT(analysis.GetValueDefinedAt(param1).uses(), 191 UnorderedElementsAre(HloUse{add, 1, {}})); 192 EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(), 193 UnorderedElementsAre(HloUse{gte0, 0, {}}, HloUse{gte1, 0, {}})); 194 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); 195 } 196 197 TEST_P(HloDataflowAnalysisTest, NestedTuple) { 198 // Verify the dataflow through a nested tuple. 199 auto builder = HloComputation::Builder(TestName()); 200 auto constant1 = builder.AddInstruction( 201 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 202 auto constant2 = builder.AddInstruction( 203 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 204 auto tuple = builder.AddInstruction( 205 HloInstruction::CreateTuple({constant1, constant2})); 206 auto nested_tuple = builder.AddInstruction( 207 HloInstruction::CreateTuple({tuple, tuple, constant1})); 208 auto gte_tuple = builder.AddInstruction( 209 HloInstruction::CreateGetTupleElement(tuple->shape(), nested_tuple, 1)); 210 auto gte_out = builder.AddInstruction( 211 HloInstruction::CreateGetTupleElement(scalar_shape_, gte_tuple, 0)); 212 module_->AddEntryComputation(builder.Build()); 213 214 bool ssa_form = GetParam(); 215 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); 216 217 EXPECT_EQ(analysis.values().size(), 4); 218 219 // Verify positions and uses. 220 EXPECT_THAT( 221 analysis.GetValueDefinedAt(constant1).positions(), 222 UnorderedElementsAre( 223 HloPosition{constant1, {}}, HloPosition{tuple, {0}}, 224 HloPosition{nested_tuple, {0, 0}}, HloPosition{nested_tuple, {1, 0}}, 225 HloPosition{nested_tuple, {2}}, HloPosition{gte_tuple, {0}}, 226 HloPosition{gte_out, {}})); 227 // Constant values should have only a single use, which is the root of the 228 // computation. 229 EXPECT_THAT(analysis.GetValueDefinedAt(constant1, /*index=*/{}).uses(), 230 UnorderedElementsAre(HloUse{gte_out, 0, {0}})); 231 EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty()); 232 233 // The top-level tuple values are used in GTE instructions. 234 EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(), 235 UnorderedElementsAre(HloUse{gte_out, 0, {}})); 236 EXPECT_THAT(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{}).uses(), 237 UnorderedElementsAre(HloUse{gte_tuple, 0, {}})); 238 239 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); 240 EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); 241 EXPECT_FALSE( 242 analysis.GetValueDefinedAt(tuple, /*index=*/{}).live_out_of_module()); 243 EXPECT_FALSE(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{}) 244 .live_out_of_module()); 245 } 246 247 TEST_P(HloDataflowAnalysisTest, SingleCall) { 248 // Test a single call of a subcomputation. The subcomputation adds its two 249 // array-shaped parameters. 250 auto subbuilder = HloComputation::Builder("Subcomputation"); 251 auto subparam0 = subbuilder.AddInstruction( 252 HloInstruction::CreateParameter(0, scalar_shape_, "param0")); 253 auto subparam1 = subbuilder.AddInstruction( 254 HloInstruction::CreateParameter(1, scalar_shape_, "param1")); 255 auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( 256 scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); 257 HloComputation* called_computation = 258 module_->AddEmbeddedComputation(subbuilder.Build()); 259 260 auto builder = HloComputation::Builder(TestName()); 261 auto constant1 = builder.AddInstruction( 262 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 263 auto constant2 = builder.AddInstruction( 264 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 265 auto call = builder.AddInstruction(HloInstruction::CreateCall( 266 scalar_shape_, {constant1, constant2}, called_computation)); 267 module_->AddEntryComputation(builder.Build()); 268 269 bool ssa_form = GetParam(); 270 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); 271 272 EXPECT_EQ(analysis.values().size(), 3); 273 274 // The parameters of the subcomputation and the call instruction itself should 275 // not define values. Their values flow from elsewhere. 276 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); 277 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); 278 EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0)); 279 EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam1)); 280 EXPECT_TRUE(analysis.ValueIsDefinedAt(add)); 281 EXPECT_FALSE(analysis.ValueIsDefinedAt(call)); 282 283 EXPECT_EQ(analysis.GetUniqueValueAt(subparam0), 284 analysis.GetValueDefinedAt(constant1)); 285 EXPECT_EQ(analysis.GetUniqueValueAt(subparam1), 286 analysis.GetValueDefinedAt(constant2)); 287 EXPECT_EQ(analysis.GetUniqueValueAt(call), analysis.GetValueDefinedAt(add)); 288 289 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), 290 UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{add, 0, {}})); 291 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), 292 UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{add, 1, {}})); 293 294 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); 295 } 296 297 TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { 298 // Test a subcomputation which is called twice with identical values. 299 auto subbuilder = HloComputation::Builder("Subcomputation"); 300 auto subparam0 = subbuilder.AddInstruction( 301 HloInstruction::CreateParameter(0, scalar_shape_, "param0")); 302 auto subparam1 = subbuilder.AddInstruction( 303 HloInstruction::CreateParameter(1, scalar_shape_, "param1")); 304 auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( 305 scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); 306 HloComputation* called_computation = 307 module_->AddEmbeddedComputation(subbuilder.Build()); 308 309 auto builder = HloComputation::Builder(TestName()); 310 auto constant1 = builder.AddInstruction( 311 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 312 auto constant2 = builder.AddInstruction( 313 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 314 auto call1 = builder.AddInstruction(HloInstruction::CreateCall( 315 scalar_shape_, {constant1, constant2}, called_computation)); 316 auto call2 = builder.AddInstruction(HloInstruction::CreateCall( 317 scalar_shape_, {constant1, constant2}, called_computation)); 318 auto sub = builder.AddInstruction(HloInstruction::CreateBinary( 319 scalar_shape_, HloOpcode::kSubtract, call1, call2)); 320 module_->AddEntryComputation(builder.Build()); 321 322 bool ssa_form = GetParam(); 323 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); 324 325 EXPECT_EQ(analysis.values().size(), 4); 326 327 // Definitions should be identical to the single callsite case. 328 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); 329 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); 330 EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0)); 331 EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam1)); 332 EXPECT_TRUE(analysis.ValueIsDefinedAt(add)); 333 EXPECT_FALSE(analysis.ValueIsDefinedAt(call1)); 334 EXPECT_FALSE(analysis.ValueIsDefinedAt(call2)); 335 EXPECT_TRUE(analysis.ValueIsDefinedAt(sub)); 336 337 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), 338 UnorderedElementsAre(HloUse{call1, 0, {}}, HloUse{call2, 0, {}}, 339 HloUse{add, 0, {}})); 340 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), 341 UnorderedElementsAre(HloUse{call1, 1, {}}, HloUse{call2, 1, {}}, 342 HloUse{add, 1, {}})); 343 // The Add from the subcomputation is used as both operands of the Subtract. 344 EXPECT_THAT(analysis.GetValueDefinedAt(add).uses(), 345 UnorderedElementsAre(HloUse{sub, 0, {}}, HloUse{sub, 1, {}})); 346 347 EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module()); 348 EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_module()); 349 } 350 351 TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) { 352 // Test a subcomputation which is called twice with different argument values. 353 auto subbuilder = HloComputation::Builder("Subcomputation"); 354 auto subparam0 = subbuilder.AddInstruction( 355 HloInstruction::CreateParameter(0, scalar_shape_, "param0")); 356 auto subparam1 = subbuilder.AddInstruction( 357 HloInstruction::CreateParameter(1, scalar_shape_, "param1")); 358 auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( 359 scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); 360 HloComputation* called_computation = 361 module_->AddEmbeddedComputation(subbuilder.Build()); 362 363 auto builder = HloComputation::Builder(TestName()); 364 auto constant1 = builder.AddInstruction( 365 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 366 auto constant2 = builder.AddInstruction( 367 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 368 auto call1 = builder.AddInstruction(HloInstruction::CreateCall( 369 scalar_shape_, {constant1, constant2}, called_computation)); 370 auto call2 = builder.AddInstruction(HloInstruction::CreateCall( 371 scalar_shape_, {call1, constant2}, called_computation)); 372 module_->AddEntryComputation(builder.Build()); 373 374 bool ssa_form = GetParam(); 375 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); 376 377 EXPECT_FALSE(analysis.ValueIsDefinedAt(call1)); 378 EXPECT_FALSE(analysis.ValueIsDefinedAt(call2)); 379 380 EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0)); 381 382 EXPECT_THAT(HloValuesAt(subparam0), 383 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), 384 analysis.GetValueDefinedAt(add))); 385 EXPECT_THAT(HloValuesAt(subparam1), 386 UnorderedElementsAre(analysis.GetValueDefinedAt(constant2))); 387 388 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); 389 } 390 391 TEST_P(HloDataflowAnalysisTest, NestedCalls) { 392 // Test a module with nested computations. HLO is: 393 // 394 // F32[] inner_computation(F32[] %param0, F32[] %param1): 395 // %add = Add(%param0, %param1) 396 // 397 // F32[] outer_computation((F32[] %param0, F32[] %param1): 398 // ;; Note that parameters are interchanged in the call. 399 // %nested_call = Call(inner_computation, {%param1, %param0}) 400 // 401 // F32[] entry: 402 // %constant1 = Constant(1.0) 403 // %constant2 = Constant(2.0) 404 // %call = Call(outer_computation, {%constant1, %constant2}) 405 // 406 auto inner_builder = HloComputation::Builder("InnerComputation"); 407 auto inner_param0 = inner_builder.AddInstruction( 408 HloInstruction::CreateParameter(0, scalar_shape_, "param0")); 409 auto inner_param1 = inner_builder.AddInstruction( 410 HloInstruction::CreateParameter(1, scalar_shape_, "param1")); 411 auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary( 412 scalar_shape_, HloOpcode::kAdd, inner_param0, inner_param1)); 413 HloComputation* inner_computation = 414 module_->AddEmbeddedComputation(inner_builder.Build()); 415 416 auto outer_builder = HloComputation::Builder("OuterComputation"); 417 auto outer_param0 = outer_builder.AddInstruction( 418 HloInstruction::CreateParameter(0, scalar_shape_, "param0")); 419 auto outer_param1 = outer_builder.AddInstruction( 420 HloInstruction::CreateParameter(1, scalar_shape_, "param1")); 421 // Swizzle parameters. 422 auto nested_call = outer_builder.AddInstruction(HloInstruction::CreateCall( 423 scalar_shape_, {outer_param1, outer_param0}, inner_computation)); 424 HloComputation* outer_computation = 425 module_->AddEmbeddedComputation(outer_builder.Build()); 426 427 auto builder = HloComputation::Builder(TestName()); 428 auto constant1 = builder.AddInstruction( 429 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 430 auto constant2 = builder.AddInstruction( 431 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 432 auto call = builder.AddInstruction(HloInstruction::CreateCall( 433 scalar_shape_, {constant1, constant2}, outer_computation)); 434 module_->AddEntryComputation(builder.Build()); 435 436 bool ssa_form = GetParam(); 437 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); 438 439 // Only three values should be defined. Most instructions just pass through 440 // their operand values. 441 EXPECT_EQ(analysis.values().size(), 3); 442 443 // Verify that the uses of the constants are properly swizzled by parameter 444 // permutation in nested_call. 445 EXPECT_THAT( 446 analysis.GetValueDefinedAt(constant1).uses(), 447 UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{nested_call, 1, {}}, 448 HloUse{add, 1, {}})); 449 EXPECT_THAT( 450 analysis.GetValueDefinedAt(constant2).uses(), 451 UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{nested_call, 0, {}}, 452 HloUse{add, 0, {}})); 453 454 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); 455 } 456 457 TEST_P(HloDataflowAnalysisTest, SingleWhile) { 458 // Test a simple single while instruction. The while body includes a 459 // pass-through value. HLO: 460 // 461 // body((F32[], F32[]) %tuple_param): 462 // %add = Add(%tuple_param{0}, %tuple_param{1}) 463 // return Tuple(%tuple_param{0}, %add) 464 // 465 // condition((F32[], F32[]) %tuple_param): 466 // return Constant(false) 467 // 468 // entry: 469 // %constant1 = Constant(1.0) 470 // %constant2 = Constant(2.0) 471 // %tuple = Tuple(%constant1, %constant2) 472 // return While(%tuple, body, condition) 473 // 474 const Shape tuple_shape = 475 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); 476 477 // Element 0 passes transparently through the body. 478 auto body_builder = HloComputation::Builder("body"); 479 auto body_param = body_builder.AddInstruction( 480 HloInstruction::CreateParameter(0, tuple_shape, "param")); 481 auto body_element_0 = body_builder.AddInstruction( 482 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); 483 auto body_element_1 = body_builder.AddInstruction( 484 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); 485 auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( 486 scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); 487 auto body_root = body_builder.AddInstruction( 488 HloInstruction::CreateTuple({body_element_0, add})); 489 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); 490 491 // Condition computation trivially returns a constant "false". 492 auto cond_builder = HloComputation::Builder("condition"); 493 auto cond_param = cond_builder.AddInstruction( 494 HloInstruction::CreateParameter(0, tuple_shape, "param")); 495 auto cond_constant = cond_builder.AddInstruction( 496 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 497 HloComputation* condition = 498 module_->AddEmbeddedComputation(cond_builder.Build()); 499 500 auto builder = HloComputation::Builder(TestName()); 501 auto constant1 = builder.AddInstruction( 502 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 503 auto constant2 = builder.AddInstruction( 504 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 505 auto tuple = builder.AddInstruction( 506 HloInstruction::CreateTuple({constant1, constant2})); 507 auto xla_while = builder.AddInstruction( 508 HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); 509 module_->AddEntryComputation(builder.Build()); 510 511 bool ssa_form = GetParam(); 512 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); 513 514 EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module()); 515 516 if (ssa_form) { 517 // Element 0 of the tuple passed through the body so no phi value is 518 // defined. 519 EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0})); 520 EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0})); 521 EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0})); 522 523 // Element 1 of the tuple should be a phi value. 524 EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1})); 525 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi()); 526 EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1})); 527 EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi()); 528 EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1})); 529 EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi()); 530 531 EXPECT_THAT( 532 analysis.GetValueDefinedAt(constant1).uses(), 533 UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{body_root, 0, {}}, 534 HloUse{xla_while, 0, {0}})); 535 536 // Constant1 passes through the body and out of the module. 537 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); 538 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}) 539 .live_out_of_module()); 540 541 EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module()); 542 } else { 543 // While instruction and subcomputation parameters should not define values 544 // in non-ssa form. 545 EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0})); 546 EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1})); 547 EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0})); 548 EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1})); 549 EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0})); 550 EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1})); 551 552 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); 553 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); 554 } 555 } 556 557 TEST_P(HloDataflowAnalysisTest, SequentialWhiles) { 558 // Test sequential while instructions. The while body includes a 559 // pass-through value. HLO: 560 // 561 // body((F32[], F32[]) %tuple_param): 562 // %add = Add(%tuple_param{0}, %tuple_param{1}) 563 // return Tuple(%tuple_param{0}, %add) 564 // 565 // condition((F32[], F32[]) %tuple_param): 566 // return Constant(false) 567 // 568 // entry: 569 // %constant1 = Constant(1.0) 570 // %constant2 = Constant(2.0) 571 // %tuple = Tuple(%constant1, %constant2) 572 // %while0 = While(%tuple, body, condition) 573 // %while1 = While(%while0, body, condition) 574 // return While(%while1, body, condition) 575 // 576 const Shape tuple_shape = 577 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); 578 579 // Element 0 passes transparently through the body. 580 auto body_builder = HloComputation::Builder("body"); 581 auto body_param = body_builder.AddInstruction( 582 HloInstruction::CreateParameter(0, tuple_shape, "param")); 583 auto body_element_0 = body_builder.AddInstruction( 584 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); 585 auto body_element_1 = body_builder.AddInstruction( 586 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); 587 auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( 588 scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); 589 body_builder.AddInstruction( 590 HloInstruction::CreateTuple({body_element_0, add})); 591 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); 592 593 auto cond_builder = HloComputation::Builder("condition"); 594 cond_builder.AddInstruction( 595 HloInstruction::CreateParameter(0, tuple_shape, "param")); 596 cond_builder.AddInstruction( 597 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 598 HloComputation* condition = 599 module_->AddEmbeddedComputation(cond_builder.Build()); 600 601 auto builder = HloComputation::Builder(TestName()); 602 auto constant1 = builder.AddInstruction( 603 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 604 auto constant2 = builder.AddInstruction( 605 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 606 auto tuple = builder.AddInstruction( 607 HloInstruction::CreateTuple({constant1, constant2})); 608 auto xla_while0 = builder.AddInstruction( 609 HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); 610 auto xla_while1 = builder.AddInstruction( 611 HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while0)); 612 auto xla_while2 = builder.AddInstruction( 613 HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1)); 614 module_->AddEntryComputation(builder.Build()); 615 616 bool ssa_form = GetParam(); 617 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); 618 619 // Element 0 is passed through all the while instructions and out of the 620 // module.. 621 EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}), 622 analysis.GetValueDefinedAt(constant1)); 623 EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}), 624 analysis.GetValueDefinedAt(constant1)); 625 EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}), 626 analysis.GetValueDefinedAt(constant1)); 627 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); 628 } 629 630 TEST_P(HloDataflowAnalysisTest, NestedWhiles) { 631 // Test nested while instructions. The inner body passes through element 0 of 632 // its parameter, and the outer body passes through element 1. HLO: 633 // 634 // inner_body((F32[], F32[]) %tuple_param): 635 // %add = Add(%tuple_param{0}, %tuple_param{1}) 636 // return Tuple(%tuple_param{0}, %add) 637 // 638 // outer_body((F32[], F32[]) %tuple_param): 639 // %negate = Negate(%tuple_param{0}) 640 // %tuple = Tuple(%negate, %tuple_param{1}) 641 // return While(%tuple, inner_body, condition) 642 // 643 // entry: 644 // %constant1 = Constant(1.0) 645 // %constant2 = Constant(2.0) 646 // %tuple = Tuple(%constant1, %constant2) 647 // return While(%tuple, outer_body, condition) 648 // 649 const Shape tuple_shape = 650 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); 651 652 auto cond_builder = HloComputation::Builder("condition"); 653 cond_builder.AddInstruction( 654 HloInstruction::CreateParameter(0, tuple_shape, "param")); 655 cond_builder.AddInstruction( 656 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 657 HloComputation* condition = 658 module_->AddEmbeddedComputation(cond_builder.Build()); 659 660 // Element 0 passes transparently through the body. 661 auto inner_builder = HloComputation::Builder("inner_body"); 662 auto inner_param = inner_builder.AddInstruction( 663 HloInstruction::CreateParameter(0, tuple_shape, "param")); 664 auto inner_element_0 = inner_builder.AddInstruction( 665 HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 0)); 666 auto inner_element_1 = inner_builder.AddInstruction( 667 HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 1)); 668 auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary( 669 scalar_shape_, HloOpcode::kAdd, inner_element_0, inner_element_1)); 670 inner_builder.AddInstruction( 671 HloInstruction::CreateTuple({inner_element_0, add})); 672 HloComputation* inner_body = 673 module_->AddEmbeddedComputation(inner_builder.Build()); 674 675 // Element 1 passes transparently through the body. 676 auto outer_builder = HloComputation::Builder("outer_body"); 677 auto outer_param = outer_builder.AddInstruction( 678 HloInstruction::CreateParameter(0, tuple_shape, "param")); 679 auto outer_element_0 = outer_builder.AddInstruction( 680 HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 0)); 681 auto negate = outer_builder.AddInstruction(HloInstruction::CreateUnary( 682 scalar_shape_, HloOpcode::kNegate, outer_element_0)); 683 auto outer_element_1 = outer_builder.AddInstruction( 684 HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 1)); 685 auto outer_tuple = outer_builder.AddInstruction( 686 HloInstruction::CreateTuple({negate, outer_element_1})); 687 auto nested_while = outer_builder.AddInstruction(HloInstruction::CreateWhile( 688 tuple_shape, condition, inner_body, outer_tuple)); 689 HloComputation* outer_body = 690 module_->AddEmbeddedComputation(outer_builder.Build()); 691 692 auto builder = HloComputation::Builder(TestName()); 693 auto constant1 = builder.AddInstruction( 694 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 695 auto constant2 = builder.AddInstruction( 696 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 697 auto tuple = builder.AddInstruction( 698 HloInstruction::CreateTuple({constant1, constant2})); 699 auto entry_while = builder.AddInstruction( 700 HloInstruction::CreateWhile(tuple_shape, condition, outer_body, tuple)); 701 module_->AddEntryComputation(builder.Build()); 702 703 bool ssa_form = GetParam(); 704 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); 705 706 EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}), 707 UnorderedElementsAre(analysis.GetValueDefinedAt(negate))); 708 if (ssa_form) { 709 EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_param, /*index=*/{1})); 710 EXPECT_TRUE( 711 analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi()); 712 713 // Element 0 of the nested while is %negate. 714 EXPECT_FALSE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0})); 715 EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}), 716 UnorderedElementsAre(analysis.GetValueDefinedAt(negate))); 717 // Element 1 is a phi value (join of %add and %constant2). 718 EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{1})); 719 EXPECT_TRUE( 720 analysis.GetValueDefinedAt(nested_while, /*index=*/{1}).is_phi()); 721 722 EXPECT_TRUE(analysis.ValueIsDefinedAt(entry_while, /*index=*/{0})); 723 EXPECT_TRUE( 724 analysis.GetValueDefinedAt(entry_while, /*index=*/{0}).is_phi()); 725 726 EXPECT_TRUE(analysis.ValueIsDefinedAt(entry_while, /*index=*/{1})); 727 EXPECT_TRUE( 728 analysis.GetValueDefinedAt(entry_while, /*index=*/{1}).is_phi()); 729 } else { 730 EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{1}), 731 UnorderedElementsAre(analysis.GetValueDefinedAt(add), 732 analysis.GetValueDefinedAt(constant2))); 733 734 EXPECT_THAT(HloValuesAt(nested_while, /*index=*/{0}), 735 UnorderedElementsAre(analysis.GetValueDefinedAt(negate))); 736 EXPECT_THAT(HloValuesAt(nested_while, /*index=*/{1}), 737 UnorderedElementsAre(analysis.GetValueDefinedAt(add), 738 analysis.GetValueDefinedAt(constant2))); 739 740 EXPECT_THAT(HloValuesAt(entry_while, /*index=*/{0}), 741 UnorderedElementsAre(analysis.GetValueDefinedAt(negate), 742 analysis.GetValueDefinedAt(constant1))); 743 EXPECT_THAT(HloValuesAt(entry_while, /*index=*/{1}), 744 UnorderedElementsAre(analysis.GetValueDefinedAt(add), 745 analysis.GetValueDefinedAt(constant2))); 746 } 747 } 748 749 TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) { 750 // Test a while instruction with a body which permutes it's tuple parameter 751 // elements. HLO: 752 // 753 // body((F32[], F32[]) %tuple_param): 754 // return Tuple(%tuple_param{1}, %tuple_param{0}) 755 // 756 // condition((F32[], F32[]) %tuple_param): 757 // return Constant(false) 758 // 759 // entry: 760 // %constant1 = Constant(1.0) 761 // %constant2 = Constant(2.0) 762 // %tuple = Tuple(%constant1, %constant2) 763 // return While(%tuple, body, condition) 764 // 765 const Shape tuple_shape = 766 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); 767 768 auto body_builder = HloComputation::Builder("body"); 769 auto body_param = body_builder.AddInstruction( 770 HloInstruction::CreateParameter(0, tuple_shape, "param")); 771 auto body_element_0 = body_builder.AddInstruction( 772 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); 773 auto body_element_1 = body_builder.AddInstruction( 774 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); 775 body_builder.AddInstruction( 776 HloInstruction::CreateTuple({body_element_1, body_element_0})); 777 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); 778 779 auto cond_builder = HloComputation::Builder("condition"); 780 auto cond_param = cond_builder.AddInstruction( 781 HloInstruction::CreateParameter(0, tuple_shape, "param")); 782 cond_builder.AddInstruction( 783 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 784 HloComputation* condition = 785 module_->AddEmbeddedComputation(cond_builder.Build()); 786 787 auto builder = HloComputation::Builder(TestName()); 788 auto constant1 = builder.AddInstruction( 789 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 790 auto constant2 = builder.AddInstruction( 791 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 792 auto tuple = builder.AddInstruction( 793 HloInstruction::CreateTuple({constant1, constant2})); 794 auto xla_while = builder.AddInstruction( 795 HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); 796 module_->AddEntryComputation(builder.Build()); 797 798 bool ssa_form = GetParam(); 799 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); 800 801 if (ssa_form) { 802 // Element 0 and 1 in the while should both be phi values. 803 EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0})); 804 EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{0}).is_phi()); 805 EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1})); 806 EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi()); 807 808 EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0})); 809 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi()); 810 EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1})); 811 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi()); 812 813 EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0})); 814 EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{0}).is_phi()); 815 EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1})); 816 EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi()); 817 818 EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); 819 EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); 820 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{}) 821 .live_out_of_module()); 822 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}) 823 .live_out_of_module()); 824 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}) 825 .live_out_of_module()); 826 } else { 827 // Elements 0 and 1 have both constants as reaching definitions. 828 EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{0}), 829 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), 830 analysis.GetValueDefinedAt(constant2))); 831 EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{1}), 832 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), 833 analysis.GetValueDefinedAt(constant2))); 834 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); 835 EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); 836 } 837 } 838 839 TEST_P(HloDataflowAnalysisTest, ArraySelect) { 840 // Test a kSelect of an array value. 841 auto builder = HloComputation::Builder(TestName()); 842 auto pred = builder.AddInstruction( 843 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 844 auto constant1 = builder.AddInstruction( 845 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 846 auto constant2 = builder.AddInstruction( 847 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 848 auto select = builder.AddInstruction(HloInstruction::CreateTernary( 849 scalar_shape_, HloOpcode::kSelect, pred, constant1, constant2)); 850 851 module_->AddEntryComputation(builder.Build()); 852 853 bool ssa_form = GetParam(); 854 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); 855 856 EXPECT_TRUE(analysis.ValueIsDefinedAt(select)); 857 EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); 858 EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); 859 EXPECT_TRUE(analysis.GetValueDefinedAt(select).live_out_of_module()); 860 } 861 862 TEST_P(HloDataflowAnalysisTest, TupleSelect) { 863 // Test a kSelect of a tuple value. Non-top-level element flow through the 864 // instruction. 865 auto builder = HloComputation::Builder(TestName()); 866 auto pred = builder.AddInstruction( 867 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 868 auto constant1 = builder.AddInstruction( 869 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 870 auto constant2 = builder.AddInstruction( 871 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 872 auto constant3 = builder.AddInstruction( 873 HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0))); 874 auto constant4 = builder.AddInstruction( 875 HloInstruction::CreateConstant(Literal::CreateR0<float>(4.0))); 876 auto tuple1 = 877 builder.AddInstruction(HloInstruction::CreateTuple({constant1})); 878 auto tuple2 = 879 builder.AddInstruction(HloInstruction::CreateTuple({constant2})); 880 auto tuple3 = 881 builder.AddInstruction(HloInstruction::CreateTuple({constant3})); 882 auto tuple4 = 883 builder.AddInstruction(HloInstruction::CreateTuple({constant4})); 884 const Shape tuple_shape = tuple1->shape(); 885 auto select11 = builder.AddInstruction(HloInstruction::CreateTernary( 886 tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple1)); 887 auto select12 = builder.AddInstruction(HloInstruction::CreateTernary( 888 tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2)); 889 auto select34 = builder.AddInstruction(HloInstruction::CreateTernary( 890 tuple_shape, HloOpcode::kSelect, pred, tuple3, tuple4)); 891 auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary( 892 tuple_shape, HloOpcode::kSelect, pred, select12, select34)); 893 894 module_->AddEntryComputation(builder.Build()); 895 896 bool ssa_form = GetParam(); 897 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); 898 899 // Top-level value is always defined by a kSelect. 900 EXPECT_TRUE(analysis.ValueIsDefinedAt(select11)); 901 EXPECT_TRUE(analysis.ValueIsDefinedAt(select12)); 902 EXPECT_TRUE(analysis.ValueIsDefinedAt(select34)); 903 EXPECT_TRUE(analysis.ValueIsDefinedAt(select1234)); 904 905 EXPECT_FALSE(analysis.ValueIsDefinedAt(select11, /*index=*/{0})); 906 EXPECT_FALSE(analysis.ValueIsDefinedAt(select12, /*index=*/{0})); 907 EXPECT_FALSE(analysis.ValueIsDefinedAt(select34, /*index=*/{0})); 908 EXPECT_FALSE(analysis.ValueIsDefinedAt(select1234, /*index=*/{0})); 909 910 EXPECT_THAT(HloValuesAt(select11, /*index=*/{0}), 911 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1))); 912 EXPECT_THAT(HloValuesAt(select12, /*index=*/{0}), 913 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), 914 analysis.GetValueDefinedAt(constant2))); 915 EXPECT_THAT(HloValuesAt(select34, /*index=*/{0}), 916 UnorderedElementsAre(analysis.GetValueDefinedAt(constant3), 917 analysis.GetValueDefinedAt(constant4))); 918 EXPECT_THAT(HloValuesAt(select1234, /*index=*/{0}), 919 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), 920 analysis.GetValueDefinedAt(constant2), 921 analysis.GetValueDefinedAt(constant3), 922 analysis.GetValueDefinedAt(constant4))); 923 924 EXPECT_THAT( 925 analysis.GetValueDefinedAt(tuple1, /*index=*/{}).uses(), 926 UnorderedElementsAre(HloUse{select11, 1, {}}, HloUse{select11, 2, {}}, 927 HloUse{select12, 1, {}})); 928 929 // The two constant values just pass through the Selects and are not 930 // used except at the root. They are live out however. 931 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), 932 UnorderedElementsAre(HloUse{select1234, 1, {0}})); 933 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), 934 UnorderedElementsAre(HloUse{select1234, 1, {0}})); 935 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); 936 EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); 937 } 938 939 TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) { 940 // Test kSelect of a nested tuple. 941 auto builder = HloComputation::Builder(TestName()); 942 auto pred = builder.AddInstruction( 943 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 944 auto constant1 = builder.AddInstruction( 945 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 946 auto constant2 = builder.AddInstruction( 947 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 948 auto constant3 = builder.AddInstruction( 949 HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0))); 950 auto constant4 = builder.AddInstruction( 951 HloInstruction::CreateConstant(Literal::CreateR0<float>(4.0))); 952 auto constant5 = builder.AddInstruction( 953 HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0))); 954 auto inner_tuple1 = builder.AddInstruction( 955 HloInstruction::CreateTuple({constant2, constant3})); 956 auto tuple1 = builder.AddInstruction( 957 HloInstruction::CreateTuple({constant1, inner_tuple1})); 958 auto inner_tuple2 = builder.AddInstruction( 959 HloInstruction::CreateTuple({constant5, constant3})); 960 auto tuple2 = builder.AddInstruction( 961 HloInstruction::CreateTuple({constant4, inner_tuple2})); 962 auto select = builder.AddInstruction(HloInstruction::CreateTernary( 963 tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); 964 965 module_->AddEntryComputation(builder.Build()); 966 967 bool ssa_form = GetParam(); 968 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); 969 970 EXPECT_TRUE(analysis.ValueIsDefinedAt(select)); 971 972 EXPECT_THAT(HloValuesAt(select, /*index=*/{0}), 973 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), 974 analysis.GetValueDefinedAt(constant4))); 975 EXPECT_THAT(HloValuesAt(select, /*index=*/{1}), 976 UnorderedElementsAre(analysis.GetValueDefinedAt(inner_tuple1), 977 analysis.GetValueDefinedAt(inner_tuple2))); 978 EXPECT_THAT(HloValuesAt(select, /*index=*/{1, 0}), 979 UnorderedElementsAre(analysis.GetValueDefinedAt(constant2), 980 analysis.GetValueDefinedAt(constant5))); 981 EXPECT_THAT(HloValuesAt(select, /*index=*/{1, 1}), 982 UnorderedElementsAre(analysis.GetValueDefinedAt(constant3))); 983 } 984 985 TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) { 986 // Test a tuple-shaped kSelect feeding a kWhile instruction. HLO: 987 // 988 // body((F32[], F32[]) %tuple_param): 989 // %add = Add(%tuple_param{0}, %tuple_param{1}) 990 // return Tuple(%tuple_param{0}, %add) 991 // 992 // condition((F32[], F32[]) %tuple_param): 993 // return Constant(false) 994 // 995 // entry: 996 // %constant1 = Constant(1.0) 997 // %constant2 = Constant(2.0) 998 // %constant3 = Constant(3.0) 999 // %tuple1 = Tuple(%constant1) 1000 // %tuple2 = Tuple(%constant2) 1001 // %select = Select(%tuple1, %tuple2) 1002 // %gte = GetTupleElement(%select, 0) 1003 // %tuple = Tuple(%gte, %constant3) 1004 // return While(%tuple, body, condition) 1005 // 1006 auto builder = HloComputation::Builder(TestName()); 1007 1008 const Shape tuple_shape = 1009 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); 1010 1011 // Element 0 passes transparently through the body. 1012 auto body_builder = HloComputation::Builder("body"); 1013 auto body_param = body_builder.AddInstruction( 1014 HloInstruction::CreateParameter(0, tuple_shape, "param")); 1015 auto body_element_0 = body_builder.AddInstruction( 1016 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); 1017 auto body_element_1 = body_builder.AddInstruction( 1018 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); 1019 auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( 1020 scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); 1021 body_builder.AddInstruction( 1022 HloInstruction::CreateTuple({body_element_0, add})); 1023 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); 1024 1025 auto cond_builder = HloComputation::Builder("condition"); 1026 cond_builder.AddInstruction( 1027 HloInstruction::CreateParameter(0, tuple_shape, "param")); 1028 cond_builder.AddInstruction( 1029 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 1030 HloComputation* condition = 1031 module_->AddEmbeddedComputation(cond_builder.Build()); 1032 1033 auto pred = builder.AddInstruction( 1034 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 1035 auto constant1 = builder.AddInstruction( 1036 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 1037 auto constant2 = builder.AddInstruction( 1038 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 1039 auto constant3 = builder.AddInstruction( 1040 HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0))); 1041 auto tuple1 = 1042 builder.AddInstruction(HloInstruction::CreateTuple({constant1})); 1043 auto tuple2 = 1044 builder.AddInstruction(HloInstruction::CreateTuple({constant2})); 1045 auto select = builder.AddInstruction(HloInstruction::CreateTernary( 1046 tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); 1047 auto gte = builder.AddInstruction( 1048 HloInstruction::CreateGetTupleElement(scalar_shape_, select, 0)); 1049 auto tuple = 1050 builder.AddInstruction(HloInstruction::CreateTuple({gte, constant3})); 1051 auto xla_while = builder.AddInstruction( 1052 HloInstruction::CreateWhile(tuple->shape(), condition, body, tuple)); 1053 1054 module_->AddEntryComputation(builder.Build()); 1055 1056 bool ssa_form = GetParam(); 1057 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); 1058 1059 if (ssa_form) { 1060 EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0})); 1061 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi()); 1062 EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1})); 1063 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi()); 1064 1065 EXPECT_FALSE(analysis.ValueIsDefinedAt(select, /*index=*/{0})); 1066 1067 EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); 1068 EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); 1069 EXPECT_FALSE(analysis.GetValueDefinedAt(constant3).live_out_of_module()); 1070 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}) 1071 .live_out_of_module()); 1072 } else { 1073 EXPECT_THAT(HloValuesAt(gte), 1074 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), 1075 analysis.GetValueDefinedAt(constant2))); 1076 EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{0}), 1077 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), 1078 analysis.GetValueDefinedAt(constant2))); 1079 EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{1}), 1080 UnorderedElementsAre(analysis.GetValueDefinedAt(add), 1081 analysis.GetValueDefinedAt(constant3))); 1082 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); 1083 EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); 1084 EXPECT_TRUE(analysis.GetValueDefinedAt(constant3).live_out_of_module()); 1085 } 1086 } 1087 1088 TEST_P(HloDataflowAnalysisTest, BitcastDefinesValue) { 1089 // Test the bitcast_defines_value flag to the dataflow analysis. 1090 auto builder = HloComputation::Builder(TestName()); 1091 auto constant = builder.AddInstruction( 1092 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 1093 auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( 1094 scalar_shape_, HloOpcode::kBitcast, constant)); 1095 1096 module_->AddEntryComputation(builder.Build()); 1097 1098 bool ssa_form = GetParam(); 1099 { 1100 const HloDataflowAnalysis& analysis = 1101 RunAnalysis(ssa_form, /*bitcast_defines_value=*/true); 1102 1103 EXPECT_EQ(analysis.values().size(), 2); 1104 1105 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant)); 1106 EXPECT_TRUE(analysis.ValueIsDefinedAt(bitcast)); 1107 EXPECT_FALSE(analysis.GetValueDefinedAt(constant).live_out_of_module()); 1108 EXPECT_TRUE(analysis.GetValueDefinedAt(bitcast).live_out_of_module()); 1109 } 1110 { 1111 const HloDataflowAnalysis& analysis = 1112 RunAnalysis(ssa_form, /*bitcast_defines_value=*/false); 1113 EXPECT_EQ(analysis.values().size(), 1); 1114 1115 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant)); 1116 EXPECT_FALSE(analysis.ValueIsDefinedAt(bitcast)); 1117 EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module()); 1118 } 1119 } 1120 1121 TEST_P(HloDataflowAnalysisTest, TupleCopy) { 1122 // Test that a tuple-shaped copy only copies (defines) the top-level value. 1123 auto builder = HloComputation::Builder(TestName()); 1124 auto param0 = builder.AddInstruction( 1125 HloInstruction::CreateParameter(0, scalar_shape_, "param0")); 1126 auto param1 = builder.AddInstruction( 1127 HloInstruction::CreateParameter(1, scalar_shape_, "param1")); 1128 auto tuple = 1129 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1})); 1130 auto copy = builder.AddInstruction( 1131 HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple)); 1132 module_->AddEntryComputation(builder.Build()); 1133 1134 bool ssa_form = GetParam(); 1135 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); 1136 1137 EXPECT_EQ(analysis.values().size(), 4); 1138 1139 EXPECT_TRUE(analysis.ValueIsDefinedAt(param0)); 1140 EXPECT_TRUE(analysis.ValueIsDefinedAt(param1)); 1141 EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple, /*index=*/{})); 1142 EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{0})); 1143 EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{1})); 1144 EXPECT_TRUE(analysis.ValueIsDefinedAt(copy, /*index=*/{})); 1145 EXPECT_FALSE(analysis.ValueIsDefinedAt(copy, /*index=*/{0})); 1146 EXPECT_FALSE(analysis.ValueIsDefinedAt(copy, /*index=*/{1})); 1147 1148 EXPECT_THAT(HloValuesAt(copy, /*index=*/{0}), 1149 UnorderedElementsAre(analysis.GetValueDefinedAt(param0))); 1150 EXPECT_THAT(HloValuesAt(copy, /*index=*/{1}), 1151 UnorderedElementsAre(analysis.GetValueDefinedAt(param1))); 1152 EXPECT_TRUE( 1153 analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module()); 1154 } 1155 1156 TEST_P(HloDataflowAnalysisTest, SendAndSendDone) { 1157 // Test that a Send forwards its operand to the output tuple at {0}. 1158 auto builder = HloComputation::Builder(TestName()); 1159 auto param = builder.AddInstruction( 1160 HloInstruction::CreateParameter(0, scalar_shape_, "param0")); 1161 auto send = builder.AddInstruction( 1162 HloInstruction::CreateSend(param, /*channel_id=*/0)); 1163 auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); 1164 module_->AddEntryComputation(builder.Build()); 1165 1166 bool ssa_form = GetParam(); 1167 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); 1168 1169 EXPECT_EQ(analysis.values().size(), 4); 1170 1171 EXPECT_TRUE(analysis.ValueIsDefinedAt(param)); 1172 EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{})); 1173 EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0})); 1174 EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1})); 1175 EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done)); 1176 EXPECT_THAT(HloValuesAt(send, /*index=*/{0}), 1177 UnorderedElementsAre(analysis.GetValueDefinedAt(param))); 1178 } 1179 1180 TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) { 1181 // Test that a RecvDone forwards its operand tuple element at {0} to the 1182 // output. 1183 auto builder = HloComputation::Builder(TestName()); 1184 auto recv = builder.AddInstruction( 1185 HloInstruction::CreateRecv(scalar_shape_, /*channel_id=*/0)); 1186 auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); 1187 module_->AddEntryComputation(builder.Build()); 1188 1189 bool ssa_form = GetParam(); 1190 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); 1191 1192 EXPECT_EQ(analysis.values().size(), 3); 1193 1194 EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{})); 1195 EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0})); 1196 EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1})); 1197 EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done)); 1198 EXPECT_THAT(HloValuesAt(recv_done), 1199 UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0}))); 1200 EXPECT_TRUE( 1201 analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module()); 1202 } 1203 1204 TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) { 1205 // A simple chain of elementwise operations. No values should interfere. 1206 // 1207 // param --> negate -> exp -> log 1208 // 1209 auto builder = HloComputation::Builder(TestName()); 1210 auto param = builder.AddInstruction( 1211 HloInstruction::CreateParameter(0, vector_shape_, "param")); 1212 auto negate = builder.AddInstruction( 1213 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param)); 1214 auto exp = builder.AddInstruction( 1215 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, negate)); 1216 auto log = builder.AddInstruction( 1217 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kLog, exp)); 1218 1219 module_->AddEntryComputation(builder.Build()); 1220 RunAnalysis(GetParam()); 1221 1222 DependencyHloOrdering ordering(module_.get()); 1223 1224 // No values should interfere. 1225 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate)); 1226 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp)); 1227 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, log)); 1228 EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, exp)); 1229 EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, log)); 1230 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate)); 1231 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, log)); 1232 EXPECT_FALSE(InstructionsMayInterfere(ordering, log, negate)); 1233 EXPECT_FALSE(InstructionsMayInterfere(ordering, log, exp)); 1234 1235 // Values should interfere with itself. 1236 EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, exp)); 1237 } 1238 1239 TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) { 1240 // Two entry params, which interfere with each other. 1241 // 1242 // param0 --> negate ---------------\ 1243 // param1 --> exp --> add 1244 auto builder = HloComputation::Builder(TestName()); 1245 auto param0 = builder.AddInstruction( 1246 HloInstruction::CreateParameter(0, vector_shape_, "param0")); 1247 auto param1 = builder.AddInstruction( 1248 HloInstruction::CreateParameter(1, vector_shape_, "param1")); 1249 auto negate = builder.AddInstruction( 1250 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param0)); 1251 auto exp = builder.AddInstruction( 1252 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param1)); 1253 auto add = builder.AddInstruction(HloInstruction::CreateBinary( 1254 vector_shape_, HloOpcode::kAdd, negate, exp)); 1255 1256 auto entry = module_->AddEntryComputation(builder.Build()); 1257 RunAnalysis(GetParam()); 1258 1259 SequentialHloOrdering::HloModuleSequence sequence; 1260 sequence.insert({entry, {param0, negate, param1, exp, add}}); 1261 SequentialHloOrdering ordering(module_.get(), sequence); 1262 1263 // Entry parameters interfere as if they are defined simultaneously at 1264 // the very beginning. 1265 EXPECT_TRUE(InstructionsMayInterfere(ordering, param0, param1)); 1266 EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, negate)); 1267 EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, exp)); 1268 EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, add)); 1269 EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, param0)); 1270 EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, negate)); 1271 EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, exp)); 1272 EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, add)); 1273 1274 // Negate and exp still interfere. 1275 EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp)); 1276 EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate)); 1277 1278 // But {negate, add} and {exp, add} don't interfere. 1279 EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add)); 1280 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate)); 1281 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add)); 1282 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp)); 1283 } 1284 1285 TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { 1286 // Similar to MultipleEntryParameters_Sequential, but the parameter is of 1287 // while body computation. Body computation in the sequential order: 1288 // 1289 // %constant = Constant(...) 1290 // %exp = Exp(%constant) 1291 // %param = Param(0) 1292 // %add = Add(%param, %exp) ;; Root of body 1293 // %dead_constant = Constant(...) 1294 // %dead_negate = Negate(%dead_constant) 1295 // 1296 // %constant and its only use %exp are ordered before 'param'. However, the 1297 // %constant and %param values still interfere because the parameter is 1298 // considered live into the while body. 1299 // 1300 // Similarly, %dead_constant and %dead_negate are ordered after the root of 1301 // the body computation %add. However, %add is liveout of the computation so 1302 // %dead_constant and %add interfere. 1303 auto body_builder = HloComputation::Builder(TestName()); 1304 auto body_param = body_builder.AddInstruction( 1305 HloInstruction::CreateParameter(0, scalar_shape_, "body_param")); 1306 auto constant = body_builder.AddInstruction( 1307 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 1308 auto exp = body_builder.AddInstruction( 1309 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant)); 1310 auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( 1311 scalar_shape_, HloOpcode::kAdd, exp, body_param)); 1312 auto dead_constant = body_builder.AddInstruction( 1313 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 1314 auto dead_negate = body_builder.AddInstruction(HloInstruction::CreateUnary( 1315 scalar_shape_, HloOpcode::kNegate, dead_constant)); 1316 HloComputation* body = module_->AddEmbeddedComputation( 1317 body_builder.Build(/*root_instruction=*/add)); 1318 1319 auto cond_builder = HloComputation::Builder("condition"); 1320 auto cond_param = cond_builder.AddInstruction( 1321 HloInstruction::CreateParameter(0, scalar_shape_, "cond_param")); 1322 auto cond_constant = cond_builder.AddInstruction( 1323 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 1324 HloComputation* condition = 1325 module_->AddEmbeddedComputation(cond_builder.Build()); 1326 1327 auto builder = HloComputation::Builder(TestName()); 1328 auto param = builder.AddInstruction( 1329 HloInstruction::CreateParameter(0, scalar_shape_, "param")); 1330 auto xla_while = builder.AddInstruction( 1331 HloInstruction::CreateWhile(scalar_shape_, condition, body, param)); 1332 1333 auto entry = module_->AddEntryComputation(builder.Build()); 1334 bool ssa_form = GetParam(); 1335 RunAnalysis(ssa_form); 1336 1337 SequentialHloOrdering::HloModuleSequence sequence; 1338 sequence.insert({entry, {param, xla_while}}); 1339 sequence.insert({condition, {cond_param, cond_constant}}); 1340 // Construct the order such that 'constant' and its use 'exp' are before 1341 // body_param. 1342 sequence.insert({body, {constant, exp, body_param, add}}); 1343 1344 SequentialHloOrdering ordering(module_.get(), sequence); 1345 1346 // 'add' is live out of the body and will interfere with an later instructions 1347 // such as 'dead_constant' and 'dead_negate'. 1348 EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_constant)); 1349 EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_negate)); 1350 1351 // The remaining checks test phi values defined by body and condition 1352 // parameters which only occur in the SSA form of the analysis. 1353 if (ssa_form) { 1354 // Though the ordering suggests 'constant' and 'param' should not interfere, 1355 // 'param' is live in and thus interferes with any earlier instruction of 1356 // the computation in the order (eg 'constant')' 1357 EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, constant)); 1358 EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, exp)); 1359 EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add)); 1360 1361 // The following values end up in the same buffer: 1362 // (1) the init value: 'param' 1363 // (2) the body parameter: 'body_param' 1364 // (3) the condition parameter: 'cond_param' 1365 // (4) the root value of the while body: 'add' 1366 // (5) the while value: 'xla_while' 1367 // None should interfere. 1368 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, body_param)); 1369 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, cond_param)); 1370 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add)); 1371 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, xla_while)); 1372 1373 EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, cond_param)); 1374 EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add)); 1375 EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, xla_while)); 1376 1377 EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, add)); 1378 EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, xla_while)); 1379 1380 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, xla_while)); 1381 } 1382 } 1383 1384 TEST_P(HloDataflowAnalysisTest, NonElementwiseOperand) { 1385 // A chain of operations with two elementwise and one non-elementwise. The 1386 // elementwise op should not interfere with its operand, while the 1387 // non-elementwise op should interfere. Entry params always interfere. 1388 // 1389 // param --> exp -> negate -> reverse 1390 // 1391 auto builder = HloComputation::Builder(TestName()); 1392 auto param = builder.AddInstruction( 1393 HloInstruction::CreateParameter(0, vector_shape_, "param")); 1394 auto exp = builder.AddInstruction( 1395 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param)); 1396 auto negate = builder.AddInstruction( 1397 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, exp)); 1398 auto reverse = builder.AddInstruction( 1399 HloInstruction::CreateReverse(vector_shape_, negate, {0})); 1400 1401 module_->AddEntryComputation(builder.Build()); 1402 RunAnalysis(GetParam()); 1403 1404 DependencyHloOrdering ordering(module_.get()); 1405 1406 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp)); 1407 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate)); 1408 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, reverse)); 1409 1410 // Negate is elementwise, so doesn't interfere with its operand. 1411 // Reverse is non-elementwise, so does interfere with its operand. 1412 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate)); 1413 EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, reverse)); 1414 } 1415 1416 TEST_P(HloDataflowAnalysisTest, OverlappedValues) { 1417 // Verify simultaneously live values interfere (exp and negate). 1418 // 1419 // param --> negate -> add 1420 // \---> exp -----/ 1421 // 1422 auto builder = HloComputation::Builder(TestName()); 1423 auto param = builder.AddInstruction( 1424 HloInstruction::CreateParameter(0, vector_shape_, "param")); 1425 auto negate = builder.AddInstruction( 1426 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param)); 1427 auto exp = builder.AddInstruction( 1428 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param)); 1429 auto add = builder.AddInstruction(HloInstruction::CreateBinary( 1430 vector_shape_, HloOpcode::kAdd, negate, exp)); 1431 1432 module_->AddEntryComputation(builder.Build()); 1433 RunAnalysis(GetParam()); 1434 1435 DependencyHloOrdering ordering(module_.get()); 1436 1437 EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate)); 1438 EXPECT_TRUE(InstructionsMayInterfere(ordering, param, exp)); 1439 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add)); 1440 1441 // Negate and exp interfere with each other, but not with add. 1442 EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp)); 1443 EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate)); 1444 EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add)); 1445 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate)); 1446 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add)); 1447 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp)); 1448 } 1449 1450 TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) { 1451 // Identical to the test OverlappedValue but using a sequential ordering of 1452 // HLO instructions. 1453 // 1454 // param --> negate -> add 1455 // \---> exp -----/ 1456 // 1457 // Sequential order: 1458 // param, negate, exp, add 1459 // 1460 // Liveness is identical to the DependencyHloOrdering. 1461 auto builder = HloComputation::Builder(TestName()); 1462 auto param = builder.AddInstruction( 1463 HloInstruction::CreateParameter(0, vector_shape_, "param")); 1464 auto negate = builder.AddInstruction( 1465 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param)); 1466 auto exp = builder.AddInstruction( 1467 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param)); 1468 auto add = builder.AddInstruction(HloInstruction::CreateBinary( 1469 vector_shape_, HloOpcode::kAdd, negate, exp)); 1470 1471 auto entry = module_->AddEntryComputation(builder.Build()); 1472 RunAnalysis(GetParam()); 1473 1474 SequentialHloOrdering::HloModuleSequence sequence; 1475 std::vector<const HloInstruction*> order = {param, negate, exp, add}; 1476 sequence.emplace(entry, order); 1477 1478 SequentialHloOrdering ordering(module_.get(), sequence); 1479 1480 EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate)); 1481 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp)); 1482 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add)); 1483 1484 // Negate and exp interfere with each other, but not with add. 1485 EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp)); 1486 EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate)); 1487 EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add)); 1488 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate)); 1489 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add)); 1490 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp)); 1491 } 1492 1493 TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) { 1494 // Test MayInterfere() for embedded computation, specifically the interference 1495 // of values in different computations. 1496 // 1497 // embedded_computation: 1498 // %embedded_param = Param(0) 1499 // %embedded_log = Log(%embedded_param) 1500 // 1501 // entry computation: 1502 // %param = Param(0) 1503 // %negate = Negate(%param) 1504 // %exp = Negate(%exp) 1505 // %call = Call(embedded_computation, {%exp}) 1506 // %add = Add(%negate, %call) 1507 // 1508 // Note %negate is live across the call and should interfere with all values 1509 // in the embedded computation. 1510 auto embedded_builder = HloComputation::Builder(TestName() + "_embedded"); 1511 auto embedded_param = embedded_builder.AddInstruction( 1512 HloInstruction::CreateParameter(0, vector_shape_, "embedded_param")); 1513 auto embedded_log = 1514 embedded_builder.AddInstruction(HloInstruction::CreateUnary( 1515 vector_shape_, HloOpcode::kLog, embedded_param)); 1516 auto embedded_computation = 1517 module_->AddEmbeddedComputation(embedded_builder.Build()); 1518 1519 auto builder = HloComputation::Builder(TestName()); 1520 auto param = builder.AddInstruction( 1521 HloInstruction::CreateParameter(0, vector_shape_, "param")); 1522 auto negate = builder.AddInstruction( 1523 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param)); 1524 auto exp = builder.AddInstruction( 1525 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param)); 1526 auto call = builder.AddInstruction( 1527 HloInstruction::CreateCall(vector_shape_, {exp}, embedded_computation)); 1528 builder.AddInstruction(HloInstruction::CreateBinary( 1529 vector_shape_, HloOpcode::kAdd, negate, call)); 1530 module_->AddEntryComputation(builder.Build()); 1531 RunAnalysis(GetParam()); 1532 1533 DependencyHloOrdering ordering(module_.get()); 1534 1535 // Exp only use is the call so it should not interfere with values inside the 1536 // embedded computation. 1537 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, embedded_log)); 1538 1539 // Negate is live across the call and should interfere with values in the 1540 // embedded computation 1541 EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log)); 1542 } 1543 1544 TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) { 1545 // Test conditional with identity computations in both true and false cases. 1546 // 1547 // true_computation(F32[] %true_param): 1548 // return %true_param 1549 // 1550 // false_computation(F32[] %false_param): 1551 // return %false_param 1552 // 1553 // entry: 1554 // %pred = Constant(true) 1555 // %constant1 = Constant(56.0) 1556 // %constant2 = Constant(12.0) 1557 // return Conditional(%pred, %constant1, true_computation, 1558 // %constant2, false_computation) 1559 1560 auto true_builder = HloComputation::Builder(TestName() + "_true"); 1561 auto true_param = true_builder.AddInstruction( 1562 HloInstruction::CreateParameter(0, scalar_shape_, "true_param")); 1563 HloComputation* true_computation = 1564 module_->AddEmbeddedComputation(true_builder.Build()); 1565 1566 auto false_builder = HloComputation::Builder(TestName() + "_false"); 1567 auto false_param = false_builder.AddInstruction( 1568 HloInstruction::CreateParameter(0, scalar_shape_, "false_param")); 1569 HloComputation* false_computation = 1570 module_->AddEmbeddedComputation(false_builder.Build()); 1571 1572 auto builder = HloComputation::Builder(TestName()); 1573 auto pred = builder.AddInstruction( 1574 HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))); 1575 auto constant1 = builder.AddInstruction( 1576 HloInstruction::CreateConstant(Literal::CreateR0<float>(56.0f))); 1577 auto constant2 = builder.AddInstruction( 1578 HloInstruction::CreateConstant(Literal::CreateR0<float>(12.0f))); 1579 auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( 1580 scalar_shape_, pred, constant1, true_computation, constant2, 1581 false_computation)); 1582 module_->AddEntryComputation(builder.Build()); 1583 1584 const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); 1585 1586 EXPECT_TRUE(analysis.ValueIsDefinedAt(pred)); 1587 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); 1588 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); 1589 1590 EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param)); 1591 EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param)); 1592 1593 EXPECT_EQ(analysis.GetUniqueValueAt(true_param), 1594 analysis.GetValueDefinedAt(constant1)); 1595 EXPECT_EQ(analysis.GetUniqueValueAt(false_param), 1596 analysis.GetValueDefinedAt(constant2)); 1597 1598 EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(), 1599 ElementsAre(HloUse{conditional, 0, {}})); 1600 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), 1601 ElementsAre(HloUse{conditional, 1, {}})); 1602 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), 1603 ElementsAre(HloUse{conditional, 2, {}})); 1604 1605 EXPECT_EQ(analysis.values().size(), 3); 1606 EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); 1607 EXPECT_THAT(HloValuesAt(conditional), 1608 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), 1609 analysis.GetValueDefinedAt(constant2))); 1610 } 1611 1612 TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) { 1613 // Test conditional with true and false computations taking a tuple operand. 1614 // 1615 // true_computation((F32[], F32[]) %true_param): 1616 // %true_x = GetTupleElement(%true_param, 0) 1617 // %true_y = GetTupleElement(%true_param, 1) 1618 // return Add(%true_x, %true_y) 1619 // 1620 // false_computation((F32[], F32[]) %false_param): 1621 // %false_x = GetTupleElement(%false_param, 0) 1622 // %false_y = GetTupleElement(%false_param, 1) 1623 // return Subtract(%false_x, %false_y) 1624 // 1625 // entry: 1626 // %pred = Constant(true) 1627 // %constant1 = Constant(56.0) 1628 // %constant2 = Constant(12.0) 1629 // %tuple_operand = Tuple(%constant1, %constant2) 1630 // return Conditional(%pred, %tuple_operand, true_computation, 1631 // %tuple_operand, false_computation) 1632 1633 auto true_builder = HloComputation::Builder(TestName() + "_true"); 1634 auto true_param = true_builder.AddInstruction( 1635 HloInstruction::CreateParameter(0, tuple_shape_, "true_param")); 1636 auto true_x = true_builder.AddInstruction( 1637 HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 0)); 1638 auto true_y = true_builder.AddInstruction( 1639 HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 1)); 1640 auto add = true_builder.AddInstruction(HloInstruction::CreateBinary( 1641 scalar_shape_, HloOpcode::kAdd, true_x, true_y)); 1642 HloComputation* true_computation = 1643 module_->AddEmbeddedComputation(true_builder.Build()); 1644 1645 auto false_builder = HloComputation::Builder(TestName() + "_false"); 1646 auto false_param = false_builder.AddInstruction( 1647 HloInstruction::CreateParameter(0, tuple_shape_, "false_param")); 1648 auto false_x = false_builder.AddInstruction( 1649 HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 0)); 1650 auto false_y = false_builder.AddInstruction( 1651 HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 1)); 1652 auto sub = false_builder.AddInstruction(HloInstruction::CreateBinary( 1653 scalar_shape_, HloOpcode::kSubtract, false_x, false_y)); 1654 HloComputation* false_computation = 1655 module_->AddEmbeddedComputation(false_builder.Build()); 1656 1657 auto builder = HloComputation::Builder(TestName()); 1658 auto pred = builder.AddInstruction( 1659 HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))); 1660 auto constant1 = builder.AddInstruction( 1661 HloInstruction::CreateConstant(Literal::CreateR0<float>(56.0f))); 1662 auto constant2 = builder.AddInstruction( 1663 HloInstruction::CreateConstant(Literal::CreateR0<float>(12.0f))); 1664 auto tuple_operand = builder.AddInstruction( 1665 HloInstruction::CreateTuple({constant1, constant2})); 1666 auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( 1667 scalar_shape_, pred, tuple_operand, true_computation, tuple_operand, 1668 false_computation)); 1669 module_->AddEntryComputation(builder.Build()); 1670 1671 const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); 1672 1673 EXPECT_TRUE(analysis.ValueIsDefinedAt(pred)); 1674 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); 1675 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); 1676 EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand)); 1677 EXPECT_TRUE(analysis.ValueIsDefinedAt(add)); 1678 EXPECT_TRUE(analysis.ValueIsDefinedAt(sub)); 1679 1680 EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param)); 1681 EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param)); 1682 EXPECT_FALSE(analysis.ValueIsDefinedAt(true_x)); 1683 EXPECT_FALSE(analysis.ValueIsDefinedAt(true_y)); 1684 EXPECT_FALSE(analysis.ValueIsDefinedAt(false_x)); 1685 EXPECT_FALSE(analysis.ValueIsDefinedAt(false_y)); 1686 1687 EXPECT_EQ(analysis.GetUniqueValueAt(true_param), 1688 analysis.GetValueDefinedAt(tuple_operand)); 1689 EXPECT_EQ(analysis.GetUniqueValueAt(false_param), 1690 analysis.GetValueDefinedAt(tuple_operand)); 1691 EXPECT_EQ(analysis.GetUniqueValueAt(true_x), 1692 analysis.GetValueDefinedAt(constant1)); 1693 EXPECT_EQ(analysis.GetUniqueValueAt(true_y), 1694 analysis.GetValueDefinedAt(constant2)); 1695 EXPECT_EQ(analysis.GetUniqueValueAt(false_x), 1696 analysis.GetValueDefinedAt(constant1)); 1697 EXPECT_EQ(analysis.GetUniqueValueAt(false_y), 1698 analysis.GetValueDefinedAt(constant2)); 1699 1700 EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(), 1701 ElementsAre(HloUse{conditional, 0, {}})); 1702 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), 1703 UnorderedElementsAre(HloUse{conditional, 1, {0}}, 1704 HloUse{conditional, 2, {0}}, 1705 HloUse{add, 0, {}}, HloUse{sub, 0, {}})); 1706 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), 1707 UnorderedElementsAre(HloUse{conditional, 1, {1}}, 1708 HloUse{conditional, 2, {1}}, 1709 HloUse{add, 1, {}}, HloUse{sub, 1, {}})); 1710 EXPECT_THAT(analysis.GetValueDefinedAt(tuple_operand).uses(), 1711 UnorderedElementsAre( 1712 HloUse{conditional, 1, {}}, HloUse{conditional, 2, {}}, 1713 HloUse{true_x, 0, {}}, HloUse{true_y, 0, {}}, 1714 HloUse{false_x, 0, {}}, HloUse{false_y, 0, {}})); 1715 1716 EXPECT_EQ(analysis.values().size(), 6); 1717 EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); 1718 EXPECT_THAT(HloValuesAt(conditional), 1719 UnorderedElementsAre(analysis.GetValueDefinedAt(add), 1720 analysis.GetValueDefinedAt(sub))); 1721 } 1722 1723 TEST_P(HloDataflowAnalysisTest, NestedConditionals) { 1724 // computation1(F32[] %param1): 1725 // %ceil = Ceil(%param1) 1726 // return %ceil 1727 // 1728 // computation2(F32[] %param2): 1729 // %floor = Floor(%param2) 1730 // return %floor 1731 // 1732 // computation3(F32[] %param3): 1733 // %negate = Negate(%param3) 1734 // return %negate 1735 // 1736 // inner_conditional((PRED, F32[], F32[]) %param_cond): 1737 // %pred_cond = GetTupleElement(%param_cond, 0) 1738 // %true_operand_cond = GetTupleElement(%param_cond, 1) 1739 // %false_opearnd_cond = GetTupleElement(%param_cond, 2) 1740 // return Conditional(%pred_cond, %true_operand_cond, computation1, 1741 // %false_operand_cond, computation2) 1742 // 1743 // entry: 1744 // %pred1 = Constant(true) 1745 // %pred2 = Constant(false) 1746 // %constant1 = Constant(1.1); 1747 // %constant2 = Constant(2.2); 1748 // %constant3 = Constant(3.3); 1749 // return Conditional(%pred1, (%pred2, %constant1, %constant2), 1750 // inner_conditional, %constant3, computation3) 1751 1752 auto computation1 = module_->AddEmbeddedComputation( 1753 CreateR0F32UnaryOpComputation(HloOpcode::kCeil)); 1754 auto computation2 = module_->AddEmbeddedComputation( 1755 CreateR0F32UnaryOpComputation(HloOpcode::kFloor)); 1756 auto computation3 = module_->AddEmbeddedComputation( 1757 CreateR0F32UnaryOpComputation(HloOpcode::kNegate)); 1758 1759 // Build inner_conditional computation. 1760 const Shape scalar_bool_shape = ShapeUtil::MakeShape(PRED, {}); 1761 const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( 1762 {scalar_bool_shape, scalar_shape_, scalar_shape_}); 1763 auto inner_builder = 1764 HloComputation::Builder(TestName() + "_inner_conditional"); 1765 auto param_cond = inner_builder.AddInstruction( 1766 HloInstruction::CreateParameter(0, tuple_param_shape, "param_cond")); 1767 auto pred_cond = inner_builder.AddInstruction( 1768 HloInstruction::CreateGetTupleElement(scalar_bool_shape, param_cond, 0)); 1769 auto true_operand_cond = inner_builder.AddInstruction( 1770 HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 1)); 1771 auto false_operand_cond = inner_builder.AddInstruction( 1772 HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 2)); 1773 auto inner_conditional = 1774 inner_builder.AddInstruction(HloInstruction::CreateConditional( 1775 scalar_shape_, pred_cond, true_operand_cond, computation1, 1776 false_operand_cond, computation2)); 1777 auto inner_conditional_computation = 1778 module_->AddEmbeddedComputation(inner_builder.Build()); 1779 1780 // Build entry computation. 1781 auto builder = HloComputation::Builder(TestName()); 1782 auto pred1 = builder.AddInstruction( 1783 HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))); 1784 auto pred2 = builder.AddInstruction( 1785 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 1786 auto constant1 = builder.AddInstruction( 1787 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f))); 1788 auto constant2 = builder.AddInstruction( 1789 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.2f))); 1790 auto constant3 = builder.AddInstruction( 1791 HloInstruction::CreateConstant(Literal::CreateR0<float>(3.3f))); 1792 auto tuple_operand = builder.AddInstruction( 1793 HloInstruction::CreateTuple({pred2, constant1, constant2})); 1794 auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( 1795 scalar_shape_, pred1, tuple_operand, inner_conditional_computation, 1796 constant3, computation3)); 1797 module_->AddEntryComputation(builder.Build()); 1798 1799 const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); 1800 1801 EXPECT_TRUE(analysis.ValueIsDefinedAt(pred1)); 1802 EXPECT_TRUE(analysis.ValueIsDefinedAt(pred2)); 1803 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); 1804 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); 1805 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant3)); 1806 EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand)); 1807 EXPECT_TRUE(analysis.ValueIsDefinedAt(computation1->root_instruction())); 1808 EXPECT_TRUE(analysis.ValueIsDefinedAt(computation2->root_instruction())); 1809 EXPECT_TRUE(analysis.ValueIsDefinedAt(computation3->root_instruction())); 1810 1811 auto computation1_param = computation1->parameter_instruction(0); 1812 auto computation2_param = computation2->parameter_instruction(0); 1813 auto computation3_param = computation3->parameter_instruction(0); 1814 EXPECT_FALSE(analysis.ValueIsDefinedAt(computation1_param)); 1815 EXPECT_FALSE(analysis.ValueIsDefinedAt(computation2_param)); 1816 EXPECT_FALSE(analysis.ValueIsDefinedAt(computation3_param)); 1817 EXPECT_EQ(analysis.GetUniqueValueAt(computation1_param), 1818 analysis.GetValueDefinedAt(constant1)); 1819 EXPECT_EQ(analysis.GetUniqueValueAt(computation2_param), 1820 analysis.GetValueDefinedAt(constant2)); 1821 EXPECT_EQ(analysis.GetUniqueValueAt(computation3_param), 1822 analysis.GetValueDefinedAt(constant3)); 1823 1824 EXPECT_FALSE(analysis.ValueIsDefinedAt(param_cond)); 1825 EXPECT_FALSE(analysis.ValueIsDefinedAt(pred_cond)); 1826 EXPECT_FALSE(analysis.ValueIsDefinedAt(true_operand_cond)); 1827 EXPECT_FALSE(analysis.ValueIsDefinedAt(false_operand_cond)); 1828 EXPECT_EQ(analysis.GetUniqueValueAt(param_cond), 1829 analysis.GetValueDefinedAt(tuple_operand)); 1830 EXPECT_EQ(analysis.GetUniqueValueAt(pred_cond), 1831 analysis.GetValueDefinedAt(pred2)); 1832 EXPECT_EQ(analysis.GetUniqueValueAt(true_operand_cond), 1833 analysis.GetValueDefinedAt(constant1)); 1834 EXPECT_EQ(analysis.GetUniqueValueAt(false_operand_cond), 1835 analysis.GetValueDefinedAt(constant2)); 1836 1837 EXPECT_EQ(analysis.values().size(), 9); 1838 EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional)); 1839 EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); 1840 EXPECT_THAT( 1841 HloValuesAt(inner_conditional), 1842 UnorderedElementsAre( 1843 analysis.GetValueDefinedAt(computation1->root_instruction()), 1844 analysis.GetValueDefinedAt(computation2->root_instruction()))); 1845 EXPECT_THAT( 1846 HloValuesAt(conditional), 1847 UnorderedElementsAre( 1848 analysis.GetValueDefinedAt(computation1->root_instruction()), 1849 analysis.GetValueDefinedAt(computation2->root_instruction()), 1850 analysis.GetValueDefinedAt(computation3->root_instruction()))); 1851 } 1852 1853 INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, 1854 HloDataflowAnalysisTest, 1855 ::testing::Values(false, true)); 1856 1857 } // namespace 1858 } // namespace xla 1859