1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 17 18 #include <map> 19 #include <memory> 20 21 #include "tensorflow/compiler/xla/literal_util.h" 22 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 23 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 24 #include "tensorflow/compiler/xla/service/instruction_fusion.h" 25 #include "tensorflow/compiler/xla/shape_util.h" 26 #include "tensorflow/compiler/xla/test.h" 27 #include "tensorflow/compiler/xla/test_helpers.h" 28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 29 #include "tensorflow/compiler/xla/xla_data.pb.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/platform/test.h" 32 33 namespace op = xla::testing::opcode_matchers; 34 35 namespace xla { 36 namespace { 37 38 using ::testing::UnorderedElementsAre; 39 using ::testing::UnorderedElementsAreArray; 40 41 class TuplePointsToAnalysisTest : public HloTestBase { 42 protected: 43 // Builds a module with the given entry computation and runs points to 44 // analysis. 45 void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) { 46 BuildModule(std::move(computation)); 47 RunAnalysis(); 48 } 49 50 void BuildModule(std::unique_ptr<HloComputation> computation) { 51 module_ = CreateNewModule(); 52 module_->AddEntryComputation(std::move(computation)); 53 } 54 55 void RunAnalysis() { 56 CHECK_NOTNULL(module_.get()); 57 points_to_analysis_ = 58 TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); 59 } 60 61 // Returns the LogicalBuffer defined at the given instruction and 62 // index. CHECKs if no buffer is defined at that point. 63 const LogicalBuffer* const GetBuffer(const HloInstruction* instruction, 64 const ShapeIndex& index) { 65 const auto& pointed_to = 66 points_to_analysis_->GetPointsToSet(instruction).element(index); 67 CHECK_EQ(1, pointed_to.size()); 68 CHECK_EQ(instruction, pointed_to[0]->instruction()); 69 CHECK(index == pointed_to[0]->index()); 70 return pointed_to[0]; 71 } 72 73 // Checks that the given points-to set contains exactly (unordered) the given 74 // LogicalBuffers. 75 void ExpectHasBuffers( 76 const PointsToSet::BufferList& points_to_set, 77 tensorflow::gtl::ArraySlice<const LogicalBuffer*> buffers) { 78 std::vector<const LogicalBuffer*> vec(buffers.begin(), buffers.end()); 79 EXPECT_THAT(points_to_set, UnorderedElementsAreArray(vec)); 80 } 81 82 // Checks that the given points-to set contains exactly (unordered) the 83 // top-level buffers of the given instructions. 84 void ExpectHasTopLevelBuffers( 85 const PointsToSet::BufferList& points_to_set, 86 tensorflow::gtl::ArraySlice<HloInstruction*> instructions) { 87 PointsToSet::BufferList buffers; 88 for (auto instruction : instructions) { 89 buffers.push_back(GetBuffer(instruction, /*index=*/{})); 90 } 91 ExpectHasBuffers(points_to_set, buffers); 92 } 93 94 // Overload which takes a set instead of a vector. 95 void ExpectHasTopLevelBuffers( 96 const PointsToSet::BufferSet& points_to_set, 97 tensorflow::gtl::ArraySlice<HloInstruction*> instructions) { 98 ExpectHasTopLevelBuffers( 99 PointsToSet::BufferList(points_to_set.begin(), points_to_set.end()), 100 instructions); 101 } 102 103 // Checks that the buffer defined at the given instruction and index has 104 // aliases which are exactly (unordered) the given instruction/index pairs. 105 void ExpectHasBufferAliases( 106 const HloInstruction* instruction, const ShapeIndex& index, 107 tensorflow::gtl::ArraySlice<std::pair<HloInstruction*, ShapeIndex>> 108 expected) { 109 const LogicalBuffer* buffer = 110 points_to_analysis_->GetBufferDefinedAt(instruction, index) 111 .ValueOrDie(); 112 std::vector<BufferAlias> expected_aliases; 113 for (auto& pair : expected) { 114 expected_aliases.push_back(BufferAlias(pair.first, pair.second)); 115 } 116 EXPECT_THAT(points_to_analysis_->GetBufferAliases(*buffer), 117 UnorderedElementsAreArray(expected_aliases)); 118 } 119 120 std::unique_ptr<HloModule> module_; 121 std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_; 122 }; 123 124 TEST_F(TuplePointsToAnalysisTest, SimpleTuple) { 125 auto builder = HloComputation::Builder(TestName()); 126 auto constant1 = builder.AddInstruction( 127 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 128 auto constant2 = builder.AddInstruction( 129 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 130 auto tuple = builder.AddInstruction( 131 HloInstruction::CreateTuple({constant1, constant2})); 132 133 BuildModuleAndRunAnalysis(builder.Build()); 134 EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(constant1).size()); 135 ExpectHasTopLevelBuffers( 136 points_to_analysis_->GetPointsToSet(constant1).element({}), {constant1}); 137 EXPECT_TRUE( 138 points_to_analysis_->GetPointsToSet(constant1).tuple_sources({}).empty()); 139 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(tuple).IsDistinct()); 140 141 EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(constant2).size()); 142 ExpectHasTopLevelBuffers( 143 points_to_analysis_->GetPointsToSet(constant2).element({}), {constant2}); 144 EXPECT_TRUE( 145 points_to_analysis_->GetPointsToSet(constant2).tuple_sources({}).empty()); 146 147 EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(tuple).size()); 148 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous()); 149 EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}), 150 UnorderedElementsAre(tuple)); 151 152 ExpectHasTopLevelBuffers( 153 points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(), 154 {constant1, constant2, tuple}); 155 ExpectHasTopLevelBuffers( 156 points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple}); 157 ExpectHasTopLevelBuffers( 158 points_to_analysis_->GetPointsToSet(tuple).element({0}), {constant1}); 159 ExpectHasTopLevelBuffers( 160 points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant2}); 161 162 const PointsToSet& tuple_points_to_set = 163 points_to_analysis_->GetPointsToSet(tuple); 164 EXPECT_TRUE(tuple_points_to_set.ContainsBufferAtIndex( 165 *GetBuffer(constant1, {}), {0})); 166 EXPECT_TRUE(tuple_points_to_set.ContainsBufferAtIndex( 167 *GetBuffer(constant2, {}), {1})); 168 EXPECT_FALSE(tuple_points_to_set.ContainsBufferAtIndex( 169 *GetBuffer(constant2, {}), {0})); 170 EXPECT_TRUE(tuple_points_to_set.ContainsBuffer(*GetBuffer(constant1, {}))); 171 EXPECT_TRUE(tuple_points_to_set.ContainsBuffer(*GetBuffer(constant2, {}))); 172 } 173 174 TEST_F(TuplePointsToAnalysisTest, NestedTuple) { 175 // Create a (nested) tuple containing an inner tuple. The points-to set of the 176 // outer tuple should contain all elements of the points-to set of the inner 177 // tuple. 178 auto builder = HloComputation::Builder(TestName()); 179 auto constant1 = builder.AddInstruction( 180 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 181 auto constant2 = builder.AddInstruction( 182 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 183 auto inner_tuple = builder.AddInstruction( 184 HloInstruction::CreateTuple({constant1, constant2})); 185 186 auto constant3 = builder.AddInstruction( 187 HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0))); 188 auto tuple = builder.AddInstruction( 189 HloInstruction::CreateTuple({inner_tuple, constant3})); 190 191 BuildModuleAndRunAnalysis(builder.Build()); 192 ExpectHasTopLevelBuffers( 193 points_to_analysis_->GetPointsToSet(constant1).element({}), {constant1}); 194 ExpectHasTopLevelBuffers( 195 points_to_analysis_->GetPointsToSet(constant2).element({}), {constant2}); 196 ExpectHasTopLevelBuffers( 197 points_to_analysis_->GetPointsToSet(constant3).element({}), {constant3}); 198 199 EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(inner_tuple).size()); 200 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(inner_tuple).IsAmbiguous()); 201 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(inner_tuple).IsDistinct()); 202 ExpectHasTopLevelBuffers( 203 points_to_analysis_->GetPointsToSet(inner_tuple).CreateFlattenedSet(), 204 {constant1, constant2, inner_tuple}); 205 ExpectHasTopLevelBuffers( 206 points_to_analysis_->GetPointsToSet(inner_tuple).element({}), 207 {inner_tuple}); 208 EXPECT_THAT( 209 points_to_analysis_->GetPointsToSet(inner_tuple).tuple_sources({}), 210 UnorderedElementsAre(inner_tuple)); 211 212 EXPECT_EQ(5, points_to_analysis_->GetPointsToSet(tuple).size()); 213 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous()); 214 ExpectHasTopLevelBuffers( 215 points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(), 216 {constant1, constant2, constant3, inner_tuple, tuple}); 217 218 EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}), 219 UnorderedElementsAre(tuple)); 220 EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({0}), 221 UnorderedElementsAre(inner_tuple)); 222 EXPECT_TRUE( 223 points_to_analysis_->GetPointsToSet(tuple).tuple_sources({1}).empty()); 224 225 ExpectHasTopLevelBuffers( 226 points_to_analysis_->GetPointsToSet(tuple).element({0}), {inner_tuple}); 227 ExpectHasTopLevelBuffers( 228 points_to_analysis_->GetPointsToSet(tuple).element({0, 0}), {constant1}); 229 ExpectHasTopLevelBuffers( 230 points_to_analysis_->GetPointsToSet(tuple).element({0, 1}), {constant2}); 231 ExpectHasTopLevelBuffers( 232 points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant3}); 233 } 234 235 TEST_F(TuplePointsToAnalysisTest, GetTupleElement) { 236 // Create a nested tuple, then extract the inner tuple with GetTupleElement. 237 // The points-to set of the GetTupleElement should be the same as the inner 238 // tuple. 239 auto builder = HloComputation::Builder(TestName()); 240 auto constant1 = builder.AddInstruction( 241 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 242 auto constant2 = builder.AddInstruction( 243 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 244 auto inner_tuple = builder.AddInstruction( 245 HloInstruction::CreateTuple({constant1, constant2})); 246 247 auto constant3 = builder.AddInstruction( 248 HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0))); 249 auto tuple = builder.AddInstruction( 250 HloInstruction::CreateTuple({inner_tuple, constant3})); 251 252 auto get_tuple_element = builder.AddInstruction( 253 HloInstruction::CreateGetTupleElement(inner_tuple->shape(), tuple, 0)); 254 255 BuildModuleAndRunAnalysis(builder.Build()); 256 257 auto& points_to_set = points_to_analysis_->GetPointsToSet(get_tuple_element); 258 EXPECT_EQ(3, points_to_set.size()); 259 EXPECT_FALSE(points_to_set.IsAmbiguous()); 260 EXPECT_TRUE(points_to_set.IsDistinct()); 261 ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(), 262 {constant1, constant2, inner_tuple}); 263 ExpectHasTopLevelBuffers(points_to_set.element({}), {inner_tuple}); 264 265 EXPECT_THAT(points_to_set.tuple_sources({}), 266 UnorderedElementsAre(inner_tuple)); 267 } 268 269 TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) { 270 // Create a tuple which contains duplicate elements. 271 auto builder = HloComputation::Builder(TestName()); 272 auto constant = builder.AddInstruction( 273 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 274 auto tuple = builder.AddInstruction( 275 HloInstruction::CreateTuple({constant, constant, constant})); 276 277 BuildModuleAndRunAnalysis(builder.Build()); 278 279 EXPECT_EQ(2, points_to_analysis_->GetPointsToSet(tuple).size()); 280 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous()); 281 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsDistinct()); 282 ExpectHasTopLevelBuffers( 283 points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple}); 284 ExpectHasTopLevelBuffers( 285 points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(), 286 {constant, tuple}); 287 } 288 289 TEST_F(TuplePointsToAnalysisTest, TupleCopy) { 290 // Create a copy (HloOpcode::kCopy) of a tuple. The points to sets should be 291 // the same. 292 auto builder = HloComputation::Builder(TestName()); 293 auto constant1 = builder.AddInstruction( 294 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 295 auto constant2 = builder.AddInstruction( 296 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 297 auto tuple = builder.AddInstruction( 298 HloInstruction::CreateTuple({constant1, constant2})); 299 auto copy = builder.AddInstruction( 300 HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple)); 301 302 BuildModuleAndRunAnalysis(builder.Build()); 303 304 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(copy).IsAmbiguous()); 305 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(copy).IsDistinct()); 306 ExpectHasTopLevelBuffers( 307 points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(), 308 {constant1, constant2, tuple}); 309 ExpectHasTopLevelBuffers( 310 points_to_analysis_->GetPointsToSet(copy).element({}), {copy}); 311 ExpectHasTopLevelBuffers( 312 points_to_analysis_->GetPointsToSet(copy).CreateFlattenedSet(), 313 {constant1, constant2, copy}); 314 } 315 316 TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) { 317 // Send forwards its operand to the output tuple at {0}. 318 auto builder = HloComputation::Builder(TestName()); 319 auto constant = builder.AddInstruction( 320 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 321 auto send = builder.AddInstruction( 322 HloInstruction::CreateSend(constant, /*channel_id=*/0)); 323 auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); 324 325 BuildModuleAndRunAnalysis(builder.Build()); 326 327 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send).IsAmbiguous()); 328 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send).IsDistinct()); 329 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send_done).IsAmbiguous()); 330 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send_done).IsDistinct()); 331 332 ExpectHasTopLevelBuffers( 333 points_to_analysis_->GetPointsToSet(send).element({}), {send}); 334 ExpectHasTopLevelBuffers( 335 points_to_analysis_->GetPointsToSet(send).element({0}), {constant}); 336 ExpectHasTopLevelBuffers( 337 points_to_analysis_->GetPointsToSet(send_done).CreateFlattenedSet(), 338 {send_done}); 339 ExpectHasBufferAliases(constant, {}, {{constant, {}}, {send, {0}}}); 340 } 341 342 TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) { 343 // RecvDone forwards its operand tuple element at {0} to the output. 344 auto builder = HloComputation::Builder(TestName()); 345 auto recv = builder.AddInstruction(HloInstruction::CreateRecv( 346 ShapeUtil::MakeShape(F32, {1, 2, 3}), /*channel_id=*/0)); 347 auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); 348 349 BuildModuleAndRunAnalysis(builder.Build()); 350 351 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv).IsAmbiguous()); 352 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv).IsDistinct()); 353 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv_done).IsAmbiguous()); 354 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv_done).IsDistinct()); 355 356 ExpectHasTopLevelBuffers( 357 points_to_analysis_->GetPointsToSet(recv).element({}), {recv}); 358 ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {}}}); 359 } 360 361 TEST_F(TuplePointsToAnalysisTest, TupleSelect) { 362 // Select from two different tuples. This should create an ambiguous points to 363 // set containing the union of both sides. 364 auto builder = HloComputation::Builder(TestName()); 365 auto constant1 = builder.AddInstruction( 366 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 367 auto constant2 = builder.AddInstruction( 368 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 369 auto tuple1 = builder.AddInstruction( 370 HloInstruction::CreateTuple({constant1, constant2})); 371 auto tuple2 = builder.AddInstruction( 372 HloInstruction::CreateTuple({constant2, constant2})); 373 374 auto pred = builder.AddInstruction( 375 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 376 auto select = builder.AddInstruction(HloInstruction::CreateTernary( 377 tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); 378 379 BuildModuleAndRunAnalysis(builder.Build()); 380 381 auto& points_to_set = points_to_analysis_->GetPointsToSet(select); 382 EXPECT_EQ(3, points_to_set.size()); 383 EXPECT_TRUE(points_to_set.IsAmbiguous()); 384 EXPECT_FALSE(points_to_set.IsDistinct()); 385 ExpectHasTopLevelBuffers(points_to_set.element({}), {select}); 386 ExpectHasTopLevelBuffers(points_to_set.element({0}), {constant1, constant2}); 387 ExpectHasTopLevelBuffers(points_to_set.element({1}), {constant2}); 388 ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(), 389 {constant1, constant2, select}); 390 } 391 392 TEST_F(TuplePointsToAnalysisTest, SelectTupleParameters) { 393 // Create a Select which selects between two tuple parameters. Verify the 394 // points-to sets and tuple sources are properly set. 395 Shape tuple_shape = ShapeUtil::MakeTupleShape( 396 {ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeShape(U32, {5})}); 397 398 auto builder = HloComputation::Builder(TestName()); 399 auto param0 = builder.AddInstruction( 400 HloInstruction::CreateParameter(0, tuple_shape, "param0")); 401 auto param1 = builder.AddInstruction( 402 HloInstruction::CreateParameter(1, tuple_shape, "param1")); 403 auto pred = builder.AddInstruction( 404 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 405 auto select = builder.AddInstruction(HloInstruction::CreateTernary( 406 tuple_shape, HloOpcode::kSelect, pred, param0, param1)); 407 auto copy = builder.AddInstruction( 408 HloInstruction::CreateUnary(tuple_shape, HloOpcode::kCopy, select)); 409 410 BuildModuleAndRunAnalysis(builder.Build()); 411 412 // The points-to set of each element of a tuple parameters should be itself 413 // with the appropriate index. 414 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({}), 415 {GetBuffer(param0, {})}); 416 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({0}), 417 {GetBuffer(param0, {0})}); 418 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({1}), 419 {GetBuffer(param0, {1})}); 420 421 // Select's point-to set of its subelements should be the respective 422 // subelements of param0 and param1. The top-level buffer, however, does not 423 // alias as it is created by the select instruction. 424 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({}), 425 {GetBuffer(select, {})}); 426 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({0}), 427 {GetBuffer(param0, {0}), GetBuffer(param1, {0})}); 428 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({1}), 429 {GetBuffer(param0, {1}), GetBuffer(param1, {1})}); 430 431 // Copy should be identical to select other than the top-level buffer. 432 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({}), 433 {GetBuffer(copy, {})}); 434 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({0}), 435 {GetBuffer(param0, {0}), GetBuffer(param1, {0})}); 436 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({1}), 437 {GetBuffer(param0, {1}), GetBuffer(param1, {1})}); 438 } 439 440 TEST_F(TuplePointsToAnalysisTest, UnambiguousTupleSelect) { 441 // Select from two identical tuples. The result should not be ambiguous. 442 auto builder = HloComputation::Builder(TestName()); 443 auto constant1 = builder.AddInstruction( 444 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 445 auto constant2 = builder.AddInstruction( 446 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 447 auto tuple1 = builder.AddInstruction( 448 HloInstruction::CreateTuple({constant1, constant2})); 449 auto tuple2 = builder.AddInstruction( 450 HloInstruction::CreateTuple({constant1, constant2})); 451 452 auto pred = builder.AddInstruction( 453 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 454 auto select = builder.AddInstruction(HloInstruction::CreateTernary( 455 tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); 456 457 BuildModuleAndRunAnalysis(builder.Build()); 458 459 auto& points_to_set = points_to_analysis_->GetPointsToSet(select); 460 EXPECT_EQ(3, points_to_set.size()); 461 EXPECT_FALSE(points_to_set.IsAmbiguous()); 462 EXPECT_TRUE(points_to_set.IsDistinct()); 463 ExpectHasTopLevelBuffers(points_to_set.element({}), {select}); 464 ExpectHasTopLevelBuffers(points_to_set.element({0}), {constant1}); 465 ExpectHasTopLevelBuffers(points_to_set.element({1}), {constant2}); 466 ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(), 467 {constant1, constant2, select}); 468 } 469 470 TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) { 471 // Select from nested tuples. Verify that the nested points-to sets contain 472 // the right values. 473 auto builder = HloComputation::Builder(TestName()); 474 auto constant1 = builder.AddInstruction( 475 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 476 auto constant2 = builder.AddInstruction( 477 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 478 auto inner_tuple1 = builder.AddInstruction( 479 HloInstruction::CreateTuple({constant1, constant2})); 480 auto inner_tuple2 = builder.AddInstruction( 481 HloInstruction::CreateTuple({constant2, constant2})); 482 483 auto tuple1 = 484 builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple1})); 485 auto tuple2 = 486 builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple2})); 487 488 auto pred = builder.AddInstruction( 489 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 490 auto select = builder.AddInstruction(HloInstruction::CreateTernary( 491 tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); 492 493 BuildModuleAndRunAnalysis(builder.Build()); 494 495 auto& points_to_set = points_to_analysis_->GetPointsToSet(select); 496 EXPECT_EQ(5, points_to_set.size()); 497 EXPECT_TRUE(points_to_set.IsAmbiguous()); 498 EXPECT_FALSE(points_to_set.IsDistinct()); 499 500 // Verify points-to set. 501 ExpectHasTopLevelBuffers(points_to_set.element({}), {select}); 502 ExpectHasTopLevelBuffers(points_to_set.element({0}), 503 {inner_tuple1, inner_tuple2}); 504 ExpectHasTopLevelBuffers(points_to_set.element({0, 0}), 505 {constant1, constant2}); 506 ExpectHasTopLevelBuffers(points_to_set.element({0, 1}), {constant2}); 507 508 // Verify tuple sources. 509 EXPECT_THAT(points_to_set.tuple_sources({}), 510 UnorderedElementsAre(tuple1, tuple2)); 511 EXPECT_THAT(points_to_set.tuple_sources({0}), 512 UnorderedElementsAre(inner_tuple1, inner_tuple2)); 513 EXPECT_EQ(0, points_to_set.tuple_sources({0, 0}).size()); 514 EXPECT_EQ(0, points_to_set.tuple_sources({0, 1}).size()); 515 } 516 517 TEST_F(TuplePointsToAnalysisTest, TupleWithBitcast) { 518 // Bitcast is an alias of its operand. A tuple with a bitcast element should 519 // have the operand of the bitcast in its points-to set. 520 auto builder = HloComputation::Builder(TestName()); 521 auto constant1 = builder.AddInstruction( 522 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 523 auto constant2 = builder.AddInstruction( 524 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 525 auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( 526 constant2->shape(), HloOpcode::kBitcast, constant2)); 527 auto tuple = 528 builder.AddInstruction(HloInstruction::CreateTuple({constant1, bitcast})); 529 530 BuildModuleAndRunAnalysis(builder.Build()); 531 532 EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(bitcast).size()); 533 ExpectHasTopLevelBuffers( 534 points_to_analysis_->GetPointsToSet(bitcast).element({}), {constant2}); 535 EXPECT_TRUE( 536 points_to_analysis_->GetPointsToSet(bitcast).tuple_sources({}).empty()); 537 538 EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(tuple).size()); 539 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous()); 540 EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}), 541 UnorderedElementsAre(tuple)); 542 543 ExpectHasTopLevelBuffers( 544 points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(), 545 {constant1, constant2, tuple}); 546 ExpectHasTopLevelBuffers( 547 points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple}); 548 ExpectHasTopLevelBuffers( 549 points_to_analysis_->GetPointsToSet(tuple).element({0}), {constant1}); 550 ExpectHasTopLevelBuffers( 551 points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant2}); 552 } 553 554 TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) { 555 // Construct a tuple constant and kCopy it. Verify the points-to set of the 556 // copy correctly correctly points into the nested elements of the constant. 557 auto builder = HloComputation::Builder(TestName()); 558 auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( 559 Literal::MakeTuple({Literal::CreateR2<float>({{1.0}, {2.0}}).get(), 560 Literal::CreateR1<float>({2.0, 42}).get()}))); 561 auto copy = builder.AddInstruction(HloInstruction::CreateUnary( 562 tuple_constant->shape(), HloOpcode::kCopy, tuple_constant)); 563 564 BuildModuleAndRunAnalysis(builder.Build()); 565 566 auto& points_to_set = points_to_analysis_->GetPointsToSet(copy); 567 568 ExpectHasBuffers(points_to_set.element({}), {GetBuffer(copy, {})}); 569 ExpectHasBuffers(points_to_set.element({0}), 570 {GetBuffer(tuple_constant, {0})}); 571 ExpectHasBuffers(points_to_set.element({1}), 572 {GetBuffer(tuple_constant, {1})}); 573 } 574 575 TEST_F(TuplePointsToAnalysisTest, BufferAliases) { 576 // Create a nested tuple in which individual elements appear multiple 577 // times. Verify buffer alias sets. 578 auto builder = HloComputation::Builder(TestName()); 579 auto constant1 = builder.AddInstruction( 580 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 581 auto constant2 = builder.AddInstruction( 582 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); 583 auto inner_tuple = builder.AddInstruction( 584 HloInstruction::CreateTuple({constant1, constant2})); 585 auto tuple = builder.AddInstruction( 586 HloInstruction::CreateTuple({inner_tuple, constant2})); 587 588 BuildModuleAndRunAnalysis(builder.Build()); 589 590 ExpectHasBufferAliases( 591 constant1, /*index=*/{}, 592 {{constant1, {}}, {inner_tuple, {0}}, {tuple, {0, 0}}}); 593 ExpectHasBufferAliases( 594 constant2, /*index=*/{}, 595 {{constant2, {}}, {inner_tuple, {1}}, {tuple, {0, 1}}, {tuple, {1}}}); 596 ExpectHasBufferAliases(inner_tuple, /*index=*/{}, 597 {{inner_tuple, {}}, {tuple, {0}}}); 598 ExpectHasBufferAliases(tuple, /*index=*/{}, {{tuple, {}}}); 599 } 600 601 class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { 602 protected: 603 // Builds a computation, runs instruction fusion HloPass, runs points-to 604 // analysis, then checks for expected results (see unit test cases for 605 // example computation graphs). 606 void Run(const bool add_additional_gte0_user) { 607 Shape input_shape = ShapeUtil::MakeShape(F32, {8}); 608 Shape update_shape = ShapeUtil::MakeShape(F32, {3}); 609 Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); 610 Shape tuple_shape = 611 ShapeUtil::MakeTupleShape({input_shape, update_shape, starts_shape}); 612 613 auto builder = HloComputation::Builder(TestName()); 614 // Create tuple-shaped parameter. 615 auto tuple_param0 = builder.AddInstruction( 616 HloInstruction::CreateParameter(0, tuple_shape, "param0")); 617 // Create 'tuple_element1' = GetTupleElement(tuple_param0, 1). 618 auto tuple_element1 = builder.AddInstruction( 619 HloInstruction::CreateGetTupleElement(update_shape, tuple_param0, 1)); 620 auto ones = builder.AddInstruction(HloInstruction::CreateConstant( 621 Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f}))); 622 // Create 'update' = Add(GetTupleElement(tuple_param0, 1), ones) 623 auto update = builder.AddInstruction(HloInstruction::CreateBinary( 624 update_shape, HloOpcode::kAdd, tuple_element1, ones)); 625 // Create 'input' = GetTupleElement(tuple_param0, 0). 626 auto input = builder.AddInstruction( 627 HloInstruction::CreateGetTupleElement(input_shape, tuple_param0, 0)); 628 629 if (add_additional_gte0_user) { 630 // Create 'slice' as an additional user of 'input'. 631 auto slice = builder.AddInstruction( 632 HloInstruction::CreateSlice(update_shape, input, {0}, {3}, {1})); 633 // Modify 'update' to take 'slice' output. 634 update = builder.AddInstruction(HloInstruction::CreateBinary( 635 update_shape, HloOpcode::kAdd, update, slice)); 636 } 637 638 // Create slice 'starts' = GetTupleElement(tuple_param0, 2). 639 auto starts = builder.AddInstruction( 640 HloInstruction::CreateGetTupleElement(starts_shape, tuple_param0, 2)); 641 // Update 'input' with 'update' at dynamic 'starts' indices. 642 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( 643 input_shape, input, update, starts)); 644 645 // Build computation and add it to module as entry computation. 646 BuildModule(builder.Build()); 647 // Run instruction fusion HloPass. 648 EXPECT_TRUE(InstructionFusion(InstructionFusion::IsExpensive) 649 .Run(module_.get()) 650 .ValueOrDie()); 651 // Get computation root instruction (should be a kFusion). 652 auto* fusion = module_->entry_computation()->root_instruction(); 653 EXPECT_THAT(fusion, op::Fusion(tuple_param0)); 654 // Run points-to analysis (should include fused instructions from 'fusion'). 655 RunAnalysis(); 656 657 // Check points-to set of fusion parameter associated with 'tuple_param0'. 658 auto* fusion_param = GetFusionParameterForOperand(fusion, tuple_param0); 659 ExpectHasBuffers( 660 points_to_analysis_->GetPointsToSet(fusion_param).element({}), 661 {GetBuffer(fusion_param, {})}); 662 ExpectHasBuffers( 663 points_to_analysis_->GetPointsToSet(fusion_param).element({0}), 664 {GetBuffer(fusion_param, {0})}); 665 ExpectHasBuffers( 666 points_to_analysis_->GetPointsToSet(fusion_param).element({1}), 667 {GetBuffer(fusion_param, {1})}); 668 ExpectHasBuffers( 669 points_to_analysis_->GetPointsToSet(fusion_param).element({2}), 670 {GetBuffer(fusion_param, {2})}); 671 672 // Check that Gte at tuple_index = 0 points-to fusion_param({0}) 673 auto fused_gte0 = GetUniqueFusionParameterUserAt(fusion_param, 0); 674 ExpectHasBuffers( 675 points_to_analysis_->GetPointsToSet(fused_gte0).element({}), 676 {GetBuffer(fusion_param, {0})}); 677 // Check that Gte at tuple_index = 1 points-to fusion_param({1}) 678 auto fused_gte1 = GetUniqueFusionParameterUserAt(fusion_param, 1); 679 ExpectHasBuffers( 680 points_to_analysis_->GetPointsToSet(fused_gte1).element({}), 681 {GetBuffer(fusion_param, {1})}); 682 // Check that Gte at tuple_index = 2 points-to fusion_param({2}) 683 auto fused_gte2 = GetUniqueFusionParameterUserAt(fusion_param, 2); 684 ExpectHasBuffers( 685 points_to_analysis_->GetPointsToSet(fused_gte2).element({}), 686 {GetBuffer(fusion_param, {2})}); 687 688 // Check buffer aliases of 'fusion_param' at shape index {0}. 689 ExpectHasBufferAliases(fusion_param, /*index=*/{0}, 690 {{fusion_param, {0}}, {fused_gte0, {}}}); 691 // Check buffer aliases of 'fusion_param' at shape index {1}. 692 ExpectHasBufferAliases(fusion_param, /*index=*/{1}, 693 {{fusion_param, {1}}, {fused_gte1, {}}}); 694 // Check buffer aliases of 'fusion_param' at shape index {2}. 695 ExpectHasBufferAliases(fusion_param, /*index=*/{2}, 696 {{fusion_param, {2}}, {fused_gte2, {}}}); 697 698 // Check number of users of 'fusion_param' aliases at shape index {0}. 699 ExpectNumUsersOfAliases(fusion_param, {0}, 700 add_additional_gte0_user ? 2 : 1); 701 } 702 703 // Returns fusion parameter (from 'fusion.fused_instructions') corresponding 704 // to fusion 'operand'. 705 HloInstruction* GetFusionParameterForOperand(HloInstruction* fusion, 706 HloInstruction* operand) { 707 auto it = std::find_if( 708 fusion->fused_instructions().begin(), 709 fusion->fused_instructions().end(), [=](const HloInstruction* fused) { 710 return fused->opcode() == HloOpcode::kParameter && 711 fusion->operand(fused->parameter_number()) == operand; 712 }); 713 CHECK(it != fusion->fused_instructions().end()); 714 return *it; 715 } 716 717 // Returns all users of 'fusion_paran' at 'tuple_index'. 718 std::vector<HloInstruction*> GetFusionParameterUsersAt( 719 HloInstruction* fusion_param, int64 tuple_index) { 720 CHECK(ShapeUtil::IsTuple(fusion_param->shape())); 721 std::vector<HloInstruction*> users_at_tuple_index; 722 for (auto user : fusion_param->users()) { 723 CHECK_EQ(HloOpcode::kGetTupleElement, user->opcode()); 724 if (user->tuple_index() == tuple_index) { 725 users_at_tuple_index.push_back(user); 726 } 727 } 728 return users_at_tuple_index; 729 } 730 731 // Returns the unique user of 'fusion_param' at 'tuple_index'. 732 HloInstruction* GetUniqueFusionParameterUserAt(HloInstruction* fusion_param, 733 int64 tuple_index) { 734 std::vector<HloInstruction*> users = 735 GetFusionParameterUsersAt(fusion_param, tuple_index); 736 CHECK_EQ(1, users.size()); 737 return users[0]; 738 } 739 740 // Checks that the count of all users of all aliases of 'instruction' at 741 // 'index' match 'expected_num_users'. 742 void ExpectNumUsersOfAliases(const HloInstruction* instruction, 743 const ShapeIndex& index, 744 const int64 expected_num_users) { 745 const auto* buffer = GetBuffer(instruction, index); 746 int64 num_users = 0; 747 for (const auto& alias : points_to_analysis_->GetBufferAliases(*buffer)) { 748 for (auto user : alias.instruction()->users()) { 749 if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { 750 // Gte instructions only access the top-level buffer of their operand. 751 continue; 752 } 753 ++num_users; 754 } 755 } 756 EXPECT_EQ(expected_num_users, num_users); 757 } 758 }; 759 760 // Tests the points-to set of tuple-shaped fusion parameter 0 and all GTE users. 761 // Tests the alias set of tuple-shaped fusion parameter 0 at all shape indices. 762 // Tests that there is a single user of the aliases of tuple-shaped fusion 763 // parameter 0 at shape index {0}. 764 // 765 // Param0 Const 766 // \ / 767 // Fusion 768 // / \ 769 // FusionParam0 FusionParam1 770 // / | \ | 771 // Gte(0) Gte(2) Gte(1) / 772 // \ | \ / 773 // \ | Add 774 // \ | / 775 // \0 |2 /1 776 // DynamicUpdateSlice // fused root. 777 // 778 TEST_F(FusionPointsToAnalysisTest, FusionParam0OneUser) { 779 Run(/*add_additional_gte0_user=*/false); 780 } 781 782 // Tests the points-to set of tuple-shaped fusion parameter 0 and all GTE users. 783 // Tests the alias set of tuple-shaped fusion parameter 0 at all shape indices. 784 // Tests that there are two users of the aliases of tuple-shaped fusion 785 // parameter 0 at shape index {0}. 786 // 787 // Param0 Const 788 // \ / 789 // Fusion 790 // / \ 791 // FusionParam0 FusionParam1 792 // / | \ | 793 // Gte(2) Gte(0) Gte(1) / 794 // \ | \ / 795 // \ |\ Add 796 // \ | \ / 797 // | | Slice / 798 // | | \ / 799 // | | Add 800 // | | | 801 // |2 |0 |1 802 // DynamicUpdateSlice // fused root. 803 // 804 TEST_F(FusionPointsToAnalysisTest, FusionParam0TwoUsers) { 805 Run(/*add_additional_gte0_user=*/true); 806 } 807 808 } // namespace 809 } // namespace xla 810