1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" 17 18 #include <algorithm> 19 #include <set> 20 21 #include "absl/strings/str_cat.h" 22 #include "absl/types/span.h" 23 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 24 #include "tensorflow/compiler/xla/service/hlo_parser.h" 25 #include "tensorflow/compiler/xla/service/transpose_folding.h" 26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 27 #include "tensorflow/compiler/xla/tests/test_utils.h" 28 29 namespace op = xla::testing::opcode_matchers; 30 31 namespace xla { 32 namespace cpu { 33 namespace { 34 35 using InstructionFusionTest = HloTestBase; 36 37 std::unique_ptr<HloInstruction> MakeDot(const Shape& shape, HloInstruction* lhs, 38 HloInstruction* rhs) { 39 DotDimensionNumbers dot_dnums; 40 dot_dnums.add_lhs_contracting_dimensions(1); 41 dot_dnums.add_rhs_contracting_dimensions(0); 42 PrecisionConfig precision_config; 43 precision_config.mutable_operand_precision()->Resize( 44 2, PrecisionConfig::DEFAULT); 45 return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums, 46 precision_config); 47 } 48 49 TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { 50 HloComputation::Builder builder(TestName()); 51 HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 52 0, ShapeUtil::MakeShape(F32, {1024, 256}), "arg0")); 53 HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( 54 1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1")); 55 56 HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary( 57 ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg0)); 58 HloInstruction* dot = builder.AddInstruction( 59 MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), exp0, arg1)); 60 61 auto module = CreateNewUnverifiedModule(); 62 auto computation = module->AddEntryComputation(builder.Build()); 63 EXPECT_EQ(dot, computation->root_instruction()); 64 EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); 65 EXPECT_THAT(computation->root_instruction(), op::Fusion()); 66 } 67 68 TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) { 69 HloComputation::Builder builder(TestName()); 70 HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 71 0, ShapeUtil::MakeShape(F32, {1, 256}), "arg0")); 72 HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( 73 1, ShapeUtil::MakeShape(F32, {256, 1024}), "arg1")); 74 75 HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( 76 ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1)); 77 HloInstruction* dot = builder.AddInstruction( 78 MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, exp1)); 79 80 auto module = CreateNewUnverifiedModule(); 81 auto computation = module->AddEntryComputation(builder.Build()); 82 EXPECT_EQ(dot, computation->root_instruction()); 83 EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); 84 EXPECT_THAT(computation->root_instruction(), op::Fusion()); 85 } 86 87 TEST_F(InstructionFusionTest, DotOperationNoFusion_Bitcast) { 88 HloComputation::Builder builder(TestName()); 89 HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 90 0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0")); 91 HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( 92 1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1")); 93 94 HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary( 95 ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0)); 96 HloInstruction* bitcast0 = builder.AddInstruction(HloInstruction::CreateUnary( 97 ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kBitcast, exp0)); 98 HloInstruction* dot = builder.AddInstruction( 99 MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), bitcast0, arg1)); 100 101 auto module = CreateNewUnverifiedModule(); 102 auto computation = module->AddEntryComputation(builder.Build()); 103 EXPECT_EQ(dot, computation->root_instruction()); 104 EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); 105 } 106 107 TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { 108 HloComputation::Builder builder(TestName()); 109 HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 110 0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0")); 111 HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( 112 1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1")); 113 114 HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary( 115 ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0)); 116 HloInstruction* reshape0 = 117 builder.AddInstruction(HloInstruction::CreateReshape( 118 ShapeUtil::MakeShape(S32, {1024, 256}), exp0)); 119 HloInstruction* dot = builder.AddInstruction( 120 MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), reshape0, arg1)); 121 122 auto module = CreateNewUnverifiedModule(); 123 auto computation = module->AddEntryComputation(builder.Build()); 124 EXPECT_EQ(dot, computation->root_instruction()); 125 EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); 126 EXPECT_THAT(computation->root_instruction(), op::Fusion()); 127 } 128 129 TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) { 130 HloComputation::Builder builder(TestName()); 131 HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 132 0, ShapeUtil::MakeShape(F32, {1, 32 * 1024}), "arg0")); 133 HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( 134 1, ShapeUtil::MakeShape(F32, {256, 32 * 1024}), "arg1")); 135 136 HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( 137 ShapeUtil::MakeShape(S32, {256, 32 * 1024}), HloOpcode::kExp, arg1)); 138 HloInstruction* dot = builder.AddInstruction( 139 MakeDot(ShapeUtil::MakeShape(F32, {1, 32 * 1024}), arg0, exp1)); 140 141 auto module = CreateNewUnverifiedModule(); 142 auto computation = module->AddEntryComputation(builder.Build()); 143 EXPECT_EQ(dot, computation->root_instruction()); 144 EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); 145 EXPECT_EQ(dot, computation->root_instruction()); 146 } 147 148 TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) { 149 HloComputation::Builder builder(TestName()); 150 HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 151 0, ShapeUtil::MakeShape(F32, {2, 256}), "arg0")); 152 HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( 153 1, ShapeUtil::MakeShape(F32, {256, 1024}), "arg1")); 154 155 HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( 156 ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1)); 157 HloInstruction* dot = builder.AddInstruction( 158 MakeDot(ShapeUtil::MakeShape(F32, {2, 1024}), arg0, exp1)); 159 160 auto module = CreateNewUnverifiedModule(); 161 auto computation = module->AddEntryComputation(builder.Build()); 162 EXPECT_EQ(dot, computation->root_instruction()); 163 EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); 164 EXPECT_EQ(dot, computation->root_instruction()); 165 } 166 167 TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion_RHS) { 168 string hlo_string = R"( 169 HloModule DotOperationFusion_TransposeFusion 170 171 ENTRY DotOperationFusion_TransposeFusion { 172 arg0 = f32[1,256] parameter(0) 173 arg1 = f32[1024,256] parameter(1) 174 exponential = s32[1024,256] exponential(arg1) 175 transpose = s32[256,1024] transpose(exponential), dimensions={1,0} 176 ROOT dot = f32[1,1024] dot(arg0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} 177 } 178 )"; 179 180 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, 181 ParseHloString(hlo_string)); 182 HloComputation* computation = module->entry_computation(); 183 184 TransposeFolding transpose_folding( 185 [](const HloInstruction& dot, 186 const TransposeFolding::OperandIndices& candidate_operands) { 187 return candidate_operands; 188 }, 189 TransposeFolding::NeverFoldTranspose); 190 TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); 191 ASSERT_TRUE(changed); 192 ASSERT_THAT(computation->root_instruction(), 193 op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)), 194 /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1)); 195 } 196 197 TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion_LHS) { 198 string hlo_string = R"( 199 HloModule DotOperationFusion_TransposeFusion 200 201 ENTRY DotOperationFusion_TransposeFusion { 202 arg0 = f32[256,1] parameter(0) 203 arg1 = f32[256,1024] parameter(1) 204 transpose = s32[1,256] transpose(arg0), dimensions={1,0} 205 exponential = s32[256,1024] exponential(arg1) 206 ROOT dot = f32[1,1024] dot(transpose, exponential), lhs_contracting_dims={1}, rhs_contracting_dims={0} 207 } 208 )"; 209 210 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, 211 ParseHloString(hlo_string)); 212 HloComputation* computation = module->entry_computation(); 213 214 TransposeFolding transpose_folding( 215 [](const HloInstruction& dot, 216 const TransposeFolding::OperandIndices& candidate_operands) { 217 return candidate_operands; 218 }, 219 TransposeFolding::NeverFoldTranspose); 220 TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); 221 ASSERT_TRUE(changed); 222 ASSERT_THAT(computation->root_instruction(), 223 op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)), 224 /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/0)); 225 } 226 227 TEST_F(InstructionFusionTest, 228 DotOperationFusion_TransposeFusion_LHS_NonDefault) { 229 string hlo_string = R"( 230 HloModule DotOperationFusion_TransposeFusion 231 232 ENTRY DotOperationFusion_TransposeFusion { 233 arg0 = f32[1,256] parameter(0) 234 arg1 = f32[256,1024] parameter(1) 235 transpose = s32[256,1] transpose(arg0), dimensions={1,0} 236 exponential = s32[256,1024] exponential(arg1) 237 ROOT dot = f32[1,1024] dot(transpose, exponential), lhs_contracting_dims={0}, rhs_contracting_dims={0} 238 } 239 )"; 240 241 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, 242 ParseHloString(hlo_string)); 243 HloComputation* computation = module->entry_computation(); 244 245 TransposeFolding transpose_folding( 246 [](const HloInstruction& dot, 247 const TransposeFolding::OperandIndices& candidate_operands) { 248 return candidate_operands; 249 }, 250 TransposeFolding::NeverFoldTranspose); 251 TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); 252 ASSERT_TRUE(changed); 253 ASSERT_THAT(computation->root_instruction(), 254 op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)), 255 /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0)); 256 } 257 258 class OpcodeFusionTest : public InstructionFusionTest { 259 protected: 260 // Runs CPU instruction fusion on the given module, and tests that the result 261 // contains a fused op at the root with exactly the given multiset of opcodes. 262 void RunFusionAndCheckOpcodesWereFused( 263 HloModule* module, const std::multiset<HloOpcode>& expected_opcodes, 264 HloInstruction::FusionKind fusion_kind = 265 HloInstruction::FusionKind::kLoop) { 266 auto computation = module->entry_computation(); 267 auto did_fusion = CpuInstructionFusion().Run(module); 268 ASSERT_TRUE(did_fusion.ok()); 269 EXPECT_TRUE(did_fusion.ValueOrDie()); 270 271 HloInstruction* root = computation->root_instruction(); 272 ASSERT_THAT(root, op::Fusion()); 273 EXPECT_EQ(root->fusion_kind(), fusion_kind); 274 275 std::vector<HloOpcode> fused_opcodes(root->fused_instruction_count()); 276 std::transform(root->fused_instructions().begin(), 277 root->fused_instructions().end(), fused_opcodes.begin(), 278 [](const HloInstruction* hlo) { return hlo->opcode(); }); 279 280 EXPECT_EQ( 281 std::multiset<HloOpcode>(fused_opcodes.begin(), fused_opcodes.end()), 282 expected_opcodes); 283 } 284 285 HloComputation* CreateAdderToOne(HloModule* module) { 286 HloComputation::Builder builder(TestName()); 287 HloInstruction* arg0 = 288 builder.AddInstruction(HloInstruction::CreateParameter( 289 0, ShapeUtil::MakeShape(F32, {}), "arg0")); 290 HloInstruction* one = builder.AddInstruction( 291 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); 292 builder.AddInstruction(HloInstruction::CreateBinary( 293 ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, arg0, one)); 294 return module->AddEmbeddedComputation(builder.Build()); 295 } 296 297 HloComputation* CreateMax(HloModule* module) { 298 HloComputation::Builder builder(TestName()); 299 HloInstruction* arg0 = 300 builder.AddInstruction(HloInstruction::CreateParameter( 301 0, ShapeUtil::MakeShape(F32, {}), "arg0")); 302 HloInstruction* arg1 = 303 builder.AddInstruction(HloInstruction::CreateParameter( 304 1, ShapeUtil::MakeShape(F32, {}), "arg1")); 305 builder.AddInstruction(HloInstruction::CreateBinary( 306 ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, arg0, arg1)); 307 return module->AddEmbeddedComputation(builder.Build()); 308 } 309 }; 310 311 TEST_F(OpcodeFusionTest, Exponential_Reshape_Negate) { 312 HloComputation::Builder builder(TestName()); 313 Shape param_shape = ShapeUtil::MakeShape(F32, {1, 4}); 314 Shape result_shape = ShapeUtil::MakeShape(F32, {4}); 315 HloInstruction* param0 = builder.AddInstruction( 316 HloInstruction::CreateParameter(0, param_shape, "param")); 317 HloInstruction* exp1 = builder.AddInstruction( 318 HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0)); 319 HloInstruction* reshape2 = 320 builder.AddInstruction(HloInstruction::CreateReshape(result_shape, exp1)); 321 builder.AddInstruction( 322 HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape2)); 323 324 auto module = CreateNewVerifiedModule(); 325 module->AddEntryComputation(builder.Build()); 326 327 RunFusionAndCheckOpcodesWereFused( 328 module.get(), {HloOpcode::kNegate, HloOpcode::kReshape, HloOpcode::kExp, 329 HloOpcode::kParameter}); 330 } 331 332 TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) { 333 HloComputation::Builder builder(TestName()); 334 Shape param_shape = ShapeUtil::MakeShape(F32, {8}); 335 Shape starts_shape = ShapeUtil::MakeShape(F32, {}); 336 Shape broadcast_shape = ShapeUtil::MakeShape(F32, {1, 8, 8}); 337 Shape reshape_shape = ShapeUtil::MakeShape(F32, {8, 8}); 338 Shape dynamic_slice_shape = ShapeUtil::MakeShape(F32, {4, 4}); 339 HloInstruction* param0 = builder.AddInstruction( 340 HloInstruction::CreateParameter(0, param_shape, "param")); 341 HloInstruction* param1 = builder.AddInstruction( 342 HloInstruction::CreateParameter(1, starts_shape, "starts")); 343 HloInstruction* param2 = builder.AddInstruction( 344 HloInstruction::CreateParameter(2, starts_shape, "starts")); 345 HloInstruction* broadcast2 = builder.AddInstruction( 346 HloInstruction::CreateBroadcast(broadcast_shape, param0, {1})); 347 HloInstruction* reshape3 = builder.AddInstruction( 348 HloInstruction::CreateReshape(reshape_shape, broadcast2)); 349 HloInstruction* dynamic_slice4 = 350 builder.AddInstruction(HloInstruction::CreateDynamicSlice( 351 dynamic_slice_shape, reshape3, {param1, param2}, {4, 4})); 352 builder.AddInstruction(HloInstruction::CreateUnary( 353 dynamic_slice_shape, HloOpcode::kTanh, dynamic_slice4)); 354 355 auto module = CreateNewUnverifiedModule(); 356 module->AddEntryComputation(builder.Build()); 357 358 RunFusionAndCheckOpcodesWereFused( 359 module.get(), 360 {HloOpcode::kTanh, HloOpcode::kDynamicSlice, HloOpcode::kReshape, 361 HloOpcode::kBroadcast, HloOpcode::kParameter, HloOpcode::kParameter, 362 HloOpcode::kParameter}); 363 } 364 365 TEST_F(OpcodeFusionTest, Broadcast_Negate) { 366 HloComputation::Builder builder(TestName()); 367 Shape param_shape = ShapeUtil::MakeShape(F32, {8}); 368 Shape result_shape = ShapeUtil::MakeShape(F32, {8, 8}); 369 HloInstruction* param0 = builder.AddInstruction( 370 HloInstruction::CreateParameter(0, param_shape, "param")); 371 HloInstruction* broadcast1 = builder.AddInstruction( 372 HloInstruction::CreateBroadcast(result_shape, param0, {1})); 373 builder.AddInstruction(HloInstruction::CreateUnary( 374 result_shape, HloOpcode::kNegate, broadcast1)); 375 376 auto module = CreateNewVerifiedModule(); 377 module->AddEntryComputation(builder.Build()); 378 379 RunFusionAndCheckOpcodesWereFused( 380 module.get(), 381 {HloOpcode::kNegate, HloOpcode::kBroadcast, HloOpcode::kParameter}); 382 } 383 384 TEST_F(OpcodeFusionTest, DynamicSlice_Negate) { 385 HloComputation::Builder builder(TestName()); 386 Shape param_shape = ShapeUtil::MakeShape(F32, {4}); 387 Shape slice_shape = ShapeUtil::MakeShape(F32, {}); 388 Shape result_shape = ShapeUtil::MakeShape(F32, {2}); 389 HloInstruction* param0 = builder.AddInstruction( 390 HloInstruction::CreateParameter(0, param_shape, "param")); 391 HloInstruction* param1 = builder.AddInstruction( 392 HloInstruction::CreateParameter(1, slice_shape, "starts")); 393 HloInstruction* dynamic_slice2 = builder.AddInstruction( 394 HloInstruction::CreateDynamicSlice(result_shape, param0, {param1}, {2})); 395 builder.AddInstruction(HloInstruction::CreateUnary( 396 result_shape, HloOpcode::kNegate, dynamic_slice2)); 397 398 auto module = CreateNewUnverifiedModule(); 399 module->AddEntryComputation(builder.Build()); 400 401 RunFusionAndCheckOpcodesWereFused( 402 module.get(), {HloOpcode::kNegate, HloOpcode::kDynamicSlice, 403 HloOpcode::kParameter, HloOpcode::kParameter}); 404 } 405 406 TEST_F(OpcodeFusionTest, Exponential_Negate) { 407 HloComputation::Builder builder(TestName()); 408 Shape param_shape = ShapeUtil::MakeShape(F32, {4}); 409 HloInstruction* param0 = builder.AddInstruction( 410 HloInstruction::CreateParameter(0, param_shape, "param")); 411 HloInstruction* exp1 = builder.AddInstruction( 412 HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0)); 413 builder.AddInstruction( 414 HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, exp1)); 415 416 auto module = CreateNewVerifiedModule(); 417 module->AddEntryComputation(builder.Build()); 418 419 RunFusionAndCheckOpcodesWereFused( 420 module.get(), 421 {HloOpcode::kNegate, HloOpcode::kExp, HloOpcode::kParameter}); 422 } 423 424 TEST_F(OpcodeFusionTest, Reshape_Negate) { 425 HloComputation::Builder builder(TestName()); 426 Shape param_shape = ShapeUtil::MakeShape(F32, {4, 4}); 427 Shape result_shape = ShapeUtil::MakeShape(F32, {16}); 428 HloInstruction* param0 = builder.AddInstruction( 429 HloInstruction::CreateParameter(0, param_shape, "param")); 430 HloInstruction* reshape1 = builder.AddInstruction( 431 HloInstruction::CreateReshape(result_shape, param0)); 432 builder.AddInstruction( 433 HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape1)); 434 435 auto module = CreateNewVerifiedModule(); 436 module->AddEntryComputation(builder.Build()); 437 438 RunFusionAndCheckOpcodesWereFused( 439 module.get(), 440 {HloOpcode::kNegate, HloOpcode::kReshape, HloOpcode::kParameter}); 441 } 442 443 TEST_F(OpcodeFusionTest, Reverse_Negate) { 444 HloComputation::Builder builder(TestName()); 445 Shape param_shape = ShapeUtil::MakeShape(F32, {8}); 446 HloInstruction* param0 = builder.AddInstruction( 447 HloInstruction::CreateParameter(0, param_shape, "param")); 448 HloInstruction* reverse1 = builder.AddInstruction( 449 HloInstruction::CreateReverse(param_shape, param0, {0})); 450 builder.AddInstruction( 451 HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, reverse1)); 452 453 auto module = CreateNewVerifiedModule(); 454 module->AddEntryComputation(builder.Build()); 455 456 RunFusionAndCheckOpcodesWereFused( 457 module.get(), 458 {HloOpcode::kNegate, HloOpcode::kReverse, HloOpcode::kParameter}); 459 } 460 461 TEST_F(OpcodeFusionTest, Slice_Negate) { 462 HloComputation::Builder builder(TestName()); 463 Shape param_shape = ShapeUtil::MakeShape(F32, {4}); 464 Shape slice_shape = ShapeUtil::MakeShape(F32, {2}); 465 HloInstruction* param0 = builder.AddInstruction( 466 HloInstruction::CreateParameter(0, param_shape, "param")); 467 HloInstruction* slice1 = builder.AddInstruction( 468 HloInstruction::CreateSlice(slice_shape, param0, {0}, {4}, {2})); 469 builder.AddInstruction(HloInstruction::CreateUnary( 470 ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, slice1)); 471 472 auto module = CreateNewUnverifiedModule(); 473 module->AddEntryComputation(builder.Build()); 474 475 RunFusionAndCheckOpcodesWereFused( 476 module.get(), 477 {HloOpcode::kNegate, HloOpcode::kSlice, HloOpcode::kParameter}); 478 } 479 480 TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) { 481 HloComputation::Builder builder(TestName()); 482 Shape param_shape = ShapeUtil::MakeShape(F32, {3, 4}); 483 Shape result_shape = ShapeUtil::MakeShape(F32, {4, 3}); 484 HloInstruction* param0 = builder.AddInstruction( 485 HloInstruction::CreateParameter(0, param_shape, "param")); 486 // InstructionFusion::ShouldFuse() precludes fusing a transpose whose operand 487 // is a parameter, so create an operand between the parameter and transpose. 488 HloInstruction* exp1 = builder.AddInstruction( 489 HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0)); 490 HloInstruction* transpose2 = builder.AddInstruction( 491 HloInstruction::CreateTranspose(result_shape, exp1, {1, 0})); 492 builder.AddInstruction(HloInstruction::CreateUnary( 493 result_shape, HloOpcode::kNegate, transpose2)); 494 495 auto module = CreateNewVerifiedModule(); 496 module->AddEntryComputation(builder.Build()); 497 498 RunFusionAndCheckOpcodesWereFused( 499 module.get(), {HloOpcode::kNegate, HloOpcode::kTranspose, HloOpcode::kExp, 500 HloOpcode::kParameter}); 501 } 502 503 TEST_F(OpcodeFusionTest, UnaryMapOfExp) { 504 auto module = CreateNewVerifiedModule(); 505 506 HloComputation::Builder builder(TestName()); 507 Shape shape = ShapeUtil::MakeShape(F32, {3, 4}); 508 HloInstruction* param0 = builder.AddInstruction( 509 HloInstruction::CreateParameter(0, shape, "param")); 510 511 HloInstruction* exp = builder.AddInstruction( 512 HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0)); 513 builder.AddInstruction( 514 HloInstruction::CreateMap(shape, {exp}, CreateAdderToOne(module.get()))); 515 516 module->AddEntryComputation(builder.Build()); 517 518 RunFusionAndCheckOpcodesWereFused( 519 module.get(), {HloOpcode::kParameter, HloOpcode::kExp, HloOpcode::kMap}); 520 } 521 522 TEST_F(OpcodeFusionTest, BinaryMapOfExps) { 523 auto module = CreateNewVerifiedModule(); 524 525 HloComputation::Builder builder(TestName()); 526 Shape shape = ShapeUtil::MakeShape(F32, {3, 4}); 527 HloInstruction* param0 = builder.AddInstruction( 528 HloInstruction::CreateParameter(0, shape, "param")); 529 HloInstruction* param1 = builder.AddInstruction( 530 HloInstruction::CreateParameter(1, shape, "param")); 531 532 HloInstruction* exp0 = builder.AddInstruction( 533 HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0)); 534 HloInstruction* exp1 = builder.AddInstruction( 535 HloInstruction::CreateUnary(shape, HloOpcode::kExp, param1)); 536 537 builder.AddInstruction( 538 HloInstruction::CreateMap(shape, {exp0, exp1}, CreateMax(module.get()))); 539 540 module->AddEntryComputation(builder.Build()); 541 542 RunFusionAndCheckOpcodesWereFused( 543 module.get(), {HloOpcode::kParameter, HloOpcode::kParameter, 544 HloOpcode::kExp, HloOpcode::kExp, HloOpcode::kMap}); 545 } 546 547 TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { 548 auto module = CreateNewVerifiedModule(); 549 550 HloComputation::Builder builder(TestName()); 551 Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); 552 Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000}); 553 554 std::vector<HloInstruction*> slice_indices, update_indices; 555 for (int i = 0; i < 3; ++i) { 556 slice_indices.push_back( 557 builder.AddInstruction(HloInstruction::CreateParameter( 558 1 + i, ShapeUtil::MakeShape(U32, {}), "slice_indices"))); 559 update_indices.push_back( 560 builder.AddInstruction(HloInstruction::CreateParameter( 561 5 + i, ShapeUtil::MakeShape(U32, {}), "update_indices"))); 562 } 563 HloInstruction* slice = 564 builder.AddInstruction(HloInstruction::CreateDynamicSlice( 565 slice_shape, 566 builder.AddInstruction( 567 HloInstruction::CreateParameter(0, full_shape, "slice_from")), 568 slice_indices, 569 /*slice_sizes=*/{10, 1, 1000})); 570 571 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( 572 full_shape, 573 builder.AddInstruction( 574 HloInstruction::CreateParameter(4, full_shape, "to_update")), 575 slice, update_indices)); 576 577 module->AddEntryComputation(builder.Build()); 578 RunFusionAndCheckOpcodesWereFused( 579 module.get(), 580 {HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice, 581 HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, 582 HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, 583 HloOpcode::kParameter, HloOpcode::kParameter}); 584 } 585 586 TEST_F(OpcodeFusionTest, MessOfFusibleNodes) { 587 auto module = CreateNewVerifiedModule(); 588 HloComputation::Builder builder(TestName()); 589 590 Shape full_shape = ShapeUtil::MakeShape(F32, {4, 100, 10, 100, 50}); 591 592 auto loop_idx = builder.AddInstruction(HloInstruction::CreateParameter( 593 0, ShapeUtil::MakeShape(S32, {}), "param0")); 594 auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( 595 1, ShapeUtil::MakeShape(S32, {}), "param1")); 596 597 auto idx_choice = builder.AddInstruction(HloInstruction::CreateReshape( 598 ShapeUtil::MakeShape(S32, {}), 599 builder.AddInstruction(HloInstruction::CreateDynamicSlice( 600 ShapeUtil::MakeShape(S32, {1}), 601 builder.AddInstruction(HloInstruction::CreateParameter( 602 2, ShapeUtil::MakeShape(S32, {4}), "param2")), 603 {loop_idx}, 604 /*slice_sizes=*/{1})))); 605 auto zero = builder.AddInstruction( 606 HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); 607 608 auto slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( 609 ShapeUtil::MakeShape(F32, {1, 100, 10, 100, 50}), 610 builder.AddInstruction(HloInstruction::CreateParameter( 611 3, ShapeUtil::MakeShape(F32, {100, 100, 10, 100, 50}), "param3")), 612 {idx_choice, zero, zero, zero, zero}, 613 /*slice_sizes=*/{1, 100, 10, 100, 50})); 614 615 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( 616 full_shape, 617 builder.AddInstruction( 618 HloInstruction::CreateParameter(4, full_shape, "param4")), 619 slice, {loop_idx, param1, param1, param1, param1})); 620 621 module->AddEntryComputation(builder.Build()); 622 RunFusionAndCheckOpcodesWereFused( 623 module.get(), 624 {HloOpcode::kDynamicSlice, HloOpcode::kDynamicSlice, 625 HloOpcode::kDynamicUpdateSlice, HloOpcode::kReshape, 626 HloOpcode::kConstant, HloOpcode::kParameter, HloOpcode::kParameter, 627 HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter}); 628 } 629 630 void CreateComputationForDotAddOutputFusionTest(const string& test_name, 631 HloModule* module, int m, int k, 632 int n, 633 bool add_extra_use_for_dot) { 634 HloComputation::Builder builder(test_name); 635 636 Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); 637 Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); 638 Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); 639 640 auto* dot_lhs = builder.AddInstruction( 641 HloInstruction::CreateParameter(0, dot_lhs_shape, "param0")); 642 auto* dot_rhs = builder.AddInstruction( 643 HloInstruction::CreateParameter(1, dot_rhs_shape, "param1")); 644 auto* addend = builder.AddInstruction( 645 HloInstruction::CreateParameter(2, dot_shape, "param2")); 646 647 auto* dot = 648 builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); 649 builder.AddInstruction( 650 HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend)); 651 652 if (add_extra_use_for_dot) { 653 auto* token = builder.AddInstruction(HloInstruction::CreateToken()); 654 builder.AddInstruction( 655 HloInstruction::CreateOutfeed(dot_shape, dot, token, "no_config")); 656 } 657 658 module->AddEntryComputation(builder.Build()); 659 } 660 661 TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) { 662 auto module = CreateNewVerifiedModule(); 663 CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/1, 664 /*k=*/50, /*n=*/19, 665 /*add_extra_use_for_dot=*/false); 666 667 RunFusionAndCheckOpcodesWereFused( 668 module.get(), 669 {HloOpcode::kDot, HloOpcode::kAdd, HloOpcode::kParameter, 670 HloOpcode::kParameter, HloOpcode::kParameter}, 671 HloInstruction::FusionKind::kOutput); 672 } 673 674 TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) { 675 auto module = CreateNewVerifiedModule(); 676 CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, 677 /*k=*/50, /*n=*/1, 678 /*add_extra_use_for_dot=*/false); 679 680 RunFusionAndCheckOpcodesWereFused( 681 module.get(), 682 {HloOpcode::kDot, HloOpcode::kAdd, HloOpcode::kParameter, 683 HloOpcode::kParameter, HloOpcode::kParameter}, 684 HloInstruction::FusionKind::kOutput); 685 } 686 687 TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) { 688 auto module = CreateNewVerifiedModule(); 689 CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, 690 /*k=*/50, /*n=*/19, 691 /*add_extra_use_for_dot=*/false); 692 693 TF_ASSERT_OK_AND_ASSIGN(bool fused_something, 694 CpuInstructionFusion().Run(module.get())); 695 EXPECT_FALSE(fused_something); 696 EXPECT_THAT(module->entry_computation()->root_instruction(), 697 Not(op::Fusion())); 698 } 699 700 TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) { 701 auto module = CreateNewVerifiedModule(); 702 CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, 703 /*k=*/50, /*n=*/1, 704 /*add_extra_use_for_dot=*/true); 705 706 TF_ASSERT_OK_AND_ASSIGN(bool fused_something, 707 CpuInstructionFusion().Run(module.get())); 708 EXPECT_FALSE(fused_something); 709 EXPECT_THAT(module->entry_computation()->root_instruction(), 710 Not(op::Fusion())); 711 } 712 713 TEST_F(InstructionFusionTest, 714 DotOperationFusion_DontOutputFuseDuplicateOperands) { 715 absl::string_view module_string = R"( 716 HloModule module 717 718 ENTRY main { 719 a = f32[50,60]{1,0} parameter(0) 720 b = f32[60,1]{1,0} parameter(1) 721 c = f32[50,1]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0} 722 ROOT d = f32[50,1]{1,0} add(c, c) 723 } 724 )"; 725 726 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, 727 ParseAndReturnVerifiedModule(module_string)); 728 TF_ASSERT_OK_AND_ASSIGN(bool fused_something, 729 CpuInstructionFusion().Run(module.get())); 730 EXPECT_FALSE(fused_something); 731 EXPECT_THAT(module->entry_computation()->root_instruction(), 732 Not(op::Fusion())); 733 } 734 735 struct GatherLoopFusionTestSpec { 736 string test_name; 737 string hlo_computation_text; 738 739 static string Name( 740 const ::testing::TestParamInfo<GatherLoopFusionTestSpec>& info) { 741 return info.param.test_name; 742 } 743 }; 744 745 class GatherLoopFusionTest 746 : public OpcodeFusionTest, 747 public ::testing::WithParamInterface<GatherLoopFusionTestSpec> {}; 748 749 TEST_P(GatherLoopFusionTest, GatherLoopFusion) { 750 const GatherLoopFusionTestSpec& spec = GetParam(); 751 string hlo_string = absl::StrCat("HloModule ", spec.test_name, "\n\n", 752 spec.hlo_computation_text); 753 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, 754 ParseHloString(hlo_string)); 755 756 RunFusionAndCheckOpcodesWereFused( 757 module.get(), 758 {HloOpcode::kGather, HloOpcode::kAdd, HloOpcode::kBroadcast, 759 HloOpcode::kConstant, HloOpcode::kParameter, HloOpcode::kParameter}); 760 } 761 762 std::vector<GatherLoopFusionTestSpec> GetGatherLoopFusionTestSpecs() { 763 std::vector<GatherLoopFusionTestSpec> result; 764 765 result.push_back({"FusedTensorFlowGatherV2", R"( 766 ENTRY main { 767 operand = s32[3,3] parameter(0) 768 indices = s32[2] parameter(1) 769 gather = s32[3,2] gather(operand, indices), 770 offset_dims={0}, 771 collapsed_slice_dims={1}, 772 start_index_map={1}, 773 index_vector_dim=1, 774 slice_sizes={3, 1} 775 one = s32[] constant(1) 776 one_broadcasted = s32[3,2] broadcast(one), dimensions={} 777 ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted) 778 } 779 )"}); 780 781 result.push_back({"FusedTensorFlowGatherMultipleBatchDims", R"( 782 ENTRY main { 783 operand = s32[3,3] parameter(0) 784 indices = s32[2,2] parameter(1) 785 gather = s32[2,3,2] gather(operand, indices), 786 offset_dims={1}, 787 collapsed_slice_dims={1}, 788 start_index_map={1}, 789 index_vector_dim=2, 790 slice_sizes={3, 1} 791 one = s32[] constant(1) 792 one_broadcasted = s32[2,3,2] broadcast(one), dimensions={} 793 ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted) 794 } 795 )"}); 796 797 result.push_back({"FusedTensorFlowGatherNdMultipleBatchDims", R"( 798 ENTRY main { 799 operand = s32[3,3] parameter(0) 800 indices = s32[2,2,2] parameter(1) 801 gather = s32[2,2] gather(operand, indices), 802 offset_dims={}, 803 collapsed_slice_dims={0,1}, 804 start_index_map={0,1}, 805 index_vector_dim=2, 806 slice_sizes={1, 1} 807 one = s32[] constant(1) 808 one_broadcasted = s32[2,2] broadcast(one), dimensions={} 809 ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) 810 } 811 )"}); 812 813 result.push_back({"FusedTensorFlowGatherNd_0", R"( 814 ENTRY main { 815 operand = s32[3,3,2] parameter(0) 816 indices = s32[2,2] parameter(1) 817 gather = s32[2,2] gather(operand, indices), 818 offset_dims={1}, 819 collapsed_slice_dims={0,1}, 820 start_index_map={0,1}, 821 index_vector_dim=1, 822 slice_sizes={1,1,2} 823 one = s32[] constant(1) 824 one_broadcasted = s32[2,2] broadcast(one), dimensions={} 825 ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) 826 } 827 )"}); 828 829 result.push_back({"FusedTensorFlowGatherNd_1", R"( 830 ENTRY main { 831 operand = s32[3,3,2] parameter(0) 832 indices = s32[2,2] parameter(1) 833 gather = s32[2,2] gather(operand, indices), 834 offset_dims={1}, 835 collapsed_slice_dims={0,1}, 836 start_index_map={0,1}, 837 index_vector_dim=0, 838 slice_sizes={1,1,2} 839 one = s32[] constant(1) 840 one_broadcasted = s32[2,2] broadcast(one), dimensions={} 841 ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) 842 } 843 )"}); 844 845 result.push_back({"FusedDynamicSlice", R"( 846 ENTRY main { 847 operand = s32[3,3] parameter(0) 848 indices = s32[2] parameter(1) 849 gather = s32[1,1] gather(operand, indices), 850 offset_dims={0,1}, 851 collapsed_slice_dims={}, 852 start_index_map={0,1}, 853 index_vector_dim=0, 854 slice_sizes={1,1} 855 one = s32[] constant(1) 856 one_broadcasted = s32[1,1] broadcast(one), dimensions={} 857 ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted) 858 } 859 )"}); 860 861 result.push_back({"FusedBatchDynamicSlice", R"( 862 ENTRY main { 863 operand = s32[3,3] parameter(0) 864 indices = s32[2,2] parameter(1) 865 gather = s32[2,1,1] gather(operand, indices), 866 offset_dims={1,2}, 867 collapsed_slice_dims={}, 868 start_index_map={0,1}, 869 index_vector_dim=0, 870 slice_sizes={1,1} 871 one = s32[] constant(1) 872 one_broadcasted = s32[2,1,1] broadcast(one), dimensions={} 873 ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted) 874 } 875 )"}); 876 877 return result; 878 } 879 880 INSTANTIATE_TEST_SUITE_P(GatherLoopFusionTestInstantiation, 881 GatherLoopFusionTest, 882 ::testing::ValuesIn(GetGatherLoopFusionTestSpecs()), 883 GatherLoopFusionTestSpec::Name); 884 } // namespace 885 } // namespace cpu 886 } // namespace xla 887