Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     17 
     18 #include <set>
     19 #include <unordered_map>
     20 #include <utility>
     21 #include <vector>
     22 
     23 #include "tensorflow/compiler/xla/literal_util.h"
     24 #include "tensorflow/compiler/xla/protobuf_util.h"
     25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
     26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     27 #include "tensorflow/compiler/xla/shape_util.h"
     28 #include "tensorflow/compiler/xla/test.h"
     29 #include "tensorflow/compiler/xla/test_helpers.h"
     30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     31 #include "tensorflow/compiler/xla/util.h"
     32 
     33 namespace xla {
     34 namespace {
     35 
     36 using ::testing::ElementsAre;
     37 using ::testing::UnorderedElementsAre;
     38 
     39 class HloInstructionTest : public HloTestBase {
     40  protected:
     41   HloInstructionTest() {}
     42 
     43   Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
     44 };
     45 
     46 // Simple visitor that collects the number of users and operands for certain HLO
     47 // nodes. It also verifies some of the DFS visiting invariants (operands visited
     48 // before their users, nodes not visited twice, etc.)
     49 class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault {
     50  public:
     51   Status DefaultAction(HloInstruction* hlo_instruction) override {
     52     return Unimplemented("not implemented %s",
     53                          HloOpcodeString(hlo_instruction->opcode()).c_str());
     54   }
     55 
     56   Status HandleParameter(HloInstruction* parameter) override {
     57     EXPECT_EQ(0, count_.count(parameter));
     58     count_[parameter] = GetCountsForNode(parameter);
     59     return Status::OK();
     60   }
     61 
     62   Status HandleConstant(HloInstruction* constant) override {
     63     EXPECT_EQ(0, count_.count(constant));
     64     count_[constant] = GetCountsForNode(constant);
     65     return Status::OK();
     66   }
     67 
     68   Status HandleAdd(HloInstruction* add) override {
     69     auto lhs = add->operand(0);
     70     auto rhs = add->operand(1);
     71     EXPECT_EQ(0, count_.count(add));
     72     EXPECT_GT(count_.count(lhs), 0);
     73     EXPECT_GT(count_.count(rhs), 0);
     74     count_[add] = GetCountsForNode(add);
     75     return Status::OK();
     76   }
     77 
     78   Status HandleNegate(HloInstruction* negate) override {
     79     auto operand = negate->operand(0);
     80     EXPECT_EQ(0, count_.count(negate));
     81     EXPECT_GT(count_.count(operand), 0);
     82     count_[negate] = GetCountsForNode(negate);
     83     return Status::OK();
     84   }
     85 
     86   Status HandleMap(HloInstruction* map) override {
     87     EXPECT_EQ(0, count_.count(map));
     88     for (HloInstruction* arg : map->operands()) {
     89       EXPECT_GT(count_.count(arg), 0);
     90     }
     91     count_[map] = GetCountsForNode(map);
     92     return Status::OK();
     93   }
     94 
     95   Status HandleReduce(HloInstruction* reduce) override {
     96     auto arg = reduce->operand(0);
     97     auto init_value = reduce->operand(1);
     98     EXPECT_EQ(0, count_.count(reduce));
     99     EXPECT_GT(count_.count(arg), 0);
    100     EXPECT_GT(count_.count(init_value), 0);
    101     count_[reduce] = GetCountsForNode(reduce);
    102     return Status::OK();
    103   }
    104 
    105   int64 NumOperands(const HloInstruction* node) {
    106     auto count_iterator = count_.find(node);
    107     EXPECT_NE(count_.end(), count_iterator);
    108     return count_iterator->second.operand_count;
    109   }
    110 
    111   int64 NumUsers(const HloInstruction* node) {
    112     auto count_iterator = count_.find(node);
    113     EXPECT_NE(count_.end(), count_iterator);
    114     return count_iterator->second.user_count;
    115   }
    116 
    117  private:
    118   struct NumOpsAndUsers {
    119     int64 operand_count;
    120     int64 user_count;
    121   };
    122 
    123   // Helper function to count operands and users for the given HLO.
    124   NumOpsAndUsers GetCountsForNode(const HloInstruction* node) {
    125     NumOpsAndUsers counts{node->operand_count(), node->user_count()};
    126     return counts;
    127   }
    128 
    129   // Counters for HLOs. Maps HLO to a NumOpsAndUsers.
    130   std::unordered_map<const HloInstruction*, NumOpsAndUsers> count_;
    131 };
    132 
    133 TEST_F(HloInstructionTest, BasicProperties) {
    134   auto parameter = HloInstruction::CreateParameter(1, r0f32_, "foo");
    135 
    136   EXPECT_EQ(HloOpcode::kParameter, parameter->opcode());
    137   EXPECT_TRUE(ShapeUtil::IsScalarF32(parameter->shape()));
    138   EXPECT_EQ(0, parameter->operand_count());
    139 }
    140 
    141 TEST_F(HloInstructionTest, UserWithTwoOperands) {
    142   // [Param foo]----->  |-----|
    143   //                    | Add |
    144   // [Param bar]----->  |-----|
    145   HloComputation::Builder builder(TestName());
    146   auto foo =
    147       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
    148   auto bar =
    149       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
    150   auto add = builder.AddInstruction(
    151       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
    152   HloModule module(TestName());
    153   module.AddEntryComputation(builder.Build());
    154 
    155   EXPECT_THAT(add->operands(), UnorderedElementsAre(foo, bar));
    156   EXPECT_THAT(foo->users(), UnorderedElementsAre(add));
    157   EXPECT_THAT(bar->users(), UnorderedElementsAre(add));
    158 
    159   OpAndUserCollectingVisitor visitor;
    160   ASSERT_IS_OK(add->Accept(&visitor));
    161 
    162   EXPECT_EQ(2, visitor.NumOperands(add));
    163   EXPECT_EQ(0, visitor.NumUsers(add));
    164   EXPECT_EQ(1, visitor.NumUsers(foo));
    165   EXPECT_EQ(1, visitor.NumUsers(bar));
    166 }
    167 
    168 TEST_F(HloInstructionTest, MultipleUsers) {
    169   //        [Param foo]
    170   //       /     |     \
    171   //      /      |      \     [Param bar]
    172   //     /       |       \         |
    173   //     |       |       |         |
    174   //     V       V       V         V
    175   //  -------  -------   -----------
    176   //  | exp |  | exp |   |   add   |
    177   //  -------  -------   -----------
    178   HloComputation::Builder builder(TestName());
    179   auto foo =
    180       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
    181   auto bar =
    182       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
    183   auto exp1 = builder.AddInstruction(
    184       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
    185   auto exp2 = builder.AddInstruction(
    186       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
    187   auto add = builder.AddInstruction(
    188       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
    189   HloModule module(TestName());
    190   module.AddEntryComputation(builder.Build());
    191 
    192   EXPECT_EQ(3, foo->user_count());
    193   EXPECT_EQ(1, bar->user_count());
    194   EXPECT_EQ(0, exp1->user_count());
    195   EXPECT_EQ(0, exp2->user_count());
    196   EXPECT_EQ(0, add->user_count());
    197 
    198   OpAndUserCollectingVisitor visitor;
    199   ASSERT_IS_OK(add->Accept(&visitor));
    200 
    201   EXPECT_EQ(2, visitor.NumOperands(add));
    202   EXPECT_EQ(3, visitor.NumUsers(foo));
    203 }
    204 
    205 TEST_F(HloInstructionTest, RepeatedUser) {
    206   // Here we have a user 'add' nodes that uses the same HLO in both operands.
    207   // Make sure we don't count it as two distinct users.
    208   //
    209   //        [Param foo]
    210   //           |   |
    211   //           |   |
    212   //           |   |
    213   //           V   V
    214   //          -------
    215   //          | add |
    216   //          -------
    217   HloComputation::Builder builder(TestName());
    218   auto foo =
    219       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
    220   auto add = builder.AddInstruction(
    221       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
    222   HloModule module(TestName());
    223   module.AddEntryComputation(builder.Build());
    224 
    225   EXPECT_EQ(1, foo->user_count());
    226 
    227   // But 'add' still has two operands, even if both are the same HLO.
    228   EXPECT_EQ(2, add->operand_count());
    229 }
    230 
    231 TEST_F(HloInstructionTest, MultipleUsersAndOperands) {
    232   //        [param0]          [param1]
    233   //           |                 |
    234   //           |       [c0]      |
    235   //           |        |        |
    236   //           V        |        V
    237   //        -------     |     -------
    238   //        | add | <---^---> | add |
    239   //        -------           -------
    240   //           |                 |
    241   //           \     -------     /
    242   //            ---->| add |<----
    243   //                 -------
    244   HloComputation::Builder builder(TestName());
    245   auto param0 = builder.AddInstruction(
    246       HloInstruction::CreateParameter(0, r0f32_, "param0"));
    247   auto param1 = builder.AddInstruction(
    248       HloInstruction::CreateParameter(1, r0f32_, "param1"));
    249   auto c0 = builder.AddInstruction(
    250       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
    251   auto addleft = builder.AddInstruction(
    252       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0, c0));
    253   auto addright = builder.AddInstruction(
    254       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c0, param1));
    255   auto addtotal = builder.AddInstruction(
    256       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright));
    257   HloModule module(TestName());
    258   module.AddEntryComputation(builder.Build());
    259 
    260   OpAndUserCollectingVisitor visitor;
    261   ASSERT_IS_OK(addtotal->Accept(&visitor));
    262 
    263   EXPECT_EQ(2, visitor.NumUsers(c0));
    264   EXPECT_EQ(2, visitor.NumOperands(addleft));
    265   EXPECT_EQ(2, visitor.NumOperands(addright));
    266   EXPECT_EQ(2, visitor.NumOperands(addtotal));
    267 }
    268 
    269 TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) {
    270   //        [param0]   [c0]   [param1]
    271   //           |        |        |
    272   //           |        V        |
    273   //           |     -------     |
    274   //           |     | neg |     |
    275   //           |     -------     |
    276   //           V        |        V
    277   //        -------     |     -------
    278   //        | add | <---^---> | add |
    279   //        -------           -------
    280   //           |                 |
    281   //           \     -------     /
    282   //            ---->| add |<----
    283   //                 -------
    284   //                    |
    285   //                    V
    286   //                 -------
    287   //                 | neg |
    288   //                 -------
    289   HloComputation::Builder builder(TestName());
    290   auto param0 = builder.AddInstruction(
    291       HloInstruction::CreateParameter(0, r0f32_, "param0"));
    292   auto param1 = builder.AddInstruction(
    293       HloInstruction::CreateParameter(1, r0f32_, "param1"));
    294   auto c0 = builder.AddInstruction(
    295       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
    296   auto neg1 = builder.AddInstruction(
    297       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, c0));
    298   auto addleft = builder.AddInstruction(
    299       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0, neg1));
    300   auto addright = builder.AddInstruction(
    301       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, neg1, param1));
    302   auto addtotal = builder.AddInstruction(
    303       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright));
    304   auto neg2 = builder.AddInstruction(
    305       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, addtotal));
    306   HloModule module(TestName());
    307   module.AddEntryComputation(builder.Build());
    308 
    309   OpAndUserCollectingVisitor visitor;
    310   ASSERT_IS_OK(neg2->Accept(&visitor));
    311 
    312   EXPECT_EQ(1, visitor.NumUsers(c0));
    313   EXPECT_EQ(2, visitor.NumUsers(neg1));
    314   EXPECT_EQ(2, visitor.NumOperands(addleft));
    315   EXPECT_EQ(2, visitor.NumOperands(addright));
    316   EXPECT_EQ(2, visitor.NumOperands(addtotal));
    317   EXPECT_EQ(1, visitor.NumOperands(neg2));
    318   EXPECT_EQ(0, visitor.NumUsers(neg2));
    319 }
    320 
    321 TEST_F(HloInstructionTest, TrivialMap) {
    322   // This tests creating a trivial x+1 map as the only operation.
    323   //
    324   // param0[100x10] ---> (map x+1)
    325   //
    326   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    327   Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10});
    328   HloModule module(TestName());
    329 
    330   // Builds an x+1.0 computation to use in a Map.
    331   auto embedded_builder = HloComputation::Builder("f32+1");
    332   auto param = embedded_builder.AddInstruction(
    333       HloInstruction::CreateParameter(0, r0f32, "x"));
    334   auto value = embedded_builder.AddInstruction(
    335       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    336   embedded_builder.AddInstruction(
    337       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value));
    338   auto add_f32 = module.AddEmbeddedComputation(embedded_builder.Build());
    339 
    340   // Builds a parameter and feeds it to the map.
    341   HloComputation::Builder builder(TestName());
    342   auto param0 = builder.AddInstruction(
    343       HloInstruction::CreateParameter(0, f32a100x10, ""));
    344   auto map = builder.AddInstruction(
    345       HloInstruction::CreateMap(f32a100x10, {param0}, add_f32));
    346   module.AddEntryComputation(builder.Build());
    347 
    348   OpAndUserCollectingVisitor visitor;
    349   ASSERT_IS_OK(map->Accept(&visitor));
    350 
    351   // Check counts.  We aren't walking the mapper computation yet.
    352   EXPECT_EQ(1, visitor.NumUsers(param0));
    353   EXPECT_EQ(0, visitor.NumUsers(map));
    354   EXPECT_EQ(1, visitor.NumOperands(map));
    355 
    356   // TODO(dehnert):  Add walking and counters for the wrapped computation.
    357 }
    358 
    359 TEST_F(HloInstructionTest, TrivialReduce) {
    360   // This tests creating a trivial x+y reduce as the only operation.
    361   //
    362   // param0[100x10] ---> (reduce x+y)
    363   //
    364   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    365   Shape f32v100 = ShapeUtil::MakeShape(F32, {100});
    366   Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10});
    367 
    368   // Builds an x+y computation to use in a Reduce.
    369   auto embedded_builder = HloComputation::Builder("f32+f32");
    370   auto paramx = embedded_builder.AddInstruction(
    371       HloInstruction::CreateParameter(0, r0f32, "x"));
    372   auto paramy = embedded_builder.AddInstruction(
    373       HloInstruction::CreateParameter(1, r0f32, "y"));
    374   embedded_builder.AddInstruction(
    375       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, paramx, paramy));
    376   HloModule module(TestName());
    377   auto add_f32 = module.AddEmbeddedComputation(embedded_builder.Build());
    378 
    379   // Builds a parameter and an initial value and feeds them to the reduce.
    380   HloComputation::Builder builder(TestName());
    381   auto param0 = builder.AddInstruction(
    382       HloInstruction::CreateParameter(0, f32a100x10, ""));
    383   auto const0 = builder.AddInstruction(
    384       HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
    385   builder.AddInstruction(
    386       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
    387   auto reduce = builder.AddInstruction(
    388       HloInstruction::CreateReduce(f32v100, param0, const0,
    389                                    /*dimensions_to_reduce=*/{1}, add_f32));
    390   module.AddEntryComputation(builder.Build());
    391 
    392   OpAndUserCollectingVisitor visitor;
    393   ASSERT_IS_OK(reduce->Accept(&visitor));
    394 
    395   // Check counts.  We aren't walking the reducer computation.
    396   EXPECT_EQ(1, visitor.NumUsers(param0));
    397   EXPECT_EQ(1, visitor.NumUsers(const0));
    398   EXPECT_EQ(0, visitor.NumUsers(reduce));
    399   EXPECT_EQ(2, visitor.NumOperands(reduce));
    400 }
    401 
    402 TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) {
    403   // Construct a graph of a few binary ops using two different
    404   // parameters. Replace one of the parameters with the other parameter in one
    405   // of the instructions.
    406   HloComputation::Builder builder(TestName());
    407   auto foo =
    408       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
    409   auto bar =
    410       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
    411   auto add_foobar = builder.AddInstruction(
    412       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
    413   auto add_foofoo = builder.AddInstruction(
    414       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
    415   builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
    416                                                       add_foobar, add_foofoo));
    417   HloModule module(TestName());
    418   module.AddEntryComputation(builder.Build());
    419 
    420   EXPECT_EQ(2, foo->user_count());
    421   EXPECT_EQ(1, bar->user_count());
    422 
    423   // Replace the use of foo in add_foofoo with bar.
    424   ASSERT_IS_OK(foo->ReplaceUseWith(add_foofoo, bar));
    425 
    426   EXPECT_EQ(1, foo->user_count());
    427   EXPECT_EQ(2, bar->user_count());
    428 
    429   EXPECT_THAT(foo->users(), UnorderedElementsAre(add_foobar));
    430   EXPECT_THAT(add_foobar->operands(), ElementsAre(foo, bar));
    431 
    432   EXPECT_THAT(bar->users(), UnorderedElementsAre(add_foobar, add_foofoo));
    433   EXPECT_THAT(add_foobar->operands(), ElementsAre(foo, bar));
    434   EXPECT_THAT(add_foofoo->operands(), ElementsAre(bar, bar));
    435 }
    436 
    437 TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) {
    438   // Construct a tuple containing several parameters. Replace one parameter with
    439   // another in the tuple.
    440   HloComputation::Builder builder(TestName());
    441   auto foo =
    442       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
    443   auto bar =
    444       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
    445   auto baz =
    446       builder.AddInstruction(HloInstruction::CreateParameter(2, r0f32_, "baz"));
    447 
    448   auto tuple =
    449       builder.AddInstruction(HloInstruction::CreateTuple({foo, bar, baz, foo}));
    450   auto add_foobar = builder.AddInstruction(
    451       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
    452   HloModule module(TestName());
    453   module.AddEntryComputation(builder.Build());
    454 
    455   EXPECT_EQ(2, foo->user_count());
    456   EXPECT_THAT(foo->users(), UnorderedElementsAre(tuple, add_foobar));
    457 
    458   // Replace the use of foo in tuple with bar.
    459   ASSERT_IS_OK(foo->ReplaceUseWith(tuple, bar));
    460 
    461   EXPECT_THAT(foo->users(), UnorderedElementsAre(add_foobar));
    462 
    463   // Both uses of foo in tuple should have been replaced with bar.
    464   EXPECT_THAT(tuple->operands(), ElementsAre(bar, bar, baz, bar));
    465 }
    466 
    467 TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) {
    468   // Construct a couple unary instructions which use a parameter. Replace the
    469   // use of a parameter in one of the unary ops with the other parameter.
    470   HloComputation::Builder builder(TestName());
    471   auto foo =
    472       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
    473   auto bar =
    474       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
    475 
    476   auto exp = builder.AddInstruction(
    477       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
    478   auto log = builder.AddInstruction(
    479       HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo));
    480   HloModule module(TestName());
    481   module.AddEntryComputation(builder.Build());
    482 
    483   EXPECT_EQ(2, foo->user_count());
    484   EXPECT_THAT(foo->users(), UnorderedElementsAre(exp, log));
    485   EXPECT_EQ(0, bar->user_count());
    486 
    487   // Replace the use of foo in exp with bar.
    488   ASSERT_IS_OK(foo->ReplaceUseWith(exp, bar));
    489 
    490   // The use of foo in log should not have been affected.
    491   EXPECT_EQ(1, foo->user_count());
    492   EXPECT_THAT(foo->users(), UnorderedElementsAre(log));
    493   EXPECT_THAT(log->operands(), ElementsAre(foo));
    494 
    495   // Bar should now be used in exp.
    496   EXPECT_EQ(1, bar->user_count());
    497   EXPECT_EQ(*bar->users().begin(), exp);
    498   EXPECT_EQ(1, exp->operands().size());
    499   EXPECT_EQ(*exp->operands().begin(), bar);
    500 }
    501 
    502 TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) {
    503   // Construct a simple graph of a few binary ops using two different
    504   // parameters. Replace all uses of one of the parameters with the other
    505   // parameter.
    506   HloComputation::Builder builder(TestName());
    507   auto foo =
    508       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
    509   auto bar =
    510       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
    511   auto add_foobar = builder.AddInstruction(
    512       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
    513   auto add_foofoo = builder.AddInstruction(
    514       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
    515   builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
    516                                                       add_foobar, add_foofoo));
    517   HloModule module(TestName());
    518   module.AddEntryComputation(builder.Build());
    519 
    520   EXPECT_EQ(2, foo->user_count());
    521   EXPECT_EQ(1, bar->user_count());
    522 
    523   // Replace all uses of foo with bar.
    524   ASSERT_IS_OK(foo->ReplaceAllUsesWith(bar));
    525 
    526   EXPECT_EQ(0, foo->user_count());
    527   EXPECT_EQ(2, bar->user_count());
    528 
    529   EXPECT_THAT(bar->users(), UnorderedElementsAre(add_foobar, add_foofoo));
    530 }
    531 
    532 TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) {
    533   // Construct a graph containing several ops (a unary, binary, and variadic)
    534   // which use two parameters. Replace all uses of one of the parameters with
    535   // the other parameter.
    536   HloComputation::Builder builder(TestName());
    537   auto foo =
    538       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
    539   auto bar =
    540       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
    541 
    542   auto add_foobar = builder.AddInstruction(
    543       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
    544   auto exp = builder.AddInstruction(
    545       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
    546   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({foo, bar}));
    547   HloModule module(TestName());
    548   module.AddEntryComputation(builder.Build());
    549 
    550   EXPECT_EQ(3, foo->user_count());
    551   EXPECT_EQ(2, bar->user_count());
    552 
    553   // Replace all uses of foo with bar.
    554   ASSERT_IS_OK(foo->ReplaceAllUsesWith(bar));
    555 
    556   EXPECT_EQ(0, foo->user_count());
    557   EXPECT_EQ(3, bar->user_count());
    558 
    559   EXPECT_THAT(bar->users(), UnorderedElementsAre(add_foobar, exp, tuple));
    560 }
    561 
    562 // Simple visitor that collects and post-processes each node in the graph.
    563 class NodeCollectorAndPostProcessor : public DfsHloVisitorWithDefault {
    564  public:
    565   NodeCollectorAndPostProcessor() {}
    566 
    567   Status Postprocess(HloInstruction* hlo) override {
    568     post_processed_nodes_.push_back(hlo);
    569     return Status::OK();
    570   }
    571 
    572   Status DefaultAction(HloInstruction* hlo_instruction) override {
    573     visited_nodes_.push_back(hlo_instruction);
    574     return Status::OK();
    575   }
    576 
    577   const std::vector<const HloInstruction*>& visited_nodes() {
    578     return visited_nodes_;
    579   }
    580 
    581   const std::vector<const HloInstruction*>& post_processed_nodes() {
    582     return post_processed_nodes_;
    583   }
    584 
    585  private:
    586   std::vector<const HloInstruction*> visited_nodes_;
    587   std::vector<const HloInstruction*> post_processed_nodes_;
    588 };
    589 
    590 // Returns true if "vec" contains distinct nodes.
    591 bool Distinct(const std::vector<const HloInstruction*>& vec) {
    592   std::set<const HloInstruction*> distinct_nodes(vec.begin(), vec.end());
    593   return distinct_nodes.size() == vec.size();
    594 }
    595 
    596 TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) {
    597   // Verifies all the nodes are visited and post-processed in the same order,
    598   // and that each node is visited exactly once.
    599   //
    600   //    /--> exp --\
    601   // foo            add
    602   //    \--> log --/
    603   HloComputation::Builder builder(TestName());
    604   auto foo =
    605       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
    606   auto exp = builder.AddInstruction(
    607       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
    608   auto log = builder.AddInstruction(
    609       HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo));
    610   auto add = builder.AddInstruction(
    611       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp, log));
    612   HloModule module(TestName());
    613   module.AddEntryComputation(builder.Build());
    614 
    615   NodeCollectorAndPostProcessor visitor;
    616   ASSERT_IS_OK(add->Accept(&visitor));
    617   // Verifies all the nodes are visited and post-processed in the same order.
    618   EXPECT_EQ(visitor.visited_nodes(), visitor.post_processed_nodes());
    619   // Verifies each node is visited exactly once.
    620   EXPECT_TRUE(Distinct(visitor.visited_nodes()));
    621 }
    622 
    623 TEST_F(HloInstructionTest, SingletonFusionOp) {
    624   HloComputation::Builder builder(TestName());
    625   // Create a fusion instruction containing a single unary operation.
    626   auto constant = builder.AddInstruction(
    627       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
    628   auto exp = builder.AddInstruction(
    629       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
    630   HloModule module(TestName());
    631   auto* computation = module.AddEntryComputation(builder.Build());
    632   auto* fusion = computation->CreateFusionInstruction(
    633       {exp}, HloInstruction::FusionKind::kLoop);
    634 
    635   EXPECT_THAT(fusion->operands(), ElementsAre(constant));
    636   EXPECT_THAT(constant->users(), ElementsAre(fusion));
    637 }
    638 
    639 TEST_F(HloInstructionTest, BinaryFusionOp) {
    640   HloComputation::Builder builder(TestName());
    641   // Create a fusion instruction containing a single binary operation.
    642   auto constant1 = builder.AddInstruction(
    643       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
    644   auto constant2 = builder.AddInstruction(
    645       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.1f)));
    646   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
    647       r0f32_, HloOpcode::kAdd, constant1, constant2));
    648   HloModule module(TestName());
    649   auto* computation = module.AddEntryComputation(builder.Build());
    650   auto* fusion = computation->CreateFusionInstruction(
    651       {add}, HloInstruction::FusionKind::kLoop);
    652 
    653   EXPECT_THAT(fusion->operands(), ElementsAre(constant1, constant2));
    654   EXPECT_THAT(constant1->users(), ElementsAre(fusion));
    655   EXPECT_THAT(constant2->users(), ElementsAre(fusion));
    656 }
    657 
    658 TEST_F(HloInstructionTest, ChainFusionOp) {
    659   HloComputation::Builder builder(TestName());
    660   // Create a chain of fused unary ops.
    661   auto constant = builder.AddInstruction(
    662       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
    663   auto exp1 = builder.AddInstruction(
    664       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
    665   auto exp2 = builder.AddInstruction(
    666       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
    667   auto exp3 = builder.AddInstruction(
    668       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2));
    669 
    670   HloModule module(TestName());
    671   auto* computation = module.AddEntryComputation(builder.Build());
    672   auto* fusion = computation->CreateFusionInstruction(
    673       {exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop);
    674 
    675   EXPECT_THAT(fusion->operands(), ElementsAre(constant));
    676   EXPECT_THAT(constant->users(), ElementsAre(fusion));
    677 }
    678 
    679 TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
    680   HloComputation::Builder builder(TestName());
    681   // Create a chain of fused unary ops.
    682   auto constant = builder.AddInstruction(
    683       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
    684   auto exp1 = builder.AddInstruction(
    685       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
    686   auto exp2 = builder.AddInstruction(
    687       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
    688   OpMetadata metadata;
    689   metadata.set_op_name("tf_op");
    690   exp1->set_metadata(metadata);
    691   exp2->set_metadata(metadata);
    692 
    693   HloModule module(TestName());
    694   auto* computation = module.AddEntryComputation(builder.Build());
    695   auto* fusion = computation->CreateFusionInstruction(
    696       {exp2, exp1}, HloInstruction::FusionKind::kLoop);
    697 
    698   EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata()));
    699   EXPECT_TRUE(protobuf_util::ProtobufEquals(
    700       metadata, fusion->fused_expression_root()->metadata()));
    701   EXPECT_TRUE(protobuf_util::ProtobufEquals(
    702       metadata, fusion->fused_expression_root()->operand(0)->metadata()));
    703 
    704   auto cloned = fusion->CloneWithNewOperands(fusion->shape(), {});
    705   EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata()));
    706 }
    707 
    708 TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) {
    709   HloComputation::Builder builder(TestName());
    710   auto constant = builder.AddInstruction(
    711       HloInstruction::CreateConstant(Literal::CreateR2<float>({
    712           {1, 2},
    713           {3, 4},
    714       })));
    715   auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
    716   auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1});
    717   auto outfeed10 = builder.AddInstruction(
    718       HloInstruction::CreateOutfeed(shape10, constant, ""));
    719   auto outfeed01 = builder.AddInstruction(
    720       HloInstruction::CreateOutfeed(shape01, constant, ""));
    721 
    722   auto clone01 = builder.AddInstruction(outfeed01->Clone());
    723   auto clone10 = builder.AddInstruction(outfeed10->Clone());
    724 
    725   EXPECT_TRUE(ShapeUtil::Equal(clone01->outfeed_shape(), shape01));
    726   EXPECT_TRUE(ShapeUtil::Equal(clone10->outfeed_shape(), shape10));
    727 }
    728 
    729 TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) {
    730   HloComputation::Builder builder(TestName());
    731   auto* constant = builder.AddInstruction(
    732       HloInstruction::CreateConstant(Literal::CreateR2<float>({
    733           {1, 2},
    734           {3, 4},
    735       })));
    736   auto* tuple =
    737       builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
    738   *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {0})
    739        ->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
    740   *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {1})
    741        ->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
    742   auto tuple_clone = tuple->Clone();
    743   EXPECT_TRUE(ShapeUtil::Equal(tuple_clone->shape(), tuple->shape()));
    744 }
    745 
    746 TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
    747   // Create a fusion instruction containing a single unary operation.
    748   const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
    749   HloModule module(TestName());
    750 
    751   auto make_map_computation = [&]() {
    752     auto builder = HloComputation::Builder("FusionMap");
    753     builder.AddInstruction(
    754         HloInstruction::CreateParameter(0, scalar_shape, "param"));
    755     return module.AddEmbeddedComputation(builder.Build());
    756   };
    757 
    758   HloComputation* computation_x = make_map_computation();
    759   HloComputation* computation_y = make_map_computation();
    760 
    761   HloComputation::Builder builder(TestName());
    762   auto constant = builder.AddInstruction(
    763       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
    764   auto map_1_x = builder.AddInstruction(HloInstruction::CreateMap(
    765       scalar_shape, {constant}, computation_x, /*static_operands=*/{}));
    766   auto map_2_x = builder.AddInstruction(HloInstruction::CreateMap(
    767       scalar_shape, {map_1_x}, computation_x, /*static_operands=*/{}));
    768   auto map_3_y = builder.AddInstruction(HloInstruction::CreateMap(
    769       scalar_shape, {map_2_x}, computation_y, /*static_operands=*/{}));
    770   auto* computation = module.AddEntryComputation(builder.Build());
    771 
    772   auto* fusion = computation->CreateFusionInstruction(
    773       {map_3_y}, HloInstruction::FusionKind::kLoop);
    774   auto* fused_computation = fusion->fused_instructions_computation();
    775   EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
    776 
    777   fusion->FuseInstruction(map_2_x);
    778   EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
    779 
    780   fusion->FuseInstruction(map_1_x);
    781   EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
    782 }
    783 
    784 TEST_F(HloInstructionTest, ComplexFusionOp) {
    785   HloComputation::Builder builder(TestName());
    786   // Fuse all instructions in complicated expression:
    787   //
    788   //   add = Add(C1, C2)
    789   //   clamp = Clamp(C2, add, add)
    790   //   exp = Exp(add)
    791   //   mul = Mul(exp, C3)
    792   //   sub = Sub(mul, clamp)
    793   //   tuple = Tuple({sub, sub, mul, C1})
    794   //
    795   // Notable complexities are repeated operands in the same instruction,
    796   // different shapes, use of value in different expressions.
    797   auto c1 = builder.AddInstruction(
    798       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
    799   auto c2 = builder.AddInstruction(
    800       HloInstruction::CreateConstant(Literal::CreateR0<float>(2.1f)));
    801   auto c3 = builder.AddInstruction(
    802       HloInstruction::CreateConstant(Literal::CreateR0<float>(9.0f)));
    803 
    804   auto add = builder.AddInstruction(
    805       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1, c2));
    806   auto clamp = builder.AddInstruction(
    807       HloInstruction::CreateTernary(r0f32_, HloOpcode::kClamp, c2, add, add));
    808   auto exp = builder.AddInstruction(
    809       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, add));
    810   auto mul = builder.AddInstruction(
    811       HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, exp, c3));
    812   auto sub = builder.AddInstruction(
    813       HloInstruction::CreateBinary(r0f32_, HloOpcode::kSubtract, mul, clamp));
    814   auto tuple =
    815       builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1}));
    816 
    817   HloModule module(TestName());
    818   auto* computation = module.AddEntryComputation(builder.Build());
    819   auto* fusion = computation->CreateFusionInstruction(
    820       {tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
    821 
    822   // Operands in the fusion instruction's operands() vector should be in the
    823   // order in which their users were added fused.
    824   EXPECT_THAT(fusion->operands(), ElementsAre(c1, c3, c2));
    825   EXPECT_THAT(c1->users(), ElementsAre(fusion));
    826 }
    827 
    828 // Convenience function for comparing two HloInstructions.
    829 static bool Identical(const HloInstruction& instruction1,
    830                       const HloInstruction& instruction2) {
    831   // Verify Identical is reflexive for both instructions.
    832   EXPECT_TRUE(instruction1.Identical(instruction1));
    833   EXPECT_TRUE(instruction2.Identical(instruction2));
    834 
    835   bool is_equal = instruction1.Identical(instruction2);
    836   // Verify Identical is symmetric.
    837   EXPECT_EQ(is_equal, instruction2.Identical(instruction1));
    838   return is_equal;
    839 }
    840 
    841 // Convenience function for comparing two HloInstructions for structural
    842 // equality.
    843 static bool StructuralEqual(const HloInstruction& instruction1,
    844                             const HloInstruction& instruction2) {
    845   auto eq_operand_shapes = [](const HloInstruction* a,
    846                               const HloInstruction* b) {
    847     return ShapeUtil::Equal(a->shape(), b->shape());
    848   };
    849   auto eq_computations = [](const HloComputation* a, const HloComputation* b) {
    850     return *a == *b;
    851   };
    852 
    853   // Verify Identical is reflexive for both instructions.
    854   EXPECT_TRUE(
    855       instruction1.Identical(instruction1, eq_operand_shapes, eq_computations));
    856   EXPECT_TRUE(
    857       instruction2.Identical(instruction2, eq_operand_shapes, eq_computations));
    858 
    859   bool is_equal =
    860       instruction1.Identical(instruction2, eq_operand_shapes, eq_computations);
    861   // Verify Identical is symmetric.
    862   EXPECT_EQ(is_equal, instruction2.Identical(instruction1, eq_operand_shapes,
    863                                              eq_computations));
    864   return is_equal;
    865 }
    866 
    867 TEST_F(HloInstructionTest, IdenticalInstructions) {
    868   // Test HloInstruction::Identical with some subset of instructions types.
    869 
    870   // Create a set of random constant operands to use below. Make them matrices
    871   // so dimensions are interesting.
    872   auto operand1 = HloInstruction::CreateConstant(
    873       Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
    874   auto operand2 = HloInstruction::CreateConstant(
    875       Literal::CreateR2<float>({{10.0, 20.0}, {30.0, 40.0}}));
    876   auto vector_operand =
    877       HloInstruction::CreateConstant(Literal::CreateR1<float>({42.0, 123.0}));
    878   Shape shape = operand1->shape();
    879 
    880   // Convenient short names for the operands.
    881   HloInstruction* op1 = operand1.get();
    882   HloInstruction* op2 = operand2.get();
    883 
    884   // Operations which only depend on their operands and opcode.
    885   EXPECT_TRUE(
    886       Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1),
    887                 *HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1)));
    888   EXPECT_FALSE(
    889       Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1),
    890                 *HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op2)));
    891   EXPECT_FALSE(
    892       Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1),
    893                 *HloInstruction::CreateUnary(shape, HloOpcode::kNegate, op1)));
    894 
    895   // Tuples.
    896   EXPECT_TRUE(Identical(*HloInstruction::CreateTuple({op1, op2}),
    897                         *HloInstruction::CreateTuple({op1, op2})));
    898   EXPECT_FALSE(Identical(*HloInstruction::CreateTuple({op1, op2}),
    899                          *HloInstruction::CreateTuple({op2, op1})));
    900 
    901   // Broadcasts.
    902   EXPECT_TRUE(Identical(*HloInstruction::CreateBroadcast(shape, op1, {0, 1}),
    903                         *HloInstruction::CreateBroadcast(shape, op1, {0, 1})));
    904   EXPECT_FALSE(Identical(*HloInstruction::CreateBroadcast(shape, op1, {0, 1}),
    905                          *HloInstruction::CreateBroadcast(shape, op1, {1, 0})));
    906   Shape bcast_shape1 = ShapeUtil::MakeShape(F32, {2, 2, 42});
    907   Shape bcast_shape2 = ShapeUtil::MakeShape(F32, {2, 2, 123});
    908   EXPECT_FALSE(
    909       Identical(*HloInstruction::CreateBroadcast(bcast_shape1, op1, {0, 1}),
    910                 *HloInstruction::CreateBroadcast(bcast_shape2, op1, {0, 1})));
    911 
    912   // Binary operands.
    913   EXPECT_TRUE(Identical(
    914       *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2),
    915       *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2)));
    916   EXPECT_FALSE(Identical(
    917       *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2),
    918       *HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op2, op1)));
    919   EXPECT_FALSE(Identical(
    920       *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2),
    921       *HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op1, op2)));
    922 }
    923 
    924 TEST_F(HloInstructionTest, FunctionVisitor) {
    925   // Verify the function visitor HloInstruction::Accept visits all instructions
    926   // from a root properly given the following graph:
    927   //
    928   //        param
    929   //       /     \
    930   //    negate   exp
    931   //        \    /
    932   //         add
    933   const Shape f32 = ShapeUtil::MakeShape(F32, {});
    934   HloComputation::Builder builder(TestName());
    935   auto param =
    936       builder.AddInstruction(HloInstruction::CreateParameter(0, f32, "0"));
    937   auto negate = builder.AddInstruction(
    938       HloInstruction::CreateUnary(f32, HloOpcode::kNegate, param));
    939   auto exp = builder.AddInstruction(
    940       HloInstruction::CreateUnary(f32, HloOpcode::kExp, param));
    941   auto add = builder.AddInstruction(
    942       HloInstruction::CreateBinary(f32, HloOpcode::kAdd, negate, exp));
    943   HloModule module(TestName());
    944   module.AddEntryComputation(builder.Build());
    945 
    946   int visit_num = 0;
    947   std::unordered_map<HloInstruction*, int> visit_order;
    948   EXPECT_IS_OK(add->Accept([&visit_num, &visit_order](HloInstruction* inst) {
    949     EXPECT_EQ(0, visit_order.count(inst));
    950     visit_order[inst] = visit_num;
    951     visit_num++;
    952     return Status::OK();
    953   }));
    954 
    955   EXPECT_EQ(0, visit_order.at(param));
    956   // negate and exp can be visited in an arbitrary order.
    957   EXPECT_TRUE(visit_order.at(exp) == 1 || visit_order.at(exp) == 2);
    958   EXPECT_TRUE(visit_order.at(negate) == 1 || visit_order.at(negate) == 2);
    959   EXPECT_NE(visit_order.at(exp), visit_order.at(negate));
    960   EXPECT_EQ(3, visit_order.at(add));
    961 }
    962 
    963 TEST_F(HloInstructionTest, FullyElementwise) {
    964   const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
    965   HloComputation::Builder builder(TestName());
    966   auto x =
    967       builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
    968   auto y =
    969       builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y"));
    970   auto add = builder.AddInstruction(
    971       HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, x, y));
    972   HloModule module(TestName());
    973   module.AddEntryComputation(builder.Build());
    974 
    975   EXPECT_TRUE(add->IsElementwise());
    976   for (int i = 0; i < add->operand_count(); ++i) {
    977     EXPECT_TRUE(add->IsElementwiseOnOperand(i));
    978   }
    979 }
    980 
    981 TEST_F(HloInstructionTest, PartiallyElementwise) {
    982   const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
    983   const Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 5});
    984 
    985   // Fused expression:
    986   //
    987   // p0     p1   p2   p3
    988   //   \   /    /     |
    989   //    mul    /      |
    990   //      \   /       |
    991   //       div     broadcast
    992   //          \    /
    993   //           max
    994   //
    995   // The fusion instruction is not elementwise on p3 because the broadcast is
    996   // not elementwise.
    997   HloComputation::Builder builder("PartiallyElementwise");
    998   HloInstruction* p0 =
    999       builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "p0"));
   1000   HloInstruction* p1 =
   1001       builder.AddInstruction(HloInstruction::CreateParameter(1, r2f32, "p1"));
   1002   HloInstruction* p2 =
   1003       builder.AddInstruction(HloInstruction::CreateParameter(2, r2f32, "p2"));
   1004   HloInstruction* p3 =
   1005       builder.AddInstruction(HloInstruction::CreateParameter(3, r1f32, "p3"));
   1006   HloInstruction* mul = builder.AddInstruction(
   1007       HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, p0, p1));
   1008   HloInstruction* div = builder.AddInstruction(
   1009       HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, mul, p2));
   1010   // Dimension 0 of shape [5] is mapped to dimension 1 of shape [3x5].
   1011   HloInstruction* broadcast =
   1012       builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, p3, {1}));
   1013   HloInstruction* max = builder.AddInstruction(
   1014       HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast));
   1015 
   1016   HloModule module(TestName());
   1017   auto* computation = module.AddEntryComputation(builder.Build());
   1018   HloInstruction* fusion = computation->CreateFusionInstruction(
   1019       {max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop);
   1020   EXPECT_FALSE(fusion->IsElementwise());
   1021   for (int64 operand_idx = 0; operand_idx < fusion->operand_count();
   1022        ++operand_idx) {
   1023     const HloInstruction* operand = fusion->operand(operand_idx);
   1024     if (operand == p3) {
   1025       EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx));
   1026     } else {
   1027       EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx));
   1028     }
   1029   }
   1030 }
   1031 
   1032 TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) {
   1033   // Fused expression:
   1034   //
   1035   // x     y
   1036   //  \   / \
   1037   //   min   broadcast
   1038   //     \   /
   1039   //      sub
   1040   //
   1041   // The fusion instruction is elementwise on `x` because the only path from x
   1042   // to sub contains only elementwise operations. It is not elementwise on `y`
   1043   // because the path y->broadcast->sub is not all elementwise.
   1044   const Shape r0f32 = ShapeUtil::MakeShape(F32, {});
   1045   const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
   1046 
   1047   HloComputation::Builder builder("PartiallyElementwiseWithReuse");
   1048   HloInstruction* x =
   1049       builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
   1050   HloInstruction* y =
   1051       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "y"));
   1052   HloInstruction* min = builder.AddInstruction(
   1053       HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, x, y));
   1054   HloInstruction* broadcast =
   1055       builder.AddInstruction(HloInstruction::CreateBroadcast(r1f32, y, {0}));
   1056   HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
   1057       r1f32, HloOpcode::kSubtract, min, broadcast));
   1058 
   1059   HloModule module(TestName());
   1060   auto* computation = module.AddEntryComputation(builder.Build());
   1061   HloInstruction* fusion = computation->CreateFusionInstruction(
   1062       {sub, broadcast, min}, HloInstruction::FusionKind::kLoop);
   1063   EXPECT_FALSE(fusion->IsElementwise());
   1064   for (int64 operand_idx = 0; operand_idx < fusion->operand_count();
   1065        ++operand_idx) {
   1066     if (fusion->operand(operand_idx) == x) {
   1067       EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx));
   1068     } else {
   1069       EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx));
   1070     }
   1071   }
   1072 }
   1073 
   1074 TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
   1075   // Fused expression:
   1076   //
   1077   // x     y
   1078   // |     |
   1079   // |  transpose
   1080   //  \   /
   1081   //   dot
   1082   //
   1083   // Tests that shapes aren't mangled by Clone().
   1084   const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
   1085   const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
   1086   const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
   1087   const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
   1088 
   1089   HloComputation::Builder builder("TransposeDot");
   1090   HloInstruction* x =
   1091       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
   1092   HloInstruction* y =
   1093       builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
   1094   HloInstruction* reshape =
   1095       builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
   1096   DotDimensionNumbers dot_dnums;
   1097   dot_dnums.add_lhs_contracting_dimensions(1);
   1098   dot_dnums.add_rhs_contracting_dimensions(0);
   1099   HloInstruction* dot = builder.AddInstruction(
   1100       HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
   1101 
   1102   HloModule module(TestName());
   1103   auto* computation = module.AddEntryComputation(builder.Build());
   1104   HloInstruction* fusion = computation->CreateFusionInstruction(
   1105       {dot, reshape}, HloInstruction::FusionKind::kTransposeDot);
   1106 
   1107   auto fusion2 = fusion->Clone();
   1108   const HloInstruction* root = fusion->fused_expression_root();
   1109   const HloInstruction* root2 = fusion2->fused_expression_root();
   1110   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), root2->shape()));
   1111   EXPECT_TRUE(
   1112       ShapeUtil::Equal(root->operand(0)->shape(), root2->operand(0)->shape()));
   1113   EXPECT_TRUE(
   1114       ShapeUtil::Equal(root->operand(1)->shape(), root2->operand(1)->shape()));
   1115   EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->operand(0)->shape(),
   1116                                root2->operand(1)->operand(0)->shape()));
   1117   EXPECT_TRUE(StructuralEqual(*fusion, *fusion2));
   1118 }
   1119 
   1120 TEST_F(HloInstructionTest, FusionEquality) {
   1121   HloModule module(TestName());
   1122   HloComputation::Builder builder(TestName());
   1123 
   1124   // Create two fusion instructions containing a single unary operation.
   1125   auto parameter =
   1126       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
   1127   auto exp = builder.AddInstruction(
   1128       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, parameter));
   1129   auto neg = builder.AddInstruction(
   1130       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, parameter));
   1131   auto* computation = module.AddEntryComputation(builder.Build());
   1132   auto* fusion = computation->CreateFusionInstruction(
   1133       {exp}, HloInstruction::FusionKind::kLoop);
   1134   auto* fusion2 = computation->CreateFusionInstruction(
   1135       {neg}, HloInstruction::FusionKind::kLoop);
   1136   EXPECT_FALSE(StructuralEqual(*fusion, *fusion2));
   1137 
   1138   auto clone = fusion->Clone();
   1139   EXPECT_TRUE(StructuralEqual(*fusion, *clone));
   1140 }
   1141 
   1142 TEST_F(HloInstructionTest, NestedFusionEquality) {
   1143   HloModule module(TestName());
   1144   HloComputation::Builder builder(TestName());
   1145 
   1146   // Build a nested fusion computation.
   1147   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
   1148   auto a = builder.AddInstruction(HloInstruction::CreateConstant(
   1149       Literal::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
   1150   auto b = builder.AddInstruction(HloInstruction::CreateConstant(
   1151       Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
   1152   auto b_t = builder.AddInstruction(
   1153       HloInstruction::CreateTranspose(data_shape, b, {1, 0}));
   1154   DotDimensionNumbers dot_dnums;
   1155   dot_dnums.add_lhs_contracting_dimensions(1);
   1156   dot_dnums.add_rhs_contracting_dimensions(0);
   1157   auto dot = builder.AddInstruction(
   1158       HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums));
   1159   auto one = builder.AddInstruction(
   1160       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
   1161   auto add_operand = builder.AddInstruction(
   1162       HloInstruction::CreateBroadcast(data_shape, one, {1}));
   1163   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
   1164       data_shape, HloOpcode::kAdd, dot, add_operand));
   1165   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
   1166       data_shape, HloOpcode::kSubtract, dot, add_operand));
   1167   builder.AddInstruction(
   1168       HloInstruction::CreateBinary(data_shape, HloOpcode::kMultiply, add, sub));
   1169   auto computation = module.AddEntryComputation(builder.Build());
   1170 
   1171   auto nested_fusion = computation->CreateFusionInstruction(
   1172       {dot, b_t}, HloInstruction::FusionKind::kTransposeDot);
   1173 
   1174   auto fusion = computation->CreateFusionInstruction(
   1175       {add, nested_fusion}, HloInstruction::FusionKind::kOutput);
   1176   auto fusion2 = computation->CreateFusionInstruction(
   1177       {sub, nested_fusion}, HloInstruction::FusionKind::kOutput);
   1178   auto clone = fusion->Clone();
   1179   EXPECT_TRUE(StructuralEqual(*fusion, *clone));
   1180   EXPECT_FALSE(StructuralEqual(*fusion, *fusion2));
   1181 }
   1182 
   1183 TEST_F(HloInstructionTest, CloneSuffixNames) {
   1184   // Test that the suffix string added to cloned instructions is not
   1185   // duplicated. Rather a numeric incrementing value should be appended. That
   1186   // is, we want "foo.clone2", not "foo.clone.clone".
   1187 
   1188   // Test cloning the same instruction multiple times.
   1189   auto foo =
   1190       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "foo");
   1191   EXPECT_EQ(foo->Clone()->name(), "foo.clone");
   1192   EXPECT_EQ(foo->Clone()->Clone()->name(), "foo.clone2");
   1193   EXPECT_EQ(foo->Clone()->Clone()->Clone()->name(), "foo.clone3");
   1194 
   1195   // Test custom suffixes.
   1196   EXPECT_EQ(foo->Clone("bar")->name(), "foo.bar");
   1197   EXPECT_EQ(foo->Clone("bar")->Clone("bar")->name(), "foo.bar2");
   1198   EXPECT_EQ(foo->Clone("bar")->Clone("bar")->Clone()->name(), "foo.bar2.clone");
   1199 
   1200   // Test instruction name with a dot.
   1201   auto foo_baz = HloInstruction::CreateParameter(
   1202       0, ShapeUtil::MakeShape(F32, {}), "foo.baz");
   1203   EXPECT_EQ(foo_baz->Clone()->name(), "foo.baz.clone");
   1204 
   1205   // Test incrementing a large number after the suffix.
   1206   auto foo_clone234 = HloInstruction::CreateParameter(
   1207       0, ShapeUtil::MakeShape(F32, {}), "foo.clone234");
   1208   EXPECT_EQ(foo_clone234->Clone()->name(), "foo.clone235");
   1209 
   1210   // Test a non-numeric string after the cloning suffix.
   1211   auto foo_clonexyz = HloInstruction::CreateParameter(
   1212       0, ShapeUtil::MakeShape(F32, {}), "foo.clonexyz");
   1213   EXPECT_EQ(foo_clonexyz->Clone()->name(), "foo.clonexyz.clone");
   1214 
   1215   // Test a name with multiple appearances of the suffix.
   1216   auto foo_clone_clone3 = HloInstruction::CreateParameter(
   1217       0, ShapeUtil::MakeShape(F32, {}), "foo.clone.clone3");
   1218   EXPECT_EQ(foo_clone_clone3->Clone()->name(), "foo.clone.clone4");
   1219 }
   1220 
   1221 TEST_F(HloInstructionTest, Stringification) {
   1222   // Tests stringification of a simple op, fusion, while, and conditional.
   1223   const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
   1224   const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
   1225   const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
   1226   const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
   1227 
   1228   HloComputation::Builder builder("TransposeDot");
   1229   HloInstruction* x =
   1230       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
   1231   HloInstruction* y =
   1232       builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
   1233   HloInstruction* reshape =
   1234       builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
   1235   DotDimensionNumbers dot_dnums;
   1236   dot_dnums.add_lhs_contracting_dimensions(1);
   1237   dot_dnums.add_rhs_contracting_dimensions(0);
   1238   HloInstruction* dot = builder.AddInstruction(
   1239       HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
   1240 
   1241   auto options = HloPrintOptions().set_print_metadata(false);
   1242 
   1243   EXPECT_EQ(dot->ToString(options),
   1244             "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} "
   1245             "%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}");
   1246 
   1247   HloModule module(TestName());
   1248   auto* computation = module.AddEntryComputation(builder.Build());
   1249   HloInstruction* fusion = computation->CreateFusionInstruction(
   1250       {dot, reshape}, HloInstruction::FusionKind::kTransposeDot);
   1251 
   1252   EXPECT_EQ(
   1253       fusion->ToString(options),
   1254       "%dot_fusion = f32[5,20]{1,0} fusion(f32[5,10]{1,0} %x, "
   1255       "f32[20,10]{1,0} %y), kind=kTransposeDot, calls=%fused_computation");
   1256 
   1257   HloInstruction* loop = builder.AddInstruction(
   1258       HloInstruction::CreateWhile(sout, computation, computation, x));
   1259   EXPECT_EQ(loop->ToString(options),
   1260             "%while = f32[5,20]{1,0} while(f32[5,10]{1,0} %x), "
   1261             "condition=%TransposeDot, body=%TransposeDot");
   1262 
   1263   auto pred = builder.AddInstruction(
   1264       HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
   1265   HloInstruction* conditional =
   1266       builder.AddInstruction(HloInstruction::CreateConditional(
   1267           sout, pred, x, computation, x, computation));
   1268   EXPECT_EQ(conditional->ToString(options),
   1269             "%conditional = f32[5,20]{1,0} conditional(pred[] %constant, "
   1270             "f32[5,10]{1,0} %x, f32[5,10]{1,0} %x), "
   1271             "true_computation=%TransposeDot, false_computation=%TransposeDot");
   1272 }
   1273 
   1274 TEST_F(HloInstructionTest, StringifyGather) {
   1275   Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
   1276   Shape gather_indices_tensor_shape =
   1277       ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
   1278   Shape gather_result_shape =
   1279       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26});
   1280 
   1281   HloComputation::Builder builder("Gather");
   1282   HloInstruction* input = builder.AddInstruction(
   1283       HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
   1284   HloInstruction* gather_indices =
   1285       builder.AddInstruction(HloInstruction::CreateParameter(
   1286           1, gather_indices_tensor_shape, "gather_indices"));
   1287 
   1288   HloInstruction* gather_instruction =
   1289       builder.AddInstruction(HloInstruction::CreateGather(
   1290           gather_result_shape, input, gather_indices,
   1291           HloInstruction::MakeGatherDimNumbers(
   1292               /*output_window_dims=*/{4, 5, 6, 7, 8},
   1293               /*elided_window_dims=*/{},
   1294               /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
   1295           /*window_bounds=*/{30, 29, 28, 27, 26}));
   1296 
   1297   HloModule module(TestName());
   1298   module.AddEntryComputation(builder.Build());
   1299 
   1300   EXPECT_EQ(gather_instruction->ToString(),
   1301             "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
   1302             "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
   1303             "s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), "
   1304             "output_window_dims={4,5,6,7,8}, elided_window_dims={}, "
   1305             "gather_dims_to_operand_dims={0,1,2,3,4}, "
   1306             "window_bounds={30,29,28,27,26}");
   1307 }
   1308 
   1309 }  // namespace
   1310 }  // namespace xla
   1311