Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
     17 
     18 #include <algorithm>
     19 #include <set>
     20 
     21 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     22 #include "tensorflow/compiler/xla/service/transpose_folding.h"
     23 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     24 #include "tensorflow/core/lib/gtl/array_slice.h"
     25 
     26 namespace op = xla::testing::opcode_matchers;
     27 
     28 namespace xla {
     29 namespace cpu {
     30 namespace {
     31 
     32 using InstructionFusionTest = HloTestBase;
     33 
     34 std::unique_ptr<HloInstruction> MakeDot(const Shape& shape, HloInstruction* lhs,
     35                                         HloInstruction* rhs) {
     36   DotDimensionNumbers dot_dnums;
     37   dot_dnums.add_lhs_contracting_dimensions(1);
     38   dot_dnums.add_rhs_contracting_dimensions(0);
     39   return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums);
     40 }
     41 
     42 TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) {
     43   HloComputation::Builder builder(TestName());
     44   HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
     45       0, ShapeUtil::MakeShape(F32, {1024, 256}), "arg0"));
     46   HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
     47       1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1"));
     48 
     49   HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
     50       ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg0));
     51   HloInstruction* dot = builder.AddInstruction(
     52       MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), exp0, arg1));
     53 
     54   auto module = CreateNewModule();
     55   auto computation = module->AddEntryComputation(builder.Build());
     56   EXPECT_EQ(dot, computation->root_instruction());
     57   EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
     58   EXPECT_THAT(computation->root_instruction(), op::Fusion());
     59 }
     60 
     61 TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) {
     62   HloComputation::Builder builder(TestName());
     63   HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
     64       0, ShapeUtil::MakeShape(F32, {1, 256}), "arg0"));
     65   HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
     66       1, ShapeUtil::MakeShape(F32, {256, 1024}), "arg1"));
     67 
     68   HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
     69       ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1));
     70   HloInstruction* dot = builder.AddInstruction(
     71       MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, exp1));
     72 
     73   auto module = CreateNewModule();
     74   auto computation = module->AddEntryComputation(builder.Build());
     75   EXPECT_EQ(dot, computation->root_instruction());
     76   EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
     77   EXPECT_THAT(computation->root_instruction(), op::Fusion());
     78 }
     79 
     80 TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) {
     81   HloComputation::Builder builder(TestName());
     82   HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
     83       0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0"));
     84   HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
     85       1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1"));
     86 
     87   HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
     88       ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0));
     89   HloInstruction* bitcast0 = builder.AddInstruction(HloInstruction::CreateUnary(
     90       ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kBitcast, exp0));
     91   HloInstruction* dot = builder.AddInstruction(
     92       MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), bitcast0, arg1));
     93 
     94   auto module = CreateNewModule();
     95   auto computation = module->AddEntryComputation(builder.Build());
     96   EXPECT_EQ(dot, computation->root_instruction());
     97   EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
     98   EXPECT_THAT(computation->root_instruction(), op::Fusion());
     99 }
    100 
    101 TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) {
    102   HloComputation::Builder builder(TestName());
    103   HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
    104       0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0"));
    105   HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
    106       1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1"));
    107 
    108   HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
    109       ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0));
    110   HloInstruction* reshape0 =
    111       builder.AddInstruction(HloInstruction::CreateReshape(
    112           ShapeUtil::MakeShape(S32, {1024, 256}), exp0));
    113   HloInstruction* dot = builder.AddInstruction(
    114       MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), reshape0, arg1));
    115 
    116   auto module = CreateNewModule();
    117   auto computation = module->AddEntryComputation(builder.Build());
    118   EXPECT_EQ(dot, computation->root_instruction());
    119   EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
    120   EXPECT_THAT(computation->root_instruction(), op::Fusion());
    121 }
    122 
    123 TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) {
    124   HloComputation::Builder builder(TestName());
    125   HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
    126       0, ShapeUtil::MakeShape(F32, {1, 32 * 1024}), "arg0"));
    127   HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
    128       1, ShapeUtil::MakeShape(F32, {256, 32 * 1024}), "arg1"));
    129 
    130   HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
    131       ShapeUtil::MakeShape(S32, {256, 32 * 1024}), HloOpcode::kExp, arg1));
    132   HloInstruction* dot = builder.AddInstruction(
    133       MakeDot(ShapeUtil::MakeShape(F32, {1, 32 * 1024}), arg0, exp1));
    134 
    135   auto module = CreateNewModule();
    136   auto computation = module->AddEntryComputation(builder.Build());
    137   EXPECT_EQ(dot, computation->root_instruction());
    138   EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
    139   EXPECT_EQ(dot, computation->root_instruction());
    140 }
    141 
    142 TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) {
    143   HloComputation::Builder builder(TestName());
    144   HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
    145       0, ShapeUtil::MakeShape(F32, {2, 256}), "arg0"));
    146   HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
    147       1, ShapeUtil::MakeShape(F32, {256, 1024}), "arg1"));
    148 
    149   HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
    150       ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1));
    151   HloInstruction* dot = builder.AddInstruction(
    152       MakeDot(ShapeUtil::MakeShape(F32, {2, 1024}), arg0, exp1));
    153 
    154   auto module = CreateNewModule();
    155   auto computation = module->AddEntryComputation(builder.Build());
    156   EXPECT_EQ(dot, computation->root_instruction());
    157   EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
    158   EXPECT_EQ(dot, computation->root_instruction());
    159 }
    160 
    161 TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion) {
    162   HloComputation::Builder builder(TestName());
    163   HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
    164       0, ShapeUtil::MakeShape(F32, {1, 256}), "arg0"));
    165   HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
    166       1, ShapeUtil::MakeShape(F32, {1024, 256}), "arg1"));
    167 
    168   HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
    169       ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg1));
    170   HloInstruction* transpose1 =
    171       builder.AddInstruction(HloInstruction::CreateTranspose(
    172           ShapeUtil::MakeShape(S32, {256, 1024}), exp1, {1, 0}));
    173   builder.AddInstruction(
    174       MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, transpose1));
    175 
    176   auto module = CreateNewModule();
    177   auto computation = module->AddEntryComputation(builder.Build());
    178   TransposeFolding transpose_folding(
    179       [](const HloInstruction& dot,
    180          const TransposeFolding::OperandIndices& candidate_operands) {
    181         return candidate_operands;
    182       },
    183       TransposeFolding::NeverFoldTranspose);
    184   EXPECT_TRUE(transpose_folding.Run(module.get()).ValueOrDie());
    185   EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kFusion);
    186   EXPECT_EQ(computation->root_instruction()->fusion_kind(),
    187             HloInstruction::FusionKind::kTransposeDot);
    188   EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
    189   EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kFusion);
    190   EXPECT_EQ(computation->root_instruction()->fusion_kind(),
    191             HloInstruction::FusionKind::kTransposeDot);
    192 }
    193 
    194 class OpcodeFusionTest : public InstructionFusionTest {
    195  protected:
    196   // Runs CPU instruction fusion on the given module, and tests that the result
    197   // contains a fused op at the root with exactly the given multiset of opcodes.
    198   void RunFusionAndCheckOpcodesWereFused(
    199       HloModule* module, const std::multiset<HloOpcode>& expected_opcodes,
    200       HloInstruction::FusionKind fusion_kind =
    201           HloInstruction::FusionKind::kLoop) {
    202     auto computation = module->entry_computation();
    203     auto did_fusion = CpuInstructionFusion().Run(module);
    204     ASSERT_TRUE(did_fusion.ok());
    205     EXPECT_TRUE(did_fusion.ValueOrDie());
    206 
    207     HloInstruction* root = computation->root_instruction();
    208     ASSERT_THAT(root, op::Fusion());
    209     EXPECT_EQ(root->fusion_kind(), fusion_kind);
    210 
    211     std::vector<HloOpcode> fused_opcodes(root->fused_instruction_count());
    212     std::transform(root->fused_instructions().begin(),
    213                    root->fused_instructions().end(), fused_opcodes.begin(),
    214                    [](const HloInstruction* hlo) { return hlo->opcode(); });
    215 
    216     EXPECT_EQ(
    217         std::multiset<HloOpcode>(fused_opcodes.begin(), fused_opcodes.end()),
    218         expected_opcodes);
    219   }
    220 
    221   HloComputation* CreateAdderToOne(HloModule* module) {
    222     HloComputation::Builder builder(TestName());
    223     HloInstruction* arg0 =
    224         builder.AddInstruction(HloInstruction::CreateParameter(
    225             0, ShapeUtil::MakeShape(F32, {}), "arg0"));
    226     HloInstruction* one = builder.AddInstruction(
    227         HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
    228     builder.AddInstruction(HloInstruction::CreateBinary(
    229         ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, arg0, one));
    230     return module->AddEmbeddedComputation(builder.Build());
    231   }
    232 
    233   HloComputation* CreateMax(HloModule* module) {
    234     HloComputation::Builder builder(TestName());
    235     HloInstruction* arg0 =
    236         builder.AddInstruction(HloInstruction::CreateParameter(
    237             0, ShapeUtil::MakeShape(F32, {}), "arg0"));
    238     HloInstruction* arg1 =
    239         builder.AddInstruction(HloInstruction::CreateParameter(
    240             1, ShapeUtil::MakeShape(F32, {}), "arg1"));
    241     builder.AddInstruction(HloInstruction::CreateBinary(
    242         ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, arg0, arg1));
    243     return module->AddEmbeddedComputation(builder.Build());
    244   }
    245 };
    246 
    247 TEST_F(OpcodeFusionTest, Exponential_Bitcast_Negate) {
    248   HloComputation::Builder builder(TestName());
    249   Shape param_shape = ShapeUtil::MakeShape(F32, {1, 4});
    250   Shape result_shape = ShapeUtil::MakeShape(F32, {4});
    251   HloInstruction* param0 = builder.AddInstruction(
    252       HloInstruction::CreateParameter(0, param_shape, "param"));
    253   // InstructionFusion::ShouldFuse() precludes fusing a bitcast whose operand
    254   // is a parameter, so create an operand between the parameter and bitcast.
    255   HloInstruction* exp1 = builder.AddInstruction(
    256       HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0));
    257   HloInstruction* bitcast2 = builder.AddInstruction(
    258       HloInstruction::CreateUnary(result_shape, HloOpcode::kBitcast, exp1));
    259   builder.AddInstruction(
    260       HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, bitcast2));
    261 
    262   auto module = CreateNewModule();
    263   module->AddEntryComputation(builder.Build());
    264 
    265   RunFusionAndCheckOpcodesWereFused(
    266       module.get(), {HloOpcode::kNegate, HloOpcode::kBitcast, HloOpcode::kExp,
    267                      HloOpcode::kParameter});
    268 }
    269 
    270 TEST_F(OpcodeFusionTest, Broadcast_Bitcast_DynamicSlice_Tanh) {
    271   HloComputation::Builder builder(TestName());
    272   Shape param_shape = ShapeUtil::MakeShape(F32, {8});
    273   Shape starts_shape = ShapeUtil::MakeShape(F32, {2});
    274   Shape broadcast_shape = ShapeUtil::MakeShape(F32, {1, 8, 8});
    275   Shape bitcast_shape = ShapeUtil::MakeShape(F32, {8, 8});
    276   Shape dynamic_slice_shape = ShapeUtil::MakeShape(F32, {4, 4});
    277   HloInstruction* param0 = builder.AddInstruction(
    278       HloInstruction::CreateParameter(0, param_shape, "param"));
    279   HloInstruction* param1 = builder.AddInstruction(
    280       HloInstruction::CreateParameter(1, starts_shape, "starts"));
    281   HloInstruction* broadcast2 = builder.AddInstruction(
    282       HloInstruction::CreateBroadcast(broadcast_shape, param0, {1}));
    283   HloInstruction* bitcast3 = builder.AddInstruction(HloInstruction::CreateUnary(
    284       bitcast_shape, HloOpcode::kBitcast, broadcast2));
    285   HloInstruction* dynamic_slice4 =
    286       builder.AddInstruction(HloInstruction::CreateDynamicSlice(
    287           dynamic_slice_shape, bitcast3, param1, {4, 4}));
    288   builder.AddInstruction(HloInstruction::CreateUnary(
    289       dynamic_slice_shape, HloOpcode::kTanh, dynamic_slice4));
    290 
    291   auto module = CreateNewModule();
    292   module->AddEntryComputation(builder.Build());
    293 
    294   RunFusionAndCheckOpcodesWereFused(
    295       module.get(),
    296       {HloOpcode::kTanh, HloOpcode::kDynamicSlice, HloOpcode::kBitcast,
    297        HloOpcode::kBroadcast, HloOpcode::kParameter, HloOpcode::kParameter});
    298 }
    299 
    300 TEST_F(OpcodeFusionTest, Broadcast_Negate) {
    301   HloComputation::Builder builder(TestName());
    302   Shape param_shape = ShapeUtil::MakeShape(F32, {8});
    303   Shape result_shape = ShapeUtil::MakeShape(F32, {8, 8});
    304   HloInstruction* param0 = builder.AddInstruction(
    305       HloInstruction::CreateParameter(0, param_shape, "param"));
    306   HloInstruction* broadcast1 = builder.AddInstruction(
    307       HloInstruction::CreateBroadcast(result_shape, param0, {1}));
    308   builder.AddInstruction(HloInstruction::CreateUnary(
    309       result_shape, HloOpcode::kNegate, broadcast1));
    310 
    311   auto module = CreateNewModule();
    312   module->AddEntryComputation(builder.Build());
    313 
    314   RunFusionAndCheckOpcodesWereFused(
    315       module.get(),
    316       {HloOpcode::kNegate, HloOpcode::kBroadcast, HloOpcode::kParameter});
    317 }
    318 
    319 TEST_F(OpcodeFusionTest, DynamicSlice_Negate) {
    320   HloComputation::Builder builder(TestName());
    321   Shape param_shape = ShapeUtil::MakeShape(F32, {4});
    322   Shape slice_shape = ShapeUtil::MakeShape(F32, {1});
    323   Shape result_shape = ShapeUtil::MakeShape(F32, {2});
    324   HloInstruction* param0 = builder.AddInstruction(
    325       HloInstruction::CreateParameter(0, param_shape, "param"));
    326   HloInstruction* param1 = builder.AddInstruction(
    327       HloInstruction::CreateParameter(1, slice_shape, "starts"));
    328   HloInstruction* dynamic_slice2 = builder.AddInstruction(
    329       HloInstruction::CreateDynamicSlice(result_shape, param0, param1, {2}));
    330   builder.AddInstruction(HloInstruction::CreateUnary(
    331       result_shape, HloOpcode::kNegate, dynamic_slice2));
    332 
    333   auto module = CreateNewModule();
    334   module->AddEntryComputation(builder.Build());
    335 
    336   RunFusionAndCheckOpcodesWereFused(
    337       module.get(), {HloOpcode::kNegate, HloOpcode::kDynamicSlice,
    338                      HloOpcode::kParameter, HloOpcode::kParameter});
    339 }
    340 
    341 TEST_F(OpcodeFusionTest, Exponential_Negate) {
    342   HloComputation::Builder builder(TestName());
    343   Shape param_shape = ShapeUtil::MakeShape(F32, {4});
    344   HloInstruction* param0 = builder.AddInstruction(
    345       HloInstruction::CreateParameter(0, param_shape, "param"));
    346   HloInstruction* exp1 = builder.AddInstruction(
    347       HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0));
    348   builder.AddInstruction(
    349       HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, exp1));
    350 
    351   auto module = CreateNewModule();
    352   module->AddEntryComputation(builder.Build());
    353 
    354   RunFusionAndCheckOpcodesWereFused(
    355       module.get(),
    356       {HloOpcode::kNegate, HloOpcode::kExp, HloOpcode::kParameter});
    357 }
    358 
    359 TEST_F(OpcodeFusionTest, Reshape_Negate) {
    360   HloComputation::Builder builder(TestName());
    361   Shape param_shape = ShapeUtil::MakeShape(F32, {4, 4});
    362   Shape result_shape = ShapeUtil::MakeShape(F32, {16});
    363   HloInstruction* param0 = builder.AddInstruction(
    364       HloInstruction::CreateParameter(0, param_shape, "param"));
    365   HloInstruction* reshape1 = builder.AddInstruction(
    366       HloInstruction::CreateReshape(result_shape, param0));
    367   builder.AddInstruction(
    368       HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape1));
    369 
    370   auto module = CreateNewModule();
    371   module->AddEntryComputation(builder.Build());
    372 
    373   RunFusionAndCheckOpcodesWereFused(
    374       module.get(),
    375       {HloOpcode::kNegate, HloOpcode::kReshape, HloOpcode::kParameter});
    376 }
    377 
    378 TEST_F(OpcodeFusionTest, Reverse_Negate) {
    379   HloComputation::Builder builder(TestName());
    380   Shape param_shape = ShapeUtil::MakeShape(F32, {8});
    381   HloInstruction* param0 = builder.AddInstruction(
    382       HloInstruction::CreateParameter(0, param_shape, "param"));
    383   HloInstruction* reverse1 = builder.AddInstruction(
    384       HloInstruction::CreateReverse(param_shape, param0, {0}));
    385   builder.AddInstruction(
    386       HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, reverse1));
    387 
    388   auto module = CreateNewModule();
    389   module->AddEntryComputation(builder.Build());
    390 
    391   RunFusionAndCheckOpcodesWereFused(
    392       module.get(),
    393       {HloOpcode::kNegate, HloOpcode::kReverse, HloOpcode::kParameter});
    394 }
    395 
    396 TEST_F(OpcodeFusionTest, Slice_Negate) {
    397   HloComputation::Builder builder(TestName());
    398   Shape param_shape = ShapeUtil::MakeShape(F32, {4});
    399   Shape slice_shape = ShapeUtil::MakeShape(F32, {2});
    400   HloInstruction* param0 = builder.AddInstruction(
    401       HloInstruction::CreateParameter(0, param_shape, "param"));
    402   HloInstruction* slice1 = builder.AddInstruction(
    403       HloInstruction::CreateSlice(slice_shape, param0, {0}, {4}, {2}));
    404   builder.AddInstruction(HloInstruction::CreateUnary(
    405       ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, slice1));
    406 
    407   auto module = CreateNewModule();
    408   module->AddEntryComputation(builder.Build());
    409 
    410   RunFusionAndCheckOpcodesWereFused(
    411       module.get(),
    412       {HloOpcode::kNegate, HloOpcode::kSlice, HloOpcode::kParameter});
    413 }
    414 
    415 TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) {
    416   HloComputation::Builder builder(TestName());
    417   Shape param_shape = ShapeUtil::MakeShape(F32, {3, 4});
    418   Shape result_shape = ShapeUtil::MakeShape(F32, {4, 3});
    419   HloInstruction* param0 = builder.AddInstruction(
    420       HloInstruction::CreateParameter(0, param_shape, "param"));
    421   // InstructionFusion::ShouldFuse() precludes fusing a transpose whose operand
    422   // is a parameter, so create an operand between the parameter and transpose.
    423   HloInstruction* exp1 = builder.AddInstruction(
    424       HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0));
    425   HloInstruction* transpose2 = builder.AddInstruction(
    426       HloInstruction::CreateTranspose(result_shape, exp1, {1, 0}));
    427   builder.AddInstruction(HloInstruction::CreateUnary(
    428       result_shape, HloOpcode::kNegate, transpose2));
    429 
    430   auto module = CreateNewModule();
    431   module->AddEntryComputation(builder.Build());
    432 
    433   RunFusionAndCheckOpcodesWereFused(
    434       module.get(), {HloOpcode::kNegate, HloOpcode::kTranspose, HloOpcode::kExp,
    435                      HloOpcode::kParameter});
    436 }
    437 
    438 TEST_F(OpcodeFusionTest, UnaryMapOfExp) {
    439   auto module = CreateNewModule();
    440 
    441   HloComputation::Builder builder(TestName());
    442   Shape shape = ShapeUtil::MakeShape(F32, {3, 4});
    443   HloInstruction* param0 = builder.AddInstruction(
    444       HloInstruction::CreateParameter(0, shape, "param"));
    445 
    446   HloInstruction* exp = builder.AddInstruction(
    447       HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0));
    448   builder.AddInstruction(HloInstruction::CreateMap(
    449       shape, {exp}, CreateAdderToOne(module.get()), /*static_operands=*/{}));
    450 
    451   module->AddEntryComputation(builder.Build());
    452 
    453   RunFusionAndCheckOpcodesWereFused(
    454       module.get(), {HloOpcode::kParameter, HloOpcode::kExp, HloOpcode::kMap});
    455 }
    456 
    457 TEST_F(OpcodeFusionTest, BinaryMapOfExps) {
    458   auto module = CreateNewModule();
    459 
    460   HloComputation::Builder builder(TestName());
    461   Shape shape = ShapeUtil::MakeShape(F32, {3, 4});
    462   HloInstruction* param0 = builder.AddInstruction(
    463       HloInstruction::CreateParameter(0, shape, "param"));
    464   HloInstruction* param1 = builder.AddInstruction(
    465       HloInstruction::CreateParameter(1, shape, "param"));
    466 
    467   HloInstruction* exp0 = builder.AddInstruction(
    468       HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0));
    469   HloInstruction* exp1 = builder.AddInstruction(
    470       HloInstruction::CreateUnary(shape, HloOpcode::kExp, param1));
    471 
    472   builder.AddInstruction(HloInstruction::CreateMap(
    473       shape, {exp0, exp1}, CreateMax(module.get()), /*static_operands=*/{}));
    474 
    475   module->AddEntryComputation(builder.Build());
    476 
    477   RunFusionAndCheckOpcodesWereFused(
    478       module.get(), {HloOpcode::kParameter, HloOpcode::kParameter,
    479                      HloOpcode::kExp, HloOpcode::kExp, HloOpcode::kMap});
    480 }
    481 
    482 TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) {
    483   auto module = CreateNewModule();
    484 
    485   HloComputation::Builder builder(TestName());
    486   Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
    487   Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000});
    488 
    489   HloInstruction* slice =
    490       builder.AddInstruction(HloInstruction::CreateDynamicSlice(
    491           slice_shape,
    492           builder.AddInstruction(
    493               HloInstruction::CreateParameter(0, full_shape, "slice_from")),
    494           builder.AddInstruction(HloInstruction::CreateParameter(
    495               1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")),
    496           /*slice_sizes=*/{10, 1, 1000}));
    497 
    498   builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
    499       full_shape,
    500       builder.AddInstruction(
    501           HloInstruction::CreateParameter(2, full_shape, "to_update")),
    502       slice,
    503       builder.AddInstruction(HloInstruction::CreateParameter(
    504           3, ShapeUtil::MakeShape(U32, {3}), "update_indices"))));
    505 
    506   module->AddEntryComputation(builder.Build());
    507   RunFusionAndCheckOpcodesWereFused(
    508       module.get(), {HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice,
    509                      HloOpcode::kParameter, HloOpcode::kParameter,
    510                      HloOpcode::kParameter, HloOpcode::kParameter});
    511 }
    512 
    513 TEST_F(OpcodeFusionTest, MessOfFusileNodes) {
    514   auto module = CreateNewModule();
    515   HloComputation::Builder builder(TestName());
    516 
    517   Shape full_shape = ShapeUtil::MakeShape(F32, {4, 100, 10, 100, 50});
    518 
    519   auto loop_idx = builder.AddInstruction(HloInstruction::CreateReshape(
    520       ShapeUtil::MakeShape(S32, {1}),
    521       builder.AddInstruction(HloInstruction::CreateParameter(
    522           0, ShapeUtil::MakeShape(S32, {}), "param0"))));
    523 
    524   auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
    525       1, ShapeUtil::MakeShape(S32, {1}), "param1"));
    526   auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
    527       ShapeUtil::MakeShape(S32, {5}),
    528       {loop_idx, param1, param1, param1, param1}, /*dimension=*/0));
    529 
    530   auto idx_choice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
    531       ShapeUtil::MakeShape(S32, {1}),
    532       builder.AddInstruction(HloInstruction::CreateParameter(
    533           2, ShapeUtil::MakeShape(S32, {4}), "param2")),
    534       loop_idx,
    535       /*slice_sizes=*/{1}));
    536 
    537   PaddingConfig padding_config;
    538   padding_config.add_dimensions()->set_edge_padding_high(4);
    539   auto pad = builder.AddInstruction(HloInstruction::CreatePad(
    540       ShapeUtil::MakeShape(S32, {5}), idx_choice,
    541       builder.AddInstruction(
    542           HloInstruction::CreateConstant(Literal::CreateR0(0))),
    543       padding_config));
    544 
    545   auto slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
    546       ShapeUtil::MakeShape(F32, {1, 100, 10, 100, 50}),
    547       builder.AddInstruction(HloInstruction::CreateParameter(
    548           3, ShapeUtil::MakeShape(F32, {100, 100, 10, 100, 50}), "param3")),
    549       pad, /*slice_sizes=*/{1, 100, 10, 100, 50}));
    550 
    551   builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
    552       full_shape,
    553       builder.AddInstruction(
    554           HloInstruction::CreateParameter(4, full_shape, "param4")),
    555       slice, concat));
    556 
    557   module->AddEntryComputation(builder.Build());
    558   RunFusionAndCheckOpcodesWereFused(
    559       module.get(),
    560       {HloOpcode::kConcatenate, HloOpcode::kPad, HloOpcode::kDynamicSlice,
    561        HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice,
    562        HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter,
    563        HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter});
    564 }
    565 
    566 // Tests that we do not fuse instructions in cases where instructions in the
    567 // fusion would reuse elements from its operand due to an implicit broadcast.
    568 TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastUnary) {
    569   Shape small_shape = ShapeUtil::MakeShape(F32, {1, 4});
    570   Shape large_shape = ShapeUtil::MakeShape(F32, {3, 4});
    571 
    572   HloComputation::Builder builder(TestName());
    573 
    574   HloInstruction* small_param =
    575       builder.AddInstruction(HloInstruction::CreateParameter(
    576           /*parameter_number=*/0, small_shape, "param"));
    577   HloInstruction* small_exp = builder.AddInstruction(
    578       HloInstruction::CreateUnary(small_shape, HloOpcode::kExp, small_param));
    579   builder.AddInstruction(
    580       HloInstruction::CreateUnary(large_shape, HloOpcode::kExp, small_exp));
    581 
    582   std::unique_ptr<HloModule> module = CreateNewModule();
    583   module->AddEntryComputation(builder.Build());
    584 
    585   auto did_fusion = CpuInstructionFusion().Run(module.get());
    586   ASSERT_TRUE(did_fusion.ok());
    587   EXPECT_FALSE(did_fusion.ValueOrDie());
    588   ASSERT_THAT(module->entry_computation()->root_instruction(),
    589               Not(op::Fusion()));
    590 }
    591 
    592 // Like ReuseViaImplicitBroadcastUnary but with a binary operation.
    593 TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastBinary) {
    594   Shape small_shape = ShapeUtil::MakeShape(F32, {1, 4});
    595   Shape large_shape = ShapeUtil::MakeShape(F32, {3, 4});
    596 
    597   HloComputation::Builder builder(TestName());
    598 
    599   HloInstruction* small_param =
    600       builder.AddInstruction(HloInstruction::CreateParameter(
    601           /*parameter_number=*/0, small_shape, "param"));
    602   HloInstruction* large_param =
    603       builder.AddInstruction(HloInstruction::CreateParameter(
    604           /*parameter_number=*/1, large_shape, "param"));
    605   HloInstruction* small_exp = builder.AddInstruction(
    606       HloInstruction::CreateUnary(small_shape, HloOpcode::kExp, small_param));
    607 
    608   builder.AddInstruction(HloInstruction::CreateBinary(
    609       large_shape, HloOpcode::kAdd, small_exp, large_param));
    610 
    611   std::unique_ptr<HloModule> module = CreateNewModule();
    612   module->AddEntryComputation(builder.Build());
    613 
    614   auto did_fusion = CpuInstructionFusion().Run(module.get());
    615   ASSERT_TRUE(did_fusion.ok());
    616   EXPECT_FALSE(did_fusion.ValueOrDie());
    617   ASSERT_THAT(module->entry_computation()->root_instruction(),
    618               Not(op::Fusion()));
    619 }
    620 
    621 void CreateComputationForDotAddOutputFusionTest(const string& test_name,
    622                                                 HloModule* module, int m, int k,
    623                                                 int n,
    624                                                 bool add_extra_use_for_dot) {
    625   HloComputation::Builder builder(test_name);
    626 
    627   Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
    628   Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
    629   Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
    630 
    631   auto* dot_lhs = builder.AddInstruction(
    632       HloInstruction::CreateParameter(0, dot_lhs_shape, "param0"));
    633   auto* dot_rhs = builder.AddInstruction(
    634       HloInstruction::CreateParameter(1, dot_rhs_shape, "param1"));
    635   auto* addend = builder.AddInstruction(
    636       HloInstruction::CreateParameter(2, dot_shape, "param2"));
    637 
    638   auto* dot = builder.AddInstruction(
    639       HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
    640   builder.AddInstruction(
    641       HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend));
    642 
    643   if (add_extra_use_for_dot) {
    644     builder.AddInstruction(
    645         HloInstruction::CreateOutfeed(dot_shape, dot, "no_config"));
    646   }
    647 
    648   module->AddEntryComputation(builder.Build());
    649 }
    650 
    651 TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) {
    652   auto module = CreateNewModule();
    653   CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/1,
    654                                              /*k=*/50, /*n=*/19,
    655                                              /*add_extra_use_for_dot=*/false);
    656 
    657   RunFusionAndCheckOpcodesWereFused(
    658       module.get(),
    659       {HloOpcode::kDot, HloOpcode::kAdd, HloOpcode::kParameter,
    660        HloOpcode::kParameter, HloOpcode::kParameter},
    661       HloInstruction::FusionKind::kOutput);
    662 }
    663 
    664 TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) {
    665   auto module = CreateNewModule();
    666   CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19,
    667                                              /*k=*/50, /*n=*/1,
    668                                              /*add_extra_use_for_dot=*/false);
    669 
    670   RunFusionAndCheckOpcodesWereFused(
    671       module.get(),
    672       {HloOpcode::kDot, HloOpcode::kAdd, HloOpcode::kParameter,
    673        HloOpcode::kParameter, HloOpcode::kParameter},
    674       HloInstruction::FusionKind::kOutput);
    675 }
    676 
    677 TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) {
    678   auto module = CreateNewModule();
    679   CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19,
    680                                              /*k=*/50, /*n=*/19,
    681                                              /*add_extra_use_for_dot=*/false);
    682 
    683   TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
    684                           CpuInstructionFusion().Run(module.get()));
    685   EXPECT_FALSE(fused_something);
    686   EXPECT_THAT(module->entry_computation()->root_instruction(),
    687               Not(op::Fusion()));
    688 }
    689 
    690 TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) {
    691   auto module = CreateNewModule();
    692   CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19,
    693                                              /*k=*/50, /*n=*/1,
    694                                              /*add_extra_use_for_dot=*/true);
    695 
    696   TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
    697                           CpuInstructionFusion().Run(module.get()));
    698   EXPECT_FALSE(fused_something);
    699   EXPECT_THAT(module->entry_computation()->root_instruction(),
    700               Not(op::Fusion()));
    701 }
    702 
    703 }  // namespace
    704 }  // namespace cpu
    705 }  // namespace xla
    706