1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_computation.h" 17 18 #include <set> 19 20 #include "tensorflow/compiler/xla/literal_util.h" 21 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 24 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 25 #include "tensorflow/compiler/xla/shape_util.h" 26 #include "tensorflow/compiler/xla/test.h" 27 #include "tensorflow/compiler/xla/test_helpers.h" 28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 29 30 namespace op = xla::testing::opcode_matchers; 31 32 namespace xla { 33 34 namespace { 35 36 using ::testing::ElementsAre; 37 using ::testing::UnorderedElementsAre; 38 39 class HloComputationTest : public HloTestBase { 40 protected: 41 HloComputationTest() {} 42 43 // Create a computation which takes a scalar and returns its negation. 44 std::unique_ptr<HloComputation> CreateNegateComputation() { 45 auto builder = HloComputation::Builder("Negate"); 46 auto param = builder.AddInstruction( 47 HloInstruction::CreateParameter(0, r0f32_, "param0")); 48 builder.AddInstruction( 49 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); 50 return builder.Build(); 51 } 52 53 // Creates a computation which calls map with the given computation. 54 std::unique_ptr<HloComputation> CreateMapComputation( 55 HloComputation* map_computation) { 56 auto builder = HloComputation::Builder("Map"); 57 auto param = builder.AddInstruction( 58 HloInstruction::CreateParameter(0, r0f32_, "param0")); 59 builder.AddInstruction( 60 HloInstruction::CreateMap(r0f32_, {param}, map_computation)); 61 return builder.Build(); 62 } 63 64 Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); 65 }; 66 67 TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) { 68 auto module = CreateNewModule(); 69 auto negate_computation = 70 module->AddEntryComputation(CreateNegateComputation()); 71 EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty()); 72 } 73 74 TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) { 75 // Create computation which calls one other computation. 76 auto module = CreateNewModule(); 77 auto negate_computation = 78 module->AddEmbeddedComputation(CreateNegateComputation()); 79 auto map_computation = 80 module->AddEntryComputation(CreateMapComputation(negate_computation)); 81 EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty()); 82 EXPECT_THAT(map_computation->MakeEmbeddedComputationsList(), 83 ElementsAre(negate_computation)); 84 } 85 86 TEST_F(HloComputationTest, GetEmbeddedComputationsDiamond) { 87 // Create computations with a diamond-shaped callgraph. 88 auto module = CreateNewModule(); 89 auto negate_computation = 90 module->AddEmbeddedComputation(CreateNegateComputation()); 91 auto map1_computation = 92 module->AddEmbeddedComputation(CreateMapComputation(negate_computation)); 93 auto map2_computation = 94 module->AddEmbeddedComputation(CreateMapComputation(negate_computation)); 95 96 auto builder = HloComputation::Builder(TestName()); 97 auto param = builder.AddInstruction( 98 HloInstruction::CreateParameter(0, r0f32_, "param0")); 99 auto map1 = builder.AddInstruction( 100 HloInstruction::CreateMap(r0f32_, {param}, map1_computation)); 101 auto map2 = builder.AddInstruction( 102 HloInstruction::CreateMap(r0f32_, {param}, map2_computation)); 103 builder.AddInstruction( 104 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2)); 105 auto computation = module->AddEntryComputation(builder.Build()); 106 107 auto embedded_computations = computation->MakeEmbeddedComputationsList(); 108 EXPECT_EQ(3, embedded_computations.size()); 109 // GetEmbeddedComputations returns a post order of the embedded computations, 110 // so the negate computation must come first. 111 EXPECT_EQ(negate_computation, *embedded_computations.begin()); 112 EXPECT_THAT(embedded_computations, 113 UnorderedElementsAre(negate_computation, map1_computation, 114 map2_computation)); 115 } 116 117 TEST_F(HloComputationTest, PostOrderSingleton) { 118 // Test GetInstructionPostOrder for a computation with one instruction. 119 auto builder = HloComputation::Builder(TestName()); 120 auto constant = builder.AddInstruction( 121 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 122 auto module = CreateNewModule(); 123 auto computation = module->AddEntryComputation(builder.Build()); 124 EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant)); 125 } 126 127 TEST_F(HloComputationTest, PostOrderSimple) { 128 // Test GetInstructionPostOrder for a computation with a chain of 129 // instructions. 130 auto builder = HloComputation::Builder(TestName()); 131 auto constant = builder.AddInstruction( 132 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 133 auto negate1 = builder.AddInstruction( 134 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); 135 auto negate2 = builder.AddInstruction( 136 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1)); 137 auto module = CreateNewModule(); 138 auto computation = module->AddEntryComputation(builder.Build()); 139 EXPECT_THAT(computation->MakeInstructionPostOrder(), 140 ElementsAre(constant, negate1, negate2)); 141 } 142 143 TEST_F(HloComputationTest, PostOrderTrace) { 144 // Test GetInstructionPostOrder for a computation with a trace instruction. 145 auto builder = HloComputation::Builder(TestName()); 146 auto constant = builder.AddInstruction( 147 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 148 auto negate1 = builder.AddInstruction( 149 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); 150 auto trace = 151 builder.AddInstruction(HloInstruction::CreateTrace("foobar", negate1)); 152 auto negate2 = builder.AddInstruction( 153 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1)); 154 auto module = CreateNewModule(); 155 auto computation = module->AddEntryComputation(builder.Build()); 156 // Trace instructions should be at the end of the sort. 157 EXPECT_THAT(computation->MakeInstructionPostOrder(), 158 ElementsAre(constant, negate1, negate2, trace)); 159 } 160 161 TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) { 162 // Test GetInstructionPostOrder for a computation with multiple instructions 163 // which are not connected. 164 auto builder = HloComputation::Builder(TestName()); 165 auto constant1 = builder.AddInstruction( 166 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 167 auto constant2 = builder.AddInstruction( 168 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 169 auto constant3 = builder.AddInstruction( 170 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 171 auto constant4 = builder.AddInstruction( 172 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 173 auto module = CreateNewModule(); 174 auto computation = module->AddEntryComputation(builder.Build()); 175 EXPECT_THAT(computation->MakeInstructionPostOrder(), 176 UnorderedElementsAre(constant1, constant2, constant3, constant4)); 177 } 178 179 TEST_F(HloComputationTest, PostOrderWithMultipleRoots) { 180 // Test GetInstructionPostOrder for a computation with multiple instructions 181 // which are not connected. 182 auto builder = HloComputation::Builder(TestName()); 183 auto constant1 = builder.AddInstruction( 184 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 185 auto constant2 = builder.AddInstruction( 186 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 187 auto constant3 = builder.AddInstruction( 188 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 189 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( 190 r0f32_, HloOpcode::kAdd, constant1, constant2)); 191 auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( 192 r0f32_, HloOpcode::kAdd, constant2, constant3)); 193 auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( 194 r0f32_, HloOpcode::kAdd, constant1, constant3)); 195 auto module = CreateNewModule(); 196 auto computation = module->AddEntryComputation(builder.Build()); 197 auto post_order = computation->MakeInstructionPostOrder(); 198 EXPECT_EQ(6, post_order.size()); 199 EXPECT_THAT(post_order, UnorderedElementsAre(constant1, constant2, constant3, 200 add1, add2, add3)); 201 } 202 203 TEST_F(HloComputationTest, VisitWithMultipleRoots) { 204 // Test that Accept visits all instructions in the computation even if the 205 // computation has multiple roots (dead code). 206 auto builder = HloComputation::Builder(TestName()); 207 auto constant1 = builder.AddInstruction( 208 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 209 auto constant2 = builder.AddInstruction( 210 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 211 auto constant3 = builder.AddInstruction( 212 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 213 // Add three disconnected add expressions. 214 builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, 215 constant1, constant2)); 216 builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, 217 constant2, constant3)); 218 builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, 219 constant1, constant3)); 220 auto module = CreateNewModule(); 221 auto computation = module->AddEntryComputation(builder.Build()); 222 // Visitor which keeps track of which instructions have been visited. 223 class TestVisitor : public DfsHloVisitorWithDefault { 224 public: 225 explicit TestVisitor(HloComputation* computation) 226 : computation_(computation) {} 227 228 Status DefaultAction(HloInstruction* hlo_instruction) override { 229 EXPECT_EQ(0, visited_set_.count(hlo_instruction)); 230 visited_set_.insert(hlo_instruction); 231 last_visited_ = hlo_instruction; 232 return Status::OK(); 233 } 234 235 Status FinishVisit(HloInstruction* root) override { 236 EXPECT_EQ(computation_->root_instruction(), root); 237 ++finish_visit_calls_; 238 return Status::OK(); 239 } 240 241 HloComputation* computation_; 242 std::set<HloInstruction*> visited_set_; 243 int64 finish_visit_calls_ = 0; 244 HloInstruction* last_visited_ = nullptr; 245 }; 246 247 TestVisitor visitor(computation); 248 EXPECT_IS_OK(computation->Accept(&visitor)); 249 250 EXPECT_EQ(6, visitor.visited_set_.size()); 251 EXPECT_EQ(1, visitor.finish_visit_calls_); 252 EXPECT_EQ(computation->root_instruction(), visitor.last_visited_); 253 } 254 255 TEST_F(HloComputationTest, DeepCopyArray) { 256 // Test that DeepCopyInstruction properly copies an array. 257 auto builder = HloComputation::Builder(TestName()); 258 auto constant = builder.AddInstruction(HloInstruction::CreateConstant( 259 Literal::CreateR1<float>({1.0, 2.0, 3.0}))); 260 auto module = CreateNewModule(); 261 auto computation = module->AddEntryComputation(builder.Build()); 262 auto copy = computation->DeepCopyInstruction(constant).ValueOrDie(); 263 264 EXPECT_THAT(copy, op::Copy(constant)); 265 } 266 267 TEST_F(HloComputationTest, DeepCopyTuple) { 268 // Test that DeepCopyInstruction properly copies a tuple. 269 auto builder = HloComputation::Builder(TestName()); 270 auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( 271 Literal::CreateR1<float>({1.0, 2.0, 3.0}))); 272 auto constant2 = builder.AddInstruction( 273 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0))); 274 auto tuple = builder.AddInstruction( 275 HloInstruction::CreateTuple({constant1, constant2})); 276 277 auto module = CreateNewModule(); 278 auto computation = module->AddEntryComputation(builder.Build()); 279 auto tuple_copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); 280 281 EXPECT_THAT(tuple_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)), 282 op::Copy(op::GetTupleElement(tuple)))); 283 EXPECT_EQ(0, tuple_copy->operand(0)->operand(0)->tuple_index()); 284 EXPECT_EQ(1, tuple_copy->operand(1)->operand(0)->tuple_index()); 285 } 286 287 TEST_F(HloComputationTest, DeepCopyArrayAtIndices) { 288 // Test that DeepCopyInstruction properly handles an array when the indices to 289 // copy are specified. 290 auto builder = HloComputation::Builder(TestName()); 291 auto constant = builder.AddInstruction(HloInstruction::CreateConstant( 292 Literal::CreateR1<float>({1.0, 2.0, 3.0}))); 293 auto computation = builder.Build(); 294 295 { 296 // If the index is true, then a copy should be made. 297 ShapeTree<bool> indices_to_copy(constant->shape(), /*init_value=*/true); 298 EXPECT_THAT(computation->DeepCopyInstruction(constant, &indices_to_copy) 299 .ValueOrDie(), 300 op::Copy(constant)); 301 } 302 303 { 304 // If the index is false, then no copy should be made. 305 ShapeTree<bool> indices_to_copy(constant->shape(), /*init_value=*/false); 306 EXPECT_EQ(computation->DeepCopyInstruction(constant, &indices_to_copy) 307 .ValueOrDie(), 308 constant); 309 } 310 } 311 312 TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { 313 // Test that DeepCopyInstruction properly copies elements of a tuple as 314 // specified by the given indices. 315 auto builder = HloComputation::Builder(TestName()); 316 auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( 317 Literal::CreateR1<float>({1.0, 2.0, 3.0}))); 318 auto constant2 = builder.AddInstruction( 319 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0))); 320 auto tuple = builder.AddInstruction( 321 HloInstruction::CreateTuple({constant1, constant2})); 322 auto computation = builder.Build(); 323 324 { 325 // All true values should copy all array elements. 326 ShapeTree<bool> indices_to_copy(tuple->shape(), /*init_value=*/true); 327 ShapeTree<HloInstruction*> copies_added(tuple->shape(), 328 /*init_value=*/nullptr); 329 HloInstruction* deep_copy = 330 computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added) 331 .ValueOrDie(); 332 333 EXPECT_THAT(deep_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)), 334 op::Copy(op::GetTupleElement(tuple)))); 335 EXPECT_THAT(deep_copy, op::Tuple(copies_added.element({0}), 336 copies_added.element({1}))); 337 } 338 339 { 340 // All false elements should copy no array elements, but the GTE and tuple 341 // instruction scaffolding should be built. 342 ShapeTree<bool> indices_to_copy(tuple->shape(), /*init_value=*/false); 343 ShapeTree<HloInstruction*> copies_added(tuple->shape(), 344 /*init_value=*/nullptr); 345 HloInstruction* deep_copy = 346 computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added) 347 .ValueOrDie(); 348 349 EXPECT_THAT(deep_copy, op::Tuple(op::GetTupleElement(tuple), 350 op::GetTupleElement(tuple))); 351 EXPECT_TRUE(copies_added.element({}) == nullptr); 352 EXPECT_TRUE(copies_added.element({0}) == nullptr); 353 EXPECT_TRUE(copies_added.element({1}) == nullptr); 354 } 355 356 { 357 // Verify one element copied, the other not. 358 ShapeTree<bool> indices_to_copy(tuple->shape(), /*init_value=*/false); 359 *indices_to_copy.mutable_element({0}) = true; 360 ShapeTree<HloInstruction*> copies_added(tuple->shape(), 361 /*init_value=*/nullptr); 362 HloInstruction* deep_copy = 363 computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added) 364 .ValueOrDie(); 365 366 EXPECT_THAT(deep_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)), 367 op::GetTupleElement(tuple))); 368 EXPECT_TRUE(copies_added.element({}) == nullptr); 369 EXPECT_TRUE(copies_added.element({0}) != nullptr); 370 EXPECT_TRUE(copies_added.element({1}) == nullptr); 371 } 372 } 373 374 TEST_F(HloComputationTest, CycleDetection) { 375 // Test whether the visitor can detect cycles in the graph. 376 auto builder = HloComputation::Builder(TestName()); 377 auto constant = builder.AddInstruction( 378 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 379 auto negate = builder.AddInstruction( 380 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); 381 auto add = builder.AddInstruction( 382 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, negate, negate)); 383 auto module = CreateNewModule(); 384 auto computation = module->AddEntryComputation(builder.Build()); 385 // Add a control dependency to create a cycle. 386 ASSERT_IS_OK(add->AddControlDependencyTo(negate)); 387 388 const auto visitor = [](HloInstruction* instruction) { return Status::OK(); }; 389 auto visit_status = computation->Accept(visitor); 390 ASSERT_FALSE(visit_status.ok()); 391 ASSERT_THAT(visit_status.error_message(), 392 ::testing::ContainsRegex("cycle is detecte")); 393 } 394 395 TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { 396 // Test RemoveInstructionAndUnusedOperands with an instruction which has a 397 // duplicated (dead) operand. This verifies that the operand is not deleted 398 // twice. 399 auto builder = HloComputation::Builder(TestName()); 400 auto constant = builder.AddInstruction( 401 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 402 auto dead_negate = builder.AddInstruction( 403 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); 404 auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary( 405 r0f32_, HloOpcode::kAdd, dead_negate, dead_negate)); 406 auto negate = builder.AddInstruction( 407 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); 408 auto module = CreateNewModule(); 409 auto computation = module->AddEntryComputation(builder.Build()); 410 EXPECT_EQ(4, computation->instruction_count()); 411 EXPECT_THAT(computation->root_instruction(), op::Negate(constant)); 412 EXPECT_EQ(negate, computation->root_instruction()); 413 414 ASSERT_IS_OK(computation->RemoveInstructionAndUnusedOperands(dead_add)); 415 416 EXPECT_EQ(2, computation->instruction_count()); 417 EXPECT_THAT(computation->root_instruction(), op::Negate(constant)); 418 EXPECT_EQ(negate, computation->root_instruction()); 419 } 420 421 TEST_F(HloComputationTest, CloneWithControlDependency) { 422 auto builder = HloComputation::Builder(TestName()); 423 auto constant1 = builder.AddInstruction( 424 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f))); 425 auto constant2 = builder.AddInstruction( 426 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0f))); 427 auto add = builder.AddInstruction(HloInstruction::CreateBinary( 428 r0f32_, HloOpcode::kAdd, constant1, constant2)); 429 430 auto param = builder.AddInstruction( 431 HloInstruction::CreateParameter(0, r0f32_, "param0")); 432 auto negate = builder.AddInstruction( 433 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); 434 auto module = CreateNewModule(); 435 auto computation = 436 module->AddEntryComputation(builder.Build(/*root_instruction=*/add)); 437 438 TF_CHECK_OK(negate->AddControlDependencyTo(add)); 439 440 auto clone = computation->Clone(); 441 442 auto cloned_add = clone->root_instruction(); 443 EXPECT_EQ(cloned_add->opcode(), HloOpcode::kAdd); 444 445 auto predecessors = cloned_add->control_predecessors(); 446 EXPECT_EQ(1, predecessors.size()); 447 EXPECT_EQ(HloOpcode::kNegate, predecessors[0]->opcode()); 448 auto successors = predecessors[0]->control_successors(); 449 EXPECT_THAT(successors, ::testing::ElementsAre(cloned_add)); 450 } 451 452 TEST_F(HloComputationTest, Reachability) { 453 // Test reachability of a non-trivial computation: 454 // 455 // const1 const2 456 // | | 457 // | +-------+ 458 // | | | 459 // add .. negate 460 // | . | 461 // | .... exp 462 // | | 463 // +---+ +-+---+ 464 // | | | 465 // multiply copy 466 // 467 // There is a control dependency from 'add' to 'exp'. 468 auto builder = HloComputation::Builder(TestName()); 469 auto constant1 = builder.AddInstruction( 470 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f))); 471 auto constant2 = builder.AddInstruction( 472 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0f))); 473 auto add = builder.AddInstruction(HloInstruction::CreateBinary( 474 r0f32_, HloOpcode::kAdd, constant1, constant2)); 475 auto negate = builder.AddInstruction( 476 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant2)); 477 auto exp = builder.AddInstruction( 478 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, negate)); 479 auto mul = builder.AddInstruction( 480 HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, add, exp)); 481 auto copy = builder.AddInstruction( 482 HloInstruction::CreateUnary(r0f32_, HloOpcode::kCopy, exp)); 483 484 auto module = CreateNewModule(); 485 auto computation = 486 module->AddEntryComputation(builder.Build(/*root_instruction=*/mul)); 487 488 TF_CHECK_OK(add->AddControlDependencyTo(exp)); 489 auto reachability = computation->ComputeReachability(); 490 491 EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); 492 EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); 493 EXPECT_TRUE(reachability->IsReachable(constant1, add)); 494 EXPECT_FALSE(reachability->IsReachable(constant1, negate)); 495 EXPECT_TRUE(reachability->IsReachable(constant1, exp)); 496 EXPECT_TRUE(reachability->IsReachable(constant1, mul)); 497 EXPECT_TRUE(reachability->IsReachable(constant1, copy)); 498 499 EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); 500 EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); 501 EXPECT_TRUE(reachability->IsReachable(constant2, add)); 502 EXPECT_TRUE(reachability->IsReachable(constant2, negate)); 503 EXPECT_TRUE(reachability->IsReachable(constant2, exp)); 504 EXPECT_TRUE(reachability->IsReachable(constant2, mul)); 505 EXPECT_TRUE(reachability->IsReachable(constant2, copy)); 506 507 EXPECT_FALSE(reachability->IsReachable(exp, constant1)); 508 EXPECT_FALSE(reachability->IsReachable(exp, constant2)); 509 EXPECT_FALSE(reachability->IsReachable(exp, add)); 510 EXPECT_FALSE(reachability->IsReachable(exp, negate)); 511 EXPECT_TRUE(reachability->IsReachable(exp, exp)); 512 EXPECT_TRUE(reachability->IsReachable(exp, mul)); 513 EXPECT_TRUE(reachability->IsReachable(exp, copy)); 514 515 EXPECT_FALSE(reachability->IsReachable(mul, constant1)); 516 EXPECT_FALSE(reachability->IsReachable(mul, constant2)); 517 EXPECT_FALSE(reachability->IsReachable(mul, add)); 518 EXPECT_FALSE(reachability->IsReachable(mul, negate)); 519 EXPECT_FALSE(reachability->IsReachable(mul, exp)); 520 EXPECT_TRUE(reachability->IsReachable(mul, mul)); 521 EXPECT_FALSE(reachability->IsReachable(mul, copy)); 522 523 EXPECT_TRUE(reachability->IsConnected(constant1, copy)); 524 EXPECT_TRUE(reachability->IsConnected(copy, constant1)); 525 EXPECT_FALSE(reachability->IsConnected(negate, add)); 526 EXPECT_FALSE(reachability->IsConnected(add, negate)); 527 528 // Remove the control dependency then update and verify the reachability map 529 ASSERT_IS_OK(add->RemoveControlDependencyTo(exp)); 530 computation->UpdateReachabilityThroughInstruction(exp, reachability.get()); 531 532 EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); 533 EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); 534 EXPECT_TRUE(reachability->IsReachable(constant1, add)); 535 EXPECT_FALSE(reachability->IsReachable(constant1, negate)); 536 EXPECT_FALSE(reachability->IsReachable(constant1, exp)); 537 EXPECT_TRUE(reachability->IsReachable(constant1, mul)); 538 EXPECT_FALSE(reachability->IsReachable(constant1, copy)); 539 540 // Change a use within the graph then update and verify the reachability map 541 ASSERT_IS_OK(constant2->ReplaceUseWith(negate, constant1)); 542 computation->UpdateReachabilityThroughInstruction(negate, reachability.get()); 543 544 EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); 545 EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); 546 EXPECT_TRUE(reachability->IsReachable(constant2, add)); 547 EXPECT_FALSE(reachability->IsReachable(constant2, negate)); 548 EXPECT_FALSE(reachability->IsReachable(constant2, exp)); 549 EXPECT_TRUE(reachability->IsReachable(constant2, mul)); 550 EXPECT_FALSE(reachability->IsReachable(constant2, copy)); 551 } 552 553 } // namespace 554 555 } // namespace xla 556