1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/instruction_fusion.h" 17 18 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 19 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 20 21 namespace xla { 22 23 using InstructionFusionTest = HloTestBase; 24 25 TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) { 26 HloComputation::Builder builder(TestName()); 27 auto param0 = builder.AddInstruction( 28 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0")); 29 auto reshape1 = builder.AddInstruction( 30 HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0)); 31 32 auto module = CreateNewModule(); 33 auto computation = module->AddEntryComputation(builder.Build()); 34 EXPECT_EQ(reshape1, computation->root_instruction()); 35 EXPECT_FALSE( 36 InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) 37 .Run(module.get()) 38 .ValueOrDie()); 39 } 40 41 TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) { 42 HloComputation::Builder builder(TestName()); 43 auto param0 = builder.AddInstruction( 44 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0")); 45 auto reshape1 = builder.AddInstruction( 46 HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0)); 47 48 auto module = CreateNewModule(); 49 auto computation = module->AddEntryComputation(builder.Build()); 50 EXPECT_EQ(reshape1, computation->root_instruction()); 51 EXPECT_FALSE( 52 InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) 53 .Run(module.get()) 54 .ValueOrDie()); 55 } 56 57 TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { 58 HloComputation::Builder builder(TestName()); 59 auto param0 = builder.AddInstruction( 60 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0")); 61 auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( 62 ShapeUtil::MakeShape(S32, {}), param0, {})); 63 64 auto module = CreateNewModule(); 65 auto computation = module->AddEntryComputation(builder.Build()); 66 EXPECT_EQ(transpose1, computation->root_instruction()); 67 EXPECT_FALSE( 68 InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) 69 .Run(module.get()) 70 .ValueOrDie()); 71 } 72 73 TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { 74 HloComputation::Builder builder(TestName()); 75 auto shape = ShapeUtil::MakeShape(F32, {16, 16}); 76 auto param0 = 77 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0")); 78 auto param1 = 79 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1")); 80 HloInstruction* binary1 = builder.AddInstruction( 81 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); 82 builder.AddInstruction(HloInstruction::CreateSend(binary1, 0)); 83 HloInstruction* unary = builder.AddInstruction( 84 HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); 85 86 auto module = CreateNewModule(); 87 auto computation = module->AddEntryComputation(builder.Build()); 88 EXPECT_EQ(unary, computation->root_instruction()); 89 EXPECT_FALSE( 90 InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) 91 .Run(module.get()) 92 .ValueOrDie()); 93 } 94 95 TEST_F(InstructionFusionTest, AllowUnaryDuplication) { 96 HloComputation::Builder builder(TestName()); 97 auto shape = ShapeUtil::MakeShape(F32, {16, 16}); 98 auto param0 = 99 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0")); 100 HloInstruction* unary1 = builder.AddInstruction( 101 HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0)); 102 builder.AddInstruction(HloInstruction::CreateSend(unary1, 0)); 103 HloInstruction* unary2 = builder.AddInstruction( 104 HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1)); 105 106 auto module = CreateNewModule(); 107 auto computation = module->AddEntryComputation(builder.Build()); 108 EXPECT_EQ(unary2, computation->root_instruction()); 109 EXPECT_TRUE( 110 InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) 111 .Run(module.get()) 112 .ValueOrDie()); 113 } 114 115 TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { 116 auto shape = ShapeUtil::MakeShape(F32, {16, 16}); 117 auto small_shape = ShapeUtil::MakeShape(F32, {16}); 118 HloComputation::Builder builder(TestName()); 119 auto param0 = builder.AddInstruction( 120 HloInstruction::CreateParameter(0, small_shape, "0")); 121 auto param1 = 122 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1")); 123 HloInstruction* binary1 = builder.AddInstruction( 124 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); 125 builder.AddInstruction(HloInstruction::CreateSend(binary1, 0)); 126 HloInstruction* unary = builder.AddInstruction( 127 HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); 128 129 auto module = CreateNewModule(); 130 auto computation = module->AddEntryComputation(builder.Build()); 131 EXPECT_EQ(unary, computation->root_instruction()); 132 EXPECT_TRUE( 133 InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) 134 .Run(module.get()) 135 .ValueOrDie()); 136 } 137 138 } // namespace xla 139