Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
     17 
     18 #include <algorithm>
     19 #include <set>
     20 
     21 #include "absl/strings/str_cat.h"
     22 #include "absl/types/span.h"
     23 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     24 #include "tensorflow/compiler/xla/service/hlo_parser.h"
     25 #include "tensorflow/compiler/xla/service/transpose_folding.h"
     26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     27 #include "tensorflow/compiler/xla/tests/test_utils.h"
     28 
     29 namespace op = xla::testing::opcode_matchers;
     30 
     31 namespace xla {
     32 namespace cpu {
     33 namespace {
     34 
     35 using InstructionFusionTest = HloTestBase;
     36 
     37 std::unique_ptr<HloInstruction> MakeDot(const Shape& shape, HloInstruction* lhs,
     38                                         HloInstruction* rhs) {
     39   DotDimensionNumbers dot_dnums;
     40   dot_dnums.add_lhs_contracting_dimensions(1);
     41   dot_dnums.add_rhs_contracting_dimensions(0);
     42   PrecisionConfig precision_config;
     43   precision_config.mutable_operand_precision()->Resize(
     44       2, PrecisionConfig::DEFAULT);
     45   return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums,
     46                                    precision_config);
     47 }
     48 
     49 TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) {
     50   HloComputation::Builder builder(TestName());
     51   HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
     52       0, ShapeUtil::MakeShape(F32, {1024, 256}), "arg0"));
     53   HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
     54       1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1"));
     55 
     56   HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
     57       ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg0));
     58   HloInstruction* dot = builder.AddInstruction(
     59       MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), exp0, arg1));
     60 
     61   auto module = CreateNewUnverifiedModule();
     62   auto computation = module->AddEntryComputation(builder.Build());
     63   EXPECT_EQ(dot, computation->root_instruction());
     64   EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
     65   EXPECT_THAT(computation->root_instruction(), op::Fusion());
     66 }
     67 
     68 TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) {
     69   HloComputation::Builder builder(TestName());
     70   HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
     71       0, ShapeUtil::MakeShape(F32, {1, 256}), "arg0"));
     72   HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
     73       1, ShapeUtil::MakeShape(F32, {256, 1024}), "arg1"));
     74 
     75   HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
     76       ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1));
     77   HloInstruction* dot = builder.AddInstruction(
     78       MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, exp1));
     79 
     80   auto module = CreateNewUnverifiedModule();
     81   auto computation = module->AddEntryComputation(builder.Build());
     82   EXPECT_EQ(dot, computation->root_instruction());
     83   EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
     84   EXPECT_THAT(computation->root_instruction(), op::Fusion());
     85 }
     86 
     87 TEST_F(InstructionFusionTest, DotOperationNoFusion_Bitcast) {
     88   HloComputation::Builder builder(TestName());
     89   HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
     90       0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0"));
     91   HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
     92       1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1"));
     93 
     94   HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
     95       ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0));
     96   HloInstruction* bitcast0 = builder.AddInstruction(HloInstruction::CreateUnary(
     97       ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kBitcast, exp0));
     98   HloInstruction* dot = builder.AddInstruction(
     99       MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), bitcast0, arg1));
    100 
    101   auto module = CreateNewUnverifiedModule();
    102   auto computation = module->AddEntryComputation(builder.Build());
    103   EXPECT_EQ(dot, computation->root_instruction());
    104   EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
    105 }
    106 
    107 TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) {
    108   HloComputation::Builder builder(TestName());
    109   HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
    110       0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0"));
    111   HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
    112       1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1"));
    113 
    114   HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
    115       ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0));
    116   HloInstruction* reshape0 =
    117       builder.AddInstruction(HloInstruction::CreateReshape(
    118           ShapeUtil::MakeShape(S32, {1024, 256}), exp0));
    119   HloInstruction* dot = builder.AddInstruction(
    120       MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), reshape0, arg1));
    121 
    122   auto module = CreateNewUnverifiedModule();
    123   auto computation = module->AddEntryComputation(builder.Build());
    124   EXPECT_EQ(dot, computation->root_instruction());
    125   EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
    126   EXPECT_THAT(computation->root_instruction(), op::Fusion());
    127 }
    128 
    129 TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) {
    130   HloComputation::Builder builder(TestName());
    131   HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
    132       0, ShapeUtil::MakeShape(F32, {1, 32 * 1024}), "arg0"));
    133   HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
    134       1, ShapeUtil::MakeShape(F32, {256, 32 * 1024}), "arg1"));
    135 
    136   HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
    137       ShapeUtil::MakeShape(S32, {256, 32 * 1024}), HloOpcode::kExp, arg1));
    138   HloInstruction* dot = builder.AddInstruction(
    139       MakeDot(ShapeUtil::MakeShape(F32, {1, 32 * 1024}), arg0, exp1));
    140 
    141   auto module = CreateNewUnverifiedModule();
    142   auto computation = module->AddEntryComputation(builder.Build());
    143   EXPECT_EQ(dot, computation->root_instruction());
    144   EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
    145   EXPECT_EQ(dot, computation->root_instruction());
    146 }
    147 
    148 TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) {
    149   HloComputation::Builder builder(TestName());
    150   HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
    151       0, ShapeUtil::MakeShape(F32, {2, 256}), "arg0"));
    152   HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
    153       1, ShapeUtil::MakeShape(F32, {256, 1024}), "arg1"));
    154 
    155   HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
    156       ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1));
    157   HloInstruction* dot = builder.AddInstruction(
    158       MakeDot(ShapeUtil::MakeShape(F32, {2, 1024}), arg0, exp1));
    159 
    160   auto module = CreateNewUnverifiedModule();
    161   auto computation = module->AddEntryComputation(builder.Build());
    162   EXPECT_EQ(dot, computation->root_instruction());
    163   EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
    164   EXPECT_EQ(dot, computation->root_instruction());
    165 }
    166 
    167 TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion_RHS) {
    168   string hlo_string = R"(
    169 HloModule DotOperationFusion_TransposeFusion
    170 
    171 ENTRY DotOperationFusion_TransposeFusion {
    172   arg0 = f32[1,256] parameter(0)
    173   arg1 = f32[1024,256] parameter(1)
    174   exponential = s32[1024,256] exponential(arg1)
    175   transpose = s32[256,1024] transpose(exponential), dimensions={1,0}
    176   ROOT dot = f32[1,1024] dot(arg0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
    177 }
    178 )";
    179 
    180   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
    181                           ParseHloString(hlo_string));
    182   HloComputation* computation = module->entry_computation();
    183 
    184   TransposeFolding transpose_folding(
    185       [](const HloInstruction& dot,
    186          const TransposeFolding::OperandIndices& candidate_operands) {
    187         return candidate_operands;
    188       },
    189       TransposeFolding::NeverFoldTranspose);
    190   TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get()));
    191   ASSERT_TRUE(changed);
    192   ASSERT_THAT(computation->root_instruction(),
    193               op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)),
    194                       /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1));
    195 }
    196 
    197 TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion_LHS) {
    198   string hlo_string = R"(
    199 HloModule DotOperationFusion_TransposeFusion
    200 
    201 ENTRY DotOperationFusion_TransposeFusion {
    202   arg0 = f32[256,1] parameter(0)
    203   arg1 = f32[256,1024] parameter(1)
    204   transpose = s32[1,256] transpose(arg0), dimensions={1,0}
    205   exponential = s32[256,1024] exponential(arg1)
    206   ROOT dot = f32[1,1024] dot(transpose, exponential), lhs_contracting_dims={1}, rhs_contracting_dims={0}
    207 }
    208 )";
    209 
    210   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
    211                           ParseHloString(hlo_string));
    212   HloComputation* computation = module->entry_computation();
    213 
    214   TransposeFolding transpose_folding(
    215       [](const HloInstruction& dot,
    216          const TransposeFolding::OperandIndices& candidate_operands) {
    217         return candidate_operands;
    218       },
    219       TransposeFolding::NeverFoldTranspose);
    220   TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get()));
    221   ASSERT_TRUE(changed);
    222   ASSERT_THAT(computation->root_instruction(),
    223               op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)),
    224                       /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/0));
    225 }
    226 
    227 TEST_F(InstructionFusionTest,
    228        DotOperationFusion_TransposeFusion_LHS_NonDefault) {
    229   string hlo_string = R"(
    230 HloModule DotOperationFusion_TransposeFusion
    231 
    232 ENTRY DotOperationFusion_TransposeFusion {
    233   arg0 = f32[1,256] parameter(0)
    234   arg1 = f32[256,1024] parameter(1)
    235   transpose = s32[256,1] transpose(arg0), dimensions={1,0}
    236   exponential = s32[256,1024] exponential(arg1)
    237   ROOT dot = f32[1,1024] dot(transpose, exponential), lhs_contracting_dims={0}, rhs_contracting_dims={0}
    238 }
    239 )";
    240 
    241   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
    242                           ParseHloString(hlo_string));
    243   HloComputation* computation = module->entry_computation();
    244 
    245   TransposeFolding transpose_folding(
    246       [](const HloInstruction& dot,
    247          const TransposeFolding::OperandIndices& candidate_operands) {
    248         return candidate_operands;
    249       },
    250       TransposeFolding::NeverFoldTranspose);
    251   TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get()));
    252   ASSERT_TRUE(changed);
    253   ASSERT_THAT(computation->root_instruction(),
    254               op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)),
    255                       /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0));
    256 }
    257 
    258 class OpcodeFusionTest : public InstructionFusionTest {
    259  protected:
    260   // Runs CPU instruction fusion on the given module, and tests that the result
    261   // contains a fused op at the root with exactly the given multiset of opcodes.
    262   void RunFusionAndCheckOpcodesWereFused(
    263       HloModule* module, const std::multiset<HloOpcode>& expected_opcodes,
    264       HloInstruction::FusionKind fusion_kind =
    265           HloInstruction::FusionKind::kLoop) {
    266     auto computation = module->entry_computation();
    267     auto did_fusion = CpuInstructionFusion().Run(module);
    268     ASSERT_TRUE(did_fusion.ok());
    269     EXPECT_TRUE(did_fusion.ValueOrDie());
    270 
    271     HloInstruction* root = computation->root_instruction();
    272     ASSERT_THAT(root, op::Fusion());
    273     EXPECT_EQ(root->fusion_kind(), fusion_kind);
    274 
    275     std::vector<HloOpcode> fused_opcodes(root->fused_instruction_count());
    276     std::transform(root->fused_instructions().begin(),
    277                    root->fused_instructions().end(), fused_opcodes.begin(),
    278                    [](const HloInstruction* hlo) { return hlo->opcode(); });
    279 
    280     EXPECT_EQ(
    281         std::multiset<HloOpcode>(fused_opcodes.begin(), fused_opcodes.end()),
    282         expected_opcodes);
    283   }
    284 
    285   HloComputation* CreateAdderToOne(HloModule* module) {
    286     HloComputation::Builder builder(TestName());
    287     HloInstruction* arg0 =
    288         builder.AddInstruction(HloInstruction::CreateParameter(
    289             0, ShapeUtil::MakeShape(F32, {}), "arg0"));
    290     HloInstruction* one = builder.AddInstruction(
    291         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
    292     builder.AddInstruction(HloInstruction::CreateBinary(
    293         ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, arg0, one));
    294     return module->AddEmbeddedComputation(builder.Build());
    295   }
    296 
    297   HloComputation* CreateMax(HloModule* module) {
    298     HloComputation::Builder builder(TestName());
    299     HloInstruction* arg0 =
    300         builder.AddInstruction(HloInstruction::CreateParameter(
    301             0, ShapeUtil::MakeShape(F32, {}), "arg0"));
    302     HloInstruction* arg1 =
    303         builder.AddInstruction(HloInstruction::CreateParameter(
    304             1, ShapeUtil::MakeShape(F32, {}), "arg1"));
    305     builder.AddInstruction(HloInstruction::CreateBinary(
    306         ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, arg0, arg1));
    307     return module->AddEmbeddedComputation(builder.Build());
    308   }
    309 };
    310 
    311 TEST_F(OpcodeFusionTest, Exponential_Reshape_Negate) {
    312   HloComputation::Builder builder(TestName());
    313   Shape param_shape = ShapeUtil::MakeShape(F32, {1, 4});
    314   Shape result_shape = ShapeUtil::MakeShape(F32, {4});
    315   HloInstruction* param0 = builder.AddInstruction(
    316       HloInstruction::CreateParameter(0, param_shape, "param"));
    317   HloInstruction* exp1 = builder.AddInstruction(
    318       HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0));
    319   HloInstruction* reshape2 =
    320       builder.AddInstruction(HloInstruction::CreateReshape(result_shape, exp1));
    321   builder.AddInstruction(
    322       HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape2));
    323 
    324   auto module = CreateNewVerifiedModule();
    325   module->AddEntryComputation(builder.Build());
    326 
    327   RunFusionAndCheckOpcodesWereFused(
    328       module.get(), {HloOpcode::kNegate, HloOpcode::kReshape, HloOpcode::kExp,
    329                      HloOpcode::kParameter});
    330 }
    331 
    332 TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) {
    333   HloComputation::Builder builder(TestName());
    334   Shape param_shape = ShapeUtil::MakeShape(F32, {8});
    335   Shape starts_shape = ShapeUtil::MakeShape(F32, {});
    336   Shape broadcast_shape = ShapeUtil::MakeShape(F32, {1, 8, 8});
    337   Shape reshape_shape = ShapeUtil::MakeShape(F32, {8, 8});
    338   Shape dynamic_slice_shape = ShapeUtil::MakeShape(F32, {4, 4});
    339   HloInstruction* param0 = builder.AddInstruction(
    340       HloInstruction::CreateParameter(0, param_shape, "param"));
    341   HloInstruction* param1 = builder.AddInstruction(
    342       HloInstruction::CreateParameter(1, starts_shape, "starts"));
    343   HloInstruction* param2 = builder.AddInstruction(
    344       HloInstruction::CreateParameter(2, starts_shape, "starts"));
    345   HloInstruction* broadcast2 = builder.AddInstruction(
    346       HloInstruction::CreateBroadcast(broadcast_shape, param0, {1}));
    347   HloInstruction* reshape3 = builder.AddInstruction(
    348       HloInstruction::CreateReshape(reshape_shape, broadcast2));
    349   HloInstruction* dynamic_slice4 =
    350       builder.AddInstruction(HloInstruction::CreateDynamicSlice(
    351           dynamic_slice_shape, reshape3, {param1, param2}, {4, 4}));
    352   builder.AddInstruction(HloInstruction::CreateUnary(
    353       dynamic_slice_shape, HloOpcode::kTanh, dynamic_slice4));
    354 
    355   auto module = CreateNewUnverifiedModule();
    356   module->AddEntryComputation(builder.Build());
    357 
    358   RunFusionAndCheckOpcodesWereFused(
    359       module.get(),
    360       {HloOpcode::kTanh, HloOpcode::kDynamicSlice, HloOpcode::kReshape,
    361        HloOpcode::kBroadcast, HloOpcode::kParameter, HloOpcode::kParameter,
    362        HloOpcode::kParameter});
    363 }
    364 
    365 TEST_F(OpcodeFusionTest, Broadcast_Negate) {
    366   HloComputation::Builder builder(TestName());
    367   Shape param_shape = ShapeUtil::MakeShape(F32, {8});
    368   Shape result_shape = ShapeUtil::MakeShape(F32, {8, 8});
    369   HloInstruction* param0 = builder.AddInstruction(
    370       HloInstruction::CreateParameter(0, param_shape, "param"));
    371   HloInstruction* broadcast1 = builder.AddInstruction(
    372       HloInstruction::CreateBroadcast(result_shape, param0, {1}));
    373   builder.AddInstruction(HloInstruction::CreateUnary(
    374       result_shape, HloOpcode::kNegate, broadcast1));
    375 
    376   auto module = CreateNewVerifiedModule();
    377   module->AddEntryComputation(builder.Build());
    378 
    379   RunFusionAndCheckOpcodesWereFused(
    380       module.get(),
    381       {HloOpcode::kNegate, HloOpcode::kBroadcast, HloOpcode::kParameter});
    382 }
    383 
    384 TEST_F(OpcodeFusionTest, DynamicSlice_Negate) {
    385   HloComputation::Builder builder(TestName());
    386   Shape param_shape = ShapeUtil::MakeShape(F32, {4});
    387   Shape slice_shape = ShapeUtil::MakeShape(F32, {});
    388   Shape result_shape = ShapeUtil::MakeShape(F32, {2});
    389   HloInstruction* param0 = builder.AddInstruction(
    390       HloInstruction::CreateParameter(0, param_shape, "param"));
    391   HloInstruction* param1 = builder.AddInstruction(
    392       HloInstruction::CreateParameter(1, slice_shape, "starts"));
    393   HloInstruction* dynamic_slice2 = builder.AddInstruction(
    394       HloInstruction::CreateDynamicSlice(result_shape, param0, {param1}, {2}));
    395   builder.AddInstruction(HloInstruction::CreateUnary(
    396       result_shape, HloOpcode::kNegate, dynamic_slice2));
    397 
    398   auto module = CreateNewUnverifiedModule();
    399   module->AddEntryComputation(builder.Build());
    400 
    401   RunFusionAndCheckOpcodesWereFused(
    402       module.get(), {HloOpcode::kNegate, HloOpcode::kDynamicSlice,
    403                      HloOpcode::kParameter, HloOpcode::kParameter});
    404 }
    405 
    406 TEST_F(OpcodeFusionTest, Exponential_Negate) {
    407   HloComputation::Builder builder(TestName());
    408   Shape param_shape = ShapeUtil::MakeShape(F32, {4});
    409   HloInstruction* param0 = builder.AddInstruction(
    410       HloInstruction::CreateParameter(0, param_shape, "param"));
    411   HloInstruction* exp1 = builder.AddInstruction(
    412       HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0));
    413   builder.AddInstruction(
    414       HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, exp1));
    415 
    416   auto module = CreateNewVerifiedModule();
    417   module->AddEntryComputation(builder.Build());
    418 
    419   RunFusionAndCheckOpcodesWereFused(
    420       module.get(),
    421       {HloOpcode::kNegate, HloOpcode::kExp, HloOpcode::kParameter});
    422 }
    423 
    424 TEST_F(OpcodeFusionTest, Reshape_Negate) {
    425   HloComputation::Builder builder(TestName());
    426   Shape param_shape = ShapeUtil::MakeShape(F32, {4, 4});
    427   Shape result_shape = ShapeUtil::MakeShape(F32, {16});
    428   HloInstruction* param0 = builder.AddInstruction(
    429       HloInstruction::CreateParameter(0, param_shape, "param"));
    430   HloInstruction* reshape1 = builder.AddInstruction(
    431       HloInstruction::CreateReshape(result_shape, param0));
    432   builder.AddInstruction(
    433       HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape1));
    434 
    435   auto module = CreateNewVerifiedModule();
    436   module->AddEntryComputation(builder.Build());
    437 
    438   RunFusionAndCheckOpcodesWereFused(
    439       module.get(),
    440       {HloOpcode::kNegate, HloOpcode::kReshape, HloOpcode::kParameter});
    441 }
    442 
    443 TEST_F(OpcodeFusionTest, Reverse_Negate) {
    444   HloComputation::Builder builder(TestName());
    445   Shape param_shape = ShapeUtil::MakeShape(F32, {8});
    446   HloInstruction* param0 = builder.AddInstruction(
    447       HloInstruction::CreateParameter(0, param_shape, "param"));
    448   HloInstruction* reverse1 = builder.AddInstruction(
    449       HloInstruction::CreateReverse(param_shape, param0, {0}));
    450   builder.AddInstruction(
    451       HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, reverse1));
    452 
    453   auto module = CreateNewVerifiedModule();
    454   module->AddEntryComputation(builder.Build());
    455 
    456   RunFusionAndCheckOpcodesWereFused(
    457       module.get(),
    458       {HloOpcode::kNegate, HloOpcode::kReverse, HloOpcode::kParameter});
    459 }
    460 
    461 TEST_F(OpcodeFusionTest, Slice_Negate) {
    462   HloComputation::Builder builder(TestName());
    463   Shape param_shape = ShapeUtil::MakeShape(F32, {4});
    464   Shape slice_shape = ShapeUtil::MakeShape(F32, {2});
    465   HloInstruction* param0 = builder.AddInstruction(
    466       HloInstruction::CreateParameter(0, param_shape, "param"));
    467   HloInstruction* slice1 = builder.AddInstruction(
    468       HloInstruction::CreateSlice(slice_shape, param0, {0}, {4}, {2}));
    469   builder.AddInstruction(HloInstruction::CreateUnary(
    470       ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, slice1));
    471 
    472   auto module = CreateNewUnverifiedModule();
    473   module->AddEntryComputation(builder.Build());
    474 
    475   RunFusionAndCheckOpcodesWereFused(
    476       module.get(),
    477       {HloOpcode::kNegate, HloOpcode::kSlice, HloOpcode::kParameter});
    478 }
    479 
    480 TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) {
    481   HloComputation::Builder builder(TestName());
    482   Shape param_shape = ShapeUtil::MakeShape(F32, {3, 4});
    483   Shape result_shape = ShapeUtil::MakeShape(F32, {4, 3});
    484   HloInstruction* param0 = builder.AddInstruction(
    485       HloInstruction::CreateParameter(0, param_shape, "param"));
    486   // InstructionFusion::ShouldFuse() precludes fusing a transpose whose operand
    487   // is a parameter, so create an operand between the parameter and transpose.
    488   HloInstruction* exp1 = builder.AddInstruction(
    489       HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0));
    490   HloInstruction* transpose2 = builder.AddInstruction(
    491       HloInstruction::CreateTranspose(result_shape, exp1, {1, 0}));
    492   builder.AddInstruction(HloInstruction::CreateUnary(
    493       result_shape, HloOpcode::kNegate, transpose2));
    494 
    495   auto module = CreateNewVerifiedModule();
    496   module->AddEntryComputation(builder.Build());
    497 
    498   RunFusionAndCheckOpcodesWereFused(
    499       module.get(), {HloOpcode::kNegate, HloOpcode::kTranspose, HloOpcode::kExp,
    500                      HloOpcode::kParameter});
    501 }
    502 
    503 TEST_F(OpcodeFusionTest, UnaryMapOfExp) {
    504   auto module = CreateNewVerifiedModule();
    505 
    506   HloComputation::Builder builder(TestName());
    507   Shape shape = ShapeUtil::MakeShape(F32, {3, 4});
    508   HloInstruction* param0 = builder.AddInstruction(
    509       HloInstruction::CreateParameter(0, shape, "param"));
    510 
    511   HloInstruction* exp = builder.AddInstruction(
    512       HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0));
    513   builder.AddInstruction(
    514       HloInstruction::CreateMap(shape, {exp}, CreateAdderToOne(module.get())));
    515 
    516   module->AddEntryComputation(builder.Build());
    517 
    518   RunFusionAndCheckOpcodesWereFused(
    519       module.get(), {HloOpcode::kParameter, HloOpcode::kExp, HloOpcode::kMap});
    520 }
    521 
    522 TEST_F(OpcodeFusionTest, BinaryMapOfExps) {
    523   auto module = CreateNewVerifiedModule();
    524 
    525   HloComputation::Builder builder(TestName());
    526   Shape shape = ShapeUtil::MakeShape(F32, {3, 4});
    527   HloInstruction* param0 = builder.AddInstruction(
    528       HloInstruction::CreateParameter(0, shape, "param"));
    529   HloInstruction* param1 = builder.AddInstruction(
    530       HloInstruction::CreateParameter(1, shape, "param"));
    531 
    532   HloInstruction* exp0 = builder.AddInstruction(
    533       HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0));
    534   HloInstruction* exp1 = builder.AddInstruction(
    535       HloInstruction::CreateUnary(shape, HloOpcode::kExp, param1));
    536 
    537   builder.AddInstruction(
    538       HloInstruction::CreateMap(shape, {exp0, exp1}, CreateMax(module.get())));
    539 
    540   module->AddEntryComputation(builder.Build());
    541 
    542   RunFusionAndCheckOpcodesWereFused(
    543       module.get(), {HloOpcode::kParameter, HloOpcode::kParameter,
    544                      HloOpcode::kExp, HloOpcode::kExp, HloOpcode::kMap});
    545 }
    546 
    547 TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) {
    548   auto module = CreateNewVerifiedModule();
    549 
    550   HloComputation::Builder builder(TestName());
    551   Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
    552   Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000});
    553 
    554   std::vector<HloInstruction*> slice_indices, update_indices;
    555   for (int i = 0; i < 3; ++i) {
    556     slice_indices.push_back(
    557         builder.AddInstruction(HloInstruction::CreateParameter(
    558             1 + i, ShapeUtil::MakeShape(U32, {}), "slice_indices")));
    559     update_indices.push_back(
    560         builder.AddInstruction(HloInstruction::CreateParameter(
    561             5 + i, ShapeUtil::MakeShape(U32, {}), "update_indices")));
    562   }
    563   HloInstruction* slice =
    564       builder.AddInstruction(HloInstruction::CreateDynamicSlice(
    565           slice_shape,
    566           builder.AddInstruction(
    567               HloInstruction::CreateParameter(0, full_shape, "slice_from")),
    568           slice_indices,
    569           /*slice_sizes=*/{10, 1, 1000}));
    570 
    571   builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
    572       full_shape,
    573       builder.AddInstruction(
    574           HloInstruction::CreateParameter(4, full_shape, "to_update")),
    575       slice, update_indices));
    576 
    577   module->AddEntryComputation(builder.Build());
    578   RunFusionAndCheckOpcodesWereFused(
    579       module.get(),
    580       {HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice,
    581        HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter,
    582        HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter,
    583        HloOpcode::kParameter, HloOpcode::kParameter});
    584 }
    585 
    586 TEST_F(OpcodeFusionTest, MessOfFusibleNodes) {
    587   auto module = CreateNewVerifiedModule();
    588   HloComputation::Builder builder(TestName());
    589 
    590   Shape full_shape = ShapeUtil::MakeShape(F32, {4, 100, 10, 100, 50});
    591 
    592   auto loop_idx = builder.AddInstruction(HloInstruction::CreateParameter(
    593       0, ShapeUtil::MakeShape(S32, {}), "param0"));
    594   auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
    595       1, ShapeUtil::MakeShape(S32, {}), "param1"));
    596 
    597   auto idx_choice = builder.AddInstruction(HloInstruction::CreateReshape(
    598       ShapeUtil::MakeShape(S32, {}),
    599       builder.AddInstruction(HloInstruction::CreateDynamicSlice(
    600           ShapeUtil::MakeShape(S32, {1}),
    601           builder.AddInstruction(HloInstruction::CreateParameter(
    602               2, ShapeUtil::MakeShape(S32, {4}), "param2")),
    603           {loop_idx},
    604           /*slice_sizes=*/{1}))));
    605   auto zero = builder.AddInstruction(
    606       HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
    607 
    608   auto slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
    609       ShapeUtil::MakeShape(F32, {1, 100, 10, 100, 50}),
    610       builder.AddInstruction(HloInstruction::CreateParameter(
    611           3, ShapeUtil::MakeShape(F32, {100, 100, 10, 100, 50}), "param3")),
    612       {idx_choice, zero, zero, zero, zero},
    613       /*slice_sizes=*/{1, 100, 10, 100, 50}));
    614 
    615   builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
    616       full_shape,
    617       builder.AddInstruction(
    618           HloInstruction::CreateParameter(4, full_shape, "param4")),
    619       slice, {loop_idx, param1, param1, param1, param1}));
    620 
    621   module->AddEntryComputation(builder.Build());
    622   RunFusionAndCheckOpcodesWereFused(
    623       module.get(),
    624       {HloOpcode::kDynamicSlice, HloOpcode::kDynamicSlice,
    625        HloOpcode::kDynamicUpdateSlice, HloOpcode::kReshape,
    626        HloOpcode::kConstant, HloOpcode::kParameter, HloOpcode::kParameter,
    627        HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter});
    628 }
    629 
    630 void CreateComputationForDotAddOutputFusionTest(const string& test_name,
    631                                                 HloModule* module, int m, int k,
    632                                                 int n,
    633                                                 bool add_extra_use_for_dot) {
    634   HloComputation::Builder builder(test_name);
    635 
    636   Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
    637   Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
    638   Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
    639 
    640   auto* dot_lhs = builder.AddInstruction(
    641       HloInstruction::CreateParameter(0, dot_lhs_shape, "param0"));
    642   auto* dot_rhs = builder.AddInstruction(
    643       HloInstruction::CreateParameter(1, dot_rhs_shape, "param1"));
    644   auto* addend = builder.AddInstruction(
    645       HloInstruction::CreateParameter(2, dot_shape, "param2"));
    646 
    647   auto* dot =
    648       builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
    649   builder.AddInstruction(
    650       HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend));
    651 
    652   if (add_extra_use_for_dot) {
    653     auto* token = builder.AddInstruction(HloInstruction::CreateToken());
    654     builder.AddInstruction(
    655         HloInstruction::CreateOutfeed(dot_shape, dot, token, "no_config"));
    656   }
    657 
    658   module->AddEntryComputation(builder.Build());
    659 }
    660 
    661 TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) {
    662   auto module = CreateNewVerifiedModule();
    663   CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/1,
    664                                              /*k=*/50, /*n=*/19,
    665                                              /*add_extra_use_for_dot=*/false);
    666 
    667   RunFusionAndCheckOpcodesWereFused(
    668       module.get(),
    669       {HloOpcode::kDot, HloOpcode::kAdd, HloOpcode::kParameter,
    670        HloOpcode::kParameter, HloOpcode::kParameter},
    671       HloInstruction::FusionKind::kOutput);
    672 }
    673 
    674 TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) {
    675   auto module = CreateNewVerifiedModule();
    676   CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19,
    677                                              /*k=*/50, /*n=*/1,
    678                                              /*add_extra_use_for_dot=*/false);
    679 
    680   RunFusionAndCheckOpcodesWereFused(
    681       module.get(),
    682       {HloOpcode::kDot, HloOpcode::kAdd, HloOpcode::kParameter,
    683        HloOpcode::kParameter, HloOpcode::kParameter},
    684       HloInstruction::FusionKind::kOutput);
    685 }
    686 
    687 TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) {
    688   auto module = CreateNewVerifiedModule();
    689   CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19,
    690                                              /*k=*/50, /*n=*/19,
    691                                              /*add_extra_use_for_dot=*/false);
    692 
    693   TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
    694                           CpuInstructionFusion().Run(module.get()));
    695   EXPECT_FALSE(fused_something);
    696   EXPECT_THAT(module->entry_computation()->root_instruction(),
    697               Not(op::Fusion()));
    698 }
    699 
    700 TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) {
    701   auto module = CreateNewVerifiedModule();
    702   CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19,
    703                                              /*k=*/50, /*n=*/1,
    704                                              /*add_extra_use_for_dot=*/true);
    705 
    706   TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
    707                           CpuInstructionFusion().Run(module.get()));
    708   EXPECT_FALSE(fused_something);
    709   EXPECT_THAT(module->entry_computation()->root_instruction(),
    710               Not(op::Fusion()));
    711 }
    712 
    713 TEST_F(InstructionFusionTest,
    714        DotOperationFusion_DontOutputFuseDuplicateOperands) {
    715   absl::string_view module_string = R"(
    716 HloModule module
    717 
    718 ENTRY main {
    719   a = f32[50,60]{1,0} parameter(0)
    720   b = f32[60,1]{1,0} parameter(1)
    721   c = f32[50,1]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
    722   ROOT d = f32[50,1]{1,0} add(c, c)
    723 }
    724 )";
    725 
    726   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
    727                           ParseAndReturnVerifiedModule(module_string));
    728   TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
    729                           CpuInstructionFusion().Run(module.get()));
    730   EXPECT_FALSE(fused_something);
    731   EXPECT_THAT(module->entry_computation()->root_instruction(),
    732               Not(op::Fusion()));
    733 }
    734 
    735 struct GatherLoopFusionTestSpec {
    736   string test_name;
    737   string hlo_computation_text;
    738 
    739   static string Name(
    740       const ::testing::TestParamInfo<GatherLoopFusionTestSpec>& info) {
    741     return info.param.test_name;
    742   }
    743 };
    744 
    745 class GatherLoopFusionTest
    746     : public OpcodeFusionTest,
    747       public ::testing::WithParamInterface<GatherLoopFusionTestSpec> {};
    748 
    749 TEST_P(GatherLoopFusionTest, GatherLoopFusion) {
    750   const GatherLoopFusionTestSpec& spec = GetParam();
    751   string hlo_string = absl::StrCat("HloModule ", spec.test_name, "\n\n",
    752                                    spec.hlo_computation_text);
    753   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
    754                           ParseHloString(hlo_string));
    755 
    756   RunFusionAndCheckOpcodesWereFused(
    757       module.get(),
    758       {HloOpcode::kGather, HloOpcode::kAdd, HloOpcode::kBroadcast,
    759        HloOpcode::kConstant, HloOpcode::kParameter, HloOpcode::kParameter});
    760 }
    761 
    762 std::vector<GatherLoopFusionTestSpec> GetGatherLoopFusionTestSpecs() {
    763   std::vector<GatherLoopFusionTestSpec> result;
    764 
    765   result.push_back({"FusedTensorFlowGatherV2", R"(
    766 ENTRY main {
    767   operand = s32[3,3] parameter(0)
    768   indices = s32[2] parameter(1)
    769   gather = s32[3,2] gather(operand, indices),
    770       offset_dims={0},
    771       collapsed_slice_dims={1},
    772       start_index_map={1},
    773       index_vector_dim=1,
    774       slice_sizes={3, 1}
    775   one = s32[] constant(1)
    776   one_broadcasted = s32[3,2] broadcast(one), dimensions={}
    777   ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted)
    778 }
    779 )"});
    780 
    781   result.push_back({"FusedTensorFlowGatherMultipleBatchDims", R"(
    782 ENTRY main {
    783   operand = s32[3,3] parameter(0)
    784   indices = s32[2,2] parameter(1)
    785   gather = s32[2,3,2] gather(operand, indices),
    786       offset_dims={1},
    787       collapsed_slice_dims={1},
    788       start_index_map={1},
    789       index_vector_dim=2,
    790       slice_sizes={3, 1}
    791   one = s32[] constant(1)
    792   one_broadcasted = s32[2,3,2] broadcast(one), dimensions={}
    793   ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted)
    794 }
    795 )"});
    796 
    797   result.push_back({"FusedTensorFlowGatherNdMultipleBatchDims", R"(
    798 ENTRY main {
    799   operand = s32[3,3] parameter(0)
    800   indices = s32[2,2,2] parameter(1)
    801   gather = s32[2,2] gather(operand, indices),
    802       offset_dims={},
    803       collapsed_slice_dims={0,1},
    804       start_index_map={0,1},
    805       index_vector_dim=2,
    806       slice_sizes={1, 1}
    807   one = s32[] constant(1)
    808   one_broadcasted = s32[2,2] broadcast(one), dimensions={}
    809   ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
    810 }
    811 )"});
    812 
    813   result.push_back({"FusedTensorFlowGatherNd_0", R"(
    814 ENTRY main {
    815   operand = s32[3,3,2] parameter(0)
    816   indices = s32[2,2] parameter(1)
    817   gather = s32[2,2] gather(operand, indices),
    818       offset_dims={1},
    819       collapsed_slice_dims={0,1},
    820       start_index_map={0,1},
    821       index_vector_dim=1,
    822       slice_sizes={1,1,2}
    823   one = s32[] constant(1)
    824   one_broadcasted = s32[2,2] broadcast(one), dimensions={}
    825   ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
    826 }
    827 )"});
    828 
    829   result.push_back({"FusedTensorFlowGatherNd_1", R"(
    830 ENTRY main {
    831   operand = s32[3,3,2] parameter(0)
    832   indices = s32[2,2] parameter(1)
    833   gather = s32[2,2] gather(operand, indices),
    834       offset_dims={1},
    835       collapsed_slice_dims={0,1},
    836       start_index_map={0,1},
    837       index_vector_dim=0,
    838       slice_sizes={1,1,2}
    839   one = s32[] constant(1)
    840   one_broadcasted = s32[2,2] broadcast(one), dimensions={}
    841   ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
    842 }
    843 )"});
    844 
    845   result.push_back({"FusedDynamicSlice", R"(
    846 ENTRY main {
    847   operand = s32[3,3] parameter(0)
    848   indices = s32[2] parameter(1)
    849   gather = s32[1,1] gather(operand, indices),
    850       offset_dims={0,1},
    851       collapsed_slice_dims={},
    852       start_index_map={0,1},
    853       index_vector_dim=0,
    854       slice_sizes={1,1}
    855   one = s32[] constant(1)
    856   one_broadcasted = s32[1,1] broadcast(one), dimensions={}
    857   ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted)
    858 }
    859 )"});
    860 
    861   result.push_back({"FusedBatchDynamicSlice", R"(
    862 ENTRY main {
    863   operand = s32[3,3] parameter(0)
    864   indices = s32[2,2] parameter(1)
    865   gather = s32[2,1,1] gather(operand, indices),
    866       offset_dims={1,2},
    867       collapsed_slice_dims={},
    868       start_index_map={0,1},
    869       index_vector_dim=0,
    870       slice_sizes={1,1}
    871   one = s32[] constant(1)
    872   one_broadcasted = s32[2,1,1] broadcast(one), dimensions={}
    873   ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted)
    874 }
    875 )"});
    876 
    877   return result;
    878 }
    879 
    880 INSTANTIATE_TEST_SUITE_P(GatherLoopFusionTestInstantiation,
    881                          GatherLoopFusionTest,
    882                          ::testing::ValuesIn(GetGatherLoopFusionTestSpecs()),
    883                          GatherLoopFusionTestSpec::Name);
    884 }  // namespace
    885 }  // namespace cpu
    886 }  // namespace xla
    887