1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 17 18 #include <set> 19 #include <unordered_map> 20 #include <utility> 21 #include <vector> 22 23 #include "tensorflow/compiler/xla/literal_util.h" 24 #include "tensorflow/compiler/xla/protobuf_util.h" 25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 26 #include "tensorflow/compiler/xla/service/hlo_computation.h" 27 #include "tensorflow/compiler/xla/shape_util.h" 28 #include "tensorflow/compiler/xla/test.h" 29 #include "tensorflow/compiler/xla/test_helpers.h" 30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 31 #include "tensorflow/compiler/xla/util.h" 32 33 namespace xla { 34 namespace { 35 36 using ::testing::ElementsAre; 37 using ::testing::UnorderedElementsAre; 38 39 class HloInstructionTest : public HloTestBase { 40 protected: 41 HloInstructionTest() {} 42 43 Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); 44 }; 45 46 // Simple visitor that collects the number of users and operands for certain HLO 47 // nodes. It also verifies some of the DFS visiting invariants (operands visited 48 // before their users, nodes not visited twice, etc.) 49 class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { 50 public: 51 Status DefaultAction(HloInstruction* hlo_instruction) override { 52 return Unimplemented("not implemented %s", 53 HloOpcodeString(hlo_instruction->opcode()).c_str()); 54 } 55 56 Status HandleParameter(HloInstruction* parameter) override { 57 EXPECT_EQ(0, count_.count(parameter)); 58 count_[parameter] = GetCountsForNode(parameter); 59 return Status::OK(); 60 } 61 62 Status HandleConstant(HloInstruction* constant) override { 63 EXPECT_EQ(0, count_.count(constant)); 64 count_[constant] = GetCountsForNode(constant); 65 return Status::OK(); 66 } 67 68 Status HandleAdd(HloInstruction* add) override { 69 auto lhs = add->operand(0); 70 auto rhs = add->operand(1); 71 EXPECT_EQ(0, count_.count(add)); 72 EXPECT_GT(count_.count(lhs), 0); 73 EXPECT_GT(count_.count(rhs), 0); 74 count_[add] = GetCountsForNode(add); 75 return Status::OK(); 76 } 77 78 Status HandleNegate(HloInstruction* negate) override { 79 auto operand = negate->operand(0); 80 EXPECT_EQ(0, count_.count(negate)); 81 EXPECT_GT(count_.count(operand), 0); 82 count_[negate] = GetCountsForNode(negate); 83 return Status::OK(); 84 } 85 86 Status HandleMap(HloInstruction* map) override { 87 EXPECT_EQ(0, count_.count(map)); 88 for (HloInstruction* arg : map->operands()) { 89 EXPECT_GT(count_.count(arg), 0); 90 } 91 count_[map] = GetCountsForNode(map); 92 return Status::OK(); 93 } 94 95 Status HandleReduce(HloInstruction* reduce) override { 96 auto arg = reduce->operand(0); 97 auto init_value = reduce->operand(1); 98 EXPECT_EQ(0, count_.count(reduce)); 99 EXPECT_GT(count_.count(arg), 0); 100 EXPECT_GT(count_.count(init_value), 0); 101 count_[reduce] = GetCountsForNode(reduce); 102 return Status::OK(); 103 } 104 105 int64 NumOperands(const HloInstruction* node) { 106 auto count_iterator = count_.find(node); 107 EXPECT_NE(count_.end(), count_iterator); 108 return count_iterator->second.operand_count; 109 } 110 111 int64 NumUsers(const HloInstruction* node) { 112 auto count_iterator = count_.find(node); 113 EXPECT_NE(count_.end(), count_iterator); 114 return count_iterator->second.user_count; 115 } 116 117 private: 118 struct NumOpsAndUsers { 119 int64 operand_count; 120 int64 user_count; 121 }; 122 123 // Helper function to count operands and users for the given HLO. 124 NumOpsAndUsers GetCountsForNode(const HloInstruction* node) { 125 NumOpsAndUsers counts{node->operand_count(), node->user_count()}; 126 return counts; 127 } 128 129 // Counters for HLOs. Maps HLO to a NumOpsAndUsers. 130 std::unordered_map<const HloInstruction*, NumOpsAndUsers> count_; 131 }; 132 133 TEST_F(HloInstructionTest, BasicProperties) { 134 auto parameter = HloInstruction::CreateParameter(1, r0f32_, "foo"); 135 136 EXPECT_EQ(HloOpcode::kParameter, parameter->opcode()); 137 EXPECT_TRUE(ShapeUtil::IsScalarF32(parameter->shape())); 138 EXPECT_EQ(0, parameter->operand_count()); 139 } 140 141 TEST_F(HloInstructionTest, UserWithTwoOperands) { 142 // [Param foo]-----> |-----| 143 // | Add | 144 // [Param bar]-----> |-----| 145 HloComputation::Builder builder(TestName()); 146 auto foo = 147 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo")); 148 auto bar = 149 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar")); 150 auto add = builder.AddInstruction( 151 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); 152 HloModule module(TestName()); 153 module.AddEntryComputation(builder.Build()); 154 155 EXPECT_THAT(add->operands(), UnorderedElementsAre(foo, bar)); 156 EXPECT_THAT(foo->users(), UnorderedElementsAre(add)); 157 EXPECT_THAT(bar->users(), UnorderedElementsAre(add)); 158 159 OpAndUserCollectingVisitor visitor; 160 ASSERT_IS_OK(add->Accept(&visitor)); 161 162 EXPECT_EQ(2, visitor.NumOperands(add)); 163 EXPECT_EQ(0, visitor.NumUsers(add)); 164 EXPECT_EQ(1, visitor.NumUsers(foo)); 165 EXPECT_EQ(1, visitor.NumUsers(bar)); 166 } 167 168 TEST_F(HloInstructionTest, MultipleUsers) { 169 // [Param foo] 170 // / | \ 171 // / | \ [Param bar] 172 // / | \ | 173 // | | | | 174 // V V V V 175 // ------- ------- ----------- 176 // | exp | | exp | | add | 177 // ------- ------- ----------- 178 HloComputation::Builder builder(TestName()); 179 auto foo = 180 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo")); 181 auto bar = 182 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar")); 183 auto exp1 = builder.AddInstruction( 184 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); 185 auto exp2 = builder.AddInstruction( 186 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); 187 auto add = builder.AddInstruction( 188 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); 189 HloModule module(TestName()); 190 module.AddEntryComputation(builder.Build()); 191 192 EXPECT_EQ(3, foo->user_count()); 193 EXPECT_EQ(1, bar->user_count()); 194 EXPECT_EQ(0, exp1->user_count()); 195 EXPECT_EQ(0, exp2->user_count()); 196 EXPECT_EQ(0, add->user_count()); 197 198 OpAndUserCollectingVisitor visitor; 199 ASSERT_IS_OK(add->Accept(&visitor)); 200 201 EXPECT_EQ(2, visitor.NumOperands(add)); 202 EXPECT_EQ(3, visitor.NumUsers(foo)); 203 } 204 205 TEST_F(HloInstructionTest, RepeatedUser) { 206 // Here we have a user 'add' nodes that uses the same HLO in both operands. 207 // Make sure we don't count it as two distinct users. 208 // 209 // [Param foo] 210 // | | 211 // | | 212 // | | 213 // V V 214 // ------- 215 // | add | 216 // ------- 217 HloComputation::Builder builder(TestName()); 218 auto foo = 219 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo")); 220 auto add = builder.AddInstruction( 221 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); 222 HloModule module(TestName()); 223 module.AddEntryComputation(builder.Build()); 224 225 EXPECT_EQ(1, foo->user_count()); 226 227 // But 'add' still has two operands, even if both are the same HLO. 228 EXPECT_EQ(2, add->operand_count()); 229 } 230 231 TEST_F(HloInstructionTest, MultipleUsersAndOperands) { 232 // [param0] [param1] 233 // | | 234 // | [c0] | 235 // | | | 236 // V | V 237 // ------- | ------- 238 // | add | <---^---> | add | 239 // ------- ------- 240 // | | 241 // \ ------- / 242 // ---->| add |<---- 243 // ------- 244 HloComputation::Builder builder(TestName()); 245 auto param0 = builder.AddInstruction( 246 HloInstruction::CreateParameter(0, r0f32_, "param0")); 247 auto param1 = builder.AddInstruction( 248 HloInstruction::CreateParameter(1, r0f32_, "param1")); 249 auto c0 = builder.AddInstruction( 250 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f))); 251 auto addleft = builder.AddInstruction( 252 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0, c0)); 253 auto addright = builder.AddInstruction( 254 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c0, param1)); 255 auto addtotal = builder.AddInstruction( 256 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright)); 257 HloModule module(TestName()); 258 module.AddEntryComputation(builder.Build()); 259 260 OpAndUserCollectingVisitor visitor; 261 ASSERT_IS_OK(addtotal->Accept(&visitor)); 262 263 EXPECT_EQ(2, visitor.NumUsers(c0)); 264 EXPECT_EQ(2, visitor.NumOperands(addleft)); 265 EXPECT_EQ(2, visitor.NumOperands(addright)); 266 EXPECT_EQ(2, visitor.NumOperands(addtotal)); 267 } 268 269 TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) { 270 // [param0] [c0] [param1] 271 // | | | 272 // | V | 273 // | ------- | 274 // | | neg | | 275 // | ------- | 276 // V | V 277 // ------- | ------- 278 // | add | <---^---> | add | 279 // ------- ------- 280 // | | 281 // \ ------- / 282 // ---->| add |<---- 283 // ------- 284 // | 285 // V 286 // ------- 287 // | neg | 288 // ------- 289 HloComputation::Builder builder(TestName()); 290 auto param0 = builder.AddInstruction( 291 HloInstruction::CreateParameter(0, r0f32_, "param0")); 292 auto param1 = builder.AddInstruction( 293 HloInstruction::CreateParameter(1, r0f32_, "param1")); 294 auto c0 = builder.AddInstruction( 295 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f))); 296 auto neg1 = builder.AddInstruction( 297 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, c0)); 298 auto addleft = builder.AddInstruction( 299 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0, neg1)); 300 auto addright = builder.AddInstruction( 301 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, neg1, param1)); 302 auto addtotal = builder.AddInstruction( 303 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright)); 304 auto neg2 = builder.AddInstruction( 305 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, addtotal)); 306 HloModule module(TestName()); 307 module.AddEntryComputation(builder.Build()); 308 309 OpAndUserCollectingVisitor visitor; 310 ASSERT_IS_OK(neg2->Accept(&visitor)); 311 312 EXPECT_EQ(1, visitor.NumUsers(c0)); 313 EXPECT_EQ(2, visitor.NumUsers(neg1)); 314 EXPECT_EQ(2, visitor.NumOperands(addleft)); 315 EXPECT_EQ(2, visitor.NumOperands(addright)); 316 EXPECT_EQ(2, visitor.NumOperands(addtotal)); 317 EXPECT_EQ(1, visitor.NumOperands(neg2)); 318 EXPECT_EQ(0, visitor.NumUsers(neg2)); 319 } 320 321 TEST_F(HloInstructionTest, TrivialMap) { 322 // This tests creating a trivial x+1 map as the only operation. 323 // 324 // param0[100x10] ---> (map x+1) 325 // 326 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 327 Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10}); 328 HloModule module(TestName()); 329 330 // Builds an x+1.0 computation to use in a Map. 331 auto embedded_builder = HloComputation::Builder("f32+1"); 332 auto param = embedded_builder.AddInstruction( 333 HloInstruction::CreateParameter(0, r0f32, "x")); 334 auto value = embedded_builder.AddInstruction( 335 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 336 embedded_builder.AddInstruction( 337 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value)); 338 auto add_f32 = module.AddEmbeddedComputation(embedded_builder.Build()); 339 340 // Builds a parameter and feeds it to the map. 341 HloComputation::Builder builder(TestName()); 342 auto param0 = builder.AddInstruction( 343 HloInstruction::CreateParameter(0, f32a100x10, "")); 344 auto map = builder.AddInstruction( 345 HloInstruction::CreateMap(f32a100x10, {param0}, add_f32)); 346 module.AddEntryComputation(builder.Build()); 347 348 OpAndUserCollectingVisitor visitor; 349 ASSERT_IS_OK(map->Accept(&visitor)); 350 351 // Check counts. We aren't walking the mapper computation yet. 352 EXPECT_EQ(1, visitor.NumUsers(param0)); 353 EXPECT_EQ(0, visitor.NumUsers(map)); 354 EXPECT_EQ(1, visitor.NumOperands(map)); 355 356 // TODO(dehnert): Add walking and counters for the wrapped computation. 357 } 358 359 TEST_F(HloInstructionTest, TrivialReduce) { 360 // This tests creating a trivial x+y reduce as the only operation. 361 // 362 // param0[100x10] ---> (reduce x+y) 363 // 364 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 365 Shape f32v100 = ShapeUtil::MakeShape(F32, {100}); 366 Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10}); 367 368 // Builds an x+y computation to use in a Reduce. 369 auto embedded_builder = HloComputation::Builder("f32+f32"); 370 auto paramx = embedded_builder.AddInstruction( 371 HloInstruction::CreateParameter(0, r0f32, "x")); 372 auto paramy = embedded_builder.AddInstruction( 373 HloInstruction::CreateParameter(1, r0f32, "y")); 374 embedded_builder.AddInstruction( 375 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, paramx, paramy)); 376 HloModule module(TestName()); 377 auto add_f32 = module.AddEmbeddedComputation(embedded_builder.Build()); 378 379 // Builds a parameter and an initial value and feeds them to the reduce. 380 HloComputation::Builder builder(TestName()); 381 auto param0 = builder.AddInstruction( 382 HloInstruction::CreateParameter(0, f32a100x10, "")); 383 auto const0 = builder.AddInstruction( 384 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); 385 builder.AddInstruction( 386 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f))); 387 auto reduce = builder.AddInstruction( 388 HloInstruction::CreateReduce(f32v100, param0, const0, 389 /*dimensions_to_reduce=*/{1}, add_f32)); 390 module.AddEntryComputation(builder.Build()); 391 392 OpAndUserCollectingVisitor visitor; 393 ASSERT_IS_OK(reduce->Accept(&visitor)); 394 395 // Check counts. We aren't walking the reducer computation. 396 EXPECT_EQ(1, visitor.NumUsers(param0)); 397 EXPECT_EQ(1, visitor.NumUsers(const0)); 398 EXPECT_EQ(0, visitor.NumUsers(reduce)); 399 EXPECT_EQ(2, visitor.NumOperands(reduce)); 400 } 401 402 TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) { 403 // Construct a graph of a few binary ops using two different 404 // parameters. Replace one of the parameters with the other parameter in one 405 // of the instructions. 406 HloComputation::Builder builder(TestName()); 407 auto foo = 408 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo")); 409 auto bar = 410 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar")); 411 auto add_foobar = builder.AddInstruction( 412 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); 413 auto add_foofoo = builder.AddInstruction( 414 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); 415 builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, 416 add_foobar, add_foofoo)); 417 HloModule module(TestName()); 418 module.AddEntryComputation(builder.Build()); 419 420 EXPECT_EQ(2, foo->user_count()); 421 EXPECT_EQ(1, bar->user_count()); 422 423 // Replace the use of foo in add_foofoo with bar. 424 ASSERT_IS_OK(foo->ReplaceUseWith(add_foofoo, bar)); 425 426 EXPECT_EQ(1, foo->user_count()); 427 EXPECT_EQ(2, bar->user_count()); 428 429 EXPECT_THAT(foo->users(), UnorderedElementsAre(add_foobar)); 430 EXPECT_THAT(add_foobar->operands(), ElementsAre(foo, bar)); 431 432 EXPECT_THAT(bar->users(), UnorderedElementsAre(add_foobar, add_foofoo)); 433 EXPECT_THAT(add_foobar->operands(), ElementsAre(foo, bar)); 434 EXPECT_THAT(add_foofoo->operands(), ElementsAre(bar, bar)); 435 } 436 437 TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) { 438 // Construct a tuple containing several parameters. Replace one parameter with 439 // another in the tuple. 440 HloComputation::Builder builder(TestName()); 441 auto foo = 442 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo")); 443 auto bar = 444 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar")); 445 auto baz = 446 builder.AddInstruction(HloInstruction::CreateParameter(2, r0f32_, "baz")); 447 448 auto tuple = 449 builder.AddInstruction(HloInstruction::CreateTuple({foo, bar, baz, foo})); 450 auto add_foobar = builder.AddInstruction( 451 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); 452 HloModule module(TestName()); 453 module.AddEntryComputation(builder.Build()); 454 455 EXPECT_EQ(2, foo->user_count()); 456 EXPECT_THAT(foo->users(), UnorderedElementsAre(tuple, add_foobar)); 457 458 // Replace the use of foo in tuple with bar. 459 ASSERT_IS_OK(foo->ReplaceUseWith(tuple, bar)); 460 461 EXPECT_THAT(foo->users(), UnorderedElementsAre(add_foobar)); 462 463 // Both uses of foo in tuple should have been replaced with bar. 464 EXPECT_THAT(tuple->operands(), ElementsAre(bar, bar, baz, bar)); 465 } 466 467 TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) { 468 // Construct a couple unary instructions which use a parameter. Replace the 469 // use of a parameter in one of the unary ops with the other parameter. 470 HloComputation::Builder builder(TestName()); 471 auto foo = 472 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo")); 473 auto bar = 474 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar")); 475 476 auto exp = builder.AddInstruction( 477 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); 478 auto log = builder.AddInstruction( 479 HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo)); 480 HloModule module(TestName()); 481 module.AddEntryComputation(builder.Build()); 482 483 EXPECT_EQ(2, foo->user_count()); 484 EXPECT_THAT(foo->users(), UnorderedElementsAre(exp, log)); 485 EXPECT_EQ(0, bar->user_count()); 486 487 // Replace the use of foo in exp with bar. 488 ASSERT_IS_OK(foo->ReplaceUseWith(exp, bar)); 489 490 // The use of foo in log should not have been affected. 491 EXPECT_EQ(1, foo->user_count()); 492 EXPECT_THAT(foo->users(), UnorderedElementsAre(log)); 493 EXPECT_THAT(log->operands(), ElementsAre(foo)); 494 495 // Bar should now be used in exp. 496 EXPECT_EQ(1, bar->user_count()); 497 EXPECT_EQ(*bar->users().begin(), exp); 498 EXPECT_EQ(1, exp->operands().size()); 499 EXPECT_EQ(*exp->operands().begin(), bar); 500 } 501 502 TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) { 503 // Construct a simple graph of a few binary ops using two different 504 // parameters. Replace all uses of one of the parameters with the other 505 // parameter. 506 HloComputation::Builder builder(TestName()); 507 auto foo = 508 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo")); 509 auto bar = 510 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar")); 511 auto add_foobar = builder.AddInstruction( 512 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); 513 auto add_foofoo = builder.AddInstruction( 514 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); 515 builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, 516 add_foobar, add_foofoo)); 517 HloModule module(TestName()); 518 module.AddEntryComputation(builder.Build()); 519 520 EXPECT_EQ(2, foo->user_count()); 521 EXPECT_EQ(1, bar->user_count()); 522 523 // Replace all uses of foo with bar. 524 ASSERT_IS_OK(foo->ReplaceAllUsesWith(bar)); 525 526 EXPECT_EQ(0, foo->user_count()); 527 EXPECT_EQ(2, bar->user_count()); 528 529 EXPECT_THAT(bar->users(), UnorderedElementsAre(add_foobar, add_foofoo)); 530 } 531 532 TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) { 533 // Construct a graph containing several ops (a unary, binary, and variadic) 534 // which use two parameters. Replace all uses of one of the parameters with 535 // the other parameter. 536 HloComputation::Builder builder(TestName()); 537 auto foo = 538 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo")); 539 auto bar = 540 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar")); 541 542 auto add_foobar = builder.AddInstruction( 543 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); 544 auto exp = builder.AddInstruction( 545 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); 546 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({foo, bar})); 547 HloModule module(TestName()); 548 module.AddEntryComputation(builder.Build()); 549 550 EXPECT_EQ(3, foo->user_count()); 551 EXPECT_EQ(2, bar->user_count()); 552 553 // Replace all uses of foo with bar. 554 ASSERT_IS_OK(foo->ReplaceAllUsesWith(bar)); 555 556 EXPECT_EQ(0, foo->user_count()); 557 EXPECT_EQ(3, bar->user_count()); 558 559 EXPECT_THAT(bar->users(), UnorderedElementsAre(add_foobar, exp, tuple)); 560 } 561 562 // Simple visitor that collects and post-processes each node in the graph. 563 class NodeCollectorAndPostProcessor : public DfsHloVisitorWithDefault { 564 public: 565 NodeCollectorAndPostProcessor() {} 566 567 Status Postprocess(HloInstruction* hlo) override { 568 post_processed_nodes_.push_back(hlo); 569 return Status::OK(); 570 } 571 572 Status DefaultAction(HloInstruction* hlo_instruction) override { 573 visited_nodes_.push_back(hlo_instruction); 574 return Status::OK(); 575 } 576 577 const std::vector<const HloInstruction*>& visited_nodes() { 578 return visited_nodes_; 579 } 580 581 const std::vector<const HloInstruction*>& post_processed_nodes() { 582 return post_processed_nodes_; 583 } 584 585 private: 586 std::vector<const HloInstruction*> visited_nodes_; 587 std::vector<const HloInstruction*> post_processed_nodes_; 588 }; 589 590 // Returns true if "vec" contains distinct nodes. 591 bool Distinct(const std::vector<const HloInstruction*>& vec) { 592 std::set<const HloInstruction*> distinct_nodes(vec.begin(), vec.end()); 593 return distinct_nodes.size() == vec.size(); 594 } 595 596 TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) { 597 // Verifies all the nodes are visited and post-processed in the same order, 598 // and that each node is visited exactly once. 599 // 600 // /--> exp --\ 601 // foo add 602 // \--> log --/ 603 HloComputation::Builder builder(TestName()); 604 auto foo = 605 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo")); 606 auto exp = builder.AddInstruction( 607 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); 608 auto log = builder.AddInstruction( 609 HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo)); 610 auto add = builder.AddInstruction( 611 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp, log)); 612 HloModule module(TestName()); 613 module.AddEntryComputation(builder.Build()); 614 615 NodeCollectorAndPostProcessor visitor; 616 ASSERT_IS_OK(add->Accept(&visitor)); 617 // Verifies all the nodes are visited and post-processed in the same order. 618 EXPECT_EQ(visitor.visited_nodes(), visitor.post_processed_nodes()); 619 // Verifies each node is visited exactly once. 620 EXPECT_TRUE(Distinct(visitor.visited_nodes())); 621 } 622 623 TEST_F(HloInstructionTest, SingletonFusionOp) { 624 HloComputation::Builder builder(TestName()); 625 // Create a fusion instruction containing a single unary operation. 626 auto constant = builder.AddInstruction( 627 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f))); 628 auto exp = builder.AddInstruction( 629 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); 630 HloModule module(TestName()); 631 auto* computation = module.AddEntryComputation(builder.Build()); 632 auto* fusion = computation->CreateFusionInstruction( 633 {exp}, HloInstruction::FusionKind::kLoop); 634 635 EXPECT_THAT(fusion->operands(), ElementsAre(constant)); 636 EXPECT_THAT(constant->users(), ElementsAre(fusion)); 637 } 638 639 TEST_F(HloInstructionTest, BinaryFusionOp) { 640 HloComputation::Builder builder(TestName()); 641 // Create a fusion instruction containing a single binary operation. 642 auto constant1 = builder.AddInstruction( 643 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f))); 644 auto constant2 = builder.AddInstruction( 645 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.1f))); 646 auto add = builder.AddInstruction(HloInstruction::CreateBinary( 647 r0f32_, HloOpcode::kAdd, constant1, constant2)); 648 HloModule module(TestName()); 649 auto* computation = module.AddEntryComputation(builder.Build()); 650 auto* fusion = computation->CreateFusionInstruction( 651 {add}, HloInstruction::FusionKind::kLoop); 652 653 EXPECT_THAT(fusion->operands(), ElementsAre(constant1, constant2)); 654 EXPECT_THAT(constant1->users(), ElementsAre(fusion)); 655 EXPECT_THAT(constant2->users(), ElementsAre(fusion)); 656 } 657 658 TEST_F(HloInstructionTest, ChainFusionOp) { 659 HloComputation::Builder builder(TestName()); 660 // Create a chain of fused unary ops. 661 auto constant = builder.AddInstruction( 662 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f))); 663 auto exp1 = builder.AddInstruction( 664 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); 665 auto exp2 = builder.AddInstruction( 666 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1)); 667 auto exp3 = builder.AddInstruction( 668 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2)); 669 670 HloModule module(TestName()); 671 auto* computation = module.AddEntryComputation(builder.Build()); 672 auto* fusion = computation->CreateFusionInstruction( 673 {exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop); 674 675 EXPECT_THAT(fusion->operands(), ElementsAre(constant)); 676 EXPECT_THAT(constant->users(), ElementsAre(fusion)); 677 } 678 679 TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { 680 HloComputation::Builder builder(TestName()); 681 // Create a chain of fused unary ops. 682 auto constant = builder.AddInstruction( 683 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f))); 684 auto exp1 = builder.AddInstruction( 685 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); 686 auto exp2 = builder.AddInstruction( 687 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1)); 688 OpMetadata metadata; 689 metadata.set_op_name("tf_op"); 690 exp1->set_metadata(metadata); 691 exp2->set_metadata(metadata); 692 693 HloModule module(TestName()); 694 auto* computation = module.AddEntryComputation(builder.Build()); 695 auto* fusion = computation->CreateFusionInstruction( 696 {exp2, exp1}, HloInstruction::FusionKind::kLoop); 697 698 EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata())); 699 EXPECT_TRUE(protobuf_util::ProtobufEquals( 700 metadata, fusion->fused_expression_root()->metadata())); 701 EXPECT_TRUE(protobuf_util::ProtobufEquals( 702 metadata, fusion->fused_expression_root()->operand(0)->metadata())); 703 704 auto cloned = fusion->CloneWithNewOperands(fusion->shape(), {}); 705 EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata())); 706 } 707 708 TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) { 709 HloComputation::Builder builder(TestName()); 710 auto constant = builder.AddInstruction( 711 HloInstruction::CreateConstant(Literal::CreateR2<float>({ 712 {1, 2}, 713 {3, 4}, 714 }))); 715 auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); 716 auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}); 717 auto outfeed10 = builder.AddInstruction( 718 HloInstruction::CreateOutfeed(shape10, constant, "")); 719 auto outfeed01 = builder.AddInstruction( 720 HloInstruction::CreateOutfeed(shape01, constant, "")); 721 722 auto clone01 = builder.AddInstruction(outfeed01->Clone()); 723 auto clone10 = builder.AddInstruction(outfeed10->Clone()); 724 725 EXPECT_TRUE(ShapeUtil::Equal(clone01->outfeed_shape(), shape01)); 726 EXPECT_TRUE(ShapeUtil::Equal(clone10->outfeed_shape(), shape10)); 727 } 728 729 TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) { 730 HloComputation::Builder builder(TestName()); 731 auto* constant = builder.AddInstruction( 732 HloInstruction::CreateConstant(Literal::CreateR2<float>({ 733 {1, 2}, 734 {3, 4}, 735 }))); 736 auto* tuple = 737 builder.AddInstruction(HloInstruction::CreateTuple({constant, constant})); 738 *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {0}) 739 ->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); 740 *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {1}) 741 ->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); 742 auto tuple_clone = tuple->Clone(); 743 EXPECT_TRUE(ShapeUtil::Equal(tuple_clone->shape(), tuple->shape())); 744 } 745 746 TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { 747 // Create a fusion instruction containing a single unary operation. 748 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); 749 HloModule module(TestName()); 750 751 auto make_map_computation = [&]() { 752 auto builder = HloComputation::Builder("FusionMap"); 753 builder.AddInstruction( 754 HloInstruction::CreateParameter(0, scalar_shape, "param")); 755 return module.AddEmbeddedComputation(builder.Build()); 756 }; 757 758 HloComputation* computation_x = make_map_computation(); 759 HloComputation* computation_y = make_map_computation(); 760 761 HloComputation::Builder builder(TestName()); 762 auto constant = builder.AddInstruction( 763 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f))); 764 auto map_1_x = builder.AddInstruction(HloInstruction::CreateMap( 765 scalar_shape, {constant}, computation_x, /*static_operands=*/{})); 766 auto map_2_x = builder.AddInstruction(HloInstruction::CreateMap( 767 scalar_shape, {map_1_x}, computation_x, /*static_operands=*/{})); 768 auto map_3_y = builder.AddInstruction(HloInstruction::CreateMap( 769 scalar_shape, {map_2_x}, computation_y, /*static_operands=*/{})); 770 auto* computation = module.AddEntryComputation(builder.Build()); 771 772 auto* fusion = computation->CreateFusionInstruction( 773 {map_3_y}, HloInstruction::FusionKind::kLoop); 774 auto* fused_computation = fusion->fused_instructions_computation(); 775 EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation)); 776 777 fusion->FuseInstruction(map_2_x); 778 EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation)); 779 780 fusion->FuseInstruction(map_1_x); 781 EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation)); 782 } 783 784 TEST_F(HloInstructionTest, ComplexFusionOp) { 785 HloComputation::Builder builder(TestName()); 786 // Fuse all instructions in complicated expression: 787 // 788 // add = Add(C1, C2) 789 // clamp = Clamp(C2, add, add) 790 // exp = Exp(add) 791 // mul = Mul(exp, C3) 792 // sub = Sub(mul, clamp) 793 // tuple = Tuple({sub, sub, mul, C1}) 794 // 795 // Notable complexities are repeated operands in the same instruction, 796 // different shapes, use of value in different expressions. 797 auto c1 = builder.AddInstruction( 798 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f))); 799 auto c2 = builder.AddInstruction( 800 HloInstruction::CreateConstant(Literal::CreateR0<float>(2.1f))); 801 auto c3 = builder.AddInstruction( 802 HloInstruction::CreateConstant(Literal::CreateR0<float>(9.0f))); 803 804 auto add = builder.AddInstruction( 805 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1, c2)); 806 auto clamp = builder.AddInstruction( 807 HloInstruction::CreateTernary(r0f32_, HloOpcode::kClamp, c2, add, add)); 808 auto exp = builder.AddInstruction( 809 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, add)); 810 auto mul = builder.AddInstruction( 811 HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, exp, c3)); 812 auto sub = builder.AddInstruction( 813 HloInstruction::CreateBinary(r0f32_, HloOpcode::kSubtract, mul, clamp)); 814 auto tuple = 815 builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1})); 816 817 HloModule module(TestName()); 818 auto* computation = module.AddEntryComputation(builder.Build()); 819 auto* fusion = computation->CreateFusionInstruction( 820 {tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); 821 822 // Operands in the fusion instruction's operands() vector should be in the 823 // order in which their users were added fused. 824 EXPECT_THAT(fusion->operands(), ElementsAre(c1, c3, c2)); 825 EXPECT_THAT(c1->users(), ElementsAre(fusion)); 826 } 827 828 // Convenience function for comparing two HloInstructions. 829 static bool Identical(const HloInstruction& instruction1, 830 const HloInstruction& instruction2) { 831 // Verify Identical is reflexive for both instructions. 832 EXPECT_TRUE(instruction1.Identical(instruction1)); 833 EXPECT_TRUE(instruction2.Identical(instruction2)); 834 835 bool is_equal = instruction1.Identical(instruction2); 836 // Verify Identical is symmetric. 837 EXPECT_EQ(is_equal, instruction2.Identical(instruction1)); 838 return is_equal; 839 } 840 841 // Convenience function for comparing two HloInstructions for structural 842 // equality. 843 static bool StructuralEqual(const HloInstruction& instruction1, 844 const HloInstruction& instruction2) { 845 auto eq_operand_shapes = [](const HloInstruction* a, 846 const HloInstruction* b) { 847 return ShapeUtil::Equal(a->shape(), b->shape()); 848 }; 849 auto eq_computations = [](const HloComputation* a, const HloComputation* b) { 850 return *a == *b; 851 }; 852 853 // Verify Identical is reflexive for both instructions. 854 EXPECT_TRUE( 855 instruction1.Identical(instruction1, eq_operand_shapes, eq_computations)); 856 EXPECT_TRUE( 857 instruction2.Identical(instruction2, eq_operand_shapes, eq_computations)); 858 859 bool is_equal = 860 instruction1.Identical(instruction2, eq_operand_shapes, eq_computations); 861 // Verify Identical is symmetric. 862 EXPECT_EQ(is_equal, instruction2.Identical(instruction1, eq_operand_shapes, 863 eq_computations)); 864 return is_equal; 865 } 866 867 TEST_F(HloInstructionTest, IdenticalInstructions) { 868 // Test HloInstruction::Identical with some subset of instructions types. 869 870 // Create a set of random constant operands to use below. Make them matrices 871 // so dimensions are interesting. 872 auto operand1 = HloInstruction::CreateConstant( 873 Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})); 874 auto operand2 = HloInstruction::CreateConstant( 875 Literal::CreateR2<float>({{10.0, 20.0}, {30.0, 40.0}})); 876 auto vector_operand = 877 HloInstruction::CreateConstant(Literal::CreateR1<float>({42.0, 123.0})); 878 Shape shape = operand1->shape(); 879 880 // Convenient short names for the operands. 881 HloInstruction* op1 = operand1.get(); 882 HloInstruction* op2 = operand2.get(); 883 884 // Operations which only depend on their operands and opcode. 885 EXPECT_TRUE( 886 Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1), 887 *HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1))); 888 EXPECT_FALSE( 889 Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1), 890 *HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op2))); 891 EXPECT_FALSE( 892 Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1), 893 *HloInstruction::CreateUnary(shape, HloOpcode::kNegate, op1))); 894 895 // Tuples. 896 EXPECT_TRUE(Identical(*HloInstruction::CreateTuple({op1, op2}), 897 *HloInstruction::CreateTuple({op1, op2}))); 898 EXPECT_FALSE(Identical(*HloInstruction::CreateTuple({op1, op2}), 899 *HloInstruction::CreateTuple({op2, op1}))); 900 901 // Broadcasts. 902 EXPECT_TRUE(Identical(*HloInstruction::CreateBroadcast(shape, op1, {0, 1}), 903 *HloInstruction::CreateBroadcast(shape, op1, {0, 1}))); 904 EXPECT_FALSE(Identical(*HloInstruction::CreateBroadcast(shape, op1, {0, 1}), 905 *HloInstruction::CreateBroadcast(shape, op1, {1, 0}))); 906 Shape bcast_shape1 = ShapeUtil::MakeShape(F32, {2, 2, 42}); 907 Shape bcast_shape2 = ShapeUtil::MakeShape(F32, {2, 2, 123}); 908 EXPECT_FALSE( 909 Identical(*HloInstruction::CreateBroadcast(bcast_shape1, op1, {0, 1}), 910 *HloInstruction::CreateBroadcast(bcast_shape2, op1, {0, 1}))); 911 912 // Binary operands. 913 EXPECT_TRUE(Identical( 914 *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2), 915 *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2))); 916 EXPECT_FALSE(Identical( 917 *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2), 918 *HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op2, op1))); 919 EXPECT_FALSE(Identical( 920 *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2), 921 *HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op1, op2))); 922 } 923 924 TEST_F(HloInstructionTest, FunctionVisitor) { 925 // Verify the function visitor HloInstruction::Accept visits all instructions 926 // from a root properly given the following graph: 927 // 928 // param 929 // / \ 930 // negate exp 931 // \ / 932 // add 933 const Shape f32 = ShapeUtil::MakeShape(F32, {}); 934 HloComputation::Builder builder(TestName()); 935 auto param = 936 builder.AddInstruction(HloInstruction::CreateParameter(0, f32, "0")); 937 auto negate = builder.AddInstruction( 938 HloInstruction::CreateUnary(f32, HloOpcode::kNegate, param)); 939 auto exp = builder.AddInstruction( 940 HloInstruction::CreateUnary(f32, HloOpcode::kExp, param)); 941 auto add = builder.AddInstruction( 942 HloInstruction::CreateBinary(f32, HloOpcode::kAdd, negate, exp)); 943 HloModule module(TestName()); 944 module.AddEntryComputation(builder.Build()); 945 946 int visit_num = 0; 947 std::unordered_map<HloInstruction*, int> visit_order; 948 EXPECT_IS_OK(add->Accept([&visit_num, &visit_order](HloInstruction* inst) { 949 EXPECT_EQ(0, visit_order.count(inst)); 950 visit_order[inst] = visit_num; 951 visit_num++; 952 return Status::OK(); 953 })); 954 955 EXPECT_EQ(0, visit_order.at(param)); 956 // negate and exp can be visited in an arbitrary order. 957 EXPECT_TRUE(visit_order.at(exp) == 1 || visit_order.at(exp) == 2); 958 EXPECT_TRUE(visit_order.at(negate) == 1 || visit_order.at(negate) == 2); 959 EXPECT_NE(visit_order.at(exp), visit_order.at(negate)); 960 EXPECT_EQ(3, visit_order.at(add)); 961 } 962 963 TEST_F(HloInstructionTest, FullyElementwise) { 964 const Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); 965 HloComputation::Builder builder(TestName()); 966 auto x = 967 builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x")); 968 auto y = 969 builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); 970 auto add = builder.AddInstruction( 971 HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, x, y)); 972 HloModule module(TestName()); 973 module.AddEntryComputation(builder.Build()); 974 975 EXPECT_TRUE(add->IsElementwise()); 976 for (int i = 0; i < add->operand_count(); ++i) { 977 EXPECT_TRUE(add->IsElementwiseOnOperand(i)); 978 } 979 } 980 981 TEST_F(HloInstructionTest, PartiallyElementwise) { 982 const Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); 983 const Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 5}); 984 985 // Fused expression: 986 // 987 // p0 p1 p2 p3 988 // \ / / | 989 // mul / | 990 // \ / | 991 // div broadcast 992 // \ / 993 // max 994 // 995 // The fusion instruction is not elementwise on p3 because the broadcast is 996 // not elementwise. 997 HloComputation::Builder builder("PartiallyElementwise"); 998 HloInstruction* p0 = 999 builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "p0")); 1000 HloInstruction* p1 = 1001 builder.AddInstruction(HloInstruction::CreateParameter(1, r2f32, "p1")); 1002 HloInstruction* p2 = 1003 builder.AddInstruction(HloInstruction::CreateParameter(2, r2f32, "p2")); 1004 HloInstruction* p3 = 1005 builder.AddInstruction(HloInstruction::CreateParameter(3, r1f32, "p3")); 1006 HloInstruction* mul = builder.AddInstruction( 1007 HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, p0, p1)); 1008 HloInstruction* div = builder.AddInstruction( 1009 HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, mul, p2)); 1010 // Dimension 0 of shape [5] is mapped to dimension 1 of shape [3x5]. 1011 HloInstruction* broadcast = 1012 builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, p3, {1})); 1013 HloInstruction* max = builder.AddInstruction( 1014 HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast)); 1015 1016 HloModule module(TestName()); 1017 auto* computation = module.AddEntryComputation(builder.Build()); 1018 HloInstruction* fusion = computation->CreateFusionInstruction( 1019 {max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop); 1020 EXPECT_FALSE(fusion->IsElementwise()); 1021 for (int64 operand_idx = 0; operand_idx < fusion->operand_count(); 1022 ++operand_idx) { 1023 const HloInstruction* operand = fusion->operand(operand_idx); 1024 if (operand == p3) { 1025 EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx)); 1026 } else { 1027 EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx)); 1028 } 1029 } 1030 } 1031 1032 TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { 1033 // Fused expression: 1034 // 1035 // x y 1036 // \ / \ 1037 // min broadcast 1038 // \ / 1039 // sub 1040 // 1041 // The fusion instruction is elementwise on `x` because the only path from x 1042 // to sub contains only elementwise operations. It is not elementwise on `y` 1043 // because the path y->broadcast->sub is not all elementwise. 1044 const Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 1045 const Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); 1046 1047 HloComputation::Builder builder("PartiallyElementwiseWithReuse"); 1048 HloInstruction* x = 1049 builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x")); 1050 HloInstruction* y = 1051 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "y")); 1052 HloInstruction* min = builder.AddInstruction( 1053 HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, x, y)); 1054 HloInstruction* broadcast = 1055 builder.AddInstruction(HloInstruction::CreateBroadcast(r1f32, y, {0})); 1056 HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( 1057 r1f32, HloOpcode::kSubtract, min, broadcast)); 1058 1059 HloModule module(TestName()); 1060 auto* computation = module.AddEntryComputation(builder.Build()); 1061 HloInstruction* fusion = computation->CreateFusionInstruction( 1062 {sub, broadcast, min}, HloInstruction::FusionKind::kLoop); 1063 EXPECT_FALSE(fusion->IsElementwise()); 1064 for (int64 operand_idx = 0; operand_idx < fusion->operand_count(); 1065 ++operand_idx) { 1066 if (fusion->operand(operand_idx) == x) { 1067 EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx)); 1068 } else { 1069 EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx)); 1070 } 1071 } 1072 } 1073 1074 TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { 1075 // Fused expression: 1076 // 1077 // x y 1078 // | | 1079 // | transpose 1080 // \ / 1081 // dot 1082 // 1083 // Tests that shapes aren't mangled by Clone(). 1084 const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); 1085 const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); 1086 const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); 1087 const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); 1088 1089 HloComputation::Builder builder("TransposeDot"); 1090 HloInstruction* x = 1091 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); 1092 HloInstruction* y = 1093 builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); 1094 HloInstruction* reshape = 1095 builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); 1096 DotDimensionNumbers dot_dnums; 1097 dot_dnums.add_lhs_contracting_dimensions(1); 1098 dot_dnums.add_rhs_contracting_dimensions(0); 1099 HloInstruction* dot = builder.AddInstruction( 1100 HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); 1101 1102 HloModule module(TestName()); 1103 auto* computation = module.AddEntryComputation(builder.Build()); 1104 HloInstruction* fusion = computation->CreateFusionInstruction( 1105 {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); 1106 1107 auto fusion2 = fusion->Clone(); 1108 const HloInstruction* root = fusion->fused_expression_root(); 1109 const HloInstruction* root2 = fusion2->fused_expression_root(); 1110 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), root2->shape())); 1111 EXPECT_TRUE( 1112 ShapeUtil::Equal(root->operand(0)->shape(), root2->operand(0)->shape())); 1113 EXPECT_TRUE( 1114 ShapeUtil::Equal(root->operand(1)->shape(), root2->operand(1)->shape())); 1115 EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->operand(0)->shape(), 1116 root2->operand(1)->operand(0)->shape())); 1117 EXPECT_TRUE(StructuralEqual(*fusion, *fusion2)); 1118 } 1119 1120 TEST_F(HloInstructionTest, FusionEquality) { 1121 HloModule module(TestName()); 1122 HloComputation::Builder builder(TestName()); 1123 1124 // Create two fusion instructions containing a single unary operation. 1125 auto parameter = 1126 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); 1127 auto exp = builder.AddInstruction( 1128 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, parameter)); 1129 auto neg = builder.AddInstruction( 1130 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, parameter)); 1131 auto* computation = module.AddEntryComputation(builder.Build()); 1132 auto* fusion = computation->CreateFusionInstruction( 1133 {exp}, HloInstruction::FusionKind::kLoop); 1134 auto* fusion2 = computation->CreateFusionInstruction( 1135 {neg}, HloInstruction::FusionKind::kLoop); 1136 EXPECT_FALSE(StructuralEqual(*fusion, *fusion2)); 1137 1138 auto clone = fusion->Clone(); 1139 EXPECT_TRUE(StructuralEqual(*fusion, *clone)); 1140 } 1141 1142 TEST_F(HloInstructionTest, NestedFusionEquality) { 1143 HloModule module(TestName()); 1144 HloComputation::Builder builder(TestName()); 1145 1146 // Build a nested fusion computation. 1147 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); 1148 auto a = builder.AddInstruction(HloInstruction::CreateConstant( 1149 Literal::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}}))); 1150 auto b = builder.AddInstruction(HloInstruction::CreateConstant( 1151 Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}}))); 1152 auto b_t = builder.AddInstruction( 1153 HloInstruction::CreateTranspose(data_shape, b, {1, 0})); 1154 DotDimensionNumbers dot_dnums; 1155 dot_dnums.add_lhs_contracting_dimensions(1); 1156 dot_dnums.add_rhs_contracting_dimensions(0); 1157 auto dot = builder.AddInstruction( 1158 HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); 1159 auto one = builder.AddInstruction( 1160 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 1161 auto add_operand = builder.AddInstruction( 1162 HloInstruction::CreateBroadcast(data_shape, one, {1})); 1163 auto add = builder.AddInstruction(HloInstruction::CreateBinary( 1164 data_shape, HloOpcode::kAdd, dot, add_operand)); 1165 auto sub = builder.AddInstruction(HloInstruction::CreateBinary( 1166 data_shape, HloOpcode::kSubtract, dot, add_operand)); 1167 builder.AddInstruction( 1168 HloInstruction::CreateBinary(data_shape, HloOpcode::kMultiply, add, sub)); 1169 auto computation = module.AddEntryComputation(builder.Build()); 1170 1171 auto nested_fusion = computation->CreateFusionInstruction( 1172 {dot, b_t}, HloInstruction::FusionKind::kTransposeDot); 1173 1174 auto fusion = computation->CreateFusionInstruction( 1175 {add, nested_fusion}, HloInstruction::FusionKind::kOutput); 1176 auto fusion2 = computation->CreateFusionInstruction( 1177 {sub, nested_fusion}, HloInstruction::FusionKind::kOutput); 1178 auto clone = fusion->Clone(); 1179 EXPECT_TRUE(StructuralEqual(*fusion, *clone)); 1180 EXPECT_FALSE(StructuralEqual(*fusion, *fusion2)); 1181 } 1182 1183 TEST_F(HloInstructionTest, CloneSuffixNames) { 1184 // Test that the suffix string added to cloned instructions is not 1185 // duplicated. Rather a numeric incrementing value should be appended. That 1186 // is, we want "foo.clone2", not "foo.clone.clone". 1187 1188 // Test cloning the same instruction multiple times. 1189 auto foo = 1190 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "foo"); 1191 EXPECT_EQ(foo->Clone()->name(), "foo.clone"); 1192 EXPECT_EQ(foo->Clone()->Clone()->name(), "foo.clone2"); 1193 EXPECT_EQ(foo->Clone()->Clone()->Clone()->name(), "foo.clone3"); 1194 1195 // Test custom suffixes. 1196 EXPECT_EQ(foo->Clone("bar")->name(), "foo.bar"); 1197 EXPECT_EQ(foo->Clone("bar")->Clone("bar")->name(), "foo.bar2"); 1198 EXPECT_EQ(foo->Clone("bar")->Clone("bar")->Clone()->name(), "foo.bar2.clone"); 1199 1200 // Test instruction name with a dot. 1201 auto foo_baz = HloInstruction::CreateParameter( 1202 0, ShapeUtil::MakeShape(F32, {}), "foo.baz"); 1203 EXPECT_EQ(foo_baz->Clone()->name(), "foo.baz.clone"); 1204 1205 // Test incrementing a large number after the suffix. 1206 auto foo_clone234 = HloInstruction::CreateParameter( 1207 0, ShapeUtil::MakeShape(F32, {}), "foo.clone234"); 1208 EXPECT_EQ(foo_clone234->Clone()->name(), "foo.clone235"); 1209 1210 // Test a non-numeric string after the cloning suffix. 1211 auto foo_clonexyz = HloInstruction::CreateParameter( 1212 0, ShapeUtil::MakeShape(F32, {}), "foo.clonexyz"); 1213 EXPECT_EQ(foo_clonexyz->Clone()->name(), "foo.clonexyz.clone"); 1214 1215 // Test a name with multiple appearances of the suffix. 1216 auto foo_clone_clone3 = HloInstruction::CreateParameter( 1217 0, ShapeUtil::MakeShape(F32, {}), "foo.clone.clone3"); 1218 EXPECT_EQ(foo_clone_clone3->Clone()->name(), "foo.clone.clone4"); 1219 } 1220 1221 TEST_F(HloInstructionTest, Stringification) { 1222 // Tests stringification of a simple op, fusion, while, and conditional. 1223 const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); 1224 const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); 1225 const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); 1226 const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); 1227 1228 HloComputation::Builder builder("TransposeDot"); 1229 HloInstruction* x = 1230 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); 1231 HloInstruction* y = 1232 builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); 1233 HloInstruction* reshape = 1234 builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); 1235 DotDimensionNumbers dot_dnums; 1236 dot_dnums.add_lhs_contracting_dimensions(1); 1237 dot_dnums.add_rhs_contracting_dimensions(0); 1238 HloInstruction* dot = builder.AddInstruction( 1239 HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); 1240 1241 auto options = HloPrintOptions().set_print_metadata(false); 1242 1243 EXPECT_EQ(dot->ToString(options), 1244 "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} " 1245 "%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}"); 1246 1247 HloModule module(TestName()); 1248 auto* computation = module.AddEntryComputation(builder.Build()); 1249 HloInstruction* fusion = computation->CreateFusionInstruction( 1250 {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); 1251 1252 EXPECT_EQ( 1253 fusion->ToString(options), 1254 "%dot_fusion = f32[5,20]{1,0} fusion(f32[5,10]{1,0} %x, " 1255 "f32[20,10]{1,0} %y), kind=kTransposeDot, calls=%fused_computation"); 1256 1257 HloInstruction* loop = builder.AddInstruction( 1258 HloInstruction::CreateWhile(sout, computation, computation, x)); 1259 EXPECT_EQ(loop->ToString(options), 1260 "%while = f32[5,20]{1,0} while(f32[5,10]{1,0} %x), " 1261 "condition=%TransposeDot, body=%TransposeDot"); 1262 1263 auto pred = builder.AddInstruction( 1264 HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))); 1265 HloInstruction* conditional = 1266 builder.AddInstruction(HloInstruction::CreateConditional( 1267 sout, pred, x, computation, x, computation)); 1268 EXPECT_EQ(conditional->ToString(options), 1269 "%conditional = f32[5,20]{1,0} conditional(pred[] %constant, " 1270 "f32[5,10]{1,0} %x, f32[5,10]{1,0} %x), " 1271 "true_computation=%TransposeDot, false_computation=%TransposeDot"); 1272 } 1273 1274 TEST_F(HloInstructionTest, StringifyGather) { 1275 Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); 1276 Shape gather_indices_tensor_shape = 1277 ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); 1278 Shape gather_result_shape = 1279 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}); 1280 1281 HloComputation::Builder builder("Gather"); 1282 HloInstruction* input = builder.AddInstruction( 1283 HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); 1284 HloInstruction* gather_indices = 1285 builder.AddInstruction(HloInstruction::CreateParameter( 1286 1, gather_indices_tensor_shape, "gather_indices")); 1287 1288 HloInstruction* gather_instruction = 1289 builder.AddInstruction(HloInstruction::CreateGather( 1290 gather_result_shape, input, gather_indices, 1291 HloInstruction::MakeGatherDimNumbers( 1292 /*output_window_dims=*/{4, 5, 6, 7, 8}, 1293 /*elided_window_dims=*/{}, 1294 /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), 1295 /*window_bounds=*/{30, 29, 28, 27, 26})); 1296 1297 HloModule module(TestName()); 1298 module.AddEntryComputation(builder.Build()); 1299 1300 EXPECT_EQ(gather_instruction->ToString(), 1301 "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " 1302 "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " 1303 "s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), " 1304 "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " 1305 "gather_dims_to_operand_dims={0,1,2,3,4}, " 1306 "window_bounds={30,29,28,27,26}"); 1307 } 1308 1309 } // namespace 1310 } // namespace xla 1311