1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/defuser.h" 17 18 #include "tensorflow/compiler/xla/literal_util.h" 19 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 20 #include "tensorflow/compiler/xla/shape_util.h" 21 #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" 22 23 namespace op = xla::testing::opcode_matchers; 24 25 namespace xla { 26 namespace { 27 28 class DefuserTest : public HloVerifiedTestBase { 29 protected: 30 // Returns the number of fusion instructions in the module. 31 int FusionCount() { 32 int count = 0; 33 for (HloComputation* computation : module().computations()) { 34 if (computation->IsFusionComputation()) { 35 count++; 36 } 37 } 38 return count; 39 } 40 41 Defuser defuser_; 42 const Shape shape_ = ShapeUtil::MakeShape(F32, {2, 2}); 43 }; 44 45 TEST_F(DefuserTest, NoFusionInstruction) { 46 auto builder = HloComputation::Builder(TestName()); 47 auto param0 = 48 builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); 49 auto param1 = 50 builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1")); 51 builder.AddInstruction( 52 HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); 53 54 module().AddEntryComputation(builder.Build()); 55 EXPECT_EQ(0, FusionCount()); 56 57 EXPECT_FALSE(defuser_.Run(&module()).ValueOrDie()); 58 } 59 60 TEST_F(DefuserTest, TrivialFusionInstructionAsRoot) { 61 auto builder = HloComputation::Builder(TestName()); 62 auto param0 = 63 builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); 64 auto param1 = 65 builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1")); 66 auto add = builder.AddInstruction( 67 HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); 68 69 auto computation = module().AddEntryComputation(builder.Build()); 70 computation->CreateFusionInstruction({add}, 71 HloInstruction::FusionKind::kLoop); 72 73 EXPECT_THAT(computation->root_instruction(), op::Fusion()); 74 75 EXPECT_EQ(1, FusionCount()); 76 EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); 77 EXPECT_EQ(0, FusionCount()); 78 79 EXPECT_THAT(computation->root_instruction(), 80 op::Add(op::Parameter(), op::Parameter())); 81 } 82 83 TEST_F(DefuserTest, TrivialFusionInstructionNotAsRoot) { 84 auto builder = HloComputation::Builder(TestName()); 85 auto param0 = 86 builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); 87 auto param1 = 88 builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1")); 89 auto add = builder.AddInstruction( 90 HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); 91 builder.AddInstruction( 92 HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add)); 93 94 auto computation = module().AddEntryComputation(builder.Build()); 95 computation->CreateFusionInstruction({add}, 96 HloInstruction::FusionKind::kLoop); 97 98 EXPECT_THAT(computation->root_instruction(), op::Negate(op::Fusion())); 99 100 EXPECT_EQ(1, FusionCount()); 101 EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); 102 EXPECT_EQ(0, FusionCount()); 103 104 EXPECT_THAT(computation->root_instruction(), 105 op::Negate(op::Add(op::Parameter(), op::Parameter()))); 106 } 107 108 TEST_F(DefuserTest, NonTrivialFusionInstruction) { 109 auto builder = HloComputation::Builder(TestName()); 110 auto param0 = 111 builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); 112 auto param1 = 113 builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1")); 114 auto param3 = 115 builder.AddInstruction(HloInstruction::CreateParameter(2, shape_, "p2")); 116 auto add = builder.AddInstruction( 117 HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); 118 auto negate = builder.AddInstruction( 119 HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add)); 120 auto sub = builder.AddInstruction( 121 HloInstruction::CreateBinary(shape_, HloOpcode::kSubtract, add, negate)); 122 auto mul = builder.AddInstruction( 123 HloInstruction::CreateBinary(shape_, HloOpcode::kMultiply, sub, param3)); 124 auto div = builder.AddInstruction( 125 HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3)); 126 auto constant = builder.AddInstruction(HloInstruction::CreateConstant( 127 Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); 128 auto add2 = builder.AddInstruction( 129 HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div)); 130 131 auto computation = module().AddEntryComputation(builder.Build()); 132 computation->CreateFusionInstruction( 133 {add2, constant, div, mul, sub, negate, add}, 134 HloInstruction::FusionKind::kLoop); 135 136 EXPECT_THAT(computation->root_instruction(), op::Fusion()); 137 138 EXPECT_EQ(1, FusionCount()); 139 EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); 140 EXPECT_EQ(0, FusionCount()); 141 142 EXPECT_THAT(computation->root_instruction(), 143 op::Add(op::Constant(), op::Divide())); 144 } 145 146 TEST_F(DefuserTest, MultipleFusionInstructions) { 147 auto builder = HloComputation::Builder(TestName()); 148 auto param0 = 149 builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); 150 auto param1 = 151 builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1")); 152 auto param3 = 153 builder.AddInstruction(HloInstruction::CreateParameter(2, shape_, "p2")); 154 auto add = builder.AddInstruction( 155 HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); 156 auto negate = builder.AddInstruction( 157 HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add)); 158 auto sub = builder.AddInstruction( 159 HloInstruction::CreateBinary(shape_, HloOpcode::kSubtract, add, negate)); 160 auto mul = builder.AddInstruction( 161 HloInstruction::CreateBinary(shape_, HloOpcode::kMultiply, sub, param3)); 162 auto div = builder.AddInstruction( 163 HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3)); 164 auto constant = builder.AddInstruction(HloInstruction::CreateConstant( 165 Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); 166 auto add2 = builder.AddInstruction( 167 HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div)); 168 169 auto computation = module().AddEntryComputation(builder.Build()); 170 computation->CreateFusionInstruction({add2, constant, div, mul}, 171 HloInstruction::FusionKind::kLoop); 172 computation->CreateFusionInstruction({sub, negate, add}, 173 HloInstruction::FusionKind::kLoop); 174 175 EXPECT_THAT(computation->root_instruction(), op::Fusion()); 176 177 EXPECT_EQ(2, FusionCount()); 178 EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); 179 EXPECT_EQ(0, FusionCount()); 180 181 EXPECT_THAT(computation->root_instruction(), 182 op::Add(op::Constant(), op::Divide())); 183 } 184 185 TEST_F(DefuserTest, NestedFusionInstructions) { 186 auto builder = HloComputation::Builder(TestName()); 187 auto param0 = 188 builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); 189 auto param1 = 190 builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1")); 191 auto add = builder.AddInstruction( 192 HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); 193 auto negate = builder.AddInstruction( 194 HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add)); 195 196 auto computation = module().AddEntryComputation(builder.Build()); 197 auto outer_fusion = computation->CreateFusionInstruction( 198 {negate, add}, HloInstruction::FusionKind::kLoop); 199 HloInstruction* fused_negate = outer_fusion->fused_expression_root(); 200 ASSERT_EQ(fused_negate->opcode(), HloOpcode::kNegate); 201 outer_fusion->fused_instructions_computation()->CreateFusionInstruction( 202 {fused_negate}, HloInstruction::FusionKind::kLoop); 203 204 EXPECT_THAT(computation->root_instruction(), op::Fusion()); 205 206 EXPECT_EQ(2, FusionCount()); 207 EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); 208 EXPECT_EQ(0, FusionCount()); 209 210 EXPECT_THAT(computation->root_instruction(), op::Negate(op::Add())); 211 } 212 213 } // namespace 214 } // namespace xla 215