1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" 17 18 #include <algorithm> 19 #include <set> 20 21 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 22 #include "tensorflow/compiler/xla/service/transpose_folding.h" 23 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 24 #include "tensorflow/core/lib/gtl/array_slice.h" 25 26 namespace op = xla::testing::opcode_matchers; 27 28 namespace xla { 29 namespace cpu { 30 namespace { 31 32 using InstructionFusionTest = HloTestBase; 33 34 std::unique_ptr<HloInstruction> MakeDot(const Shape& shape, HloInstruction* lhs, 35 HloInstruction* rhs) { 36 DotDimensionNumbers dot_dnums; 37 dot_dnums.add_lhs_contracting_dimensions(1); 38 dot_dnums.add_rhs_contracting_dimensions(0); 39 return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums); 40 } 41 42 TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { 43 HloComputation::Builder builder(TestName()); 44 HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 45 0, ShapeUtil::MakeShape(F32, {1024, 256}), "arg0")); 46 HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( 47 1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1")); 48 49 HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary( 50 ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg0)); 51 HloInstruction* dot = builder.AddInstruction( 52 MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), exp0, arg1)); 53 54 auto module = CreateNewModule(); 55 auto computation = module->AddEntryComputation(builder.Build()); 56 EXPECT_EQ(dot, computation->root_instruction()); 57 EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); 58 EXPECT_THAT(computation->root_instruction(), op::Fusion()); 59 } 60 61 TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) { 62 HloComputation::Builder builder(TestName()); 63 HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 64 0, ShapeUtil::MakeShape(F32, {1, 256}), "arg0")); 65 HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( 66 1, ShapeUtil::MakeShape(F32, {256, 1024}), "arg1")); 67 68 HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( 69 ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1)); 70 HloInstruction* dot = builder.AddInstruction( 71 MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, exp1)); 72 73 auto module = CreateNewModule(); 74 auto computation = module->AddEntryComputation(builder.Build()); 75 EXPECT_EQ(dot, computation->root_instruction()); 76 EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); 77 EXPECT_THAT(computation->root_instruction(), op::Fusion()); 78 } 79 80 TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) { 81 HloComputation::Builder builder(TestName()); 82 HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 83 0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0")); 84 HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( 85 1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1")); 86 87 HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary( 88 ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0)); 89 HloInstruction* bitcast0 = builder.AddInstruction(HloInstruction::CreateUnary( 90 ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kBitcast, exp0)); 91 HloInstruction* dot = builder.AddInstruction( 92 MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), bitcast0, arg1)); 93 94 auto module = CreateNewModule(); 95 auto computation = module->AddEntryComputation(builder.Build()); 96 EXPECT_EQ(dot, computation->root_instruction()); 97 EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); 98 EXPECT_THAT(computation->root_instruction(), op::Fusion()); 99 } 100 101 TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { 102 HloComputation::Builder builder(TestName()); 103 HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 104 0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0")); 105 HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( 106 1, ShapeUtil::MakeShape(F32, {256, 1}), "arg1")); 107 108 HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary( 109 ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0)); 110 HloInstruction* reshape0 = 111 builder.AddInstruction(HloInstruction::CreateReshape( 112 ShapeUtil::MakeShape(S32, {1024, 256}), exp0)); 113 HloInstruction* dot = builder.AddInstruction( 114 MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), reshape0, arg1)); 115 116 auto module = CreateNewModule(); 117 auto computation = module->AddEntryComputation(builder.Build()); 118 EXPECT_EQ(dot, computation->root_instruction()); 119 EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); 120 EXPECT_THAT(computation->root_instruction(), op::Fusion()); 121 } 122 123 TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) { 124 HloComputation::Builder builder(TestName()); 125 HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 126 0, ShapeUtil::MakeShape(F32, {1, 32 * 1024}), "arg0")); 127 HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( 128 1, ShapeUtil::MakeShape(F32, {256, 32 * 1024}), "arg1")); 129 130 HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( 131 ShapeUtil::MakeShape(S32, {256, 32 * 1024}), HloOpcode::kExp, arg1)); 132 HloInstruction* dot = builder.AddInstruction( 133 MakeDot(ShapeUtil::MakeShape(F32, {1, 32 * 1024}), arg0, exp1)); 134 135 auto module = CreateNewModule(); 136 auto computation = module->AddEntryComputation(builder.Build()); 137 EXPECT_EQ(dot, computation->root_instruction()); 138 EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); 139 EXPECT_EQ(dot, computation->root_instruction()); 140 } 141 142 TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) { 143 HloComputation::Builder builder(TestName()); 144 HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 145 0, ShapeUtil::MakeShape(F32, {2, 256}), "arg0")); 146 HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( 147 1, ShapeUtil::MakeShape(F32, {256, 1024}), "arg1")); 148 149 HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( 150 ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1)); 151 HloInstruction* dot = builder.AddInstruction( 152 MakeDot(ShapeUtil::MakeShape(F32, {2, 1024}), arg0, exp1)); 153 154 auto module = CreateNewModule(); 155 auto computation = module->AddEntryComputation(builder.Build()); 156 EXPECT_EQ(dot, computation->root_instruction()); 157 EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); 158 EXPECT_EQ(dot, computation->root_instruction()); 159 } 160 161 TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion) { 162 HloComputation::Builder builder(TestName()); 163 HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 164 0, ShapeUtil::MakeShape(F32, {1, 256}), "arg0")); 165 HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( 166 1, ShapeUtil::MakeShape(F32, {1024, 256}), "arg1")); 167 168 HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( 169 ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg1)); 170 HloInstruction* transpose1 = 171 builder.AddInstruction(HloInstruction::CreateTranspose( 172 ShapeUtil::MakeShape(S32, {256, 1024}), exp1, {1, 0})); 173 builder.AddInstruction( 174 MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, transpose1)); 175 176 auto module = CreateNewModule(); 177 auto computation = module->AddEntryComputation(builder.Build()); 178 TransposeFolding transpose_folding( 179 [](const HloInstruction& dot, 180 const TransposeFolding::OperandIndices& candidate_operands) { 181 return candidate_operands; 182 }, 183 TransposeFolding::NeverFoldTranspose); 184 EXPECT_TRUE(transpose_folding.Run(module.get()).ValueOrDie()); 185 EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kFusion); 186 EXPECT_EQ(computation->root_instruction()->fusion_kind(), 187 HloInstruction::FusionKind::kTransposeDot); 188 EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); 189 EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kFusion); 190 EXPECT_EQ(computation->root_instruction()->fusion_kind(), 191 HloInstruction::FusionKind::kTransposeDot); 192 } 193 194 class OpcodeFusionTest : public InstructionFusionTest { 195 protected: 196 // Runs CPU instruction fusion on the given module, and tests that the result 197 // contains a fused op at the root with exactly the given multiset of opcodes. 198 void RunFusionAndCheckOpcodesWereFused( 199 HloModule* module, const std::multiset<HloOpcode>& expected_opcodes, 200 HloInstruction::FusionKind fusion_kind = 201 HloInstruction::FusionKind::kLoop) { 202 auto computation = module->entry_computation(); 203 auto did_fusion = CpuInstructionFusion().Run(module); 204 ASSERT_TRUE(did_fusion.ok()); 205 EXPECT_TRUE(did_fusion.ValueOrDie()); 206 207 HloInstruction* root = computation->root_instruction(); 208 ASSERT_THAT(root, op::Fusion()); 209 EXPECT_EQ(root->fusion_kind(), fusion_kind); 210 211 std::vector<HloOpcode> fused_opcodes(root->fused_instruction_count()); 212 std::transform(root->fused_instructions().begin(), 213 root->fused_instructions().end(), fused_opcodes.begin(), 214 [](const HloInstruction* hlo) { return hlo->opcode(); }); 215 216 EXPECT_EQ( 217 std::multiset<HloOpcode>(fused_opcodes.begin(), fused_opcodes.end()), 218 expected_opcodes); 219 } 220 221 HloComputation* CreateAdderToOne(HloModule* module) { 222 HloComputation::Builder builder(TestName()); 223 HloInstruction* arg0 = 224 builder.AddInstruction(HloInstruction::CreateParameter( 225 0, ShapeUtil::MakeShape(F32, {}), "arg0")); 226 HloInstruction* one = builder.AddInstruction( 227 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 228 builder.AddInstruction(HloInstruction::CreateBinary( 229 ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, arg0, one)); 230 return module->AddEmbeddedComputation(builder.Build()); 231 } 232 233 HloComputation* CreateMax(HloModule* module) { 234 HloComputation::Builder builder(TestName()); 235 HloInstruction* arg0 = 236 builder.AddInstruction(HloInstruction::CreateParameter( 237 0, ShapeUtil::MakeShape(F32, {}), "arg0")); 238 HloInstruction* arg1 = 239 builder.AddInstruction(HloInstruction::CreateParameter( 240 1, ShapeUtil::MakeShape(F32, {}), "arg1")); 241 builder.AddInstruction(HloInstruction::CreateBinary( 242 ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, arg0, arg1)); 243 return module->AddEmbeddedComputation(builder.Build()); 244 } 245 }; 246 247 TEST_F(OpcodeFusionTest, Exponential_Bitcast_Negate) { 248 HloComputation::Builder builder(TestName()); 249 Shape param_shape = ShapeUtil::MakeShape(F32, {1, 4}); 250 Shape result_shape = ShapeUtil::MakeShape(F32, {4}); 251 HloInstruction* param0 = builder.AddInstruction( 252 HloInstruction::CreateParameter(0, param_shape, "param")); 253 // InstructionFusion::ShouldFuse() precludes fusing a bitcast whose operand 254 // is a parameter, so create an operand between the parameter and bitcast. 255 HloInstruction* exp1 = builder.AddInstruction( 256 HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0)); 257 HloInstruction* bitcast2 = builder.AddInstruction( 258 HloInstruction::CreateUnary(result_shape, HloOpcode::kBitcast, exp1)); 259 builder.AddInstruction( 260 HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, bitcast2)); 261 262 auto module = CreateNewModule(); 263 module->AddEntryComputation(builder.Build()); 264 265 RunFusionAndCheckOpcodesWereFused( 266 module.get(), {HloOpcode::kNegate, HloOpcode::kBitcast, HloOpcode::kExp, 267 HloOpcode::kParameter}); 268 } 269 270 TEST_F(OpcodeFusionTest, Broadcast_Bitcast_DynamicSlice_Tanh) { 271 HloComputation::Builder builder(TestName()); 272 Shape param_shape = ShapeUtil::MakeShape(F32, {8}); 273 Shape starts_shape = ShapeUtil::MakeShape(F32, {2}); 274 Shape broadcast_shape = ShapeUtil::MakeShape(F32, {1, 8, 8}); 275 Shape bitcast_shape = ShapeUtil::MakeShape(F32, {8, 8}); 276 Shape dynamic_slice_shape = ShapeUtil::MakeShape(F32, {4, 4}); 277 HloInstruction* param0 = builder.AddInstruction( 278 HloInstruction::CreateParameter(0, param_shape, "param")); 279 HloInstruction* param1 = builder.AddInstruction( 280 HloInstruction::CreateParameter(1, starts_shape, "starts")); 281 HloInstruction* broadcast2 = builder.AddInstruction( 282 HloInstruction::CreateBroadcast(broadcast_shape, param0, {1})); 283 HloInstruction* bitcast3 = builder.AddInstruction(HloInstruction::CreateUnary( 284 bitcast_shape, HloOpcode::kBitcast, broadcast2)); 285 HloInstruction* dynamic_slice4 = 286 builder.AddInstruction(HloInstruction::CreateDynamicSlice( 287 dynamic_slice_shape, bitcast3, param1, {4, 4})); 288 builder.AddInstruction(HloInstruction::CreateUnary( 289 dynamic_slice_shape, HloOpcode::kTanh, dynamic_slice4)); 290 291 auto module = CreateNewModule(); 292 module->AddEntryComputation(builder.Build()); 293 294 RunFusionAndCheckOpcodesWereFused( 295 module.get(), 296 {HloOpcode::kTanh, HloOpcode::kDynamicSlice, HloOpcode::kBitcast, 297 HloOpcode::kBroadcast, HloOpcode::kParameter, HloOpcode::kParameter}); 298 } 299 300 TEST_F(OpcodeFusionTest, Broadcast_Negate) { 301 HloComputation::Builder builder(TestName()); 302 Shape param_shape = ShapeUtil::MakeShape(F32, {8}); 303 Shape result_shape = ShapeUtil::MakeShape(F32, {8, 8}); 304 HloInstruction* param0 = builder.AddInstruction( 305 HloInstruction::CreateParameter(0, param_shape, "param")); 306 HloInstruction* broadcast1 = builder.AddInstruction( 307 HloInstruction::CreateBroadcast(result_shape, param0, {1})); 308 builder.AddInstruction(HloInstruction::CreateUnary( 309 result_shape, HloOpcode::kNegate, broadcast1)); 310 311 auto module = CreateNewModule(); 312 module->AddEntryComputation(builder.Build()); 313 314 RunFusionAndCheckOpcodesWereFused( 315 module.get(), 316 {HloOpcode::kNegate, HloOpcode::kBroadcast, HloOpcode::kParameter}); 317 } 318 319 TEST_F(OpcodeFusionTest, DynamicSlice_Negate) { 320 HloComputation::Builder builder(TestName()); 321 Shape param_shape = ShapeUtil::MakeShape(F32, {4}); 322 Shape slice_shape = ShapeUtil::MakeShape(F32, {1}); 323 Shape result_shape = ShapeUtil::MakeShape(F32, {2}); 324 HloInstruction* param0 = builder.AddInstruction( 325 HloInstruction::CreateParameter(0, param_shape, "param")); 326 HloInstruction* param1 = builder.AddInstruction( 327 HloInstruction::CreateParameter(1, slice_shape, "starts")); 328 HloInstruction* dynamic_slice2 = builder.AddInstruction( 329 HloInstruction::CreateDynamicSlice(result_shape, param0, param1, {2})); 330 builder.AddInstruction(HloInstruction::CreateUnary( 331 result_shape, HloOpcode::kNegate, dynamic_slice2)); 332 333 auto module = CreateNewModule(); 334 module->AddEntryComputation(builder.Build()); 335 336 RunFusionAndCheckOpcodesWereFused( 337 module.get(), {HloOpcode::kNegate, HloOpcode::kDynamicSlice, 338 HloOpcode::kParameter, HloOpcode::kParameter}); 339 } 340 341 TEST_F(OpcodeFusionTest, Exponential_Negate) { 342 HloComputation::Builder builder(TestName()); 343 Shape param_shape = ShapeUtil::MakeShape(F32, {4}); 344 HloInstruction* param0 = builder.AddInstruction( 345 HloInstruction::CreateParameter(0, param_shape, "param")); 346 HloInstruction* exp1 = builder.AddInstruction( 347 HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0)); 348 builder.AddInstruction( 349 HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, exp1)); 350 351 auto module = CreateNewModule(); 352 module->AddEntryComputation(builder.Build()); 353 354 RunFusionAndCheckOpcodesWereFused( 355 module.get(), 356 {HloOpcode::kNegate, HloOpcode::kExp, HloOpcode::kParameter}); 357 } 358 359 TEST_F(OpcodeFusionTest, Reshape_Negate) { 360 HloComputation::Builder builder(TestName()); 361 Shape param_shape = ShapeUtil::MakeShape(F32, {4, 4}); 362 Shape result_shape = ShapeUtil::MakeShape(F32, {16}); 363 HloInstruction* param0 = builder.AddInstruction( 364 HloInstruction::CreateParameter(0, param_shape, "param")); 365 HloInstruction* reshape1 = builder.AddInstruction( 366 HloInstruction::CreateReshape(result_shape, param0)); 367 builder.AddInstruction( 368 HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape1)); 369 370 auto module = CreateNewModule(); 371 module->AddEntryComputation(builder.Build()); 372 373 RunFusionAndCheckOpcodesWereFused( 374 module.get(), 375 {HloOpcode::kNegate, HloOpcode::kReshape, HloOpcode::kParameter}); 376 } 377 378 TEST_F(OpcodeFusionTest, Reverse_Negate) { 379 HloComputation::Builder builder(TestName()); 380 Shape param_shape = ShapeUtil::MakeShape(F32, {8}); 381 HloInstruction* param0 = builder.AddInstruction( 382 HloInstruction::CreateParameter(0, param_shape, "param")); 383 HloInstruction* reverse1 = builder.AddInstruction( 384 HloInstruction::CreateReverse(param_shape, param0, {0})); 385 builder.AddInstruction( 386 HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, reverse1)); 387 388 auto module = CreateNewModule(); 389 module->AddEntryComputation(builder.Build()); 390 391 RunFusionAndCheckOpcodesWereFused( 392 module.get(), 393 {HloOpcode::kNegate, HloOpcode::kReverse, HloOpcode::kParameter}); 394 } 395 396 TEST_F(OpcodeFusionTest, Slice_Negate) { 397 HloComputation::Builder builder(TestName()); 398 Shape param_shape = ShapeUtil::MakeShape(F32, {4}); 399 Shape slice_shape = ShapeUtil::MakeShape(F32, {2}); 400 HloInstruction* param0 = builder.AddInstruction( 401 HloInstruction::CreateParameter(0, param_shape, "param")); 402 HloInstruction* slice1 = builder.AddInstruction( 403 HloInstruction::CreateSlice(slice_shape, param0, {0}, {4}, {2})); 404 builder.AddInstruction(HloInstruction::CreateUnary( 405 ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, slice1)); 406 407 auto module = CreateNewModule(); 408 module->AddEntryComputation(builder.Build()); 409 410 RunFusionAndCheckOpcodesWereFused( 411 module.get(), 412 {HloOpcode::kNegate, HloOpcode::kSlice, HloOpcode::kParameter}); 413 } 414 415 TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) { 416 HloComputation::Builder builder(TestName()); 417 Shape param_shape = ShapeUtil::MakeShape(F32, {3, 4}); 418 Shape result_shape = ShapeUtil::MakeShape(F32, {4, 3}); 419 HloInstruction* param0 = builder.AddInstruction( 420 HloInstruction::CreateParameter(0, param_shape, "param")); 421 // InstructionFusion::ShouldFuse() precludes fusing a transpose whose operand 422 // is a parameter, so create an operand between the parameter and transpose. 423 HloInstruction* exp1 = builder.AddInstruction( 424 HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0)); 425 HloInstruction* transpose2 = builder.AddInstruction( 426 HloInstruction::CreateTranspose(result_shape, exp1, {1, 0})); 427 builder.AddInstruction(HloInstruction::CreateUnary( 428 result_shape, HloOpcode::kNegate, transpose2)); 429 430 auto module = CreateNewModule(); 431 module->AddEntryComputation(builder.Build()); 432 433 RunFusionAndCheckOpcodesWereFused( 434 module.get(), {HloOpcode::kNegate, HloOpcode::kTranspose, HloOpcode::kExp, 435 HloOpcode::kParameter}); 436 } 437 438 TEST_F(OpcodeFusionTest, UnaryMapOfExp) { 439 auto module = CreateNewModule(); 440 441 HloComputation::Builder builder(TestName()); 442 Shape shape = ShapeUtil::MakeShape(F32, {3, 4}); 443 HloInstruction* param0 = builder.AddInstruction( 444 HloInstruction::CreateParameter(0, shape, "param")); 445 446 HloInstruction* exp = builder.AddInstruction( 447 HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0)); 448 builder.AddInstruction(HloInstruction::CreateMap( 449 shape, {exp}, CreateAdderToOne(module.get()), /*static_operands=*/{})); 450 451 module->AddEntryComputation(builder.Build()); 452 453 RunFusionAndCheckOpcodesWereFused( 454 module.get(), {HloOpcode::kParameter, HloOpcode::kExp, HloOpcode::kMap}); 455 } 456 457 TEST_F(OpcodeFusionTest, BinaryMapOfExps) { 458 auto module = CreateNewModule(); 459 460 HloComputation::Builder builder(TestName()); 461 Shape shape = ShapeUtil::MakeShape(F32, {3, 4}); 462 HloInstruction* param0 = builder.AddInstruction( 463 HloInstruction::CreateParameter(0, shape, "param")); 464 HloInstruction* param1 = builder.AddInstruction( 465 HloInstruction::CreateParameter(1, shape, "param")); 466 467 HloInstruction* exp0 = builder.AddInstruction( 468 HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0)); 469 HloInstruction* exp1 = builder.AddInstruction( 470 HloInstruction::CreateUnary(shape, HloOpcode::kExp, param1)); 471 472 builder.AddInstruction(HloInstruction::CreateMap( 473 shape, {exp0, exp1}, CreateMax(module.get()), /*static_operands=*/{})); 474 475 module->AddEntryComputation(builder.Build()); 476 477 RunFusionAndCheckOpcodesWereFused( 478 module.get(), {HloOpcode::kParameter, HloOpcode::kParameter, 479 HloOpcode::kExp, HloOpcode::kExp, HloOpcode::kMap}); 480 } 481 482 TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { 483 auto module = CreateNewModule(); 484 485 HloComputation::Builder builder(TestName()); 486 Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); 487 Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000}); 488 489 HloInstruction* slice = 490 builder.AddInstruction(HloInstruction::CreateDynamicSlice( 491 slice_shape, 492 builder.AddInstruction( 493 HloInstruction::CreateParameter(0, full_shape, "slice_from")), 494 builder.AddInstruction(HloInstruction::CreateParameter( 495 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), 496 /*slice_sizes=*/{10, 1, 1000})); 497 498 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( 499 full_shape, 500 builder.AddInstruction( 501 HloInstruction::CreateParameter(2, full_shape, "to_update")), 502 slice, 503 builder.AddInstruction(HloInstruction::CreateParameter( 504 3, ShapeUtil::MakeShape(U32, {3}), "update_indices")))); 505 506 module->AddEntryComputation(builder.Build()); 507 RunFusionAndCheckOpcodesWereFused( 508 module.get(), {HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice, 509 HloOpcode::kParameter, HloOpcode::kParameter, 510 HloOpcode::kParameter, HloOpcode::kParameter}); 511 } 512 513 TEST_F(OpcodeFusionTest, MessOfFusileNodes) { 514 auto module = CreateNewModule(); 515 HloComputation::Builder builder(TestName()); 516 517 Shape full_shape = ShapeUtil::MakeShape(F32, {4, 100, 10, 100, 50}); 518 519 auto loop_idx = builder.AddInstruction(HloInstruction::CreateReshape( 520 ShapeUtil::MakeShape(S32, {1}), 521 builder.AddInstruction(HloInstruction::CreateParameter( 522 0, ShapeUtil::MakeShape(S32, {}), "param0")))); 523 524 auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( 525 1, ShapeUtil::MakeShape(S32, {1}), "param1")); 526 auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate( 527 ShapeUtil::MakeShape(S32, {5}), 528 {loop_idx, param1, param1, param1, param1}, /*dimension=*/0)); 529 530 auto idx_choice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( 531 ShapeUtil::MakeShape(S32, {1}), 532 builder.AddInstruction(HloInstruction::CreateParameter( 533 2, ShapeUtil::MakeShape(S32, {4}), "param2")), 534 loop_idx, 535 /*slice_sizes=*/{1})); 536 537 PaddingConfig padding_config; 538 padding_config.add_dimensions()->set_edge_padding_high(4); 539 auto pad = builder.AddInstruction(HloInstruction::CreatePad( 540 ShapeUtil::MakeShape(S32, {5}), idx_choice, 541 builder.AddInstruction( 542 HloInstruction::CreateConstant(Literal::CreateR0(0))), 543 padding_config)); 544 545 auto slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( 546 ShapeUtil::MakeShape(F32, {1, 100, 10, 100, 50}), 547 builder.AddInstruction(HloInstruction::CreateParameter( 548 3, ShapeUtil::MakeShape(F32, {100, 100, 10, 100, 50}), "param3")), 549 pad, /*slice_sizes=*/{1, 100, 10, 100, 50})); 550 551 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( 552 full_shape, 553 builder.AddInstruction( 554 HloInstruction::CreateParameter(4, full_shape, "param4")), 555 slice, concat)); 556 557 module->AddEntryComputation(builder.Build()); 558 RunFusionAndCheckOpcodesWereFused( 559 module.get(), 560 {HloOpcode::kConcatenate, HloOpcode::kPad, HloOpcode::kDynamicSlice, 561 HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice, 562 HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, 563 HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter}); 564 } 565 566 // Tests that we do not fuse instructions in cases where instructions in the 567 // fusion would reuse elements from its operand due to an implicit broadcast. 568 TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastUnary) { 569 Shape small_shape = ShapeUtil::MakeShape(F32, {1, 4}); 570 Shape large_shape = ShapeUtil::MakeShape(F32, {3, 4}); 571 572 HloComputation::Builder builder(TestName()); 573 574 HloInstruction* small_param = 575 builder.AddInstruction(HloInstruction::CreateParameter( 576 /*parameter_number=*/0, small_shape, "param")); 577 HloInstruction* small_exp = builder.AddInstruction( 578 HloInstruction::CreateUnary(small_shape, HloOpcode::kExp, small_param)); 579 builder.AddInstruction( 580 HloInstruction::CreateUnary(large_shape, HloOpcode::kExp, small_exp)); 581 582 std::unique_ptr<HloModule> module = CreateNewModule(); 583 module->AddEntryComputation(builder.Build()); 584 585 auto did_fusion = CpuInstructionFusion().Run(module.get()); 586 ASSERT_TRUE(did_fusion.ok()); 587 EXPECT_FALSE(did_fusion.ValueOrDie()); 588 ASSERT_THAT(module->entry_computation()->root_instruction(), 589 Not(op::Fusion())); 590 } 591 592 // Like ReuseViaImplicitBroadcastUnary but with a binary operation. 593 TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastBinary) { 594 Shape small_shape = ShapeUtil::MakeShape(F32, {1, 4}); 595 Shape large_shape = ShapeUtil::MakeShape(F32, {3, 4}); 596 597 HloComputation::Builder builder(TestName()); 598 599 HloInstruction* small_param = 600 builder.AddInstruction(HloInstruction::CreateParameter( 601 /*parameter_number=*/0, small_shape, "param")); 602 HloInstruction* large_param = 603 builder.AddInstruction(HloInstruction::CreateParameter( 604 /*parameter_number=*/1, large_shape, "param")); 605 HloInstruction* small_exp = builder.AddInstruction( 606 HloInstruction::CreateUnary(small_shape, HloOpcode::kExp, small_param)); 607 608 builder.AddInstruction(HloInstruction::CreateBinary( 609 large_shape, HloOpcode::kAdd, small_exp, large_param)); 610 611 std::unique_ptr<HloModule> module = CreateNewModule(); 612 module->AddEntryComputation(builder.Build()); 613 614 auto did_fusion = CpuInstructionFusion().Run(module.get()); 615 ASSERT_TRUE(did_fusion.ok()); 616 EXPECT_FALSE(did_fusion.ValueOrDie()); 617 ASSERT_THAT(module->entry_computation()->root_instruction(), 618 Not(op::Fusion())); 619 } 620 621 void CreateComputationForDotAddOutputFusionTest(const string& test_name, 622 HloModule* module, int m, int k, 623 int n, 624 bool add_extra_use_for_dot) { 625 HloComputation::Builder builder(test_name); 626 627 Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); 628 Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); 629 Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); 630 631 auto* dot_lhs = builder.AddInstruction( 632 HloInstruction::CreateParameter(0, dot_lhs_shape, "param0")); 633 auto* dot_rhs = builder.AddInstruction( 634 HloInstruction::CreateParameter(1, dot_rhs_shape, "param1")); 635 auto* addend = builder.AddInstruction( 636 HloInstruction::CreateParameter(2, dot_shape, "param2")); 637 638 auto* dot = builder.AddInstruction( 639 HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); 640 builder.AddInstruction( 641 HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend)); 642 643 if (add_extra_use_for_dot) { 644 builder.AddInstruction( 645 HloInstruction::CreateOutfeed(dot_shape, dot, "no_config")); 646 } 647 648 module->AddEntryComputation(builder.Build()); 649 } 650 651 TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) { 652 auto module = CreateNewModule(); 653 CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/1, 654 /*k=*/50, /*n=*/19, 655 /*add_extra_use_for_dot=*/false); 656 657 RunFusionAndCheckOpcodesWereFused( 658 module.get(), 659 {HloOpcode::kDot, HloOpcode::kAdd, HloOpcode::kParameter, 660 HloOpcode::kParameter, HloOpcode::kParameter}, 661 HloInstruction::FusionKind::kOutput); 662 } 663 664 TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) { 665 auto module = CreateNewModule(); 666 CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, 667 /*k=*/50, /*n=*/1, 668 /*add_extra_use_for_dot=*/false); 669 670 RunFusionAndCheckOpcodesWereFused( 671 module.get(), 672 {HloOpcode::kDot, HloOpcode::kAdd, HloOpcode::kParameter, 673 HloOpcode::kParameter, HloOpcode::kParameter}, 674 HloInstruction::FusionKind::kOutput); 675 } 676 677 TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) { 678 auto module = CreateNewModule(); 679 CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, 680 /*k=*/50, /*n=*/19, 681 /*add_extra_use_for_dot=*/false); 682 683 TF_ASSERT_OK_AND_ASSIGN(bool fused_something, 684 CpuInstructionFusion().Run(module.get())); 685 EXPECT_FALSE(fused_something); 686 EXPECT_THAT(module->entry_computation()->root_instruction(), 687 Not(op::Fusion())); 688 } 689 690 TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) { 691 auto module = CreateNewModule(); 692 CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, 693 /*k=*/50, /*n=*/1, 694 /*add_extra_use_for_dot=*/true); 695 696 TF_ASSERT_OK_AND_ASSIGN(bool fused_something, 697 CpuInstructionFusion().Run(module.get())); 698 EXPECT_FALSE(fused_something); 699 EXPECT_THAT(module->entry_computation()->root_instruction(), 700 Not(op::Fusion())); 701 } 702 703 } // namespace 704 } // namespace cpu 705 } // namespace xla 706