1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" 17 18 #include <map> 19 #include <memory> 20 21 #include "tensorflow/compiler/xla/literal_util.h" 22 #include "tensorflow/compiler/xla/service/flatten_call_graph.h" 23 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" 24 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 25 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 26 #include "tensorflow/compiler/xla/service/hlo_ordering.h" 27 #include "tensorflow/compiler/xla/service/instruction_fusion.h" 28 #include "tensorflow/compiler/xla/shape_util.h" 29 #include "tensorflow/compiler/xla/test.h" 30 #include "tensorflow/compiler/xla/test_helpers.h" 31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 32 #include "tensorflow/compiler/xla/xla_data.pb.h" 33 #include "tensorflow/core/lib/core/status_test_util.h" 34 #include "tensorflow/core/platform/logging.h" 35 #include "tensorflow/core/platform/test.h" 36 37 namespace xla { 38 namespace { 39 40 using ::testing::UnorderedElementsAre; 41 42 class HloAliasAnalysisTest : public HloTestBase { 43 protected: 44 HloAliasAnalysisTest() : module_(CreateNewModule()) {} 45 46 // Run alias analysis on the member module. For convenience returns a 47 // reference to the generated analysis stored in analysis_. 48 HloAliasAnalysis& RunAnalysis() { 49 hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis"); 50 analysis_ = HloAliasAnalysis::Run(module_.get()).ConsumeValueOrDie(); 51 return *analysis_; 52 } 53 54 // Return a vector of the buffers in the buffer set at the current position 55 // sorted by buffer id. 56 std::vector<HloBuffer> GetBuffersAt(const HloInstruction* instruction, 57 const ShapeIndex& index = {}) const { 58 std::set<HloBuffer::Id> buffer_ids; 59 for (const HloValue* value : analysis_->dataflow_analysis() 60 .GetValueSet(instruction, index) 61 .values()) { 62 buffer_ids.insert(analysis_->GetBufferContainingValue(*value).id()); 63 } 64 65 std::vector<HloBuffer> buffers; 66 for (HloBuffer::Id id : buffer_ids) { 67 buffers.push_back(analysis_->GetBuffer(id)); 68 } 69 return buffers; 70 } 71 72 // Return a vector containing all of the HloValues in the given buffer. 73 std::vector<HloValue> GetValuesInBuffer(const HloBuffer& buffer) { 74 std::vector<HloValue> values; 75 for (const HloValue* value : buffer.values()) { 76 values.push_back(*value); 77 } 78 return values; 79 } 80 81 // Return the HloValue defined at the given position. 82 const HloValue& GetValueDefinedAt(const HloInstruction* instruction, 83 const ShapeIndex& index = {}) const { 84 return analysis_->dataflow_analysis().GetValueDefinedAt(instruction, index); 85 } 86 87 // Returns true if any values held in the same buffer interfere. Generally, in 88 // the compiler pipeline copy-insertion will guarantee that this interference 89 // never occurs, but HLO graphs with interference can be explicitly 90 // constructed. 91 bool AnyValuesInSameBufferInterfere() { 92 DependencyHloOrdering ordering(module_.get()); 93 for (const HloBuffer& buffer : analysis_->buffers()) { 94 for (const HloValue* value_a : buffer.values()) { 95 for (const HloValue* value_b : buffer.values()) { 96 if (*value_a != *value_b && 97 ordering.MayInterfere(*value_a, *value_b, 98 analysis_->dataflow_analysis())) { 99 VLOG(1) << *value_a << " interferes with " << *value_b 100 << " in buffer: " << buffer; 101 return true; 102 } 103 } 104 } 105 } 106 return false; 107 } 108 109 std::unique_ptr<HloModule> module_; 110 std::unique_ptr<HloAliasAnalysis> analysis_; 111 112 const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); 113 }; 114 115 TEST_F(HloAliasAnalysisTest, BinaryOperation) { 116 // Test the analysis on a single binary operation (Add). 117 auto builder = HloComputation::Builder(TestName()); 118 auto constant1 = builder.AddInstruction( 119 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 120 auto constant2 = builder.AddInstruction( 121 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 122 auto add = builder.AddInstruction(HloInstruction::CreateBinary( 123 scalar_shape_, HloOpcode::kAdd, constant1, constant2)); 124 module_->AddEntryComputation(builder.Build()); 125 126 const HloAliasAnalysis& analysis = RunAnalysis(); 127 128 EXPECT_EQ(analysis.buffers().size(), 3); 129 130 // All of the buffer sets should trivially contain a single buffer containing 131 // a single value. 132 for (const HloInstruction* instruction : {constant1, constant2, add}) { 133 EXPECT_EQ(analysis.GetUniqueBufferAt(instruction).GetUniqueValue(), 134 GetValueDefinedAt(instruction)); 135 } 136 137 EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(add)); 138 EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(add)); 139 140 EXPECT_FALSE(AnyValuesInSameBufferInterfere()); 141 } 142 143 TEST_F(HloAliasAnalysisTest, TupleAndGtes) { 144 // Verify the analysis for a Tuple and GetTupleElement instructions. 145 auto builder = HloComputation::Builder(TestName()); 146 auto param0 = builder.AddInstruction( 147 HloInstruction::CreateParameter(0, scalar_shape_, "param0")); 148 auto param1 = builder.AddInstruction( 149 HloInstruction::CreateParameter(1, scalar_shape_, "param1")); 150 auto tuple = 151 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1})); 152 auto gte0 = builder.AddInstruction( 153 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 0)); 154 auto gte1 = builder.AddInstruction( 155 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1)); 156 builder.AddInstruction( 157 HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1)); 158 module_->AddEntryComputation(builder.Build()); 159 160 const HloAliasAnalysis& analysis = RunAnalysis(); 161 162 EXPECT_EQ(analysis.buffers().size(), 4); 163 164 // Verify the expected aliasing of the tuple elements. 165 EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{}).GetUniqueValue(), 166 GetValueDefinedAt(tuple, /*index=*/{})); 167 EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{0}).GetUniqueValue(), 168 GetValueDefinedAt(param0)); 169 EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{1}).GetUniqueValue(), 170 GetValueDefinedAt(param1)); 171 172 // The tuple operand, tuple element, and result of the GTE instruction should 173 // all be the same buffer. 174 EXPECT_EQ(analysis.GetUniqueBufferAt(param0), 175 analysis.GetUniqueBufferAt(tuple, /*index=*/{0})); 176 EXPECT_EQ(analysis.GetUniqueBufferAt(param0), 177 analysis.GetUniqueBufferAt(gte0)); 178 179 // Verify the positions of an aliased buffer. 180 EXPECT_THAT( 181 analysis.GetUniqueBufferAt(param0).ComputePositions(), 182 UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}}, 183 HloPosition{gte0, {}})); 184 185 EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(tuple)); 186 EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(tuple)); 187 188 EXPECT_FALSE(AnyValuesInSameBufferInterfere()); 189 } 190 191 TEST_F(HloAliasAnalysisTest, NondistinctTuple) { 192 // Test a expression with a non-distinct buffer set. 193 auto builder = HloComputation::Builder(TestName()); 194 auto param0 = builder.AddInstruction( 195 HloInstruction::CreateParameter(0, scalar_shape_, "param0")); 196 auto param1 = builder.AddInstruction( 197 HloInstruction::CreateParameter(1, scalar_shape_, "param1")); 198 // param0 is included twice in the tuple. 199 auto tuple = builder.AddInstruction( 200 HloInstruction::CreateTuple({param0, param1, param0})); 201 module_->AddEntryComputation(builder.Build()); 202 203 const HloAliasAnalysis& analysis = RunAnalysis(); 204 205 EXPECT_THAT( 206 analysis.GetUniqueBufferAt(param0).ComputePositions(), 207 UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}}, 208 HloPosition{tuple, {2}})); 209 210 EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(tuple)); 211 EXPECT_FALSE(analysis.InstructionBuffersAreDistinct(tuple)); 212 213 EXPECT_FALSE(AnyValuesInSameBufferInterfere()); 214 } 215 216 TEST_F(HloAliasAnalysisTest, SingleCall) { 217 // Test a single call of a subcomputation. The subcomputation adds its two 218 // array-shaped parameters. 219 auto subbuilder = HloComputation::Builder("Subcomputation"); 220 auto subparam0 = subbuilder.AddInstruction( 221 HloInstruction::CreateParameter(0, scalar_shape_, "param0")); 222 auto subparam1 = subbuilder.AddInstruction( 223 HloInstruction::CreateParameter(1, scalar_shape_, "param1")); 224 auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( 225 scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); 226 HloComputation* called_computation = 227 module_->AddEmbeddedComputation(subbuilder.Build()); 228 229 auto builder = HloComputation::Builder(TestName()); 230 auto constant1 = builder.AddInstruction( 231 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 232 auto constant2 = builder.AddInstruction( 233 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 234 auto call = builder.AddInstruction(HloInstruction::CreateCall( 235 scalar_shape_, {constant1, constant2}, called_computation)); 236 module_->AddEntryComputation(builder.Build()); 237 238 const HloAliasAnalysis& analysis = RunAnalysis(); 239 240 // Verify aliasing of the kCall operands and the subcomputation parameters. 241 EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).ComputePositions(), 242 UnorderedElementsAre(HloPosition{constant1, {}}, 243 HloPosition{subparam0, {}})); 244 EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).ComputePositions(), 245 UnorderedElementsAre(HloPosition{constant2, {}}, 246 HloPosition{subparam1, {}})); 247 248 // The subcomputation root and the kCall itself should alias. 249 EXPECT_THAT( 250 analysis.GetUniqueBufferAt(add).ComputePositions(), 251 UnorderedElementsAre(HloPosition{add, {}}, HloPosition{call, {}})); 252 253 EXPECT_FALSE(AnyValuesInSameBufferInterfere()); 254 } 255 256 TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) { 257 // Test a subcomputation which is called twice with different argument values. 258 auto subbuilder = HloComputation::Builder("Subcomputation"); 259 auto subparam0 = subbuilder.AddInstruction( 260 HloInstruction::CreateParameter(0, scalar_shape_, "param0")); 261 auto subparam1 = subbuilder.AddInstruction( 262 HloInstruction::CreateParameter(1, scalar_shape_, "param1")); 263 auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( 264 scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); 265 HloComputation* called_computation = 266 module_->AddEmbeddedComputation(subbuilder.Build()); 267 268 auto builder = HloComputation::Builder(TestName()); 269 auto constant1 = builder.AddInstruction( 270 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 271 auto constant2 = builder.AddInstruction( 272 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 273 auto call1 = builder.AddInstruction(HloInstruction::CreateCall( 274 scalar_shape_, {constant1, constant2}, called_computation)); 275 auto call2 = builder.AddInstruction(HloInstruction::CreateCall( 276 scalar_shape_, {call1, constant2}, called_computation)); 277 module_->AddEntryComputation(builder.Build()); 278 279 const HloAliasAnalysis& analysis = RunAnalysis(); 280 281 EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).ComputePositions(), 282 UnorderedElementsAre(HloPosition{constant1, {}}, 283 HloPosition{subparam0, {}})); 284 EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).ComputePositions(), 285 UnorderedElementsAre(HloPosition{constant2, {}}, 286 HloPosition{subparam1, {}})); 287 288 // The 'add' (root of the subcomputation) aliases the two call instruction, 289 // and the first parameter of the subcomputation because 'call1' it is passed 290 // as an argument to the subcomputation in 'call2'. 291 EXPECT_THAT( 292 analysis.GetUniqueBufferAt(add).ComputePositions(), 293 UnorderedElementsAre(HloPosition{add, {}}, HloPosition{call1, {}}, 294 HloPosition{subparam0, {}}, HloPosition{call2, {}})); 295 296 EXPECT_THAT(GetBuffersAt(subparam0), 297 UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1), 298 analysis.GetUniqueBufferAt(add))); 299 EXPECT_THAT(GetBuffersAt(subparam1), 300 UnorderedElementsAre(analysis.GetUniqueBufferAt(constant2))); 301 302 EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(subparam0)); 303 EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(subparam1)); 304 EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(subparam0)); 305 EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(subparam1)); 306 307 EXPECT_FALSE(AnyValuesInSameBufferInterfere()); 308 } 309 310 TEST_F(HloAliasAnalysisTest, SingleWhile) { 311 // Test a simple single while instruction. The while body includes a 312 // pass-through value. HLO: 313 // 314 // body((F32[], F32[]) %tuple_param): 315 // %add = Add(%tuple_param{0}, %tuple_param{1}) 316 // return Tuple(%tuple_param{0}, %add) 317 // 318 // condition((F32[], F32[]) %tuple_param): 319 // return Constant(false) 320 // 321 // entry: 322 // %constant1 = Constant(1.0) 323 // %constant2 = Constant(2.0) 324 // %tuple = Tuple(%constant1, %constant2) 325 // return While(%tuple, body, condition) 326 // 327 const Shape tuple_shape = 328 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); 329 330 // Element 0 passes transparently through the body. 331 auto body_builder = HloComputation::Builder("body"); 332 auto body_param = body_builder.AddInstruction( 333 HloInstruction::CreateParameter(0, tuple_shape, "param")); 334 auto body_element_0 = body_builder.AddInstruction( 335 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); 336 auto body_element_1 = body_builder.AddInstruction( 337 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); 338 auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( 339 scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); 340 auto body_tuple = body_builder.AddInstruction( 341 HloInstruction::CreateTuple({body_element_0, add})); 342 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); 343 344 // Condition computation trivially returns a constant "false". 345 auto cond_builder = HloComputation::Builder("condition"); 346 auto cond_param = cond_builder.AddInstruction( 347 HloInstruction::CreateParameter(0, tuple_shape, "param")); 348 cond_builder.AddInstruction( 349 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 350 HloComputation* condition = 351 module_->AddEmbeddedComputation(cond_builder.Build()); 352 353 auto builder = HloComputation::Builder(TestName()); 354 auto constant1 = builder.AddInstruction( 355 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 356 auto constant2 = builder.AddInstruction( 357 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 358 auto tuple = builder.AddInstruction( 359 HloInstruction::CreateTuple({constant1, constant2})); 360 auto xla_while = builder.AddInstruction( 361 HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); 362 module_->AddEntryComputation(builder.Build()); 363 364 const HloAliasAnalysis& analysis = RunAnalysis(); 365 366 // Verify the positions of the aliased while buffers. 367 EXPECT_THAT( 368 analysis.GetUniqueBufferAt(xla_while, /*index=*/{}).ComputePositions(), 369 UnorderedElementsAre(HloPosition{tuple, {}}, HloPosition{xla_while, {}}, 370 HloPosition{body_param, {}}, 371 HloPosition{body_tuple, {}}, 372 HloPosition{cond_param, {}})); 373 EXPECT_THAT( 374 analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}).ComputePositions(), 375 UnorderedElementsAre( 376 HloPosition{constant1, {}}, HloPosition{tuple, {0}}, 377 HloPosition{xla_while, {0}}, HloPosition{body_param, {0}}, 378 HloPosition{body_element_0, {}}, HloPosition{body_tuple, {0}}, 379 HloPosition{cond_param, {0}})); 380 EXPECT_THAT( 381 analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).ComputePositions(), 382 UnorderedElementsAre( 383 HloPosition{constant2, {}}, HloPosition{tuple, {1}}, 384 HloPosition{xla_while, {1}}, HloPosition{body_param, {1}}, 385 HloPosition{body_element_1, {}}, HloPosition{add, {}}, 386 HloPosition{body_tuple, {1}}, HloPosition{cond_param, {1}})); 387 388 EXPECT_THAT( 389 GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0})), 390 UnorderedElementsAre(GetValueDefinedAt(constant1))); 391 EXPECT_THAT( 392 GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})), 393 UnorderedElementsAre(GetValueDefinedAt(constant2), 394 GetValueDefinedAt(xla_while, /*index=*/{1}), 395 GetValueDefinedAt(body_param, {1}), 396 GetValueDefinedAt(cond_param, {1}), 397 GetValueDefinedAt(add))); 398 399 EXPECT_FALSE(AnyValuesInSameBufferInterfere()); 400 } 401 402 TEST_F(HloAliasAnalysisTest, SequentialWhiles) { 403 // Test sequential while instructions. The while body includes a 404 // pass-through value. HLO: 405 // 406 // body((F32[], F32[]) %tuple_param): 407 // %add = Add(%tuple_param{0}, %tuple_param{1}) 408 // return Tuple(%tuple_param{0}, %add) 409 // 410 // condition((F32[], F32[]) %tuple_param): 411 // return Constant(false) 412 // 413 // entry: 414 // %constant1 = Constant(1.0) 415 // %constant2 = Constant(2.0) 416 // %tuple = Tuple(%constant1, %constant2) 417 // %while0 = While(%tuple, body, condition) 418 // %while1 = While(%while0, body, condition) 419 // return While(%while1, body, condition) 420 // 421 const Shape tuple_shape = 422 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); 423 424 // Element 0 passes transparently through the body. 425 auto body_builder = HloComputation::Builder("body"); 426 auto body_param = body_builder.AddInstruction( 427 HloInstruction::CreateParameter(0, tuple_shape, "param")); 428 auto body_element_0 = body_builder.AddInstruction( 429 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); 430 auto body_element_1 = body_builder.AddInstruction( 431 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); 432 auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( 433 scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); 434 body_builder.AddInstruction( 435 HloInstruction::CreateTuple({body_element_0, add})); 436 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); 437 438 auto cond_builder = HloComputation::Builder("condition"); 439 cond_builder.AddInstruction( 440 HloInstruction::CreateParameter(0, tuple_shape, "param")); 441 cond_builder.AddInstruction( 442 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 443 HloComputation* condition = 444 module_->AddEmbeddedComputation(cond_builder.Build()); 445 446 auto builder = HloComputation::Builder(TestName()); 447 auto constant1 = builder.AddInstruction( 448 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 449 auto constant2 = builder.AddInstruction( 450 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 451 auto tuple = builder.AddInstruction( 452 HloInstruction::CreateTuple({constant1, constant2})); 453 auto xla_while0 = builder.AddInstruction( 454 HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); 455 auto xla_while1 = builder.AddInstruction( 456 HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while0)); 457 auto xla_while2 = builder.AddInstruction( 458 HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1)); 459 module_->AddEntryComputation(builder.Build()); 460 461 FlattenCallGraph flattener; 462 TF_ASSERT_OK(flattener.Run(module_.get()).status()); 463 464 const HloAliasAnalysis& analysis = RunAnalysis(); 465 466 EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{}), 467 analysis.GetUniqueBufferAt(xla_while2, /*index=*/{})); 468 EXPECT_EQ(analysis.GetUniqueBufferAt(constant1), 469 analysis.GetUniqueBufferAt(xla_while2, /*index=*/{0})); 470 EXPECT_EQ(analysis.GetUniqueBufferAt(constant2), 471 analysis.GetUniqueBufferAt(xla_while2, /*index=*/{1})); 472 } 473 474 TEST_F(HloAliasAnalysisTest, NestedWhiles) { 475 // Test nested while instructions. The inner body passes through element 0 of 476 // its parameter, and the outer body passes through element 1. HLO: 477 // 478 // inner_body((F32[], F32[]) %tuple_param): 479 // %add = Add(%tuple_param{0}, %tuple_param{1}) 480 // return Tuple(%tuple_param{0}, %add) 481 // 482 // outer_body((F32[], F32[]) %tuple_param): 483 // %negate = Negate(%tuple_param{0}) 484 // %tuple = Tuple(%negate, %tuple_param{1}) 485 // return While(%tuple, inner_body, condition) 486 // 487 // entry: 488 // %constant1 = Constant(1.0) 489 // %constant2 = Constant(2.0) 490 // %tuple = Tuple(%constant1, %constant2) 491 // return While(%tuple, outer_body, condition) 492 // 493 const Shape tuple_shape = 494 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); 495 496 auto build_cond_computation = [&tuple_shape]() { 497 auto cond_builder = HloComputation::Builder("condition"); 498 cond_builder.AddInstruction( 499 HloInstruction::CreateParameter(0, tuple_shape, "param")); 500 cond_builder.AddInstruction( 501 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 502 return cond_builder.Build(); 503 }; 504 // Build separate condition computations so the call graph is flat. The 505 // callgraph is always flattened in the compiler pipeline, and the flattened 506 // callgraph enables representative interference analysis. 507 HloComputation* condition1 = 508 module_->AddEmbeddedComputation(build_cond_computation()); 509 HloComputation* condition2 = 510 module_->AddEmbeddedComputation(build_cond_computation()); 511 512 // Element 0 passes transparently through the body. 513 auto inner_builder = HloComputation::Builder("inner_body"); 514 auto inner_param = inner_builder.AddInstruction( 515 HloInstruction::CreateParameter(0, tuple_shape, "param")); 516 auto inner_element_0 = inner_builder.AddInstruction( 517 HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 0)); 518 auto inner_element_1 = inner_builder.AddInstruction( 519 HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 1)); 520 auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary( 521 scalar_shape_, HloOpcode::kAdd, inner_element_0, inner_element_1)); 522 inner_builder.AddInstruction( 523 HloInstruction::CreateTuple({inner_element_0, add})); 524 HloComputation* inner_body = 525 module_->AddEmbeddedComputation(inner_builder.Build()); 526 527 // Element 1 passes transparently through the body. 528 auto outer_builder = HloComputation::Builder("outer_body"); 529 auto outer_param = outer_builder.AddInstruction( 530 HloInstruction::CreateParameter(0, tuple_shape, "param")); 531 auto outer_element_0 = outer_builder.AddInstruction( 532 HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 0)); 533 auto negate = outer_builder.AddInstruction(HloInstruction::CreateUnary( 534 scalar_shape_, HloOpcode::kNegate, outer_element_0)); 535 auto outer_element_1 = outer_builder.AddInstruction( 536 HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 1)); 537 auto outer_tuple = outer_builder.AddInstruction( 538 HloInstruction::CreateTuple({negate, outer_element_1})); 539 auto nested_while = outer_builder.AddInstruction(HloInstruction::CreateWhile( 540 tuple_shape, condition1, inner_body, outer_tuple)); 541 HloComputation* outer_body = 542 module_->AddEmbeddedComputation(outer_builder.Build()); 543 544 auto builder = HloComputation::Builder(TestName()); 545 auto constant1 = builder.AddInstruction( 546 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 547 auto constant2 = builder.AddInstruction( 548 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 549 auto tuple = builder.AddInstruction( 550 HloInstruction::CreateTuple({constant1, constant2})); 551 auto entry_while = builder.AddInstruction( 552 HloInstruction::CreateWhile(tuple_shape, condition2, outer_body, tuple)); 553 module_->AddEntryComputation(builder.Build()); 554 555 const HloAliasAnalysis& analysis = RunAnalysis(); 556 557 EXPECT_EQ(analysis.GetUniqueBufferAt(constant1), 558 analysis.GetUniqueBufferAt(entry_while, /*index=*/{0})); 559 EXPECT_EQ(analysis.GetUniqueBufferAt(constant1), 560 analysis.GetUniqueBufferAt(nested_while, /*index=*/{0})); 561 EXPECT_EQ(analysis.GetUniqueBufferAt(constant1), 562 analysis.GetUniqueBufferAt(inner_element_0)); 563 564 EXPECT_EQ(analysis.GetUniqueBufferAt(constant2), 565 analysis.GetUniqueBufferAt(entry_while, /*index=*/{1})); 566 EXPECT_EQ(analysis.GetUniqueBufferAt(constant2), 567 analysis.GetUniqueBufferAt(nested_while, /*index=*/{1})); 568 EXPECT_EQ(analysis.GetUniqueBufferAt(constant2), 569 analysis.GetUniqueBufferAt(inner_element_1)); 570 571 EXPECT_FALSE(AnyValuesInSameBufferInterfere()); 572 } 573 574 TEST_F(HloAliasAnalysisTest, SwizzlingWhile) { 575 // Test a while instruction with a body which permutes it's tuple parameter 576 // elements. HLO: 577 // 578 // body((F32[], F32[], F32[]) %tuple_param): 579 // return Tuple(%tuple_param{1}, %tuple_param{2}, %tuple_param{0}) 580 // 581 // condition((F32[], F32[]) %tuple_param): 582 // return Constant(false) 583 // 584 // entry: 585 // %constant1 = Constant(1.0) 586 // %constant2 = Constant(2.0) 587 // %constant3 = Constant(3.0) 588 // %tuple = Tuple(%constant1, %constant2, %constant3) 589 // return While(%tuple, body, condition) 590 // 591 const Shape tuple_shape = 592 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_, scalar_shape_}); 593 594 auto body_builder = HloComputation::Builder("body"); 595 auto body_param = body_builder.AddInstruction( 596 HloInstruction::CreateParameter(0, tuple_shape, "param")); 597 auto body_element_0 = body_builder.AddInstruction( 598 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); 599 auto body_element_1 = body_builder.AddInstruction( 600 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); 601 auto body_element_2 = body_builder.AddInstruction( 602 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 2)); 603 body_builder.AddInstruction(HloInstruction::CreateTuple( 604 {body_element_1, body_element_2, body_element_0})); 605 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); 606 607 auto cond_builder = HloComputation::Builder("condition"); 608 cond_builder.AddInstruction( 609 HloInstruction::CreateParameter(0, tuple_shape, "param")); 610 auto cond_constant = cond_builder.AddInstruction( 611 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 612 HloComputation* condition = 613 module_->AddEmbeddedComputation(cond_builder.Build()); 614 615 auto builder = HloComputation::Builder(TestName()); 616 auto constant1 = builder.AddInstruction( 617 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 618 auto constant2 = builder.AddInstruction( 619 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 620 auto constant3 = builder.AddInstruction( 621 HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0))); 622 auto tuple = builder.AddInstruction( 623 HloInstruction::CreateTuple({constant1, constant2, constant3})); 624 auto xla_while = builder.AddInstruction( 625 HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); 626 module_->AddEntryComputation(builder.Build()); 627 628 const HloAliasAnalysis& analysis = RunAnalysis(); 629 630 // The swizzling while makes most positions in the module alias leaving only 3 631 // HloBuffers. 632 EXPECT_THAT( 633 analysis.buffers(), 634 UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1), 635 analysis.GetUniqueBufferAt(tuple, /*index=*/{}), 636 analysis.GetUniqueBufferAt(cond_constant))); 637 638 // The tuple elements of the while and the three constant inputs should all be 639 // smooshed into the same buffer. 640 EXPECT_EQ(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}), 641 analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})); 642 EXPECT_EQ(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}), 643 analysis.GetUniqueBufferAt(xla_while, /*index=*/{2})); 644 EXPECT_EQ(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}), 645 analysis.GetUniqueBufferAt(constant1)); 646 EXPECT_EQ(analysis.GetUniqueBufferAt(constant1), 647 analysis.GetUniqueBufferAt(constant2)); 648 EXPECT_EQ(analysis.GetUniqueBufferAt(constant1), 649 analysis.GetUniqueBufferAt(constant3)); 650 651 // All elements in of the loop state tuple are forced into the same buffer 652 // resulting liveness interference. 653 EXPECT_TRUE(AnyValuesInSameBufferInterfere()); 654 } 655 656 TEST_F(HloAliasAnalysisTest, TupleSelect) { 657 // Test a kSelect of a tuple value. Non-top-level element flow through the 658 // instruction. 659 auto builder = HloComputation::Builder(TestName()); 660 auto pred = builder.AddInstruction( 661 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 662 auto constant1 = builder.AddInstruction( 663 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 664 auto constant2 = builder.AddInstruction( 665 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 666 auto constant3 = builder.AddInstruction( 667 HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0))); 668 auto constant4 = builder.AddInstruction( 669 HloInstruction::CreateConstant(Literal::CreateR0<float>(4.0))); 670 auto tuple1 = 671 builder.AddInstruction(HloInstruction::CreateTuple({constant1})); 672 auto tuple2 = 673 builder.AddInstruction(HloInstruction::CreateTuple({constant2})); 674 auto tuple3 = 675 builder.AddInstruction(HloInstruction::CreateTuple({constant3})); 676 auto tuple4 = 677 builder.AddInstruction(HloInstruction::CreateTuple({constant4})); 678 const Shape tuple_shape = tuple1->shape(); 679 auto select11 = builder.AddInstruction(HloInstruction::CreateTernary( 680 tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple1)); 681 auto select12 = builder.AddInstruction(HloInstruction::CreateTernary( 682 tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2)); 683 auto select34 = builder.AddInstruction(HloInstruction::CreateTernary( 684 tuple_shape, HloOpcode::kSelect, pred, tuple3, tuple4)); 685 auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary( 686 tuple_shape, HloOpcode::kSelect, pred, select12, select34)); 687 688 module_->AddEntryComputation(builder.Build()); 689 690 const HloAliasAnalysis& analysis = RunAnalysis(); 691 692 // Verify the buffer sets of each select. 693 EXPECT_THAT(GetBuffersAt(select11, /*index=*/{0}), 694 UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1))); 695 EXPECT_THAT(GetBuffersAt(select12, /*index=*/{0}), 696 UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1), 697 analysis.GetUniqueBufferAt(constant2))); 698 EXPECT_THAT(GetBuffersAt(select34, /*index=*/{0}), 699 UnorderedElementsAre(analysis.GetUniqueBufferAt(constant3), 700 analysis.GetUniqueBufferAt(constant4))); 701 EXPECT_THAT(GetBuffersAt(select1234, /*index=*/{0}), 702 UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1), 703 analysis.GetUniqueBufferAt(constant2), 704 analysis.GetUniqueBufferAt(constant3), 705 analysis.GetUniqueBufferAt(constant4))); 706 707 EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(select11)); 708 EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(select12)); 709 EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(select34)); 710 EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(select1234)); 711 712 EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select11)); 713 EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select12)); 714 EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select34)); 715 EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select1234)); 716 717 EXPECT_FALSE(AnyValuesInSameBufferInterfere()); 718 } 719 720 TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { 721 // Test a tuple-shaped kSelect feeding a kWhile instruction. HLO: 722 // 723 // body((F32[], F32[]) %tuple_param): 724 // %negate = Negate(%tuple_param{0}) 725 // return Tuple(%negate) 726 // 727 // condition((F32[], F32[]) %tuple_param): 728 // return Constant(false) 729 // 730 // entry: 731 // %constant1 = Constant(1.0) 732 // %constant2 = Constant(2.0) 733 // %tuple1 = Tuple(%constant1) 734 // %tuple2 = Tuple(%constant2) 735 // %select = Select(%tuple1, %tuple2) 736 // return While(%select, body, condition) 737 // 738 auto builder = HloComputation::Builder(TestName()); 739 740 const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_}); 741 742 // Element 0 passes transparently through the body. 743 auto body_builder = HloComputation::Builder("body"); 744 auto body_param = body_builder.AddInstruction( 745 HloInstruction::CreateParameter(0, tuple_shape, "param")); 746 auto body_element = body_builder.AddInstruction( 747 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); 748 auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( 749 scalar_shape_, HloOpcode::kNegate, body_element)); 750 body_builder.AddInstruction(HloInstruction::CreateTuple({negate})); 751 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); 752 753 auto cond_builder = HloComputation::Builder("condition"); 754 auto cond_param = cond_builder.AddInstruction( 755 HloInstruction::CreateParameter(0, tuple_shape, "param")); 756 cond_builder.AddInstruction( 757 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 758 HloComputation* condition = 759 module_->AddEmbeddedComputation(cond_builder.Build()); 760 761 auto pred = builder.AddInstruction( 762 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 763 auto constant1 = builder.AddInstruction( 764 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 765 auto constant2 = builder.AddInstruction( 766 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 767 auto tuple1 = 768 builder.AddInstruction(HloInstruction::CreateTuple({constant1})); 769 auto tuple2 = 770 builder.AddInstruction(HloInstruction::CreateTuple({constant2})); 771 auto select = builder.AddInstruction(HloInstruction::CreateTernary( 772 tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2)); 773 auto xla_while = builder.AddInstruction( 774 HloInstruction::CreateWhile(tuple_shape, condition, body, select)); 775 776 module_->AddEntryComputation(builder.Build()); 777 778 const HloAliasAnalysis& analysis = RunAnalysis(); 779 780 // The while should flatten the ambiguous select buffer set so that the buffer 781 // set contents (constant1 and constant2) becomes a single buffer. 782 EXPECT_EQ(analysis.GetUniqueBufferAt(constant1), 783 analysis.GetUniqueBufferAt(constant2)); 784 EXPECT_EQ(analysis.GetUniqueBufferAt(constant1), 785 analysis.GetUniqueBufferAt(xla_while, /*index=*/{0})); 786 787 EXPECT_THAT(GetValuesInBuffer(analysis.GetUniqueBufferAt(constant1)), 788 UnorderedElementsAre(GetValueDefinedAt(constant1), 789 GetValueDefinedAt(constant2), 790 GetValueDefinedAt(xla_while, /*index=*/{0}), 791 GetValueDefinedAt(body_param, /*index=*/{0}), 792 GetValueDefinedAt(cond_param, /*index=*/{0}), 793 GetValueDefinedAt(negate))); 794 EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(select)); 795 EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(xla_while)); 796 797 EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select)); 798 EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(xla_while)); 799 800 // The two operands of the select get flattened into the same buffer resulting 801 // in liveness interference. 802 EXPECT_TRUE(AnyValuesInSameBufferInterfere()); 803 } 804 805 TEST_F(HloAliasAnalysisTest, Bitcast) { 806 // Bitcasting a value should not produce a new buffer. 807 auto builder = HloComputation::Builder(TestName()); 808 auto constant = builder.AddInstruction( 809 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 810 auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( 811 scalar_shape_, HloOpcode::kBitcast, constant)); 812 813 module_->AddEntryComputation(builder.Build()); 814 815 const HloAliasAnalysis& analysis = RunAnalysis(); 816 817 EXPECT_EQ(analysis.buffers().size(), 1); 818 819 EXPECT_EQ(analysis.GetUniqueBufferAt(constant), 820 analysis.GetUniqueBufferAt(bitcast)); 821 } 822 823 TEST_F(HloAliasAnalysisTest, BitcastInterference) { 824 // A bitcast value simultaneously live with its operand should not cause 825 // interference. 826 auto builder = HloComputation::Builder(TestName()); 827 auto constant = builder.AddInstruction( 828 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 829 auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( 830 scalar_shape_, HloOpcode::kBitcast, constant)); 831 builder.AddInstruction(HloInstruction::CreateTuple({constant, bitcast})); 832 833 module_->AddEntryComputation(builder.Build()); 834 835 const HloAliasAnalysis& analysis = RunAnalysis(); 836 837 DependencyHloOrdering ordering(module_.get()); 838 EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering)); 839 } 840 841 TEST_F(HloAliasAnalysisTest, WhileInterference) { 842 // Build a while loop which has a parallel use of the init value. Depending on 843 // ordering there may be interference between the update-in-place while and 844 // the other use of the init. 845 auto builder = HloComputation::Builder(TestName()); 846 auto init = builder.AddInstruction( 847 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 848 849 auto cond_builder = HloComputation::Builder("condition"); 850 auto cond_param = cond_builder.AddInstruction( 851 HloInstruction::CreateParameter(0, init->shape(), "param")); 852 auto cond_root = cond_builder.AddInstruction( 853 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 854 HloComputation* condition = 855 module_->AddEmbeddedComputation(cond_builder.Build()); 856 857 auto body_builder = HloComputation::Builder("body"); 858 auto body_param = body_builder.AddInstruction( 859 HloInstruction::CreateParameter(0, init->shape(), "param")); 860 auto body_root = body_builder.AddInstruction( 861 HloInstruction::CreateUnary(init->shape(), HloOpcode::kExp, body_param)); 862 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); 863 864 auto xla_while = builder.AddInstruction( 865 HloInstruction::CreateWhile(init->shape(), condition, body, init)); 866 867 auto negate = builder.AddInstruction( 868 HloInstruction::CreateUnary(init->shape(), HloOpcode::kNegate, init)); 869 auto entry_root = 870 builder.AddInstruction(HloInstruction::CreateTuple({negate, xla_while})); 871 872 HloComputation* entry = module_->AddEntryComputation(builder.Build()); 873 874 const HloAliasAnalysis& analysis = RunAnalysis(); 875 876 { 877 // Dependency ordering should interfere because the negate and while are 878 // unordered. 879 DependencyHloOrdering ordering(module_.get()); 880 EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering)); 881 } 882 883 // For a sequential order, if there is interference iff the negate is after 884 // the while. 885 SequentialHloOrdering::HloModuleSequence sequence; 886 sequence[body] = {body_param, body_root}; 887 sequence[condition] = {cond_param, cond_root}; 888 { 889 sequence[entry] = {init, xla_while, negate, entry_root}; 890 SequentialHloOrdering ordering(module_.get(), sequence); 891 EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering)); 892 } 893 894 { 895 sequence[entry] = {init, negate, xla_while, entry_root}; 896 SequentialHloOrdering ordering(module_.get(), sequence); 897 EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering)); 898 } 899 } 900 901 } // namespace 902 } // namespace xla 903